Skip to content

Speeding up simulations with JIT-compilation and GPUs

In this tutorial, you will learn how to:

  • make parameter sweeps in Jaxley
  • use jit to compile your simulations and make them faster
  • use vmap to parallelize simulations on GPUs

Here is a code snippet which you will learn to understand in this tutorial:

from jax import jit, vmap


cell = ...  # See tutorial on Basics of Jaxley.

def simulate(params):
    param_state = None
    param_state = cell.data_set("Na_gNa", params[0], param_state)
    param_state = cell.data_set("K_gK", params[1], param_state)
    return jx.integrate(cell, param_state=param_state)

# Define 100 sets of sodium and potassium conductances.
all_params = jnp.asarray(np.random.rand(100, 2))

# Fast for-loops with jit compilation.
jitted_simulate = jit(simulate)
voltages = [jitted_simulate(params) for params in all_params]

# Using vmap for parallelization.
vmapped_simulate = vmap(jitted_simulate, in_axes=(0,))
voltages = vmapped_simulate(all_params)

In the previous tutorials, you learned how to build single cells or networks and how to change their parameters. In this tutorial, you will learn how to speed up such simulations by many orders of magnitude. This can be achieved in to ways:

  • by using JIT compilation
  • by using GPU parallelization

Let’s get started!

Using GPU or CPU

In Jaxley you can set whether you want to use gpu or cpu with the following lines at the beginning of your script:

from jax import config
config.update("jax_platform_name", "cpu")

JAX (and Jaxley) also allow to choose between float32 and float64. Especially on GPUs, float32 will be faster, but we have experienced stability issues when simulating morphologically detailed neurons with float32.

config.update("jax_enable_x64", True)  # Set to false to use `float32`.

Next, we will import relevant libraries:

import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from jax import jit, vmap

import jaxley as jx
from jaxley.channels import Na, K, Leak

Building the cell or network

We first build a cell (or network) in the same way as we showed in the previous tutorials:

dt = 0.025
t_max = 10.0

comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])

cell.insert(Na())
cell.insert(K())
cell.insert(Leak())

cell.delete_stimuli()
current = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=dt, t_max=t_max)
cell.branch(0).loc(0.0).stimulate(current)

cell.delete_recordings()
cell.branch(0).loc(0.0).record()
Added 1 external_states. See `.externals` for details.
Added 1 recordings. See `.recordings` for details.

Parameter sweeps

Assume you want to run the same cell with many different values for the sodium and potassium conductance, for example for genetic algorithms or for parameter sweeps. To do this efficiently in Jaxley, you have to use the data_set() method (in combination with jit and vmap, as shown later):

def simulate(params):
    param_state = None
    param_state = cell.data_set("Na_gNa", params[0], param_state)
    param_state = cell.data_set("K_gK", params[1], param_state)
    return jx.integrate(cell, param_state=param_state)

The .data_set() method takes three arguments:

1) the name of the parameter you want to set. Jaxley allows to set the following parameters: “radius”, “length”, “axial_resistivity”, as well as all parameters of channels and synapses.
2) the value of the parameter.
3) a param_state which is initialized as None and is modified by .data_set(). This has to be passed to jx.integrate().

Having done this, the simplest (but least efficient) way to perform the parameter sweep is to run a for-loop over many parameter sets:

# Define 5 sets of sodium and potassium conductances.
all_params = jnp.asarray(np.random.rand(5, 2))

voltages = jnp.asarray([simulate(params) for params in all_params])
print("voltages.shape", voltages.shape)
voltages.shape (5, 1, 402)

The resulting voltages have shape (num_simulations, num_recordings, num_timesteps).

Stimulus sweeps

In addition to running sweeps across multiple parameters, you can also run sweeeps across multiple stimuli (e.g. step current stimuli of different amplitudes. You can achieve this with the data_stimulate() method:

def simulate(i_amp):
    current = jx.step_current(1.0, 1.0, i_amp, 0.025, 10.0)

    data_stimuli = None
    data_stimuli = cell.branch(0).comp(0).data_stimulate(current, data_stimuli)
    return jx.integrate(cell, data_stimuli=data_stimuli)

Speeding up for loops via jit compilation

We can speed up such parameter sweeps (or stimulus sweeps) with jit compilation. jit compilation will compile the simulation when it is run for the first time, such that every other simulation will be must faster. This can be achieved by defining a new function which uses JAX’s jit():

jitted_simulate = jit(simulate)
# First run, will be slow.
voltages = jitted_simulate(all_params[0])
# More runs, will be much faster.
voltages = jnp.asarray([jitted_simulate(params) for params in all_params])
print("voltages.shape", voltages.shape)
voltages.shape (5, 1, 402)

jit compilation can be up to 10k times faster, especially for small simulations with few compartments. For very large models, the gain obtained with jit will be much smaller (jit may even provide no speed up at all).

Speeding up with GPU parallelization via vmap

Another way to speed up parameter sweeps is with GPU parallelization. Parallelization in Jaxley can be achieved by using vmap of JAX. To do this, we first create a new function that handles multiple parameter sets directly:

# Using vmap for parallelization.
vmapped_simulate = vmap(jitted_simulate)

We can then run this method on all parameter sets (all_params.shape == (100, 2)), and Jaxley will automatically parallelize across them. Of course, you will only get a speed-up if you have a GPU available and you specified gpu as device in the beginning of this tutorial.

voltages = vmapped_simulate(all_params)

GPU parallelization with vmap can give a large speed-up, which can easily be 2-3 orders of magnitude.

Combining jit and vmap

Finally, you can also combine using jit and vmap. For example, you can run multiple batches of many parallel simulations. Each batch can be parallelized with vmap and simulating each batch can be compiled with jit:

jitted_vmapped_simulate = jit(vmap(simulate))
for batch in range(10):
    all_params = jnp.asarray(np.random.rand(5, 2))
    voltages_batch = jitted_vmapped_simulate(all_params)

That’s all you have to know about jit and vmap! If you have worked through this and the previous tutorials, you should be ready to set up your first network simulations.

Next steps

If you want to learn more, we recommend you to read the tutorial on building channel and synapse models or to read the tutorial on groups, which allow to make your Jaxley simulations more elegant and convenient to interact with.

Alternatively, you can also directly jump ahead to the tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.