pprop.optimization package¶
Module contents¶
This module provides an Adam optimiser for minimising a loss of the form
\(L(f(\boldsymbol{\theta}))\), where \(f\) is a
Propagator and \(L\) is a user-supplied scalar
loss function.
Gradients are computed via the chain rule:
where \(\partial \mathbf{f}/\partial \boldsymbol{\theta}\) is
obtained analytically from
eval_and_grad(), and
\(\partial L / \partial \mathbf{f}\) is either supplied directly by
the caller via grad_L, or estimated by central finite differences via
_numerical_grad().
- pprop.optimization.adam(L, propagator, params_init, lr=0.001, num_steps=1000, print_every=100, grad_L=None)[source]¶
Minimize \(L(f(\boldsymbol{\theta}))\) using the Adam optimiser.
At each step the gradient is assembled via the chain rule:
\[\nabla_{\boldsymbol{\theta}} L = \underbrace{\nabla_{\mathbf{f}} L}_{\text{grad\_L or finite diff.}} \cdot \underbrace{\frac{\partial \mathbf{f}}{\partial \boldsymbol{\theta}}}_{\text{analytic}}\]The gradient \(\nabla_{\mathbf{f}} L\) is computed in one of two ways:
If
grad_Lis provided, it is called directly. This is exact and efficient; a natural choice isjax.grad(L)whenLis written with JAX-compatible operations.If
grad_LisNone(default), the gradient is estimated by central finite differences via_numerical_grad(). This requires no assumptions onLbeyond it being callable.
- Parameters:
L (Callable[[ndarray], float]) – Scalar loss function. Receives
f_valsof shape(num_obs,)and returns a float.propagator (Propagator) – A propagated
Propagatorinstance exposing aneval_and_grad(params)method.params_init (ndarray of shape (num_params,)) – Initial parameter vector. A copy is taken so the original is not modified.
lr (float, optional) – Adam learning rate. Defaults to
1e-3.num_steps (int, optional) – Number of optimisation steps. Defaults to
1000.print_every (int, optional) – Print a progress line every this many steps. Set to
0for silent operation. Defaults to100.grad_L (Callable[[ndarray], ndarray], optional) – Gradient of
Lwith respect to its inputf_vals. Should return an array of shape(num_obs,). IfNone, central finite differences are used instead. A typical choice isjax.grad(L)whenLis JAX-compatible.
- Returns:
dict with keys
``params`` (ndarray of shape (num_params,)) – Final parameter vector after optimisation.
``fun`` (float) – Loss value at the final parameters.
``history`` (list[float]) – Loss value recorded at every step.
- Return type:
dict
Examples
NumPy loss: finite differences used automatically:
>>> result = adam(lambda f: float(np.sum(f**2)), propagator, params_init)
JAX loss: exact gradient via
jax.grad:>>> import jax >>> import jax.numpy as jnp >>> L_jax = lambda f: jnp.sum(f**2) >>> result = adam(L_jax, propagator, params_init, grad_L=jax.grad(L_jax))