Source code for pypomp.core.results

from dataclasses import dataclass, field
from abc import ABC, abstractmethod
import pandas as pd
import xarray as xr
import numpy as np
import jax

from .rw_sigma import RWSigma
from .parameters import PanelParameters


from ..maths import logmeanexp, logmeanexp_se


@dataclass
class BaseResult(ABC):
    """Base class for all result types."""

    method: str
    execution_time: float | None
    key: jax.Array
    timestamp: pd.Timestamp = field(default_factory=pd.Timestamp.now)

    def __post_init__(self):
        """Post-initialization hook."""
        pass

    def __eq__(self, other) -> bool:  # type: ignore[override]
        """
        Structural equality for all result types.

        Compares:
        - type
        - method string
        - execution_time
        - timestamp
        - JAX key contents (via key_data)
        """
        if not isinstance(other, type(self)):
            return False

        if self.method != other.method:
            return False

        if self.execution_time != other.execution_time:
            return False

        if self.timestamp != other.timestamp:
            return False

        # Compare JAX keys by underlying data
        if not jax.numpy.array_equal(
            jax.random.key_data(self.key), jax.random.key_data(other.key)
        ):
            return False

        return True

    def __getstate__(self):
        """
        Custom pickling: store JAX key as raw bits (key is not always picklable directly).
        """
        state = vars(self).copy()
        if self.key is not None:
            state["_key_data"] = jax.random.key_data(self.key)
        state.pop("key", None)
        return state

    def __setstate__(self, state):
        """
        Custom unpickling: reconstruct JAX key from raw bits.
        """
        vars(self).update(state)
        if "_key_data" in state:
            self.key = jax.random.wrap_key_data(state["_key_data"])
        vars(self).pop("_key_data", None)

    @abstractmethod
    def to_dataframe(self, ignore_nan: bool = False) -> pd.DataFrame:
        """Convert result to DataFrame."""
        pass

    @abstractmethod
    def CLL(self, average: bool = False) -> pd.DataFrame:
        """Return conditional log-likelihoods as a DataFrame."""
        pass

    @abstractmethod
    def ESS(self, average: bool = False) -> pd.DataFrame:
        """Return Effective Sample Size as a DataFrame."""
        pass

    @abstractmethod
    def print_summary(self):
        """Print a summary of this result."""
        pass


@dataclass
class PompBaseResult(BaseResult):
    """Base class for Pomp results."""

    theta: list[dict] = field(default_factory=list)

    def __eq__(self, other) -> bool:  # type: ignore[override]
        """
        Structural equality for Pomp result types.

        Extends BaseResult equality by comparing theta.
        """
        if not super().__eq__(other):
            return False

        # theta is a list of plain dicts; rely on Python's structural equality
        if self.theta != other.theta:
            return False

        return True


@dataclass
class PanelPompBaseResult(BaseResult):
    """Base class for PanelPomp results."""

    theta: "PanelParameters | None" = None

    def __eq__(self, other) -> bool:  # type: ignore[override]
        """
        Structural equality for PanelPomp result types.

        Extends BaseResult equality by comparing PanelParameters.
        """
        if not super().__eq__(other):
            return False

        if (self.theta is None) != (other.theta is None):
            return False

        if self.theta is not None and self.theta != other.theta:
            return False

        return True


