Source code for pypomp.functional.train

import jax
from .structs import PompStruct
from ..core.algorithms.train import _vmapped_train_internal


[docs] def train( struct: PompStruct, thetas_array: jax.Array, J: int, optimizer: str, M: int, eta: jax.Array, c: float, max_ls_itn: int, thresh: float, scale: bool, ls: bool, alpha: float | jax.Array, keys: jax.Array, alpha_cooling: float, n_monitors: int, clip_norm: float | None = None, beta1: float = 0.9, beta2: float = 0.999, epsilon: float = 1e-8, ) -> tuple[jax.Array, jax.Array]: """ This is a pure functional implementation of the optimization algorithm, intended for users who need to compose it within custom JAX loops or higher-order functions. For a more user-friendly (but non-functional) interface, see :meth:`pypomp.core.pomp.Pomp.train`. This function performs Maximum Likelihood Estimation (MLE) by treating the particle filter as a differentiable computational graph. It computes gradients of the log-likelihood with respect to the parameters via reverse-mode automatic differentiation (using JAX), and updates the parameters using optimizers (e.g., Adam, SGD). This implementation leverages JAX to efficiently vectorize the algorithm across multiple initial parameter sets simultaneously. Args: struct (PompStruct): The compiled structural representation of the POMP model. thetas_array (jax.Array): Array of initial parameters. Shape (n_reps, n_params). J (int): Number of particles. optimizer (str): Optimizer choice. M (int): Number of iterations. eta (jax.Array): Learning rates array. Shape (M, n_params). c (float): Armijo condition constant. max_ls_itn (int): Max line search iterations. thresh (float): Resampling threshold. scale (bool): Whether to scale direction. ls (bool): Whether to use line search. alpha (float | jax.Array): Alpha parameter. keys (jax.Array): Random keys. Shape (n_reps, ...). alpha_cooling (float): Alpha cooling factor. n_monitors (int): Number of monitors. clip_norm (float | None): Gradient clipping norm. beta1 (float): Exponential decay rate for first moment estimates. beta2 (float): Exponential decay rate for second moment estimates. epsilon (float): Small constant for numerical stability. Returns: tuple[jax.Array, jax.Array]: Negative logLik history: Shape (n_reps, M) Theta history: Shape (n_reps, M+1, n_params) """ return _vmapped_train_internal( thetas_array, struct.ys, struct.dt_array_extended, struct.nstep_array, struct.t0, struct.times, struct.rinit_pf, struct.rproc_pf, struct.dmeas_pf, struct.accumvars, struct.covars_extended, J, optimizer, M, eta, c, max_ls_itn, thresh, scale, ls, alpha, keys, alpha_cooling, n_monitors, clip_norm, beta1, beta2, epsilon, )