Source code for pypomp.core.pomp

"""
This module implements the OOP structure for POMP models.
"""

import importlib
import cloudpickle
from copy import deepcopy
import time
from typing import Callable, Any
import numpy as np
import jax
import jax.numpy as jnp
import pandas as pd
import warnings
from typing import Union, overload, Literal
from .viz import plot_traces_internal, plot_simulations_internal

from pypomp.types import ThetaInput
from .metadata import ModelMetadata
from pypomp import functional as F
from .model_struct import _RInit, _RProc, _DMeas, _RMeas
import xarray as xr
from .algorithms.helpers import _calc_ys_covars
from .rw_sigma import RWSigma
from .learning_rate import LearningRate
from .par_trans import ParTrans
from .optimizer import Optimizer, Adam
from .results import ResultsHistory, PompPFilterResult, PompMIFResult, PompTrainResult
from .parameters import PompParameters
from pypomp.maths import logmeanexp
from pypomp import benchmarks
from pypomp.functional.structs import PompStruct


[docs] class Pomp: """ A class representing a Partially Observed Markov Process (POMP) model. This class provides a structured way to define and work with POMP models, which are used for modeling time series data where the underlying state process is only partially observed. The class encapsulates the model components including the initial state distribution, process model, and measurement model. In particular, the class provides methods for: - Simulation of the model - Particle filtering - Iterated filtering - Model training using a differentiable particle filter **⚠️ IMPORTANT: Defining Model Components** The `rinit`, `rproc`, `dmeas`, and `rmeas` arguments expect user-defined functions. **You MUST read the documentation for each argument to understand the required argument names, type hints, and return types.** The `Pomp` object will fail to initialize if these functions do not strictly adhere to the specifications. - **State initialization simulator (rinit):** See :ref:`rinit-tutorial`. - **State transition simulator (rproc):** See :ref:`rproc-tutorial`. - **Measurement density (dmeas):** See :ref:`dmeas-tutorial`. - **Measurement simulator (rmeas):** See :ref:`rmeas-tutorial`. Parameters ---------- ys : pd.DataFrame The measurement data frame. The row index must contain the observation times. theta : ThetaInput Initial parameter(s) for the model. Accepts: - A single dictionary: dict[str, Numeric] - A list of dictionaries: list[dict[str, Numeric]] - An existing PompParameters object Numeric values (e.g. jax.Array, int) are automatically coerced to standard Python floats for internal storage. Vectorized methods (like pfilter) will run in parallel over list/PompParameters inputs. statenames : list[str] List of all latent state variable names. t0 : float The initial time for the model (typically before the first observation). rinit : Callable Initial state simulator function. rproc : Callable Process simulator function (defining a single time step). dmeas : Callable, optional Measurement density function (log-likelihood). rmeas : Callable, optional Measurement simulator function. par_trans : ParTrans, optional Parameter transformation object used to move parameters between the natural space and the estimation space. Defaults to the identity transformation. covars : pd.DataFrame, optional Time-varying covariates. The row index must contain the covariate times. nstep : int, optional The number of integration steps to take between observations. Passed automatically to the `RProc` component. Must be None if `dt` is provided. dt : float, optional Fixed time step size for the process model. Passed automatically to the `RProc` component. Must be None if `nstep` is provided. accumvars : tuple[str, ...], optional Names of accumulator state variables (e.g., incidence tracking). These are reset to 0 at the start of each observation interval. validate_logic : bool, optional Whether to validate the logic of the model components. """ ys: pd.DataFrame """The measurement data frame with observation times as the index.""" _theta: PompParameters """Internal storage for model parameters in canonical order.""" canonical_param_names: list[str] """Ordered list of parameter names used throughout the model.""" statenames: list[str] """Names of all latent state variables in the process model.""" t0: float """Initial time for the model (typically before the first observation).""" rinit: _RInit """Simulator for the initial state distribution.""" rproc: _RProc """Process model simulator handling state transitions between observation times.""" dmeas: _DMeas | None """Measurement density used to evaluate the likelihood of observations.""" rmeas: _RMeas | None """Measurement simulator used to generate synthetic observations.""" par_trans: ParTrans """Parameter transformation object mapping between natural and estimation spaces.""" covars: pd.DataFrame | None """Time-varying covariates for the model, if applicable.""" _covars_extended: np.ndarray | None """Internal covariate array interpolated/aligned to the integration grid.""" _nstep_array: np.ndarray """Number of integration steps between successive observation times.""" _dt_array_extended: np.ndarray """Time step sizes for each integration step over the full time grid.""" _max_steps_per_interval: int """Maximum number of integration steps between any two observation times.""" accumvars: list[str] | None """Names of accumulator state variables that are reset at each observation time.""" _accumvars_indices: tuple[int, ...] | None """Indices of accumulator state variables within the full state vector.""" results_history: ResultsHistory """History of results from `pfilter`, `mif`, and `train` calls.""" fresh_key: jax.Array | None """Running a method that takes a key will store a fresh, unused key here.""" metadata: ModelMetadata """Environment and version metadata initialized when this instance was built.""" def __init__( self, ys: pd.DataFrame, theta: ThetaInput, statenames: tuple[str, ...] | list[str], t0: float, rinit: Callable, rproc: Callable, dmeas: Callable | None = None, rmeas: Callable | None = None, par_trans: ParTrans | None = None, nstep: int | None = None, dt: float | None = None, accumvars: tuple[str, ...] | list[str] | None = None, covars: pd.DataFrame | None = None, validate_logic: bool = True, ): if not isinstance(ys, pd.DataFrame): raise TypeError("ys must be a pandas DataFrame") if covars is not None and not isinstance(covars, pd.DataFrame): raise TypeError("covars must be a pandas DataFrame or None") if isinstance(theta, PompParameters): self._theta = theta else: self._theta = PompParameters(theta) # Extract parameter names from first theta dict self.canonical_param_names = self._theta.get_param_names() # If statenames not provided, we need to infer them if statenames is None: raise ValueError( "statenames must be provided as a list of state variable names" ) if not isinstance(statenames, list) or not all( isinstance(name, str) for name in statenames ): raise ValueError("statenames must be a tuple or list of strings") if accumvars is not None: if not all(isinstance(name, str) for name in accumvars): raise ValueError("accumvars must be a tuple or list of strings") if not all(name in statenames for name in accumvars): raise ValueError("all accumvars must be in statenames") self._accumvars_indices = tuple( tuple(statenames).index(name) for name in accumvars ) else: self._accumvars_indices = None self.statenames = list(statenames) self.accumvars = list(accumvars) if accumvars is not None else None self.ys = ys self.covars = covars self.t0 = float(t0) self.results_history = ResultsHistory() self.fresh_key = None self.metadata = ModelMetadata() if covars is not None: self.covar_names = list(covars.columns) else: self.covar_names = [] self.par_trans = par_trans or ParTrans() self.rinit = _RInit( struct=rinit, statenames=self.statenames, param_names=self.canonical_param_names, covar_names=self.covar_names, par_trans=self.par_trans, validate_logic=validate_logic, ) if dmeas is not None: self.dmeas = _DMeas( struct=dmeas, statenames=self.statenames, param_names=self.canonical_param_names, covar_names=self.covar_names, par_trans=self.par_trans, y_names=list(self.ys.columns), validate_logic=validate_logic, ) else: self.dmeas = None if rmeas is not None: self.rmeas = _RMeas( struct=rmeas, statenames=self.statenames, param_names=self.canonical_param_names, covar_names=self.covar_names, par_trans=self.par_trans, y_names=list(self.ys.columns), validate_logic=validate_logic, ) else: self.rmeas = None if self.dmeas is None and self.rmeas is None: raise ValueError("You must supply at least one of dmeas or rmeas") ( self._covars_extended, self._dt_array_extended, self._nstep_array, self._max_steps_per_interval, ) = _calc_ys_covars( t0=self.t0, times=np.array(self.ys.index), ctimes=np.array(self.covars.index) if self.covars is not None else None, covars=np.array(self.covars) if self.covars is not None else None, dt=dt, nstep=nstep, order="linear", ) self.rproc = _RProc( struct=rproc, statenames=self.statenames, param_names=self.canonical_param_names, covar_names=self.covar_names, par_trans=self.par_trans, nstep=nstep, dt=dt, accumvars=self._accumvars_indices, validate_logic=validate_logic, nstep_array=self._nstep_array, max_steps_bound=self._max_steps_per_interval, ) @property def theta(self) -> PompParameters: return self._theta @theta.setter def theta(self, value: ThetaInput): if isinstance(value, PompParameters): self._theta = value else: self._theta = PompParameters(value) def _prepare_theta_input( self, theta: ThetaInput, ) -> PompParameters: """ Prepare the theta input for the method. """ if theta is None: return self.theta elif isinstance(theta, dict) or isinstance(theta, list): theta = PompParameters(theta) elif isinstance(theta, PompParameters): pass else: raise TypeError( "theta must be a dictionary, a list of dictionaries, or a PompParameters object" ) if set(theta.get_param_names()) != set(self.canonical_param_names): raise ValueError( "theta parameter names must match canonical_param_names up to reordering" ) return theta def _update_fresh_key( self, key: jax.Array | None = None ) -> tuple[jax.Array, jax.Array]: """ Updates the fresh_key attribute and returns a new key along with the old key. Returns: tuple[jax.Array, jax.Array]: A tuple containing the new key and the old key. The old key is the key that was used to update the fresh_key attribute. The new key is the key that should be used for the next method call. """ old_key = self.fresh_key if key is None else key if old_key is None: raise ValueError( "Both the key argument and the fresh_key attribute are None. At least one key must be given." ) self.fresh_key, new_key = jax.random.split(old_key) return new_key, old_key def to_struct(self) -> PompStruct: """ Exports the static data and compiled simulator functions into a lightweight JAX PyTree (PompStruct) for use with the functional API (pypomp.functional). Returns: PompStruct: The compiled structural representation of the model. """ return PompStruct( ys=jnp.array(self.ys), dt_array_extended=jnp.array(self._dt_array_extended), nstep_array=jnp.array(self._nstep_array), t0=self.t0, times=jnp.array(self.ys.index), covars_extended=jnp.array(self._covars_extended) if self._covars_extended is not None else None, accumvars=self.rproc.accumvars, rinit_pf=self.rinit.struct_pf, rproc_pf=self.rproc.struct_pf_interp, dmeas_pf=self.dmeas.struct_pf if self.dmeas is not None else None, rinit_per=self.rinit.struct_per, rproc_per=self.rproc.struct_per_interp, dmeas_per=self.dmeas.struct_per if self.dmeas is not None else None, rmeas_pf=self.rmeas.struct_pf if self.rmeas is not None else None, )
[docs] @staticmethod def sample_params( param_bounds: dict[str, tuple[float, float]], n: int, key: jax.Array ) -> list[dict[str, float]]: """ Samples multiple sets of parameters from independent uniform distributions. This utility method generates random parameter vectors within specified lower and upper bounds. It is commonly used to create initial parameter guesses or 'starting points' for global optimization. Args: param_bounds (dict): Dictionary mapping parameter names to (lower, upper) bounds n (int): Number of parameter sets to sample key (jax.Array): JAX random key for reproducibility Returns: list[dict]: List of n dictionaries containing sampled parameters """ keys = jax.random.split(key, len(param_bounds)) param_sets = [] for i in range(n): params = {} for j, (param_name, (lower, upper)) in enumerate(param_bounds.items()): subkey = jax.random.split(keys[j], n)[i] params[param_name] = float( jax.random.uniform(subkey, shape=(), minval=lower, maxval=upper) ) param_sets.append(params) return param_sets
[docs] def print_metadata(self) -> None: """ Displays technical metadata regarding the creation and runtime environment of this `Pomp` instance. This includes information such as the timestamp of creation, the versions of key dependencies, and other environment-specific details useful for reproducibility and debugging. """ self.metadata.print_metadata()
[docs] def pfilter( self, J: int, key: jax.Array | None = None, theta: ThetaInput = None, thresh: float = 0, reps: int = 1, CLL: bool = False, ESS: bool = False, filter_mean: bool = False, prediction_mean: bool = False, track_time: bool = True, ) -> None: """ Evaluates the likelihood of the model via the particle filter (bootstrap filter). The particle filter (also known as Sequential Monte Carlo) estimates the log-likelihood of the data given a specific set of parameters by propagating a swarm of particles through the latent state space. It can also be used to estimate the latent states over time (via filtering or prediction means). This implementation leverages JAX to efficiently vectorize the algorithm across multiple parameter sets simultaneously. Results are automatically stored in the model's history and can be accessed using `self.results()`. Args: J (int): The number of particles key (jax.Array, optional): The random key. Defaults to self.fresh_key. theta (ThetaInput, optional): Parameters involved in the POMP model. Defaults to self.theta. Accepts: - A single dictionary: dict[str, Numeric] - A list of dictionaries: list[dict[str, Numeric]] - An existing PompParameters object Providing a list or PompParameters object enables faster, vectorized execution across all parameter sets. thresh (float, optional): Threshold value to determine whether to resample particles. Defaults to 0. reps (int, optional): Number of replicates to run. Defaults to 1. CLL (bool, optional): Boolean flag controlling whether to compute and store the conditional log-likelihoods at each time point. ESS (bool, optional): Boolean flag controlling whether to compute and store the effective sample size at each time point. filter_mean (bool, optional): Boolean flag controlling whether to compute and store the filtered mean at each time point. prediction_mean (bool, optional): Boolean flag controlling whether to compute and store the prediction mean at each time point. track_time (bool, optional): Boolean flag controlling whether to track the execution time. Returns: None. Updates `self.results_history` with a `PompPFilterResult` containing the log-likelihoods, and optionally the conditional log-likelihoods (CLL), effective sample size (ESS), filtered means, and prediction means if requested. """ start_time = time.time() theta_obj_in = deepcopy(self._prepare_theta_input(theta)) n_theta_reps = theta_obj_in.num_replicates() new_key, old_key = self._update_fresh_key(key) if self.dmeas is None: raise ValueError("self.dmeas cannot be None") if J < 1: raise ValueError("J should be greater than 0.") thetas_array = theta_obj_in.to_jax_array(self.canonical_param_names) rep_keys = jax.random.split(new_key, n_theta_reps * reps).reshape( n_theta_reps, reps, *new_key.shape ) if len(jax.devices()) > 1: mesh = jax.sharding.Mesh(jax.devices(), axis_names=("theta_reps",)) sharding_spec = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec("theta_reps", None) ) rep_keys_sharding_spec = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec( "theta_reps", *([None] * (rep_keys.ndim - 1)) ), ) thetas_array = jax.device_put(thetas_array, sharding_spec) rep_keys = jax.device_put(rep_keys, rep_keys_sharding_spec) results_jax = F.pfilter( self.to_struct(), thetas_array, J, thresh, rep_keys, CLL, ESS, filter_mean, prediction_mean, ) results = jax.device_get(results_jax) del results_jax logLiks = results["logLik"] logLik_da = xr.DataArray(logLiks, dims=["theta_idx", "rep"]) if track_time is True: execution_time = time.time() - start_time else: execution_time = None CLL_da = None ESS_da = None filter_mean_da = None prediction_mean_da = None if CLL and "CLL" in results: CLL_da = xr.DataArray( results["CLL"], dims=["theta_idx", "rep", "time"], coords={"time": self.ys.index}, ) if ESS and "ESS" in results: ESS_da = xr.DataArray( results["ESS"], dims=["theta_idx", "rep", "time"], coords={"time": self.ys.index}, ) if filter_mean and "filter_mean" in results: filter_mean_da = xr.DataArray( results["filter_mean"], dims=["theta_idx", "rep", "time", "state"], coords={"time": self.ys.index}, ) if prediction_mean and "prediction_mean" in results: prediction_mean_da = xr.DataArray( results["prediction_mean"], dims=["theta_idx", "rep", "time", "state"], coords={"time": self.ys.index}, ) del results logLik_estimates = logmeanexp(logLiks, axis=-1, ignore_nan=False) theta_obj_in.logLik = logLik_estimates self.theta = theta_obj_in result = PompPFilterResult( method="pfilter", execution_time=execution_time, key=old_key, theta=theta_obj_in.to_list(), logLiks=logLik_da, J=J, reps=reps, thresh=thresh, CLL_da=CLL_da, ESS_da=ESS_da, filter_mean=filter_mean_da, prediction_mean=prediction_mean_da, ) self.results_history.add(result)
[docs] def mif( self, J: int, M: int, rw_sd: RWSigma, a: float, key: jax.Array | None = None, theta: ThetaInput = None, thresh: float = 0, n_monitors: int = 0, track_time: bool = True, ) -> None: """ Estimates model parameters by maximizing the marginal likelihood via the Iterated Filtering (IF2) algorithm. The Iterated Filtering algorithm estimates maximum likelihood parameters by introducing random perturbations to the parameters and sequentially filtering them alongside the state variables. Over successive iterations (cooling cycles), the perturbation variance is decayed, allowing the parameters to converge to their MLEs. This implementation leverages JAX to efficiently vectorize the algorithm across multiple initial parameter sets simultaneously. Results are automatically stored in the model's history and can be accessed using `self.results()`. Args: J (int): The number of particles. M (int): Number of algorithm iterations. rw_sd (RWSigma): Random walk sigma object. a (float): Decay factor for RWSigma over 50 iterations. key (jax.Array, optional): The random key for reproducibility. Defaults to self.fresh_key. theta (ThetaInput, optional): Parameters involved in the POMP model. Defaults to self.theta. Accepts: - A single dictionary: dict[str, Numeric] - A list of dictionaries: list[dict[str, Numeric]] - An existing PompParameters object Providing a list or PompParameters object enables faster, vectorized execution across all parameter sets. thresh (float): Resampling threshold. Defaults to 0. n_monitors (int): Number of particle filter runs to average for log-likelihood estimation. Defaults to 0 (uses estimate from perturbed filter). track_time (bool): Boolean flag controlling whether to track the execution time. Returns: None. Updates `self.results_history` with a `PompMIFResult` containing the log-likelihoods, parameter traces, and diagnostic information from the Iterated Filtering (IF2) run. """ start_time = time.time() rw_param_names = list(rw_sd.all_names) if set(rw_param_names) != set(self.canonical_param_names): raise ValueError( "rw_sd.sigmas keys must match canonical_param_names up to reordering. " f"Got {sorted(rw_param_names)}, expected {sorted(self.canonical_param_names)}." ) theta_obj_in = deepcopy(self._prepare_theta_input(theta)) theta_list_in = theta_obj_in.to_list() n_reps = theta_obj_in.num_replicates() new_key, old_key = self._update_fresh_key(key) theta_obj_in.transform(self.par_trans, direction="to_est") sigmas_array, sigmas_init_array = rw_sd._return_arrays( param_names=self.canonical_param_names ) theta_array = theta_obj_in.to_jax_array(self.canonical_param_names) if self.dmeas is None: raise ValueError("self.dmeas cannot be None") if J < 1: raise ValueError("J should be greater than 0.") keys = jax.random.split(new_key, n_reps) theta_tiled = jnp.tile(theta_array, (J, 1, 1)) if len(jax.devices()) > 1: mesh = jax.sharding.Mesh(jax.devices(), axis_names=("reps",)) sharding_spec = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec(None, "reps", None) ) theta_tiled = jax.device_put(theta_tiled, sharding_spec) nLLs_jax, theta_traces_jax, final_thetas_jax = F.mif( self.to_struct(), theta_tiled, sigmas_array, sigmas_init_array, M, a, J, thresh, keys, n_monitors, ) nLLs = jax.device_get(nLLs_jax) theta_traces = jax.device_get(theta_traces_jax) final_thetas = jax.device_get(final_thetas_jax) del nLLs_jax, theta_traces_jax, final_thetas_jax final_theta_ests = [] param_names = self.canonical_param_names trace_vars = ["logLik"] + param_names trace_data = np.zeros((n_reps, M + 1, len(trace_vars)), dtype=float) for i in range(n_reps): # Prepend nan for the log-likelihood of the initial parameters logliks_with_nan = np.concatenate([np.array([np.nan]), -nLLs[i]]) param_traces = theta_traces[i] # shape: (M+1, n_params) # Transform traces from estimation space to natural space param_traces = self.par_trans.transform_array( param_traces, param_names, direction="from_est" ) trace_data[i, :, 0] = logliks_with_nan trace_data[i, :, 1:] = param_traces final_theta_ests.append(final_thetas[i]) traces_da = xr.DataArray( trace_data, dims=["theta_idx", "iteration", "variable"], coords={ "theta_idx": np.arange(n_reps), "iteration": np.arange(M + 1), "variable": trace_vars, }, ) theta = [ self.par_trans.to_floats( theta=dict( zip( self.canonical_param_names, np.mean(theta_est, axis=0).tolist(), ) ), direction="from_est", ) for theta_est in final_theta_ests ] logLik_estimates = -nLLs self.theta = PompParameters(theta, logLik=logLik_estimates) del final_theta_ests if track_time is True: execution_time = time.time() - start_time else: execution_time = None result = PompMIFResult( method="mif", execution_time=execution_time, key=old_key, theta=theta_list_in, traces_da=traces_da, J=J, M=M, rw_sd=rw_sd, a=a, thresh=thresh, n_monitors=n_monitors, ) self.results_history.add(result)
[docs] def train( self, J: int, M: int, eta: LearningRate, key: jax.Array | None = None, theta: ThetaInput = None, optimizer: Optimizer = Adam(), alpha: float = 0.97, thresh: int = 0, alpha_cooling: float = 1.0, n_monitors: int = 1, track_time: bool = True, ) -> None: """ Optimizes model parameters using a differentiable particle filter and gradient-based methods. This method performs Maximum Likelihood Estimation (MLE) by treating the particle filter as a differentiable computational graph. It computes gradients of the log-likelihood with respect to the parameters via reverse-mode automatic differentiation (using JAX), and updates the parameters using optimizers (e.g., Adam, SGD). This implementation leverages JAX to efficiently vectorize the algorithm across multiple initial parameter sets simultaneously. Results are automatically stored in the model's history and can be accessed using `self.results()`. Args: J (int): The number of particles in the MOP objective for obtaining the gradient and/or Hessian. M (int): Maximum iteration for the gradient descent optimization. eta (LearningRate): Learning rates per parameter as a LearningRate object. key (jax.Array, optional): The random key for reproducibility. Defaults to self.fresh_key. theta (ThetaInput, optional): Parameters involved in the POMP model. Defaults to self.theta. Accepts: - A single dictionary: dict[str, Numeric] - A list of dictionaries: list[dict[str, Numeric]] - An existing PompParameters object Providing a list or PompParameters object enables faster, vectorized execution across all parameter sets. optimizer (Optimizer, optional): The optimizer configuration object to use (e.g., `pp.Adam()`, `pp.SGD()`, `pp.Newton()`, `pp.FullMatrixAdam()`, etc.). Defaults to `pp.Adam()`. Hyperparameters like learning rate scaling, line search (`scale`, `ls`, `c`, `max_ls_itn`), gradient clipping (`clip_norm`), or Adam beta values are configured directly inside the optimizer instance. alpha (float, optional): Discount factor for MOP. thresh (int, optional): Threshold value to determine whether to resample particles. alpha_cooling (float, optional): Cooling factor for the MOP discount factor (alpha) using cosine decay. This factor represents the multiplier for the distance of alpha from 1.0 by the end of training (i.e., alpha approaches 1.0). Defaults to 1.0 (no cooling). n_monitors (int, optional): Number of particle filter runs to average for log-likelihood estimation. track_time (bool, optional): Boolean flag controlling whether to track the execution time. Returns: None. Updates `self.results_history` with a `PompTrainResult` containing the log-likelihoods, parameter traces, and optimizer details from the training run. """ start_time = time.time() theta_obj_in = deepcopy(self._prepare_theta_input(theta)) theta_list_in = theta_obj_in.to_list() theta_obj_in.transform(self.par_trans, direction="to_est") n_reps = theta_obj_in.num_replicates() if self.dmeas is None: raise ValueError("self.dmeas cannot be None") if J < 1: raise ValueError("J should be greater than 0") if not isinstance(eta, LearningRate): raise TypeError("eta must be a LearningRate object") # Convert eta to JAX array in canonical order eta_array = eta.to_array(self.canonical_param_names, M) new_key, old_key = self._update_fresh_key(key) keys = jnp.array(jax.random.split(new_key, n_reps)) theta_array = theta_obj_in.to_jax_array(self.canonical_param_names) opt_name = optimizer.__class__.__name__ beta1 = getattr(optimizer, "beta1", 0.9) beta2 = getattr(optimizer, "beta2", 0.999) epsilon = getattr(optimizer, "epsilon", 1e-8 if opt_name == "Adam" else 1e-4) c = optimizer.c max_ls_itn = optimizer.max_ls_itn clip_norm = optimizer.clip_norm scale = optimizer.scale ls = optimizer.ls nLLs, theta_ests = F.train( self.to_struct(), theta_array, J, opt_name, M, eta_array, c, max_ls_itn, thresh, scale, ls, alpha, keys, alpha_cooling, n_monitors, clip_norm, beta1, beta2, epsilon, ) theta_ests_natural = np.stack( [ self.par_trans.transform_array( np.asarray(theta_ests[i]), self.canonical_param_names, direction="from_est", ) for i in range(n_reps) ], axis=0, ) joined_array = xr.DataArray( np.concatenate( [ -nLLs[..., np.newaxis], # shape: (theta_idx, iteration, 1) theta_ests_natural, # shape: (theta_idx, iteration, n_theta) ], axis=-1, ), dims=["theta_idx", "iteration", "variable"], coords={ "theta_idx": range(0, n_reps), "iteration": range(0, M + 1), "variable": ["logLik"] + self.canonical_param_names, }, ) theta = [ self.par_trans.to_floats( theta=dict( zip(self.canonical_param_names, theta_ests[i, -1, :].tolist()) ), direction="from_est", ) for i in range(n_reps) ] logLik_estimates = np.asarray(-nLLs) self.theta = PompParameters(theta, logLik=logLik_estimates) if track_time is True: nLLs.block_until_ready() execution_time = time.time() - start_time else: execution_time = None result = PompTrainResult( method="train", execution_time=execution_time, key=old_key, theta=theta_list_in, traces_da=joined_array, optimizer=optimizer, J=J, M=M, eta=eta, alpha=alpha, thresh=thresh, alpha_cooling=alpha_cooling, ) self.results_history.add(result)
[docs] def dpop_train( self, J: int, M: int, eta: LearningRate, optimizer: str = "Adam", alpha: float = 0.8, decay: float = 0.0, process_weight_state: str | None = None, key: jax.Array | None = None, theta: ThetaInput = None, ) -> tuple[jax.Array, jax.Array]: """ Optimizes model parameters using the DPOP differentiable particle filter and gradient-based methods. This method trains the model parameters to maximize the DPOP objective function using first-order optimizers like Adam or SGD, with optional learning rate decay. Gradients are computed efficiently via JAX reverse-mode automatic differentiation. Parameters ---------- J : int Number of particles. M : int Number of gradient steps. eta : LearningRate Learning rates per parameter as a LearningRate object. optimizer : str, default "Adam" Optimizer to use: "Adam" or "SGD". alpha : float, default 0.8 DPOP discount / cooling factor. decay : float, default 0.0 Learning-rate decay coefficient. At iteration m, the effective learning rate is ``eta / (1 + decay * m)``. process_weight_state : str or None, default None Name of the state component that stores the accumulated process log-weight (e.g. ``"logw"``). key : jax.Array or None, default None Random key. If None, uses ``self.fresh_key``. theta : ThetaInput, default None Optional initial parameter(s). Accepts dict[str, Numeric], list[dict[str, Numeric]], or PompParameters. Numeric values are coerced to floats. Defaults to self.theta. Returns ------- nll_history : jax.Array, shape (M+1,) Mean DPOP negative log-likelihood per observation at each step. theta_history : jax.Array, shape (M+1, p) Parameter vector (estimation space) at each step. """ from .algorithms.train_dpop import dpop_train as _dpop_train new_key, _ = self._update_fresh_key(key) theta_obj = self._prepare_theta_input(theta) theta_nat = theta_obj.to_list()[0] param_names = self.canonical_param_names theta_est_dict = self.par_trans.to_est(theta_nat) theta_init = jnp.array([theta_est_dict[name] for name in param_names]) if not isinstance(eta, LearningRate): raise TypeError("eta must be a LearningRate object") # For now, dpop_train only uses a constant learning rate across iterations # Extract the first row of the schedule eta_array = eta.to_array(param_names, M)[0] ys_array = jnp.array(self.ys.values) dt_array_extended = self._dt_array_extended nstep_array = self._nstep_array t0 = self.t0 times_array = jnp.array(self.ys.index.values) rinitializer = self.rinit.struct_pf rprocess_interp = self.rproc.struct_pf_interp if self.dmeas is None: raise ValueError("dpop_train requires self.dmeas to be not None.") dmeasure = self.dmeas.struct_pf accumvars = self.rproc.accumvars covars_extended = self._covars_extended if process_weight_state is None: raise ValueError( "dpop_train requires a process-weight state. " "Please provide `process_weight_state` as the name of the " "state variable that accumulates the transition log-weight " "(e.g. 'logw')." ) try: process_weight_index = int(self.statenames.index(process_weight_state)) except ValueError as e: raise ValueError( f"State '{process_weight_state}' not found in statenames " f"{self.statenames}" ) from e ntimes = len(self.ys) theta_hist, nll_hist = _dpop_train( theta_init=theta_init, ys=ys_array, dt_array_extended=dt_array_extended, nstep_array=nstep_array, t0=t0, times=times_array, J=J, rinitializer=rinitializer, rprocess_interp=rprocess_interp, dmeasure=dmeasure, accumvars=accumvars, covars_extended=covars_extended, alpha=alpha, process_weight_index=process_weight_index, ntimes=ntimes, key=new_key, M=M, eta=eta_array, optimizer=optimizer, decay=decay, ) return nll_hist, theta_hist
@overload def simulate( self, key: jax.Array | None = None, theta: ThetaInput = None, times: jax.Array | None = None, nsim: int = 1, as_pomp: Literal[False] = False, ) -> tuple[pd.DataFrame, pd.DataFrame]: ... @overload def simulate( self, key: jax.Array | None = None, theta: ThetaInput = None, times: jax.Array | None = None, nsim: int = 1, *, as_pomp: Literal[True], ) -> "Pomp": ...
[docs] def simulate( self, key: jax.Array | None = None, theta: ThetaInput = None, times: jax.Array | None = None, nsim: int = 1, as_pomp: bool = False, ) -> Union[tuple[pd.DataFrame, pd.DataFrame], "Pomp"]: """ Simulates the latent state and measurement processes of the POMP model. This method propagates the system's latent state through time according to the process model (`rproc`) and generates corresponding simulated observations from the measurement model (`rmeas`). This implementation leverages JAX to efficiently vectorize the simulations across multiple parameter sets and simulation replicates simultaneously. Args: key (jax.Array, optional): The random key for random number generation. Defaults to self.fresh_key. theta (ThetaInput, optional): Parameters involved in the POMP model. Defaults to self.theta. Accepts: - A single dictionary: dict[str, Numeric] - A list of dictionaries: list[dict[str, Numeric]] - An existing PompParameters object Providing a list or PompParameters object enables faster, vectorized execution across all parameter sets. times (jax.Array, optional): Times at which to generate observations. Defaults to self.ys.index. nsim (int): The number of simulations to perform. Defaults to 1. as_pomp (bool): If True, returns a new Pomp object containing the simulated observations for the first parameter replicate and simulation, instead of DataFrames. Returns: If as_pomp is False: tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the simulated unobserved state values and the simulated observed values. The columns are as follows: - theta_idx: The index of the parameter set. - sim: The index of the simulation. - time: The time points at which the observations were made. - Remaining columns contain the features of the state and observation processes. If as_pomp is True: Pomp: A deep copy of the original model, where the `ys` attribute contains one dataset of simulated observations. """ if as_pomp: if nsim > 1: warnings.warn( "as_pomp is True, but nsim > 1. Only 1 simulation will be performed as_pomp overrides nsim.", UserWarning, ) nsim = 1 theta_obj_in = self._prepare_theta_input(theta) if self.rmeas is None: raise ValueError( "self.rmeas cannot be None. Did you forget to supply it to the object or method?" ) thetas_array = theta_obj_in.to_jax_array(self.canonical_param_names) new_key, old_key = self._update_fresh_key(key) keys = jax.random.split(new_key, thetas_array.shape[0]) times_array = jnp.array(self.ys.index) if times is None else times X_sims, Y_sims = F.simulate( self.to_struct(), thetas_array, nsim, keys, times=times_array, ) def _to_long( arr: Union[jax.Array, np.ndarray], times_vec: Union[jax.Array, np.ndarray, pd.Index], prefix: str, ) -> pd.DataFrame: vals = np.asarray(arr) # (n_theta, n_sim, n_time, n_feat) n_theta_l, n_sim_l, n_time_l, n_feat_l = vals.shape flat = vals.reshape(n_theta_l * n_sim_l * n_time_l, n_feat_l) theta_idx_l = np.repeat(np.arange(n_theta_l), n_sim_l * n_time_l) sim_idx_l = np.tile(np.repeat(np.arange(n_sim_l), n_time_l), n_theta_l) time_vals_l = np.tile( np.asarray(times_vec).reshape(1, -1), (n_theta_l * n_sim_l, 1) ).reshape(-1) cols = pd.Index([f"{prefix}_{i}" for i in range(n_feat_l)]) df = pd.DataFrame(flat, columns=cols) df.insert(0, "time", time_vals_l) df.insert(0, "sim", sim_idx_l) df.insert(0, "theta_idx", theta_idx_l) return df times0 = np.concatenate([np.array([self.t0]), np.array(times_array)]) X_sims_long = _to_long(X_sims, times0, "state") Y_sims_long = _to_long(Y_sims, np.array(times_array), "obs") if as_pomp: simulated_ys_long = Y_sims_long[ (Y_sims_long["theta_idx"] == 0) & (Y_sims_long["sim"] == 0) ].copy() simulated_ys = pd.DataFrame( simulated_ys_long.drop(columns=["theta_idx", "sim", "time"]) ) simulated_ys.index = pd.Index(simulated_ys_long["time"]) simulated_ys.columns = self.ys.columns pomp_copy = deepcopy(self) pomp_copy.ys = simulated_ys pomp_copy.theta = theta_obj_in.subset([0]) return pomp_copy return X_sims_long, Y_sims_long
[docs] def probe( self, probes: dict[str, Callable[[pd.DataFrame], float]], nsim: int = 100, key: jax.Array | None = None, theta: ThetaInput = None, ) -> pd.DataFrame: """ Evaluates model diagnostics by comparing 'probes' (summary statistics) of real data against simulated data. This method is useful for assessing model goodness-of-fit by checking if specific features of the observed data (e.g., mean, autocorrelation, peak height) are well-captured by simulations generated from the model's parameters. It calculates the specified probe statistics for the original dataset and for multiple simulation replicates, providing a basis for visual or formal comparison. Args: probes (dict[str, Callable[[pd.DataFrame], float]]): A dictionary of probe functions. Each function should receive a DataFrame of observations (with time as the index, or a single dataframe component) and return a numeric scalar. Example: `{"mean": lambda df: df["obs"].mean()}` nsim (int, optional): Number of simulations to run per parameter set. Defaults to 100. key (jax.Array, optional): JAX random key for the simulations. theta (ThetaInput, optional): Parameters to simulate from. Returns: pd.DataFrame: A long-format DataFrame with columns: `probe`, `value`, `is_real_data`, `theta_idx`, `sim` """ sim_result = self.simulate(nsim=nsim, key=key, theta=theta, as_pomp=False) assert isinstance(sim_result, tuple) _, y_sims = sim_result results = [] for name, func in probes.items(): results.append( { "probe": name, "value": float(func(self.ys)), "is_real_data": True, "theta_idx": pd.NA, "sim": pd.NA, } ) def apply_probes(group): theta_idx, sim_id = group.name df = pd.DataFrame(group.drop(columns=["time"])) df.index = pd.Index(group["time"]) df.columns = self.ys.columns for name, func in probes.items(): results.append( { "probe": name, "value": float(func(df)), "is_real_data": False, "theta_idx": theta_idx, "sim": sim_id, } ) y_sims.groupby(["theta_idx", "sim"]).apply(apply_probes, include_groups=False) # type: ignore[call-overload] return pd.DataFrame(results)
[docs] def traces(self) -> pd.DataFrame: """ Returns a DataFrame with the full trace of log-likelihoods and parameters from the entire result history. Columns are - theta_idx: The index of the parameter set (for all methods) - iteration: The global iteration number for that parameter set (increments over all mif/train calls for that set; for pfilter, the last iteration for that set) - method: 'pfilter', 'mif', or 'train' - logLik: The log-likelihood estimate (averaged over reps for pfilter) - <param>: One column for each parameter """ return self.results_history.traces()
[docs] def results(self, index: int = -1, ignore_nan: bool = False) -> pd.DataFrame: """ Returns a DataFrame with the results of the method run at the given index in the model's history. This method provides a convenient way to access the outcome of previous runs (e.g., `pfilter`, `mif`, or `train`). It returns a tidy DataFrame containing the final log-likelihoods and parameter values for all replicates associated with that specific run. Args: index (int): The index of the result to return. Defaults to -1 (the last result). ignore_nan (bool): If True, ignore NaNs when computing the log-likelihood. Returns: pd.DataFrame: A DataFrame with the results of the method run at the given index. """ return self.results_history.results(index=index, ignore_nan=ignore_nan)
[docs] def CLL(self, index: int = -1, average: bool = False) -> pd.DataFrame: """ Returns a tidy DataFrame with the conditional log-likelihoods of the method run at the given index. Args: index (int, optional): The index of the result to retrieve. Defaults to -1. average (bool, optional): Boolean flag controlling whether to average the conditional log-likelihoods over replicates using logmeanexp. Defaults to False. Returns: pd.DataFrame: A DataFrame with the conditional log-likelihoods. """ return self.results_history.CLL(index=index, average=average)
[docs] def ESS(self, index: int = -1, average: bool = False) -> pd.DataFrame: """ Returns a tidy DataFrame with the effective sample size of the method run at the given index. Args: index (int, optional): The index of the result to retrieve. Defaults to -1. average (bool, optional): Boolean flag controlling whether to average the effective sample size over replicates using arithmetic mean. Defaults to False. Returns: pd.DataFrame: A DataFrame with the effective sample size. """ return self.results_history.ESS(index=index, average=average)
[docs] def time(self): """ Return a DataFrame summarizing the execution times of methods run. Returns: pd.DataFrame: A DataFrame where each row contains: - 'method': The name of the method run. - 'time': The execution time in seconds. """ return self.results_history.time()
[docs] def prune(self, n: int = 1, refill: bool = True): """ Filters the current set of parameter replicates to keep only the top `n` performers based on their most recent log-likelihood estimates. This method is commonly used after an estimation run (like `pfilter` or `mif`) to discard poorly performing parameter sets and focus subsequent computational effort on the most promising candidates. If `refill` is enabled, the kept parameters are duplicated to maintain the original number of replicates. Args: n (int): Number of top thetas to keep. refill (bool): If True, repeat the top n thetas to match the previous number of theta sets. """ self.theta.prune(n=n, refill=refill)
[docs] def plot_traces(self, show: bool = True) -> Any: """ Plot the parameter and log-likelihood traces from the entire result history. Each facet shows a parameter or logLik. The x-axis is iteration, y-axis is value. Lines connect mif/train points for the same replication; pfilter points are dots. Color by replication. Args: show (bool): Whether to display the plot. Defaults to True. """ traces = self.traces() fig = plot_traces_internal(traces, title="Pomp Traces") if fig is not None and show: fig.show() return fig
[docs] def plot_simulations( self, key: jax.Array, nsim: int = 20, mode: str = "lines", theta: ThetaInput = None, show: bool = True, ) -> Any: """ Generates an interactive plot comparing simulated trajectories from the model against the actual observed data. This visualization helps assess whether the model (with its current parameters) produces behavior that is qualitatively similar to the observed system. It can display individual simulated paths ('lines') or confidence intervals ('quantiles') to represent the distribution of possible outcomes. Args: key (jax.Array): JAX random key for simulation. nsim (int): Number of simulations to perform. Defaults to 20. mode (str): Plotting mode, either "lines" (individual sims) or "quantiles" (shaded region). Defaults to "lines". theta (ThetaInput, optional): Parameters to use for simulation. Defaults to the first replicate in self.theta. show (bool): Whether to display the plot. Defaults to True. """ if theta is None: theta = ( self.theta.subset([0]) if self.theta and self.theta.num_replicates() > 1 else self.theta ) _, sims = self.simulate(nsim=nsim, theta=theta, key=key) fig = plot_simulations_internal(sims, self.ys, mode=mode) if fig is not None and show: fig.show() return fig
[docs] def print_summary(self, n: int = 5): """ Prints a high-level summary of the POMP model instance and its estimation history. The summary includes: - Basic model statistics such as the number of observations, time steps, and parameters. - The current number of parameter replicates stored in the object. - A summary of the results history, listing the execution of estimation methods (e.g., pfilter, mif, train) and their corresponding performance metrics. """ print("Basics:") print("-------") print(f"Number of observations: {len(self.ys)}") print(f"Number of time steps: {len(self._dt_array_extended)}") print(f"Number of parameters: {self.theta.num_params()}") print(f"Number of parameter sets: {self.theta.num_replicates()}") print() self.results_history.print_summary(n=n)
def __eq__(self, other): """ Check structural equality with another Pomp object. Two Pomp instances are considered equal if they: - Are of the same type - Have identical canonical parameter names - Have equal parameter sets (self.theta) - Have identical data (ys) and covariates (covars) - Have the same state names and initial time t0 - Have equivalent model components (rinit, rproc, dmeas, rmeas) - Have equal fresh_key values (or both None) """ if not isinstance(other, type(self)): return False # Canonical parameter names if self.canonical_param_names != other.canonical_param_names: return False # Parameter sets if self.theta != other.theta: return False # Data and covariates if not self.ys.equals(other.ys): return False if (self.covars is None) != (other.covars is None): return False if self.covars is not None and other.covars is not None: if not self.covars.equals(other.covars): return False # Handle _covars_extended (can be None or JAX array) if (self._covars_extended is None) != (other._covars_extended is None): return False if self._covars_extended is not None and other._covars_extended is not None: if not jax.numpy.array_equal(self._covars_extended, other._covars_extended): return False # Compare JAX arrays using array_equal if not jax.numpy.array_equal(self._nstep_array, other._nstep_array): return False if not jax.numpy.array_equal(self._dt_array_extended, other._dt_array_extended): return False if self._max_steps_per_interval != other._max_steps_per_interval: return False # State names and initial time if self.statenames != other.statenames: return False if float(self.t0) != float(other.t0): return False # Model components: rely on their own __eq__ implementations if self.rinit != other.rinit: return False if self.rproc != other.rproc: return False if (self.dmeas is None) != (other.dmeas is None): return False if self.dmeas is not None and self.dmeas != other.dmeas: return False if (self.rmeas is None) != (other.rmeas is None): return False if self.rmeas is not None and self.rmeas != other.rmeas: return False if self.results_history != other.results_history: return False if self.par_trans != other.par_trans: return False # fresh_key: both None or numerically equal if (self.fresh_key is None) != (other.fresh_key is None): return False if self.fresh_key is not None and other.fresh_key is not None: if not jax.numpy.array_equal( jax.random.key_data(self.fresh_key), jax.random.key_data(other.fresh_key), ): return False return True
[docs] @staticmethod def merge(*pomp_objs: "Pomp") -> "Pomp": """ Merges multiple `Pomp` objects into a single instance by combining their parameter replicates and results histories. All provided `Pomp` objects must share the same structural components (e.g., state names, parameter names, and model logic). The resulting object will contain the union of all parameter sets and their corresponding estimation results, which is particularly useful for consolidating parallelized simulation or estimation runs. """ if len(pomp_objs) == 0: raise ValueError("At least one Pomp object must be provided.") first = pomp_objs[0] for obj in pomp_objs: if not isinstance(obj, type(first)): raise TypeError("All merged objects must be of type Pomp.") if obj.canonical_param_names != first.canonical_param_names: raise ValueError( "All Pomp objects must have the same canonical_param_names." ) if obj.statenames != first.statenames: raise ValueError("All Pomp objects must have the same statenames.") if not obj.ys.equals(first.ys): raise ValueError("All Pomp objects must have the same ys data.") if obj.t0 != first.t0: raise ValueError("All Pomp objects must have the same t0.") if obj.rinit != first.rinit or obj.rproc != first.rproc: raise ValueError("All Pomp objects must have the same rinit and rproc.") if (obj.dmeas is None) != (first.dmeas is None): raise ValueError( "All Pomp objects must have the same dmeas (both None or both not None)." ) if obj.dmeas is not None and obj.dmeas != first.dmeas: raise ValueError("All Pomp objects must have the same dmeas.") if (obj.rmeas is None) != (first.rmeas is None): raise ValueError( "All Pomp objects must have the same rmeas (both None or both not None)." ) if obj.rmeas is not None and obj.rmeas != first.rmeas: raise ValueError("All Pomp objects must have the same rmeas.") if obj.par_trans != first.par_trans: raise ValueError("All Pomp objects must have the same par_trans.") merged_theta = PompParameters.merge(*[obj._theta for obj in pomp_objs]) merged_history = ResultsHistory.merge( *[obj.results_history for obj in pomp_objs] ) merged_pomp = deepcopy(first) merged_pomp._theta = merged_theta merged_pomp.results_history = merged_history merged_pomp.fresh_key = first.fresh_key return merged_pomp
def __getstate__(self): """ Custom pickling method to handle wrapped function objects. This is necessary because the JAX-wrapped functions are not picklable. """ state = self.__dict__.copy() # Use cloudpickle to store model functions by-value. This ensures that # the unpickling environment does not require the original source modules. if hasattr(self.rinit, "struct"): original_func = self.rinit.original_func state["_rinit_func_bytes"] = cloudpickle.dumps(original_func) if hasattr(self.rproc, "struct"): original_func = self.rproc.original_func state["_rproc_func_bytes"] = cloudpickle.dumps(original_func) state["_rproc_dt"] = getattr(self.rproc, "dt", None) state["_rproc_nstep"] = getattr(self.rproc, "nstep", None) state["_rproc_accumvars"] = getattr(self.rproc, "accumvars", None) if self.dmeas is not None and hasattr(self.dmeas, "struct"): original_func = self.dmeas.original_func state["_dmeas_func_bytes"] = cloudpickle.dumps(original_func) if self.rmeas is not None and hasattr(self.rmeas, "struct"): original_func = self.rmeas.original_func state["_rmeas_func_bytes"] = cloudpickle.dumps(original_func) # Store JAX key as raw bits (key is not picklable directly) if self.fresh_key is not None: state["_fresh_key_data"] = jax.random.key_data(self.fresh_key) # Remove the wrapped objects and key from state state.pop("rinit", None) state.pop("rproc", None) state.pop("dmeas", None) state.pop("rmeas", None) state.pop("fresh_key", None) return state def __setstate__(self, state): """ Custom unpickling method to reconstruct wrapped function objects. This is necessary because the JAX-wrapped functions are not picklable. """ # Restore basic attributes self.__dict__.update(state) # Reconstruct JAX key from raw bits if "_fresh_key_data" in state: try: self.fresh_key = jax.random.wrap_key_data(state["_fresh_key_data"]) except Exception as e: warnings.warn(f"Failed to reconstruct JAX fresh_key: {e}", UserWarning) self.fresh_key = None elif "fresh_key" not in self.__dict__: self.fresh_key = None def _load_func(prefix: str) -> Any: func_bytes_key = f"_{prefix}_func_bytes" func_name_key = f"_{prefix}_func_name" module_key = f"_{prefix}_module" try: # Modern approach (by-value): Uses cloudpickle bytes to remove # environment dependencies. if func_bytes_key in state: return cloudpickle.loads(state[func_bytes_key]) # Legacy approach (by-reference): Provided for backward compatibility # with objects pickled in older versions of pypomp. elif func_name_key in state: module = importlib.import_module(state[module_key]) return getattr(module, state[func_name_key]) except Exception as e: warnings.warn( f"Failed to reconstruct {prefix} function: {e}. " f"The model may be unusable for simulations or estimation.", UserWarning, ) return None # Reconstruct rinit obj_rinit = _load_func("rinit") if obj_rinit is not None: if isinstance(obj_rinit, _RInit): self.rinit = obj_rinit else: self.rinit = _RInit( struct=obj_rinit, statenames=self.statenames, param_names=self.canonical_param_names, covar_names=self.covar_names, par_trans=self.par_trans, ) # Reconstruct rproc obj_rproc = _load_func("rproc") if obj_rproc is not None: if isinstance(obj_rproc, _RProc): self.rproc = obj_rproc else: kwargs = {} if state.get("_rproc_dt") is not None: kwargs["dt"] = state["_rproc_dt"] if ( state.get("_rproc_nstep") is not None and state.get("_rproc_dt") is None ): kwargs["nstep"] = state["_rproc_nstep"] if state.get("_rproc_accumvars") is not None: kwargs["accumvars"] = state["_rproc_accumvars"] self.rproc = _RProc( struct=obj_rproc, statenames=self.statenames, param_names=self.canonical_param_names, covar_names=self.covar_names, par_trans=self.par_trans, **kwargs, ) if state.get("_rproc_nstep") is not None: if state.get("_rproc_dt") is not None: self.rproc.nstep = state["_rproc_nstep"] # Reconstruct dmeas obj_dmeas = _load_func("dmeas") if obj_dmeas is not None: if isinstance(obj_dmeas, _DMeas): self.dmeas = obj_dmeas else: self.dmeas = _DMeas( struct=obj_dmeas, statenames=self.statenames, param_names=self.canonical_param_names, covar_names=self.covar_names, par_trans=self.par_trans, y_names=list(self.ys.columns) if hasattr(self, "ys") else None, ) # Reconstruct rmeas obj_rmeas = _load_func("rmeas") if obj_rmeas is not None: if isinstance(obj_rmeas, _RMeas): self.rmeas = obj_rmeas else: self.rmeas = _RMeas( struct=obj_rmeas, statenames=self.statenames, param_names=self.canonical_param_names, covar_names=self.covar_names, par_trans=self.par_trans, y_names=list(self.ys.columns) if hasattr(self, "ys") else None, ) # Set defaults if reconstruction failed or was missing if not hasattr(self, "rinit"): self.rinit = None # type: ignore if not hasattr(self, "rproc"): self.rproc = None # type: ignore if not hasattr(self, "rmeas"): self.rmeas = None if not hasattr(self, "dmeas"): self.dmeas = None # Clean up temporary state variables for key in [ "_rinit_func_bytes", "_rinit_func_name", "_rinit_module", "_rproc_func_bytes", "_rproc_func_name", "_rproc_dt", "_rproc_nstep", "_rproc_accumvars", "_rproc_module", "_dmeas_func_bytes", "_dmeas_func_name", "_dmeas_module", "_rmeas_func_bytes", "_rmeas_func_name", "_rmeas_module", "_fresh_key_data", ]: if key in self.__dict__: del self.__dict__[key]
[docs] def arma( self, order: tuple[int, int, int] = (1, 0, 1), log_ys: bool = False, suppress_warnings: bool = True, ) -> float: """ Fits an independent ARIMA model to the observation data and returns the estimated log-likelihood. This is a wrapper around `pypomp.benchmarks.arma`. Args: order (tuple, optional): The (p, d, q) order for the ARIMA model. Defaults to (1, 0, 1). log_ys (bool, optional): If True, fits the model to log(y+1). Defaults to False. suppress_warnings (bool, optional): If True, suppresses individual warnings from statsmodels and issues a summary warning instead. Defaults to True. Returns: float: The sum of the log-likelihoods. """ return benchmarks.arma( self.ys, order=order, log_ys=log_ys, suppress_warnings=suppress_warnings )
[docs] def negbin( self, autoregressive: bool = False, suppress_warnings: bool = True ) -> float: """ Fits a Negative Binomial model to the observation data and returns the log-likelihood. This is a wrapper around `pypomp.benchmarks.negbin`. Args: autoregressive (bool, optional): If True, fits an AR(1) model. Defaults to False (iid). suppress_warnings (bool, optional): If True, suppresses individual warnings from statsmodels/optimization and issues a summary warning instead. Defaults to True. Returns: float: The sum of the log-likelihoods. """ return benchmarks.negbin( self.ys, autoregressive=autoregressive, suppress_warnings=suppress_warnings, )