Source code for pypomp.core.parameters.panel

from __future__ import annotations
import pandas as pd
import jax.numpy as jnp
import numpy as np
import jax
import xarray as xr
from typing import (
    Union,
    Literal,
    Any,
    overload,
)

from .base import ParameterSet
from ..par_trans import ParTrans


def _standardize_panel_theta(
    theta: Union[
        dict[str, pd.DataFrame | None],
        list[dict[str, pd.DataFrame | None]],
        None,
    ],
) -> tuple[xr.Dataset, list[str], list[str]]:
    if theta is None:
        shared_da = xr.DataArray(
            np.empty((0, 0)),
            dims=["theta_idx", "parameter"],
            coords={"theta_idx": [], "parameter": []},
        )
        unit_specific_da = xr.DataArray(
            np.empty((0, 0, 0)),
            dims=["theta_idx", "unit", "parameter"],
            coords={"theta_idx": [], "unit": [], "parameter": []},
        )
        ds = xr.Dataset(
            data_vars={
                "shared": shared_da,
                "unit_specific": unit_specific_da,
            }
        )
        return ds, [], []

    if isinstance(theta, dict):
        theta_list = [theta]
    else:
        theta_list = list(theta)

    if not isinstance(theta_list, list):
        raise TypeError("theta must be a dictionary or a list of dictionaries")

    # Copy the structures, convert to floats, and validate keys to avoid side-effects
    clean_theta = []
    for i, t in enumerate(theta_list):
        keys = set(t.keys())
        if keys != {"shared", "unit_specific"}:
            raise ValueError(
                f"Each parameter dictionary must have exactly the keys 'shared' and 'unit_specific'. "
                f"Found keys {keys} in item {i}."
            )
        if not all(isinstance(v, (pd.DataFrame, type(None))) for v in t.values()):
            raise TypeError(
                f"All values in each dictionary must be None or pd.DataFrames. "
                f"Found values {t.values()} of type {type(t.values())} in item {i}."
            )
        t_copy = {"shared": t["shared"], "unit_specific": t["unit_specific"]}
        if t_copy["shared"] is not None:
            t_copy["shared"] = t_copy["shared"].astype(float)
        if t_copy["unit_specific"] is not None:
            t_copy["unit_specific"] = t_copy["unit_specific"].astype(float)
        clean_theta.append(t_copy)

    # Consistency checks
    shared_none = [t["shared"] is None for t in clean_theta]
    unit_none = [t["unit_specific"] is None for t in clean_theta]

    some_shared_none = any(shared_none) and not all(shared_none)
    some_unit_specific_none = any(unit_none) and not all(unit_none)
    if some_shared_none:
        raise ValueError(
            "Some, but not all, shared parameters are None. This is not supported."
        )
    if some_unit_specific_none:
        raise ValueError(
            "Some, but not all, unit-specific parameters are None. This is not supported."
        )

    # Check dataframe consistency
    ref = clean_theta[0]
    if ref["shared"] is not None:
        ref_s_idx = ref["shared"].index
        ref_s_cols = ref["shared"].columns
        if len(ref_s_cols) != 1:
            raise ValueError("Shared parameters must have exactly one column.")
    else:
        ref_s_idx = []

    if ref["unit_specific"] is not None:
        ref_u_idx = ref["unit_specific"].index
        ref_u_cols = ref["unit_specific"].columns
    else:
        ref_u_idx = []
        ref_u_cols = []

    shared_param_names = set(ref_s_idx)
    unit_param_names = set(ref_u_idx)
    overlap = shared_param_names.intersection(unit_param_names)
    if overlap:
        raise ValueError(
            f"Parameter name(s) found in both shared and unit-specific parameters: {sorted(overlap)}"
        )

    for i, t in enumerate(clean_theta[1:], 1):
        if t["shared"] is not None:
            if not t["shared"].index.equals(ref_s_idx):
                raise ValueError(f"Shared parameter index mismatch at replicate {i}.")
        if t["unit_specific"] is not None:
            if not t["unit_specific"].index.equals(ref_u_idx):
                raise ValueError(f"Unit parameter index mismatch at replicate {i}.")
            if not t["unit_specific"].columns.equals(ref_u_cols):
                raise ValueError(f"Unit columns mismatch at replicate {i}.")

    # Gather names
    shared_names_list = list(ref_s_idx)
    unit_specific_names_list = list(ref_u_idx)

    if ref["unit_specific"] is not None:
        unit_names = list(ref_u_cols)
    else:
        unit_names = []

    reps = len(clean_theta)

    if ref["shared"] is not None:
        shared_values = np.stack(
            [t["shared"].loc[shared_names_list].iloc[:, 0].values for t in clean_theta]
        )
    else:
        shared_values = np.zeros((reps, 0))

    if ref["unit_specific"] is not None:
        unit_values = np.stack(
            [
                t["unit_specific"].loc[unit_specific_names_list, unit_names].T.values
                for t in clean_theta
            ]
        )
    else:
        unit_values = np.zeros((reps, len(unit_names), 0))

    shared_da = xr.DataArray(
        shared_values,
        dims=["theta_idx", "parameter"],
        coords={
            "theta_idx": np.arange(reps),
            "parameter": shared_names_list,
        },
    )
    unit_specific_da = xr.DataArray(
        unit_values,
        dims=["theta_idx", "unit", "parameter"],
        coords={
            "theta_idx": np.arange(reps),
            "unit": unit_names,
            "parameter": unit_specific_names_list,
        },
    )

    ds = xr.Dataset(
        data_vars={
            "shared": shared_da,
            "unit_specific": unit_specific_da,
        }
    )
    ds.attrs["shared_names"] = shared_names_list
    ds.attrs["unit_specific_names"] = unit_specific_names_list
    return ds, shared_names_list, unit_specific_names_list


