Source code for pypomp.functional.simulate

import jax
import jax.numpy as jnp
from .structs import PompStruct
from ..core.algorithms.simulate import _jv_simulate_internal


[docs] def simulate( struct: PompStruct, thetas_array: jax.Array, nsim: int, keys: jax.Array, times: jax.Array | None = None, ) -> tuple[jax.Array, jax.Array]: """ This is a pure functional implementation of the simulation algorithm, intended for users who need to compose it within custom JAX loops or higher-order functions. For a more user-friendly (but non-functional) interface, see :meth:`pypomp.core.pomp.Pomp.simulate`. This function propagates the system's latent state through time according to the process model (`rproc`) and generates corresponding simulated observations from the measurement model (`rmeas`). This implementation leverages JAX to efficiently vectorize the simulations across multiple parameter sets and simulation replicates simultaneously. Args: struct (PompStruct): The compiled structural representation of the POMP model. thetas_array (jax.Array): Array of initial parameters. Shape (n_reps, n_params). nsim (int): Number of simulations. keys (jax.Array): Random keys. Shape (n_reps, ...). times (jax.Array | None): Custom observation times. Defaults to struct.times. Returns: tuple[jax.Array, jax.Array]: X_sims: simulated states. Shape (n_reps, nsim, len(times), n_states) Y_sims: simulated observations. Shape (n_reps, nsim, len(times), n_obs) """ _times = struct.times if times is None else times ydim = struct.ys.shape[1] if struct.ys is not None else 1 X, Y = _jv_simulate_internal( struct.rinit_pf, struct.rproc_pf, struct.rmeas_pf, thetas_array, struct.t0, _times, struct.dt_array_extended, struct.nstep_array, ydim, struct.covars_extended, struct.accumvars, nsim, keys, ) # Transpose from (n_reps, time, dim, nsim) to (n_reps, nsim, time, dim) return jnp.transpose(X, (0, 3, 1, 2)), jnp.transpose(Y, (0, 3, 1, 2))