Source code for pypomp.panel.panel
"""
This module implements the OOP structure for PanelPOMP models.
"""
import jax
import pandas as pd
from pypomp.core.pomp import Pomp
from .validation_mixin import PanelValidationMixin
from .estimation_mixin import PanelEstimationMixin
from .analysis_mixin import PanelAnalysisMixin
from pypomp.core.results import ResultsHistory
from pypomp.core.parameters import PanelParameters
from pypomp.core.metadata import ModelMetadata
from copy import deepcopy
[docs]
class PanelPomp(PanelValidationMixin, PanelEstimationMixin, PanelAnalysisMixin):
"""
The PanelPomp class represents a panel of partially observed Markov process models.
It extends the single-unit POMP framework to handle multiple units that share
structural characteristics but may have distinct parameter values and observations.
In particular, the class provides methods for:
- Simulation of panel models
- Particle filtering for panel models
- Marginalized Panel Iterated Filtering (MPIF)
- Gradient descent via automatic differentiation
Parameters
----------
Pomp_dict : dict[str, Pomp]
A dictionary mapping unit names to Pomp objects. Each Pomp object represents a single unit in the panel data.
The keys are used as unit identifiers.
theta : PanelParameters | dict | list, optional
A PanelParameters object, a dictionary with "shared" and "unit_specific" keys, or a list of such dictionaries.
"""
unit_objects: dict[str, Pomp]
theta: PanelParameters
results_history: ResultsHistory
fresh_key: jax.Array | None
metadata: ModelMetadata
canonical_param_names: list[str]
canonical_shared_param_names: list[str]
canonical_unit_param_names: list[str]
def __init__(
self,
Pomp_dict: dict[str, Pomp],
theta: PanelParameters
| dict[str, pd.DataFrame | None]
| list[dict[str, pd.DataFrame | None]]
| None = None,
):
if theta is not None:
if isinstance(theta, PanelParameters):
self.theta = theta
else:
self.theta = PanelParameters(theta=theta)
else:
self.theta = PanelParameters(theta=None)
self.unit_objects = Pomp_dict
self.results_history = ResultsHistory()
self.fresh_key = None
self.metadata = ModelMetadata()
self.canonical_param_names = self.theta.get_param_names()
self.canonical_shared_param_names = self.theta.get_shared_param_names()
self.canonical_unit_param_names = self.theta.get_unit_param_names()
self._validate_params_and_units()
for unit in self.unit_objects.keys():
self.unit_objects[unit].theta = None # type: ignore
[docs]
def get_unit_names(self) -> list[str]:
return list(self.unit_objects.keys())
[docs]
def print_summary(self, n: int = 5):
"""
Print a summary of the PanelPomp object.
"""
first_unit = list(self.unit_objects.keys())[0]
print("Basics:")
print("-------")
print(f"Number of units: {len(self.unit_objects)}")
print(f"Number of parameters: {len(self.canonical_param_names)}")
print(
f"Number of observations (first unit): {len(self.unit_objects[first_unit].ys)}"
)
print(
f"Number of time steps (first unit): {len(self.unit_objects[first_unit]._dt_array_extended)}"
)
print(f"Number of parameter sets: {self.theta.num_replicates()}")
print()
self.results_history.print_summary(n=n)
def __eq__(self, other):
"""
Check structural equality with another PanelPomp object.
Two PanelPomp instances are considered equal if they:
- Are of the same type
- Have identical canonical parameter name lists
- Have equal PanelParameters (self.theta)
- Have the same unit names in the same order
- Have unit Pomp objects with identical data and parameter structure
- Have equal results_history
- Have equal fresh_key values (or both None)
"""
if not isinstance(other, type(self)):
return False
# Canonical parameter structure
if self.canonical_param_names != other.canonical_param_names:
return False
if self.canonical_shared_param_names != other.canonical_shared_param_names:
return False
if self.canonical_unit_param_names != other.canonical_unit_param_names:
return False
# Panel parameters
if self.theta != other.theta:
return False
# Unit objects: same unit names and comparable structure
self_units = list(self.unit_objects.keys())
other_units = list(other.unit_objects.keys())
if self_units != other_units:
return False
for unit in self_units:
if self.unit_objects[unit] != other.unit_objects[unit]:
return False
if self.results_history != other.results_history:
return False
if (self.fresh_key is None) != (other.fresh_key is None):
return False
if self.fresh_key is not None and other.fresh_key is not None:
if not jax.numpy.array_equal(
jax.random.key_data(self.fresh_key),
jax.random.key_data(other.fresh_key),
):
return False
return True
[docs]
@staticmethod
def merge(*panel_pomp_objs: "PanelPomp") -> "PanelPomp":
"""
Merge replications from multiple PanelPomp objects into a single object.
All panel objects must have the same units and canonical parameter names.
"""
if len(panel_pomp_objs) == 0:
raise ValueError("At least one PanelPomp object must be provided.")
first = panel_pomp_objs[0]
for obj in panel_pomp_objs:
if not isinstance(obj, type(first)):
raise TypeError("All merged objects must be of type PanelPomp.")
if obj.canonical_param_names != first.canonical_param_names:
raise ValueError(
"All PanelPomp objects must have the same canonical_param_names."
)
if obj.canonical_shared_param_names != first.canonical_shared_param_names:
raise ValueError(
"All PanelPomp objects must have the same canonical_shared_param_names."
)
if obj.canonical_unit_param_names != first.canonical_unit_param_names:
raise ValueError(
"All PanelPomp objects must have the same canonical_unit_param_names."
)
if list(obj.unit_objects.keys()) != list(first.unit_objects.keys()):
raise ValueError("All PanelPomp objects must have the same unit names.")
merged_theta = PanelParameters.merge(*[obj.theta for obj in panel_pomp_objs])
merged_history = ResultsHistory.merge(
*[obj.results_history for obj in panel_pomp_objs]
)
merged_panel_pomp = deepcopy(first)
merged_panel_pomp.theta = merged_theta
merged_panel_pomp.results_history = merged_history
merged_panel_pomp.fresh_key = first.fresh_key
return merged_panel_pomp
def __getstate__(self):
"""
Custom pickling method to handle wrapped function objects. This is
necessary because the JAX-wrapped functions in the Pomp objects are not picklable.
"""
state = self.__dict__.copy()
if self.fresh_key is not None:
state["_fresh_key_data"] = jax.random.key_data(self.fresh_key)
state.pop("fresh_key", None)
if hasattr(self, "unit_objects") and self.unit_objects is not None:
unit_objects_state = {}
for unit_name, pomp_obj in self.unit_objects.items():
unit_objects_state[unit_name] = pomp_obj.__getstate__()
state["_unit_objects_state"] = unit_objects_state
state.pop("unit_objects", None)
return state
def __setstate__(self, state):
"""
Custom unpickling method to reconstruct wrapped function objects. This is
necessary because the JAX-wrapped functions in the Pomp objects are not picklable.
"""
self.__dict__.update(state)
if "_fresh_key_data" in state:
self.fresh_key = jax.random.wrap_key_data(state["_fresh_key_data"])
elif "fresh_key" not in self.__dict__:
self.fresh_key = None
self.__dict__.pop("_fresh_key_data", None)
if "_unit_objects_state" in state:
unit_objects = {}
for unit_name, pomp_state in state["_unit_objects_state"].items():
pomp_obj = Pomp.__new__(Pomp)
pomp_obj.__setstate__(pomp_state)
unit_objects[unit_name] = pomp_obj
self.unit_objects = unit_objects
del self.__dict__["_unit_objects_state"]
else:
self.unit_objects = {}