Optimization
TypeOptimizer
¶
optax
wrapper which allows different argument values for different params.
Source code in jaxley/optimize/optimizer.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
|
__init__(optimizer, optimizer_args, opt_params)
¶
Create the optimizers.
This requires access to opt_params
in order to know how many optimizers
should be created. It creates len(opt_params)
optimizers.
Example usage:
lrs = {"HH_gNa": 0.01, "radius": 1.0}
optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)
opt_state = optimizer.init(opt_params)
optimizer_args = {"HH_gNa": [0.01, 0.4], "radius": [1.0, 0.8]}
optimizer = TypeOptimizer(
lambda args: optax.sgd(args[0], momentum=args[1]),
optimizer_args,
opt_params
)
opt_state = optimizer.init(opt_params)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
optimizer |
Callable
|
A Callable that takes the learning rate and returns the
|
required |
optimizer_args |
Dict[str, Any]
|
The arguments for different kinds of parameters.
Each item of the dictionary will be passed to the |
required |
opt_params |
List[Dict[str, ndarray]]
|
The parameters to be optimized. The exact values are not used, only the number of elements in the list and the key of each dict. |
required |
Source code in jaxley/optimize/optimizer.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
|
init(opt_params)
¶
Initialize the optimizers. Equivalent to optax.optimizers.init()
.
Source code in jaxley/optimize/optimizer.py
59 60 61 62 63 64 65 66 |
|
update(gradient, opt_state)
¶
Update the optimizers. Equivalent to optax.optimizers.update()
.
Source code in jaxley/optimize/optimizer.py
68 69 70 71 72 73 74 75 76 77 |
|
ParamTransform
¶
Parameter transformation utility.
This class is used to transform parameters from an unconstrained space to a constrained space and back. If the range is bounded both from above and below, we use the sigmoid function to transform the parameters. If the range is only bounded from below or above, we use softplus.
Attributes:
Name | Type | Description |
---|---|---|
lowers |
A dictionary of lower bounds for each parameter (None for no bound). |
|
uppers |
A dictionary of upper bounds for each parameter (None for no bound). |
Source code in jaxley/optimize/transforms.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
|
__init__(lowers, uppers)
¶
Initialize the bounds.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
lowers |
Dict[str, float]
|
A dictionary of lower bounds for each parameter (None for no bound). |
required |
uppers |
Dict[str, float]
|
A dictionary of upper bounds for each parameter (None for no bound). |
required |
Source code in jaxley/optimize/transforms.py
44 45 46 47 48 49 50 51 52 53 |
|
forward(params)
¶
Pushes unconstrained parameters through a tf such that they fit the interval.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
List[Dict[str, ndarray]]
|
A list of dictionaries with unconstrained parameters. |
required |
Returns:
Type | Description |
---|---|
ndarray
|
A list of dictionaries with transformed parameters. |
Source code in jaxley/optimize/transforms.py
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
|
inverse(params)
¶
Takes parameters from within the interval and makes them unconstrained.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
ndarray
|
A list of dictionaries with transformed parameters. |
required |
Returns:
Type | Description |
---|---|
ndarray
|
A list of dictionaries with unconstrained parameters. |
Source code in jaxley/optimize/transforms.py
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
|
expit(x)
¶
Inverse sigmoid (expit)
Source code in jaxley/optimize/transforms.py
16 17 18 |
|
inv_softplus(x)
¶
Inverse softplus.
Source code in jaxley/optimize/transforms.py
26 27 28 |
|
sigmoid(x)
¶
Sigmoid.
Source code in jaxley/optimize/transforms.py
11 12 13 |
|
softplus(x)
¶
Softplus.
Source code in jaxley/optimize/transforms.py
21 22 23 |
|