Skip to content

Simulation

add_clamps(externals, external_inds, data_clamps=None)

Adds clamps to the external inputs.

Parameters:

Name Type Description Default
externals Dict

Current external inputs.

required
external_inds Dict

Current external indices.

required
data_clamps Optional[Tuple[str, ndarray, DataFrame]]

Additional data clamps. Defaults to None.

None

Returns:

Type Description
Tuple[Dict, Dict]

Tuple[Dict, Dict]: Updated external inputs and indices.

Source code in jaxley/integrate.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def add_clamps(
    externals: Dict,
    external_inds: Dict,
    data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None,
) -> Tuple[Dict, Dict]:
    """Adds clamps to the external inputs.

    Args:
        externals (Dict): Current external inputs.
        external_inds (Dict): Current external indices.
        data_clamps (Optional[Tuple[str, jnp.ndarray, pd.DataFrame]], optional): Additional data clamps. Defaults to None.

    Returns:
        Tuple[Dict, Dict]: Updated external inputs and indices.
    """
    # If a clamp is inserted, add it to the external inputs.
    if data_clamps is not None:
        state_name, clamps, inds = data_clamps
        if state_name in externals.keys():
            externals[state_name] = jnp.concatenate([externals[state_name], clamps])
            external_inds[state_name] = jnp.concatenate(
                [external_inds[state_name], inds.index.to_numpy()]
            )
        else:
            externals[state_name] = clamps
            external_inds[state_name] = inds.index.to_numpy()

    return externals, external_inds

add_stimuli(externals, external_inds, data_stimuli=None)

Extends the external inputs with the stimuli.

Parameters:

Name Type Description Default
externals Dict

Current external inputs.

required
external_inds Dict

Current external indices.

required
data_stimuli Optional[Tuple[ndarray, DataFrame]]

Additional data stimuli. Defaults to None.

None

Returns:

Type Description
Tuple[Dict, Dict]

Tuple[Dict, Dict]: Updated external inputs and indices.

Source code in jaxley/integrate.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def add_stimuli(
    externals: Dict,
    external_inds: Dict,
    data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,
) -> Tuple[Dict, Dict]:
    """Extends the external inputs with the stimuli.

    Args:
        externals (Dict): Current external inputs.
        external_inds (Dict): Current external indices.
        data_stimuli (Optional[Tuple[jnp.ndarray, pd.DataFrame]], optional): Additional data stimuli. Defaults to None.

    Returns:
        Tuple[Dict, Dict]: Updated external inputs and indices.
    """
    # If stimulus is inserted, add it to the external inputs.
    if "i" in externals.keys() or data_stimuli is not None:
        if "i" in externals.keys():
            if data_stimuli is not None:
                externals["i"] = jnp.concatenate([externals["i"], data_stimuli[1]])
                external_inds["i"] = jnp.concatenate(
                    [external_inds["i"], data_stimuli[2].index.to_numpy()]
                )
        else:
            externals["i"] = data_stimuli[1]
            external_inds["i"] = data_stimuli[2].index.to_numpy()

    return externals, external_inds

build_init_and_step_fn(module, voltage_solver='jaxley.stone', solver='bwd_euler')

This function returns the init_fn and step_fn which initialize the parameters and states of the neuron model and then step through the model

Parameters:

Name Type Description Default
module Module

A Module object that e.g. a cell.

required
voltage_solver str

Voltage solver used in step. Defaults to “jaxley.stone”.

'jaxley.stone'
solver str

ODE solver. Defaults to “bwd_euler”.

'bwd_euler'

Returns:

Type Description
Tuple[Callable, Callable]

init_fn, step_fn: Functions that initialize the state and parameters, and perform a single integration step, respectively.

