Source code for pypomp.models.dacca

import os
import csv
import jax
import jax.numpy as jnp
import pandas as pd
from pypomp.core.pomp import Pomp
import jax.scipy.special as jspecial
import numpy as np
from pypomp.core.par_trans import ParTrans
from pypomp.types import (
    StateDict,
    ParamDict,
    CovarDict,
    TimeFloat,
    StepSizeFloat,
    InitialTimeFloat,
    RNGKey,
    ObservationDict,
)

theta = {
    "gamma": 20.8,  # recovery rate
    "epsilon": 19.1,  # rate of waning of immunity for severe infections
    "rho": 0.0,  # rate of waning of immunity for inapparent infections
    "m": 0.06,  # cholera mortality rate
    "c": 1.0,  # fraction of infections that lead to severe infection
    "beta_trend": -0.00498,  # slope of secular trend in transmission
    **{
        f"bs{i + 1}": float(b)
        for i, b in enumerate([0.747, 6.38, -3.44, 4.23, 3.33, 4.55])
    },  # seasonal transmission rates
    "sigma": 3.13,  # environmental noise intensity
    "tau": 0.23,  # measurement error s.d.
    "alpha": 1.0,  # non-linear transmission parameter
    "delta": 0.02,  # mortality rate
    "S_0": 0.621,
    "I_0": 0.378,
    "Y_0": 0.0,
    "R1_0": 0.000843,
    "R2_0": 0.000972,
    "R3_0": 1.16e-07,
    **{
        f"omegas{i + 1}": float(omega)
        for i, omega in enumerate(
            jnp.log(jnp.array([0.184, 0.0786, 0.0584, 0.00917, 0.000208, 0.0124]))
        )
    },  # seasonal environmental reservoir parameters
}


test_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(test_dir, os.pardir, "data/dacca")

dacca_path = os.path.join(data_dir, "dacca.csv")
covars_path = os.path.join(data_dir, "covars.csv")
covart_path = os.path.join(data_dir, "covart.csv")

with open(dacca_path, "r") as f:
    reader = csv.reader(f)
    next(reader)
    data = [(float(row[1]), float(row[2])) for row in reader]
    times, values = zip(*data)
    ys = pd.DataFrame(values, index=pd.Index(times), columns=pd.Index(["deaths"]))

with open(covart_path, "r") as f:
    reader = csv.reader(f)
    next(reader)
    covart_index_list = [float(row[1]) for row in reader]
    covart_index_arr = jnp.array(covart_index_list)

with open(covars_path, "r") as f:
    reader = csv.reader(f)
    next(reader)
    covars_data = [[float(value) for value in row[1:]] for row in reader]
    covars = pd.DataFrame(
        covars_data,
        index=np.array(covart_index_arr),
        columns=pd.Index(
            [
                "trend",
                "dpopdt",
                "pop",
                "seas1",
                "seas2",
                "seas3",
                "seas4",
                "seas5",
                "seas6",
            ]
        ),
    )

theta_names = (
    [
        "gamma",
        "m",
        "rho",
        "epsilon",
        "c",
        "beta_trend",
        "sigma",
        "tau",
        "alpha",
        "delta",
        "S_0",
        "I_0",
        "Y_0",
        "R1_0",
        "R2_0",
        "R3_0",
    ]
    + [f"bs{i}" for i in range(1, 7)]
    + [f"omegas{i}" for i in range(1, 7)]
)


statenames = ["S", "I", "Y", "Mn", "R1", "R2", "R3", "count"]
accumvars = ["Mn"]


def _rinit(theta_: ParamDict, key: RNGKey, covars: CovarDict, t0: InitialTimeFloat):
    S_0 = theta_["S_0"]
    I_0 = theta_["I_0"]
    Y_0 = theta_["Y_0"]
    R0 = jnp.array([theta_[f"R{i}_0"] for i in range(1, 4)])

    total_sum = S_0 + I_0 + Y_0 + jnp.sum(R0)
    pop = covars["pop"]

    S = pop * S_0 / total_sum
    I = pop * I_0 / total_sum
    Y = pop * Y_0 / total_sum
    R = pop * R0 / total_sum

    Mn = 0.0
    count = 0.0
    return {
        "S": S,
        "I": I,
        "Y": Y,
        "Mn": Mn,
        "R1": R[0],
        "R2": R[1],
        "R3": R[2],
        "count": count,
    }


