Source code for pypomp.functional.pfilter
import jax
from .structs import PompStruct
from ..core.algorithms.pfilter import _vmapped_pfilter_internal2
[docs]
def pfilter(
struct: PompStruct,
thetas_array: jax.Array,
J: int,
thresh: float,
keys: jax.Array,
CLL: bool = False,
ESS: bool = False,
filter_mean: bool = False,
prediction_mean: bool = False,
) -> dict[str, jax.Array]:
"""
This is a pure functional implementation of the particle filter, intended for
users who need to compose it within custom JAX loops or higher-order
functions. For a more user-friendly (but non-functional) interface, see
:meth:`pypomp.core.pomp.Pomp.pfilter`.
This implementation leverages JAX to efficiently vectorize the algorithm across
multiple parameter sets simultaneously.
Args:
struct (PompStruct): The compiled structural representation of the POMP model.
thetas_array (jax.Array): Array of initial parameters. Shape (n_reps, n_params).
J (int): Number of particles.
thresh (float): Resampling threshold.
keys (jax.Array): Random keys. Shape (n_reps, reps, ...).
CLL (bool): Compute conditional log-likelihoods.
ESS (bool): Compute effective sample size.
filter_mean (bool): Compute filtered mean.
prediction_mean (bool): Compute prediction mean.
Returns:
dict[str, jax.Array]: A dictionary containing the results of the particle filter.
The following entries are always present:
- `logLik`: The log-likelihood estimate.
The following entries are present if their corresponding flags are set to True:
- `CLL`: Conditional log-likelihoods at each time point.
- `ESS`: Effective sample size at each time point.
- `filter_mean`: Filtered state means at each time point.
- `prediction_mean`: Predicted state means at each time point.
"""
results = _vmapped_pfilter_internal2(
thetas_array,
struct.dt_array_extended,
struct.nstep_array,
struct.t0,
struct.times,
struct.ys,
J,
struct.rinit_pf,
struct.rproc_pf,
struct.dmeas_pf,
struct.accumvars,
struct.covars_extended,
thresh,
keys,
CLL,
ESS,
filter_mean,
prediction_mean,
False,
)
results["logLik"] = -results.pop("neg_loglik")
return results