pypomp.functional.dpop

pypomp.functional.dpop(struct: PompStruct, thetas_array: Array, J: int, alpha: float, process_weight_index: int, keys: Array) Array[source]

This is a pure functional implementation of the DPOP differentiable particle filter, intended for users who need to compose it within custom JAX loops or higher-order functions.

This function is analogous to pypomp.functional.mop() as a fully differentiable objective function for parameter estimation. However, it additionally incorporates a per-interval transition log-weight that is assumed to be stored in one of the state components.

The process log-weight is expected to be accumulated over a single observation interval by the user-specified process model. At the beginning of each interval, the corresponding state component should be reset to zero (this is naturally handled by accumvars).

Parameters:
  • struct (PompStruct) – The compiled structural representation of the POMP model.

  • thetas_array (jax.Array) – Array of initial parameters. Shape (n_reps, n_params).

  • J (int) – Number of particles.

  • alpha (float) – Alpha parameter for DPOP.

  • process_weight_index (int) – Index of the process weight state.

  • keys (jax.Array) – Random keys. Shape (n_reps, …).

Returns:

Negative DPOP log-likelihood estimates.

Return type:

jax.Array