def _rproc(
    X_: StateDict,
    theta_: ParamDict,
    key: RNGKey,
    covars: CovarDict,
    t: TimeFloat,
    dt: StepSizeFloat,
):
    S = X_["S"]
    I = X_["I"]
    Y = X_["Y"]
    deaths = X_["Mn"]
    pts = jnp.array([X_["R1"], X_["R2"], X_["R3"]])
    count = X_["count"]
    trend = covars["trend"]
    dpopdt = covars["dpopdt"]
    pop = covars["pop"]
    seas = jnp.array([covars[f"seas{i}"] for i in range(1, 7)])
    gamma = theta_["gamma"]
    deltaI = theta_["m"]
    rho = theta_["rho"]
    eps = theta_["epsilon"]
    clin = theta_["c"]
    beta_trend = theta_["beta_trend"]
    sd_beta = theta_["sigma"]
    alpha = theta_["alpha"]
    delta = theta_["delta"]
    omegas = jnp.array([theta_[f"omegas{i}"] for i in range(1, 7)])
    bs = jnp.array([theta_[f"bs{i}"] for i in range(1, 7)])

    nrstage = 3
    std = jnp.sqrt(dt)

    neps = eps * nrstage  # rate
    passages = jnp.zeros(nrstage + 1)

    # Get current time step values
    beta = jnp.exp(beta_trend * trend + jnp.dot(bs, seas))
    omega = jnp.exp(jnp.dot(omegas, seas))

    subkey, key = jax.random.split(key)
    dw = jax.random.normal(subkey) * std

    effI = (I / pop) ** alpha
    births = dpopdt + delta * pop
    passages = passages.at[0].set(gamma * I)
    ideaths = delta * I
    disease = deltaI * I
    ydeaths = delta * Y
    wanings = rho * Y

    rdeaths = pts * delta
    passages = passages.at[1:].set(pts * neps)

    infections = (omega + (beta + sd_beta * dw / dt) * effI) * S
    sdeaths = delta * S

    S += (births - infections - sdeaths + passages[nrstage] + wanings) * dt
    I += (clin * infections - disease - ideaths - passages[0]) * dt
    Y += ((1 - clin) * infections - ydeaths - wanings) * dt

    pts = pts + (passages[:-1] - passages[1:] - rdeaths) * dt

    deaths = deaths + disease * dt

    count = count + jnp.any(jnp.hstack([jnp.array([S, I, Y, deaths]), pts]) < 0)

    S = jnp.clip(S, 0)
    I = jnp.clip(I, 0)
    Y = jnp.clip(Y, 0)
    pts = jnp.clip(pts, 0)
    deaths = jnp.clip(deaths, 0)

    return {
        "S": S,
        "I": I,
        "Y": Y,
        "Mn": deaths,
        "R1": pts[0],
        "R2": pts[1],
        "R3": pts[2],
        "count": count,
    }


