Skip to content

Utils

build_radiuses_from_xyzr(radius_fns, branch_indices, min_radius, ncomp)

Return the radiuses of branches given SWC file xyzr.

Returns an array of shape (num_branches, ncomp).

Parameters:

Name Type Description Default
radius_fns List[Callable]

Functions which, given compartment locations return the radius.

required
branch_indices List[int]

The indices of the branches for which to return the radiuses.

required
min_radius Optional[float]

If passed, the radiuses are clipped to be at least as large.

required
ncomp int

The number of compartments that every branch is discretized into.

required
Source code in jaxley/utils/cell_utils.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def build_radiuses_from_xyzr(
    radius_fns: List[Callable],
    branch_indices: List[int],
    min_radius: Optional[float],
    ncomp: int,
) -> jnp.ndarray:
    """Return the radiuses of branches given SWC file xyzr.

    Returns an array of shape `(num_branches, ncomp)`.

    Args:
        radius_fns: Functions which, given compartment locations return the radius.
        branch_indices: The indices of the branches for which to return the radiuses.
        min_radius: If passed, the radiuses are clipped to be at least as large.
        ncomp: The number of compartments that every branch is discretized into.
    """
    # Compartment locations are at the center of the internal nodes.
    non_split = 1 / ncomp
    range_ = np.linspace(non_split / 2, 1 - non_split / 2, ncomp)

    # Build radiuses.
    radiuses = np.asarray([radius_fns[b](range_) for b in branch_indices])
    radiuses_each = radiuses.ravel(order="C")
    if min_radius is None:
        assert np.all(
            radiuses_each > 0.0
        ), "Radius 0.0 in SWC file. Set `read_swc(..., min_radius=...)`."
    else:
        radiuses_each[radiuses_each < min_radius] = min_radius

    return radiuses_each

compute_axial_conductances(comp_edges, params)

Given comp_edges, radius, length, r_a, cm, compute the axial conductances.

Note that the resulting axial conductances will already by divided by the capacitance cm.

Source code in jaxley/utils/cell_utils.py
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
def compute_axial_conductances(
    comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray]
) -> jnp.ndarray:
    """Given `comp_edges`, radius, length, r_a, cm, compute the axial conductances.

    Note that the resulting axial conductances will already by divided by the
    capacitance `cm`.
    """
    # `Compartment-to-compartment` (c2c) axial coupling conductances.
    condition = comp_edges["type"].to_numpy() == 0
    source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list())
    sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list())

    if len(sink_comp_inds) > 0:
        conds_c2c = (
            vmap(compute_coupling_cond, in_axes=(0, 0, 0, 0, 0, 0))(
                params["radius"][sink_comp_inds],
                params["radius"][source_comp_inds],
                params["axial_resistivity"][sink_comp_inds],
                params["axial_resistivity"][source_comp_inds],
                params["length"][sink_comp_inds],
                params["length"][source_comp_inds],
            )
            / params["capacitance"][sink_comp_inds]
        )
    else:
        conds_c2c = jnp.asarray([])

    # `branchpoint-to-compartment` (bp2c) axial coupling conductances.
    condition = comp_edges["type"].isin([1, 2])
    sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list())

    if len(sink_comp_inds) > 0:
        conds_bp2c = (
            vmap(compute_coupling_cond_branchpoint, in_axes=(0, 0, 0))(
                params["radius"][sink_comp_inds],
                params["axial_resistivity"][sink_comp_inds],
                params["length"][sink_comp_inds],
            )
            / params["capacitance"][sink_comp_inds]
        )
    else:
        conds_bp2c = jnp.asarray([])

    # `compartment-to-branchpoint` (c2bp) axial coupling conductances.
    condition = comp_edges["type"].isin([3, 4])
    source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list())

    if len(source_comp_inds) > 0:
        conds_c2bp = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(
            params["radius"][source_comp_inds],
            params["axial_resistivity"][source_comp_inds],
            params["length"][source_comp_inds],
        )
        # For numerical stability. These values are very small, but their scale
        # does not matter.
        conds_c2bp *= 1_000
    else:
        conds_c2bp = jnp.asarray([])

    # All axial coupling conductances.
    return jnp.concatenate([conds_c2c, conds_bp2c, conds_c2bp])

compute_children_and_parents(branch_edges)

