Skip to content

Building ion channel models

In this tutorial, you will learn how to:

  • define your own ion channel models beyond the preconfigured channels in Jaxley

This tutorial assumes that you have already learned how to build basic simulations.

from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad

import jaxley as jx

First, we define a cell as you saw in the previous tutorial:

comp = jx.Compartment()
branch = jx.Branch(comp, ncomp=4)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])

You have also already learned how to insert preconfigured channels into Jaxley models:

cell.insert(Na())
cell.insert(K())
cell.insert(Leak())

In this tutorial, we will show you how to build your own channel and synapse models.

Your own channel

Below is how you can define your own channel. We will go into detail about individual parts of the code in the next couple of cells.

import jax.numpy as jnp
from jaxley.channels import Channel
from jaxley.solver_gate import solve_gate_exponential


def exp_update_alpha(x, y):
    return x / (jnp.exp(x / y) - 1.0)

class Potassium(Channel):
    """Potassium channel."""

    def __init__(self, name = None):
        self.current_is_in_mA_per_cm2 = True
        super().__init__(name)
        self.channel_params = {"gK_new": 1e-4}
        self.channel_states = {"n_new": 0.0}
        self.current_name = "i_K"

    def update_states(self, states, dt, v, params):
        """Update state."""
        ns = states["n_new"]
        alpha = 0.01 * exp_update_alpha(-(v + 55), 10)
        beta = 0.125 * jnp.exp(-(v + 65) / 80)
        new_n = solve_gate_exponential(ns, dt, alpha, beta)
        return {"n_new": new_n}

    def compute_current(self, states, v, params):
        """Return current."""
        ns = states["n_new"]
        kd_conds = params["gK_new"] * ns**4  # S/cm^2

        e_kd = -77.0        
        return kd_conds * (v - e_kd)

    def init_state(self, states, v, params, delta_t):
        alpha = 0.01 * exp_update_alpha(-(v + 55), 10)
        beta = 0.125 * jnp.exp(-(v + 65) / 80)
        return {"n_new": alpha / (alpha + beta)}

Let’s look at each part of this in detail.

The below is simply a helper function for the solver of the gate variables:

def exp_update_alpha(x, y):
    return x / (jnp.exp(x / y) - 1.0)

Next, we define our channel as a class. It should inherit from the Channel class and define channel_params, channel_states, and current_name. You also need to set self.current_is_in_mA_per_cm2=True as the first line on your __init__() method. This is to acknowledge that your current is returned in mA/cm2 (not in uA/cm2, as would have been required in Jaxley versions 0.4.0 or older).

class Potassium(Channel):
    """Potassium channel."""

    def __init__(self, name=None):
        self.current_is_in_mA_per_cm2 = True
        super().__init__(name)
        self.channel_params = {"gK_new": 1e-4}
        self.channel_states = {"n_new": 0.0}
        self.current_name = "i_K"

Next, we have the update_states() method, which updates the gating variables:

    def update_states(self, states, dt, v, params):

Every channel you define must have an update_states() method which takes exactly these five arguments (self, states, dt, v, params). The inputs states to the update_states method is a dictionary which contains all states that are updated (including states of other channels). v is a jnp.ndarray which contains the voltage of a single compartment (shape ()). Let’s get the state of the potassium channel which we are building here:

ns = states["n_new"]

Next, we update the state of the channel. In this example, we do this with exponential Euler, but you can implement any solver yourself:

alpha = 0.01 * exp_update_alpha(-(v + 55), 10)
beta = 0.125 * jnp.exp(-(v + 65) / 80)
new_n = solve_gate_exponential(ns, dt, alpha, beta)
return {"n_new": new_n}

A channel also needs a compute_current() method which returns the current throught the channel:

    def compute_current(self, states, v, params):
        ns = states["n_new"]

        # Multiply with 1000 to convert Siemens to milli Siemens.
        kd_conds = params["gK_new"] * ns**4  # S/cm^2

        e_kd = -77.0        
        current = kd_conds * (v - e_kd)
        return current

Finally, the init_state() method can be implemented optionally. It can be used to automatically compute the initial state based on the voltage when cell.init_states() is run.

Alright, done! We can now insert this channel into any jx.Module such as our cell:

cell.insert(Potassium())
cell.delete_stimuli()
current = jx.step_current(1.0, 1.0, 0.1, 0.025, 10.0)
cell.branch(0).comp(0).stimulate(current)

cell.delete_recordings()
cell.branch(0).comp(0).record()
Added 1 external_states. See `.externals` for details.
Added 1 recordings. See `.recordings` for details.
s = jx.integrate(cell)
fig, ax = plt.subplots(1, 1, figsize=(4, 2))
_ = ax.plot(s.T[:-1])
_ = ax.set_ylim([-80, 50])
_ = ax.set_xlabel("Time (ms)")
_ = ax.set_ylabel("Voltage (mV)")

png

Your own synapse

The parts below assume that you have already learned how to build network simulations in Jaxley.

Note that again, a synapse needs to have the two functions update_states and compute_current with all input arguments shown below.

The below is an example of how to define your own synapse model in Jaxley:

import jax.numpy as jnp
from jaxley.synapses.synapse import Synapse


class TestSynapse(Synapse):
    """
    Compute syanptic current and update syanpse state.
    """
    def __init__(self, name = None):
        super().__init__(name)
        self.synapse_params = {"gChol": 0.001, "eChol": 0.0}
        self.synapse_states = {"s_chol": 0.1}

    def update_states(self, states, delta_t, pre_voltage, post_voltage, params):
        """Return updated synapse state and current."""
        s_inf = 1.0 / (1.0 + jnp.exp((-35.0 - pre_voltage) / 10.0))
        exp_term = jnp.exp(-delta_t)
        new_s = states["s_chol"] * exp_term + s_inf * (1.0 - exp_term)
        return {"s_chol": new_s}

    def compute_current(self, states, pre_voltage, post_voltage, params):
        g_syn = params["gChol"] * states["s_chol"]
        return g_syn * (post_voltage - params["eChol"])

As you can see above, synapses follow closely how channels are defined. The main difference is that the compute_current method takes two voltages: the pre-synaptic voltage (a jnp.ndarray of shape ()) and the post-synaptic voltage (a jnp.ndarray of shape ()).

net = jx.Network([cell for _ in range(3)])
from jaxley.connect import connect

pre = net.cell(0).branch(0).loc(0.0)
post = net.cell(1).branch(0).loc(0.0)
connect(pre, post, TestSynapse())
net.cell(0).branch(0).loc(0.0).stimulate(jx.step_current(1.0, 2.0, 0.1, 0.025, 10.0))
for i in range(3):
    net.cell(i).branch(0).loc(0.0).record()
Added 1 external_states. See `.externals` for details.
Added 1 recordings. See `.recordings` for details.
Added 1 recordings. See `.recordings` for details.
Added 1 recordings. See `.recordings` for details.
s = jx.integrate(net)
fig, ax = plt.subplots(1, 1, figsize=(4, 2))
_ = ax.plot(s.T[:-1])
_ = ax.set_ylim([-80, 50])
_ = ax.set_xlabel("Time (ms)")
_ = ax.set_ylabel("Voltage (mV)")

png

That’s it! You are now ready to build your own custom simulations and equip them with channel and synapse models!

This tutorial does not have an immediate follow-up tutorial. If you have not done so already, you can check out our tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.