def _rproc_gamma(
    X_: StateDict,
    theta_: ParamDict,
    key: RNGKey,
    covars: CovarDict,
    t: TimeFloat,
    dt: StepSizeFloat,
):
    S = X_["S"]
    I = X_["I"]
    Y = X_["Y"]
    deaths = X_["Mn"]
    pts = jnp.array([X_["R1"], X_["R2"], X_["R3"]])
    count = X_["count"]
    trend = covars["trend"]
    dpopdt = covars["dpopdt"]
    pop = covars["pop"]
    seas = jnp.array([covars[f"seas{i}"] for i in range(1, 7)])
    gamma = theta_["gamma"]
    deltaI = theta_["m"]
    rho = theta_["rho"]
    eps = theta_["epsilon"]
    clin = theta_["c"]
    beta_trend = theta_["beta_trend"]
    sd_beta = theta_["sigma"]
    alpha = theta_["alpha"]
    delta = theta_["delta"]
    omegas = jnp.array([theta_[f"omegas{i}"] for i in range(1, 7)])
    bs = jnp.array([theta_[f"bs{i}"] for i in range(1, 7)])

    nrstage = 3
    # std = jnp.sqrt(dt)

    neps = eps * nrstage  # rate
    passages = jnp.zeros(nrstage + 1)

    # Get current time step values
    beta = jnp.exp(beta_trend * trend + jnp.dot(bs, seas))
    omega = jnp.exp(jnp.dot(omegas, seas))

    subkey, key = jax.random.split(key)
    # dw = jax.random.normal(subkey) * std

    effI = (I / pop) ** alpha
    births = dpopdt + delta * pop
    passages = passages.at[0].set(gamma * I)
    ideaths = delta * I
    disease = deltaI * I
    ydeaths = delta * Y
    wanings = rho * Y

    rdeaths = pts * delta
    passages = passages.at[1:].set(pts * neps)

    """
    # old code: perturb = sd_beta * dw / dt, where dw is a standard normal
        rproc does the above
    # this function draws from a gamma white noise process 
            Gamma(shape=dt/sigma**2, scale=sigma**2)
    # with gamma noise, want the mean to be dt, 
            and the variance to be sd_beta**2 * dt,
            before dividing by dt to yield multiplicative noise by 1
    """

    perturb = jax.random.gamma(subkey, dt / sd_beta**2) * sd_beta**2 / dt
    infections = (omega + beta * perturb * effI) * S

    sdeaths = delta * S

    S += (births - infections - sdeaths + passages[nrstage] + wanings) * dt
    I += (clin * infections - disease - ideaths - passages[0]) * dt
    Y += ((1 - clin) * infections - ydeaths - wanings) * dt

    pts = pts + (passages[:-1] - passages[1:] - rdeaths) * dt

    deaths = deaths + disease * dt

    count = count + jnp.any(jnp.hstack([jnp.array([S, I, Y, deaths]), pts]) < 0)

    S = jnp.clip(S, 0)
    I = jnp.clip(I, 0)
    Y = jnp.clip(Y, 0)
    pts = jnp.clip(pts, 0)
    deaths = jnp.clip(deaths, 0)

    return {
        "S": S,
        "I": I,
        "Y": Y,
        "Mn": deaths,
        "R1": pts[0],
        "R2": pts[1],
        "R3": pts[2],
        "count": count,
    }


def _dmeas_helper(y, deaths, v, tol, ltol):
    return jnp.logaddexp(
        jax.scipy.stats.norm.logpdf(y, loc=deaths, scale=v + tol), ltol
    ).reshape(-1)


def _dmeas_helper_tol(y, deaths, v, tol, ltol):
    return jnp.array([ltol])


def _dmeas(
    Y_: ObservationDict,
    X_: StateDict,
    theta_: ParamDict,
    covars: CovarDict,
    t: TimeFloat,
):
    deaths = X_["Mn"]
    count = X_["count"]
    tol = 1.0e-18
    ltol = jnp.log(tol)
    tau = theta_["tau"]
    v = tau * deaths
    # return jax.scipy.stats.norm.logpdf(y, loc=deaths, scale=v)
    y = Y_["deaths"]
    result = jax.lax.cond(
        jnp.logical_or(
            (1 - jnp.isfinite(v)).astype(bool), count > 0
        ),  # if Y < 0 then count violation
        _dmeas_helper_tol,
        _dmeas_helper,
        *(y, deaths, v, tol, ltol),
    )
    return jnp.reshape(result, ())


def _rmeas(
    X_: StateDict,
    theta_: ParamDict,
    key: RNGKey,
    covars: CovarDict,
    t: TimeFloat,
):
    deaths = X_["Mn"]
    tau = theta_["tau"]
    v = tau * deaths
    return jax.random.normal(key) * v + deaths