Source code in jaxley/integrate.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def build_init_and_step_fn(
    module: Module,
    voltage_solver: str = "jaxley.stone",
    solver: str = "bwd_euler",
) -> Tuple[Callable, Callable]:
    """This function returns the `init_fn` and `step_fn` which initialize the
    parameters and states of the neuron model and then step through the model

    Args:
        module (Module): A `Module` object that e.g. a cell.
        voltage_solver (str, optional): Voltage solver used in step. Defaults to "jaxley.stone".
        solver (str, optional): ODE solver. Defaults to "bwd_euler".

    Returns:
        init_fn, step_fn: Functions that initialize the state and parameters, and perform
            a single integration step, respectively.
    """
    # Initialize the external inputs and their indices.
    external_inds = module.external_inds.copy()

    def init_fn(
        params: List[Dict[str, jnp.ndarray]],
        all_states: Optional[Dict] = None,
        param_state: Optional[List[Dict]] = None,
        delta_t: float = 0.025,
    ) -> Tuple[Dict, Dict]:
        """Initializes the parameters and states of the neuron model.

        Args:
            params (List[Dict[str, jnp.ndarray]]): List of trainable parameters.
            all_states (Optional[Dict], optional): State if alread initialized. Defaults to None.
            param_state (Optional[List[Dict]], optional): Parameters returned by `data_set`.. Defaults to None.
            delta_t (float, optional): Step size. Defaults to 0.025.

        Returns:
            Tuple[Dict, Dict]: All states and parameters.
        """
        # Make the `trainable_params` of the same shape as the `param_state`, such that
        # they can be processed together by `get_all_parameters`.
        pstate = params_to_pstate(params, module.indices_set_by_trainables)
        if param_state is not None:
            pstate += param_state

        all_params = module.get_all_parameters(pstate, voltage_solver=voltage_solver)
        all_states = (
            module.get_all_states(pstate, all_params, delta_t)
            if all_states is None
            else all_states
        )
        return all_states, all_params

    def step_fn(
        all_states: Dict,
        all_params: Dict,
        externals: Dict,
        external_inds: Dict = external_inds,
        delta_t: float = 0.025,
    ) -> Dict:
        """Performs a single integration step with step size delta_t.

        Args:
            all_states (Dict): Current state of the neuron model.
            all_params (Dict): Current parameters of the neuron model.
            externals (Dict): External inputs.
            external_inds (Dict, optional): External indices. Defaults to `module.external_inds`.
            delta_t (float, optional): Time step. Defaults to 0.025.

        Returns:
            Dict: Updated states.
        """
        state = all_states
        state = module.step(
            state,
            delta_t,
            external_inds,
            externals,
            params=all_params,
            solver=solver,
            voltage_solver=voltage_solver,
        )
        return state

    return init_fn, step_fn

integrate(module, params=[], *, param_state=None, data_stimuli=None, data_clamps=None, t_max=None, delta_t=0.025, solver='bwd_euler', voltage_solver='jaxley.stone', checkpoint_lengths=None, all_states=None, return_states=False)

Solves ODE and simulates neuron model.

Parameters:

Name Type Description Default
params List[Dict[str, ndarray]]

Trainable parameters returned by get_parameters().

[]
param_state Optional[List[Dict]]

Parameters returned by data_set.

None
data_stimuli Optional[Tuple[ndarray, DataFrame]]

Outputs of .data_stimulate(), only needed if stimuli change across function calls.

None
data_clamps Optional[Tuple[str, ndarray, DataFrame]]

Outputs of .data_clamp(), only needed if clamps change across function calls.

None
t_max Optional[float]

Duration of the simulation in milliseconds. If t_max is greater than the length of the stimulus input, the stimulus will be padded at the end with zeros. If t_max is smaller, then the stimulus with be truncated.

None
delta_t float

Time step of the solver in milliseconds.

0.025
solver str

Which ODE solver to use. Either of [“fwd_euler”, “bwd_euler”, “crank_nicolson”].

'bwd_euler'
tridiag_solver

Algorithm to solve tridiagonal systems. The different options only affect bwd_euler and crank_nicolson solvers. Either of [“stone”, “thomas”], where stone is much faster on GPU for long branches with many compartments and thomas is slightly faster on CPU (thomas is used in NEURON).

