Skip to content

Simulation

integrate(module, params=[], *, param_state=None, data_stimuli=None, t_max=None, delta_t=0.025, solver='bwd_euler', voltage_solver='jaxley.stone', checkpoint_lengths=None, all_states=None, return_states=False)

Solves ODE and simulates neuron model.

Parameters:

Name Type Description Default
params List[Dict[str, ndarray]]

Trainable parameters returned by get_parameters().

[]
param_state Optional[List[Dict]]

Parameters returned by data_set.

None
data_stimuli Optional[Tuple[ndarray, DataFrame]]

Outputs of .data_stimulate(), only needed if stimuli change across function calls.

None
t_max Optional[float]

Duration of the simulation in milliseconds. If t_max is greater than the length of the stimulus input, the stimulus will be padded at the end with zeros. If t_max is smaller, then the stimulus with be truncated.

None
delta_t float

Time step of the solver in milliseconds.

0.025
solver str

Which ODE solver to use. Either of [“fwd_euler”, “bwd_euler”, “cranck”].

'bwd_euler'
tridiag_solver

Algorithm to solve tridiagonal systems. The different options only affect bwd_euler and cranck solvers. Either of [“stone”, “thomas”], where stone is much faster on GPU for long branches with many compartments and thomas is slightly faster on CPU (thomas is used in NEURON).

required
checkpoint_lengths Optional[List[int]]

Number of timesteps at every level of checkpointing. The prod(checkpoint_lengths) must be larger or equal to the desired number of simulated timesteps. Warning: the simulation is run for prod(checkpoint_lengths) timesteps, and the result is posthoc truncated to the desired simulation length. Therefore, a poor choice of checkpoint_lengths can lead to longer simulation time. If None, no checkpointing is applied.

None
all_states Optional[Dict]

An optional initial state that was returned by a previous jx.integrate(..., return_states=True) run. Overrides potentially trainable initial states.

None
return_states bool

If True, it returns all states such that the current state of the Module can be set with set_states.

