Skip to content

Optimization

TypeOptimizer

optax wrapper which allows different argument values for different params.

Source code in jaxley/optimize/optimizer.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class TypeOptimizer:
    """`optax` wrapper which allows different argument values for different params."""

    def __init__(
        self,
        optimizer: Callable,
        optimizer_args: Dict[str, Any],
        opt_params: List[Dict[str, jnp.ndarray]],
    ):
        """Create the optimizers.

        This requires access to `opt_params` in order to know how many optimizers
        should be created. It creates `len(opt_params)` optimizers.

        Example usage:
        ```
        lrs = {"HH_gNa": 0.01, "radius": 1.0}
        optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)
        opt_state = optimizer.init(opt_params)
        ```

        ```
        optimizer_args = {"HH_gNa": [0.01, 0.4], "radius": [1.0, 0.8]}
        optimizer = TypeOptimizer(
            lambda args: optax.sgd(args[0], momentum=args[1]),
            optimizer_args,
            opt_params
        )
        opt_state = optimizer.init(opt_params)
        ```

        Args:
            optimizer: A Callable that takes the learning rate and returns the
                `optax.optimizer` which should be used.
            optimizer_args: The arguments for different kinds of parameters.
                Each item of the dictionary will be passed to the `Callable` passed to
                `optimizer`.
            opt_params: The parameters to be optimized. The exact values are not used,
                only the number of elements in the list and the key of each dict.
        """
        self.base_optimizer = optimizer

        self.optimizers = []
        for params in opt_params:
            names = list(params.keys())
            assert len(names) == 1, "Multiple parameters were added at once."
            name = names[0]
            optimizer = self.base_optimizer(optimizer_args[name])
            self.optimizers.append({name: optimizer})

    def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:
        """Initialize the optimizers. Equivalent to `optax.optimizers.init()`."""
        opt_states = []
        for params, optimizer in zip(opt_params, self.optimizers):
            name = list(optimizer.keys())[0]
            opt_state = optimizer[name].init(params)
            opt_states.append(opt_state)
        return opt_states

    def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:
        """Update the optimizers. Equivalent to `optax.optimizers.update()`."""
        all_updates = []
        new_opt_states = []
        for grad, state, opt in zip(gradient, opt_state, self.optimizers):
            name = list(opt.keys())[0]
            updates, new_opt_state = opt[name].update(grad, state)
            all_updates.append(updates)
            new_opt_states.append(new_opt_state)
        return all_updates, new_opt_states

__init__(optimizer, optimizer_args, opt_params)

Create the optimizers.

This requires access to opt_params in order to know how many optimizers should be created. It creates len(opt_params) optimizers.

Example usage:

lrs = {"HH_gNa": 0.01, "radius": 1.0}
optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)
opt_state = optimizer.init(opt_params)

optimizer_args = {"HH_gNa": [0.01, 0.4], "radius": [1.0, 0.8]}
optimizer = TypeOptimizer(
    lambda args: optax.sgd(args[0], momentum=args[1]),
    optimizer_args,
    opt_params
)
opt_state = optimizer.init(opt_params)

Parameters:

Name Type Description Default
optimizer Callable

A Callable that takes the learning rate and returns the optax.optimizer which should be used.

required
optimizer_args Dict[str, Any]

The arguments for different kinds of parameters. Each item of the dictionary will be passed to the Callable passed to optimizer.

required
opt_params List[Dict[str, ndarray]]

The parameters to be optimized. The exact values are not used, only the number of elements in the list and the key of each dict.