required
checkpoint_lengths Optional[List[int]]

Number of timesteps at every level of checkpointing. The prod(checkpoint_lengths) must be larger or equal to the desired number of simulated timesteps. Warning: the simulation is run for prod(checkpoint_lengths) timesteps, and the result is posthoc truncated to the desired simulation length. Therefore, a poor choice of checkpoint_lengths can lead to longer simulation time. If None, no checkpointing is applied.

None
all_states Optional[Dict]

An optional initial state that was returned by a previous jx.integrate(..., return_states=True) run. Overrides potentially trainable initial states.

None
return_states bool

If True, it returns all states such that the current state of the Module can be set with set_states.

False
Source code in jaxley/integrate.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def integrate(
    module: Module,
    params: List[Dict[str, jnp.ndarray]] = [],
    *,
    param_state: Optional[List[Dict]] = None,
    data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,
    data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None,
    t_max: Optional[float] = None,
    delta_t: float = 0.025,
    solver: str = "bwd_euler",
    voltage_solver: str = "jaxley.stone",
    checkpoint_lengths: Optional[List[int]] = None,
    all_states: Optional[Dict] = None,
    return_states: bool = False,
) -> jnp.ndarray:
    """
    Solves ODE and simulates neuron model.

    Args:
        params: Trainable parameters returned by `get_parameters()`.
        param_state: Parameters returned by `data_set`.
        data_stimuli: Outputs of `.data_stimulate()`, only needed if stimuli change
            across function calls.
        data_clamps: Outputs of `.data_clamp()`, only needed if clamps change across
            function calls.
        t_max: Duration of the simulation in milliseconds. If `t_max` is greater than
            the length of the stimulus input, the stimulus will be padded at the end
            with zeros. If `t_max` is smaller, then the stimulus with be truncated.
        delta_t: Time step of the solver in milliseconds.
        solver: Which ODE solver to use. Either of ["fwd_euler", "bwd_euler",
            "crank_nicolson"].
        tridiag_solver: Algorithm to solve tridiagonal systems. The  different options
            only affect `bwd_euler` and `crank_nicolson` solvers. Either of ["stone",
            "thomas"], where `stone` is much faster on GPU for long branches
            with many compartments and `thomas` is slightly faster on CPU (`thomas` is
            used in NEURON).
        checkpoint_lengths: Number of timesteps at every level of checkpointing. The
            `prod(checkpoint_lengths)` must be larger or equal to the desired number of
            simulated timesteps. Warning: the simulation is run for
            `prod(checkpoint_lengths)` timesteps, and the result is posthoc truncated
            to the desired simulation length. Therefore, a poor choice of
            `checkpoint_lengths` can lead to longer simulation time. If `None`, no
            checkpointing is applied.
        all_states: An optional initial state that was returned by a previous
            `jx.integrate(..., return_states=True)` run. Overrides potentially
            trainable initial states.
        return_states: If True, it returns all states such that the current state of
            the `Module` can be set with `set_states`.
    """

    assert module.initialized, "Module is not initialized, run `._initialize()`."
    module.to_jax()  # Creates `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.

    # Initialize the external inputs and their indices.
    externals = module.externals.copy()
    external_inds = module.external_inds.copy()

    # If stimulus is inserted, add it to the external inputs.
    externals, external_inds = add_stimuli(externals, external_inds, data_stimuli)

    # If a clamp is inserted, add it to the external inputs.
    externals, external_inds = add_clamps(externals, external_inds, data_clamps)

    if not externals.keys():
        # No stimulus was inserted and no clamp was set.
        assert (
            t_max is not None
        ), "If no stimulus or clamp are inserted you have to specify the simulation duration at `jx.integrate(..., t_max=)`."

    for key in externals.keys():
        externals[key] = externals[key].T  # Shape `(time, num_stimuli)`.

    if module.recordings.empty:
        raise ValueError("No recordings are set. Please set them.")
    rec_inds = module.recordings.rec_index.to_numpy()
    rec_states = module.recordings.state.to_numpy()

    # Shorten or pad stimulus depending on `t_max`.
    if t_max is not None:
        t_max_steps = int(t_max // delta_t + 1)

        # Pad or truncate the stimulus.
        for key in externals.keys():
            if t_max_steps > externals[key].shape[0]:
                if key == "i":
                    pad = jnp.zeros(
                        (t_max_steps - externals["i"].shape[0], externals["i"].shape[1])
                    )
                    externals["i"] = jnp.concatenate((externals["i"], pad))
                else:
                    raise NotImplementedError(
                        "clamp must be at least as long as simulation."
                    )
            else:
                externals[key] = externals[key][:t_max_steps, :]

    init_fn, step_fn = build_init_and_step_fn(
        module, voltage_solver=voltage_solver, solver=solver
    )
    all_states, all_params = init_fn(params, all_states, param_state, delta_t)

    def _body_fun(state, externals):
        state = step_fn(state, all_params, externals, external_inds, delta_t)
        recs = jnp.asarray(
            [
                state[rec_state][rec_ind]
                for rec_state, rec_ind in zip(rec_states, rec_inds)
            ]
        )
        return state, recs

    # If necessary, pad the stimulus with zeros in order to simulate sufficiently long.
    # The total simulation length will be `prod(checkpoint_lengths)`. At the end, we
    # return only the first `nsteps_to_return` elements (plus the initial state).
    if externals:
        example_key = list(externals.keys())[0]
        nsteps_to_return = len(externals[example_key])
    else:
        nsteps_to_return = t_max_steps

    if checkpoint_lengths is None:
        checkpoint_lengths = [nsteps_to_return]
        length = nsteps_to_return
    else:
        length = prod(checkpoint_lengths)
        size_difference = length - nsteps_to_return
        assert (
            nsteps_to_return <= length
        ), "The desired simulation duration is longer than `prod(nested_length)`."
        if externals:
            dummy_external = jnp.zeros(
                (size_difference, externals[example_key].shape[1])
            )
            for key in externals.keys():
                externals[key] = jnp.concatenate([externals[key], dummy_external])

    # Record the initial state.
    init_recs = jnp.asarray(
        [
            all_states[rec_state][rec_ind]
            for rec_state, rec_ind in zip(rec_states, rec_inds)
        ]
    )
    init_recording = jnp.expand_dims(init_recs, axis=0)

    # Run simulation.
    all_states, recordings = nested_checkpoint_scan(
        _body_fun,
        all_states,
        externals,
        length=length,
        nested_lengths=checkpoint_lengths,
    )
    recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T
    return (recs, all_states) if return_states else recs

exponential_euler(x, dt, x_inf, x_tau)

An exact solver for the linear dynamical system dx = -(x - x_inf) / x_tau.

Source code in jaxley/solver_gate.py
36
37
38
39
40
41
42
43
44
def exponential_euler(
    x: jnp.ndarray,
    dt: float,
    x_inf: jnp.ndarray,
    x_tau: jnp.ndarray,
):
    """An exact solver for the linear dynamical system `dx = -(x - x_inf) / x_tau`."""
    exp_term = save_exp(-dt / x_tau)
    return x * exp_term + x_inf * (1.0 - exp_term)

save_exp(x, max_value=20.0)

Clip the input to a maximum value and return its exponential.

Source code in jaxley/solver_gate.py
 7
 8
 9
10
def save_exp(x, max_value: float = 20.0):
    """Clip the input to a maximum value and return its exponential."""
    x = jnp.clip(x, a_max=max_value)
    return jnp.exp(x)

solve_inf_gate_exponential(x, dt, s_inf, tau_s)

solves dx/dt = (s_inf - x) / tau_s via exponential Euler

Parameters:

Name Type Description Default
x ndarray

gate variable

required
dt float

time_delta

required
s_inf ndarray

description

required
tau_s ndarray

description

required

Returns:

Name Type Description
_type_

updated gate

Source code in jaxley/solver_gate.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def solve_inf_gate_exponential(
    x: jnp.ndarray,
    dt: float,
    s_inf: jnp.ndarray,
    tau_s: jnp.ndarray,
):
    """solves dx/dt = (s_inf - x) / tau_s
    via exponential Euler

    Args:
        x (jnp.ndarray): gate variable
        dt (float): time_delta
        s_inf (jnp.ndarray): _description_
        tau_s (jnp.ndarray): _description_

    Returns:
        _type_: updated gate
    """
    slope = -1.0 / tau_s
    exp_term = save_exp(slope * dt)
    return x * exp_term + s_inf * (1.0 - exp_term)

step_voltage_explicit(voltages, voltage_terms, constant_terms, axial_conductances, internal_node_inds, sinks, sources, types, ncomp_per_branch, par_inds, child_inds, nbranches, solver, delta_t, idx, debug_states)

Solve one timestep of branched nerve equations with explicit (forward) Euler.

Source code in jaxley/solver_voltage.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def step_voltage_explicit(
    voltages: jnp.ndarray,
    voltage_terms: jnp.ndarray,
    constant_terms: jnp.ndarray,
    axial_conductances: jnp.ndarray,
    internal_node_inds: jnp.ndarray,
    sinks: jnp.ndarray,
    sources: jnp.ndarray,
    types: jnp.ndarray,
    ncomp_per_branch: jnp.ndarray,
    par_inds: jnp.ndarray,
    child_inds: jnp.ndarray,
    nbranches: int,
    solver: str,
    delta_t: float,
    idx: JaxleySolveIndexer,
    debug_states,
) -> jnp.ndarray:
    """Solve one timestep of branched nerve equations with explicit (forward) Euler."""
    voltages = jnp.reshape(voltages, (nbranches, -1))
    voltage_terms = jnp.reshape(voltage_terms, (nbranches, -1))
    constant_terms = jnp.reshape(constant_terms, (nbranches, -1))

    update = _voltage_vectorfield(
        voltages,
        voltage_terms,
        constant_terms,
        types,
        sources,
        sinks,
        axial_conductances,
        par_inds,
        child_inds,
        nbranches,
        solver,
        delta_t,
        idx,
        debug_states,
    )
    new_voltates = voltages + delta_t * update
    return new_voltates.ravel(order="C")

step_voltage_implicit_with_jaxley_spsolve(voltages, voltage_terms, constant_terms, axial_conductances, internal_node_inds, sinks, sources, types, ncomp_per_branch, par_inds, child_inds, nbranches, solver, delta_t, idx, debug_states)

Solve one timestep of branched nerve equations with implicit (backward) Euler.

Source code in jaxley/solver_voltage.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def step_voltage_implicit_with_jaxley_spsolve(
    voltages: jnp.ndarray,
    voltage_terms: jnp.ndarray,
    constant_terms: jnp.ndarray,
    axial_conductances: jnp.ndarray,
    internal_node_inds: jnp.ndarray,
    sinks: jnp.ndarray,
    sources: jnp.ndarray,
    types: jnp.ndarray,
    ncomp_per_branch: jnp.ndarray,
    par_inds: jnp.ndarray,
    child_inds: jnp.ndarray,
    nbranches: int,
    solver: str,
    delta_t: float,
    idx: JaxleySolveIndexer,
    debug_states,
):
    """Solve one timestep of branched nerve equations with implicit (backward) Euler."""
    # Build diagonals.
    c2c = np.isin(types, [0, 1, 2])
    total_ncomp = idx.cumsum_ncomp[-1]
    diags = jnp.ones(total_ncomp)

    # if-case needed because `.at` does not allow empty inputs, but the input is
    # empty for compartments.
    if len(sinks[c2c]) > 0:
        diags = diags.at[idx.mask(sinks[c2c])].add(delta_t * axial_conductances[c2c])

    diags = diags.at[idx.mask(internal_node_inds)].add(delta_t * voltage_terms)

    # Build solves.
    solves = jnp.zeros(total_ncomp)
    solves = solves.at[idx.mask(internal_node_inds)].add(
        voltages + delta_t * constant_terms
    )

    # Build upper and lower within the branch.
    c2c = types == 0  # c2c = compartment-to-compartment.

    # Build uppers.
    uppers = jnp.zeros(total_ncomp)
    upper_inds = sources[c2c] > sinks[c2c]
    sinks_upper = sinks[c2c][upper_inds]
    if len(sinks_upper) > 0:
        uppers = uppers.at[idx.mask(sinks_upper)].add(
            -delta_t * axial_conductances[c2c][upper_inds]
        )

    # Build lowers.
    lowers = jnp.zeros(total_ncomp)
    lower_inds = sources[c2c] < sinks[c2c]
    sinks_lower = sinks[c2c][lower_inds]
    if len(sinks_lower) > 0:
        lowers = lowers.at[idx.mask(sinks_lower)].add(
            -delta_t * axial_conductances[c2c][lower_inds]
        )

    # Build branchpoint conductances.
    branchpoint_conds_parents = axial_conductances[types == 1]
    branchpoint_conds_children = axial_conductances[types == 2]
    branchpoint_weights_parents = axial_conductances[types == 3]
    branchpoint_weights_children = axial_conductances[types == 4]
    all_branchpoint_vals = jnp.concatenate(
        [branchpoint_weights_parents, branchpoint_weights_children]
    )
    # Find unique group identifiers
    num_branchpoints = len(branchpoint_conds_parents)
    branchpoint_diags = -group_and_sum(
        all_branchpoint_vals, idx.branchpoint_group_inds, num_branchpoints
    )
    branchpoint_solves = jnp.zeros((num_branchpoints,))

    branchpoint_conds_children = -delta_t * branchpoint_conds_children
    branchpoint_conds_parents = -delta_t * branchpoint_conds_parents

    # Here, I move all child and parent indices towards a branchpoint into a larger
    # vector. This is wasteful, but it makes indexing much easier. JIT compiling
    # makes the speed difference negligible.
    # Children.
    bp_conds_children = jnp.zeros(nbranches)
    bp_weights_children = jnp.zeros(nbranches)
    # Parents.
    bp_conds_parents = jnp.zeros(nbranches)
    bp_weights_parents = jnp.zeros(nbranches)

    # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.
    # `len(inds) == 0` is the case for branches and compartments.
    if num_branchpoints > 0:
        bp_conds_children = bp_conds_children.at[child_inds].set(
            branchpoint_conds_children
        )
        bp_weights_children = bp_weights_children.at[child_inds].set(
            branchpoint_weights_children
        )
        bp_conds_parents = bp_conds_parents.at[par_inds].set(branchpoint_conds_parents)
        bp_weights_parents = bp_weights_parents.at[par_inds].set(
            branchpoint_weights_parents
        )

    # Triangulate the linear system of equations.
    (
        diags,
        lowers,
        solves,
        uppers,
        branchpoint_diags,
        branchpoint_solves,
        bp_weights_children,
        bp_conds_parents,
    ) = _triang_branched(
        lowers,
        diags,
        uppers,
        solves,
        bp_conds_children,
        bp_conds_parents,
        bp_weights_children,
        bp_weights_parents,
        branchpoint_diags,
        branchpoint_solves,
        solver,
        ncomp_per_branch,
        idx,
        debug_states,
    )

    # Backsubstitute the linear system of equations.
    (
        solves,
        lowers,
        diags,
        bp_weights_parents,
        branchpoint_solves,
        bp_conds_children,
    ) = _backsub_branched(
        lowers,
        diags,
        uppers,
        solves,
        bp_conds_children,
        bp_conds_parents,
        bp_weights_children,
        bp_weights_parents,
        branchpoint_diags,
        branchpoint_solves,
        solver,
        ncomp_per_branch,
        idx,
        debug_states,
    )
    return solves.ravel(order="C")[idx.mask(internal_node_inds)]