"""This module implements a linear Gaussian model for POMP."""
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from pypomp.core.pomp import Pomp
from pypomp.core.par_trans import ParTrans
from pypomp.types import (
StateDict,
ParamDict,
CovarDict,
TimeFloat,
StepSizeFloat,
RNGKey,
ObservationDict,
InitialTimeFloat,
)
def _get_thetas(theta):
A = jnp.array([theta["A11"], theta["A12"], theta["A21"], theta["A22"]]).reshape(
2, 2
)
C = jnp.array([theta["C11"], theta["C12"], theta["C21"], theta["C22"]]).reshape(
2, 2
)
def make_pd(m11_val, m12_val, m22_val):
m11 = jnp.maximum(m11_val, 1e-12)
m22 = jnp.maximum(m22_val, 1e-12)
limit = 0.999 * jnp.sqrt(m11 * m22)
m_off_clipped = jnp.clip(m12_val, -limit, limit)
return jnp.array([[m11, m_off_clipped], [m_off_clipped, m22]])
Q = make_pd(theta["Q11"], theta["Q12"], theta["Q22"])
R = make_pd(theta["R11"], theta["R12"], theta["R22"])
return A, C, Q, R
def _transform_thetas(A, C, Q, R):
return jnp.concatenate(
[
A.flatten(),
C.flatten(),
jnp.array([Q[0, 0], Q[0, 1], Q[1, 1]]),
jnp.array([R[0, 0], R[0, 1], R[1, 1]]),
]
)
# TODO: Add custom starting position.
def _rinit(
theta_: ParamDict,
key: RNGKey,
covars: CovarDict,
t0: InitialTimeFloat,
):
A, C, Q, R = _get_thetas(theta_)
result = jax.random.multivariate_normal(key=key, mean=jnp.array([0, 0]), cov=Q)
return {"X1": result[0], "X2": result[1]}
def _rproc(
X_: StateDict,
theta_: ParamDict,
key: RNGKey,
covars: CovarDict,
t: TimeFloat,
dt: StepSizeFloat,
):
A, C, Q, R = _get_thetas(theta_)
X_array = jnp.array([X_["X1"], X_["X2"]])
result = jax.random.multivariate_normal(key=key, mean=A @ X_array, cov=Q)
return {"X1": result[0], "X2": result[1]}
def _dmeas(
Y_: ObservationDict,
X_: StateDict,
theta_: ParamDict,
covars: CovarDict,
t: TimeFloat,
):
A, C, Q, R = _get_thetas(theta_)
X_array = jnp.array([X_["X1"], X_["X2"]])
Y_array = jnp.array([Y_["Y1"], Y_["Y2"]])
return jax.scipy.stats.multivariate_normal.logpdf(Y_array, X_array, R)
def _rmeas(
X_: StateDict,
theta_: ParamDict,
key: RNGKey,
covars: CovarDict,
t: TimeFloat,
):
A, C, Q, R = _get_thetas(theta_)
X_array = jnp.array([X_["X1"], X_["X2"]])
return jax.random.multivariate_normal(key=key, mean=C @ X_array, cov=R)
def _to_est(theta: ParamDict) -> ParamDict:
new_theta = {**theta}
for name in "ACQR":
new_theta[f"{name}11"] = jnp.log(theta[f"{name}11"])
new_theta[f"{name}22"] = jnp.log(theta[f"{name}22"])
return new_theta
def _from_est(theta: ParamDict) -> ParamDict:
new_theta = {**theta}
for name in "ACQR":
new_theta[f"{name}11"] = jnp.exp(theta[f"{name}11"])
new_theta[f"{name}22"] = jnp.exp(theta[f"{name}22"])
return new_theta
[docs]
def LG(
T: int = 4,
A: jax.Array = jnp.array(
[[jnp.cos(0.2), -jnp.sin(0.2)], [jnp.sin(0.2), jnp.cos(0.2)]]
),
C: jax.Array = jnp.eye(2),
Q: jax.Array = jnp.array([[1, 2e-2], [2e-2, 1]]) / 100,
R: jax.Array = jnp.array([[1, 0.1], [0.1, 1]]) / 10,
key: jax.Array = jax.random.key(1),
) -> Pomp:
"""
Initialize a Pomp object with the linear Gaussian model.
Parameters
----------
T : int, optional
The number of time steps to generate data for. Defaults to 4.
A : jax.Array, optional
The transition matrix.
C : jax.Array, optional
The measurement matrix.
Q : jax.Array, optional
The covariance matrix of the state noise.
R : jax.Array, optional
The covariance matrix of the measurement noise.
key : jax.Array, optional
The random key used to generate the data.
Returns
-------
A Pomp object initialized with the linear Gaussian model parameters and the generated data.
"""
# Validate covariance matrices Q and R
for name, mat in [("Q", Q), ("R", R)]:
mat_np = np.asarray(mat)
if not np.allclose(mat_np, mat_np.T, atol=1e-8, rtol=1e-5):
raise ValueError(f"Covariance matrix {name} must be symmetric.")
try:
np.linalg.cholesky(mat_np)
except np.linalg.LinAlgError:
raise ValueError(f"Covariance matrix {name} must be positive-definite.")
theta_names = [
"A11",
"A12",
"A21",
"A22",
"C11",
"C12",
"C21",
"C22",
"Q11",
"Q12",
"Q22",
"R11",
"R12",
"R22",
]
theta = dict(zip(theta_names, _transform_thetas(A, C, Q, R).tolist()))
ys_temp = pd.DataFrame(
0, index=np.arange(1, T + 1, dtype=float), columns=pd.Index(["Y1", "Y2"])
)
LG_obj_temp = Pomp(
rinit=_rinit,
rproc=_rproc,
dmeas=_dmeas,
rmeas=_rmeas,
ys=ys_temp,
t0=0.0,
nstep=1,
dt=None,
theta=theta,
covars=None,
statenames=["X1", "X2"],
par_trans=ParTrans(to_est=_to_est, from_est=_from_est),
)
LG_obj = LG_obj_temp.simulate(key=key, nsim=1, as_pomp=True)
assert isinstance(LG_obj, Pomp)
return LG_obj