Skip to content

Building and using ion channel models

In this tutorial, you will learn how to:

  • define your own ion channel models beyond the preconfigured channels in Jaxley

This tutorial assumes that you have already learned how to build basic simulations.

from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad

import jaxley as jx

First, we define a cell as you saw in the previous tutorial:

comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])

You have also already learned how to insert preconfigured channels into Jaxley models:

cell.insert(Na())
cell.insert(K())
cell.insert(Leak())

In this tutorial, we will show you how to build your own channel and synapse models.

Your own channel

Below is how you can define your own channel. We will go into detail about individual parts of the code in the next couple of cells.

import jax.numpy as jnp
from jaxley.channels import Channel
from jaxley.solver_gate import solve_gate_exponential


def exp_update_alpha(x, y):
    return x / (jnp.exp(x / y) - 1.0)

class Potassium(Channel):
    """Potassium channel."""

    def __init__(self, name = None):
        super().__init__(name)
        self.channel_params = {"gK_new": 1e-4}
        self.channel_states = {"n_new": 0.0}
        self.current_name = "i_K"

    def update_states(self, states, dt, v, params):
        """Update state."""
        ns = states["n_new"]
        alpha = 0.01 * exp_update_alpha(-(v + 55), 10)
        beta = 0.125 * jnp.exp(-(v + 65) / 80)
        new_n = solve_gate_exponential(ns, dt, alpha, beta)
        return {"n_new": new_n}

    def compute_current(self, states, v, params):
        """Return current."""
        ns = states["n_new"]

        # Multiply with 1000 to convert Siemens to milli Siemens.
        kd_conds = params["gK_new"] * ns**4 * 1000  # mS/cm^2

        e_kd = -77.0        
        return kd_conds * (v - e_kd)

    def init_state(self, v, params):
        alpha = 0.01 * exp_update_alpha(-(v + 55), 10)
        beta = 0.125 * jnp.exp(-(v + 65) / 80)
        return {"n_new": alpha / (alpha + beta)}

Let’s look at each part of this in detail.

The below is simply a helper function for the solver of the gate variables:

def exp_update_alpha(x, y):
    return x / (jnp.exp(x / y) - 1.0)

Next, we define our channel as a class. It should inherit from the Channel class and define channel_params, channel_states, and current_name.

class Potassium(Channel):
    """Potassium channel."""

    def __init__(self, name=None):
        super().__init__(name)
        self.channel_params = {"gK_new": 1e-4}
        self.channel_states = {"n_new": 0.0}
        self.current_name = "i_K"

Next, we have the update_states() method, which updates the gating variables:

    def update_states(self, states, dt, v, params):

The inputs states to the update_states method is a dictionary which contains all states that are updated (including states of other channels). v is a jnp.ndarray which contains the voltage of a single compartment (shape ()). Let’s get the state of the potassium channel which we are building here:

ns = states["n_new"]

Next, we update the state of the channel. In this example, we do this with exponential Euler, but you can implement any solver yourself:

alpha = 0.01 * exp_update_alpha(-(v + 55), 10)
beta = 0.125 * jnp.exp(-(v + 65) / 80)
new_n = solve_gate_exponential(ns, dt, alpha, beta)
return {"n_new": new_n}

A channel also needs a compute_current() method which returns the current throught the channel:

    def compute_current(self, states, v, params):
        ns = states["n_new"]

        # Multiply with 1000 to convert Siemens to milli Siemens.
        kd_conds = params["gK_new"] * ns**4 * 1000  # mS/cm^2

        e_kd = -77.0        
        current = kd_conds * (v - e_kd)
        return current

Finally, the init_state() method can be implemented optionally. It can be used to automatically compute the initial state based on the voltage when cell.init_states() is run.

Alright, done! We can now insert this channel into any jx.Module such as our cell:

cell.insert(Potassium())
cell.delete_stimuli()
current = jx.step_current(1.0, 1.0, 0.1, 0.025, 10.0)
cell.branch(0).comp(0).stimulate(current)

cell.delete_recordings()
cell.branch(0).comp(0).record()
Added 1 external_states. See `.externals` for details.
Added 1 recordings. See `.recordings` for details.
s = jx.integrate(cell)
fig, ax = plt.subplots(1, 1, figsize=(4, 2))
_ = ax.plot(s.T[:-1])
_ = ax.set_ylim([-80, 50])
_ = ax.set_xlabel("Time (ms)")
_ = ax.set_ylabel("Voltage (mV)")

png

Your own synapse

The parts below assume that you have already learned how to build network simulations in Jaxley.

The below is an example of how to define your own synapse model in Jaxley:`

import jax.numpy as jnp
from jaxley.synapses.synapse import Synapse


class TestSynapse(Synapse):
    """
    Compute syanptic current and update syanpse state.
    """
    def __init__(self, name = None):
        super().__init__(name)
        self.synapse_params = {"gChol": 0.001, "eChol": 0.0}
        self.synapse_states = {"s_chol": 0.1}

    def update_states(self, states, delta_t, pre_voltage, post_voltage, params):
        """Return updated synapse state and current."""
        s_inf = 1.0 / (1.0 + jnp.exp((-35.0 - pre_voltage) / 10.0))
        exp_term = jnp.exp(-delta_t)
        new_s = states["s_chol"] * exp_term + s_inf * (1.0 - exp_term)
        return {"s_chol": new_s}

    def compute_current(self, states, pre_voltage, post_voltage, params):
        g_syn = params["gChol"] * states["s_chol"]
        return g_syn * (post_voltage - params["eChol"])

As you can see above, synapses follow closely how channels are defined. The main difference is that the compute_current method takes two voltages: the pre-synaptic voltage (a jnp.ndarray of shape ()) and the post-synaptic voltage (a jnp.ndarray of shape ()).

net = jx.Network([cell for _ in range(3)])
from jaxley.connect import connect

pre = net.cell(0).branch(0).loc(0.0)
post = net.cell(1).branch(0).loc(0.0)
connect(pre, post, TestSynapse())
/Users/michaeldeistler/Documents/phd/jaxley/jaxley/modules/base.py:1533: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.
  self.pointer.edges = pd.concat(
net.cell(0).branch(0).loc(0.0).stimulate(jx.step_current(1.0, 2.0, 0.1, 0.025, 10.0))
for i in range(3):
    net.cell(i).branch(0).loc(0.0).record()
Added 1 external_states. See `.externals` for details.
Added 1 recordings. See `.recordings` for details.
Added 1 recordings. See `.recordings` for details.
Added 1 recordings. See `.recordings` for details.
s = jx.integrate(net)
fig, ax = plt.subplots(1, 1, figsize=(4, 2))
_ = ax.plot(s.T[:-1])
_ = ax.set_ylim([-80, 50])
_ = ax.set_xlabel("Time (ms)")
_ = ax.set_ylabel("Voltage (mV)")

png

That’s it! You are now ready to build your own custom simulations and equip them with channel and synapse models!

This tutorial does not have an immediate follow-up tutorial. You could read the tutorial on groups, which allow to make your Jaxley simulations more elegant and convenient to interact with.

Alternatively, you can also directly jump ahead to the tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.