Source code for pypomp.core.parameters

"""
This module defines the parameter classes for Pomp and PanelPomp models.
It handles input validation, standardization, and conversion to JAX arrays.
"""

from __future__ import annotations
from abc import ABC, abstractmethod
import copy
import pandas as pd
import jax.numpy as jnp
import numpy as np
import jax
from typing import Union, Literal, Mapping, cast, overload
from .par_trans import ParTrans
from pypomp.types import Numeric, ThetaInput, ParamDict


class ParameterSet(ABC):
    """
    Abstract base class for parameter sets used in POMP models.
    """

    @abstractmethod
    def to_jax_array(self, param_names: list[str], **kwargs) -> jax.Array:
        """
        Converts the parameters to a JAX array suitable for model functions.

        Args:
            param_names: A list of canonical parameter names expected by the model.
            **kwargs: Additional context required for conversion (e.g. unit names).

        Returns:
            A JAX array representing the parameters.
            - For Pomp: Shape (reps, n_params)
            - For PanelPomp: Shape (reps, n_units, n_params)
        """
        pass

    @abstractmethod
    def num_replicates(self) -> int:
        """Returns the number of parameter replicates (J)."""
        pass

    @abstractmethod
    def subset(self, indices: Union[int, list[int], slice]) -> "ParameterSet":
        """
        Returns a new ParameterSet containing only the specified replicate indices.
        """
        pass

    @abstractmethod
    def get_param_names(self) -> list[str] | tuple[list[str], list[str]]:
        """Returns the list of parameter names contained in this set."""
        pass


