Skip to content

Optimization

TypeOptimizer

optax wrapper which allows different argument values for different params.

Source code in jaxley/optimize/optimizer.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class TypeOptimizer:
    """`optax` wrapper which allows different argument values for different params."""

    def __init__(
        self,
        optimizer: Callable,
        optimizer_args: Dict[str, Any],
        opt_params: List[Dict[str, jnp.ndarray]],
    ):
        """Create the optimizers.

        This requires access to `opt_params` in order to know how many optimizers
        should be created. It creates `len(opt_params)` optimizers.

        Example usage:
        ```
        lrs = {"HH_gNa": 0.01, "radius": 1.0}
        optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)
        opt_state = optimizer.init(opt_params)
        ```

        ```
        optimizer_args = {"HH_gNa": [0.01, 0.4], "radius": [1.0, 0.8]}
        optimizer = TypeOptimizer(
            lambda args: optax.sgd(args[0], momentum=args[1]),
            optimizer_args,
            opt_params
        )
        opt_state = optimizer.init(opt_params)
        ```

        Args:
            optimizer: A Callable that takes the learning rate and returns the
                `optax.optimizer` which should be used.
            optimizer_args: The arguments for different kinds of parameters.
                Each item of the dictionary will be passed to the `Callable` passed to
                `optimizer`.
            opt_params: The parameters to be optimized. The exact values are not used,
                only the number of elements in the list and the key of each dict.
        """
        self.base_optimizer = optimizer

        self.optimizers = []
        for params in opt_params:
            names = list(params.keys())
            assert len(names) == 1, "Multiple parameters were added at once."
            name = names[0]
            optimizer = self.base_optimizer(optimizer_args[name])
            self.optimizers.append({name: optimizer})

    def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:
        """Initialize the optimizers. Equivalent to `optax.optimizers.init()`."""
        opt_states = []
        for params, optimizer in zip(opt_params, self.optimizers):
            name = list(optimizer.keys())[0]
            opt_state = optimizer[name].init(params)
            opt_states.append(opt_state)
        return opt_states

    def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:
        """Update the optimizers. Equivalent to `optax.optimizers.update()`."""
        all_updates = []
        new_opt_states = []
        for grad, state, opt in zip(gradient, opt_state, self.optimizers):
            name = list(opt.keys())[0]
            updates, new_opt_state = opt[name].update(grad, state)
            all_updates.append(updates)
            new_opt_states.append(new_opt_state)
        return all_updates, new_opt_states

__init__(optimizer, optimizer_args, opt_params)

Create the optimizers.

This requires access to opt_params in order to know how many optimizers should be created. It creates len(opt_params) optimizers.

Example usage:

lrs = {"HH_gNa": 0.01, "radius": 1.0}
optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)
opt_state = optimizer.init(opt_params)

optimizer_args = {"HH_gNa": [0.01, 0.4], "radius": [1.0, 0.8]}
optimizer = TypeOptimizer(
    lambda args: optax.sgd(args[0], momentum=args[1]),
    optimizer_args,
    opt_params
)
opt_state = optimizer.init(opt_params)

Parameters:

Name Type Description Default
optimizer Callable

A Callable that takes the learning rate and returns the optax.optimizer which should be used.

required
optimizer_args Dict[str, Any]

The arguments for different kinds of parameters. Each item of the dictionary will be passed to the Callable passed to optimizer.

required
opt_params List[Dict[str, ndarray]]

The parameters to be optimized. The exact values are not used, only the number of elements in the list and the key of each dict.

required
Source code in jaxley/optimize/optimizer.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def __init__(
    self,
    optimizer: Callable,
    optimizer_args: Dict[str, Any],
    opt_params: List[Dict[str, jnp.ndarray]],
):
    """Create the optimizers.

    This requires access to `opt_params` in order to know how many optimizers
    should be created. It creates `len(opt_params)` optimizers.

    Example usage:
    ```
    lrs = {"HH_gNa": 0.01, "radius": 1.0}
    optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)
    opt_state = optimizer.init(opt_params)
    ```

    ```
    optimizer_args = {"HH_gNa": [0.01, 0.4], "radius": [1.0, 0.8]}
    optimizer = TypeOptimizer(
        lambda args: optax.sgd(args[0], momentum=args[1]),
        optimizer_args,
        opt_params
    )
    opt_state = optimizer.init(opt_params)
    ```

    Args:
        optimizer: A Callable that takes the learning rate and returns the
            `optax.optimizer` which should be used.
        optimizer_args: The arguments for different kinds of parameters.
            Each item of the dictionary will be passed to the `Callable` passed to
            `optimizer`.
        opt_params: The parameters to be optimized. The exact values are not used,
            only the number of elements in the list and the key of each dict.
    """
    self.base_optimizer = optimizer

    self.optimizers = []
    for params in opt_params:
        names = list(params.keys())
        assert len(names) == 1, "Multiple parameters were added at once."
        name = names[0]
        optimizer = self.base_optimizer(optimizer_args[name])
        self.optimizers.append({name: optimizer})

