Source code for pypomp.models.linear_gaussian

"""This module implements a linear Gaussian model for POMP."""

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd

from pypomp.core.pomp import Pomp
from pypomp.core.par_trans import ParTrans
from pypomp.types import (
    StateDict,
    ParamDict,
    CovarDict,
    TimeFloat,
    StepSizeFloat,
    RNGKey,
    ObservationDict,
    InitialTimeFloat,
)


def _get_thetas(theta):
    A = jnp.array([theta["A1"], theta["A2"], theta["A3"], theta["A4"]]).reshape(2, 2)
    C = jnp.array([theta["C1"], theta["C2"], theta["C3"], theta["C4"]]).reshape(2, 2)
    Q = jnp.array([theta["Q1"], theta["Q2"], theta["Q3"], theta["Q4"]]).reshape(2, 2)
    R = jnp.array([theta["R1"], theta["R2"], theta["R3"], theta["R4"]]).reshape(2, 2)
    return A, C, Q, R


def _transform_thetas(A, C, Q, R):
    return jnp.concatenate([A.flatten(), C.flatten(), Q.flatten(), R.flatten()])


# TODO: Add custom starting position.
def _rinit(
    theta_: ParamDict,
    key: RNGKey,
    covars: CovarDict,
    t0: InitialTimeFloat,
):
    A, C, Q, R = _get_thetas(theta_)
    result = jax.random.multivariate_normal(key=key, mean=jnp.array([0, 0]), cov=Q)
    return {"X1": result[0], "X2": result[1]}


def _rproc(
    X_: StateDict,
    theta_: ParamDict,
    key: RNGKey,
    covars: CovarDict,
    t: TimeFloat,
    dt: StepSizeFloat,
):
    A, C, Q, R = _get_thetas(theta_)
    X_array = jnp.array([X_["X1"], X_["X2"]])
    result = jax.random.multivariate_normal(key=key, mean=A @ X_array, cov=Q)
    return {"X1": result[0], "X2": result[1]}


def _dmeas(
    Y_: ObservationDict,
    X_: StateDict,
    theta_: ParamDict,
    covars: CovarDict,
    t: TimeFloat,
):
    A, C, Q, R = _get_thetas(theta_)
    X_array = jnp.array([X_["X1"], X_["X2"]])
    Y_array = jnp.array([Y_["Y1"], Y_["Y2"]])
    return jax.scipy.stats.multivariate_normal.logpdf(Y_array, X_array, R)


def _rmeas(
    X_: StateDict,
    theta_: ParamDict,
    key: RNGKey,
    covars: CovarDict,
    t: TimeFloat,
):
    A, C, Q, R = _get_thetas(theta_)
    X_array = jnp.array([X_["X1"], X_["X2"]])
    return jax.random.multivariate_normal(key=key, mean=C @ X_array, cov=R)


def _to_est(theta: ParamDict) -> ParamDict:
    new_theta = {**theta}
    for name in "ACQR":
        new_theta[f"{name}1"] = jnp.log(theta[f"{name}1"])
        new_theta[f"{name}4"] = jnp.log(theta[f"{name}4"])
    return new_theta


def _from_est(theta: ParamDict) -> ParamDict:
    new_theta = {**theta}
    for name in "ACQR":
        new_theta[f"{name}1"] = jnp.exp(theta[f"{name}1"])
        new_theta[f"{name}4"] = jnp.exp(theta[f"{name}4"])
    return new_theta


[docs] def LG( T: int = 4, A: jax.Array = jnp.array( [[jnp.cos(0.2), -jnp.sin(0.2)], [jnp.sin(0.2), jnp.cos(0.2)]] ), C: jax.Array = jnp.eye(2), Q: jax.Array = jnp.array([[1, 2e-2], [2e-2, 1]]) / 100, R: jax.Array = jnp.array([[1, 0.1], [0.1, 1]]) / 10, key: jax.Array = jax.random.key(111), ) -> Pomp: """ Initialize a Pomp object with the linear Gaussian model. Parameters ---------- T : int, optional The number of time steps to generate data for. Defaults to 4. A : jax.Array, optional The transition matrix. Defaults to the identity matrix. C : jax.Array, optional The measurement matrix. Defaults to the identity matrix. Q : jax.Array, optional The covariance matrix of the state noise. Defaults to the identity matrix. R : jax.Array, optional The covariance matrix of the measurement noise. Defaults to the identity matrix. key : jax.Array, optional The random key used to generate the data. Defaults to jax.random.key(111). Returns ------- A Pomp object initialized with the linear Gaussian model parameters and the generated data. """ theta_names = [f"{name}{i}" for name in "ACQR" for i in range(1, 5)] theta = dict(zip(theta_names, _transform_thetas(A, C, Q, R).tolist())) ys_temp = pd.DataFrame( 0, index=np.arange(1, T + 1, dtype=float), columns=pd.Index(["Y1", "Y2"]) ) LG_obj_temp = Pomp( rinit=_rinit, rproc=_rproc, dmeas=_dmeas, rmeas=_rmeas, ys=ys_temp, t0=0.0, nstep=1, dt=None, theta=theta, covars=None, statenames=["X1", "X2"], par_trans=ParTrans(to_est=_to_est, from_est=_from_est), ) LG_obj = LG_obj_temp.simulate(key=key, nsim=1, as_pomp=True) assert isinstance(LG_obj, Pomp) return LG_obj