[docs] class PanelParameters(ParameterSet[xr.Dataset]): """ Manages parameters for PanelPomp models. Internal storage is a 3D ``xarray.DataArray`` with dimensions ``("theta_idx", "unit", "parameter")``. Parameters ---------- theta : PanelParameters | dict | list | xr.DataArray, optional Parameters for the panel model. Accepts: - A single dictionary with ``"shared"`` and ``"unit_specific"`` keys (each containing a DataFrame). - A list of such dictionaries. - An existing :class:`~pypomp.core.parameters.PanelParameters` object. - An existing ``xarray.DataArray`` with dimensions ``("theta_idx", "unit", "parameter")``. logLik_unit : np.ndarray, optional A numpy array of unit-specific log-likelihoods of shape ``(num_theta_idx, n_units)``. estimation_scale : bool, optional Whether the parameters are on the estimation scale. Defaults to False. """ _data: xr.Dataset estimation_scale: bool _logLik_unit: np.ndarray _logLik: np.ndarray _canonical_shared_param_names: list[str] _canonical_unit_param_names: list[str] _canonical_param_names: list[str] def __init__( self, theta: Union[ dict[str, pd.DataFrame | None], list[dict[str, pd.DataFrame | None]], "PanelParameters", xr.Dataset, None, ], logLik_unit: np.ndarray | None = None, estimation_scale: bool = False, ): if isinstance(theta, PanelParameters): self._data = theta._data.copy(deep=True) self.estimation_scale = theta.estimation_scale self._canonical_shared_param_names = list( theta._canonical_shared_param_names ) self._canonical_unit_param_names = list(theta._canonical_unit_param_names) self._canonical_param_names = list(theta._canonical_param_names) self._logLik_unit = ( theta.logLik_unit.copy() if logLik_unit is None else self._format_logLik_unit( logLik_unit, self._data.sizes["theta_idx"] ) ) self._logLik = self._logLik_unit.sum(axis=1) return if isinstance(theta, xr.Dataset): self._data = theta.copy(deep=True) raw_s = self._data.attrs.get("shared_names") if raw_s is None: raw_s = ( list(self._data["shared"].coords["parameter"].values) if "shared" in self._data else [] ) self._canonical_shared_param_names = [str(x) for x in raw_s] raw_u = self._data.attrs.get("unit_specific_names") if raw_u is None: raw_u = ( list(self._data["unit_specific"].coords["parameter"].values) if "unit_specific" in self._data else [] ) self._canonical_unit_param_names = [str(x) for x in raw_u] else: ds, s_names, u_names = _standardize_panel_theta(theta) self._data = ds self._canonical_shared_param_names = [str(x) for x in s_names] self._canonical_unit_param_names = [str(x) for x in u_names] self.estimation_scale = estimation_scale self._logLik_unit = self._format_logLik_unit( logLik_unit, self._data.sizes["theta_idx"] ) self._logLik = self._logLik_unit.sum(axis=1) self._canonical_param_names = list( set(self._canonical_shared_param_names + self._canonical_unit_param_names) ) def _format_logLik_unit( self, ll_unit: np.ndarray | None, n_reps: int ) -> np.ndarray: """Standardize logLik dimensions.""" n_units = 0 if n_reps > 0: n_units = len(self.get_unit_names()) if ll_unit is None: return np.full((n_reps, n_units), np.nan) ll_unit = np.array(ll_unit, dtype=float) if ll_unit.ndim == 1 and n_reps == 1: return ll_unit.reshape(1, -1) if ll_unit.shape != (n_reps, n_units): if n_units == 0 and ll_unit.size == 0: return np.empty((n_reps, 0)) raise ValueError( f"logLik_unit shape mismatch: {ll_unit.shape} vs ({n_reps}, {n_units})" ) return ll_unit @property def logLik(self) -> np.ndarray: """ Get the overall log-likelihood for each parameter set (theta_idx). """ return self._logLik @logLik.setter def logLik(self, value): raise AttributeError( "Cannot set logLik directly on PanelParameters. " "Please assign unit-specific log-likelihoods to logLik_unit instead." ) @property def logLik_unit(self) -> np.ndarray: """ Get or set the unit-specific log-likelihoods for each parameter set (theta_idx). """ return self._logLik_unit @logLik_unit.setter def logLik_unit(self, value): self._logLik_unit = self._format_logLik_unit(value, self.num_replicates()) self._logLik = self._logLik_unit.sum(axis=1)
[docs] def get_shared_param_names(self) -> list[str]: """Return the list of shared parameter names.""" return self._canonical_shared_param_names
[docs] def get_unit_param_names(self) -> list[str]: """Return the list of unit-specific parameter names.""" return self._canonical_unit_param_names
[docs] def get_unit_names(self) -> list[str]: """Return the list of unit names.""" if ( "unit_specific" in self._data and "unit" in self._data["unit_specific"].coords ): return list(self._data["unit_specific"].coords["unit"].values) return []
[docs] def subset(self, indices: Union[int, list[int], slice]) -> "PanelParameters": """ Return a new PanelParameters object with the specified parameter set (theta_idx) indices. """ if isinstance(indices, int): indices = [indices] sub_data = self._data.isel(theta_idx=indices) sub_data.coords["theta_idx"] = np.arange(sub_data.sizes["theta_idx"]) sub_ll = self._logLik_unit[indices] return PanelParameters( sub_data, logLik_unit=sub_ll, estimation_scale=self.estimation_scale )
[docs] def to_jax_array( self, param_names: list[str] | None = None, unit_names: list[str] | None = None, **kwargs, ) -> jax.Array: """ Convert to a JAX array matching the order of param_names and unit_names. Parameters ---------- param_names : list[str], optional A list of parameter names in the desired order. If None (default), returns the array matching the canonical order of parameters. unit_names : list[str], optional A list of unit names in the desired order. If None (default), returns array for all units. Returns ------- jax.Array A JAX array of shape (num_theta_idx, n_units, n_params). """ if param_names is None: param_names = self.get_param_names() reps = self.num_replicates() if reps == 0: return jnp.empty((0, 0, 0)) if unit_names is None: existing_units = self.get_unit_names() if not existing_units: raise ValueError( "unit_names required when no unit_specific parameters exist" ) unit_names = existing_units n_units = len(unit_names) n_params = len(param_names) shared_keys = set(self._canonical_shared_param_names) specific_keys = set(self._canonical_unit_param_names) existing_units = self.get_unit_names() # Pre-validate keys and units for p_name in param_names: if p_name not in shared_keys and p_name not in specific_keys: raise KeyError(f"Parameter '{p_name}' not found.") if p_name in specific_keys: for u in unit_names: if u not in existing_units: raise KeyError(f"Unit mismatch for parameter {p_name}") out_array = np.zeros((reps, n_units, n_params)) for p_idx, p_name in enumerate(param_names): if p_name in specific_keys: out_array[:, :, p_idx] = ( self._data["unit_specific"] .sel(parameter=p_name, unit=unit_names) .values ) else: # p_name in shared_keys shared_vals = self._data["shared"].sel(parameter=p_name).values out_array[:, :, p_idx] = np.broadcast_to( shared_vals[:, None], (reps, n_units) ) return jnp.array(out_array)
[docs] def mix_and_match(self) -> None: """ Mixes unit-specific and shared parameters independently by sorting each unit's unit-specific parameters and the shared parameters in descending order of their respective log-likelihood contribution. """ unit_names = self.get_unit_names() if self.num_replicates() == 0: return shared_ranks = self._logLik.argsort()[::-1] unit_ranks = {} for u_idx, unit in enumerate(unit_names): unit_ranks[unit] = self._logLik_unit[:, u_idx].argsort()[::-1] reps = self.num_replicates() shared_keys = self._canonical_shared_param_names specific_keys = self._canonical_unit_param_names # Reorder shared parameters new_shared_da = ( self._data["shared"] .sel(parameter=shared_keys) .isel(theta_idx=shared_ranks) .assign_coords(theta_idx=np.arange(reps)) ) # Reorder unit-specific parameters and unit log-likelihoods new_unit_values = np.zeros((reps, len(unit_names), len(specific_keys))) new_ll_unit = np.zeros_like(self._logLik_unit) for u_idx, unit in enumerate(unit_names): best_idx = unit_ranks[unit] new_unit_values[:, u_idx, :] = ( self._data["unit_specific"] .sel(unit=unit, parameter=specific_keys) .isel(theta_idx=best_idx) .values ) new_ll_unit[:, u_idx] = self._logLik_unit[best_idx, u_idx] unit_specific_da = xr.DataArray( new_unit_values, dims=["theta_idx", "unit", "parameter"], coords={ "theta_idx": np.arange(reps), "unit": unit_names, "parameter": specific_keys, }, ) self._data = xr.Dataset( data_vars={ "shared": new_shared_da, "unit_specific": unit_specific_da, } ) self._logLik_unit = new_ll_unit self._logLik = new_ll_unit.sum(axis=1)
@overload def params( self, as_list: Literal[True] = True ) -> list[dict[str, pd.DataFrame | None]]: ... @overload def params(self, as_list: Literal[False]) -> xr.Dataset: ... @overload def params( self, as_list: bool = True ) -> list[dict[str, pd.DataFrame | None]] | xr.Dataset: ...
[docs] def params( self, as_list: bool = True ) -> list[dict[str, pd.DataFrame | None]] | xr.Dataset: """ Get the parameters in this set. Parameters ---------- as_list : bool, default True If True, returns the parameters as a list of dictionaries with keys "shared" and "unit_specific". If False, returns the internal xarray Dataset. Returns ------- list[dict[str, pd.DataFrame | None]] | xr.Dataset The parameters either as a list of dictionaries or as a Dataset. """ return super().params(as_list)
[docs] def set_params( self, value: dict[str, pd.DataFrame | None] | list[dict[str, pd.DataFrame | None]] | xr.Dataset, ) -> None: """ Set or overwrite the parameter values. Parameters ---------- value : dict[str, pd.DataFrame | None] | list[dict[str, pd.DataFrame | None]] | xr.Dataset The new panel parameter values. Accepts: - A single dictionary with ``"shared"`` and ``"unit_specific"`` keys (each containing a DataFrame). - A list of such dictionaries. - An existing :class:`xarray.Dataset` of panel parameters. """ if value is None: raise ValueError("theta cannot be None") if isinstance(value, xr.Dataset): self._data = value.copy(deep=True) raw_s = self._data.attrs.get("shared_names") if raw_s is None: raw_s = ( list(self._data["shared"].coords["parameter"].values) if "shared" in self._data else [] ) s_names = [str(x) for x in raw_s] raw_u = self._data.attrs.get("unit_specific_names") if raw_u is None: raw_u = ( list(self._data["unit_specific"].coords["parameter"].values) if "unit_specific" in self._data else [] ) u_names = [str(x) for x in raw_u] else: self._data, s_names, u_names = _standardize_panel_theta(value) s_names = [str(x) for x in s_names] u_names = [str(x) for x in u_names] self._canonical_shared_param_names = s_names self._canonical_unit_param_names = u_names self._canonical_param_names = list(set(s_names + u_names)) self._logLik_unit = self._format_logLik_unit(None, self.num_replicates()) self._logLik = self._logLik_unit.sum(axis=1)
def _to_list(self) -> list[dict[str, pd.DataFrame | None]]: """Return the parameter sets as a list of dictionaries with 'shared' and 'unit_specific' DataFrames.""" reps = self.num_replicates() if reps == 0: return [] shared_names = self._canonical_shared_param_names unit_specific_names = self._canonical_unit_param_names unit_names = self.get_unit_names() out = [] for j in range(reps): t_dict = {} if shared_names: t_dict["shared"] = pd.DataFrame( self._data["shared"] .isel(theta_idx=j) .sel(parameter=shared_names) .values, index=pd.Index(shared_names), columns=["shared"], ) else: t_dict["shared"] = None if unit_specific_names and unit_names: t_dict["unit_specific"] = pd.DataFrame( self._data["unit_specific"] .isel(theta_idx=j) .sel(parameter=unit_specific_names, unit=unit_names) .values.T, index=pd.Index(unit_specific_names), columns=pd.Index(unit_names), ) else: t_dict["unit_specific"] = None out.append(t_dict) return out def _replicated_logLik(self, n: int) -> dict[str, np.ndarray]: if self._logLik_unit.size > 0: new_ll_unit = np.tile(self._logLik_unit, (n, 1)) else: new_ll_unit = np.empty((n * self.num_replicates(), 0)) return {"logLik_unit": new_ll_unit} def _slice_logLik(self, indices: np.ndarray) -> None: self._logLik_unit = self._logLik_unit[indices] self._logLik = self._logLik_unit.sum(axis=1) def _eq_logLik(self, other: "PanelParameters") -> bool: if self._canonical_shared_param_names != other._canonical_shared_param_names: return False if self._canonical_unit_param_names != other._canonical_unit_param_names: return False return np.array_equal(self._logLik_unit, other._logLik_unit, equal_nan=True) def _getitem_int(self, index: int) -> dict[str, pd.DataFrame | None]: return self._to_list()[index] def _transform_and_load( self, par_trans: ParTrans, param_list: list[Any], direction: Literal["to_est", "from_est"], ) -> None: transformed_list = par_trans._panel_transform_list( param_list, direction=direction ) ds, s_names, u_names = _standardize_panel_theta(transformed_list) self._data = ds self._canonical_shared_param_names = [str(x) for x in s_names] self._canonical_unit_param_names = [str(x) for x in u_names] @staticmethod def merge(*param_objs: "PanelParameters") -> "PanelParameters": """Merge replications from multiple PanelParameters objects.""" if len(param_objs) == 0: raise ValueError("At least one PanelParameters object must be provided.") first = param_objs[0] for obj in param_objs: if not isinstance(obj, PanelParameters): raise TypeError("All merged objects must be of type PanelParameters.") if obj._canonical_shared_param_names != first._canonical_shared_param_names: raise ValueError( "All PanelParameters objects must have the same canonical shared parameter names." ) if obj._canonical_unit_param_names != first._canonical_unit_param_names: raise ValueError( "All PanelParameters objects must have the same canonical unit parameter names." ) if obj.estimation_scale != first.estimation_scale: raise ValueError( "All PanelParameters objects must have the same estimation scale." ) if obj.get_unit_names() != first.get_unit_names(): raise ValueError( "All PanelParameters objects must have the same unit names." ) merged_data = xr.concat([obj._data for obj in param_objs], dim="theta_idx") merged_data.coords["theta_idx"] = np.arange(merged_data.sizes["theta_idx"]) merged_data.attrs["shared_names"] = first._canonical_shared_param_names merged_data.attrs["unit_specific_names"] = first._canonical_unit_param_names all_logLik_unit = [obj._logLik_unit for obj in param_objs] merged_logLik_unit = ( np.concatenate(all_logLik_unit, axis=0) if all_logLik_unit else np.array([]) ) return PanelParameters( merged_data, logLik_unit=merged_logLik_unit, estimation_scale=first.estimation_scale, )