False
Source code in jaxley/integrate.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def integrate(
    module: Module,
    params: List[Dict[str, jnp.ndarray]] = [],
    *,
    param_state: Optional[List[Dict]] = None,
    data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,
    t_max: Optional[float] = None,
    delta_t: float = 0.025,
    solver: str = "bwd_euler",
    voltage_solver: str = "jaxley.stone",
    checkpoint_lengths: Optional[List[int]] = None,
    all_states: Optional[Dict] = None,
    return_states: bool = False,
) -> jnp.ndarray:
    """
    Solves ODE and simulates neuron model.

    Args:
        params: Trainable parameters returned by `get_parameters()`.
        param_state: Parameters returned by `data_set`.
        data_stimuli: Outputs of `.data_stimulate()`, only needed if stimuli change
            across function calls.
        t_max: Duration of the simulation in milliseconds. If `t_max` is greater than
            the length of the stimulus input, the stimulus will be padded at the end
            with zeros. If `t_max` is smaller, then the stimulus with be truncated.
        delta_t: Time step of the solver in milliseconds.
        solver: Which ODE solver to use. Either of ["fwd_euler", "bwd_euler", "cranck"].
        tridiag_solver: Algorithm to solve tridiagonal systems. The  different options
            only affect `bwd_euler` and `cranck` solvers. Either of ["stone",
            "thomas"], where `stone` is much faster on GPU for long branches
            with many compartments and `thomas` is slightly faster on CPU (`thomas` is
            used in NEURON).
        checkpoint_lengths: Number of timesteps at every level of checkpointing. The
            `prod(checkpoint_lengths)` must be larger or equal to the desired number of
            simulated timesteps. Warning: the simulation is run for
            `prod(checkpoint_lengths)` timesteps, and the result is posthoc truncated
            to the desired simulation length. Therefore, a poor choice of
            `checkpoint_lengths` can lead to longer simulation time. If `None`, no
            checkpointing is applied.
        all_states: An optional initial state that was returned by a previous
            `jx.integrate(..., return_states=True)` run. Overrides potentially
            trainable initial states.
        return_states: If True, it returns all states such that the current state of
            the `Module` can be set with `set_states`.
    """

    assert module.initialized, "Module is not initialized, run `.initialize()`."
    module.to_jax()  # Creates `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.

    # Initialize the external inputs and their indices.
    externals = module.externals.copy()
    external_inds = module.external_inds.copy()

    # If stimulus is inserted, add it to the external inputs.
    if "i" in module.externals.keys() or data_stimuli is not None:
        if "i" in module.externals.keys():
            if data_stimuli is not None:
                externals["i"] = jnp.concatenate([externals["i"], data_stimuli[0]])
                external_inds["i"] = jnp.concatenate(
                    [external_inds["i"], data_stimuli[1].comp_index.to_numpy()]
                )
        else:
            externals["i"] = data_stimuli[0]
            external_inds["i"] = data_stimuli[1].comp_index.to_numpy()
    else:
        externals["i"] = jnp.asarray([[]]).astype("float")
        external_inds["i"] = jnp.asarray([]).astype("int32")

    if not externals.keys():
        # No stimulus was inserted and no clamp was set.
        assert (
            t_max is not None
        ), "If no stimulus or clamp are inserted you have to specify the simulation duration at `jx.integrate(..., t_max=)`."

    for key in externals.keys():
        externals[key] = externals[key].T  # Shape `(time, num_stimuli)`.

    rec_inds = module.recordings.rec_index.to_numpy()
    rec_states = module.recordings.state.to_numpy()

    # Shorten or pad stimulus depending on `t_max`.
    if t_max is not None:
        t_max_steps = int(t_max // delta_t + 1)

        # Pad or truncate the stimulus.
        if "i" in externals.keys() and t_max_steps > externals["i"].shape[0]:
            pad = jnp.zeros(
                (t_max_steps - externals["i"].shape[0], externals["i"].shape[1])
            )
            externals["i"] = jnp.concatenate((externals["i"], pad))

        for key in externals.keys():
            if t_max_steps > externals[key].shape[0]:
                raise NotImplementedError(
                    "clamp must be at least as long as simulation."
                )
            else:
                externals[key] = externals[key][:t_max_steps, :]

    # Make the `trainable_params` of the same shape as the `param_state`, such that they
    # can be processed together by `get_all_parameters`.
    pstate = params_to_pstate(params, module.indices_set_by_trainables)

    # Gather parameters from `make_trainable` and `data_set` into a single list.
    if param_state is not None:
        pstate += param_state

    all_params = module.get_all_parameters(pstate)
    all_states = (
        module.get_all_states(pstate, all_params, delta_t)
        if all_states is None
        else all_states
    )

    def _body_fun(state, externals):
        state = module.step(
            state,
            delta_t,
            external_inds,
            externals,
            params=all_params,
            solver=solver,
            voltage_solver=voltage_solver,
        )
        recs = jnp.asarray(
            [
                state[rec_state][rec_ind]
                for rec_state, rec_ind in zip(rec_states, rec_inds)
            ]
        )
        return state, recs

    # If necessary, pad the stimulus with zeros in order to simulate sufficiently long.
    # The total simulation length will be `prod(checkpoint_lengths)`. At the end, we
    # return only the first `nsteps_to_return` elements (plus the initial state).
    example_key = list(externals.keys())[0]
    nsteps_to_return = len(externals[example_key])
    if checkpoint_lengths is None:
        checkpoint_lengths = [len(externals[example_key])]
        length = len(externals[example_key])
    else:
        length = prod(checkpoint_lengths)
        size_difference = length - len(externals[example_key])
        dummy_external = jnp.zeros((size_difference, externals[example_key].shape[1]))
        assert (
            len(externals[example_key]) <= length
        ), "The desired simulation duration is longer than `prod(nested_length)`."
        for key in externals.keys():
            externals[key] = jnp.concatenate([externals[key], dummy_external])

    # Record the initial state.
    init_recs = jnp.asarray(
        [
            all_states[rec_state][rec_ind]
            for rec_state, rec_ind in zip(rec_states, rec_inds)
        ]
    )
    init_recording = jnp.expand_dims(init_recs, axis=0)

    # Run simulation.
    all_states, recordings = nested_checkpoint_scan(
        _body_fun,
        all_states,
        externals,
        length=length,
        nested_lengths=checkpoint_lengths,
    )
    recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T
    return (recs, all_states) if return_states else recs

exponential_euler(x, dt, x_inf, x_tau)

An exact solver for the linear dynamical system dx = -(x - x_inf) / x_tau.

Source code in jaxley/solver_gate.py
36
37
38
39
40
41
42
43
44
def exponential_euler(
    x: jnp.ndarray,
    dt: float,
    x_inf: jnp.ndarray,
    x_tau: jnp.ndarray,
):
    """An exact solver for the linear dynamical system `dx = -(x - x_inf) / x_tau`."""
    exp_term = save_exp(-dt / x_tau)
    return x * exp_term + x_inf * (1.0 - exp_term)

save_exp(x, max_value=20.0)

Clip the input to a maximum value and return its exponential.

Source code in jaxley/solver_gate.py
 7
 8
 9
10
def save_exp(x, max_value: float = 20.0):
    """Clip the input to a maximum value and return its exponential."""
    x = jnp.clip(x, a_max=max_value)
    return jnp.exp(x)

solve_inf_gate_exponential(x, dt, s_inf, tau_s)

solves dx/dt = (s_inf - x) / tau_s via exponential Euler

Parameters:

Name Type Description Default
x ndarray

gate variable

required
dt float

time_delta

required
s_inf ndarray

description

required
tau_s ndarray

description

required

Returns:

Name Type Description
_type_

updated gate

Source code in jaxley/solver_gate.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def solve_inf_gate_exponential(
    x: jnp.ndarray,
    dt: float,
    s_inf: jnp.ndarray,
    tau_s: jnp.ndarray,
):
    """solves dx/dt = (s_inf - x) / tau_s
    via exponential Euler

    Args:
        x (jnp.ndarray): gate variable
        dt (float): time_delta
        s_inf (jnp.ndarray): _description_
        tau_s (jnp.ndarray): _description_

    Returns:
        _type_: updated gate
    """
    slope = -1.0 / tau_s
    exp_term = save_exp(slope * dt)
    return x * exp_term + s_inf * (1.0 - exp_term)

step_voltage_explicit(voltages, voltage_terms, constant_terms, coupling_conds_bwd, coupling_conds_fwd, branch_cond_fwd, branch_cond_bwd, nbranches, parents, delta_t)

Solve one timestep of branched nerve equations with explicit (forward) Euler.

Source code in jaxley/solver_voltage.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def step_voltage_explicit(
    voltages: jnp.ndarray,
    voltage_terms: jnp.ndarray,
    constant_terms: jnp.ndarray,
    coupling_conds_bwd: jnp.ndarray,
    coupling_conds_fwd: jnp.ndarray,
    branch_cond_fwd: jnp.ndarray,
    branch_cond_bwd: jnp.ndarray,
    nbranches: int,
    parents: jnp.ndarray,
    delta_t: float,
) -> jnp.ndarray:
    """Solve one timestep of branched nerve equations with explicit (forward) Euler."""
    voltages = jnp.reshape(voltages, (nbranches, -1))
    voltage_terms = jnp.reshape(voltage_terms, (nbranches, -1))
    constant_terms = jnp.reshape(constant_terms, (nbranches, -1))

    update = voltage_vectorfield(
        parents,
        voltages,
        voltage_terms,
        constant_terms,
        coupling_conds_bwd,
        coupling_conds_fwd,
        branch_cond_fwd,
        branch_cond_bwd,
    )
    new_voltates = voltages + delta_t * update
    return new_voltates

step_voltage_implicit(voltages, voltage_terms, constant_terms, coupling_conds_upper, coupling_conds_lower, summed_coupling_conds, branchpoint_conds_children, branchpoint_conds_parents, branchpoint_weights_children, branchpoint_weights_parents, par_inds, child_inds, nbranches, solver, delta_t, children_in_level, parents_in_level, root_inds, branchpoint_group_inds, debug_states)

Solve one timestep of branched nerve equations with implicit (backward) Euler.

Source code in jaxley/solver_voltage.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def step_voltage_implicit(
    voltages,
    voltage_terms,
    constant_terms,
    coupling_conds_upper,
    coupling_conds_lower,
    summed_coupling_conds,
    branchpoint_conds_children,
    branchpoint_conds_parents,
    branchpoint_weights_children,
    branchpoint_weights_parents,
    par_inds,
    child_inds,
    nbranches,
    solver: str,
    delta_t,
    children_in_level,
    parents_in_level,
    root_inds,
    branchpoint_group_inds,
    debug_states,
):
    """Solve one timestep of branched nerve equations with implicit (backward) Euler."""
    voltages = jnp.reshape(voltages, (nbranches, -1))
    voltage_terms = jnp.reshape(voltage_terms, (nbranches, -1))
    constant_terms = jnp.reshape(constant_terms, (nbranches, -1))
    coupling_conds_upper = jnp.reshape(coupling_conds_upper, (nbranches, -1))
    coupling_conds_lower = jnp.reshape(coupling_conds_lower, (nbranches, -1))
    summed_coupling_conds = jnp.reshape(summed_coupling_conds, (nbranches, -1))

    # Define quasi-tridiagonal system.
    lowers, diags, uppers, solves = define_all_tridiags(
        voltages,
        voltage_terms,
        constant_terms,
        nbranches,
        coupling_conds_upper,
        coupling_conds_lower,
        summed_coupling_conds,
        delta_t,
    )
    all_branchpoint_vals = jnp.concatenate(
        [branchpoint_weights_parents, branchpoint_weights_children]
    )
    # Find unique group identifiers
    num_branchpoints = len(branchpoint_conds_parents)
    branchpoint_diags = -group_and_sum(
        all_branchpoint_vals, branchpoint_group_inds, num_branchpoints
    )
    branchpoint_solves = jnp.zeros((num_branchpoints,))

    branchpoint_conds_children = -delta_t * branchpoint_conds_children
    branchpoint_conds_parents = -delta_t * branchpoint_conds_parents

    # Here, I move all child and parent indices towards a branchpoint into a larger
    # vector. This is wasteful, but it makes indexing much easier. JIT compiling
    # makes the speed difference negligible.
    # Children.
    bp_conds_children = jnp.zeros(nbranches)
    bp_weights_children = jnp.zeros(nbranches)
    # Parents.
    bp_conds_parents = jnp.zeros(nbranches)
    bp_weights_parents = jnp.zeros(nbranches)

    # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.
    # `len(inds) == 0` is the case for branches and compartments.
    if num_branchpoints > 0:
        bp_conds_children = bp_conds_children.at[child_inds].set(
            branchpoint_conds_children
        )
        bp_weights_children = bp_weights_children.at[child_inds].set(
            branchpoint_weights_children
        )
        bp_conds_parents = bp_conds_parents.at[par_inds].set(branchpoint_conds_parents)
        bp_weights_parents = bp_weights_parents.at[par_inds].set(
            branchpoint_weights_parents
        )

    # Triangulate the linear system of equations.
    (
        diags,
        lowers,
        solves,
        uppers,
        branchpoint_diags,
        branchpoint_solves,
        bp_weights_children,
        bp_conds_parents,
    ) = _triang_branched(
        lowers,
        diags,
        uppers,
        solves,
        bp_conds_children,
        bp_conds_parents,
        bp_weights_children,
        bp_weights_parents,
        branchpoint_diags,
        branchpoint_solves,
        solver,
        children_in_level,
        parents_in_level,
        root_inds,
        debug_states,
    )

    # Backsubstitute the linear system of equations.
    (
        solves,
        lowers,
        diags,
        bp_weights_parents,
        branchpoint_solves,
        bp_conds_children,
    ) = _backsub_branched(
        lowers,
        diags,
        uppers,
        solves,
        bp_conds_children,
        bp_conds_parents,
        bp_weights_children,
        bp_weights_parents,
        branchpoint_diags,
        branchpoint_solves,
        solver,
        children_in_level,
        parents_in_level,
        root_inds,
        debug_states,
    )

    return solves

voltage_vectorfield(parents, voltages, voltage_terms, constant_terms, coupling_conds_bwd, coupling_conds_fwd, branch_cond_fwd, branch_cond_bwd)

Evaluate the vectorfield of the nerve equation.

Source code in jaxley/solver_voltage.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def voltage_vectorfield(
    parents: jnp.ndarray,
    voltages: jnp.ndarray,
    voltage_terms: jnp.ndarray,
    constant_terms: jnp.ndarray,
    coupling_conds_bwd: jnp.ndarray,
    coupling_conds_fwd: jnp.ndarray,
    branch_cond_fwd: jnp.ndarray,
    branch_cond_bwd: jnp.ndarray,
) -> jnp.ndarray:
    """Evaluate the vectorfield of the nerve equation."""
    # Membrane current update.
    vecfield = -voltage_terms * voltages + constant_terms

    # Current through segments within the same branch.
    vecfield = vecfield.at[:, :-1].add(
        (voltages[:, 1:] - voltages[:, :-1]) * coupling_conds_bwd
    )
    vecfield = vecfield.at[:, 1:].add(
        (voltages[:, :-1] - voltages[:, 1:]) * coupling_conds_fwd
    )

    # Current through branch points.
    if len(branch_cond_bwd) > 0:
        vecfield = vecfield.at[:, -1].add(
            (voltages[parents, 0] - voltages[:, -1]) * branch_cond_bwd
        )

        # Several branches might have the same parent, so we have to either update these
        # entries sequentially or we have to build a matrix with width being the maximum
        # number of children and then sum.
        term_to_add = (voltages[:, -1] - voltages[parents, 0]) * branch_cond_fwd
        inds = jnp.stack([parents, jnp.zeros_like(parents)]).T
        dnums = ScatterDimensionNumbers(
            update_window_dims=(),
            inserted_window_dims=(0, 1),
            scatter_dims_to_operand_dims=(0, 1),
        )
        vecfield = scatter_add(vecfield, inds, term_to_add, dnums)

    return vecfield