required
Source code in jaxley/optimize/optimizer.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def __init__(
    self,
    optimizer: Callable,
    optimizer_args: Dict[str, Any],
    opt_params: List[Dict[str, jnp.ndarray]],
):
    """Create the optimizers.

    This requires access to `opt_params` in order to know how many optimizers
    should be created. It creates `len(opt_params)` optimizers.

    Example usage:
    ```
    lrs = {"HH_gNa": 0.01, "radius": 1.0}
    optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)
    opt_state = optimizer.init(opt_params)
    ```

    ```
    optimizer_args = {"HH_gNa": [0.01, 0.4], "radius": [1.0, 0.8]}
    optimizer = TypeOptimizer(
        lambda args: optax.sgd(args[0], momentum=args[1]),
        optimizer_args,
        opt_params
    )
    opt_state = optimizer.init(opt_params)
    ```

    Args:
        optimizer: A Callable that takes the learning rate and returns the
            `optax.optimizer` which should be used.
        optimizer_args: The arguments for different kinds of parameters.
            Each item of the dictionary will be passed to the `Callable` passed to
            `optimizer`.
        opt_params: The parameters to be optimized. The exact values are not used,
            only the number of elements in the list and the key of each dict.
    """
    self.base_optimizer = optimizer

    self.optimizers = []
    for params in opt_params:
        names = list(params.keys())
        assert len(names) == 1, "Multiple parameters were added at once."
        name = names[0]
        optimizer = self.base_optimizer(optimizer_args[name])
        self.optimizers.append({name: optimizer})

init(opt_params)

Initialize the optimizers. Equivalent to optax.optimizers.init().

Source code in jaxley/optimize/optimizer.py
59
60
61
62
63
64
65
66
def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:
    """Initialize the optimizers. Equivalent to `optax.optimizers.init()`."""
    opt_states = []
    for params, optimizer in zip(opt_params, self.optimizers):
        name = list(optimizer.keys())[0]
        opt_state = optimizer[name].init(params)
        opt_states.append(opt_state)
    return opt_states

update(gradient, opt_state)

Update the optimizers. Equivalent to optax.optimizers.update().

Source code in jaxley/optimize/optimizer.py
68
69
70
71
72
73
74
75
76
77
def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:
    """Update the optimizers. Equivalent to `optax.optimizers.update()`."""
    all_updates = []
    new_opt_states = []
    for grad, state, opt in zip(gradient, opt_state, self.optimizers):
        name = list(opt.keys())[0]
        updates, new_opt_state = opt[name].update(grad, state)
        all_updates.append(updates)
        new_opt_states.append(new_opt_state)
    return all_updates, new_opt_states

AffineTransform

Bases: Transform

Source code in jaxley/optimize/transforms.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class AffineTransform(Transform):
    def __init__(self, scale: ArrayLike, shift: ArrayLike):
        """This transform rescales and shifts the input.

        Args:
            scale (ArrayLike): Scaling factor.
            shift (ArrayLike): Additive shift.

        Raises:
            ValueError: Scale needs to be larger than 0
        """
        if jnp.allclose(scale, 0):
            raise ValueError("a cannot be zero, must be invertible")
        self.a = scale
        self.b = shift

    def forward(self, x: ArrayLike) -> Array:
        return self.a * x + self.b

    def inverse(self, x: ArrayLike) -> Array:
        return (x - self.b) / self.a

__init__(scale, shift)

This transform rescales and shifts the input.

Parameters:

Name Type Description Default
scale ArrayLike

Scaling factor.

required
shift ArrayLike

Additive shift.

required

Raises:

Type Description
ValueError

Scale needs to be larger than 0

Source code in jaxley/optimize/transforms.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def __init__(self, scale: ArrayLike, shift: ArrayLike):
    """This transform rescales and shifts the input.

    Args:
        scale (ArrayLike): Scaling factor.
        shift (ArrayLike): Additive shift.

    Raises:
        ValueError: Scale needs to be larger than 0
    """
    if jnp.allclose(scale, 0):
        raise ValueError("a cannot be zero, must be invertible")
    self.a = scale
    self.b = shift

ChainTransform

Bases: Transform

Chaining together multiple transformations

