Source code for pprop.optimization

"""
This module provides an Adam optimiser for minimising a loss of the form
:math:`L(f(\\boldsymbol{\\theta}))`, where :math:`f` is a
:class:`~pprop.propagator.Propagator` and :math:`L` is a user-supplied scalar
loss function.

Gradients are computed via the chain rule:

.. math::

    \\frac{\\partial L}{\\partial \\boldsymbol{\\theta}}
    = \\frac{\\partial L}{\\partial \\mathbf{f}}
      \\cdot \\frac{\\partial \\mathbf{f}}{\\partial \\boldsymbol{\\theta}}

where :math:`\\partial \\mathbf{f}/\\partial \\boldsymbol{\\theta}` is
obtained analytically from
:meth:`~pprop.propagator.Propagator.eval_and_grad`, and
:math:`\\partial L / \\partial \\mathbf{f}` is either supplied directly by
the caller via ``grad_L``, or estimated by central finite differences via
:func:`_numerical_grad`.
"""
from __future__ import annotations

from typing import Callable, Optional

import numpy as np
import optax


[docs] def adam( L: Callable[[np.ndarray], float], propagator, params_init: np.ndarray, lr: float = 1e-3, num_steps: int = 1000, print_every: int = 100, grad_L: Optional[Callable[[np.ndarray], np.ndarray]] = None, ) -> dict: """ Minimize :math:`L(f(\\boldsymbol{\\theta}))` using the Adam optimiser. At each step the gradient is assembled via the chain rule: .. math:: \\nabla_{\\boldsymbol{\\theta}} L = \\underbrace{\\nabla_{\\mathbf{f}} L}_{\\text{grad\\_L or finite diff.}} \\cdot \\underbrace{\\frac{\\partial \\mathbf{f}}{\\partial \\boldsymbol{\\theta}}}_{\\text{analytic}} The gradient :math:`\\nabla_{\\mathbf{f}} L` is computed in one of two ways: - If ``grad_L`` is provided, it is called directly. This is exact and efficient; a natural choice is ``jax.grad(L)`` when ``L`` is written with JAX-compatible operations. - If ``grad_L`` is ``None`` (default), the gradient is estimated by central finite differences via :func:`_numerical_grad`. This requires no assumptions on ``L`` beyond it being callable. Parameters ---------- L : Callable[[ndarray], float] Scalar loss function. Receives ``f_vals`` of shape ``(num_obs,)`` and returns a float. propagator : Propagator A propagated :class:`~pprop.propagator.Propagator` instance exposing an ``eval_and_grad(params)`` method. params_init : ndarray of shape (num_params,) Initial parameter vector. A copy is taken so the original is not modified. lr : float, optional Adam learning rate. Defaults to ``1e-3``. num_steps : int, optional Number of optimisation steps. Defaults to ``1000``. print_every : int, optional Print a progress line every this many steps. Set to ``0`` for silent operation. Defaults to ``100``. grad_L : Callable[[ndarray], ndarray], optional Gradient of ``L`` with respect to its input ``f_vals``. Should return an array of shape ``(num_obs,)``. If ``None``, central finite differences are used instead. A typical choice is ``jax.grad(L)`` when ``L`` is JAX-compatible. Returns ------- dict with keys: ``params`` : ndarray of shape (num_params,) Final parameter vector after optimisation. ``fun`` : float Loss value at the final parameters. ``history`` : list[float] Loss value recorded at every step. Examples -------- NumPy loss: finite differences used automatically: >>> result = adam(lambda f: float(np.sum(f**2)), propagator, params_init) JAX loss: exact gradient via ``jax.grad``: >>> import jax >>> import jax.numpy as jnp >>> L_jax = lambda f: jnp.sum(f**2) >>> result = adam(L_jax, propagator, params_init, grad_L=jax.grad(L_jax)) """ optimizer = optax.adam(lr) params = params_init.copy().astype(float) opt_state = optimizer.init(params) history: list[float] = [] # Build the gradient callable once outside the loop. # If the user supplies grad_L we use it directly; otherwise we wrap L # in a central finite-difference estimator. _grad_L: Callable[[np.ndarray], np.ndarray] = ( grad_L if grad_L is not None else _numerical_grad(L) ) for step in range(1, num_steps + 1): # Evaluate f(θ) and its Jacobian ∂f/∂θ analytically. f_vals, f_grads = propagator.eval_and_grad(params) # (num_obs,), (num_obs, num_params) # Evaluate the scalar loss and ∂L/∂f. loss = float(L(f_vals)) dLdf = _grad_L(f_vals) # (num_obs,) # Chain rule: ∂L/∂θ = (∂L/∂f) @ (∂f/∂θ) grad = dLdf @ f_grads # (num_params,) history.append(loss) # Apply one Adam step and update parameters. updates, opt_state = optimizer.update(grad, opt_state, params) params = optax.apply_updates(params, updates) if print_every and step % print_every == 0: print(f" step {step:5d}/{num_steps} loss = {loss:.8f}") return { "params": params, "fun": float(L(propagator(params))), "history": history, }
def _numerical_grad( L: Callable[[np.ndarray], float], eps: float = 1e-5, ) -> Callable[[np.ndarray], np.ndarray]: """ Return a central finite-difference gradient function for ``L``. For each component :math:`f_i`, the partial derivative is approximated as: .. math:: \\frac{\\partial L}{\\partial f_i} \\approx \\frac{L(\\mathbf{f} + \\epsilon\\,\\mathbf{e}_i) - L(\\mathbf{f} - \\epsilon\\,\\mathbf{e}_i)}{2\\epsilon} Parameters ---------- L : Callable[[ndarray], float] Scalar loss function. eps : float, optional Finite-difference step size. Defaults to ``1e-5``. Returns ------- Callable[[ndarray], ndarray] A function that accepts ``f_vals`` of shape ``(num_obs,)`` and returns the estimated gradient of the same shape. """ def _grad(f_vals: np.ndarray) -> np.ndarray: g = np.zeros_like(f_vals) for i in range(len(f_vals)): fp = f_vals.copy() fp[i] += eps fm = f_vals.copy() fm[i] -= eps g[i] = (L(fp) - L(fm)) / (2 * eps) return g return _grad