Functional API

The pypomp.functional module provides a collection of pure, stateless JAX functions for model simulation and inference. While the object-oriented Pomp class is recommended for most users, the functional API is intended for advanced users who need to:

  1. Compose algorithms within custom JAX loops, scan, or higher-order functions.

  2. Perform end-to-end differentiation of the entire algorithm (e.g., using jax.grad on mop).

PompStruct

To use the functional API, you must first export your model’s structural data and compiled functions into a PompStruct. This can be done using the to_struct() method.

PompStruct(ys, dt_array_extended, ...)

A lightweight, immutable JAX PyTree holding the static data and compiled simulator functions for a POMP model.

Core Algorithms

pfilter(struct, thetas_array, J, thresh, keys)

This is a pure functional implementation of the particle filter, intended for users who need to compose it within custom JAX loops or higher-order functions.

mif(struct, thetas_array, sigmas_array, ...)

This is a pure functional implementation of the Iterated Filtering algorithm, intended for users who need to compose it within custom JAX loops or higher-order functions.

simulate(struct, thetas_array, nsim, keys[, ...])

This is a pure functional implementation of the simulation algorithm, intended for users who need to compose it within custom JAX loops or higher-order functions.

Differentiable Particle Filtering

These functions are primarily used for gradient-based parameter estimation. mop and dpop are designed to be fully differentiable with respect to the model parameters (thetas_array).

train(struct, thetas_array, J, optimizer, M, ...)

This is a pure functional implementation of the optimization algorithm, intended for users who need to compose it within custom JAX loops or higher-order functions.

mop(struct, thetas_array, J, alpha, keys)

This is a pure functional implementation of the MOP differentiable particle filter, intended for users who need to compose it within custom JAX loops or higher-order functions.

dpop(struct, thetas_array, J, alpha, ...)

This is a pure functional implementation of the DPOP differentiable particle filter, intended for users who need to compose it within custom JAX loops or higher-order functions.