Source code for zea.backend.optimizer

"""Simple implementation of optimizers that support multi-backend autodiff."""

import keras


[docs] def adam(step_size, b1=0.9, b2=0.999, eps=1e-8): """Construct optimizer triple for Adam. Implementation adapted from `JAX's example <https://jax.readthedocs.io/en/latest/_modules/jax/example_libraries/optimizers.html#adam>`_ See example usage: `JAX's example usage <https://jax.readthedocs.io/en/latest/jax.example_libraries.optimizers.html#jax.example_libraries.optimizers.adam>`_ Args: step_size: positive scalar b1: optional, a positive scalar value for beta_1, the exponential decay rate for the first moment estimates (default 0.9). b2: optional, a positive scalar value for beta_2, the exponential decay rate for the second moment estimates (default 0.999). eps: optional, a positive scalar value for epsilon, a small constant for numerical stability (default 1e-8). Returns: An (init_fun, update_fun, get_params) triple. """ def init(x0): m0 = keras.ops.zeros_like(x0) v0 = keras.ops.zeros_like(x0) i = 0 return x0, m0, v0, i def update(g, state): """Update rule for Adam optimizer. Args: g (array): gradient state (tuple): state of the optimizer (x, m, v, i) = (parameter, first moment estimate, second moment estimate, iteration count) Returns: state: updated state """ x, m, v, i = state m = (1 - b1) * g + b1 * m # First moment estimate. v = (1 - b2) * keras.ops.square(g) + b2 * v # Second moment estimate. mhat = m / (1 - b1 ** (i + 1)) # Bias correction. vhat = v / (1 - b2 ** (i + 1)) x = x - step_size * mhat / (keras.ops.sqrt(vhat) + eps) i = i + 1 return x, m, v, i def get_params(state): """Returns just the parameter from the state.""" x, *_ = state return x return init, update, get_params