Defining groups¶
In this tutorial, you will learn how to:
- define groups (aka sectionlists) to simplify iteractions with
Jaxley
Here is a code snippet which you will learn to understand in this tutorial:
from jax import jit, vmap
net = ... # See tutorial on Basics of Jaxley.
net.cell(0).add_to_group("fast_spiking")
net.cell(1).add_to_group("slow_spiking")
def simulate(params):
param_state = None
param_state = net.fast_spiking.data_set("HH_gNa", params[0], param_state)
param_state = net.slow_spiking.data_set("HH_gNa", params[1], param_state)
return jx.integrate(net, param_state=param_state)
# Define sodium for fast and slow spiking neurons.
params = jnp.asarray([1.0, 0.1])
# Run simulation.
voltages = simulate(params)
In many cases, you might want to group several compartments (or branches, or cells) and assign a unique parameter or mechanism to this group. For example, you might want to define a couple of branches as basal and then assign a Hodgkin-Huxley mechanism only to those branches. Or you might define a couple of cells as fast spiking and assign them a high value for the sodium conductance. We describe how you can do this in this tutorial.
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
import time
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad
import jaxley as jx
from jaxley.channels import Na, K, Leak
from jaxley.synapses import IonotropicSynapse
from jaxley.connect import fully_connect
First, we define a network as you saw in the previous tutorial:
comp = jx.Compartment()
branch = jx.Branch(comp, ncomp=2)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1])
network = jx.Network([cell for _ in range(3)])
pre = network.cell([0, 1])
post = network.cell([2])
fully_connect(pre, post, IonotropicSynapse())
network.insert(Na())
network.insert(K())
network.insert(Leak())
Group: apical dendrites¶
Assume that, in each of the five neurons in this network, the second and forth branch are apical dendrites. We can define this as:
for cell_ind in range(3):
network.cell(cell_ind).branch(1).add_to_group("apical")
network.cell(cell_ind).branch(3).add_to_group("apical")
After this, we can access network.apical
as we previously accesses anything else:
network.apical.set("radius", 0.3)
network.apical.view
View with 3 different channels. Use `.nodes` for details.
Group: fast spiking¶
Similarly, you could define a group of fast-spiking cells. Assume that the first and second cell are fast-spiking:
network.cell(0).add_to_group("fast_spiking")
network.cell(1).add_to_group("fast_spiking")
network.fast_spiking.set("Na_gNa", 0.4)
network.fast_spiking.view
View with 3 different channels. Use `.nodes` for details.
Groups from SWC files¶
If you are reading .swc
morphologigies, you can automatically assign groups with
jx.read_swc(file_name, nseg=n, assign_groups=True).
cell.soma
, cell.apical
, cell.basal
, or cell.axon
.
How groups are interpreted by .make_trainable()
¶
If you make a parameter of a group
trainable, then it will be treated as a single shared parameter for a given property:
network.fast_spiking.make_trainable("Na_gNa")
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1
As such, get_parameters()
returns only a single trainable parameter, which will be the sodium conductance for every compartment of every fast-spiking neuron:
network.get_parameters()
[{'Na_gNa': Array([0.4], dtype=float64)}]
If, instead, you would want a separate parameter for every fast-spiking cell, you should not use the group, but instead do the following (remember that fast-spiking neurons had indices [0,1]):
network.cell([0,1]).make_trainable("axial_resistivity")
Number of newly added trainable parameters: 2. Total number of trainable parameters: 3
network.get_parameters()
[{'Na_gNa': Array([0.4], dtype=float64)},
{'axial_resistivity': Array([5000., 5000.], dtype=float64)}]
This generated two parameters for the axial resistivitiy, each corresponding to one cell.
Summary¶
Groups allow you to organize your simulation in a more intuitive way, and they allow to perform parameter sharing with make_trainable()
.