Defining groups for easier handling of complex networks¶
In this tutorial, you will learn how to:
- define groups (aka sectionlists) to simplify iteractions with
Jaxley
Here is a code snippet which you will learn to understand in this tutorial:
from jax import jit, vmap
net = ... # See tutorial on Basics of Jaxley.
net.cell(0).add_to_group("fast_spiking")
net.cell(1).add_to_group("slow_spiking")
def simulate(params):
param_state = None
param_state = net.fast_spiking.data_set("HH_gNa", params[0], param_state)
param_state = net.slow_spiking.data_set("HH_gNa", params[1], param_state)
return jx.integrate(net, param_state=param_state)
# Define sodium for fast and slow spiking neurons.
params = jnp.asarray([1.0, 0.1])
# Run simulation.
voltages = simulate(params)
In many cases, you might want to group several compartments (or branches, or cells) and assign a unique parameter or mechanism to this group. For example, you might want to define a couple of branches as basal and then assign a Hodgkin-Huxley mechanism only to those branches. Or you might define a couple of cells as fast spiking and assign them a high value for the sodium conductance. We describe how you can do this in this tutorial.
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
import time
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad
import jaxley as jx
from jaxley.channels import Na, K, Leak
from jaxley.synapses import IonotropicSynapse
from jaxley.connect import fully_connect
First, we define a network as you saw in the previous tutorial:
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=2)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1])
network = jx.Network([cell for _ in range(3)])
pre = network.cell([0, 1])
post = network.cell([2])
fully_connect(pre, post, IonotropicSynapse())
network.insert(Na())
network.insert(K())
network.insert(Leak())
/Users/michaeldeistler/Documents/phd/jaxley/jaxley/modules/base.py:1533: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
self.pointer.edges = pd.concat(
Group: apical dendrites¶
Assume that, in each of the five neurons in this network, the second and forth branch are apical dendrites. We can define this as:
for cell_ind in range(3):
network.cell(cell_ind).branch(1).add_to_group("apical")
network.cell(cell_ind).branch(3).add_to_group("apical")
After this, we can access network.apical
as we previously accesses anything else:
network.apical.set("radius", 0.3)
network.apical.view
comp_index | branch_index | cell_index | length | radius | axial_resistivity | capacitance | v | Na | Na_gNa | ... | K_gK | eK | K_n | Leak | Leak_gLeak | Leak_eLeak | global_comp_index | global_branch_index | global_cell_index | controlled_by_param | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
2 | 2 | 1 | 0 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.05 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 2 | 1 | 0 | 0 |
3 | 3 | 1 | 0 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.05 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 3 | 1 | 0 | 0 |
6 | 6 | 3 | 0 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.05 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 6 | 3 | 0 | 0 |
7 | 7 | 3 | 0 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.05 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 7 | 3 | 0 | 0 |
10 | 10 | 5 | 1 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.05 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 10 | 5 | 1 | 0 |
11 | 11 | 5 | 1 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.05 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 11 | 5 | 1 | 0 |
14 | 14 | 7 | 1 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.05 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 14 | 7 | 1 | 0 |
15 | 15 | 7 | 1 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.05 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 15 | 7 | 1 | 0 |
18 | 18 | 9 | 2 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.05 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 18 | 9 | 2 | 0 |
19 | 19 | 9 | 2 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.05 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 19 | 9 | 2 | 0 |
22 | 22 | 11 | 2 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.05 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 22 | 11 | 2 | 0 |
23 | 23 | 11 | 2 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.05 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 23 | 11 | 2 | 0 |
12 rows × 25 columns
Group: fast spiking¶
Similarly, you could define a group of fast-spiking cells. Assume that the first and second cell are fast-spiking:
network.cell(0).add_to_group("fast_spiking")
network.cell(1).add_to_group("fast_spiking")
network.fast_spiking.set("Na_gNa", 0.4)
network.fast_spiking.view
comp_index | branch_index | cell_index | length | radius | axial_resistivity | capacitance | v | Na | Na_gNa | ... | K_gK | eK | K_n | Leak | Leak_gLeak | Leak_eLeak | global_comp_index | global_branch_index | global_cell_index | controlled_by_param | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 0 | 0 | 10.0 | 1.0 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 0 | 0 | 0 | 0 |
1 | 1 | 0 | 0 | 10.0 | 1.0 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 1 | 0 | 0 | 0 |
2 | 2 | 1 | 0 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 2 | 1 | 0 | 0 |
3 | 3 | 1 | 0 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 3 | 1 | 0 | 0 |
4 | 4 | 2 | 0 | 10.0 | 1.0 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 4 | 2 | 0 | 0 |
5 | 5 | 2 | 0 | 10.0 | 1.0 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 5 | 2 | 0 | 0 |
6 | 6 | 3 | 0 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 6 | 3 | 0 | 0 |
7 | 7 | 3 | 0 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 7 | 3 | 0 | 0 |
8 | 8 | 4 | 1 | 10.0 | 1.0 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 8 | 4 | 1 | 0 |
9 | 9 | 4 | 1 | 10.0 | 1.0 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 9 | 4 | 1 | 0 |
10 | 10 | 5 | 1 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 10 | 5 | 1 | 0 |
11 | 11 | 5 | 1 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 11 | 5 | 1 | 0 |
12 | 12 | 6 | 1 | 10.0 | 1.0 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 12 | 6 | 1 | 0 |
13 | 13 | 6 | 1 | 10.0 | 1.0 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 13 | 6 | 1 | 0 |
14 | 14 | 7 | 1 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 14 | 7 | 1 | 0 |
15 | 15 | 7 | 1 | 10.0 | 0.3 | 5000.0 | 1.0 | -70.0 | True | 0.4 | ... | 0.005 | -90.0 | 0.2 | True | 0.0001 | -70.0 | 15 | 7 | 1 | 0 |
16 rows × 25 columns
Groups from SWC files¶
Note: If you are reading swc morphologigies, you can automatically assign groups with jx.read_swc(file_name, assign_groups=True)
. After that, you can directly use cell.soma
, cell.apical
, cell.basal
, or cell.axon
.
How groups are interpreted by .make_trainable()
¶
If you make a parameter of a group
trainable, then it will be treated as a single shared parameter for a given property:
network.fast_spiking.make_trainable("Na_gNa")
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1
As such, get_parameters()
returns only a single trainable parameter, which will be the sodium conductance for every compartment of every fast-spiking neuron:
network.get_parameters()
[{'Na_gNa': Array([0.4], dtype=float64)}]
If, instead, you would want a separate parameter for every fast-spiking cell, you should not use the group, but instead do the following (remember that fast-spiking neurons had indices [0,1,3]):
network.cell([0,1]).make_trainable("axial_resistivity")
Number of newly added trainable parameters: 2. Total number of trainable parameters: 3
network.get_parameters()
[{'Na_gNa': Array([0.4], dtype=float64)},
{'axial_resistivity': Array([5000., 5000.], dtype=float64)}]
This generated three parameters for the axial resistivitiy, each corresponding to one cell.
Summary¶
Groups allow you to organize your simulation in a more intuitive way, and they allow to perform parameter sharing with make_trainable()
.
If you have not done so already, we recommend you to check out the tutorial on how to compute the gradient and train biophysical models.