pypomp.functional.mif¶
- pypomp.functional.mif(struct: PompStruct, thetas_array: Array, sigmas_array: Array, sigmas_init_array: Array, M: int, a: float, J: int, thresh: float, keys: Array, n_monitors: int) tuple[Array, Array, Array][source]¶
This is a pure functional implementation of the Iterated Filtering algorithm, intended for users who need to compose it within custom JAX loops or higher-order functions. For a more user-friendly (but non-functional) interface, see
pypomp.core.pomp.Pomp.mif().This implementation leverages JAX to efficiently vectorize the algorithm across multiple initial parameter sets simultaneously.
- Parameters:
struct (PompStruct) – The compiled structural representation of the POMP model.
thetas_array (jax.Array) – Array of initial parameters. Shape (J, n_reps, n_params). Note that the batch dimension for vmap is the second axis (n_reps).
sigmas_array (jax.Array) – Array of random walk sigmas. Shape (n_params,).
sigmas_init_array (jax.Array) – Array of initial random walk sigmas. Shape (n_params,).
M (int) – Number of iterations.
a (float) – Cooling factor.
J (int) – Number of particles.
thresh (float) – Resampling threshold.
keys (jax.Array) – Random keys. Shape (n_reps, …).
n_monitors (int) – Number of monitors for likelihood averaging.
- Returns:
Negative log-likelihood history: Shape (n_reps, M). Parameter trace history: Shape (n_reps, M+1, n_params). Final particle swarm: Shape (n_reps, J, n_params).
- Return type: