Skip to content

Channels

Channel

Channel base class. All channels inherit from this class.

As in NEURON, a Channel is considered a distributed process, which means that its conductances are to be specified in S/cm2 and its currents are to be specified in uA/cm2.

Source code in jaxley/channels/channel.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class Channel:
    """Channel base class. All channels inherit from this class.

    As in NEURON, a `Channel` is considered a distributed process, which means that its
    conductances are to be specified in `S/cm2` and its currents are to be specified in
    `uA/cm2`."""

    _name = None
    channel_params = None
    channel_states = None
    current_name = None

    def __init__(self, name: Optional[str] = None):
        self._name = name if name else self.__class__.__name__

    @property
    def name(self) -> Optional[str]:
        """The name of the channel (by default, this is the class name)."""
        return self._name

    def change_name(self, new_name: str):
        """Change the channel name.

        Args:
            new_name: The new name of the channel.

        Returns:
            Renamed channel, such that this function is chainable.
        """
        old_prefix = self._name + "_"
        new_prefix = new_name + "_"

        self._name = new_name
        self.channel_params = {
            (
                new_prefix + key[len(old_prefix) :]
                if key.startswith(old_prefix)
                else key
            ): value
            for key, value in self.channel_params.items()
        }

        self.channel_states = {
            (
                new_prefix + key[len(old_prefix) :]
                if key.startswith(old_prefix)
                else key
            ): value
            for key, value in self.channel_states.items()
        }
        return self

    def update_states(
        self, states, dt, v, params
    ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
        """Return the updated states."""
        raise NotImplementedError

    def compute_current(
        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
    ):
        """Given channel states and voltage, return the current through the channel.

        Args:
            states: All states of the compartment.
            v: Voltage of the compartment in mV.
            params: Parameters of the channel (conductances in `S/cm2`).

        Returns:
            Current in `uA/cm2`.
        """
        raise NotImplementedError

name: Optional[str] property

The name of the channel (by default, this is the class name).

change_name(new_name)

Change the channel name.

Parameters:

Name Type Description Default
new_name str

The new name of the channel.

required

Returns:

Type Description

Renamed channel, such that this function is chainable.

Source code in jaxley/channels/channel.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def change_name(self, new_name: str):
    """Change the channel name.

    Args:
        new_name: The new name of the channel.

    Returns:
        Renamed channel, such that this function is chainable.
    """
    old_prefix = self._name + "_"
    new_prefix = new_name + "_"

    self._name = new_name
    self.channel_params = {
        (
            new_prefix + key[len(old_prefix) :]
            if key.startswith(old_prefix)
            else key
        ): value
        for key, value in self.channel_params.items()
    }

    self.channel_states = {
        (
            new_prefix + key[len(old_prefix) :]
            if key.startswith(old_prefix)
            else key
        ): value
        for key, value in self.channel_states.items()
    }
    return self

compute_current(states, v, params)

Given channel states and voltage, return the current through the channel.

Parameters:

Name Type Description Default
states Dict[str, ndarray]

All states of the compartment.

required
v

Voltage of the compartment in mV.

required
params Dict[str, ndarray]

Parameters of the channel (conductances in S/cm2).

required

Returns:

Type Description

Current in uA/cm2.

Source code in jaxley/channels/channel.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def compute_current(
    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
    """Given channel states and voltage, return the current through the channel.

    Args:
        states: All states of the compartment.
        v: Voltage of the compartment in mV.
        params: Parameters of the channel (conductances in `S/cm2`).

    Returns:
        Current in `uA/cm2`.
    """
    raise NotImplementedError

update_states(states, dt, v, params)

Return the updated states.

Source code in jaxley/channels/channel.py
62
63
64
65
66
def update_states(
    self, states, dt, v, params
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
    """Return the updated states."""
    raise NotImplementedError

HH

Bases: Channel

Hodgkin-Huxley channel.

Source code in jaxley/channels/hh.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
class HH(Channel):
    """Hodgkin-Huxley channel."""

    def __init__(self, name: Optional[str] = None):
        super().__init__(name)
        prefix = self._name
        self.channel_params = {
            f"{prefix}_gNa": 0.12,
            f"{prefix}_gK": 0.036,
            f"{prefix}_gLeak": 0.0003,
            f"{prefix}_eNa": 50.0,
            f"{prefix}_eK": -77.0,
            f"{prefix}_eLeak": -54.3,
        }
        self.channel_states = {
            f"{prefix}_m": 0.2,
            f"{prefix}_h": 0.2,
            f"{prefix}_n": 0.2,
        }
        self.current_name = f"i_HH"

    def update_states(
        self,
        states: Dict[str, jnp.ndarray],
        dt,
        v,
        params: Dict[str, jnp.ndarray],
    ):
        """Return updated HH channel state."""
        prefix = self._name
        m, h, n = states[f"{prefix}_m"], states[f"{prefix}_h"], states[f"{prefix}_n"]
        new_m = solve_gate_exponential(m, dt, *self.m_gate(v))
        new_h = solve_gate_exponential(h, dt, *self.h_gate(v))
        new_n = solve_gate_exponential(n, dt, *self.n_gate(v))
        return {f"{prefix}_m": new_m, f"{prefix}_h": new_h, f"{prefix}_n": new_n}

    def compute_current(
        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
    ):
        """Return current through HH channels."""
        prefix = self._name
        m, h, n = states[f"{prefix}_m"], states[f"{prefix}_h"], states[f"{prefix}_n"]

        # Multiply with 1000 to convert Siemens to milli Siemens.
        gNa = params[f"{prefix}_gNa"] * (m**3) * h * 1000  # mS/cm^2
        gK = params[f"{prefix}_gK"] * n**4 * 1000  # mS/cm^2
        gLeak = params[f"{prefix}_gLeak"] * 1000  # mS/cm^2

        return (
            gNa * (v - params[f"{prefix}_eNa"])
            + gK * (v - params[f"{prefix}_eK"])
            + gLeak * (v - params[f"{prefix}_eLeak"])
        )

    def init_state(self, v, params):
        """Initialize the state such at fixed point of gate dynamics."""
        prefix = self._name
        alpha_m, beta_m = self.m_gate(v)
        alpha_h, beta_h = self.h_gate(v)
        alpha_n, beta_n = self.n_gate(v)
        return {
            f"{prefix}_m": alpha_m / (alpha_m + beta_m),
            f"{prefix}_h": alpha_h / (alpha_h + beta_h),
            f"{prefix}_n": alpha_n / (alpha_n + beta_n),
        }

    @staticmethod
    def m_gate(v):
        alpha = 0.1 * _vtrap(-(v + 40), 10)
        beta = 4.0 * save_exp(-(v + 65) / 18)
        return alpha, beta

    @staticmethod
    def h_gate(v):
        alpha = 0.07 * save_exp(-(v + 65) / 20)
        beta = 1.0 / (save_exp(-(v + 35) / 10) + 1)
        return alpha, beta

    @staticmethod
    def n_gate(v):
        alpha = 0.01 * _vtrap(-(v + 55), 10)
        beta = 0.125 * save_exp(-(v + 65) / 80)
        return alpha, beta

compute_current(states, v, params)

Return current through HH channels.

Source code in jaxley/channels/hh.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def compute_current(
    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
    """Return current through HH channels."""
    prefix = self._name
    m, h, n = states[f"{prefix}_m"], states[f"{prefix}_h"], states[f"{prefix}_n"]

    # Multiply with 1000 to convert Siemens to milli Siemens.
    gNa = params[f"{prefix}_gNa"] * (m**3) * h * 1000  # mS/cm^2
    gK = params[f"{prefix}_gK"] * n**4 * 1000  # mS/cm^2
    gLeak = params[f"{prefix}_gLeak"] * 1000  # mS/cm^2

    return (
        gNa * (v - params[f"{prefix}_eNa"])
        + gK * (v - params[f"{prefix}_eK"])
        + gLeak * (v - params[f"{prefix}_eLeak"])
    )

init_state(v, params)

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/hh.py
66
67
68
69
70
71
72
73
74
75
76
def init_state(self, v, params):
    """Initialize the state such at fixed point of gate dynamics."""
    prefix = self._name
    alpha_m, beta_m = self.m_gate(v)
    alpha_h, beta_h = self.h_gate(v)
    alpha_n, beta_n = self.n_gate(v)
    return {
        f"{prefix}_m": alpha_m / (alpha_m + beta_m),
        f"{prefix}_h": alpha_h / (alpha_h + beta_h),
        f"{prefix}_n": alpha_n / (alpha_n + beta_n),
    }

update_states(states, dt, v, params)

Return updated HH channel state.

Source code in jaxley/channels/hh.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def update_states(
    self,
    states: Dict[str, jnp.ndarray],
    dt,
    v,
    params: Dict[str, jnp.ndarray],
):
    """Return updated HH channel state."""
    prefix = self._name
    m, h, n = states[f"{prefix}_m"], states[f"{prefix}_h"], states[f"{prefix}_n"]
    new_m = solve_gate_exponential(m, dt, *self.m_gate(v))
    new_h = solve_gate_exponential(h, dt, *self.h_gate(v))
    new_n = solve_gate_exponential(n, dt, *self.n_gate(v))
    return {f"{prefix}_m": new_m, f"{prefix}_h": new_h, f"{prefix}_n": new_n}

Pospischil

Bases: Channel

Leak current

Source code in jaxley/channels/pospischil.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class Leak(Channel):
    """Leak current"""

    def __init__(self, name: Optional[str] = None):
        super().__init__(name)
        prefix = self._name
        self.channel_params = {
            f"{prefix}_gLeak": 1e-4,
            f"{prefix}_eLeak": -70.0,
        }
        self.channel_states = {}
        self.current_name = f"i_{prefix}"

    def update_states(
        self,
        states: Dict[str, jnp.ndarray],
        dt,
        v,
        params: Dict[str, jnp.ndarray],
    ):
        """No state to update."""
        return {}

    def compute_current(
        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
    ):
        """Return current."""
        prefix = self._name
        # Multiply with 1000 to convert Siemens to milli Siemens.
        gLeak = params[f"{prefix}_gLeak"] * 1000  # mS/cm^2
        return gLeak * (v - params[f"{prefix}_eLeak"])

    def init_state(self, v, params):
        return {}

compute_current(states, v, params)

Return current.

Source code in jaxley/channels/pospischil.py
58
59
60
61
62
63
64
65
def compute_current(
    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
    """Return current."""
    prefix = self._name
    # Multiply with 1000 to convert Siemens to milli Siemens.
    gLeak = params[f"{prefix}_gLeak"] * 1000  # mS/cm^2
    return gLeak * (v - params[f"{prefix}_eLeak"])

update_states(states, dt, v, params)

No state to update.

Source code in jaxley/channels/pospischil.py
48
49
50
51
52
53
54
55
56
def update_states(
    self,
    states: Dict[str, jnp.ndarray],
    dt,
    v,
    params: Dict[str, jnp.ndarray],
):
    """No state to update."""
    return {}

Bases: Channel

Sodium channel

Source code in jaxley/channels/pospischil.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class Na(Channel):
    """Sodium channel"""

    def __init__(self, name: Optional[str] = None):
        super().__init__(name)
        prefix = self._name
        self.channel_params = {
            f"{prefix}_gNa": 50e-3,
            "eNa": 50.0,
            "vt": -60.0,  # Global parameter, not prefixed with `Na`.
        }
        self.channel_states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2}
        self.current_name = f"i_Na"

    def update_states(
        self,
        states: Dict[str, jnp.ndarray],
        dt,
        v,
        params: Dict[str, jnp.ndarray],
    ):
        """Update state."""
        prefix = self._name
        m, h = states[f"{prefix}_m"], states[f"{prefix}_h"]
        new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params["vt"]))
        new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params["vt"]))
        return {f"{prefix}_m": new_m, f"{prefix}_h": new_h}

    def compute_current(
        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
    ):
        """Return current."""
        prefix = self._name
        m, h = states[f"{prefix}_m"], states[f"{prefix}_h"]

        # Multiply with 1000 to convert Siemens to milli Siemens.
        gNa = params[f"{prefix}_gNa"] * (m**3) * h * 1000  # mS/cm^2

        current = gNa * (v - params["eNa"])
        return current

    def init_state(self, v, params):
        """Initialize the state such at fixed point of gate dynamics."""
        prefix = self._name
        alpha_m, beta_m = self.m_gate(v, params["vt"])
        alpha_h, beta_h = self.h_gate(v, params["vt"])
        return {
            f"{prefix}_m": alpha_m / (alpha_m + beta_m),
            f"{prefix}_h": alpha_h / (alpha_h + beta_h),
        }

    @staticmethod
    def m_gate(v, vt):
        v_alpha = v - vt - 13.0
        alpha = 0.32 * efun(-0.25 * v_alpha) / 0.25

        v_beta = v - vt - 40.0
        beta = 0.28 * efun(0.2 * v_beta) / 0.2
        return alpha, beta

    @staticmethod
    def h_gate(v, vt):
        v_alpha = v - vt - 17.0
        alpha = 0.128 * save_exp(-v_alpha / 18.0)

        v_beta = v - vt - 40.0
        beta = 4.0 / (save_exp(-v_beta / 5.0) + 1.0)
        return alpha, beta