init(opt_params)

Initialize the optimizers. Equivalent to optax.optimizers.init().

Source code in jaxley/optimize/optimizer.py
59
60
61
62
63
64
65
66
def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:
    """Initialize the optimizers. Equivalent to `optax.optimizers.init()`."""
    opt_states = []
    for params, optimizer in zip(opt_params, self.optimizers):
        name = list(optimizer.keys())[0]
        opt_state = optimizer[name].init(params)
        opt_states.append(opt_state)
    return opt_states

update(gradient, opt_state)

Update the optimizers. Equivalent to optax.optimizers.update().

Source code in jaxley/optimize/optimizer.py
68
69
70
71
72
73
74
75
76
77
def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:
    """Update the optimizers. Equivalent to `optax.optimizers.update()`."""
    all_updates = []
    new_opt_states = []
    for grad, state, opt in zip(gradient, opt_state, self.optimizers):
        name = list(opt.keys())[0]
        updates, new_opt_state = opt[name].update(grad, state)
        all_updates.append(updates)
        new_opt_states.append(new_opt_state)
    return all_updates, new_opt_states

ParamTransform

Parameter transformation utility.

This class is used to transform parameters from an unconstrained space to a constrained space and back. If the range is bounded both from above and below, we use the sigmoid function to transform the parameters. If the range is only bounded from below or above, we use softplus.

Attributes:

Name Type Description
lowers

A dictionary of lower bounds for each parameter (None for no bound).

uppers

A dictionary of upper bounds for each parameter (None for no bound).

Source code in jaxley/optimize/transforms.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class ParamTransform:
    """Parameter transformation utility.

    This class is used to transform parameters from an unconstrained space to a constrained space
    and back. If the range is bounded both from above and below, we use the sigmoid function to
    transform the parameters. If the range is only bounded from below or above, we use softplus.

    Attributes:
        lowers: A dictionary of lower bounds for each parameter (None for no bound).
        uppers: A dictionary of upper bounds for each parameter (None for no bound).

    """

    def __init__(self, lowers: Dict[str, float], uppers: Dict[str, float]):
        """Initialize the bounds.

        Args:
            lowers: A dictionary of lower bounds for each parameter (None for no bound).
            uppers: A dictionary of upper bounds for each parameter (None for no bound).
        """

        self.lowers = lowers
        self.uppers = uppers

    def forward(self, params: List[Dict[str, jnp.ndarray]]) -> jnp.ndarray:
        """Pushes unconstrained parameters through a tf such that they fit the interval.

        Args:
            params: A list of dictionaries with unconstrained parameters.

        Returns:
            A list of dictionaries with transformed parameters.

        """

        tf_params = []
        for param in params:
            key = list(param.keys())[0]

            # If constrained from below and above, use sigmoid
            if self.lowers[key] is not None and self.uppers[key] is not None:
                tf = (
                    sigmoid(param[key]) * (self.uppers[key] - self.lowers[key])
                    + self.lowers[key]
                )
                tf_params.append({key: tf})

            # If constrained from below, use softplus
            elif self.lowers[key] is not None:
                tf = softplus(param[key]) + self.lowers[key]
                tf_params.append({key: tf})

            # If constrained from above, use negative softplus
            elif self.uppers[key] is not None:
                tf = -softplus(-param[key]) + self.uppers[key]
                tf_params.append({key: tf})

            # Else just pass through
            else:
                tf_params.append({key: param[key]})

        return tf_params

    def inverse(self, params: jnp.ndarray) -> jnp.ndarray:
        """Takes parameters from within the interval and makes them unconstrained.

        Args:
            params: A list of dictionaries with transformed parameters.

        Returns:
            A list of dictionaries with unconstrained parameters.
        """

        tf_params = []
        for param in params:
            key = list(param.keys())[0]

            # If constrained from below and above, use expit
            if self.lowers[key] is not None and self.uppers[key] is not None:
                tf = expit(
                    (param[key] - self.lowers[key])
                    / (self.uppers[key] - self.lowers[key])
                )
                tf_params.append({key: tf})

            # If constrained from below, use inv_softplus
            elif self.lowers[key] is not None:
                tf = inv_softplus(param[key] - self.lowers[key])
                tf_params.append({key: tf})

            # If constrained from above, use negative inv_softplus
            elif self.uppers[key] is not None:
                tf = -inv_softplus(-(param[key] - self.uppers[key]))
                tf_params.append({key: tf})

            # else just pass through
            else:
                tf_params.append({key: param[key]})

        return tf_params

