pypomp.functional.train

pypomp.functional.train(struct: PompStruct, thetas_array: Array, J: int, optimizer: str, M: int, eta: Array, c: float, max_ls_itn: int, thresh: float, scale: bool, ls: bool, alpha: float | Array, keys: Array, alpha_cooling: float, n_monitors: int, clip_norm: float | None = None, beta1: float = 0.9, beta2: float = 0.999, epsilon: float = 1e-08) tuple[Array, Array][source]

This is a pure functional implementation of the optimization algorithm, intended for users who need to compose it within custom JAX loops or higher-order functions. For a more user-friendly (but non-functional) interface, see pypomp.core.pomp.Pomp.train().

This function performs Maximum Likelihood Estimation (MLE) by treating the particle filter as a differentiable computational graph. It computes gradients of the log-likelihood with respect to the parameters via reverse-mode automatic differentiation (using JAX), and updates the parameters using optimizers (e.g., Adam, SGD).

This implementation leverages JAX to efficiently vectorize the algorithm across multiple initial parameter sets simultaneously.

Parameters:
  • struct (PompStruct) – The compiled structural representation of the POMP model.

  • thetas_array (jax.Array) – Array of initial parameters. Shape (n_reps, n_params).

  • J (int) – Number of particles.

  • optimizer (str) – Optimizer choice.

  • M (int) – Number of iterations.

  • eta (jax.Array) – Learning rates array. Shape (M, n_params).

  • c (float) – Armijo condition constant.

  • max_ls_itn (int) – Max line search iterations.

  • thresh (float) – Resampling threshold.

  • scale (bool) – Whether to scale direction.

  • ls (bool) – Whether to use line search.

  • alpha (float | jax.Array) – Alpha parameter.

  • keys (jax.Array) – Random keys. Shape (n_reps, …).

  • alpha_cooling (float) – Alpha cooling factor.

  • n_monitors (int) – Number of monitors.

  • clip_norm (float | None) – Gradient clipping norm.

  • beta1 (float) – Exponential decay rate for first moment estimates.

  • beta2 (float) – Exponential decay rate for second moment estimates.

  • epsilon (float) – Small constant for numerical stability.

Returns:

Negative logLik history: Shape (n_reps, M) Theta history: Shape (n_reps, M+1, n_params)

Return type:

tuple[jax.Array, jax.Array]