compute_current(states, v, params)

Return current.

Source code in jaxley/channels/pospischil.py
 99
100
101
102
103
104
105
106
107
108
109
110
def compute_current(
    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
    """Return current."""
    prefix = self._name
    m, h = states[f"{prefix}_m"], states[f"{prefix}_h"]

    # Multiply with 1000 to convert Siemens to milli Siemens.
    gNa = params[f"{prefix}_gNa"] * (m**3) * h * 1000  # mS/cm^2

    current = gNa * (v - params["eNa"])
    return current

init_state(v, params)

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
112
113
114
115
116
117
118
119
120
def init_state(self, v, params):
    """Initialize the state such at fixed point of gate dynamics."""
    prefix = self._name
    alpha_m, beta_m = self.m_gate(v, params["vt"])
    alpha_h, beta_h = self.h_gate(v, params["vt"])
    return {
        f"{prefix}_m": alpha_m / (alpha_m + beta_m),
        f"{prefix}_h": alpha_h / (alpha_h + beta_h),
    }

update_states(states, dt, v, params)

Update state.

Source code in jaxley/channels/pospischil.py
85
86
87
88
89
90
91
92
93
94
95
96
97
def update_states(
    self,
    states: Dict[str, jnp.ndarray],
    dt,
    v,
    params: Dict[str, jnp.ndarray],
):
    """Update state."""
    prefix = self._name
    m, h = states[f"{prefix}_m"], states[f"{prefix}_h"]
    new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params["vt"]))
    new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params["vt"]))
    return {f"{prefix}_m": new_m, f"{prefix}_h": new_h}

Bases: Channel

Potassium channel

Source code in jaxley/channels/pospischil.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class K(Channel):
    """Potassium channel"""

    def __init__(self, name: Optional[str] = None):
        super().__init__(name)
        prefix = self._name
        self.channel_params = {
            f"{prefix}_gK": 5e-3,
            "eK": -90.0,
            "vt": -60.0,  # Global parameter, not prefixed with `Na`.
        }
        self.channel_states = {f"{prefix}_n": 0.2}
        self.current_name = f"i_K"

    def update_states(
        self,
        states: Dict[str, jnp.ndarray],
        dt,
        v,
        params: Dict[str, jnp.ndarray],
    ):
        """Update state."""
        prefix = self._name
        n = states[f"{prefix}_n"]
        new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params["vt"]))
        return {f"{prefix}_n": new_n}

    def compute_current(
        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
    ):
        """Return current."""
        prefix = self._name
        n = states[f"{prefix}_n"]

        # Multiply with 1000 to convert Siemens to milli Siemens.
        gK = params[f"{prefix}_gK"] * (n**4) * 1000  # mS/cm^2

        return gK * (v - params["eK"])

    def init_state(self, v, params):
        """Initialize the state such at fixed point of gate dynamics."""
        prefix = self._name
        alpha_n, beta_n = self.n_gate(v, params["vt"])
        return {f"{prefix}_n": alpha_n / (alpha_n + beta_n)}

    @staticmethod
    def n_gate(v, vt):
        v_alpha = v - vt - 15.0
        alpha = 0.032 * efun(-0.2 * v_alpha) / 0.2

        v_beta = v - vt - 10.0
        beta = 0.5 * save_exp(-v_beta / 40.0)
        return alpha, beta