def _to_est(theta: ParamDict) -> ParamDict:
    IVP_list = ["S_0", "I_0", "Y_0", "R1_0", "R2_0", "R3_0"]
    IVPs = jnp.array([theta[k] for k in IVP_list])
    IVP_ests = jnp.log(IVPs / jnp.sum(IVPs))
    return {
        "gamma": jnp.log(theta["gamma"]),
        "m": jnp.log(theta["m"]),
        "rho": jnp.log(theta["rho"]),
        "epsilon": jnp.log(theta["epsilon"]),
        "c": jspecial.logit(theta["c"]),
        "beta_trend": theta["beta_trend"] * 100,
        "sigma": jnp.log(theta["sigma"]),
        "tau": jnp.log(theta["tau"]),
        "alpha": jnp.log(theta["alpha"]),
        "delta": jnp.log(theta["delta"]),
        **{k: IVP_ests[i] for i, k in enumerate(IVP_list)},
        **{f"bs{i}": theta[f"bs{i}"] for i in range(1, 7)},
        **{f"omegas{i}": theta[f"omegas{i}"] for i in range(1, 7)},
    }


def _from_est(theta: ParamDict) -> ParamDict:
    IVP_list = ["S_0", "I_0", "Y_0", "R1_0", "R2_0", "R3_0"]
    IVP_ests = jnp.exp(jnp.array([theta[k] for k in IVP_list]))
    IVPs = IVP_ests / jnp.sum(IVP_ests)
    return {
        "gamma": jnp.exp(theta["gamma"]),
        "m": jnp.exp(theta["m"]),
        "rho": jnp.exp(theta["rho"]),
        "epsilon": jnp.exp(theta["epsilon"]),
        "c": jspecial.expit(theta["c"]),
        "beta_trend": theta["beta_trend"] / 100,
        "sigma": jnp.exp(theta["sigma"]),
        "tau": jnp.exp(theta["tau"]),
        "alpha": jnp.exp(theta["alpha"]),
        "delta": jnp.exp(theta["delta"]),
        **{k: IVPs[i] for i, k in enumerate(IVP_list)},
        **{f"bs{i}": theta[f"bs{i}"] for i in range(1, 7)},
        **{f"omegas{i}": theta[f"omegas{i}"] for i in range(1, 7)},
    }


[docs] def dhaka( dt: float | None = 1 / 240, nstep: int | None = None, gamma: bool = False ) -> Pomp: """ Creates a POMP model for the Dhaka cholera data. This function constructs a Partially Observed Markov Process (POMP) model for the Dhaka cholera dataset. The model includes a stochastic process for the underlying disease dynamics and a measurement model for observed deaths. Arguments --------- dt : float, optional Time step size for the process model. Determines the number of sub-steps per observation interval for the process model. nstep : int, optional Number of sub-steps per observation interval for the process model. If None, uses Euler discretization with the specified step size. nstep and dt cannot both be not None. gamma : bool, optional Indicator for whether gamma white noise should be used in place of Gaussian noise. This corresponds to a large-population approximation of an overdispersed death process. Model Parameters ---------------- gamma : float Recovery rate (duration of infectiousness). epsilon : float Rate of waning of immunity for severe infections. rho : float Rate of waning of immunity for inapparent infections. m : float Cholera-specific mortality rate. c : float Fraction of infections that lead to severe (clinically apparent) infection. beta_trend : float Slope of the secular trend in transmission. bs1-bs6 : float Seasonal transmission rates (B-spline coefficients). sigma : float Environmental noise intensity. tau : float Measurement error standard deviation. alpha : float Non-linear transmission parameter. delta : float Natural mortality rate. S_0, I_0, Y_0, R1_0, R2_0, R3_0 : float Initial value parameters (IVPs) for the model state proportions. omegas1-omegas6 : float Seasonal environmental reservoir parameters (B-spline coefficients for the log-reservoir). Returns ------- Pomp A POMP model object representing the Dhaka cholera model. """ rproc_func = _rproc_gamma if gamma else _rproc if gamma: print( "Warning: Using overdispersed gamma white noise. Ensure this is intended behavior." ) if nstep is not None and dt is not None: raise ValueError("Cannot specify both dt and nstep") dhaka_obj = Pomp( rinit=_rinit, rproc=rproc_func, dmeas=_dmeas, rmeas=_rmeas, ys=ys, t0=1891.0, nstep=nstep, dt=dt, accumvars=accumvars, theta=theta, covars=covars, statenames=statenames, par_trans=ParTrans(to_est=_to_est, from_est=_from_est), ) return dhaka_obj
dacca = dhaka