Build indices used during `._init_morph_custom_spsolve().

Source code in jaxley/utils/cell_utils.py
768
769
770
771
772
773
774
775
776
def compute_children_and_parents(
    branch_edges: pd.DataFrame,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int]:
    """Build indices used during `._init_morph_custom_spsolve()."""
    par_inds = branch_edges["parent_branch_index"].to_numpy()
    child_inds = branch_edges["child_branch_index"].to_numpy()
    child_belongs_to_branchpoint = remap_to_consecutive(par_inds)
    par_inds = np.unique(par_inds)
    return par_inds, child_inds, child_belongs_to_branchpoint

compute_children_indices(parents)

Return all children indices of every branch.

Example:

parents = [-1, 0, 0]
compute_children_indices(parents) -> [[1, 2], [], []]

Source code in jaxley/utils/cell_utils.py
453
454
455
456
457
458
459
460
461
462
463
464
465
466
def compute_children_indices(parents) -> List[jnp.ndarray]:
    """Return all children indices of every branch.

    Example:
    ```
    parents = [-1, 0, 0]
    compute_children_indices(parents) -> [[1, 2], [], []]
    ```
    """
    num_branches = len(parents)
    child_indices = []
    for b in range(num_branches):
        child_indices.append(np.where(parents == b)[0])
    return child_indices

compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2)

Return the coupling conductance between two compartments.

Equations taken from https://en.wikipedia.org/wiki/Compartmental_neuron_models.

radius: um r_a: ohm cm length_single_compartment: um coupling_conds: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2

Source code in jaxley/utils/cell_utils.py
515
516
517
518
519
520
521
522
523
524
525
526
def compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2):
    """Return the coupling conductance between two compartments.

    Equations taken from `https://en.wikipedia.org/wiki/Compartmental_neuron_models`.

    `radius`: um
    `r_a`: ohm cm
    `length_single_compartment`: um
    `coupling_conds`: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2
    """
    # Multiply by 10**7 to convert (S / cm / um) -> (mS / cm^2).
    return rad1 * rad2**2 / (r_a1 * rad2**2 * l1 + r_a2 * rad1**2 * l2) / l1 * 10**7

compute_coupling_cond_branchpoint(rad, r_a, l)

Return the coupling conductance between one compartment and a comp with l=0.

From https://en.wikipedia.org/wiki/Compartmental_neuron_models

If one compartment has l=0.0 then the equations simplify.

R_long = \sum_i r_a * L_i/2 / crosssection_i

with crosssection = pi * r**2

For a single compartment with L>0, this turns into: R_long = r_a * L/2 / crosssection

Then, g_long = crosssection * 2 / L / r_a

Then, the effective conductance is g_long / zylinder_area. So: g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L g = r / r_a / L**2

Source code in jaxley/utils/cell_utils.py
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
def compute_coupling_cond_branchpoint(rad, r_a, l):
    r"""Return the coupling conductance between one compartment and a comp with l=0.

    From https://en.wikipedia.org/wiki/Compartmental_neuron_models

    If one compartment has l=0.0 then the equations simplify.

    R_long = \sum_i r_a * L_i/2 / crosssection_i

    with crosssection = pi * r**2

    For a single compartment with L>0, this turns into:
    R_long = r_a * L/2 / crosssection

    Then, g_long = crosssection * 2 / L / r_a

    Then, the effective conductance is g_long / zylinder_area. So:
    g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L
    g = r / r_a / L**2
    """
    return rad / r_a / l**2 * 10**7  # Convert (S / cm / um) -> (mS / cm^2)

compute_impact_on_node(rad, r_a, l)

Compute the weight with which a compartment influences its node.

In order to satisfy Kirchhoffs current law, the current at a branch point must be proportional to the crosssection of the compartment. We only require proportionality here because the branch point equation reads: g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0

Because R_long = r_a * L/2 / crosssection, we get g_long = crosssection * 2 / L / r_a \propto rad**2 / L / r_a

This equation can be multiplied by any constant.

Source code in jaxley/utils/cell_utils.py
552
553
554
555
556
557
558
559
560
561
562
563
564
def compute_impact_on_node(rad, r_a, l):
    r"""Compute the weight with which a compartment influences its node.

    In order to satisfy Kirchhoffs current law, the current at a branch point must be
    proportional to the crosssection of the compartment. We only require proportionality
    here because the branch point equation reads:
    `g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0`

    Because R_long = r_a * L/2 / crosssection, we get
    g_long = crosssection * 2 / L / r_a \propto rad**2 / L / r_a

    This equation can be multiplied by any constant."""
    return rad**2 / r_a / l

compute_morphology_indices_in_levels(num_branchpoints, child_belongs_to_branchpoint, par_inds, child_inds)

Return (row, col) to build the sparse matrix defining the voltage eqs.

This is run at init, not during runtime.

Source code in jaxley/utils/cell_utils.py
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
def compute_morphology_indices_in_levels(
    num_branchpoints,
    child_belongs_to_branchpoint,
    par_inds,
    child_inds,
):
    """Return (row, col) to build the sparse matrix defining the voltage eqs.

    This is run at `init`, not during runtime.
    """
    branchpoint_inds_parents = jnp.arange(num_branchpoints)
    branchpoint_inds_children = child_belongs_to_branchpoint
    branch_inds_parents = par_inds
    branch_inds_children = child_inds

    children = jnp.stack([branch_inds_children, branchpoint_inds_children])
    parents = jnp.stack([branch_inds_parents, branchpoint_inds_parents])

    return {"children": children.T, "parents": parents.T}

convert_point_process_to_distributed(current, radius, length)

Convert current point process (nA) to distributed current (uA/cm2).

This function gets called for synapses and for external stimuli.

Parameters:

Name Type Description Default
current ndarray

Current in nA.

required
radius ndarray

Compartment radius in um.

required
length ndarray

Compartment length in um.

required
Return

Current in uA/cm2.

Source code in jaxley/utils/cell_utils.py
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
def convert_point_process_to_distributed(
    current: jnp.ndarray, radius: jnp.ndarray, length: jnp.ndarray
) -> jnp.ndarray:
    """Convert current point process (nA) to distributed current (uA/cm2).

    This function gets called for synapses and for external stimuli.

    Args:
        current: Current in `nA`.
        radius: Compartment radius in `um`.
        length: Compartment length in `um`.

    Return:
        Current in `uA/cm2`.
    """
    area = 2 * pi * radius * length
    current /= area  # nA / um^2
    return current * 100_000  # Convert (nA / um^2) to (uA / cm^2)

equal_segments(branch_property, ncomp_per_branch)

Generates segments where some property is the same in each segment.

Parameters:

Name Type Description Default
branch_property list

List of values of the property in each branch. Should have len(branch_property) == num_branches.

required
Source code in jaxley/utils/cell_utils.py
301
302
303
304
305
306
307
308
309
def equal_segments(branch_property: list, ncomp_per_branch: int):
    """Generates segments where some property is the same in each segment.

    Args:
        branch_property: List of values of the property in each branch. Should have
            `len(branch_property) == num_branches`.
    """
    assert isinstance(branch_property, list), "branch_property must be a list."
    return jnp.asarray([branch_property] * ncomp_per_branch).T

get_num_neighbours(num_children, ncomp_per_branch, num_branches)

Number of neighbours of each compartment.

Source code in jaxley/utils/cell_utils.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def get_num_neighbours(
    num_children: jnp.ndarray,
    ncomp_per_branch: int,
    num_branches: int,
):
    """
    Number of neighbours of each compartment.
    """
    num_neighbours = 2 * jnp.ones((num_branches * ncomp_per_branch))
    num_neighbours = num_neighbours.at[ncomp_per_branch - 1].set(1.0)
    num_neighbours = num_neighbours.at[jnp.arange(num_branches) * ncomp_per_branch].set(
        num_children + 1.0
    )
    return num_neighbours

group_and_sum(values_to_sum, inds_to_group_by, num_branchpoints)

Group values by whether they have the same integer and sum values within group.

This is used to construct the last diagonals at the branch points.

Written by ChatGPT.

Source code in jaxley/utils/cell_utils.py
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
def group_and_sum(
    values_to_sum: jnp.ndarray, inds_to_group_by: jnp.ndarray, num_branchpoints: int
) -> jnp.ndarray:
    """Group values by whether they have the same integer and sum values within group.

    This is used to construct the last diagonals at the branch points.

    Written by ChatGPT.
    """
    # Initialize an array to hold the sum of each group
    group_sums = jnp.zeros(num_branchpoints)

    # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.
    # `len(inds) == 0` is the case for branches and compartments.
    if num_branchpoints > 0:
        group_sums = group_sums.at[inds_to_group_by].add(values_to_sum)

    return group_sums

interpolate_xyzr(loc, coords)

Perform a linear interpolation between xyz-coordinates.

Parameters:

Name Type Description Default
loc float

The location in [0,1] along the branch.

required
coords ndarray

Array containing the reconstructed xyzr points of the branch.

required
Return

Interpolated xyz coordinate at loc, shape `(3,).

Source code in jaxley/utils/cell_utils.py
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
def interpolate_xyzr(loc: float, coords: np.ndarray):
    """Perform a linear interpolation between xyz-coordinates.

    Args:
        loc: The location in [0,1] along the branch.
        coords: Array containing the reconstructed xyzr points of the branch.

    Return:
        Interpolated xyz coordinate at `loc`, shape `(3,).
    """
    dl = np.sqrt(np.sum(np.diff(coords[:, :3], axis=0) ** 2, axis=1))
    pathlens = np.insert(np.cumsum(dl), 0, 0)  # cummulative length of sections
    norm_pathlens = pathlens / np.maximum(1e-8, pathlens[-1])  # norm lengths to [0,1].

    return v_interp(loc, norm_pathlens, coords)

linear_segments(initial_val, endpoint_vals, parents, ncomp_per_branch)

Generates segments where some property is linearly interpolated.

Parameters:

Name Type Description Default
initial_val float

The value at the tip of the soma.

required
endpoint_vals list

The value at the endpoints of each branch.

required
Source code in jaxley/utils/cell_utils.py
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def linear_segments(
    initial_val: float, endpoint_vals: list, parents: jnp.ndarray, ncomp_per_branch: int
):
    """Generates segments where some property is linearly interpolated.

    Args:
        initial_val: The value at the tip of the soma.
        endpoint_vals: The value at the endpoints of each branch.
    """
    branch_property = endpoint_vals + [initial_val]
    num_branches = len(parents)
    # Compute radiuses by linear interpolation.
    endpoint_radiuses = jnp.asarray(branch_property)

    def compute_rad(branch_ind, loc):
        start = endpoint_radiuses[parents[branch_ind]]
        end = endpoint_radiuses[branch_ind]
        return (end - start) * loc + start

    branch_inds_of_each_comp = jnp.tile(jnp.arange(num_branches), ncomp_per_branch)
    locs_of_each_comp = jnp.linspace(1, 0, ncomp_per_branch).repeat(num_branches)
    rad_of_each_comp = compute_rad(branch_inds_of_each_comp, locs_of_each_comp)

    return jnp.reshape(rad_of_each_comp, (ncomp_per_branch, num_branches)).T

loc_of_index(global_comp_index, global_branch_index, ncomp_per_branch)

Return location corresponding to global compartment index.

Source code in jaxley/utils/cell_utils.py
507
508
509
510
511
512
def loc_of_index(global_comp_index, global_branch_index, ncomp_per_branch):
    """Return location corresponding to global compartment index."""
    cumsum_ncomp = cumsum_leading_zero(ncomp_per_branch)
    index = global_comp_index - cumsum_ncomp[global_branch_index]
    ncomp = ncomp_per_branch[global_branch_index]
    return (0.5 + index) / ncomp

local_index_of_loc(loc, global_branch_ind, ncomp_per_branch)

Returns the local index of a comp given a loc [0, 1] and the index of a branch.

This is used because we specify locations such as synapses as a value between 0 and 1. We have to convert this onto a discrete segment here.

Parameters:

Name Type Description Default
branch_ind

Index of the branch.

required
loc float

Location (in [0, 1]) along that branch.

required
ncomp_per_branch int

Number of segments of each branch.

required

Returns:

Type Description
int

The local index of the compartment.

Source code in jaxley/utils/cell_utils.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
def local_index_of_loc(
    loc: float, global_branch_ind: int, ncomp_per_branch: int
) -> int:
    """Returns the local index of a comp given a loc [0, 1] and the index of a branch.

    This is used because we specify locations such as synapses as a value between 0 and
    1. We have to convert this onto a discrete segment here.

    Args:
        branch_ind: Index of the branch.
        loc: Location (in [0, 1]) along that branch.
        ncomp_per_branch: Number of segments of each branch.

    Returns:
        The local index of the compartment.
    """
    ncomp = ncomp_per_branch[global_branch_ind]  # only for convenience.
    possible_locs = np.linspace(0.5 / ncomp, 1 - 0.5 / ncomp, ncomp)
    ind_along_branch = np.argmin(np.abs(possible_locs - loc))
    return ind_along_branch

merge_cells(cumsum_num_branches, cumsum_num_branchpoints, arrs, exclude_first=True)

Build full list of which branches are solved in which iteration.

From the branching pattern of single cells, this “merges” them into a single ordering of branches.

Parameters:

Name Type Description Default
cumsum_num_branches List[int]

cumulative number of branches. E.g., for three cells with 10, 15, and 5 branches respectively, this will should be a list containing [0, 10, 25, 30].

required
arrs List[List[ndarray]]

A list of a list of arrays that should be merged.

required
exclude_first bool

If True, the first element of each list in arrs will remain unchanged. Useful if a -1 (which indicates “no parent”) entry should not be changed.

True

Returns:

Type Description
ndarray

A list of arrays which contain the branch indices that are computed at each

ndarray

level (i.e., iteration).

Source code in jaxley/utils/cell_utils.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
def merge_cells(
    cumsum_num_branches: List[int],
    cumsum_num_branchpoints: List[int],
    arrs: List[List[np.ndarray]],
    exclude_first: bool = True,
) -> np.ndarray:
    """
    Build full list of which branches are solved in which iteration.

    From the branching pattern of single cells, this "merges" them into a single
    ordering of branches.

    Args:
        cumsum_num_branches: cumulative number of branches. E.g., for three cells with
            10, 15, and 5 branches respectively, this will should be a list containing
            `[0, 10, 25, 30]`.
        arrs: A list of a list of arrays that should be merged.
        exclude_first: If `True`, the first element of each list in `arrs` will remain
            unchanged. Useful if a `-1` (which indicates "no parent") entry should not
            be changed.

    Returns:
        A list of arrays which contain the branch indices that are computed at each
        level (i.e., iteration).
    """
    ps = []
    for i, att in enumerate(arrs):
        p = att
        if exclude_first:
            raise NotImplementedError
            p = [p[0]] + [p_in_level + cumsum_num_branches[i] for p_in_level in p[1:]]
        else:
            p = [
                p_in_level
                + np.asarray([cumsum_num_branches[i], cumsum_num_branchpoints[i]])
                for p_in_level in p
            ]
        ps.append(p)

    max_len = max([len(att) for att in arrs])
    combined_parents_in_level = []
    for i in range(max_len):
        current_ps = []
        for p in ps:
            if len(p) > i:
                current_ps.append(p[i])
        combined_parents_in_level.append(np.concatenate(current_ps))

    return combined_parents_in_level

params_to_pstate(params, indices_set_by_trainables)

Make outputs get_parameters() conform with outputs of .data_set().

make_trainable() followed by params=get_parameters() does not return indices because these indices would also be differentiated by jax.grad (as soon as the params are passed to def simulate(params). Therefore, in jx.integrate, we run the function to add indices to the dict. The outputs of params_to_pstate are of the same shape as the outputs of .data_set().

Source code in jaxley/utils/cell_utils.py
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
def params_to_pstate(
    params: List[Dict[str, jnp.ndarray]],
    indices_set_by_trainables: List[jnp.ndarray],
):
    """Make outputs `get_parameters()` conform with outputs of `.data_set()`.

    `make_trainable()` followed by `params=get_parameters()` does not return indices
    because these indices would also be differentiated by `jax.grad` (as soon as
    the `params` are passed to `def simulate(params)`. Therefore, in `jx.integrate`,
    we run the function to add indices to the dict. The outputs of `params_to_pstate`
    are of the same shape as the outputs of `.data_set()`."""
    return [
        {"key": list(p.keys())[0], "val": list(p.values())[0], "indices": i}
        for p, i in zip(params, indices_set_by_trainables)
    ]

query_channel_states_and_params(d, keys, idcs)

Get dict with subset of keys and values from d.

This is used to restrict a dict where every item contains all states to only the ones that are relevant for the channel. E.g.

states = {'eCa': Array([ 0., 0., nan]}

will be states = {'eCa': Array([ 0., 0.]}

Only loops over necessary keys, as opposed to looping over d.items().

Source code in jaxley/utils/cell_utils.py
689
690
691
692
693
694
695
696
697
698
699
700
701
def query_channel_states_and_params(d, keys, idcs):
    """Get dict with subset of keys and values from d.

    This is used to restrict a dict where every item contains __all__ states to only
    the ones that are relevant for the channel. E.g.

    ```states = {'eCa': Array([ 0.,  0., nan]}```

    will be
    ```states = {'eCa': Array([ 0.,  0.]}```

    Only loops over necessary keys, as opposed to looping over `d.items()`."""
    return dict(zip(keys, (v[idcs] for v in map(d.get, keys))))

remap_to_consecutive(arr)

Maps an array of integers to an array of consecutive integers.

E.g. [0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]

Source code in jaxley/utils/cell_utils.py
567
568
569
570
571
572
573
def remap_to_consecutive(arr):
    """Maps an array of integers to an array of consecutive integers.

    E.g. `[0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]`
    """
    _, inverse_indices = jnp.unique(arr, return_inverse=True)
    return inverse_indices

compute_rotation_matrix(axis, angle)

Return the rotation matrix associated with counterclockwise rotation about the given axis by the given angle.

Can be used to rotate a coordinate vector by multiplying it with the rotation matrix.

Parameters:

Name Type Description Default
axis ndarray

The axis of rotation.

required
angle float

The angle of rotation in radians.

required

Returns:

Type Description
ndarray

A 3x3 rotation matrix.

Source code in jaxley/utils/plot_utils.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def compute_rotation_matrix(axis: ndarray, angle: float) -> ndarray:
    """
    Return the rotation matrix associated with counterclockwise rotation about
    the given axis by the given angle.

    Can be used to rotate a coordinate vector by multiplying it with the rotation
    matrix.

    Args:
        axis: The axis of rotation.
        angle: The angle of rotation in radians.

    Returns:
        A 3x3 rotation matrix.
    """
    axis = axis / np.sqrt(np.dot(axis, axis))
    a = np.cos(angle / 2.0)
    b, c, d = -axis * np.sin(angle / 2.0)
    aa, bb, cc, dd = a * a, b * b, c * c, d * d
    bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
    return np.array(
        [
            [aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
            [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
            [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc],
        ]
    )

create_cone_frustum_mesh(length, radius_bottom, radius_top, bottom_dome=False, top_dome=False, resolution=100)

Generates mesh points for a cone frustum, with optional domes at either end.

This is used to render the traced morphology in 3D (and to project it to 2D) as part of plot_morph. Sections between two traced coordinates with two different radii can be represented by a cone frustum. Additionally, the ends of the frustum can be capped with hemispheres to ensure that two neighbouring frustums are connected smoothly (like ball joints).

Parameters:

Name Type Description Default
length float

The length of the frustum.

required
radius_bottom float

The radius of the bottom of the frustum.

required
radius_top float

The radius of the top of the frustum.

required
bottom_dome bool

If True, a dome is added to the bottom of the frustum. The dome is a hemisphere with radius radius_bottom.

False
top_dome bool

If True, a dome is added to the top of the frustum. The dome is a hemisphere with radius radius_top.

False
resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100

Returns:

Type Description
ndarray

An array of mesh points.

Source code in jaxley/utils/plot_utils.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def create_cone_frustum_mesh(
    length: float,
    radius_bottom: float,
    radius_top: float,
    bottom_dome: bool = False,
    top_dome: bool = False,
    resolution: int = 100,
) -> ndarray:
    """Generates mesh points for a cone frustum, with optional domes at either end.

    This is used to render the traced morphology in 3D (and to project it to 2D)
    as part of `plot_morph`. Sections between two traced coordinates with two
    different radii can be represented by a cone frustum. Additionally, the ends
    of the frustum can be capped with hemispheres to ensure that two neighbouring
    frustums are connected smoothly (like ball joints).

    Args:
        length: The length of the frustum.
        radius_bottom: The radius of the bottom of the frustum.
        radius_top: The radius of the top of the frustum.
        bottom_dome: If True, a dome is added to the bottom of the frustum.
            The dome is a hemisphere with radius `radius_bottom`.
        top_dome: If True, a dome is added to the top of the frustum.
            The dome is a hemisphere with radius `radius_top`.
        resolution: defines the resolution of the mesh.
            If too low (typically <10), can result in errors.
            Useful too have a simpler mesh for plotting.

    Returns:
        An array of mesh points.
    """

    t = np.linspace(0, 2 * np.pi, resolution)

    # Determine the total height including domes
    total_height = length
    total_height += radius_bottom if bottom_dome else 0
    total_height += radius_top if top_dome else 0

    z = np.linspace(0, total_height, resolution)
    t_grid, z_coords = np.meshgrid(t, z)

    # Initialize arrays
    x_coords = np.zeros_like(t_grid)
    y_coords = np.zeros_like(t_grid)
    r_coords = np.zeros_like(t_grid)

    # Bottom hemisphere
    if bottom_dome:
        dome_mask = z_coords < radius_bottom
        arg = 1 - z_coords[dome_mask] / radius_bottom
        arg[np.isclose(arg, 1, atol=1e-6, rtol=1e-6)] = 1
        arg[np.isclose(arg, -1, atol=1e-6, rtol=1e-6)] = -1
        phi = np.arccos(1 - z_coords[dome_mask] / radius_bottom)
        r_coords[dome_mask] = radius_bottom * np.sin(phi)
        z_coords[dome_mask] = z_coords[dome_mask]

    # Frustum
    frustum_start = radius_bottom if bottom_dome else 0
    frustum_end = total_height - (radius_top if top_dome else 0)
    frustum_mask = (z_coords >= frustum_start) & (z_coords <= frustum_end)
    z_frustum = z_coords[frustum_mask] - frustum_start
    r_coords[frustum_mask] = radius_bottom + (radius_top - radius_bottom) * (
        z_frustum / length
    )

    # Top hemisphere
    if top_dome:
        dome_mask = z_coords > (total_height - radius_top)
        arg = (z_coords[dome_mask] - (total_height - radius_top)) / radius_top
        arg[np.isclose(arg, 1, atol=1e-6, rtol=1e-6)] = 1
        arg[np.isclose(arg, -1, atol=1e-6, rtol=1e-6)] = -1
        phi = np.arccos(arg)
        r_coords[dome_mask] = radius_top * np.sin(phi)

    x_coords = r_coords * np.cos(t_grid)
    y_coords = r_coords * np.sin(t_grid)

    return np.stack([x_coords, y_coords, z_coords])

create_cylinder_mesh(length, radius, resolution=100)

Generates mesh points for a cylinder.

This is used to render cylindrical compartments in 3D (and to project it to 2D) as part of plot_comps.

Parameters:

Name Type Description Default
length float

The length of the cylinder.

required
radius float

The radius of the cylinder.

required
resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100

Returns:

Type Description
ndarray

An array of mesh points.

Source code in jaxley/utils/plot_utils.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def create_cylinder_mesh(
    length: float, radius: float, resolution: int = 100
) -> ndarray:
    """Generates mesh points for a cylinder.

    This is used to render cylindrical compartments in 3D (and to project it to 2D)
    as part of `plot_comps`.

    Args:
        length: The length of the cylinder.
        radius: The radius of the cylinder.
        resolution: defines the resolution of the mesh.
            If too low (typically <10), can result in errors.
            Useful too have a simpler mesh for plotting.

    Returns:
        An array of mesh points.
    """
    # Define cylinder
    t = np.linspace(0, 2 * np.pi, resolution)
    z_coords = np.linspace(-length / 2, length / 2, resolution)
    t_grid, z_coords = np.meshgrid(t, z_coords)

    x_coords = radius * np.cos(t_grid)
    y_coords = radius * np.sin(t_grid)
    return np.stack([x_coords, y_coords, z_coords])

create_sphere_mesh(radius, resolution=100)

Generates mesh points for a sphere.

This is used to render spherical compartments in 3D (and to project it to 2D) as part of plot_comps.

Parameters:

Name Type Description Default
radius float

The radius of the sphere.

required
resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100

Returns:

Type Description
ndarray

An array of mesh points.

Source code in jaxley/utils/plot_utils.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def create_sphere_mesh(radius: float, resolution: int = 100) -> np.ndarray:
    """Generates mesh points for a sphere.

    This is used to render spherical compartments in 3D (and to project it to 2D)
    as part of `plot_comps`.

    Args:
        radius: The radius of the sphere.
        resolution: defines the resolution of the mesh.
            If too low (typically <10), can result in errors.
            Useful too have a simpler mesh for plotting.

    Returns:
        An array of mesh points.
    """
    phi = np.linspace(0, np.pi, resolution)
    theta = np.linspace(0, 2 * np.pi, resolution)

    # Create a 2D meshgrid for phi and theta
    phi_coords, theta_coords = np.meshgrid(phi, theta)

    # Convert spherical coordinates to Cartesian coordinates
    x_coords = radius * np.sin(phi_coords) * np.cos(theta_coords)
    y_coords = radius * np.sin(phi_coords) * np.sin(theta_coords)
    z_coords = radius * np.cos(phi_coords)

    return np.stack([x_coords, y_coords, z_coords])

extract_outline(points)

Get the outline of a 2D/3D shape.

Extracts the subset of points which form the convex hull, i.e. the outline of the input points.

Parameters:

Name Type Description Default
points ndarray

An array of points / corrdinates.

required

Returns:

Type Description
ndarray

An array of points which form the convex hull.

Source code in jaxley/utils/plot_utils.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def extract_outline(points: ndarray) -> ndarray:
    """Get the outline of a 2D/3D shape.

    Extracts the subset of points which form the convex hull, i.e. the outline of
    the input points.

    Args:
        points: An array of points / corrdinates.

    Returns:
        An array of points which form the convex hull.
    """
    hull = ConvexHull(points)
    hull_points = points[hull.vertices]
    return hull_points

plot_comps(module_or_view, dims=(0, 1), col='k', ax=None, comp_plot_kwargs={}, true_comp_length=True, resolution=100)

Plot compartmentalized neural morphology.

Plots the projection of the cylindrical compartments.

Parameters:

Name Type Description Default
module_or_view Union[Module, View]

The module or view to plot.

required
dims Tuple[int]

The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D.

(0, 1)
col str

The color for all compartments

'k'
ax Optional[Axes]

The matplotlib axis to plot on.

None
comp_plot_kwargs Dict

The plot kwargs for plt.fill.

{}
true_comp_length bool

If True, the length of the compartment is used, i.e. the length of the traced neurite. This means for zig-zagging neurites the cylinders will be longer than the straight-line distance between the start and end point of the neurite. This can lead to overlapping and miss-aligned cylinders. Setting this False will use the straight-line distance instead for nicer plots.

True
resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100

Returns:

Type Description
Axes

Plot of the compartmentalized morphology.

Source code in jaxley/utils/plot_utils.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
def plot_comps(
    module_or_view: Union["jx.Module", "jx.View"],
    dims: Tuple[int] = (0, 1),
    col: str = "k",
    ax: Optional[Axes] = None,
    comp_plot_kwargs: Dict = {},
    true_comp_length: bool = True,
    resolution: int = 100,
) -> Axes:
    """Plot compartmentalized neural morphology.

    Plots the projection of the cylindrical compartments.

    Args:
        module_or_view: The module or view to plot.
        dims: The dimensions to plot / to project the cylinder onto,
            i.e. [0,1] xy-plane or [0,1,2] for 3D.
        col: The color for all compartments
        ax: The matplotlib axis to plot on.
        comp_plot_kwargs: The plot kwargs for plt.fill.
        true_comp_length: If True, the length of the compartment is used, i.e. the
            length of the traced neurite. This means for zig-zagging neurites the
            cylinders will be longer than the straight-line distance between the
            start and end point of the neurite. This can lead to overlapping and
            miss-aligned cylinders. Setting this False will use the straight-line
            distance instead for nicer plots.
        resolution: defines the resolution of the mesh.
            If too low (typically <10), can result in errors.
            Useful too have a simpler mesh for plotting.

    Returns:
        Plot of the compartmentalized morphology.
    """
    if ax is None:
        fig = plt.figure(figsize=(3, 3))
        ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection="3d")

    assert not np.any(
        np.isnan(module_or_view.xyzr[0][:, :3])
    ), "missing xyz coordinates."
    if "x" not in module_or_view.nodes.columns:
        module_or_view.compute_compartment_centers()

    for idx, xyzr in zip(module_or_view._branches_in_view, module_or_view.xyzr):
        locs = xyzr[:, :3]
        if locs.shape[0] == 1:  # assume spherical comp
            radius = xyzr[:, -1]
            center = xyzr[0, :3]
            if len(dims) == 3:
                xyz = create_sphere_mesh(radius, resolution)
                ax = plot_mesh(
                    xyz,
                    np.array([0, 0, 1]),
                    center,
                    np.array(dims),
                    ax,
                    color=col,
                    **comp_plot_kwargs,
                )
            else:
                ax.add_artist(plt.Circle(locs[0, dims], radius, color=col))
        else:
            lens = np.sqrt(np.nansum(np.diff(locs, axis=0) ** 2, axis=1))
            lens = np.cumsum([0] + lens.tolist())
            comp_ends = v_interp(
                np.linspace(0, lens[-1], module_or_view.ncomp + 1), lens, locs
            ).T
            axes = np.diff(comp_ends, axis=0)
            cylinder_lens = np.sqrt(np.sum(axes**2, axis=1))

            branch_df = module_or_view.nodes[
                module_or_view.nodes["global_branch_index"] == idx
            ]
            for l, axis, (i, comp) in zip(cylinder_lens, axes, branch_df.iterrows()):
                center = comp[["x", "y", "z"]]
                radius = comp["radius"]
                length = comp["length"] if true_comp_length else l
                xyz = create_cylinder_mesh(length, radius, resolution)
                ax = plot_mesh(
                    xyz,
                    axis,
                    center,
                    np.array(dims),
                    ax,
                    color=col,
                    **comp_plot_kwargs,
                )
    return ax

plot_graph(xyzr, dims=(0, 1), col='k', ax=None, type='line', morph_plot_kwargs={})

Plot morphology.

Parameters:

Name Type Description Default
xyzr ndarray

The coordinates of the morphology.

required
dims Tuple[int]

Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two or three of them.

(0, 1)
col str

The color for all branches.

'k'
ax Optional[Axes]

The matplotlib axis to plot on.

None
type str

Either line or scatter.

'line'
morph_plot_kwargs Dict

The plot kwargs for plt.plot or plt.scatter.

{}
Source code in jaxley/utils/plot_utils.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def plot_graph(
    xyzr: ndarray,
    dims: Tuple[int] = (0, 1),
    col: str = "k",
    ax: Optional[Axes] = None,
    type: str = "line",
    morph_plot_kwargs: Dict = {},
) -> Axes:
    """Plot morphology.

    Args:
        xyzr: The coordinates of the morphology.
        dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of
            two or three of them.
        col: The color for all branches.
        ax: The matplotlib axis to plot on.
        type: Either `line` or `scatter`.
        morph_plot_kwargs: The plot kwargs for plt.plot or plt.scatter.
    """

    if ax is None:
        fig = plt.figure(figsize=(3, 3))
        ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection="3d")

    for coords_of_branch in xyzr:
        points = coords_of_branch[:, dims].T

        if "line" in type.lower():
            _ = ax.plot(*points, color=col, **morph_plot_kwargs)
        elif "scatter" in type.lower():
            _ = ax.scatter(*points, color=col, **morph_plot_kwargs)
        else:
            raise NotImplementedError

    return ax

plot_mesh(mesh_points, orientation, center, dims, ax=None, **kwargs)

Plot the 2D projection of a volume mesh on a cardinal plane.

Project the projection of a cylinder that is oriented in 3D space. - Create cylinder mesh - rotate cylinder mesh to orient it lengthwise along a given orientation vector. - move its center - project onto plane - compute outline of projected mesh. - fill area inside the outline

Parameters:

Name Type Description Default
mesh_points ndarray

coordinates of the xyz mesh that define the volume

required
orientation ndarray

orientation vector. The cylinder will be oriented along this vector.

required
center ndarray

The x,y,z coordinates of the center of the cylinder.

required
dims Tuple[int]

The dimensions to plot / to project the cylinder onto,

required
ax Axes

The matplotlib axis to plot on.

None

Returns:

Type Description
Axes

Plot of the cylinder projection.

Source code in jaxley/utils/plot_utils.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def plot_mesh(
    mesh_points: ndarray,
    orientation: ndarray,
    center: ndarray,
    dims: Tuple[int],
    ax: Axes = None,
    **kwargs,
) -> Axes:
    """Plot the 2D projection of a volume mesh on a cardinal plane.

    Project the projection of a cylinder that is oriented in 3D space.
    - Create cylinder mesh
    - rotate cylinder mesh to orient it lengthwise along a given orientation vector.
    - move its center
    - project onto plane
    - compute outline of projected mesh.
    - fill area inside the outline

    Args:
        mesh_points: coordinates of the xyz mesh that define the volume
        orientation: orientation vector. The cylinder will be oriented along this vector.
        center: The x,y,z coordinates of the center of the cylinder.
        dims: The dimensions to plot / to project the cylinder onto,
        i.e. [0,1] xy-plane or [0,1,2] for 3D.
        ax: The matplotlib axis to plot on.

    Returns:
        Plot of the cylinder projection.
    """
    if ax is None:
        fig = plt.figure(figsize=(3, 3))
        ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection="3d")

    # Normalize axis vector
    orientation = np.array(orientation)
    orientation = orientation / np.linalg.norm(orientation)

    # Create a rotation matrix to align the cylinder with the given axis
    z_axis = np.array([0, 0, 1])
    rotation_axis = np.cross(z_axis, orientation)
    rotation_angle = np.arccos(np.dot(z_axis, orientation))

    if np.allclose(rotation_axis, 0):
        rotation_matrix = np.eye(3)
    else:
        rotation_matrix = compute_rotation_matrix(rotation_axis, rotation_angle)

    # Rotate mesh
    x_mesh, y_mesh, z_mesh = mesh_points
    rotated_mesh_points = np.dot(
        rotation_matrix,
        np.array([x_mesh.flatten(), y_mesh.flatten(), z_mesh.flatten()]),
    )
    rotated_mesh_points = rotated_mesh_points.reshape(3, -1)

    # project onto plane and move
    rotated_mesh_points = rotated_mesh_points[dims]
    rotated_mesh_points += np.array(center)[dims, np.newaxis]

    if len(dims) < 3:
        # get outline of cylinder mesh
        mesh_outline = extract_outline(rotated_mesh_points.T).T
        ax.fill(*mesh_outline.reshape(mesh_outline.shape[0], -1), **kwargs)
    else:
        # plot 3d mesh
        ax.plot_surface(*rotated_mesh_points.reshape(*mesh_points.shape), **kwargs)
    return ax

plot_morph(module_or_view, dims=(0, 1), col='k', ax=None, resolution=100, morph_plot_kwargs={})

Plot the detailed morphology.

Plots the traced morphology it was traced. That means at every point that was traced a disc of radius r is plotted. The outline of the discs are then connected to form the morphology. This means every trace segement can be represented by a cone frustum. To prevent breaks in the morphology, each segement is connected with a ball joint.

Parameters:

Name Type Description Default
module_or_view Union[Module, View]

The module or view to plot.

required
dims Tuple[int]

The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D.

(0, 1)
col str

The color for all branches

'k'
ax Optional[Axes]

The matplotlib axis to plot on.

None
morph_plot_kwargs Dict

The plot kwargs for plt.fill.

{}
resolution int

defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

100

Returns:

Type Description
Axes

Plot of the detailed morphology.

Source code in jaxley/utils/plot_utils.py
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
def plot_morph(
    module_or_view: Union["jx.Module", "jx.View"],
    dims: Tuple[int] = (0, 1),
    col: str = "k",
    ax: Optional[Axes] = None,
    resolution: int = 100,
    morph_plot_kwargs: Dict = {},
) -> Axes:
    """Plot the detailed morphology.

    Plots the traced morphology it was traced. That means at every point that was
    traced a disc of radius `r` is plotted. The outline of the discs are then
    connected to form the morphology. This means every trace segement can be
    represented by a cone frustum. To prevent breaks in the morphology, each
    segement is connected with a ball joint.

    Args:
        module_or_view: The module or view to plot.
        dims: The dimensions to plot / to project the cylinder onto,
            i.e. [0,1] xy-plane or [0,1,2] for 3D.
        col: The color for all branches
        ax: The matplotlib axis to plot on.
        morph_plot_kwargs: The plot kwargs for plt.fill.

        resolution: defines the resolution of the mesh.
            If too low (typically <10), can result in errors.
            Useful too have a simpler mesh for plotting.

    Returns:
        Plot of the detailed morphology."""
    if ax is None:
        fig = plt.figure(figsize=(3, 3))
        ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection="3d")
    if len(dims) == 3:
        warn(
            "rendering large morphologies in 3D can take a while. Consider projecting to 2D instead."
        )

    assert not np.any(
        np.isnan(module_or_view.xyzr[0][:, :3])
    ), "missing xyz coordinates."

    for xyzr in module_or_view.xyzr:
        if len(xyzr) > 1:
            for xyzr1, xyzr2 in zip(xyzr[1:, :], xyzr[:-1, :]):
                dxyz = xyzr2[:3] - xyzr1[:3]
                length = np.sqrt(np.sum(dxyz**2))
                points = create_cone_frustum_mesh(
                    length,
                    xyzr1[-1],
                    xyzr2[-1],
                    bottom_dome=True,
                    top_dome=True,
                    resolution=resolution,
                )
                plot_mesh(
                    points,
                    dxyz,
                    xyzr1[:3],
                    np.array(dims),
                    color=col,
                    ax=ax,
                    **morph_plot_kwargs,
                )
        else:
            points = create_cone_frustum_mesh(
                0,
                xyzr[:, -1],
                xyzr[:, -1],
                bottom_dome=True,
                top_dome=True,
                resolution=resolution,
            )
            plot_mesh(
                points,
                np.ones(3),
                xyzr[0, :3],
                dims=np.array(dims),
                color=col,
                ax=ax,
                **morph_plot_kwargs,
            )

    return ax

nested_checkpoint_scan(f, init, xs, length=None, *, nested_lengths, scan_fn=jax.lax.scan, checkpoint_fn=jax.checkpoint)

A version of lax.scan that supports recursive gradient checkpointing.

Code taken from: https://github.com/google/jax/issues/2139

The interface of nested_checkpoint_scan exactly matches lax.scan, except for the required nested_lengths argument.

The key feature of nested_checkpoint_scan is that gradient calculations require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested scans, which it achieves by re-evaluating the forward pass len(nested_lengths) - 1 times.

nested_checkpoint_scan reduces to lax.scan when nested_lengths has a single element.

Parameters:

Name Type Description Default
f Callable[[Carry, Dict[str, ndarray]], Tuple[Carry, Output]]

function to scan over.

required
init Carry

initial value.

required
xs Dict[str, ndarray]

scanned over values.

required
length Optional[int]

leading length of all dimensions

None
nested_lengths Sequence[int]

required list of lengths to scan over for each level of checkpointing. The product of nested_lengths must match length (if provided) and the size of the leading axis for all arrays in xs.

required
scan_fn

function matching the API of lax.scan

scan
checkpoint_fn Callable[[Func], Func]

function matching the API of jax.checkpoint.

checkpoint
Source code in jaxley/utils/jax_utils.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def nested_checkpoint_scan(
    f: Callable[[Carry, Dict[str, jnp.ndarray]], Tuple[Carry, Output]],
    init: Carry,
    xs: Dict[str, jnp.ndarray],
    length: Optional[int] = None,
    *,
    nested_lengths: Sequence[int],
    scan_fn=jax.lax.scan,
    checkpoint_fn: Callable[[Func], Func] = jax.checkpoint,
):
    """A version of lax.scan that supports recursive gradient checkpointing.

    Code taken from: https://github.com/google/jax/issues/2139

    The interface of `nested_checkpoint_scan` exactly matches lax.scan, except for
    the required `nested_lengths` argument.

    The key feature of `nested_checkpoint_scan` is that gradient calculations
    require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested
    scans, which it achieves by re-evaluating the forward pass
    `len(nested_lengths) - 1` times.

    `nested_checkpoint_scan` reduces to `lax.scan` when `nested_lengths` has a
    single element.

    Args:
        f: function to scan over.
        init: initial value.
        xs: scanned over values.
        length: leading length of all dimensions
        nested_lengths: required list of lengths to scan over for each level of
            checkpointing. The product of nested_lengths must match length (if
            provided) and the size of the leading axis for all arrays in ``xs``.
        scan_fn: function matching the API of lax.scan
        checkpoint_fn: function matching the API of jax.checkpoint.
    """
    if length is not None and length != math.prod(nested_lengths):
        raise ValueError(f"inconsistent {length=} and {nested_lengths=}")

    def nested_reshape(x):
        x = jnp.asarray(x)
        new_shape = tuple(nested_lengths) + x.shape[1:]
        return x.reshape(new_shape)

    sub_xs = jax.tree_util.tree_map(nested_reshape, xs)
    return _inner_nested_scan(f, init, sub_xs, nested_lengths, scan_fn, checkpoint_fn)

gather_synapes(number_of_compartments, post_syn_comp_inds, current_each_synapse_voltage_term, current_each_synapse_constant_term)

Compute current at the post synapse.

All this does it that it sums the synaptic currents that come into a particular compartment. It returns an array of as many elements as there are compartments.

Source code in jaxley/utils/syn_utils.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def gather_synapes(
    number_of_compartments: jnp.ndarray,
    post_syn_comp_inds: np.ndarray,
    current_each_synapse_voltage_term: jnp.ndarray,
    current_each_synapse_constant_term: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute current at the post synapse.

    All this does it that it sums the synaptic currents that come into a particular
    compartment. It returns an array of as many elements as there are compartments.
    """
    incoming_currents_voltages = jnp.zeros((number_of_compartments,))
    incoming_currents_contant = jnp.zeros((number_of_compartments,))

    dnums = ScatterDimensionNumbers(
        update_window_dims=(),
        inserted_window_dims=(0,),
        scatter_dims_to_operand_dims=(0,),
    )
    incoming_currents_voltages = scatter_add(
        incoming_currents_voltages,
        post_syn_comp_inds[:, None],
        current_each_synapse_voltage_term,
        dnums,
    )
    incoming_currents_contant = scatter_add(
        incoming_currents_contant,
        post_syn_comp_inds[:, None],
        current_each_synapse_constant_term,
        dnums,
    )
    return incoming_currents_voltages, incoming_currents_contant