compute_current(states, v, params)

Return current.

Source code in jaxley/channels/pospischil.py
168
169
170
171
172
173
174
175
176
177
178
def compute_current(
    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
    """Return current."""
    prefix = self._name
    n = states[f"{prefix}_n"]

    # Multiply with 1000 to convert Siemens to milli Siemens.
    gK = params[f"{prefix}_gK"] * (n**4) * 1000  # mS/cm^2

    return gK * (v - params["eK"])

init_state(v, params)

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
180
181
182
183
184
def init_state(self, v, params):
    """Initialize the state such at fixed point of gate dynamics."""
    prefix = self._name
    alpha_n, beta_n = self.n_gate(v, params["vt"])
    return {f"{prefix}_n": alpha_n / (alpha_n + beta_n)}

update_states(states, dt, v, params)

Update state.

Source code in jaxley/channels/pospischil.py
155
156
157
158
159
160
161
162
163
164
165
166
def update_states(
    self,
    states: Dict[str, jnp.ndarray],
    dt,
    v,
    params: Dict[str, jnp.ndarray],
):
    """Update state."""
    prefix = self._name
    n = states[f"{prefix}_n"]
    new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params["vt"]))
    return {f"{prefix}_n": new_n}

Bases: Channel

Slow M Potassium channel

Source code in jaxley/channels/pospischil.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
class Km(Channel):
    """Slow M Potassium channel"""

    def __init__(self, name: Optional[str] = None):
        super().__init__(name)
        prefix = self._name
        self.channel_params = {
            f"{prefix}_gKm": 0.004e-3,
            f"{prefix}_taumax": 4000.0,
            f"eK": -90.0,
        }
        self.channel_states = {f"{prefix}_p": 0.2}
        self.current_name = f"i_K"

    def update_states(
        self,
        states: Dict[str, jnp.ndarray],
        dt,
        v,
        params: Dict[str, jnp.ndarray],
    ):
        """Update state."""
        prefix = self._name
        p = states[f"{prefix}_p"]
        new_p = solve_inf_gate_exponential(
            p, dt, *self.p_gate(v, params[f"{prefix}_taumax"])
        )
        return {f"{prefix}_p": new_p}

    def compute_current(
        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
    ):
        """Return current."""
        prefix = self._name
        p = states[f"{prefix}_p"]

        # Multiply with 1000 to convert Siemens to milli Siemens.
        gKm = params[f"{prefix}_gKm"] * p * 1000  # mS/cm^2
        return gKm * (v - params["eK"])

    def init_state(self, v, params):
        """Initialize the state such at fixed point of gate dynamics."""
        prefix = self._name
        alpha_p, beta_p = self.p_gate(v, params[f"{prefix}_taumax"])
        return {f"{prefix}_p": alpha_p / (alpha_p + beta_p)}

    @staticmethod
    def p_gate(v, taumax):
        v_p = v + 35.0
        p_inf = 1.0 / (1.0 + save_exp(-0.1 * v_p))

        tau_p = taumax / (3.3 * save_exp(0.05 * v_p) + save_exp(-0.05 * v_p))

        return p_inf, tau_p

compute_current(states, v, params)

Return current.

