Source code for pypomp.core.rw_sigma

import jax
import jax.numpy as jnp
import numpy as np


[docs] class RWSigma: """ Represents the random walk standard deviation for the parameters of a model used in the Iterated Filtering 2 (IF2) algorithm. """ sigmas: dict[str, float] """Dictionary mapping parameter names to sigma values.""" init_names: list[str] """List of parameter names that are considered initial parameters.""" not_init_names: list[str] """List of parameter names that are not considered initial parameters.""" all_names: list[str] """List of all parameter names.""" def __init__(self, sigmas: dict[str, float], init_names: list[str] = []): self.sigmas, self.init_names, self.not_init_names, self.all_names = ( self._validate_attributes(sigmas, init_names) ) def _validate_attributes( self, sigmas: dict[str, float], init_names: list[str] ) -> tuple[dict[str, float], list[str], list[str], list[str]]: """ Validates the attributes of the RWSigma object and returns prepared attributes. """ if not isinstance(sigmas, dict): raise ValueError("sigmas must be a dictionary") for param_name, value in sigmas.items(): if isinstance(value, (int, np.number, jax.Array)) and not isinstance( value, bool ): try: sigmas[param_name] = float(value) except (TypeError, ValueError): pass if not isinstance(sigmas[param_name], float): raise ValueError( f"Value for parameter '{param_name}' in sigmas dictionary must be a float: " f"got {type(sigmas[param_name]).__name__}" ) if not isinstance(init_names, list): raise ValueError("init_names must be a list") if not all(isinstance(param_name, str) for param_name in init_names): raise ValueError("All values in init_names list must be strings") if not all(param_name in sigmas.keys() for param_name in init_names): raise ValueError("All init_names names must be in sigmas dictionary") if len(init_names) != len(set(init_names)): raise ValueError("Duplicate names found in init_names") if not all(sigmas[param_name] >= 0 for param_name in sigmas.keys()): raise ValueError("All values in sigmas dictionary must be non-negative") not_init_names = [name for name in sigmas.keys() if name not in init_names] if len(not_init_names) != len(set(not_init_names)): raise ValueError("Duplicate names found in not_init_names") all_names = not_init_names + init_names if len(all_names) != len(set(all_names)): raise ValueError("Duplicate names found in all_names") return sigmas, init_names, not_init_names, all_names def _return_arrays( self, param_names: list[str] | None = None ) -> tuple[jax.Array, jax.Array]: """ Returns the sigmas and sigmas_init arrays. If param_names is provided, only returns the arrays if the parameter names in the object match those in the param_names argument. Returns: sigmas_array: Array of sigmas for non-initial parameters. Shape (d,). Contains 0 for initial parameters. sigmas_init_array: Array of sigmas for initial parameters. Shape (d,). Contains 0 for non-initial parameters. """ if param_names is None: param_names = self.all_names else: if not ( all(param_name in self.all_names for param_name in param_names) and all(param_name in param_names for param_name in self.all_names) ): raise ValueError("All param_names must be in all_names and vice versa") all_sigmas_array = jnp.array( [self.sigmas[param_name] for param_name in param_names] ) not_init_mask = jnp.array( [ 1 if param_name in self.not_init_names else 0 for param_name in param_names ] ) init_mask = jnp.array( [1 if param_name in self.init_names else 0 for param_name in param_names] ) sigmas_array = all_sigmas_array * not_init_mask sigmas_init_array = all_sigmas_array * init_mask return sigmas_array, sigmas_init_array
[docs] def cool(self, factor: float) -> None: """ Reduces all sigmas by multiplying them by the specified factor in place. Args: factor (float): Value by which to multiply each sigma. Returns: None """ if not (0 <= factor <= 1): raise ValueError("factor should be between 0 and 1") for key in self.sigmas: self.sigmas[key] *= factor
def __setitem__(self, param_name: str, value: float) -> None: """ Set the value of a sigma for a given parameter name using the indexing syntax. Args: param_name (str): The name of the parameter whose sigma value you wish to set. value (float): The new sigma value. Raises: KeyError: If param_name is not found in sigmas. TypeError: If value cannot be coerced to a float. ValueError: If the value is negative. """ if param_name not in self.sigmas: raise KeyError(f"Parameter '{param_name}' not found in sigmas.") try: value = float(value) except (TypeError, ValueError): raise TypeError( "Sigma value must be a float or numeric type that can be coerced to float." ) if value < 0: raise ValueError("Sigma value must be non-negative.") self.sigmas[param_name] = value def __eq__(self, other) -> bool: """ Check equality with another RWSigma object. Two RWSigma instances are equal if they have the same sigmas and init_names. """ if not isinstance(other, type(self)): return False if self.sigmas != other.sigmas: return False if self.init_names != other.init_names: return False return True