[docs] @dataclass class PompPFilterResult(PompBaseResult): """Result from Pomp.pfilter() method.""" logLiks: xr.DataArray = field(default_factory=lambda: xr.DataArray([])) J: int = 0 reps: int = 1 thresh: float = 0.0 CLL_da: xr.DataArray | None = None ESS_da: xr.DataArray | None = None filter_mean: xr.DataArray | None = None prediction_mean: xr.DataArray | None = None def __post_init__(self): """Set method to pfilter.""" self.method = "pfilter" def __eq__(self, other) -> bool: # type: ignore[override] """Structural equality including log-likelihoods and diagnostics.""" if not super().__eq__(other): return False if self.J != other.J or self.reps != other.reps or self.thresh != other.thresh: return False # logLiks if isinstance(self.logLiks, xr.DataArray) and isinstance( other.logLiks, xr.DataArray ): if not self.logLiks.equals(other.logLiks): return False else: if not np.array_equal( np.asarray(self.logLiks), np.asarray(other.logLiks), equal_nan=True ): return False # Optional diagnostics for name in ["CLL_da", "ESS_da", "filter_mean", "prediction_mean"]: a = getattr(self, name) b = getattr(other, name) if (a is None) != (b is None): return False if a is not None and b is not None: if isinstance(a, xr.DataArray) and isinstance(b, xr.DataArray): if not a.equals(b): return False else: if not np.array_equal(np.asarray(a), np.asarray(b), equal_nan=True): return False return True def to_dataframe(self, ignore_nan: bool = False) -> pd.DataFrame: """Convert pfilter result to DataFrame.""" if not self.theta or self.logLiks.size == 0: return pd.DataFrame() arr = getattr(self.logLiks, "values", self.logLiks) logLik_arr_np = np.asarray(arr) logLik = logmeanexp(logLik_arr_np, axis=-1, ignore_nan=ignore_nan) if logLik_arr_np.shape[-1] > 1: se = logmeanexp_se(logLik_arr_np, axis=-1, ignore_nan=ignore_nan) else: se = np.full_like(logLik, np.nan, dtype=float) theta_df = pd.DataFrame(self.theta) df = pd.DataFrame( { "theta_idx": np.arange(len(theta_df), dtype=int), "logLik": logLik.astype(float), "se": se.astype(float), } ) return pd.concat( [df.reset_index(drop=True), theta_df.reset_index(drop=True)], axis=1 ) def CLL(self, average: bool = False) -> pd.DataFrame: """Return conditional log-likelihoods as a DataFrame.""" if self.CLL_da is None or self.CLL_da.size == 0: return pd.DataFrame() cll: xr.DataArray = self.CLL_da if average: cll_values = np.asarray(cll.values) cll_avg = logmeanexp(cll_values, axis=1) n_theta, n_time = cll_avg.shape avg_da = xr.DataArray( cll_avg, dims=["theta_idx", "time"], coords={ "theta_idx": cll.coords["theta_idx"].values if "theta_idx" in cll.coords else np.arange(n_theta), "time": cll.coords["time"].values if "time" in cll.coords else np.arange(n_time), }, ) return avg_da.to_dataframe(name="CLL").reset_index() else: return cll.to_dataframe(name="CLL").reset_index() def ESS(self, average: bool = False) -> pd.DataFrame: """Return Effective Sample Size as a DataFrame.""" if self.ESS_da is None or self.ESS_da.size == 0: return pd.DataFrame() ess: xr.DataArray = self.ESS_da if average: ess_avg = ess.mean(dim="rep") return ess_avg.to_dataframe(name="ESS").reset_index() else: return ess.to_dataframe(name="ESS").reset_index() def traces(self) -> pd.DataFrame: """Return traces DataFrame for this pfilter result.""" if not self.theta or not len(self.logLiks): return pd.DataFrame() arr = getattr(self.logLiks, "values", self.logLiks) logLik_arr_np = np.asarray(arr) logliks = logmeanexp(logLik_arr_np, axis=-1) n_reps = len(self.theta) base_df = pd.DataFrame( { "theta_idx": np.arange(n_reps, dtype=int), "iteration": np.zeros(n_reps, dtype=int), "method": self.method, "logLik": logliks.astype(float), } ) theta_df = pd.DataFrame(self.theta).reset_index(drop=True) return pd.concat([base_df.reset_index(drop=True), theta_df], axis=1) def print_summary(self): """Print summary of pfilter result.""" print(f"Method: {self.method}") print(f"Number of parameter sets: {len(self.theta)}") print(f"Number of particles (J): {self.J}") print(f"Number of replicates: {self.reps}") print(f"Resampling threshold: {self.thresh}") print(f"Execution time: {self.execution_time} seconds") df = self.to_dataframe() if not df.empty: print("\nTop 5 Results:") df_sorted = df.sort_values("logLik", ascending=False).head(5) print(df_sorted.to_string()) @staticmethod def merge(*results: "PompPFilterResult") -> "PompPFilterResult": """ Merge replications from an arbitrary number of PompPFilterResult objects into a single PompPFilterResult object. All objects must have the same J (number of particles), thresh (resampling threshold), and reps (number of replicates). Execution time is the maximum execution time of the merged objects, and the key is the key from the first object. """ # TODO: handle keys in a better way if len(results) == 0: raise ValueError("At least one PompPFilterResult object must be provided.") first = results[0] for result in results: if not isinstance(result, type(first)): raise TypeError("All merged objects must be of type PompPFilterResult.") if result.J != first.J: raise ValueError( "All PompPFilterResult objects must have the same J (number of particles)." ) if result.thresh != first.thresh: raise ValueError( "All PompPFilterResult objects must have the same thresh (resampling threshold)." ) if result.reps != first.reps: raise ValueError( "All PompPFilterResult objects must have the same reps (number of replicates)." ) # Merge theta lists merged_theta = [] for result in results: merged_theta.extend(result.theta) # Concatenate logLiks along the "theta" dimension logLik_arrays = [] for result in results: if result.logLiks.size > 0: logLik_arrays.append(result.logLiks) if logLik_arrays: merged_logLiks: xr.DataArray = xr.concat(logLik_arrays, dim="theta_idx") # type: ignore[assignment] else: merged_logLiks = xr.DataArray([]) # Concatenate optional diagnostics along the "theta_idx" dimension def merge_optional_diagnostic(name: str) -> xr.DataArray | None: arrays = [] for result in results: diag = getattr(result, name) if diag is not None and diag.size > 0: arrays.append(diag) if arrays: return xr.concat(arrays, dim="theta_idx") # type: ignore[return-value] return None merged_CLL = merge_optional_diagnostic("CLL_da") merged_ESS = merge_optional_diagnostic("ESS_da") merged_filter_mean = merge_optional_diagnostic("filter_mean") merged_prediction_mean = merge_optional_diagnostic("prediction_mean") # Use max execution time if available execution_times = [ r.execution_time for r in results if r.execution_time is not None ] max_execution_time = max(execution_times) if execution_times else None merged_result = PompPFilterResult( method=first.method, execution_time=max_execution_time, key=first.key, theta=merged_theta, logLiks=merged_logLiks, J=first.J, reps=first.reps, thresh=first.thresh, CLL_da=merged_CLL, ESS_da=merged_ESS, filter_mean=merged_filter_mean, prediction_mean=merged_prediction_mean, ) return merged_result
[docs] @dataclass class PompMIFResult(PompBaseResult): """Result from Pomp.mif() method.""" traces_da: xr.DataArray = field(default_factory=lambda: xr.DataArray([])) # type: ignore[assignment] J: int = 0 M: int = 0 rw_sd: RWSigma | None = None a: float = 0.0 thresh: float = 0.0 n_monitors: int = 0 def __post_init__(self): """Set method to mif.""" self.method = "mif" def __eq__(self, other) -> bool: # type: ignore[override] """Structural equality including traces and algorithmic settings.""" if not super().__eq__(other): return False if ( self.J != other.J or self.M != other.M or self.a != other.a or self.thresh != other.thresh or self.n_monitors != other.n_monitors ): return False # rw_sd comparison: rely on its own __eq__ if present if (self.rw_sd is None) != (other.rw_sd is None): return False if self.rw_sd is not None and self.rw_sd != other.rw_sd: return False # traces_da if isinstance(self.traces_da, xr.DataArray) and isinstance( other.traces_da, xr.DataArray ): if not self.traces_da.equals(other.traces_da): return False else: if not np.array_equal( np.asarray(self.traces_da), np.asarray(other.traces_da), equal_nan=True, ): return False return True def to_dataframe(self, ignore_nan: bool = False) -> pd.DataFrame: """Convert mif result to DataFrame.""" traces_da: xr.DataArray = self.traces_da if traces_da is None or not hasattr(traces_da, "sizes") or not traces_da.sizes: return pd.DataFrame() df = ( traces_da.isel(iteration=-1) .to_dataset(dim="variable") .to_dataframe() .reset_index() ) param_names = list(self.theta[0].keys()) cols = ["theta_idx", "logLik"] + param_names df = pd.DataFrame(df[cols]) df.insert(2, "se", np.nan) return df def CLL(self, average: bool = False) -> pd.DataFrame: return pd.DataFrame() def ESS(self, average: bool = False) -> pd.DataFrame: return pd.DataFrame() def traces(self) -> pd.DataFrame: """Return traces DataFrame for this mif result.""" if self.traces_da is None: return pd.DataFrame() return ( self.traces_da.to_dataset(dim="variable") .to_dataframe() .reset_index() .assign(method="mif") ) def print_summary(self): """Print summary of mif result.""" print(f"Method: {self.method}") print(f"Number of parameter sets: {len(self.theta)}") print(f"Number of particles (J): {self.J}") print(f"Number of iterations (M): {self.M}") print(f"Cooling fraction (a): {self.a}") print(f"Resampling threshold: {self.thresh}") print(f"Number of monitors: {self.n_monitors}") print(f"Execution time: {self.execution_time} seconds") df = self.to_dataframe() if not df.empty: print("\nTop 5 Results:") df_sorted = df.sort_values("logLik", ascending=False).head(5) print(df_sorted.to_string()) @staticmethod def merge(*results: "PompMIFResult") -> "PompMIFResult": """Merge replications from multiple PompMIFResult objects into a single object.""" if len(results) == 0: raise ValueError("At least one PompMIFResult object must be provided.") first = results[0] for result in results: if not isinstance(result, type(first)): raise TypeError("All merged objects must be of type PompMIFResult.") if ( result.J != first.J or result.M != first.M or result.a != first.a or result.thresh != first.thresh or result.n_monitors != first.n_monitors ): raise ValueError( "All PompMIFResult objects must have the same J, M, a, and thresh." ) if (result.rw_sd is None) != (first.rw_sd is None) or ( result.rw_sd is not None and result.rw_sd != first.rw_sd ): raise ValueError("All PompMIFResult objects must have the same rw_sd.") merged_theta = [] for result in results: merged_theta.extend(result.theta) trace_arrays = [r.traces_da for r in results if r.traces_da.size > 0] merged_traces = ( xr.concat(trace_arrays, dim="theta_idx") if trace_arrays else xr.DataArray([]) ) # type: ignore[assignment] execution_times = [ r.execution_time for r in results if r.execution_time is not None ] max_execution_time = max(execution_times) if execution_times else None return PompMIFResult( method=first.method, execution_time=max_execution_time, key=first.key, theta=merged_theta, traces_da=merged_traces, J=first.J, M=first.M, rw_sd=first.rw_sd, a=first.a, thresh=first.thresh, n_monitors=first.n_monitors, )
[docs] @dataclass class PompTrainResult(PompBaseResult): """Result from Pomp.train() method.""" traces_da: xr.DataArray = field(default_factory=lambda: xr.DataArray([])) # type: ignore[assignment] optimizer: str = "SGD" J: int = 0 M: int = 0 eta: dict[str, float] = field(default_factory=lambda: {}) alpha: float = 0.97 thresh: int = 0 ls: bool = False c: float = 0.1 max_ls_itn: int = 10 eta_cooling: float = 1.0 alpha_cooling: float = 1.0 def __post_init__(self): """Set method to train.""" self.method = "train" def __eq__(self, other) -> bool: # type: ignore[override] """Structural equality including traces and optimizer settings.""" if not super().__eq__(other): return False scalar_fields = [ "optimizer", "J", "M", "eta", "alpha", "thresh", "ls", "c", "max_ls_itn", "eta_cooling", "alpha_cooling", ] for name in scalar_fields: if getattr(self, name) != getattr(other, name): return False # traces_da if isinstance(self.traces_da, xr.DataArray) and isinstance( other.traces_da, xr.DataArray ): if not self.traces_da.equals(other.traces_da): return False else: if not np.array_equal( np.asarray(self.traces_da), np.asarray(other.traces_da), equal_nan=True, ): return False return True def to_dataframe(self, ignore_nan: bool = False) -> pd.DataFrame: """Convert train result to DataFrame.""" traces_da: xr.DataArray = self.traces_da if traces_da is None or not hasattr(traces_da, "sizes") or not traces_da.sizes: return pd.DataFrame() df = ( traces_da.isel(iteration=-1) .to_dataset(dim="variable") .to_dataframe() .reset_index() ) param_names = list(self.theta[0].keys()) cols = ["theta_idx", "logLik"] + param_names df = pd.DataFrame(df[cols]) df.insert(2, "se", np.nan) return df def CLL(self, average: bool = False) -> pd.DataFrame: return pd.DataFrame() def ESS(self, average: bool = False) -> pd.DataFrame: return pd.DataFrame() def traces(self) -> pd.DataFrame: """Return traces DataFrame for this train result.""" if self.traces_da is None: return pd.DataFrame() return ( self.traces_da.to_dataset(dim="variable") .to_dataframe() .reset_index() .assign(method="train") ) def print_summary(self): """Print summary of train result.""" print(f"Method: {self.method}") print(f"Number of parameter sets: {len(self.theta)}") print(f"Optimizer: {self.optimizer}") print(f"Number of particles (J): {self.J}") print(f"Number of iterations (M): {self.M}") print(f"Learning rate (eta): {self.eta}") print(f"Discount factor (alpha): {self.alpha}") print(f"Resampling threshold: {self.thresh}") print(f"Line search: {self.ls}") if self.ls: print(f"Armijo constant (c): {self.c}") print(f"Max line search iterations: {self.max_ls_itn}") print(f"Cooling factor for eta: {self.eta_cooling}") print(f"Cooling factor for alpha: {self.alpha_cooling}") print(f"Execution time: {self.execution_time} seconds") df = self.to_dataframe() if not df.empty: print("\nTop 5 Results:") df_sorted = df.sort_values("logLik", ascending=False).head(5) print(df_sorted.to_string()) @staticmethod def merge(*results: "PompTrainResult") -> "PompTrainResult": """Merge replications from multiple PompTrainResult objects into a single object.""" if len(results) == 0: raise ValueError("At least one PompTrainResult object must be provided.") first = results[0] scalar_fields = [ "optimizer", "J", "M", "eta", "alpha", "thresh", "ls", "c", "max_ls_itn", ] for result in results: if not isinstance(result, type(first)): raise TypeError("All merged objects must be of type PompTrainResult.") for field_name in scalar_fields: if getattr(result, field_name) != getattr(first, field_name): raise ValueError( f"All PompTrainResult objects must have the same {field_name}." ) merged_theta = [] for result in results: merged_theta.extend(result.theta) trace_arrays = [r.traces_da for r in results if r.traces_da.size > 0] merged_traces = ( xr.concat(trace_arrays, dim="theta_idx") if trace_arrays else xr.DataArray([]) ) # type: ignore[assignment] execution_times = [ r.execution_time for r in results if r.execution_time is not None ] max_execution_time = max(execution_times) if execution_times else None return PompTrainResult( method=first.method, execution_time=max_execution_time, key=first.key, theta=merged_theta, traces_da=merged_traces, optimizer=first.optimizer, J=first.J, M=first.M, eta=first.eta, alpha=first.alpha, thresh=first.thresh, ls=first.ls, c=first.c, max_ls_itn=first.max_ls_itn, eta_cooling=first.eta_cooling, alpha_cooling=first.alpha_cooling, )
[docs] @dataclass class PanelPompPFilterResult(PanelPompBaseResult): """Result from PanelPomp.pfilter() method.""" logLiks: xr.DataArray = field(default_factory=lambda: xr.DataArray([])) J: int = 0 reps: int = 1 thresh: float = 0.0 theta: "PanelParameters | None" = None CLL_da: xr.DataArray | None = None ESS_da: xr.DataArray | None = None filter_mean: xr.DataArray | None = None prediction_mean: xr.DataArray | None = None def __post_init__(self): """Set method to pfilter.""" self.method = "pfilter" def __eq__(self, other) -> bool: # type: ignore[override] """Structural equality including panel log-likelihoods and diagnostics.""" if not super().__eq__(other): return False if self.J != other.J or self.reps != other.reps or self.thresh != other.thresh: return False if isinstance(self.logLiks, xr.DataArray) and isinstance( other.logLiks, xr.DataArray ): if not self.logLiks.equals(other.logLiks): return False else: if not np.array_equal( np.asarray(self.logLiks), np.asarray(other.logLiks), equal_nan=True, ): return False for name in ["CLL_da", "ESS_da", "filter_mean", "prediction_mean"]: a = getattr(self, name) b = getattr(other, name) if (a is None) != (b is None): return False if a is not None and b is not None: if isinstance(a, xr.DataArray) and isinstance(b, xr.DataArray): if not a.equals(b): return False else: if not np.array_equal(np.asarray(a), np.asarray(b), equal_nan=True): return False return True def to_dataframe(self, ignore_nan: bool = False) -> pd.DataFrame: """Convert panel pfilter result to DataFrame.""" ll = logmeanexp(self.logLiks.values, axis=-1, ignore_nan=ignore_nan) df = ( pd.DataFrame(ll, columns=self.logLiks.coords["unit"].values) .assign( theta_idx=lambda x: range(len(x)), **{"shared logLik": lambda x: x.sum(axis=1)}, ) .melt( id_vars=["theta_idx", "shared logLik"], var_name="unit", value_name="unit logLik", ) ) # Extract shared/unit_specific from theta if self.theta is not None: shared_list: list[pd.DataFrame] = [] unit_specific_list: list[pd.DataFrame] = [] for i in range(len(self.theta._theta)): shared_df = self.theta._theta[i].get("shared") unit_specific_df = self.theta._theta[i].get("unit_specific") if shared_df is not None: shared_list.append(shared_df) if unit_specific_df is not None: unit_specific_list.append(unit_specific_df) if shared_list: s_params = pd.concat(shared_list, axis=1).T.reset_index(drop=True) df = df.join(s_params, on="theta_idx") if unit_specific_list: u_params = ( pd.concat(unit_specific_list, keys=range(len(unit_specific_list))) .stack() .unstack(level=1) .reset_index() ) col_names = list(u_params.columns) u_params.rename( columns={col_names[0]: "theta_idx", col_names[1]: "unit"}, inplace=True, ) df = df.merge(u_params, on=["theta_idx", "unit"], how="left") return df def CLL(self, average: bool = False) -> pd.DataFrame: """Return conditional log-likelihoods as a DataFrame.""" if self.CLL_da is None or self.CLL_da.size == 0: return pd.DataFrame() cll: xr.DataArray = self.CLL_da if average: cll_values = np.asarray(cll.values) cll_avg = logmeanexp(cll_values, axis=2) avg_da = xr.DataArray( cll_avg, dims=["theta_idx", "unit", "time"], coords={ "theta_idx": cll.coords["theta_idx"].values if "theta_idx" in cll.coords else np.arange(cll_avg.shape[0]), "unit": cll.coords["unit"].values, "time": cll.coords["time"].values if "time" in cll.coords else np.arange(cll_avg.shape[2]), }, ) return avg_da.to_dataframe(name="CLL").reset_index() else: return cll.to_dataframe(name="CLL").reset_index() def ESS(self, average: bool = False) -> pd.DataFrame: """Return Effective Sample Size as a DataFrame.""" if self.ESS_da is None or self.ESS_da.size == 0: return pd.DataFrame() ess: xr.DataArray = self.ESS_da if average: ess_avg = ess.mean(dim="rep") return ess_avg.to_dataframe(name="ESS").reset_index() else: return ess.to_dataframe(name="ESS").reset_index() def traces(self) -> pd.DataFrame: """Return pfilter results formatted as traces (long format).""" ll = logmeanexp(self.logLiks.values, axis=-1) reps = np.arange(len(ll)) df_s = pd.DataFrame( {"theta_idx": reps, "unit": "shared", "logLik": ll.sum(axis=1)} ) df_u = ( pd.DataFrame(ll, columns=self.logLiks.coords["unit"].values, index=reps) .melt(ignore_index=False, var_name="unit", value_name="logLik") .reset_index() .rename(columns={"index": "theta_idx"}) ) if self.theta is not None: shared_list: list[pd.DataFrame] = [] unit_specific_list: list[pd.DataFrame] = [] for i in range(len(self.theta._theta)): shared_df = self.theta._theta[i].get("shared") unit_specific_df = self.theta._theta[i].get("unit_specific") if shared_df is not None: shared_list.append(shared_df) if unit_specific_df is not None: unit_specific_list.append(unit_specific_df) if shared_list: p_s = pd.concat(shared_list, axis=1).T.set_axis(reps, axis=0) df_s = df_s.join(p_s, on="theta_idx") df_u = df_u.join(p_s, on="theta_idx") if unit_specific_list: p_u = pd.concat(unit_specific_list, keys=reps).stack().unstack(level=1) p_u.index.names = ["theta_idx", "unit"] df_u = df_u.join(p_u, on=["theta_idx", "unit"]) return pd.concat([df_s, df_u], ignore_index=True).assign( method="pfilter", iteration=1 ) def print_summary(self): """Print summary of panel pfilter result.""" print(f"Method: {self.method}") print( f"Number of parameter sets: {self.theta.num_replicates() if self.theta is not None else 0}" ) print(f"Number of particles (J): {self.J}") print(f"Number of replicates: {self.reps}") print(f"Resampling threshold: {self.thresh}") print(f"Execution time: {self.execution_time} seconds") df = self.to_dataframe() if not df.empty: print("\nTop 5 Results:") df_sorted = df.sort_values("shared logLik", ascending=False).head(5) print(df_sorted.to_string()) @staticmethod def merge(*results: "PanelPompPFilterResult") -> "PanelPompPFilterResult": """Merge replications from multiple PanelPompPFilterResult objects into a single object.""" if len(results) == 0: raise ValueError( "At least one PanelPompPFilterResult object must be provided." ) first = results[0] for result in results: if not isinstance(result, type(first)): raise TypeError( "All merged objects must be of type PanelPompPFilterResult." ) if ( result.J != first.J or result.reps != first.reps or result.thresh != first.thresh ): raise ValueError( "All PanelPompPFilterResult objects must have the same J, reps, and thresh." ) merged_theta = ( PanelParameters.merge(*[r.theta for r in results if r.theta is not None]) if any(r.theta is not None for r in results) else None ) logLik_arrays = [r.logLiks for r in results if r.logLiks.size > 0] merged_logLiks = ( xr.concat(logLik_arrays, dim="theta_idx") if logLik_arrays else xr.DataArray([]) ) # type: ignore[assignment] def merge_optional_diagnostic(name: str) -> xr.DataArray | None: arrays = [ getattr(r, name) for r in results if getattr(r, name) is not None and getattr(r, name).size > 0 ] return xr.concat(arrays, dim="theta_idx") if arrays else None # type: ignore[return-value] execution_times = [ r.execution_time for r in results if r.execution_time is not None ] max_execution_time = max(execution_times) if execution_times else None return PanelPompPFilterResult( method=first.method, execution_time=max_execution_time, key=first.key, theta=merged_theta, logLiks=merged_logLiks, J=first.J, reps=first.reps, thresh=first.thresh, CLL_da=merge_optional_diagnostic("CLL_da"), ESS_da=merge_optional_diagnostic("ESS_da"), filter_mean=merge_optional_diagnostic("filter_mean"), prediction_mean=merge_optional_diagnostic("prediction_mean"), )
[docs] @dataclass class PanelPompMIFResult(PanelPompBaseResult): """Result from PanelPomp.mif() method.""" shared_traces: xr.DataArray = field(default_factory=lambda: xr.DataArray([])) unit_traces: xr.DataArray = field(default_factory=lambda: xr.DataArray([])) logLiks: xr.DataArray = field(default_factory=lambda: xr.DataArray([])) theta: "PanelParameters | None" = None J: int = 0 M: int = 0 rw_sd: RWSigma | None = None a: float = 0.0 thresh: float = 0.0 n_monitors: int = 0 block: bool = True def __post_init__(self): """Set method to mif.""" self.method = "mif" def __eq__(self, other) -> bool: # type: ignore[override] """Structural equality including traces, log-likelihoods, and settings.""" if not super().__eq__(other): return False if ( self.J != other.J or self.M != other.M or self.a != other.a or self.thresh != other.thresh or self.n_monitors != other.n_monitors or self.block != other.block ): return False if (self.rw_sd is None) != (other.rw_sd is None): return False if self.rw_sd is not None and self.rw_sd != other.rw_sd: return False for name in ["shared_traces", "unit_traces", "logLiks"]: a = getattr(self, name) b = getattr(other, name) if isinstance(a, xr.DataArray) and isinstance(b, xr.DataArray): if not a.equals(b): return False else: if not np.array_equal(np.asarray(a), np.asarray(b), equal_nan=True): return False return True def to_dataframe(self, ignore_nan: bool = False) -> pd.DataFrame: """Convert panel mif result to DataFrame.""" s_df = ( self.shared_traces.isel(iteration=-1) .to_dataset(dim="variable") .to_dataframe() .rename(columns={"logLik": "shared logLik"}) ) u_df = ( self.unit_traces.isel(iteration=-1) .to_dataset(dim="variable") .to_dataframe() .rename(columns={"unitLogLik": "unit logLik"}) ) if "iteration" in s_df.columns: s_df = s_df.drop(columns=["iteration"]) u_df = u_df.join(s_df, on="theta_idx").reset_index() cols = ["theta_idx", "iteration", "shared logLik", "unit", "unit logLik"] + [ c for c in u_df.columns if c not in {"theta_idx", "iteration", "shared logLik", "unit", "unit logLik"} ] u_df = u_df[cols] assert isinstance(u_df, pd.DataFrame) return u_df def traces(self) -> pd.DataFrame: """Return panel mif results formatted as traces (long format).""" if self.shared_traces.size == 0: return pd.DataFrame() df_s = ( self.shared_traces.to_dataset(dim="variable").to_dataframe().reset_index() ) df_s["unit"] = "shared" df_u = self.unit_traces.to_dataset(dim="variable").to_dataframe().reset_index() df_u = df_u.rename(columns={"unitLogLik": "logLik"}) meta_cols = {"theta_idx", "iteration", "logLik", "unit"} shared_params = [c for c in df_s.columns if c not in meta_cols] if shared_params: df_u = df_u.merge( df_s[["theta_idx", "iteration"] + shared_params], on=["theta_idx", "iteration"], how="left", ) return pd.concat([df_s, df_u], ignore_index=True).assign(method="mif") def CLL(self, average: bool = False) -> pd.DataFrame: """Return conditional log-likelihoods as a DataFrame.""" return pd.DataFrame() def ESS(self, average: bool = False) -> pd.DataFrame: """Return Effective Sample Size as a DataFrame.""" return pd.DataFrame() def print_summary(self): """Print summary of panel mif result.""" print(f"Method: {self.method}") print( f"Number of parameter sets: {self.theta.num_replicates() if self.theta is not None else 0}" ) print(f"Number of particles (J): {self.J}") print(f"Number of iterations (M): {self.M}") print(f"Cooling fraction (a): {self.a}") print(f"Resampling threshold: {self.thresh}") print(f"Number of monitors: {self.n_monitors}") print(f"Block: {self.block}") print(f"Execution time: {self.execution_time} seconds") df = self.to_dataframe() if not df.empty: print("\nTop 5 Results:") df_sorted = df.sort_values("shared logLik", ascending=False).head(5) print(df_sorted.to_string()) @staticmethod def merge(*results: "PanelPompMIFResult") -> "PanelPompMIFResult": """Merge replications from multiple PanelPompMIFResult objects into a single object.""" if len(results) == 0: raise ValueError("At least one PanelPompMIFResult object must be provided.") first = results[0] for result in results: if not isinstance(result, type(first)): raise TypeError( "All merged objects must be of type PanelPompMIFResult." ) if ( result.J != first.J or result.M != first.M or result.a != first.a or result.thresh != first.thresh or result.n_monitors != first.n_monitors or result.block != first.block ): raise ValueError( "All PanelPompMIFResult objects must have the same J, M, a, thresh, and block." ) if (result.rw_sd is None) != (first.rw_sd is None) or ( result.rw_sd is not None and result.rw_sd != first.rw_sd ): raise ValueError( "All PanelPompMIFResult objects must have the same rw_sd." ) merged_theta = ( PanelParameters.merge(*[r.theta for r in results if r.theta is not None]) if any(r.theta is not None for r in results) else None ) shared_trace_arrays = [ r.shared_traces for r in results if r.shared_traces.size > 0 ] merged_shared_traces = ( xr.concat(shared_trace_arrays, dim="theta_idx") if shared_trace_arrays else xr.DataArray([]) ) # type: ignore[assignment] unit_trace_arrays = [r.unit_traces for r in results if r.unit_traces.size > 0] merged_unit_traces = ( xr.concat(unit_trace_arrays, dim="theta_idx") if unit_trace_arrays else xr.DataArray([]) ) # type: ignore[assignment] logLik_arrays = [r.logLiks for r in results if r.logLiks.size > 0] merged_logLiks = ( xr.concat(logLik_arrays, dim="theta_idx") if logLik_arrays else xr.DataArray([]) ) # type: ignore[assignment] execution_times = [ r.execution_time for r in results if r.execution_time is not None ] max_execution_time = max(execution_times) if execution_times else None return PanelPompMIFResult( method=first.method, execution_time=max_execution_time, key=first.key, theta=merged_theta, shared_traces=merged_shared_traces, unit_traces=merged_unit_traces, logLiks=merged_logLiks, J=first.J, M=first.M, rw_sd=first.rw_sd, a=first.a, thresh=first.thresh, n_monitors=first.n_monitors, block=first.block, )
[docs] @dataclass class PanelPompTrainResult(PanelPompBaseResult): """Result from PanelPomp.train() method.""" shared_traces: xr.DataArray = field(default_factory=lambda: xr.DataArray([])) unit_traces: xr.DataArray = field(default_factory=lambda: xr.DataArray([])) logLiks: xr.DataArray = field(default_factory=lambda: xr.DataArray([])) theta: "PanelParameters | None" = None optimizer: str = "SGD" J: int = 0 M: int = 0 eta: dict[str, float] | float = field(default_factory=lambda: {}) alpha: float = 0.97 eta_cooling: float = 1.0 alpha_cooling: float = 1.0 def __post_init__(self): self.method = "train" def __eq__(self, other) -> bool: # type: ignore[override] """Structural equality including traces, log-likelihoods, and settings.""" if not super().__eq__(other): return False if ( self.optimizer != other.optimizer or self.J != other.J or self.M != other.M or self.eta != other.eta or self.alpha != other.alpha or self.eta_cooling != other.eta_cooling or self.alpha_cooling != other.alpha_cooling ): return False for name in ["shared_traces", "unit_traces", "logLiks"]: a = getattr(self, name) b = getattr(other, name) if isinstance(a, xr.DataArray) and isinstance(b, xr.DataArray): if not a.equals(b): return False else: if not np.array_equal(np.asarray(a), np.asarray(b), equal_nan=True): return False return True def to_dataframe(self, ignore_nan: bool = False) -> pd.DataFrame: s_df = ( self.shared_traces.isel(iteration=-1) .to_dataset(dim="variable") .to_dataframe() .rename(columns={"logLik": "shared logLik"}) ) u_df = ( self.unit_traces.isel(iteration=-1) .to_dataset(dim="variable") .to_dataframe() .rename(columns={"unitLogLik": "unit logLik"}) ) if "iteration" in s_df.columns: s_df = s_df.drop(columns=["iteration"]) u_df = u_df.join(s_df, on="theta_idx").reset_index() cols = ["theta_idx", "iteration", "shared logLik", "unit", "unit logLik"] + [ c for c in u_df.columns if c not in { "theta_idx", "iteration", "shared logLik", "unit", "unit logLik", } ] u_df = u_df[cols] assert isinstance(u_df, pd.DataFrame) return u_df def traces(self) -> pd.DataFrame: """Return panel train results formatted as traces (long format).""" if self.shared_traces.size == 0: return pd.DataFrame() df_s = ( self.shared_traces.to_dataset(dim="variable").to_dataframe().reset_index() ) df_s["unit"] = "shared" df_u = self.unit_traces.to_dataset(dim="variable").to_dataframe().reset_index() df_u = df_u.rename(columns={"unitLogLik": "logLik"}) if self.theta is not None: shared_df = self.theta.theta[0].get("shared") unit_df = self.theta.theta[0].get("unit_specific") shared_params = list(shared_df.index) if shared_df is not None else [] if unit_df is not None: df_long = df_u.merge( df_s[["theta_idx", "iteration"] + shared_params], on=["theta_idx", "iteration"], how="left", ) return df_long else: return df_s.assign(method="train") return df_u.assign(method="train") return pd.concat([df_s, df_u], ignore_index=True).assign(method="train") def CLL(self, average: bool = False) -> pd.DataFrame: """Return conditional log-likelihoods as a DataFrame.""" return pd.DataFrame() def ESS(self, average: bool = False) -> pd.DataFrame: """Return Effective Sample Size as a DataFrame.""" return pd.DataFrame() def print_summary(self): """Print summary of panel train result.""" print(f"Method: {self.method}") print( f"Number of parameter sets: {self.theta.num_replicates() if self.theta is not None else 0}" ) print(f"Optimizer: {self.optimizer}") print(f"Number of particles (J): {self.J}") print(f"Number of iterations (M): {self.M}") print(f"Learning rate (eta): {self.eta}") print(f"Discount factor (alpha): {self.alpha}") print(f"Cooling factor for eta: {self.eta_cooling}") print(f"Cooling factor for alpha: {self.alpha_cooling}") print(f"Execution time: {self.execution_time} seconds") df = self.to_dataframe() if not df.empty: print("\nTop 5 Results:") df_sorted = df.sort_values("shared logLik", ascending=False).head(5) print(df_sorted.to_string()) @staticmethod def merge(*results: "PanelPompTrainResult") -> "PanelPompTrainResult": """Merge replications from multiple PanelPompTrainResult objects into a single object.""" if len(results) == 0: raise ValueError( "At least one PanelPompTrainResult object must be provided." ) first = results[0] for result in results: if not isinstance(result, type(first)): raise TypeError( "All merged objects must be of type PanelPompTrainResult." ) if ( result.optimizer != first.optimizer or result.J != first.J or result.M != first.M or result.eta != first.eta or result.alpha != first.alpha or result.eta_cooling != first.eta_cooling or result.alpha_cooling != first.alpha_cooling ): raise ValueError( "All PanelPompTrainResult objects must have the same optimizer, J, M, eta, alpha, eta_cooling, and alpha_cooling." ) merged_theta = ( PanelParameters.merge(*[r.theta for r in results if r.theta is not None]) if any(r.theta is not None for r in results) else None ) shared_trace_arrays = [ r.shared_traces for r in results if r.shared_traces.size > 0 ] merged_shared_traces = ( xr.concat(shared_trace_arrays, dim="theta_idx") if shared_trace_arrays else xr.DataArray([]) ) # type: ignore[assignment] unit_trace_arrays = [r.unit_traces for r in results if r.unit_traces.size > 0] merged_unit_traces = ( xr.concat(unit_trace_arrays, dim="theta_idx") if unit_trace_arrays else xr.DataArray([]) ) # type: ignore[assignment] logLik_arrays = [r.logLiks for r in results if r.logLiks.size > 0] merged_logLiks = ( xr.concat(logLik_arrays, dim="theta_idx") if logLik_arrays else xr.DataArray([]) ) # type: ignore[assignment] execution_times = [ r.execution_time for r in results if r.execution_time is not None ] max_execution_time = max(execution_times) if execution_times else None return PanelPompTrainResult( method=first.method, execution_time=max_execution_time, key=first.key, theta=merged_theta, shared_traces=merged_shared_traces, unit_traces=merged_unit_traces, logLiks=merged_logLiks, optimizer=first.optimizer, J=first.J, M=first.M, eta=first.eta, alpha=first.alpha, eta_cooling=first.eta_cooling, alpha_cooling=first.alpha_cooling, )
[docs] class ResultsHistory: """Container class for managing result history.""" _entries: list[BaseResult] = field(default_factory=list) def __init__(self): self._entries = [] def add(self, result: BaseResult): """Add a result entry.""" self._entries.append(result) def __eq__(self, other) -> bool: # type: ignore[override] """ Structural equality for ResultsHistory. Two histories are equal if they contain the same sequence of result objects (compared via their own __eq__ implementations). """ if not isinstance(other, type(self)): return False if len(self._entries) != len(other._entries): return False for a, b in zip(self._entries, other._entries): if a != b: return False return True def __getitem__(self, index): """Get result by index.""" return self._entries[index] def __len__(self): """Get number of entries.""" return len(self._entries) def __iter__(self): """Iterate over entries.""" return iter(self._entries) def clear(self): """Clear all entries from the history.""" self._entries.clear() def last(self) -> BaseResult: """Get last entry.""" if not self._entries: raise ValueError("History is empty") return self._entries[-1] def results(self, index: int = -1, ignore_nan: bool = False) -> pd.DataFrame: """Get results DataFrame for entry at index.""" if not self._entries: return pd.DataFrame() result = self._entries[index] return result.to_dataframe(ignore_nan=ignore_nan) def CLL(self, index: int = -1, average: bool = False) -> pd.DataFrame: """Get conditional log-likelihoods for entry at index.""" if not self._entries: return pd.DataFrame() result = self._entries[index] return result.CLL(average=average) def ESS(self, index: int = -1, average: bool = False) -> pd.DataFrame: """Get Effective Sample Size for entry at index.""" if not self._entries: return pd.DataFrame() result = self._entries[index] return result.ESS(average=average) def time(self) -> pd.DataFrame: """Return execution times DataFrame.""" rows = [] for idx, res in enumerate(self._entries): method = res.method exec_time = res.execution_time rows.append({"method": method, "time": exec_time}) df = pd.DataFrame(rows) df.index.name = "history_index" return df def traces(self) -> pd.DataFrame: """ Return traces DataFrame from entire result history. Handles continuous iteration counting across chained runs (e.g., MIF -> MIF) and aligns checkpoints (PFilter). """ if not self._entries: return pd.DataFrame() all_dfs = [] global_iter_counters: dict[int, int] = {} for res in self._entries: if not hasattr(res, "traces"): continue df = res.traces() # pyright: ignore[reportAttributeAccessIssue] if df.empty: continue is_estimation = res.method in ["mif", "train"] unique_reps = df["theta_idx"].unique() offsets_map = {r: global_iter_counters.get(r, 0) for r in unique_reps} row_offsets = df["theta_idx"].map(offsets_map) if is_estimation: mask = (df["iteration"] > 0) | (row_offsets == 0) df = df.loc[mask].copy() row_offsets = df["theta_idx"].map(offsets_map) df["iteration"] = df["iteration"] + row_offsets new_maxes = df.groupby("theta_idx")["iteration"].max() for r, mx in new_maxes.items(): global_iter_counters[r] = int(mx) else: # LOGIC: PFilter is a snapshot. # Plot it at the current "end" of the timeline. df = df.copy() df["iteration"] = row_offsets # We do NOT increment the global_iter_counters here if not df.empty: all_dfs.append(df) if not all_dfs: return pd.DataFrame() result_df = pd.concat(all_dfs, ignore_index=True) sort_cols = ["theta_idx", "iteration"] if "unit" in result_df.columns: sort_cols.insert(1, "unit") result_df = result_df.sort_values(sort_cols).reset_index(drop=True) canonical_first = ["theta_idx", "unit", "iteration", "method", "logLik"] existing_first = [c for c in canonical_first if c in result_df.columns] remaining = [c for c in result_df.columns if c not in existing_first] result_df = result_df[existing_first + remaining] assert isinstance(result_df, pd.DataFrame), ( "result_df is not a DataFrame; something went wrong" ) return result_df def print_summary(self): """Print summary of all entries.""" if not self._entries: print("No results history.") return print("Results history:") print("----------------") for idx, entry in enumerate(self._entries, 1): print(f"Results entry {idx}:") entry.print_summary() print() @staticmethod def merge(*histories: "ResultsHistory") -> "ResultsHistory": """Merge replications from multiple ResultsHistory objects into a single object.""" if len(histories) == 0: raise ValueError("At least one ResultsHistory object must be provided.") # Check if all histories have the same number of entries entry_lengths = [len(h._entries) for h in histories] if len(set(entry_lengths)) != 1: raise ValueError( f"Cannot merge ResultsHistory objects: differing number of entries ({entry_lengths})" ) merged_history = ResultsHistory() for i in range(entry_lengths[0]): results_at_position = [] for history in histories: if i < len(history._entries): results_at_position.append(history._entries[i]) if not results_at_position: continue first_result = results_at_position[0] result_type = type(first_result) if not all(isinstance(r, result_type) for r in results_at_position): raise ValueError( f"Results at position {i} have different types and cannot be merged." ) if hasattr(result_type, "merge"): merged_result = result_type.merge(*results_at_position) merged_history.add(merged_result) else: raise ValueError( f"Result type {result_type} does not have a merge method." ) return merged_history