Source code for pypomp.functional.dpop

import jax
from .structs import PompStruct
from ..core.algorithms.dpop import _vmapped_dpop_internal


[docs] def dpop( struct: PompStruct, thetas_array: jax.Array, J: int, alpha: float, process_weight_index: int, keys: jax.Array, ) -> jax.Array: """ This is a pure functional implementation of the DPOP differentiable particle filter, intended for users who need to compose it within custom JAX loops or higher-order functions. This function is analogous to :func:`pypomp.functional.mop` as a fully differentiable objective function for parameter estimation. However, it additionally incorporates a per-interval transition log-weight that is assumed to be stored in one of the state components. The process log-weight is expected to be accumulated over a single observation interval by the user-specified process model. At the beginning of each interval, the corresponding state component should be reset to zero (this is naturally handled by ``accumvars``). Args: struct (PompStruct): The compiled structural representation of the POMP model. thetas_array (jax.Array): Array of initial parameters. Shape (n_reps, n_params). J (int): Number of particles. alpha (float): Alpha parameter for DPOP. process_weight_index (int): Index of the process weight state. keys (jax.Array): Random keys. Shape (n_reps, ...). Returns: jax.Array: Negative DPOP log-likelihood estimates. """ return _vmapped_dpop_internal( thetas_array, struct.ys, struct.dt_array_extended, struct.nstep_array, struct.t0, struct.times, J, struct.rinit_pf, struct.rproc_pf, struct.dmeas_pf, struct.accumvars, struct.covars_extended, alpha, process_weight_index, len(struct.times), keys, )