import jax
from .structs import PompStruct
from ..core.algorithms.mif import _jv_mif_internal
[docs]
def mif(
struct: PompStruct,
thetas_array: jax.Array,
sigmas_array: jax.Array,
sigmas_init_array: jax.Array,
M: int,
a: float,
J: int,
thresh: float,
keys: jax.Array,
n_monitors: int,
) -> tuple[jax.Array, jax.Array, jax.Array]:
"""
This is a pure functional implementation of the Iterated Filtering algorithm,
intended for users who need to compose it within custom JAX loops or
higher-order functions. For a more user-friendly (but non-functional) interface, see
:meth:`pypomp.core.pomp.Pomp.mif`.
This implementation leverages JAX to efficiently vectorize the algorithm across
multiple initial parameter sets simultaneously.
Args:
struct (PompStruct): The compiled structural representation of the POMP model.
thetas_array (jax.Array): Array of initial parameters. Shape (J, n_reps, n_params).
Note that the batch dimension for `vmap` is the second axis (`n_reps`).
sigmas_array (jax.Array): Array of random walk sigmas. Shape (n_params,).
sigmas_init_array (jax.Array): Array of initial random walk sigmas. Shape (n_params,).
M (int): Number of iterations.
a (float): Cooling factor.
J (int): Number of particles.
thresh (float): Resampling threshold.
keys (jax.Array): Random keys. Shape (n_reps, ...).
n_monitors (int): Number of monitors for likelihood averaging.
Returns:
tuple[jax.Array, jax.Array, jax.Array]:
Negative log-likelihood history: Shape (n_reps, M).
Parameter trace history: Shape (n_reps, M+1, n_params).
Final particle swarm: Shape (n_reps, J, n_params).
"""
res = _jv_mif_internal(
thetas_array,
struct.dt_array_extended,
struct.nstep_array,
struct.t0,
struct.times,
struct.ys,
struct.rinit_per,
struct.rproc_per,
struct.dmeas_per,
sigmas_array,
sigmas_init_array,
struct.accumvars,
struct.covars_extended,
M,
a,
J,
thresh,
keys,
struct.rinit_pf,
struct.rproc_pf,
struct.dmeas_pf,
n_monitors,
False,
)
return res[0], res[1], res[2]