Source code in jaxley/channels/pospischil.py
225
226
227
228
229
230
231
232
233
234
def compute_current(
    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
    """Return current."""
    prefix = self._name
    p = states[f"{prefix}_p"]

    # Multiply with 1000 to convert Siemens to milli Siemens.
    gKm = params[f"{prefix}_gKm"] * p * 1000  # mS/cm^2
    return gKm * (v - params["eK"])

init_state(v, params)

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
236
237
238
239
240
def init_state(self, v, params):
    """Initialize the state such at fixed point of gate dynamics."""
    prefix = self._name
    alpha_p, beta_p = self.p_gate(v, params[f"{prefix}_taumax"])
    return {f"{prefix}_p": alpha_p / (alpha_p + beta_p)}

update_states(states, dt, v, params)

Update state.

Source code in jaxley/channels/pospischil.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def update_states(
    self,
    states: Dict[str, jnp.ndarray],
    dt,
    v,
    params: Dict[str, jnp.ndarray],
):
    """Update state."""
    prefix = self._name
    p = states[f"{prefix}_p"]
    new_p = solve_inf_gate_exponential(
        p, dt, *self.p_gate(v, params[f"{prefix}_taumax"])
    )
    return {f"{prefix}_p": new_p}

Bases: Channel

L-type Calcium channel

Source code in jaxley/channels/pospischil.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
class CaL(Channel):
    """L-type Calcium channel"""

    def __init__(self, name: Optional[str] = None):
        super().__init__(name)
        prefix = self._name
        self.channel_params = {
            f"{prefix}_gCaL": 0.1e-3,
            "eCa": 120.0,
        }
        self.channel_states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2}
        self.current_name = f"i_Ca"

    def update_states(
        self,
        states: Dict[str, jnp.ndarray],
        dt,
        v,
        params: Dict[str, jnp.ndarray],
    ):
        """Update state."""
        prefix = self._name
        q, r = states[f"{prefix}_q"], states[f"{prefix}_r"]
        new_q = solve_gate_exponential(q, dt, *self.q_gate(v))
        new_r = solve_gate_exponential(r, dt, *self.r_gate(v))
        return {f"{prefix}_q": new_q, f"{prefix}_r": new_r}

    def compute_current(
        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
    ):
        """Return current."""
        prefix = self._name
        q, r = states[f"{prefix}_q"], states[f"{prefix}_r"]

        # Multiply with 1000 to convert Siemens to milli Siemens.
        gCaL = params[f"{prefix}_gCaL"] * (q**2) * r * 1000  # mS/cm^2

        return gCaL * (v - params["eCa"])

    def init_state(self, v, params):
        """Initialize the state such at fixed point of gate dynamics."""
        prefix = self._name
        alpha_q, beta_q = self.q_gate(v)
        alpha_r, beta_r = self.r_gate(v)
        return {
            f"{prefix}_q": alpha_q / (alpha_q + beta_q),
            f"{prefix}_r": alpha_r / (alpha_r + beta_r),
        }

    @staticmethod
    def q_gate(v):
        v_alpha = -v - 27.0
        alpha = 0.055 * efun(v_alpha / 3.8) * 3.8

        v_beta = -v - 75.0
        beta = 0.94 * save_exp(v_beta / 17.0)
        return alpha, beta

    @staticmethod
    def r_gate(v):
        v_alpha = -v - 13.0
        alpha = 0.000457 * save_exp(v_alpha / 50)

        v_beta = -v - 15.0
        beta = 0.0065 / (save_exp(v_beta / 28.0) + 1)
        return alpha, beta

compute_current(states, v, params)

Return current.

