Source code for pypomp.core.par_trans

from typing import Callable, Literal, Mapping, cast
from ..types import ParamDict
import importlib
import pandas as pd
import jax
import jax.numpy as jnp
import numpy as np


[docs] class ParTrans: """ Handles parameter transformations between natural and estimation parameter spaces. """ to_est: Callable[[ParamDict], ParamDict] """The parameter transformation function to the estimation parameter space.""" from_est: Callable[[ParamDict], ParamDict] """The parameter transformation function from the estimation parameter space to the natural parameter space.""" def __init__( self, to_est: Callable[[ParamDict], ParamDict] | None = None, from_est: Callable[[ParamDict], ParamDict] | None = None, ): self.to_est: Callable[[ParamDict], ParamDict] = to_est or _to_est_default self.from_est: Callable[[ParamDict], ParamDict] = from_est or _from_est_default
[docs] def panel_transform( self, theta: dict[str, pd.DataFrame | None], direction: Literal["to_est", "from_est"], ) -> dict[str, pd.DataFrame | None]: """ Transform shared and unit-specific parameters for a single replicate. Input theta contains 'shared' and/or 'unit_specific' DataFrames. """ func = self.to_est if direction == "to_est" else self.from_est # Normalize to empty DFs if None or missing for cleaner logic s_df = theta.get("shared") if s_df is None: s_df = None u_df = theta.get("unit_specific") if u_df is None: u_df = None res: dict[str, pd.DataFrame | None] = {"shared": None, "unit_specific": None} # Pre-calculate shared dictionary (param -> value) s_vals = cast(dict, s_df.iloc[:, 0].to_dict()) if s_df is not None else {} # 1. Transform Shared Parameters if s_df is not None: # Context: Shared values + First unit's specific values (if any) ctx = s_vals.copy() if u_df is not None: first_unit = u_df.columns[0] ctx.update(cast(dict, u_df[first_unit].to_dict())) trans: ParamDict = func(cast(ParamDict, ctx)) # Filter output back to just shared keys new_s_vals = {k: trans[k] for k in s_vals} res["shared"] = pd.DataFrame( list(new_s_vals.values()), index=pd.Index(list(new_s_vals.keys())), columns=pd.Index(["shared"]), ) # 2. Transform Unit-Specific Parameters if u_df is not None: new_u_data = {} for unit in u_df.columns: # Context: Shared values + This unit's specific values ctx = s_vals.copy() ctx.update(cast(dict, u_df[unit].to_dict())) trans = func(cast(ParamDict, ctx)) # Filter output back to specific keys (maintaining order) new_u_data[unit] = [trans[k] for k in u_df.index] res["unit_specific"] = pd.DataFrame(new_u_data, index=u_df.index) return res
[docs] def panel_transform_list( self, theta_list: list[dict[str, pd.DataFrame | None]], direction: Literal["to_est", "from_est"], ) -> list[dict[str, pd.DataFrame | None]]: """ Apply transform to a list of parameter sets. """ return [self.panel_transform(t, direction) for t in theta_list]
[docs] def to_floats( self, theta: Mapping[str, float | jax.Array], direction: Literal["to_est", "from_est"], ) -> dict[str, float]: """ Convert the theta dictionary values from jax.Array to float. """ if direction == "to_est": theta_out = self.to_est(dict(theta)) return {k: float(v) for k, v in theta_out.items()} elif direction == "from_est": theta_out = self.from_est(dict(theta)) return {k: float(v) for k, v in theta_out.items()} else: raise ValueError(f"Invalid direction: {direction}")
[docs] def transform_array( self, param_array: np.ndarray, param_names: list[str], direction: Literal["to_est", "from_est"], ) -> np.ndarray: """ Transform a parameter array to or from the (unconstrained) estimation parameter space. This wrapper converts an array of parameters to a dict, applies the dict-to-dict transformation function, and converts back to an array. Args: param_array: Array of parameter values with shape (..., n_params) param_names: List of parameter names in the same order as the array direction: Direction of transformation ("to_est" or "from_est") Returns: Transformed parameter array with the same shape as input """ if direction not in ["to_est", "from_est"]: raise ValueError(f"Invalid direction: {direction}") transform_fn = self.to_est if direction == "to_est" else self.from_est original_shape = param_array.shape if len(original_shape) == 1: param_array_2d = param_array.reshape(1, -1) else: param_array_2d = param_array.reshape(-1, original_shape[-1]) def transform_single_row(row): param_dict = {name: row[i] for i, name in enumerate(param_names)} transformed_dict = transform_fn(param_dict) return jnp.array([transformed_dict[name] for name in param_names]) transform_vectorized = jax.vmap(transform_single_row) param_jax = jnp.array(param_array_2d) transformed_jax = transform_vectorized(param_jax) transformed_array = np.array(transformed_jax) if len(original_shape) == 1: return transformed_array.reshape(original_shape) else: return transformed_array.reshape(original_shape)
[docs] def transform_panel_traces( self, shared_traces: np.ndarray | None, unit_traces: np.ndarray | None, shared_param_names: list[str], unit_param_names: list[str], unit_names: list[str], direction: Literal["to_est", "from_est"], ) -> tuple[np.ndarray | None, np.ndarray | None]: """ Transform panel traces from estimation space to natural space. For panel models, shared and unit-specific parameters may be interdependent in the transformation, so they need to be transformed together. Args: shared_traces: Array of shared parameter traces, shape (n_reps, n_iters, n_shared+1) where [:, :, 0] is loglik and [:, :, 1:] are shared params unit_traces: Array of unit-specific parameter traces, shape (n_reps, n_iters, n_spec+1, n_units) where [:, :, 0, :] is per-unit loglik and [:, :, 1:, :] are unit-specific params shared_param_names: List of shared parameter names unit_param_names: List of unit-specific parameter names unit_names: List of unit names direction: Direction of transformation ("to_est" or "from_est") Returns: Tuple of (transformed_shared_traces, transformed_unit_traces) with same shapes as inputs """ if direction not in ["to_est", "from_est"]: raise ValueError(f"Invalid direction: {direction}") if shared_traces is None and unit_traces is None: return None, None transform_fn = self.to_est if direction == "to_est" else self.from_est n_shared = len(shared_param_names) n_spec = len(unit_param_names) shared_out = None unit_out = None if shared_traces is not None and n_shared > 0: n_reps, n_iters, _ = shared_traces.shape shared_out = shared_traces.copy() def transform_shared_single(shared_vals, unit_vals_for_context): param_dict = { name: shared_vals[i] for i, name in enumerate(shared_param_names) } if n_spec > 0: param_dict.update( { name: unit_vals_for_context[i] for i, name in enumerate(unit_param_names) } ) transformed = transform_fn(param_dict) return jnp.array([transformed[name] for name in shared_param_names]) transform_shared_vectorized = jax.vmap(jax.vmap(transform_shared_single)) shared_params_only = jnp.array(shared_traces[:, :, 1:]) if unit_traces is not None and n_spec > 0: unit_context = jnp.array(unit_traces[:, :, 1:, 0]) else: unit_context = jnp.zeros((n_reps, n_iters, max(1, n_spec))) transformed_shared = transform_shared_vectorized( shared_params_only, unit_context ) shared_out[:, :, 1:] = np.array(transformed_shared) if unit_traces is not None and n_spec > 0: n_reps, n_iters, _, n_units = unit_traces.shape unit_out = unit_traces.copy() def transform_unit_single(shared_vals_for_context, unit_vals): param_dict = {} if n_shared > 0: param_dict.update( { name: shared_vals_for_context[i] for i, name in enumerate(shared_param_names) } ) param_dict.update( {name: unit_vals[i] for i, name in enumerate(unit_param_names)} ) transformed = transform_fn(param_dict) return jnp.array([transformed[name] for name in unit_param_names]) # At the per-iteration slice, unit_vals has shape (n_spec, n_units), # so we need to vmap over axis=1 (units axis) here. vmap_over_units = jax.vmap(transform_unit_single, in_axes=(None, 1)) vmap_over_iters = jax.vmap(vmap_over_units, in_axes=(0, 0)) transform_unit_vectorized = jax.vmap(vmap_over_iters, in_axes=(0, 0)) if shared_traces is not None and n_shared > 0: shared_context = jnp.array(shared_traces[:, :, 1:]) else: shared_context = jnp.zeros((n_reps, n_iters, max(1, n_shared))) unit_params_only = jnp.array(unit_traces[:, :, 1:, :]) transformed_unit = transform_unit_vectorized( shared_context, unit_params_only ) # transformed shape: (n_reps, n_iters, n_units, n_spec) # target slice shape: (n_reps, n_iters, n_spec, n_units) transformed_unit = jnp.transpose(transformed_unit, (0, 1, 3, 2)) unit_out[:, :, 1:, :] = np.array(transformed_unit) elif unit_traces is not None: unit_out = unit_traces.copy() return shared_out, unit_out
def __eq__(self, other): """ Check equality with another ParTrans object. Two ParTrans instances are equal if they use the same function objects for to_est and from_est. Note that functionally identical lambda functions will not be considered equal unless they are the same object. """ if not isinstance(other, type(self)): return False if self.to_est != other.to_est: return False if self.from_est != other.from_est: return False return True def __getstate__(self): """ Custom pickling method to preserve function identity. Stores module and function names for module-level functions. Lambdas/closures cannot be reliably reconstructed and will fall back to defaults on unpickling. """ state = {} if ( hasattr(self.to_est, "__module__") and hasattr(self.to_est, "__name__") and self.to_est.__module__ is not None ): state["_to_est_module"] = self.to_est.__module__ state["_to_est_name"] = self.to_est.__name__ else: state["_to_est_is_lambda"] = True if ( hasattr(self.from_est, "__module__") and hasattr(self.from_est, "__name__") and self.from_est.__module__ is not None ): state["_from_est_module"] = self.from_est.__module__ state["_from_est_name"] = self.from_est.__name__ else: state["_from_est_is_lambda"] = True return state def __setstate__(self, state): """ Custom unpickling method to reconstruct functions. Reconstructs module-level functions by importing them. Falls back to defaults for lambdas/closures. """ if "_to_est_is_lambda" in state: self.to_est = _to_est_default else: module = importlib.import_module(state["_to_est_module"]) self.to_est = getattr(module, state["_to_est_name"]) if "_from_est_is_lambda" in state: self.from_est = _from_est_default else: module = importlib.import_module(state["_from_est_module"]) self.from_est = getattr(module, state["_from_est_name"])
def _to_est_default( theta: ParamDict, ) -> ParamDict: return dict(theta) def _from_est_default( theta: ParamDict, ) -> ParamDict: return dict(theta)