pprop.optimization package

Module contents

This module provides an Adam optimiser for minimising a loss of the form \(L(f(\boldsymbol{\theta}))\), where \(f\) is a Propagator and \(L\) is a user-supplied scalar loss function.

Gradients are computed via the chain rule:

\[\frac{\partial L}{\partial \boldsymbol{\theta}} = \frac{\partial L}{\partial \mathbf{f}} \cdot \frac{\partial \mathbf{f}}{\partial \boldsymbol{\theta}}\]

where \(\partial \mathbf{f}/\partial \boldsymbol{\theta}\) is obtained analytically from eval_and_grad(), and \(\partial L / \partial \mathbf{f}\) is either supplied directly by the caller via grad_L, or estimated by central finite differences via _numerical_grad().

pprop.optimization.adam(L, propagator, params_init, lr=0.001, num_steps=1000, print_every=100, grad_L=None)[source]

Minimize \(L(f(\boldsymbol{\theta}))\) using the Adam optimiser.

At each step the gradient is assembled via the chain rule:

\[\nabla_{\boldsymbol{\theta}} L = \underbrace{\nabla_{\mathbf{f}} L}_{\text{grad\_L or finite diff.}} \cdot \underbrace{\frac{\partial \mathbf{f}}{\partial \boldsymbol{\theta}}}_{\text{analytic}}\]

The gradient \(\nabla_{\mathbf{f}} L\) is computed in one of two ways:

  • If grad_L is provided, it is called directly. This is exact and efficient; a natural choice is jax.grad(L) when L is written with JAX-compatible operations.

  • If grad_L is None (default), the gradient is estimated by central finite differences via _numerical_grad(). This requires no assumptions on L beyond it being callable.

Parameters:
  • L (Callable[[ndarray], float]) – Scalar loss function. Receives f_vals of shape (num_obs,) and returns a float.

  • propagator (Propagator) – A propagated Propagator instance exposing an eval_and_grad(params) method.

  • params_init (ndarray of shape (num_params,)) – Initial parameter vector. A copy is taken so the original is not modified.

  • lr (float, optional) – Adam learning rate. Defaults to 1e-3.

  • num_steps (int, optional) – Number of optimisation steps. Defaults to 1000.

  • print_every (int, optional) – Print a progress line every this many steps. Set to 0 for silent operation. Defaults to 100.

  • grad_L (Callable[[ndarray], ndarray], optional) – Gradient of L with respect to its input f_vals. Should return an array of shape (num_obs,). If None, central finite differences are used instead. A typical choice is jax.grad(L) when L is JAX-compatible.

Returns:

  • dict with keys

  • ``params`` (ndarray of shape (num_params,)) – Final parameter vector after optimisation.

  • ``fun`` (float) – Loss value at the final parameters.

  • ``history`` (list[float]) – Loss value recorded at every step.

Return type:

dict

Examples

NumPy loss: finite differences used automatically:

>>> result = adam(lambda f: float(np.sum(f**2)), propagator, params_init)

JAX loss: exact gradient via jax.grad:

>>> import jax
>>> import jax.numpy as jnp
>>> L_jax = lambda f: jnp.sum(f**2)
>>> result = adam(L_jax, propagator, params_init, grad_L=jax.grad(L_jax))