Description

An end-to-end training framework for spiking neural networks (SNNs) by combining zeroth-order optimization and meta-learning for gradient estimation.

Background

Spiking neural networks (SNNs) imitate biological neurons through the spiking function, and are very on neuromorphic hardware. The building block of SNNs is the leaky integrate-and-fire (LIF) neuron model, which is a simple model that biological realism and computational practicality. LIF-based SNNs are governed by the following equations:

Here, is the membrane potential, is the input, is the learnable weight matrix, is the output, and is the threshold. The output is binary, and the neuron spikes when the membrane potential exceeds the threshold. The membrane potential leaks over time and is reset to zero when the potential crosses the threshold. The spiking function is the heaviside step function, and serves as the non-linear activation function. Sadly, the heaviside step function is non-differentiable, so we can’t just backpropagate through the network to train it.

Approach

Gradient Estimation

Luckily, we can use zeroth-order optimization to estimate the gradient of a smooth approximation of the heaviside step function. This is the method proposed in the LocalZO paper. They use the 2-point estimator with antithetic sampling to estimate the gradient of the heaviside step function. The 2-point estimator is defined as:

where is a dimension-dependent factor, and is a random perturbation sampled from a distribution such that . Applied to the heaviside step function, the 2-point estimator averaged over samples becomes:

Averaging over samples helps reduce the Monte Carlo estimation error. Sadly (again), zeroth-order variance is dominated by the dimensionality of the function. Although we can now push gradients through the network, the variance can prevent us training from the network effectively.

Variance Reduction

Luckily (again), we can learn how to reduce the variance of the gradient estimator. This is the main idea behind the Learning to Learn paper, and has been explored in the zeroth-order setting. The basic idea is to parameterize the optimizer function using a long short-term memory (LSTM) network , called the meta-optimizer. The forward pass of the meta-optimizer takes our gradient estimator as input and outputs a variance-reduced descent direction:

To train the network, we take the weighted sum of the losses of the optimization trajectory proposed by the meta-optimizer over time steps:

Results

We apply the LocalZO method to train a convolutional SNN on the MNIST dataset and benchmark against other optimizers:

MNIST

We also train meta-optimizers on multi-layer perceptron classifiers, and test their transferability to the convolutional SNN (denoted linear). Lastly, we train a meta-optimizer on the validation set only, and evaluate its performance to optimize on the trianing set (denoted V).