Source code in jaxley/optimize/transforms.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class ChainTransform(Transform):
    """Chaining together multiple transformations"""

    def __init__(self, transforms: Sequence[Transform]) -> None:
        """A chain of transformations

        Args:
            transforms (Sequence[Transform]): Transforms to apply
        """
        super().__init__()
        self.transforms = transforms

    def forward(self, x: ArrayLike) -> Array:
        for transform in self.transforms:
            x = transform(x)
        return x

    def inverse(self, y: ArrayLike) -> Array:
        for transform in reversed(self.transforms):
            y = transform.inverse(y)
        return y

__init__(transforms)

A chain of transformations

Parameters:

Name Type Description Default
transforms Sequence[Transform]

Transforms to apply

required
Source code in jaxley/optimize/transforms.py
115
116
117
118
119
120
121
122
def __init__(self, transforms: Sequence[Transform]) -> None:
    """A chain of transformations

    Args:
        transforms (Sequence[Transform]): Transforms to apply
    """
    super().__init__()
    self.transforms = transforms

CustomTransform

Bases: Transform

Custom transformation

Source code in jaxley/optimize/transforms.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
class CustomTransform(Transform):
    """Custom transformation"""

    def __init__(self, forward_fn: Callable, inverse_fn: Callable) -> None:
        """A custom transformation using a user-defined froward and
        inverse function

        Args:
            forward_fn (Callable): Forward transformation
            inverse_fn (Callable): Inverse transformation
        """
        super().__init__()
        self.forward_fn = forward_fn
        self.inverse_fn = inverse_fn

    def forward(self, x: ArrayLike) -> Array:
        return self.forward_fn(x)

    def inverse(self, y: ArrayLike) -> Array:
        return self.inverse_fn(y)

__init__(forward_fn, inverse_fn)

A custom transformation using a user-defined froward and inverse function

Parameters:

Name Type Description Default
forward_fn Callable

Forward transformation

required
inverse_fn Callable

Inverse transformation

required
Source code in jaxley/optimize/transforms.py
157
158
159
160
161
162
163
164
165
166
167
def __init__(self, forward_fn: Callable, inverse_fn: Callable) -> None:
    """A custom transformation using a user-defined froward and
    inverse function

    Args:
        forward_fn (Callable): Forward transformation
        inverse_fn (Callable): Inverse transformation
    """
    super().__init__()
    self.forward_fn = forward_fn
    self.inverse_fn = inverse_fn

MaskedTransform

Bases: Transform

Source code in jaxley/optimize/transforms.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class MaskedTransform(Transform):
    def __init__(self, mask: ArrayLike, transform: Transform) -> None:
        """A masked transformation

        Args:
            mask (ArrayLike): Which elements to transform
            transform (Transform): Transformation to apply
        """
        super().__init__()
        self.mask = mask
        self.transform = transform

    def forward(self, x: ArrayLike) -> Array:
        return jnp.where(self.mask, self.transform.forward(x), x)

    def inverse(self, y: ArrayLike) -> Array:
        return jnp.where(self.mask, self.transform.inverse(y), y)

__init__(mask, transform)

A masked transformation

Parameters:

Name Type Description Default
mask ArrayLike

Which elements to transform

required
transform Transform

Transformation to apply

required
Source code in jaxley/optimize/transforms.py
136
137
138
139
140
141
142
143
144
145
def __init__(self, mask: ArrayLike, transform: Transform) -> None:
    """A masked transformation

    Args:
        mask (ArrayLike): Which elements to transform
        transform (Transform): Transformation to apply
    """
    super().__init__()
    self.mask = mask
    self.transform = transform

NegSoftplusTransform

Bases: SoftplusTransform

Negative softplus transformation.

