"""This module implements a linear Gaussian model for POMP."""
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from pypomp.core.pomp import Pomp
from pypomp.core.par_trans import ParTrans
from pypomp.types import (
StateDict,
ParamDict,
CovarDict,
TimeFloat,
StepSizeFloat,
RNGKey,
ObservationDict,
InitialTimeFloat,
)
def _get_thetas(theta):
A = jnp.array([theta["A1"], theta["A2"], theta["A3"], theta["A4"]]).reshape(2, 2)
C = jnp.array([theta["C1"], theta["C2"], theta["C3"], theta["C4"]]).reshape(2, 2)
Q = jnp.array([theta["Q1"], theta["Q2"], theta["Q3"], theta["Q4"]]).reshape(2, 2)
R = jnp.array([theta["R1"], theta["R2"], theta["R3"], theta["R4"]]).reshape(2, 2)
return A, C, Q, R
def _transform_thetas(A, C, Q, R):
return jnp.concatenate([A.flatten(), C.flatten(), Q.flatten(), R.flatten()])
# TODO: Add custom starting position.
def _rinit(
theta_: ParamDict,
key: RNGKey,
covars: CovarDict,
t0: InitialTimeFloat,
):
A, C, Q, R = _get_thetas(theta_)
result = jax.random.multivariate_normal(key=key, mean=jnp.array([0, 0]), cov=Q)
return {"X1": result[0], "X2": result[1]}
def _rproc(
X_: StateDict,
theta_: ParamDict,
key: RNGKey,
covars: CovarDict,
t: TimeFloat,
dt: StepSizeFloat,
):
A, C, Q, R = _get_thetas(theta_)
X_array = jnp.array([X_["X1"], X_["X2"]])
result = jax.random.multivariate_normal(key=key, mean=A @ X_array, cov=Q)
return {"X1": result[0], "X2": result[1]}
def _dmeas(
Y_: ObservationDict,
X_: StateDict,
theta_: ParamDict,
covars: CovarDict,
t: TimeFloat,
):
A, C, Q, R = _get_thetas(theta_)
X_array = jnp.array([X_["X1"], X_["X2"]])
Y_array = jnp.array([Y_["Y1"], Y_["Y2"]])
return jax.scipy.stats.multivariate_normal.logpdf(Y_array, X_array, R)
def _rmeas(
X_: StateDict,
theta_: ParamDict,
key: RNGKey,
covars: CovarDict,
t: TimeFloat,
):
A, C, Q, R = _get_thetas(theta_)
X_array = jnp.array([X_["X1"], X_["X2"]])
return jax.random.multivariate_normal(key=key, mean=C @ X_array, cov=R)
def _to_est(theta: ParamDict) -> ParamDict:
new_theta = {**theta}
for name in "ACQR":
new_theta[f"{name}1"] = jnp.log(theta[f"{name}1"])
new_theta[f"{name}4"] = jnp.log(theta[f"{name}4"])
return new_theta
def _from_est(theta: ParamDict) -> ParamDict:
new_theta = {**theta}
for name in "ACQR":
new_theta[f"{name}1"] = jnp.exp(theta[f"{name}1"])
new_theta[f"{name}4"] = jnp.exp(theta[f"{name}4"])
return new_theta
[docs]
def LG(
T: int = 4,
A: jax.Array = jnp.array(
[[jnp.cos(0.2), -jnp.sin(0.2)], [jnp.sin(0.2), jnp.cos(0.2)]]
),
C: jax.Array = jnp.eye(2),
Q: jax.Array = jnp.array([[1, 2e-2], [2e-2, 1]]) / 100,
R: jax.Array = jnp.array([[1, 0.1], [0.1, 1]]) / 10,
key: jax.Array = jax.random.key(111),
) -> Pomp:
"""
Initialize a Pomp object with the linear Gaussian model.
Parameters
----------
T : int, optional
The number of time steps to generate data for. Defaults to 4.
A : jax.Array, optional
The transition matrix. Defaults to the identity matrix.
C : jax.Array, optional
The measurement matrix. Defaults to the identity matrix.
Q : jax.Array, optional
The covariance matrix of the state noise. Defaults to the identity
matrix.
R : jax.Array, optional
The covariance matrix of the measurement noise. Defaults to the identity
matrix.
key : jax.Array, optional
The random key used to generate the data. Defaults to
jax.random.key(111).
Returns
-------
A Pomp object initialized with the linear Gaussian model parameters and the generated data.
"""
theta_names = [f"{name}{i}" for name in "ACQR" for i in range(1, 5)]
theta = dict(zip(theta_names, _transform_thetas(A, C, Q, R).tolist()))
ys_temp = pd.DataFrame(
0, index=np.arange(1, T + 1, dtype=float), columns=pd.Index(["Y1", "Y2"])
)
LG_obj_temp = Pomp(
rinit=_rinit,
rproc=_rproc,
dmeas=_dmeas,
rmeas=_rmeas,
ys=ys_temp,
t0=0.0,
nstep=1,
dt=None,
theta=theta,
covars=None,
statenames=["X1", "X2"],
par_trans=ParTrans(to_est=_to_est, from_est=_from_est),
)
LG_obj = LG_obj_temp.simulate(key=key, nsim=1, as_pomp=True)
assert isinstance(LG_obj, Pomp)
return LG_obj