Skip to content

Utils

compute_children_indices(parents)

Return all children indices of every branch.

Example:

parents = [-1, 0, 0]
compute_children_indices(parents) -> [[1, 2], [], []]

Source code in jaxley/utils/cell_utils.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def compute_children_indices(parents) -> List[jnp.ndarray]:
    """Return all children indices of every branch.

    Example:
    ```
    parents = [-1, 0, 0]
    compute_children_indices(parents) -> [[1, 2], [], []]
    ```
    """
    num_branches = len(parents)
    child_indices = []
    for b in range(num_branches):
        child_indices.append(np.where(parents == b)[0])
    return child_indices

compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2)

Return the coupling conductance between two compartments.

Equations taken from https://en.wikipedia.org/wiki/Compartmental_neuron_models.

radius: um r_a: ohm cm length_single_compartment: um coupling_conds: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2

Source code in jaxley/utils/cell_utils.py
223
224
225
226
227
228
229
230
231
232
233
234
def compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2):
    """Return the coupling conductance between two compartments.

    Equations taken from `https://en.wikipedia.org/wiki/Compartmental_neuron_models`.

    `radius`: um
    `r_a`: ohm cm
    `length_single_compartment`: um
    `coupling_conds`: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2
    """
    # Multiply by 10**7 to convert (S / cm / um) -> (mS / cm^2).
    return rad1 * rad2**2 / (r_a1 * rad2**2 * l1 + r_a2 * rad1**2 * l2) / l1 * 10**7

compute_coupling_cond_branchpoint(rad, r_a, l)

Return the coupling conductance between one compartment and a comp with l=0.

From https://en.wikipedia.org/wiki/Compartmental_neuron_models

If one compartment has l=0.0 then the equations simplify.

R_long = \sum_i r_a * L_i/2 / crosssection_i

with crosssection = pi * r**2

For a single compartment with L>0, this turns into: R_long = r_a * L/2 / crosssection

Then, g_long = crosssection * 2 / L / r_a

Then, the effective conductance is g_long / zylinder_area. So: g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L g = r / r_a / L**2

Source code in jaxley/utils/cell_utils.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def compute_coupling_cond_branchpoint(rad, r_a, l):
    r"""Return the coupling conductance between one compartment and a comp with l=0.

    From https://en.wikipedia.org/wiki/Compartmental_neuron_models

    If one compartment has l=0.0 then the equations simplify.

    R_long = \sum_i r_a * L_i/2 / crosssection_i

    with crosssection = pi * r**2

    For a single compartment with L>0, this turns into:
    R_long = r_a * L/2 / crosssection

    Then, g_long = crosssection * 2 / L / r_a

    Then, the effective conductance is g_long / zylinder_area. So:
    g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L
    g = r / r_a / L**2
    """
    return rad / r_a / l**2 * 10**7  # Convert (S / cm / um) -> (mS / cm^2)

compute_impact_on_node(rad, r_a, l)

Compute the weight with which a compartment influences its node.

In order to satisfy Kirchhoffs current law, the current at a branch point must be proportional to the crosssection of the compartment. We only require proportionality here because the branch point equation reads: g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0

Because R_long = r_a * L/2 / crosssection, we get g_long = crosssection * 2 / L / r_a \propto rad**2 / L / r_a

This equation can be multiplied by any constant.

Source code in jaxley/utils/cell_utils.py
260
261
262
263
264
265
266
267
268
269
270
271
272
def compute_impact_on_node(rad, r_a, l):
    r"""Compute the weight with which a compartment influences its node.

    In order to satisfy Kirchhoffs current law, the current at a branch point must be
    proportional to the crosssection of the compartment. We only require proportionality
    here because the branch point equation reads:
    `g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0`

    Because R_long = r_a * L/2 / crosssection, we get
    g_long = crosssection * 2 / L / r_a \propto rad**2 / L / r_a

    This equation can be multiplied by any constant."""
    return rad**2 / r_a / l

compute_morphology_indices_in_levels(num_branchpoints, child_belongs_to_branchpoint, par_inds, child_inds)

Return (row, col) to build the sparse matrix defining the voltage eqs.

This is run at init, not during runtime.