Source code in jaxley/optimize/transforms.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class NegSoftplusTransform(SoftplusTransform):
    """Negative softplus transformation."""

    def __init__(self, upper: ArrayLike) -> None:
        """This transform maps any value bijectively to the interval (-inf, upper].

        Args:
            upper (ArrayLike): Upper bound of the interval.
        """
        super().__init__(upper)

    def forward(self, x: ArrayLike) -> Array:
        return -super().forward(-x)

    def inverse(self, y: ArrayLike) -> Array:
        return -super().inverse(-y)

__init__(upper)

This transform maps any value bijectively to the interval (-inf, upper].

Parameters:

Name Type Description Default
upper ArrayLike

Upper bound of the interval.

required
Source code in jaxley/optimize/transforms.py
74
75
76
77
78
79
80
def __init__(self, upper: ArrayLike) -> None:
    """This transform maps any value bijectively to the interval (-inf, upper].

    Args:
        upper (ArrayLike): Upper bound of the interval.
    """
    super().__init__(upper)

ParamTransform

Parameter transformation utility.

This class is used to transform parameters usually from an unconstrained space to a constrained space and back (bacause most biophysical parameter are bounded). The user can specify a PyTree of transforms that are applied to the parameters.

Attributes:

Name Type Description
tf_dict

A PyTree of transforms for each parameter.

Source code in jaxley/optimize/transforms.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
class ParamTransform:
    """Parameter transformation utility.

    This class is used to transform parameters usually from an unconstrained space to a constrained space
    and back (bacause most biophysical parameter are bounded). The user can specify a PyTree of transforms
    that are applied to the parameters.

    Attributes:
        tf_dict: A PyTree of transforms for each parameter.

    """

    def __init__(self, tf_dict: List[Dict[str, Transform]] | Transform) -> None:
        """Creates a new ParamTransform object.

        Args:
            tf_dict: A PyTree of transforms for each parameter.
        """

        self.tf_dict = tf_dict

    def forward(
        self, params: List[Dict[str, ArrayLike]] | ArrayLike
    ) -> Dict[str, Array]:
        """Pushes unconstrained parameters through a tf such that they fit the interval.

        Args:
            params: A list of dictionaries (or any PyTree) with unconstrained parameters.

        Returns:
            A list of dictionaries (or any PyTree) with transformed parameters.

        """

        return jax.tree_util.tree_map(lambda x, tf: tf.forward(x), params, self.tf_dict)

    def inverse(
        self, params: List[Dict[str, ArrayLike]] | ArrayLike
    ) -> Dict[str, Array]:
        """Takes parameters from within the interval and makes them unconstrained.

        Args:
            params: A list of dictionaries (or any PyTree) with transformed parameters.

        Returns:
            A list of dictionaries (or any PyTree) with unconstrained parameters.
        """

        return jax.tree_util.tree_map(lambda x, tf: tf.inverse(x), params, self.tf_dict)

__init__(tf_dict)

Creates a new ParamTransform object.

Parameters:

Name Type Description Default
tf_dict List[Dict[str, Transform]] | Transform

A PyTree of transforms for each parameter.

required
Source code in jaxley/optimize/transforms.py
188
189
190
191
192
193
194
195
def __init__(self, tf_dict: List[Dict[str, Transform]] | Transform) -> None:
    """Creates a new ParamTransform object.

    Args:
        tf_dict: A PyTree of transforms for each parameter.
    """

    self.tf_dict = tf_dict

forward(params)

Pushes unconstrained parameters through a tf such that they fit the interval.

Parameters:

Name Type Description Default
params List[Dict[str, ArrayLike]] | ArrayLike

A list of dictionaries (or any PyTree) with unconstrained parameters.

required

Returns:

Type Description
Dict[str, Array]

A list of dictionaries (or any PyTree) with transformed parameters.

Source code in jaxley/optimize/transforms.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def forward(
    self, params: List[Dict[str, ArrayLike]] | ArrayLike
) -> Dict[str, Array]:
    """Pushes unconstrained parameters through a tf such that they fit the interval.

    Args:
        params: A list of dictionaries (or any PyTree) with unconstrained parameters.

    Returns:
        A list of dictionaries (or any PyTree) with transformed parameters.

    """

    return jax.tree_util.tree_map(lambda x, tf: tf.forward(x), params, self.tf_dict)

