pypomp.functional.mif

pypomp.functional.mif(struct: PompStruct, thetas_array: Array, sigmas_array: Array, sigmas_init_array: Array, M: int, a: float, J: int, thresh: float, keys: Array, n_monitors: int) tuple[Array, Array, Array][source]

This is a pure functional implementation of the Iterated Filtering algorithm, intended for users who need to compose it within custom JAX loops or higher-order functions. For a more user-friendly (but non-functional) interface, see pypomp.core.pomp.Pomp.mif().

This implementation leverages JAX to efficiently vectorize the algorithm across multiple initial parameter sets simultaneously.

Parameters:
  • struct (PompStruct) – The compiled structural representation of the POMP model.

  • thetas_array (jax.Array) – Array of initial parameters. Shape (J, n_reps, n_params). Note that the batch dimension for vmap is the second axis (n_reps).

  • sigmas_array (jax.Array) – Array of random walk sigmas. Shape (n_params,).

  • sigmas_init_array (jax.Array) – Array of initial random walk sigmas. Shape (n_params,).

  • M (int) – Number of iterations.

  • a (float) – Cooling factor.

  • J (int) – Number of particles.

  • thresh (float) – Resampling threshold.

  • keys (jax.Array) – Random keys. Shape (n_reps, …).

  • n_monitors (int) – Number of monitors for likelihood averaging.

Returns:

Negative log-likelihood history: Shape (n_reps, M). Parameter trace history: Shape (n_reps, M+1, n_params). Final particle swarm: Shape (n_reps, J, n_params).

Return type:

tuple[jax.Array, jax.Array, jax.Array]