Source code in jaxley/utils/cell_utils.py
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
def compute_morphology_indices_in_levels(
    num_branchpoints,
    child_belongs_to_branchpoint,
    par_inds,
    child_inds,
):
    """Return (row, col) to build the sparse matrix defining the voltage eqs.

    This is run at `init`, not during runtime.
    """
    branchpoint_inds_parents = jnp.arange(num_branchpoints)
    branchpoint_inds_children = child_belongs_to_branchpoint
    branch_inds_parents = par_inds
    branch_inds_children = child_inds

    children = jnp.stack([branch_inds_children, branchpoint_inds_children])
    parents = jnp.stack([branch_inds_parents, branchpoint_inds_parents])

    return {"children": children.T, "parents": parents.T}

convert_point_process_to_distributed(current, radius, length)

Convert current point process (nA) to distributed current (uA/cm2).

This function gets called for synapses and for external stimuli.

Parameters:

Name Type Description Default
current ndarray

Current in nA.

required
radius ndarray

Compartment radius in um.

required
length ndarray

Compartment length in um.

required
Return

Current in uA/cm2.

Source code in jaxley/utils/cell_utils.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def convert_point_process_to_distributed(
    current: jnp.ndarray, radius: jnp.ndarray, length: jnp.ndarray
) -> jnp.ndarray:
    """Convert current point process (nA) to distributed current (uA/cm2).

    This function gets called for synapses and for external stimuli.

    Args:
        current: Current in `nA`.
        radius: Compartment radius in `um`.
        length: Compartment length in `um`.

    Return:
        Current in `uA/cm2`.
    """
    area = 2 * pi * radius * length
    current /= area  # nA / um^2
    return current * 100_000  # Convert (nA / um^2) to (uA / cm^2)

equal_segments(branch_property, nseg_per_branch)

Generates segments where some property is the same in each segment.

Parameters:

Name Type Description Default
branch_property list

List of values of the property in each branch. Should have len(branch_property) == num_branches.

required
Source code in jaxley/utils/cell_utils.py
13
14
15
16
17
18
19
20
21
def equal_segments(branch_property: list, nseg_per_branch: int):
    """Generates segments where some property is the same in each segment.

    Args:
        branch_property: List of values of the property in each branch. Should have
            `len(branch_property) == num_branches`.
    """
    assert isinstance(branch_property, list), "branch_property must be a list."
    return jnp.asarray([branch_property] * nseg_per_branch).T

get_num_neighbours(num_children, nseg_per_branch, num_branches)

Number of neighbours of each compartment.

Source code in jaxley/utils/cell_utils.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def get_num_neighbours(
    num_children: jnp.ndarray,
    nseg_per_branch: int,
    num_branches: int,
):
    """
    Number of neighbours of each compartment.
    """
    num_neighbours = 2 * jnp.ones((num_branches * nseg_per_branch))
    num_neighbours = num_neighbours.at[nseg_per_branch - 1].set(1.0)
    num_neighbours = num_neighbours.at[jnp.arange(num_branches) * nseg_per_branch].set(
        num_children + 1.0
    )
    return num_neighbours

group_and_sum(values_to_sum, inds_to_group_by, num_branchpoints)

Group values by whether they have the same integer and sum values within group.

This is used to construct the last diagonals at the branch points.

Written by ChatGPT.

Source code in jaxley/utils/cell_utils.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
def group_and_sum(
    values_to_sum: jnp.ndarray, inds_to_group_by: jnp.ndarray, num_branchpoints: int
) -> jnp.ndarray:
    """Group values by whether they have the same integer and sum values within group.

    This is used to construct the last diagonals at the branch points.

    Written by ChatGPT.
    """
    # Initialize an array to hold the sum of each group
    group_sums = jnp.zeros(num_branchpoints)

    # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.
    # `len(inds) == 0` is the case for branches and compartments.
    if num_branchpoints > 0:
        group_sums = group_sums.at[inds_to_group_by].add(values_to_sum)

    return group_sums

index_of_loc(branch_ind, loc, nseg_per_branch)

Returns the index of a segment given a loc in [0, 1] and the index of a branch.

This is used because we specify locations such as synapses as a value between 0 and 1. We have to convert this onto a discrete segment here.

Parameters:

Name Type Description Default
branch_ind int

Index of the branch.

required
loc float

Location (in [0, 1]) along that branch.

required
nseg_per_branch int

Number of segments of each branch.

required

Returns:

Type Description
int

The index of the compartment within the entire cell.

Source code in jaxley/utils/cell_utils.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def index_of_loc(branch_ind: int, loc: float, nseg_per_branch: int) -> int:
    """Returns the index of a segment given a loc in [0, 1] and the index of a branch.

    This is used because we specify locations such as synapses as a value between 0 and
    1. We have to convert this onto a discrete segment here.

    Args:
        branch_ind: Index of the branch.
        loc: Location (in [0, 1]) along that branch.
        nseg_per_branch: Number of segments of each branch.

    Returns:
        The index of the compartment within the entire cell.
    """
    nseg = nseg_per_branch  # only for convenience.
    possible_locs = np.linspace(0.5 / nseg, 1 - 0.5 / nseg, nseg)
    ind_along_branch = np.argmin(np.abs(possible_locs - loc))
    return branch_ind * nseg + ind_along_branch

interpolate_xyz(loc, coords)

Perform a linear interpolation between xyz-coordinates.

Parameters:

Name Type Description Default
loc float

The location in [0,1] along the branch.

required
coords ndarray

Array containing the reconstructed xyzr points of the branch.

required
Return

Interpolated xyz coordinate at loc, shape `(3,).

Source code in jaxley/utils/cell_utils.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def interpolate_xyz(loc: float, coords: np.ndarray):
    """Perform a linear interpolation between xyz-coordinates.

    Args:
        loc: The location in [0,1] along the branch.
        coords: Array containing the reconstructed xyzr points of the branch.

    Return:
        Interpolated xyz coordinate at `loc`, shape `(3,).
    """
    dl = np.sqrt(np.sum(np.diff(coords[:, :3], axis=0) ** 2, axis=1))
    pathlens = np.insert(np.cumsum(dl), 0, 0)  # cummulative length of sections
    norm_pathlens = pathlens / np.maximum(1e-8, pathlens[-1])  # norm lengths to [0,1].

    return v_interp(loc, norm_pathlens, coords[:, :3])

linear_segments(initial_val, endpoint_vals, parents, nseg_per_branch)

Generates segments where some property is linearly interpolated.

Parameters:

Name Type Description Default
initial_val float

The value at the tip of the soma.

required
endpoint_vals list

The value at the endpoints of each branch.

required
Source code in jaxley/utils/cell_utils.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def linear_segments(
    initial_val: float, endpoint_vals: list, parents: jnp.ndarray, nseg_per_branch: int
):
    """Generates segments where some property is linearly interpolated.

    Args:
        initial_val: The value at the tip of the soma.
        endpoint_vals: The value at the endpoints of each branch.
    """
    branch_property = endpoint_vals + [initial_val]
    num_branches = len(parents)
    # Compute radiuses by linear interpolation.
    endpoint_radiuses = jnp.asarray(branch_property)

    def compute_rad(branch_ind, loc):
        start = endpoint_radiuses[parents[branch_ind]]
        end = endpoint_radiuses[branch_ind]
        return (end - start) * loc + start

    branch_inds_of_each_comp = jnp.tile(jnp.arange(num_branches), nseg_per_branch)
    locs_of_each_comp = jnp.linspace(1, 0, nseg_per_branch).repeat(num_branches)
    rad_of_each_comp = compute_rad(branch_inds_of_each_comp, locs_of_each_comp)

    return jnp.reshape(rad_of_each_comp, (nseg_per_branch, num_branches)).T

loc_of_index(global_comp_index, nseg)

Return location corresponding to index.

Source code in jaxley/utils/cell_utils.py
216
217
218
219
220
def loc_of_index(global_comp_index, nseg):
    """Return location corresponding to index."""
    index = global_comp_index % nseg
    possible_locs = np.linspace(0.5 / nseg, 1 - 0.5 / nseg, nseg)
    return possible_locs[index]

merge_cells(cumsum_num_branches, cumsum_num_branchpoints, arrs, exclude_first=True)

Build full list of which branches are solved in which iteration.

From the branching pattern of single cells, this “merges” them into a single ordering of branches.

Parameters:

Name Type Description Default
cumsum_num_branches List[int]

cumulative number of branches. E.g., for three cells with 10, 15, and 5 branches respectively, this will should be a list containing [0, 10, 25, 30].

required
arrs List[List[ndarray]]

A list of a list of arrays that should be merged.

required
exclude_first bool

If True, the first element of each list in arrs will remain unchanged. Useful if a -1 (which indicates “no parent”) entry should not be changed.

True

Returns:

Type Description
ndarray

A list of arrays which contain the branch indices that are computed at each

ndarray

level (i.e., iteration).

Source code in jaxley/utils/cell_utils.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def merge_cells(
    cumsum_num_branches: List[int],
    cumsum_num_branchpoints: List[int],
    arrs: List[List[jnp.ndarray]],
    exclude_first: bool = True,
) -> jnp.ndarray:
    """
    Build full list of which branches are solved in which iteration.

    From the branching pattern of single cells, this "merges" them into a single
    ordering of branches.

    Args:
        cumsum_num_branches: cumulative number of branches. E.g., for three cells with
            10, 15, and 5 branches respectively, this will should be a list containing
            `[0, 10, 25, 30]`.
        arrs: A list of a list of arrays that should be merged.
        exclude_first: If `True`, the first element of each list in `arrs` will remain
            unchanged. Useful if a `-1` (which indicates "no parent") entry should not
            be changed.

    Returns:
        A list of arrays which contain the branch indices that are computed at each
        level (i.e., iteration).
    """
    ps = []
    for i, att in enumerate(arrs):
        p = att
        if exclude_first:
            raise NotImplementedError
            p = [p[0]] + [p_in_level + cumsum_num_branches[i] for p_in_level in p[1:]]
        else:
            p = [
                p_in_level
                + jnp.asarray([cumsum_num_branches[i], cumsum_num_branchpoints[i]])
                for p_in_level in p
            ]
        ps.append(p)

    max_len = max([len(att) for att in arrs])
    combined_parents_in_level = []
    for i in range(max_len):
        current_ps = []
        for p in ps:
            if len(p) > i:
                current_ps.append(p[i])
        combined_parents_in_level.append(jnp.concatenate(current_ps))

    return combined_parents_in_level

params_to_pstate(params, indices_set_by_trainables)

Make outputs get_parameters() conform with outputs of .data_set().

make_trainable() followed by params=get_parameters() does not return indices because these indices would also be differentiated by jax.grad (as soon as the params are passed to def simulate(params). Therefore, in jx.integrate, we run the function to add indices to the dict. The outputs of params_to_pstate are of the same shape as the outputs of .data_set().

Source code in jaxley/utils/cell_utils.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def params_to_pstate(
    params: List[Dict[str, jnp.ndarray]],
    indices_set_by_trainables: List[jnp.ndarray],
):
    """Make outputs `get_parameters()` conform with outputs of `.data_set()`.

    `make_trainable()` followed by `params=get_parameters()` does not return indices
    because these indices would also be differentiated by `jax.grad` (as soon as
    the `params` are passed to `def simulate(params)`. Therefore, in `jx.integrate`,
    we run the function to add indices to the dict. The outputs of `params_to_pstate`
    are of the same shape as the outputs of `.data_set()`."""
    return [
        {"key": list(p.keys())[0], "val": list(p.values())[0], "indices": i}
        for p, i in zip(params, indices_set_by_trainables)
    ]

query_channel_states_and_params(d, keys, idcs)

Get dict with subset of keys and values from d.

This is used to restrict a dict where every item contains all states to only the ones that are relevant for the channel. E.g.

states = {'eCa': Array([ 0., 0., nan]}

will be states = {'eCa': Array([ 0., 0.]}

Only loops over necessary keys, as opposed to looping over d.items().

Source code in jaxley/utils/cell_utils.py
401
402
403
404
405
406
407
408
409
410
411
412
413
def query_channel_states_and_params(d, keys, idcs):
    """Get dict with subset of keys and values from d.

    This is used to restrict a dict where every item contains __all__ states to only
    the ones that are relevant for the channel. E.g.

    ```states = {'eCa': Array([ 0.,  0., nan]}```

    will be
    ```states = {'eCa': Array([ 0.,  0.]}```

    Only loops over necessary keys, as opposed to looping over `d.items()`."""
    return dict(zip(keys, (v[idcs] for v in map(d.get, keys))))

remap_to_consecutive(arr)

Maps an array of integers to an array of consecutive integers.

E.g. [0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]

Source code in jaxley/utils/cell_utils.py
275
276
277
278
279
280
281
def remap_to_consecutive(arr):
    """Maps an array of integers to an array of consecutive integers.

    E.g. `[0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]`
    """
    _, inverse_indices = jnp.unique(arr, return_inverse=True)
    return inverse_indices

plot_morph(xyzr, dims=(0, 1), col='k', ax=None, type='line', morph_plot_kwargs={})

Plot morphology.

Parameters:

Name Type Description Default
dims Tuple[int]

Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.

(0, 1)
type str

Either line or scatter.

'line'
col str

The color for all branches.

'k'
Source code in jaxley/utils/plot_utils.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def plot_morph(
    xyzr,
    dims: Tuple[int] = (0, 1),
    col: str = "k",
    ax: Optional[Axes] = None,
    type: str = "line",
    morph_plot_kwargs: Dict = {},
):
    """Plot morphology.

    Args:
        dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of
            two of them.
        type: Either `line` or `scatter`.
        col: The color for all branches.
    """

    if ax is None:
        _, ax = plt.subplots(1, 1, figsize=(3, 3))

    for coords_of_branch in xyzr:
        x1, x2 = coords_of_branch[:, dims].T

        if "line" in type.lower():
            _ = ax.plot(x1, x2, color=col, **morph_plot_kwargs)
        elif "scatter" in type.lower():
            _ = ax.scatter(x1, x2, color=col, **morph_plot_kwargs)
        else:
            raise NotImplementedError

    return ax

swc_to_jaxley(fname, max_branch_len=100.0, sort=True, num_lines=None)

Read an SWC file and bring morphology into jaxley compatible formats.

Parameters:

Name Type Description Default
fname str

Path to swc file.

required
max_branch_len float

Maximal length of one branch. If a branch exceeds this length, it is split into equal parts such that each subbranch is below max_branch_len.

100.0
num_lines Optional[int]

Number of lines of the SWC file to read.

None
Source code in jaxley/utils/swc.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def swc_to_jaxley(
    fname: str,
    max_branch_len: float = 100.0,
    sort: bool = True,
    num_lines: Optional[int] = None,
) -> Tuple[List[int], List[float], List[Callable], List[float], List[np.ndarray]]:
    """Read an SWC file and bring morphology into `jaxley` compatible formats.

    Args:
        fname: Path to swc file.
        max_branch_len: Maximal length of one branch. If a branch exceeds this length,
            it is split into equal parts such that each subbranch is below
            `max_branch_len`.
        num_lines: Number of lines of the SWC file to read.
    """
    content = np.loadtxt(fname)[:num_lines]
    types = content[:, 1]
    is_single_point_soma = types[0] == 1 and types[1] != 1

    if is_single_point_soma:
        # Warn here, but the conversion of the length happens in `_compute_pathlengths`.
        warn(
            "Found a soma which consists of a single traced point. `Jaxley` "
            "interprets this soma as a spherical compartment with radius "
            "specified in the SWC file, i.e. with surface area 4*pi*r*r."
        )
    sorted_branches, types = _split_into_branches_and_sort(
        content,
        max_branch_len=max_branch_len,
        is_single_point_soma=is_single_point_soma,
        sort=sort,
    )

    parents = _build_parents(sorted_branches)
    each_length = _compute_pathlengths(
        sorted_branches, content[:, 1:6], is_single_point_soma=is_single_point_soma
    )
    pathlengths = [np.sum(length_traced) for length_traced in each_length]
    for i, pathlen in enumerate(pathlengths):
        if pathlen == 0.0:
            warn("Found a segment with length 0. Clipping it to 1.0")
            pathlengths[i] = 1.0
    radius_fns = _radius_generating_fns(
        sorted_branches, content[:, 5], each_length, parents, types
    )

    if np.sum(np.asarray(parents) == -1) > 1.0:
        parents = np.asarray([-1] + parents)
        parents[1:] += 1
        parents = parents.tolist()
        pathlengths = [0.1] + pathlengths
        radius_fns = [lambda x: content[0, 5] * np.ones_like(x)] + radius_fns
        sorted_branches = [[0]] + sorted_branches

        # Type of padded section is assumed to be of `custom` type:
        # http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html
        types = [5.0] + types

    all_coords_of_branches = []
    for i, branch in enumerate(sorted_branches):
        # Remove 1 because `content` is an array that is indexed from 0.
        branch = np.asarray(branch) - 1

        # Deal with additional branch that might have been added above in the lines
        # `if np.sum(np.asarray(parents) == -1) > 1.0:`
        branch[branch < 0] = 0

        # Get traced coordinates of the branch.
        coords_of_branch = content[branch, 2:6]
        all_coords_of_branches.append(coords_of_branch)

    return parents, pathlengths, radius_fns, types, all_coords_of_branches

nested_checkpoint_scan(f, init, xs, length=None, *, nested_lengths, scan_fn=jax.lax.scan, checkpoint_fn=jax.checkpoint)

A version of lax.scan that supports recursive gradient checkpointing.

Code taken from: https://github.com/google/jax/issues/2139

The interface of nested_checkpoint_scan exactly matches lax.scan, except for the required nested_lengths argument.

The key feature of nested_checkpoint_scan is that gradient calculations require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested scans, which it achieves by re-evaluating the forward pass len(nested_lengths) - 1 times.

nested_checkpoint_scan reduces to lax.scan when nested_lengths has a single element.

Parameters:

Name Type Description Default
f Callable[[Carry, Dict[str, ndarray]], Tuple[Carry, Output]]

function to scan over.

required
init Carry

initial value.

required
xs Dict[str, ndarray]

scanned over values.

required
length Optional[int]

leading length of all dimensions

None
nested_lengths Sequence[int]

required list of lengths to scan over for each level of checkpointing. The product of nested_lengths must match length (if provided) and the size of the leading axis for all arrays in xs.

required
scan_fn

function matching the API of lax.scan

scan
checkpoint_fn Callable[[Func], Func]

function matching the API of jax.checkpoint.

checkpoint
Source code in jaxley/utils/jax_utils.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def nested_checkpoint_scan(
    f: Callable[[Carry, Dict[str, jnp.ndarray]], Tuple[Carry, Output]],
    init: Carry,
    xs: Dict[str, jnp.ndarray],
    length: Optional[int] = None,
    *,
    nested_lengths: Sequence[int],
    scan_fn=jax.lax.scan,
    checkpoint_fn: Callable[[Func], Func] = jax.checkpoint,
):
    """A version of lax.scan that supports recursive gradient checkpointing.

    Code taken from: https://github.com/google/jax/issues/2139

    The interface of `nested_checkpoint_scan` exactly matches lax.scan, except for
    the required `nested_lengths` argument.

    The key feature of `nested_checkpoint_scan` is that gradient calculations
    require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested
    scans, which it achieves by re-evaluating the forward pass
    `len(nested_lengths) - 1` times.

    `nested_checkpoint_scan` reduces to `lax.scan` when `nested_lengths` has a
    single element.

    Args:
        f: function to scan over.
        init: initial value.
        xs: scanned over values.
        length: leading length of all dimensions
        nested_lengths: required list of lengths to scan over for each level of
            checkpointing. The product of nested_lengths must match length (if
            provided) and the size of the leading axis for all arrays in ``xs``.
        scan_fn: function matching the API of lax.scan
        checkpoint_fn: function matching the API of jax.checkpoint.
    """
    if length is not None and length != math.prod(nested_lengths):
        raise ValueError(f"inconsistent {length=} and {nested_lengths=}")

    def nested_reshape(x):
        x = jnp.asarray(x)
        new_shape = tuple(nested_lengths) + x.shape[1:]
        return x.reshape(new_shape)

    sub_xs = jax.tree_map(nested_reshape, xs)
    return _inner_nested_scan(f, init, sub_xs, nested_lengths, scan_fn, checkpoint_fn)

gather_synapes(number_of_compartments, post_syn_comp_inds, current_each_synapse_voltage_term, current_each_synapse_constant_term)

Compute current at the post synapse.

All this does it that it sums the synaptic currents that come into a particular compartment. It returns an array of as many elements as there are compartments.

Source code in jaxley/utils/syn_utils.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def gather_synapes(
    number_of_compartments: jnp.ndarray,
    post_syn_comp_inds: np.ndarray,
    current_each_synapse_voltage_term: jnp.ndarray,
    current_each_synapse_constant_term: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute current at the post synapse.

    All this does it that it sums the synaptic currents that come into a particular
    compartment. It returns an array of as many elements as there are compartments.
    """
    incoming_currents_voltages = jnp.zeros((number_of_compartments,))
    incoming_currents_contant = jnp.zeros((number_of_compartments,))

    dnums = ScatterDimensionNumbers(
        update_window_dims=(),
        inserted_window_dims=(0,),
        scatter_dims_to_operand_dims=(0,),
    )
    incoming_currents_voltages = scatter_add(
        incoming_currents_voltages,
        post_syn_comp_inds[:, None],
        current_each_synapse_voltage_term,
        dnums,
    )
    incoming_currents_contant = scatter_add(
        incoming_currents_contant,
        post_syn_comp_inds[:, None],
        current_each_synapse_constant_term,
        dnums,
    )
    return incoming_currents_voltages, incoming_currents_contant