__init__(lowers, uppers)

Initialize the bounds.

Parameters:

Name Type Description Default
lowers Dict[str, float]

A dictionary of lower bounds for each parameter (None for no bound).

required
uppers Dict[str, float]

A dictionary of upper bounds for each parameter (None for no bound).

required
Source code in jaxley/optimize/transforms.py
44
45
46
47
48
49
50
51
52
53
def __init__(self, lowers: Dict[str, float], uppers: Dict[str, float]):
    """Initialize the bounds.

    Args:
        lowers: A dictionary of lower bounds for each parameter (None for no bound).
        uppers: A dictionary of upper bounds for each parameter (None for no bound).
    """

    self.lowers = lowers
    self.uppers = uppers

forward(params)

Pushes unconstrained parameters through a tf such that they fit the interval.

Parameters:

Name Type Description Default
params List[Dict[str, ndarray]]

A list of dictionaries with unconstrained parameters.

required

Returns:

Type Description
ndarray

A list of dictionaries with transformed parameters.

Source code in jaxley/optimize/transforms.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def forward(self, params: List[Dict[str, jnp.ndarray]]) -> jnp.ndarray:
    """Pushes unconstrained parameters through a tf such that they fit the interval.

    Args:
        params: A list of dictionaries with unconstrained parameters.

    Returns:
        A list of dictionaries with transformed parameters.

    """

    tf_params = []
    for param in params:
        key = list(param.keys())[0]

        # If constrained from below and above, use sigmoid
        if self.lowers[key] is not None and self.uppers[key] is not None:
            tf = (
                sigmoid(param[key]) * (self.uppers[key] - self.lowers[key])
                + self.lowers[key]
            )
            tf_params.append({key: tf})

        # If constrained from below, use softplus
        elif self.lowers[key] is not None:
            tf = softplus(param[key]) + self.lowers[key]
            tf_params.append({key: tf})

        # If constrained from above, use negative softplus
        elif self.uppers[key] is not None:
            tf = -softplus(-param[key]) + self.uppers[key]
            tf_params.append({key: tf})

        # Else just pass through
        else:
            tf_params.append({key: param[key]})

    return tf_params

inverse(params)

Takes parameters from within the interval and makes them unconstrained.

Parameters:

Name Type Description Default
params ndarray

A list of dictionaries with transformed parameters.

required

Returns:

Type Description
ndarray

A list of dictionaries with unconstrained parameters.

Source code in jaxley/optimize/transforms.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def inverse(self, params: jnp.ndarray) -> jnp.ndarray:
    """Takes parameters from within the interval and makes them unconstrained.

    Args:
        params: A list of dictionaries with transformed parameters.

    Returns:
        A list of dictionaries with unconstrained parameters.
    """

    tf_params = []
    for param in params:
        key = list(param.keys())[0]

        # If constrained from below and above, use expit
        if self.lowers[key] is not None and self.uppers[key] is not None:
            tf = expit(
                (param[key] - self.lowers[key])
                / (self.uppers[key] - self.lowers[key])
            )
            tf_params.append({key: tf})

        # If constrained from below, use inv_softplus
        elif self.lowers[key] is not None:
            tf = inv_softplus(param[key] - self.lowers[key])
            tf_params.append({key: tf})

        # If constrained from above, use negative inv_softplus
        elif self.uppers[key] is not None:
            tf = -inv_softplus(-(param[key] - self.uppers[key]))
            tf_params.append({key: tf})

        # else just pass through
        else:
            tf_params.append({key: param[key]})

    return tf_params

expit(x)

Inverse sigmoid (expit)

Source code in jaxley/optimize/transforms.py
16
17
18
def expit(x: jnp.ndarray) -> jnp.ndarray:
    """Inverse sigmoid (expit)"""
    return -jnp.log(1 / x - 1)

inv_softplus(x)

Inverse softplus.

Source code in jaxley/optimize/transforms.py
26
27
28
def inv_softplus(x: jnp.ndarray) -> jnp.ndarray:
    """Inverse softplus."""
    return jnp.log(jnp.exp(x) - 1)

sigmoid(x)

Sigmoid.

Source code in jaxley/optimize/transforms.py
11
12
13
def sigmoid(x: jnp.ndarray) -> jnp.ndarray:
    """Sigmoid."""
    return 1 / (1 + save_exp(-x))

softplus(x)

Softplus.

Source code in jaxley/optimize/transforms.py
21
22
23
def softplus(x: jnp.ndarray) -> jnp.ndarray:
    """Softplus."""
    return jnp.log(1 + jnp.exp(x))