inverse(params)

Takes parameters from within the interval and makes them unconstrained.

Parameters:

Name Type Description Default
params List[Dict[str, ArrayLike]] | ArrayLike

A list of dictionaries (or any PyTree) with transformed parameters.

required

Returns:

Type Description
Dict[str, Array]

A list of dictionaries (or any PyTree) with unconstrained parameters.

Source code in jaxley/optimize/transforms.py
212
213
214
215
216
217
218
219
220
221
222
223
224
def inverse(
    self, params: List[Dict[str, ArrayLike]] | ArrayLike
) -> Dict[str, Array]:
    """Takes parameters from within the interval and makes them unconstrained.

    Args:
        params: A list of dictionaries (or any PyTree) with transformed parameters.

    Returns:
        A list of dictionaries (or any PyTree) with unconstrained parameters.
    """

    return jax.tree_util.tree_map(lambda x, tf: tf.inverse(x), params, self.tf_dict)

SigmoidTransform

Bases: Transform

Sigmoid transformation.

Source code in jaxley/optimize/transforms.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class SigmoidTransform(Transform):
    """Sigmoid transformation."""

    def __init__(self, lower: ArrayLike, upper: ArrayLike) -> None:
        """This transform maps any value bijectively to the interval [lower, upper].

        Args:
            lower (ArrayLike): Lower bound of the interval.
            upper (ArrayLike): Upper bound of the interval.
        """
        super().__init__()
        self.lower = lower
        self.width = upper - lower

    def forward(self, x: ArrayLike) -> Array:
        y = 1.0 / (1.0 + save_exp(-x))
        return self.lower + self.width * y

    def inverse(self, y: ArrayLike) -> Array:
        x = (y - self.lower) / self.width
        x = -jnp.log((1.0 / x) - 1.0)
        return x

__init__(lower, upper)

This transform maps any value bijectively to the interval [lower, upper].

Parameters:

Name Type Description Default
lower ArrayLike

Lower bound of the interval.

required
upper ArrayLike

Upper bound of the interval.

required
Source code in jaxley/optimize/transforms.py
31
32
33
34
35
36
37
38
39
40
def __init__(self, lower: ArrayLike, upper: ArrayLike) -> None:
    """This transform maps any value bijectively to the interval [lower, upper].

    Args:
        lower (ArrayLike): Lower bound of the interval.
        upper (ArrayLike): Upper bound of the interval.
    """
    super().__init__()
    self.lower = lower
    self.width = upper - lower

SoftplusTransform

Bases: Transform

Softplus transformation.

Source code in jaxley/optimize/transforms.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class SoftplusTransform(Transform):
    """Softplus transformation."""

    def __init__(self, lower: ArrayLike) -> None:
        """This transform maps any value bijectively to the interval [lower, inf).

        Args:
            lower (ArrayLike): Lower bound of the interval.
        """
        super().__init__()
        self.lower = lower

    def forward(self, x: ArrayLike) -> Array:
        return jnp.log1p(save_exp(x)) + self.lower

    def inverse(self, y: ArrayLike) -> Array:
        return jnp.log(save_exp(y - self.lower) - 1.0)

__init__(lower)

This transform maps any value bijectively to the interval [lower, inf).

Parameters:

Name Type Description Default
lower ArrayLike

Lower bound of the interval.

required
Source code in jaxley/optimize/transforms.py
55
56
57
58
59
60
61
62
def __init__(self, lower: ArrayLike) -> None:
    """This transform maps any value bijectively to the interval [lower, inf).

    Args:
        lower (ArrayLike): Lower bound of the interval.
    """
    super().__init__()
    self.lower = lower