Source code in jaxley/channels/pospischil.py
279
280
281
282
283
284
285
286
287
288
289
def compute_current(
    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
    """Return current."""
    prefix = self._name
    q, r = states[f"{prefix}_q"], states[f"{prefix}_r"]

    # Multiply with 1000 to convert Siemens to milli Siemens.
    gCaL = params[f"{prefix}_gCaL"] * (q**2) * r * 1000  # mS/cm^2

    return gCaL * (v - params["eCa"])

init_state(v, params)

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
291
292
293
294
295
296
297
298
299
def init_state(self, v, params):
    """Initialize the state such at fixed point of gate dynamics."""
    prefix = self._name
    alpha_q, beta_q = self.q_gate(v)
    alpha_r, beta_r = self.r_gate(v)
    return {
        f"{prefix}_q": alpha_q / (alpha_q + beta_q),
        f"{prefix}_r": alpha_r / (alpha_r + beta_r),
    }

update_states(states, dt, v, params)

Update state.

Source code in jaxley/channels/pospischil.py
265
266
267
268
269
270
271
272
273
274
275
276
277
def update_states(
    self,
    states: Dict[str, jnp.ndarray],
    dt,
    v,
    params: Dict[str, jnp.ndarray],
):
    """Update state."""
    prefix = self._name
    q, r = states[f"{prefix}_q"], states[f"{prefix}_r"]
    new_q = solve_gate_exponential(q, dt, *self.q_gate(v))
    new_r = solve_gate_exponential(r, dt, *self.r_gate(v))
    return {f"{prefix}_q": new_q, f"{prefix}_r": new_r}

Bases: Channel

T-type Calcium channel

Source code in jaxley/channels/pospischil.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
class CaT(Channel):
    """T-type Calcium channel"""

    def __init__(self, name: Optional[str] = None):
        super().__init__(name)
        prefix = self._name
        self.channel_params = {
            f"{prefix}_gCaT": 0.4e-4,
            f"{prefix}_vx": 2.0,
            "eCa": 120.0,  # Global parameter, not prefixed with `CaT`.
        }
        self.channel_states = {f"{prefix}_u": 0.2}
        self.current_name = f"i_Ca"

    def update_states(
        self,
        states: Dict[str, jnp.ndarray],
        dt,
        v,
        params: Dict[str, jnp.ndarray],
    ):
        """Update state."""
        prefix = self._name
        u = states[f"{prefix}_u"]
        new_u = solve_inf_gate_exponential(
            u, dt, *self.u_gate(v, params[f"{prefix}_vx"])
        )
        return {f"{prefix}_u": new_u}

    def compute_current(
        self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
    ):
        """Return current."""
        prefix = self._name
        u = states[f"{prefix}_u"]
        s_inf = 1.0 / (1.0 + save_exp(-(v + params[f"{prefix}_vx"] + 57.0) / 6.2))

        # Multiply with 1000 to convert Siemens to milli Siemens.
        gCaT = params[f"{prefix}_gCaT"] * (s_inf**2) * u * 1000  # mS/cm^2

        return gCaT * (v - params["eCa"])

    def init_state(self, v, params):
        """Initialize the state such at fixed point of gate dynamics."""
        prefix = self._name
        alpha_u, beta_u = self.u_gate(v, params[f"{prefix}_vx"])
        return {f"{prefix}_u": alpha_u / (alpha_u + beta_u)}

    @staticmethod
    def u_gate(v, vx):
        v_u1 = v + vx + 81.0
        u_inf = 1.0 / (1.0 + save_exp(v_u1 / 4))

        tau_u = (30.8 + (211.4 + save_exp((v + vx + 113.2) / 5.0))) / (
            3.7 * (1 + save_exp((v + vx + 84.0) / 3.2))
        )

        return u_inf, tau_u

compute_current(states, v, params)

Return current.

Source code in jaxley/channels/pospischil.py
349
350
351
352
353
354
355
356
357
358
359
360
def compute_current(
    self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
    """Return current."""
    prefix = self._name
    u = states[f"{prefix}_u"]
    s_inf = 1.0 / (1.0 + save_exp(-(v + params[f"{prefix}_vx"] + 57.0) / 6.2))

    # Multiply with 1000 to convert Siemens to milli Siemens.
    gCaT = params[f"{prefix}_gCaT"] * (s_inf**2) * u * 1000  # mS/cm^2

    return gCaT * (v - params["eCa"])

init_state(v, params)

Initialize the state such at fixed point of gate dynamics.

Source code in jaxley/channels/pospischil.py
362
363
364
365
366
def init_state(self, v, params):
    """Initialize the state such at fixed point of gate dynamics."""
    prefix = self._name
    alpha_u, beta_u = self.u_gate(v, params[f"{prefix}_vx"])
    return {f"{prefix}_u": alpha_u / (alpha_u + beta_u)}

update_states(states, dt, v, params)

Update state.

Source code in jaxley/channels/pospischil.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
def update_states(
    self,
    states: Dict[str, jnp.ndarray],
    dt,
    v,
    params: Dict[str, jnp.ndarray],
):
    """Update state."""
    prefix = self._name
    u = states[f"{prefix}_u"]
    new_u = solve_inf_gate_exponential(
        u, dt, *self.u_gate(v, params[f"{prefix}_vx"])
    )
    return {f"{prefix}_u": new_u}

Synapses

Synapse

Base class for a synapse.

As in NEURON, a Synapse is considered a point process, which means that its conductances are to be specified in uS and its currents are to be specified in nA.

Source code in jaxley/synapses/synapse.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
class Synapse:
    """Base class for a synapse.

    As in NEURON, a `Synapse` is considered a point process, which means that its
    conductances are to be specified in `uS` and its currents are to be specified in
    `nA`.
    """

    _name = None
    synapse_params = None
    synapse_states = None

    def __init__(self, name: Optional[str] = None):
        self._name = name if name else self.__class__.__name__

    @property
    def name(self) -> Optional[str]:
        return self._name

    def change_name(self, new_name: str):
        """Change the synapse name.

        Args:
            new_name: The new name of the channel.

        Returns:
            Renamed channel, such that this function is chainable.
        """
        old_prefix = self._name + "_"
        new_prefix = new_name + "_"

        self._name = new_name
        self.synapse_params = {
            (
                new_prefix + key[len(old_prefix) :]
                if key.startswith(old_prefix)
                else key
            ): value
            for key, value in self.synapse_params.items()
        }

        self.synapse_states = {
            (
                new_prefix + key[len(old_prefix) :]
                if key.startswith(old_prefix)
                else key
            ): value
            for key, value in self.synapse_states.items()
        }
        return self

    def update_states(
        states: Dict[str, jnp.ndarray],
        delta_t: float,
        pre_voltage: jnp.ndarray,
        post_voltage: jnp.ndarray,
        params: Dict[str, jnp.ndarray],
    ) -> Dict[str, jnp.ndarray]:
        """ODE update step.

        Args:
            states: States of the synapse.
            delta_t: Time step in `ms`.
            pre_voltage: Voltage of the presynaptic compartment, shape `()`.
            post_voltage: Voltage of the postsynaptic compartment, shape `()`.
            params: Parameters of the synapse. Conductances in `uS`.

        Returns:
            Updated states."""
        raise NotImplementedError

    def compute_current(
        states: Dict[str, jnp.ndarray],
        pre_voltage: jnp.ndarray,
        post_voltage: jnp.ndarray,
        params: Dict[str, jnp.ndarray],
    ) -> jnp.ndarray:
        """Return current through one synapse in `nA`.

        Internally, we use `jax.vmap` to vectorize this function across many synapses.

        Args:
            states: States of the synapse.
            pre_voltage: Voltage of the presynaptic compartment, shape `()`.
            post_voltage: Voltage of the postsynaptic compartment, shape `()`.
            params: Parameters of the synapse. Conductances in `uS`.

        Returns:
            Current through the synapse in `nA`, shape `()`.
        """
        raise NotImplementedError

change_name(new_name)

Change the synapse name.

Parameters:

Name Type Description Default
new_name str

The new name of the channel.

required

Returns:

Type Description

Renamed channel, such that this function is chainable.

Source code in jaxley/synapses/synapse.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def change_name(self, new_name: str):
    """Change the synapse name.

    Args:
        new_name: The new name of the channel.

    Returns:
        Renamed channel, such that this function is chainable.
    """
    old_prefix = self._name + "_"
    new_prefix = new_name + "_"

    self._name = new_name
    self.synapse_params = {
        (
            new_prefix + key[len(old_prefix) :]
            if key.startswith(old_prefix)
            else key
        ): value
        for key, value in self.synapse_params.items()
    }

    self.synapse_states = {
        (
            new_prefix + key[len(old_prefix) :]
            if key.startswith(old_prefix)
            else key
        ): value
        for key, value in self.synapse_states.items()
    }
    return self

compute_current(states, pre_voltage, post_voltage, params)

Return current through one synapse in nA.

Internally, we use jax.vmap to vectorize this function across many synapses.

Parameters:

Name Type Description Default
states Dict[str, ndarray]

States of the synapse.

required
pre_voltage ndarray

Voltage of the presynaptic compartment, shape ().

required
post_voltage ndarray

Voltage of the postsynaptic compartment, shape ().

required
params Dict[str, ndarray]

Parameters of the synapse. Conductances in uS.

required

Returns:

Type Description
ndarray

Current through the synapse in nA, shape ().

Source code in jaxley/synapses/synapse.py
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def compute_current(
    states: Dict[str, jnp.ndarray],
    pre_voltage: jnp.ndarray,
    post_voltage: jnp.ndarray,
    params: Dict[str, jnp.ndarray],
) -> jnp.ndarray:
    """Return current through one synapse in `nA`.

    Internally, we use `jax.vmap` to vectorize this function across many synapses.

    Args:
        states: States of the synapse.
        pre_voltage: Voltage of the presynaptic compartment, shape `()`.
        post_voltage: Voltage of the postsynaptic compartment, shape `()`.
        params: Parameters of the synapse. Conductances in `uS`.

    Returns:
        Current through the synapse in `nA`, shape `()`.
    """
    raise NotImplementedError

update_states(states, delta_t, pre_voltage, post_voltage, params)

ODE update step.

Parameters:

Name Type Description Default
states Dict[str, ndarray]

States of the synapse.

required
delta_t float

Time step in ms.

required
pre_voltage ndarray

Voltage of the presynaptic compartment, shape ().

required
post_voltage ndarray

Voltage of the postsynaptic compartment, shape ().

required
params Dict[str, ndarray]

Parameters of the synapse. Conductances in uS.

required

Returns:

Type Description
Dict[str, ndarray]

Updated states.

Source code in jaxley/synapses/synapse.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def update_states(
    states: Dict[str, jnp.ndarray],
    delta_t: float,
    pre_voltage: jnp.ndarray,
    post_voltage: jnp.ndarray,
    params: Dict[str, jnp.ndarray],
) -> Dict[str, jnp.ndarray]:
    """ODE update step.

    Args:
        states: States of the synapse.
        delta_t: Time step in `ms`.
        pre_voltage: Voltage of the presynaptic compartment, shape `()`.
        post_voltage: Voltage of the postsynaptic compartment, shape `()`.
        params: Parameters of the synapse. Conductances in `uS`.

    Returns:
        Updated states."""
    raise NotImplementedError

Ionotropic Synapse

Bases: Synapse

Compute synaptic current and update synapse state for a generic ionotropic synapse.

The synapse state “s” is the probability that a postsynaptic receptor channel is open, and this depends on the amount of neurotransmitter released, which is in turn dependent on the presynaptic voltage.

The synaptic parameters are
  • gS: the maximal conductance across the postsynaptic membrane (uS)
  • e_syn: the reversal potential across the postsynaptic membrane (mV)
  • k_minus: the rate constant of neurotransmitter unbinding from the postsynaptic receptor (s^-1)
Details of this implementation can be found in the following book chapter

L. F. Abbott and E. Marder, “Modeling Small Networks,” in Methods in Neuronal Modeling, C. Koch and I. Sergev, Eds. Cambridge: MIT Press, 1998.

Source code in jaxley/synapses/ionotropic.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class IonotropicSynapse(Synapse):
    """
    Compute synaptic current and update synapse state for a generic ionotropic synapse.

    The synapse state "s" is the probability that a postsynaptic receptor channel is
    open, and this depends on the amount of neurotransmitter released, which is in turn
    dependent on the presynaptic voltage.

    The synaptic parameters are:
        - gS: the maximal conductance across the postsynaptic membrane (uS)
        - e_syn: the reversal potential across the postsynaptic membrane (mV)
        - k_minus: the rate constant of neurotransmitter unbinding from the postsynaptic
            receptor (s^-1)

    Details of this implementation can be found in the following book chapter:
        L. F. Abbott and E. Marder, "Modeling Small Networks," in Methods in Neuronal
        Modeling, C. Koch and I. Sergev, Eds. Cambridge: MIT Press, 1998.

    """

    def __init__(self, name: Optional[str] = None):
        super().__init__(name)
        prefix = self._name
        self.synapse_params = {
            f"{prefix}_gS": 1e-4,
            f"{prefix}_e_syn": 0.0,
            f"{prefix}_k_minus": 0.025,
        }
        self.synapse_states = {f"{prefix}_s": 0.2}

    def update_states(
        self,
        states: Dict,
        delta_t: float,
        pre_voltage: float,
        post_voltage: float,
        params: Dict,
    ) -> Dict:
        """Return updated synapse state and current."""
        prefix = self._name
        v_th = -35.0  # mV
        delta = 10.0  # mV

        s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))
        tau_s = (1.0 - s_inf) / params[f"{prefix}_k_minus"]

        slope = -1.0 / tau_s
        exp_term = save_exp(slope * delta_t)
        new_s = states[f"{prefix}_s"] * exp_term + s_inf * (1.0 - exp_term)
        return {f"{prefix}_s": new_s}

    def compute_current(
        self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict
    ) -> float:
        prefix = self._name
        g_syn = params[f"{prefix}_gS"] * states[f"{prefix}_s"]
        return g_syn * (post_voltage - params[f"{prefix}_e_syn"])

