Speeding up simulations with JIT-compilation and GPUs¶
In this tutorial, you will learn how to:
- make parameter sweeps in
Jaxley
- use
jit
to compile your simulations and make them faster - use
vmap
to parallelize simulations on GPUs
Here is a code snippet which you will learn to understand in this tutorial:
from jax import jit, vmap
cell = ... # See tutorial on Basics of Jaxley.
def simulate(params):
param_state = None
param_state = cell.data_set("Na_gNa", params[0], param_state)
param_state = cell.data_set("K_gK", params[1], param_state)
return jx.integrate(cell, param_state=param_state)
# Define 100 sets of sodium and potassium conductances.
all_params = jnp.asarray(np.random.rand(100, 2))
# Fast for-loops with jit compilation.
jitted_simulate = jit(simulate)
voltages = [jitted_simulate(params) for params in all_params]
# Using vmap for parallelization.
vmapped_simulate = vmap(jitted_simulate, in_axes=(0,))
voltages = vmapped_simulate(all_params)
In the previous tutorials, you learned how to build single cells or networks and how to change their parameters. In this tutorial, you will learn how to speed up such simulations by many orders of magnitude. This can be achieved in to ways:
- by using JIT compilation
- by using GPU parallelization
Let’s get started!
Using GPU or CPU¶
In Jaxley
you can set whether you want to use gpu
or cpu
with the following lines at the beginning of your script:
from jax import config
config.update("jax_platform_name", "cpu")
JAX
(and Jaxley
) also allow to choose between float32
and float64
. Especially on GPUs, float32
will be faster, but we have experienced stability issues when simulating morphologically detailed neurons with float32
.
config.update("jax_enable_x64", True) # Set to false to use `float32`.
Next, we will import relevant libraries:
import matplotlib.pyplot as plt
import numpy as np
import jax.numpy as jnp
from jax import jit, vmap
import jaxley as jx
from jaxley.channels import Na, K, Leak
Building the cell or network¶
We first build a cell (or network) in the same way as we showed in the previous tutorials:
dt = 0.025
t_max = 10.0
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])
cell.insert(Na())
cell.insert(K())
cell.insert(Leak())
cell.delete_stimuli()
current = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=dt, t_max=t_max)
cell.branch(0).loc(0.0).stimulate(current)
cell.delete_recordings()
cell.branch(0).loc(0.0).record()
Added 1 external_states. See `.externals` for details.
Added 1 recordings. See `.recordings` for details.
Parameter sweeps¶
Assume you want to run the same cell with many different values for the sodium and potassium conductance, for example for genetic algorithms or for parameter sweeps. To do this efficiently in Jaxley
, you have to use the data_set()
method (in combination with jit
and vmap
, as shown later):
def simulate(params):
param_state = None
param_state = cell.data_set("Na_gNa", params[0], param_state)
param_state = cell.data_set("K_gK", params[1], param_state)
return jx.integrate(cell, param_state=param_state)
The .data_set()
method takes three arguments:
1) the name of the parameter you want to set. Jaxley
allows to set the following parameters: “radius”, “length”, “axial_resistivity”, as well as all parameters of channels and synapses.
2) the value of the parameter.
3) a param_state
which is initialized as None
and is modified by .data_set()
. This has to be passed to jx.integrate()
.
Having done this, the simplest (but least efficient) way to perform the parameter sweep is to run a for-loop over many parameter sets:
# Define 5 sets of sodium and potassium conductances.
all_params = jnp.asarray(np.random.rand(5, 2))
voltages = jnp.asarray([simulate(params) for params in all_params])
print("voltages.shape", voltages.shape)
voltages.shape (5, 1, 402)
The resulting voltages have shape (num_simulations, num_recordings, num_timesteps)
.
Stimulus sweeps¶
In addition to running sweeps across multiple parameters, you can also run sweeeps across multiple stimuli (e.g. step current stimuli of different amplitudes. You can achieve this with the data_stimulate()
method:
def simulate(i_amp):
current = jx.step_current(1.0, 1.0, i_amp, 0.025, 10.0)
data_stimuli = None
data_stimuli = cell.branch(0).comp(0).data_stimulate(current, data_stimuli)
return jx.integrate(cell, data_stimuli=data_stimuli)
Speeding up for loops via jit
compilation¶
We can speed up such parameter sweeps (or stimulus sweeps) with jit
compilation. jit
compilation will compile the simulation when it is run for the first time, such that every other simulation will be must faster. This can be achieved by defining a new function which uses JAX
’s jit()
:
jitted_simulate = jit(simulate)
# First run, will be slow.
voltages = jitted_simulate(all_params[0])
# More runs, will be much faster.
voltages = jnp.asarray([jitted_simulate(params) for params in all_params])
print("voltages.shape", voltages.shape)
voltages.shape (5, 1, 402)
jit
compilation can be up to 10k times faster, especially for small simulations with few compartments. For very large models, the gain obtained with jit
will be much smaller (jit
may even provide no speed up at all).
Speeding up with GPU parallelization via vmap
¶
Another way to speed up parameter sweeps is with GPU parallelization. Parallelization in Jaxley
can be achieved by using vmap
of JAX
. To do this, we first create a new function that handles multiple parameter sets directly:
# Using vmap for parallelization.
vmapped_simulate = vmap(jitted_simulate)
We can then run this method on all parameter sets (all_params.shape == (100, 2)
), and Jaxley
will automatically parallelize across them. Of course, you will only get a speed-up if you have a GPU available and you specified gpu
as device in the beginning of this tutorial.
voltages = vmapped_simulate(all_params)
GPU parallelization with vmap
can give a large speed-up, which can easily be 2-3 orders of magnitude.
Combining jit
and vmap
¶
Finally, you can also combine using jit
and vmap
. For example, you can run multiple batches of many parallel simulations. Each batch can be parallelized with vmap
and simulating each batch can be compiled with jit
:
jitted_vmapped_simulate = jit(vmap(simulate))
for batch in range(10):
all_params = jnp.asarray(np.random.rand(5, 2))
voltages_batch = jitted_vmapped_simulate(all_params)
That’s all you have to know about jit
and vmap
! If you have worked through this and the previous tutorials, you should be ready to set up your first network simulations.
Next steps¶
If you want to learn more, we recommend you to read the tutorial on building channel and synapse models or to read the tutorial on groups, which allow to make your Jaxley
simulations more elegant and convenient to interact with.
Alternatively, you can also directly jump ahead to the tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.