Skip to content

Defining groups for easier handling of complex networks

In this tutorial, you will learn how to:

  • define groups (aka sectionlists) to simplify iteractions with Jaxley

Here is a code snippet which you will learn to understand in this tutorial:

from jax import jit, vmap


net = ...  # See tutorial on Basics of Jaxley.

net.cell(0).add_to_group("fast_spiking")
net.cell(1).add_to_group("slow_spiking")

def simulate(params):
    param_state = None
    param_state = net.fast_spiking.data_set("HH_gNa", params[0], param_state)
    param_state = net.slow_spiking.data_set("HH_gNa", params[1], param_state)
    return jx.integrate(net, param_state=param_state)

# Define sodium for fast and slow spiking neurons.
params = jnp.asarray([1.0, 0.1])

# Run simulation.
voltages = simulate(params)

In many cases, you might want to group several compartments (or branches, or cells) and assign a unique parameter or mechanism to this group. For example, you might want to define a couple of branches as basal and then assign a Hodgkin-Huxley mechanism only to those branches. Or you might define a couple of cells as fast spiking and assign them a high value for the sodium conductance. We describe how you can do this in this tutorial.

from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import time
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad

import jaxley as jx
from jaxley.channels import Na, K, Leak
from jaxley.synapses import IonotropicSynapse
from jaxley.connect import fully_connect

First, we define a network as you saw in the previous tutorial:

comp = jx.Compartment()
branch = jx.Branch(comp, nseg=2)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1])
network = jx.Network([cell for _ in range(3)])

pre = network.cell([0, 1])
post = network.cell([2])
fully_connect(pre, post, IonotropicSynapse())

network.insert(Na())
network.insert(K())
network.insert(Leak())
/Users/michaeldeistler/Documents/phd/jaxley/jaxley/modules/base.py:1533: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
  self.pointer.edges = pd.concat(

Group: apical dendrites

Assume that, in each of the five neurons in this network, the second and forth branch are apical dendrites. We can define this as:

for cell_ind in range(3):
    network.cell(cell_ind).branch(1).add_to_group("apical")
    network.cell(cell_ind).branch(3).add_to_group("apical")

After this, we can access network.apical as we previously accesses anything else:

network.apical.set("radius", 0.3)
network.apical.view
comp_index branch_index cell_index length radius axial_resistivity capacitance v Na Na_gNa ... K_gK eK K_n Leak Leak_gLeak Leak_eLeak global_comp_index global_branch_index global_cell_index controlled_by_param
2 2 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 2 1 0 0
3 3 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 3 1 0 0
6 6 3 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 6 3 0 0
7 7 3 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 7 3 0 0
10 10 5 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 10 5 1 0
11 11 5 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 11 5 1 0
14 14 7 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 14 7 1 0
15 15 7 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 15 7 1 0
18 18 9 2 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 18 9 2 0
19 19 9 2 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 19 9 2 0
22 22 11 2 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 22 11 2 0
23 23 11 2 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... 0.005 -90.0 0.2 True 0.0001 -70.0 23 11 2 0

12 rows × 25 columns

Group: fast spiking

Similarly, you could define a group of fast-spiking cells. Assume that the first and second cell are fast-spiking:

network.cell(0).add_to_group("fast_spiking")
network.cell(1).add_to_group("fast_spiking")
network.fast_spiking.set("Na_gNa", 0.4)
network.fast_spiking.view
comp_index branch_index cell_index length radius axial_resistivity capacitance v Na Na_gNa ... K_gK eK K_n Leak Leak_gLeak Leak_eLeak global_comp_index global_branch_index global_cell_index controlled_by_param
0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 0 0 0 0
1 1 0 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 1 0 0 0
2 2 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 2 1 0 0
3 3 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 3 1 0 0
4 4 2 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 4 2 0 0
5 5 2 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 5 2 0 0
6 6 3 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 6 3 0 0
7 7 3 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 7 3 0 0
8 8 4 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 8 4 1 0
9 9 4 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 9 4 1 0
10 10 5 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 10 5 1 0
11 11 5 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 11 5 1 0
12 12 6 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 12 6 1 0
13 13 6 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 13 6 1 0
14 14 7 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 14 7 1 0
15 15 7 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.005 -90.0 0.2 True 0.0001 -70.0 15 7 1 0

16 rows × 25 columns

Groups from SWC files

Note: If you are reading swc morphologigies, you can automatically assign groups with jx.read_swc(file_name, assign_groups=True). After that, you can directly use cell.soma, cell.apical, cell.basal, or cell.axon.

How groups are interpreted by .make_trainable()

If you make a parameter of a group trainable, then it will be treated as a single shared parameter for a given property:

network.fast_spiking.make_trainable("Na_gNa")
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1

As such, get_parameters() returns only a single trainable parameter, which will be the sodium conductance for every compartment of every fast-spiking neuron:

network.get_parameters()
[{'Na_gNa': Array([0.4], dtype=float64)}]

If, instead, you would want a separate parameter for every fast-spiking cell, you should not use the group, but instead do the following (remember that fast-spiking neurons had indices [0,1,3]):

network.cell([0,1]).make_trainable("axial_resistivity")
Number of newly added trainable parameters: 2. Total number of trainable parameters: 3
network.get_parameters()
[{'Na_gNa': Array([0.4], dtype=float64)},
 {'axial_resistivity': Array([5000., 5000.], dtype=float64)}]

This generated three parameters for the axial resistivitiy, each corresponding to one cell.

Summary

Groups allow you to organize your simulation in a more intuitive way, and they allow to perform parameter sharing with make_trainable().