update_states(states, delta_t, pre_voltage, post_voltage, params)

Return updated synapse state and current.

Source code in jaxley/synapses/ionotropic.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def update_states(
    self,
    states: Dict,
    delta_t: float,
    pre_voltage: float,
    post_voltage: float,
    params: Dict,
) -> Dict:
    """Return updated synapse state and current."""
    prefix = self._name
    v_th = -35.0  # mV
    delta = 10.0  # mV

    s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))
    tau_s = (1.0 - s_inf) / params[f"{prefix}_k_minus"]

    slope = -1.0 / tau_s
    exp_term = save_exp(slope * delta_t)
    new_s = states[f"{prefix}_s"] * exp_term + s_inf * (1.0 - exp_term)
    return {f"{prefix}_s": new_s}

TanH Rate Synapse

Bases: Synapse

Compute synaptic current for tanh synapse (no state).

Source code in jaxley/synapses/tanh_rate.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class TanhRateSynapse(Synapse):
    """
    Compute synaptic current for tanh synapse (no state).
    """

    def __init__(self, name: Optional[str] = None):
        super().__init__(name)
        prefix = self._name
        self.synapse_params = {
            f"{prefix}_gS": 1e-4,
            f"{prefix}_x_offset": -70.0,
            f"{prefix}_slope": 1.0,
        }
        self.synapse_states = {}

    def update_states(
        self,
        states: Dict,
        delta_t: float,
        pre_voltage: float,
        post_voltage: float,
        params: Dict,
    ) -> Dict:
        """Return updated synapse state and current."""
        return {}

    def compute_current(
        self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict
    ) -> float:
        """Return updated synapse state and current."""
        prefix = self._name
        current = (
            -1
            * params[f"{prefix}_gS"]
            * jnp.tanh(
                (pre_voltage - params[f"{prefix}_x_offset"]) * params[f"{prefix}_slope"]
            )
        )
        return current

compute_current(states, pre_voltage, post_voltage, params)

Return updated synapse state and current.

Source code in jaxley/synapses/tanh_rate.py
37
38
39
40
41
42
43
44
45
46
47
48
49
def compute_current(
    self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict
) -> float:
    """Return updated synapse state and current."""
    prefix = self._name
    current = (
        -1
        * params[f"{prefix}_gS"]
        * jnp.tanh(
            (pre_voltage - params[f"{prefix}_x_offset"]) * params[f"{prefix}_slope"]
        )
    )
    return current

update_states(states, delta_t, pre_voltage, post_voltage, params)

Return updated synapse state and current.

Source code in jaxley/synapses/tanh_rate.py
26
27
28
29
30
31
32
33
34
35
def update_states(
    self,
    states: Dict,
    delta_t: float,
    pre_voltage: float,
    post_voltage: float,
    params: Dict,
) -> Dict:
    """Return updated synapse state and current."""
    return {}