[docs] class PompParameters(ParameterSet): """ Manages parameters for a standard Pomp model. Internal storage is a list of dictionaries. Parameters ---------- theta : ThetaInput Parameters for the model. Accepts: - A single dictionary: dict[str, Numeric] - A list of dictionaries: list[dict[str, Numeric]] - An existing PompParameters object logLik : np.ndarray, optional A numpy array of log-likelihoods. estimation_scale : bool, optional Whether the parameters are in the estimation scale. Defaults to False. """ _params: list[dict[str, float]] _canonical_param_names: list[str] estimation_scale: bool _logLik: np.ndarray def __init__( self, theta: ThetaInput, logLik: np.ndarray | None = None, estimation_scale: bool = False, ): if theta is None: self._params = [] self._logLik = np.full(0, np.nan) self._canonical_param_names = [] self.estimation_scale = False return if isinstance(theta, PompParameters): # Copy constructor behavior (shallow copy of list) self._params = list(theta._params) self._logLik = ( theta.logLik.copy() if logLik is None else self._format_logLik(logLik, len(self._params)) ) self._canonical_param_names = theta._canonical_param_names self.estimation_scale = theta.estimation_scale return theta_dicts: list[dict[str, Numeric]] = [] # Normalize input to list of dicts if isinstance(theta, Mapping): theta_dicts = [dict(theta)] elif isinstance(theta, (list, tuple)): theta_dicts = [dict(t) if not isinstance(t, dict) else t for t in theta] else: try: theta_dicts = [dict(t) for t in theta] except (TypeError, ValueError): raise TypeError( "theta must be a Mapping, Sequence of Mappings, or PompParameters" ) self._validate_raw(theta_dicts) self._params = cast(list[dict[str, float]], theta_dicts) self._canonical_param_names = list(self._params[0].keys()) self.estimation_scale = estimation_scale self._logLik = self._format_logLik(logLik, len(self._params)) def _format_logLik(self, ll: np.ndarray | None, n_reps: int) -> np.ndarray: """Helper to standardize logLik input.""" if ll is None: return np.full(n_reps, np.nan) ll = np.array(ll, dtype=float) if ll.ndim == 0: # Handle single scalar input (broadcast) return np.full(n_reps, ll) if len(ll) != n_reps: raise ValueError( f"Length of logLik ({len(ll)}) must match parameters ({n_reps})" ) return ll def _validate_raw(self, theta: list[dict[str, Numeric]]): if not isinstance(theta, list): raise TypeError("theta must be a list of dictionaries") if len(theta) == 0: raise ValueError("theta cannot be empty") if not all(isinstance(t, dict) for t in theta): raise TypeError("All elements in theta must be dictionaries") for i, t in enumerate(theta): for key, value in t.items(): if isinstance(value, (int, np.number, jax.Array)) and not isinstance( value, bool ): try: t[key] = float(value) except (TypeError, ValueError): pass if not isinstance(t[key], float): raise TypeError( f"Parameter '{key}' at index {i} is not a float: got {type(t[key]).__name__}" ) # Ensure all dicts have identical keys first_keys = set(theta[0].keys()) for i, t in enumerate(theta[1:]): if set(t.keys()) != first_keys: raise ValueError( f"Parameter set at index {i + 1} has different keys than the first set. " f"Expected {first_keys}, got {set(t.keys())}" ) def _child_PompParameters( self, theta: ThetaInput = None, logLik: np.ndarray | None = None, estimation_scale: bool | None = None, ): """ Make a new PompParameters object with current attributes as the default. """ # Explicitly handle None to avoid ambiguous truth-value checks on arrays theta_sel = self._params if theta is None else theta estimation_scale_sel = ( self.estimation_scale if estimation_scale is None else estimation_scale ) logLik_sel = self._logLik if logLik is None else logLik return PompParameters( theta=theta_sel, logLik=logLik_sel, estimation_scale=estimation_scale_sel )
[docs] def to_jax_array(self, param_names: list[str], **kwargs) -> jax.Array: """ Convert to JAX array matching the order of param_names. Returns shape (n_reps, n_params). """ # Logic formerly in _theta_dict_to_array try: ordered_values = [[t[name] for name in param_names] for t in self._params] except KeyError as e: raise KeyError( f"Parameter {e} expected by model but missing from parameter set." ) return jnp.array(ordered_values)
@property def logLik(self) -> np.ndarray: return self._logLik @logLik.setter def logLik(self, value): self._logLik = self._format_logLik(value, self.num_replicates())
[docs] def to_jax_array_canonical(self) -> jax.Array: return self.to_jax_array(self._canonical_param_names)
[docs] def num_replicates(self) -> int: return len(self._params)
[docs] def num_params(self) -> int: return len(self._canonical_param_names)
[docs] def subset(self, indices: Union[int, list[int], slice]) -> "PompParameters": if isinstance(indices, int): indices = [indices] # Determine subset based on type if isinstance(indices, slice): subset_params = self._params[indices] subset_logLik = self._logLik[indices] else: subset_params = [self._params[i] for i in indices] subset_logLik = self._logLik[indices] return self._child_PompParameters(subset_params, logLik=subset_logLik)
[docs] def get_param_names(self) -> list[str]: if not self._params: return [] return list(self._canonical_param_names)
[docs] def to_list(self) -> list[ParamDict]: """Returns the internal list of dictionaries.""" return cast(list[ParamDict], self._params)
[docs] def transform( self, par_trans: ParTrans, direction: Literal["to_est", "from_est"] | None = None, ): """ Transform the parameters to or from the estimation parameter space. Args: par_trans: The parameter transformation object. direction: The direction of transformation. If None, the direction is determined by the estimation_scale attribute. """ auto = direction is None if auto: direction = "from_est" if self.estimation_scale else "to_est" if direction not in ["to_est", "from_est"]: raise ValueError(f"Invalid direction: {direction}") if direction == "to_est" and not self.estimation_scale: self._params = [ par_trans.to_floats(theta_i, "to_est") for theta_i in self._params ] self.estimation_scale = True elif direction == "from_est" and self.estimation_scale: self._params = [ par_trans.to_floats(theta_i, "from_est") for theta_i in self._params ] self.estimation_scale = False else: # If this statement is reached, the parameters are already in the correct estimation space. Nothing needs to be done. pass
[docs] def prune(self, n: int = 1, refill: bool = True) -> None: """ Replace internal parameter sets with the top `n` based on stored log-likelihoods. Args: n: Number of top parameter sets to keep. refill: If True, repeat the top `n` parameter sets to match the previous number of replicates. If False, keep only the `n` sets. """ n_reps = self.num_replicates() if n_reps == 0: raise ValueError("No parameter sets available to prune.") if n < 1: raise ValueError("n must be at least 1.") if self._logLik is None or np.all(np.isnan(self._logLik)): raise ValueError("No valid log-likelihoods available to prune (all nan).") # Indices of top-n log-likelihoods (descending order) top_indices = self._logLik.argsort()[-n:][::-1] top_params = [self._params[i] for i in top_indices] top_logLik = self._logLik[top_indices] if refill: prev_len = n_reps repeats = (prev_len + n - 1) // n # Ceiling division new_params = (top_params * repeats)[:prev_len] new_logLik = np.tile(top_logLik, repeats)[:prev_len] else: new_params = top_params new_logLik = top_logLik self._params = new_params self._logLik = new_logLik
@overload def __getitem__(self, index: int) -> dict[str, float]: ... @overload def __getitem__(self, index: slice) -> "PompParameters": ... def __getitem__(self, index: int | slice) -> dict[str, float] | "PompParameters": """ Support indexing like theta[0] or theta[0:2] or theta[0]["param_name"]. - Integer index: returns the dict at that position - Slice: returns a new PompParameters object """ if isinstance(index, int): # Integer index: return the dict directly return self._params[index] elif isinstance(index, slice): # Slice: return a new PompParameters object return self._child_PompParameters(self._params[index]) else: raise TypeError(f"Invalid index: {index}. Must be an integer or slice.") def __iter__(self): """Support iteration over parameter sets.""" return iter(self._params) def __len__(self) -> int: """Return the number of parameter replicates.""" return len(self._params) def __mul__(self, n: int) -> "PompParameters": """ Support replication like theta * 3. Returns a new PompParameters with n copies of the parameter sets. """ if not isinstance(n, int): return NotImplemented if n < 0: raise ValueError("Multiplication factor must be non-negative") if n == 0: raise ValueError("Cannot create empty PompParameters") # Replicate the parameter sets n times replicated_params = self._params * n replicated_logLik = np.tile(self._logLik, n) return self._child_PompParameters(replicated_params, logLik=replicated_logLik) def __rmul__(self, n: int) -> "PompParameters": """Support left multiplication like 3 * theta.""" return self.__mul__(n) def __repr__(self) -> str: """String representation for debugging.""" return f"PompParameters(n_replicates={len(self._params)}, n_params={len(self._canonical_param_names)})" def __eq__(self, other) -> bool: """ Check equality with another PompParameters object. Two PompParameters are equal if their canonical parameter names and parameter sets are equal. """ if not isinstance(other, type(self)): return False # Compare canonical parameter names if self._canonical_param_names != other._canonical_param_names: return False # Compare parameter lists if len(self._params) != len(other._params): return False for p1, p2 in zip(self._params, other._params): if p1 != p2: return False # Check same scale if self.estimation_scale != other.estimation_scale: return False return True
[docs] @staticmethod def merge(*param_objs: "PompParameters") -> "PompParameters": """ Merge replications from an arbitrary number of PompParameters objects into a single PompParameters object. All objects must have the same canonical parameter names and estimation scale. Usage: `merged = PompParameters.merge(p1, p2, p3, ...)` """ if len(param_objs) == 0: raise ValueError("At least one PompParameters object must be provided.") first = param_objs[0] for obj in param_objs: if not isinstance(obj, type(first)): raise TypeError("All merged objects must be of type PompParameters.") if obj._canonical_param_names != first._canonical_param_names: raise ValueError( "All PompParameters objects must have the same canonical parameter names." ) if obj.estimation_scale != first.estimation_scale: raise ValueError( "All PompParameters objects must have the same estimation scale." ) all_params = [] all_logLik = [] for obj in param_objs: all_params.extend(obj._params) all_logLik.append(obj._logLik) merged_logLik = np.concatenate(all_logLik) if all_logLik else np.array([]) return PompParameters(all_params, logLik=merged_logLik)
[docs] class PanelParameters(ParameterSet): """ Manages parameters for PanelPomp models. Internal storage is a list of dictionaries, always containing "shared" and "unit_specific" keys mapping to DataFrames (which may be empty). Parameters ---------- theta : PanelParameters | dict | list, optional Parameters for the panel model. Accepts: - A single dictionary with "shared" and "unit_specific" keys. - A list of such dictionaries. - An existing PanelParameters object. logLik_unit : np.ndarray, optional A numpy array of unit-specific log-likelihoods. estimation_scale : bool, optional Whether the parameters are in the estimation scale. Defaults to False. """ _theta: list[dict[str, pd.DataFrame | None]] estimation_scale: bool _logLik_unit: np.ndarray _logLik: np.ndarray _canonical_shared_param_names: list[str] _canonical_unit_param_names: list[str] _canonical_param_names: list[str] def __init__( self, theta: Union[ dict[str, pd.DataFrame | None], list[dict[str, pd.DataFrame | None]], "PanelParameters", None, ], logLik_unit: np.ndarray | None = None, estimation_scale: bool = False, ): if isinstance(theta, PanelParameters): self._theta = [ {k: v.copy() if v is not None else None for k, v in t.items()} for t in theta._theta ] self.estimation_scale = theta.estimation_scale self._validate_none_consistency() self._validate_df_consistency() self._logLik_unit = ( theta.logLik_unit.copy() if logLik_unit is None else self._format_logLik_unit(logLik_unit, len(self._theta)) ) else: self._theta = self._normalize_input(theta) self.estimation_scale = estimation_scale self._validate_none_consistency() self._validate_df_consistency() self._logLik_unit = self._format_logLik_unit(logLik_unit, len(self._theta)) self._logLik = self._logLik_unit.sum(axis=1) shared_df = self._theta[0]["shared"] unit_df = self._theta[0]["unit_specific"] if shared_df is not None: self._canonical_shared_param_names = list(shared_df.index) else: self._canonical_shared_param_names = [] if unit_df is not None: self._canonical_unit_param_names = list(unit_df.index) else: self._canonical_unit_param_names = [] self._canonical_param_names = list( set(self._canonical_shared_param_names + self._canonical_unit_param_names) ) def _normalize_input( self, theta: None | dict[str, pd.DataFrame | None] | list[dict[str, pd.DataFrame | None]], ) -> list[dict[str, pd.DataFrame | None]]: """ Normalize input to list of dicts with valid DataFrames or None. Checks that all dictionaries have the keys "shared" and "unit_specific" and that all values are None or pd.DataFrames. """ if theta is None: return [] if isinstance(theta, dict): theta = [theta] if not isinstance(theta, list): raise TypeError("theta must be a dictionary or a list of dictionaries") for i, t in enumerate(theta): keys = set(t.keys()) if keys != {"shared", "unit_specific"}: raise ValueError( f"Each parameter dictionary must have exactly the keys 'shared' and 'unit_specific'. " f"Found keys {keys} in item {i}." ) if not all(isinstance(v, (pd.DataFrame, type(None))) for v in t.values()): raise TypeError( f"All values in each dictionary must be None or pd.DataFrames. " f"Found values {t.values()} of type {type(t.values())} in item {i}." ) if t["shared"] is not None: t["shared"] = t["shared"].astype(float) if t["unit_specific"] is not None: t["unit_specific"] = t["unit_specific"].astype(float) return theta.copy() def _validate_none_consistency(self): """ Sets internal flags for whether all or only some 'shared'/'unit_specific' are None. """ shared_none = [t["shared"] is None for t in self._theta] unit_none = [t["unit_specific"] is None for t in self._theta] some_shared_none = any(shared_none) and not all(shared_none) some_unit_specific_none = any(unit_none) and not all(unit_none) if some_shared_none: raise ValueError( "Some, but not all, shared parameters are None. This is not supported." ) if some_unit_specific_none: raise ValueError( "Some, but not all, unit-specific parameters are None. This is not supported." ) def _format_logLik_unit( self, ll_unit: np.ndarray | None, n_reps: int ) -> np.ndarray: """Standardize logLik dimensions.""" # Determine n_units from the first valid unit_specific dataframe n_units = 0 if n_reps > 0 and self._theta[0]["unit_specific"] is not None: n_units = self._theta[0]["unit_specific"].shape[1] if ll_unit is None: return np.full((n_reps, n_units), np.nan) ll_unit = np.array(ll_unit, dtype=float) if ll_unit.ndim == 1 and n_reps == 1: return ll_unit.reshape(1, -1) if ll_unit.shape != (n_reps, n_units): # Allow shape (n_reps, 0) if n_units is 0 if n_units == 0 and ll_unit.size == 0: return np.empty((n_reps, 0)) raise ValueError( f"logLik_unit shape mismatch: {ll_unit.shape} vs ({n_reps}, {n_units})" ) return ll_unit def _validate_df_consistency(self): """ Ensure all replicates have consistent data frames: - Shared parameters must have the same index and exactly one column. - Unit-specific parameters must have the same index and columns. - If a parameter is in shared, it must not be in unit-specific and vice-versa. """ if not self.theta: return ref = self.theta[0] if ref["shared"] is not None: ref_s_idx = ref["shared"].index ref_s_cols = ref["shared"].columns if len(ref_s_cols) != 1: raise ValueError("Shared parameters must have exactly one column.") else: ref_s_idx = [] if ref["unit_specific"] is not None: ref_u_idx = ref["unit_specific"].index ref_u_cols = ref["unit_specific"].columns else: ref_u_idx = [] ref_u_cols = [] shared_param_names = set(ref_s_idx) unit_param_names = set(ref_u_idx) overlap = shared_param_names.intersection(unit_param_names) if overlap: raise ValueError( f"Parameter name(s) found in both shared and unit-specific parameters: {sorted(overlap)}" ) for i, t in enumerate(self.theta[1:], 1): if t["shared"] is not None: if not t["shared"].index.equals(ref_s_idx): raise ValueError( f"Shared parameter index mismatch at replicate {i}." ) if t["unit_specific"] is not None: if not t["unit_specific"].index.equals(ref_u_idx): raise ValueError(f"Unit parameter index mismatch at replicate {i}.") if not t["unit_specific"].columns.equals(ref_u_cols): raise ValueError(f"Unit columns mismatch at replicate {i}.") @property def logLik(self) -> np.ndarray: return self._logLik @logLik.setter def logLik(self, value): # We generally don't set full logLik directly for panels, but strictly: # We can't infer unit contribution, so this setter is ambiguous # unless we just broadcast/reset. For now, assume read-only derived. pass @property def logLik_unit(self) -> np.ndarray: return self._logLik_unit @logLik_unit.setter def logLik_unit(self, value): self._logLik_unit = self._format_logLik_unit(value, len(self.theta)) self._logLik = self._logLik_unit.sum(axis=1) @property def theta(self): return self._theta.copy() @theta.setter def theta( self, value: dict[str, pd.DataFrame | None] | list[dict[str, pd.DataFrame | None]], ): self._theta = self._normalize_input(value) self._validate_none_consistency() self._validate_df_consistency() n_reps = len(value) self._logLik_unit = self._format_logLik_unit(None, n_reps) self._logLik = self._logLik_unit.sum(axis=1)
[docs] def num_replicates(self) -> int: return len(self._theta)
[docs] def get_param_names(self) -> list[str]: return self._canonical_param_names
[docs] def get_shared_param_names(self) -> list[str]: return self._canonical_shared_param_names
[docs] def get_unit_param_names(self) -> list[str]: return self._canonical_unit_param_names
[docs] def get_unit_names(self) -> list[str]: """ Return the list of unit names from the first replicate's unit_specific DataFrame. """ if not self._theta: return [] first = self._theta[0] unit_specific_df = first.get("unit_specific") if unit_specific_df is not None: return list(unit_specific_df.columns) return []
[docs] def subset(self, indices: Union[int, list[int], slice]) -> "PanelParameters": if isinstance(indices, int): indices = [indices] # Slicing handled by list/array slicing logic if isinstance(indices, slice): sub_theta = self._theta[indices] sub_ll = self._logLik_unit[indices] else: sub_theta = [self._theta[i] for i in indices] sub_ll = self._logLik_unit[indices] return PanelParameters( sub_theta, logLik_unit=sub_ll, estimation_scale=self.estimation_scale )
[docs] def to_jax_array( self, param_names: list[str], unit_names: list[str] | None = None, **kwargs ) -> jax.Array: reps = len(self._theta) if reps == 0: return jnp.empty((0, 0, 0)) # Infer unit names if needed if unit_names is None: if self._theta[0]["unit_specific"] is not None: unit_names = list(self._theta[0]["unit_specific"].columns) else: raise ValueError( "unit_names required when no unit_specific parameters exist" ) n_units = len(unit_names) n_params = len(param_names) # Identify source of each parameter ref = self._theta[0] if ref["shared"] is not None: shared_keys = set(ref["shared"].index) else: shared_keys = set() if ref["unit_specific"] is not None: specific_keys = set(ref["unit_specific"].index) else: specific_keys = set() shared_idx = [] shared_names_list = [] specific_idx = [] specific_names_list = [] for p_idx, p_name in enumerate(param_names): if p_name in specific_keys: specific_idx.append(p_idx) specific_names_list.append(p_name) elif p_name in shared_keys: shared_idx.append(p_idx) shared_names_list.append(p_name) else: raise KeyError(f"Parameter '{p_name}' not found.") full_array = np.zeros((reps, n_units, n_params)) s_col = ref["shared"].columns[0] if ref["shared"] is not None else None for j, t in enumerate(self._theta): if specific_names_list and t["unit_specific"] is not None: try: spec_values = ( t["unit_specific"] .loc[specific_names_list, unit_names] .to_numpy() ) full_array[j][:, specific_idx] = spec_values.T except KeyError as e: missing = [ p for p in specific_names_list if p not in t["unit_specific"].index ] if missing: raise KeyError(f"Unit mismatch for parameter {missing[0]}") else: raise e if shared_names_list and t["shared"] is not None: shared_values = t["shared"].loc[shared_names_list, s_col].to_numpy() full_array[j][:, shared_idx] = np.broadcast_to( shared_values, (n_units, len(shared_idx)) ) return jnp.array(full_array)
[docs] def transform( self, par_trans: ParTrans, direction: Literal["to_est", "from_est"] | None = None, ): auto = direction is None if auto: direction = "from_est" if self.estimation_scale else "to_est" if (direction == "to_est" and not self.estimation_scale) or ( direction == "from_est" and self.estimation_scale ): self._theta = par_trans.panel_transform_list( self._theta, direction=direction ) self.estimation_scale = not self.estimation_scale
[docs] def prune(self, n: int = 1, refill: bool = True) -> None: if not self._theta: return # Sort by total log likelihood top_indices = self._logLik.argsort()[-n:][::-1] top_theta = [self._theta[i] for i in top_indices] top_ll_unit = self._logLik_unit[top_indices] if refill: n_reps = len(self._theta) repeats = (n_reps + n - 1) // n self._theta = (top_theta * repeats)[:n_reps] self._logLik_unit = np.tile(top_ll_unit, (repeats, 1))[:n_reps] else: self._theta = top_theta self._logLik_unit = top_ll_unit self._logLik = self._logLik_unit.sum(axis=1)
[docs] def mix_and_match(self) -> None: """ Sorts unit-specific parameters and shared parameters in descending order of unit log-likelihood and shared log-likelihood, respectively, then combines them to form new parameter sets. The nth best parameter for a given unit or for the shared parameters is placed in the nth parameter set. """ unit_names = self.get_unit_names() if not self._theta: return # Rank by shared logLik (total) shared_ranks = self._logLik.argsort()[::-1] # Rank by unit-specific logLik unit_ranks = {} for u_idx, unit in enumerate(unit_names): unit_ranks[unit] = self._logLik_unit[:, u_idx].argsort()[::-1] new_theta = [] new_ll_unit = np.zeros_like(self._logLik_unit) for i in range(len(self._theta)): # 1. Best shared params for this position s_idx = shared_ranks[i] best_shared = ( self._theta[s_idx]["shared"].copy() if self._theta[s_idx]["shared"] is not None else None ) # 2. Best unit params for each unit for this position new_u_data = {} for u_idx, unit in enumerate(unit_names): u_best_idx = unit_ranks[unit][i] # Copy the logLik for this unit/replicate combo new_ll_unit[i, u_idx] = self._logLik_unit[u_best_idx, u_idx] # Extract the unit specific column src_df = self._theta[u_best_idx]["unit_specific"] if src_df is not None and not src_df.empty and unit in src_df.columns: new_u_data[unit] = src_df[unit].copy() # Construct new unit dataframe if new_u_data: # Use index from the first theta (guaranteed consistent) if self._theta[0]["unit_specific"] is not None: new_u_df = pd.DataFrame( new_u_data, index=self._theta[0]["unit_specific"].index ) else: new_u_df = None else: new_u_df = None new_theta.append({"shared": best_shared, "unit_specific": new_u_df}) self._theta = new_theta self._logLik_unit = new_ll_unit self._logLik = new_ll_unit.sum(axis=1)
def to_list(self) -> list[dict[str, pd.DataFrame | None]]: return self._theta.copy() @overload def __getitem__(self, index: int) -> dict[str, pd.DataFrame | None]: ... @overload def __getitem__(self, index: slice | list[int]) -> "PanelParameters": ... def __getitem__( self, index: int | slice | list[int] ) -> dict[str, pd.DataFrame | None] | "PanelParameters": if isinstance(index, int): return self._theta[index] return self.subset(index) def __iter__(self): return iter(self._theta) def __len__(self): return len(self._theta) def __mul__(self, n: int) -> "PanelParameters": """Replicate the parameter set n times.""" if not isinstance(n, int): return NotImplemented if n < 0: raise ValueError("n must be non-negative") # Replicate the internal list of dicts new_theta = self._theta * n # Replicate the logLik array if self._logLik_unit.size > 0: new_ll_unit = np.tile(self._logLik_unit, (n, 1)) else: # Handle edge case of empty params or 0 replicates n_cols = self._logLik_unit.shape[1] if self._logLik_unit.ndim > 1 else 0 new_ll_unit = np.empty((len(new_theta), n_cols)) return PanelParameters( new_theta, logLik_unit=new_ll_unit, estimation_scale=self.estimation_scale ) def __rmul__(self, n: int) -> "PanelParameters": """Support left multiplication (e.g. 5 * params).""" return self.__mul__(n) def __copy__(self): # Create a new instance without calling __init__ to skip validation overhead cls = self.__class__ new_obj = cls.__new__(cls) new_obj._theta = list(self._theta) new_obj._logLik_unit = self._logLik_unit new_obj._logLik = self._logLik new_obj.estimation_scale = self.estimation_scale new_obj._canonical_shared_param_names = self._canonical_shared_param_names new_obj._canonical_unit_param_names = self._canonical_unit_param_names new_obj._canonical_param_names = self._canonical_param_names return new_obj def __deepcopy__(self, memo): # Create a new instance without calling __init__ to skip validation overhead cls = self.__class__ new_obj = cls.__new__(cls) memo[id(self)] = new_obj new_obj._theta = copy.deepcopy(self._theta, memo) new_obj._logLik_unit = copy.deepcopy(self._logLik_unit, memo) new_obj._logLik = copy.deepcopy(self._logLik, memo) new_obj.estimation_scale = self.estimation_scale new_obj._canonical_shared_param_names = copy.deepcopy( self._canonical_shared_param_names, memo ) new_obj._canonical_unit_param_names = copy.deepcopy( self._canonical_unit_param_names, memo ) new_obj._canonical_param_names = copy.deepcopy( self._canonical_param_names, memo ) return new_obj def __eq__(self, other) -> bool: """ Check structural equality with another PanelParameters object. Two PanelParameters are equal if they have the same canonical parameter names, estimation scale, log-likelihoods, and per-replicate parameter DataFrames. """ if not isinstance(other, type(self)): return False if self.estimation_scale != other.estimation_scale: return False if self._canonical_shared_param_names != other._canonical_shared_param_names: return False if self._canonical_unit_param_names != other._canonical_unit_param_names: return False if not np.array_equal(self._logLik_unit, other._logLik_unit, equal_nan=True): return False if len(self._theta) != len(other._theta): return False for t1, t2 in zip(self._theta, other._theta): for key in ("shared", "unit_specific"): df1 = t1.get(key) df2 = t2.get(key) if (df1 is None) != (df2 is None): return False if df1 is not None: try: assert df2 is not None pd.testing.assert_frame_equal(df1, df2, check_dtype=True) except AssertionError: return False return True @staticmethod def merge(*param_objs: "PanelParameters") -> "PanelParameters": """ Merge replications from an arbitrary number of PanelParameters objects into a single PanelParameters object. All objects must have the same canonical parameter names, unit names, and estimation scale. Usage: `merged = PanelParameters.merge(p1, p2, p3, ...)` """ if len(param_objs) == 0: raise ValueError("At least one PanelParameters object must be provided.") first = param_objs[0] for obj in param_objs: if not isinstance(obj, type(first)): raise TypeError("All merged objects must be of type PanelParameters.") if obj._canonical_shared_param_names != first._canonical_shared_param_names: raise ValueError( "All PanelParameters objects must have the same canonical shared parameter names." ) if obj._canonical_unit_param_names != first._canonical_unit_param_names: raise ValueError( "All PanelParameters objects must have the same canonical unit parameter names." ) if obj.estimation_scale != first.estimation_scale: raise ValueError( "All PanelParameters objects must have the same estimation scale." ) if obj.get_unit_names() != first.get_unit_names(): raise ValueError( "All PanelParameters objects must have the same unit names." ) all_theta = [] all_logLik_unit = [] for obj in param_objs: all_theta.extend(obj._theta) all_logLik_unit.append(obj._logLik_unit) merged_logLik_unit = ( np.concatenate(all_logLik_unit, axis=0) if all_logLik_unit else np.array([]) ) return PanelParameters( all_theta, logLik_unit=merged_logLik_unit, estimation_scale=first.estimation_scale, )