Source code for pypomp.random.poisson

"""
JAX implementation of the single-precision inverse Poisson CDF approximation.

The implementation ports NVIDIA's CURAND `poissinvf` CUDA device routine to
Python so it can be composed with `jax.jit`/`jax.vmap`.  The structure matches
the original algorithm: central-region polynomial approximation, Newton
iteration fallback, and a final bottom-up / top-down summation when the rate is
small.
"""

from __future__ import annotations

from typing import Tuple, Any
from functools import partial

import jax
from jax import Array, lax
import jax.numpy as jnp
from jax.scipy import special as jsp_special
import numpy as np
from jax._src import dtypes

from ._dtype_helpers import check_and_canonicalize_user_dtype, _get_available_dtype


_RM_COEFFS: Tuple[float, ...] = (
    2.82298751e-07,
    -2.58136133e-06,
    1.02118025e-05,
    -2.37996199e-05,
    4.05347462e-05,
    -6.63730967e-05,
    1.24762566e-04,
    -2.56970731e-04,
    5.58953132e-04,
    -1.33129194e-03,
    3.70367937e-03,
    -1.38888706e-02,
    1.66666667e-01,
)

_T_COEFFS: Tuple[float, ...] = (
    1.86386867e-05,
    -2.07319499e-04,
    9.68945100e-04,
    -2.47340054e-03,
    3.79952985e-03,
    -3.86717047e-03,
    3.46960934e-03,
    -4.14125511e-03,
    5.86752093e-03,
    -8.38583787e-03,
    1.32793933e-02,
    -2.77755360e-02,
    3.33333333e-01,
)

_X_COEFFS: Tuple[float, ...] = (
    -1.45852240e-04,
    1.46121529e-03,
    -6.10328845e-03,
    1.38117964e-02,
    -1.86988746e-02,
    1.68155118e-02,
    -1.33947970e-02,
    1.35698573e-02,
    -1.55377333e-02,
    1.74065334e-02,
    -1.98011178e-02,
)

# These coefficient arrays are kept at float32 for efficiency
# They will be cast to the appropriate dtype during computation
_RM_COEFFS_ARR = jnp.array(_RM_COEFFS, dtype=jnp.float32)
_T_COEFFS_ARR = jnp.array(_T_COEFFS, dtype=jnp.float32)
_X_COEFFS_ARR = jnp.array(_X_COEFFS, dtype=jnp.float32)


def _central_region(s: Array, lam: Array, dtype) -> Array:
    # Cast coefficients to the working dtype
    rm_coeffs = _RM_COEFFS_ARR.astype(dtype)
    t_coeffs = _T_COEFFS_ARR.astype(dtype)
    x_coeffs = _X_COEFFS_ARR.astype(dtype)

    rm = jnp.polyval(rm_coeffs, s)
    rm = s + s * (rm * s)

    t = jnp.polyval(t_coeffs, rm)
    x = jnp.polyval(x_coeffs, rm) / lam

    total = lam + (x + t) + lam * rm
    return jnp.floor(total)


def _newton_region(s: Array, lam: Array, dtype) -> Array:
    MAX_LOOPS = 5
    r = jnp.maximum(0.1, 1.0 + s)
    r_prev = r
    first = jnp.array(True, dtype=jnp.bool_)
    counter = 0

    for _ in range(MAX_LOOPS):
        diff = jnp.abs(r - r_prev)
        not_done = jnp.logical_or(first, diff > 1e-5)
        not_max_loops = counter < MAX_LOOPS
        keep_going = jnp.logical_and(not_done, not_max_loops)

        t = jnp.log(r)
        s2 = jnp.sqrt(2.0 * ((1.0 - r) + r * t))
        s2 = jnp.where(r < 1.0, -s2, s2)
        next_r = r - (s2 - s) * s2 / t
        next_r = jnp.maximum(next_r, 0.1 * r)

        # Only update variables if condition is True
        r_new = jnp.where(keep_going, next_r, r)
        r_prev_new = jnp.where(keep_going, r, r_prev)
        first_new = jnp.array(False, dtype=jnp.bool_)
        counter_new = counter + 1

        r, r_prev, first, counter = r_new, r_prev_new, first_new, counter_new

    t = jnp.log(r)
    sqrt_term = jnp.sqrt(2.0 * r * ((1.0 - r) + r * t))
    log_correction = jnp.log(
        sqrt_term / jnp.maximum(jnp.abs(r - 1.0), jnp.finfo(dtype).tiny)
    )
    x = lam * r + log_correction / t
    x -= 0.0218 / (x + 0.065 * lam)
    return jnp.floor(x)


def _bottom_up(u: Array, lam: Array, dtype) -> Array:
    lami = 1.0 / lam

    t0 = jnp.exp(0.5 * lam)
    del0 = jnp.where(u > 0.5, t0 * (1e-6 * t0), 0.0)
    s0 = 1.0 - t0 * (u * t0) + del0

    def unrolled_computation(
        x_init: Array, s0: Array, del0: Array, lami: Array
    ) -> Tuple[Array, Array, Array]:
        MAX_LOOPS = 20

        # Initialize state
        x: Array = x_init
        s: Array = s0
        delta: Array = del0
        t = jnp.array(0.0, dtype=dtype)
        zero = jnp.array(0.0, dtype=dtype)
        one = jnp.array(1.0, dtype=dtype)

        # Track if we are still running (equivalent to cond1)
        active = jnp.array(True)

        # JAX will unroll this loop during compilation
        for _ in range(MAX_LOOPS):
            current_cond = s < zero

            # Determine if we should update in this step
            # We continue only if we were already active AND the condition holds
            keep_going = jnp.logical_and(active, current_cond)

            # Calculate candidates for next step
            x_next: Array = x + one
            t_next: Array = x_next * lami
            delta_next: Array = t_next * delta
            s_next: Array = t_next * s + one

            # Apply updates only if keep_going is True
            x = jnp.where(keep_going, x_next, x)
            s = jnp.where(keep_going, s_next, s)
            delta = jnp.where(keep_going, delta_next, delta)
            t = jnp.where(keep_going, t_next, t)

            # Update the active flag (once it turns False, it stays False)
            active = keep_going

        return x, s, delta

    x_init = jnp.array(0.0, dtype=dtype)
    x, s, delta = unrolled_computation(x_init, s0, del0, lami)

    def top_down_branch(state: Tuple[Array, Array]) -> Array:
        x_val, delta_val = state
        one = jnp.array(1.0, dtype=dtype)
        zero = jnp.array(0.0, dtype=dtype)

        # Setup
        delta_scaled = jnp.array(1e6, dtype=dtype) * delta_val
        t_thresh = jnp.array(1e7, dtype=dtype) * delta_scaled
        delta_scaled = (one - u) * delta_scaled

        # Unrolled first loop (finding x_hi, delta_hi)
        MAX_LOOPS_2 = 20
        x_hi: Array = x_val
        delta_hi: Array = delta_scaled
        for _ in range(MAX_LOOPS_2):
            cond: Array = delta_hi < t_thresh
            x_next: Array = x_hi + one
            delta_next: Array = delta_hi * (x_next * lami)
            x_hi = jnp.where(cond, x_next, x_hi)
            delta_hi = jnp.where(cond, delta_next, delta_hi)

        # Unrolled second loop (finding x_lo)
        MAX_LOOPS_3 = 20
        x_lo: Array = x_hi
        s_lo: Array = delta_hi
        t_lo: Array = one
        for _ in range(MAX_LOOPS_3):
            cond: Array = s_lo > zero
            t_next: Array = t_lo * (x_lo * lami)
            s_next: Array = s_lo - t_next
            x_next: Array = x_lo - one
            x_lo = jnp.where(cond, x_next, x_lo)
            s_lo = jnp.where(cond, s_next, s_lo)
            t_lo = jnp.where(cond, t_next, t_lo)
        return x_lo

    two = jnp.array(2.0, dtype=dtype)
    return lax.cond(
        s < two * delta,
        top_down_branch,
        lambda state: state[0],
        operand=(x, delta),
    )


def _poissoninv_scalar(u: Array, lam: Array, dtype) -> Array:
    u = jnp.asarray(u, dtype=dtype)
    lam = jnp.asarray(lam, dtype=dtype)

    zero = jnp.array(0.0, dtype=dtype)
    one = jnp.array(1.0, dtype=dtype)
    x0 = zero
    sqrt2 = jnp.sqrt(jnp.array(2.0, dtype=dtype))

    lam_invalid = lam <= zero
    lam_safe = jnp.where(lam_invalid, one, lam)

    def large_lambda_case(_: Any) -> Array:
        s = jsp_special.ndtri(u) * lax.rsqrt(lam_safe)

        def central(_: Any) -> Array:
            return _central_region(s, lam_safe, dtype)

        def non_central(_: Any) -> Array:
            return lax.cond(
                s > -sqrt2,
                lambda __: _newton_region(s, lam_safe, dtype),
                lambda __: x0,
                operand=zero,
            )

        return lax.cond(
            jnp.logical_and(
                s > jnp.array(-0.6833501, dtype=dtype),
                s < jnp.array(1.777993, dtype=dtype),
            ),
            central,
            non_central,
            operand=zero,
        )

    large_lambda = lam_safe > jnp.array(4.0, dtype=dtype)
    x_large: Array = lax.cond(
        large_lambda,
        large_lambda_case,
        lambda _: x0,
        operand=zero,
    )

    def bottom_up_branch(_: Any) -> Array:
        return _bottom_up(u, lam_safe, dtype)

    bottom_up = x_large <= jnp.array(10.0, dtype=dtype)
    x: Array = lax.cond(
        bottom_up,
        bottom_up_branch,
        lambda _: x_large,
        operand=zero,
    )

    nan = jnp.array(jnp.nan, dtype=dtype)
    inf = jnp.array(jnp.inf, dtype=dtype)

    x = jnp.where(u < zero, nan, x)
    x = jnp.where(u == zero, zero, x)
    x = jnp.where(u == one, inf, x)
    x = jnp.where(u > one, nan, x)
    x = jnp.where(lam_invalid, nan, x)
    x = jnp.where(x < zero, zero, x)
    return x


_poissoninv_vmap = jax.vmap(_poissoninv_scalar, in_axes=(0, 0, None))


@partial(jax.jit, static_argnames=["dtype"])
def poissoninv(u: Array, lam: Array, dtype=jnp.float32) -> Array:
    """
    Vectorized inverse Poisson CDF approximation using JAX primitives.

    Args:
        u: Probabilities (scalar or array) in the interval [0, 1].
        lam: Corresponding Poisson rate(s), must be positive.
        dtype: Data type for the computation (default float32).

    Returns:
        DeviceArray with the same broadcast shape as `u` and `lam`.
    """

    u_arr, lam_arr = jnp.broadcast_arrays(u, lam)
    flat_u = u_arr.reshape(-1)
    flat_lam = lam_arr.reshape(-1)
    flat_res = _poissoninv_vmap(flat_u, flat_lam, dtype)
    return flat_res.reshape(u_arr.shape)


[docs] @partial(jax.jit, static_argnames=["dtype"]) def fast_poisson(key: Array, lam: Array, dtype: np.dtype | None = None) -> Array: """ Generate a Poisson random variable with given rate parameter using an approximate inverse CDF method in order to run fast on GPUs. Follows the methodology from Giles (2016). We made some ad-hoc modifications to the algorithm to improve the speed. In particular, we put a cap on how many iterations the Newton-Raphson method and the exact inverse CDF method can take, and we adjusted the thresholds for applying the exact inverse CDF method. Our implementation of the method does not produce exact Poisson random variables, but it is very close to exact. Args: key: a PRNG key used as the random key. lam: rate parameters for the Poisson distribution. dtype: optional, an integer dtype for the returned values (default int64 if jax_enable_x64 is true, otherwise int32). Returns: A Poisson random variable. References: * Giles, Michael B. "Algorithm 955: Approximation of the Inverse Poisson Cumulative Distribution Function." ACM Transactions on Mathematical Software 42, no. 1 (2016): 1–22. https://doi.org/10.1145/2699466. """ dtype = check_and_canonicalize_user_dtype(int if dtype is None else dtype) assert dtype is not None if not dtypes.issubdtype(dtype, np.integer): raise ValueError( f"dtype argument to `fast_poisson` must be an integer dtype, got {dtype}" ) dtype = _get_available_dtype(dtype) assert dtype is not None if dtypes.issubdtype(dtype, np.int64): float_dtype = jnp.float64 else: float_dtype = jnp.float32 float_dtype = _get_available_dtype(float_dtype) assert float_dtype is not None lam = jnp.asarray(lam) shape = lam.shape u = jax.random.uniform(key, shape, dtype=float_dtype) # Clamp u to be slightly less than 1.0 to avoid inf output # Use nextafter to get the largest float < 1.0 u_max = jnp.nextafter( jnp.array(1.0, dtype=float_dtype), jnp.array(0.0, dtype=float_dtype) ) u = jnp.minimum(u, u_max) lam_float = lam.astype(float_dtype) x = poissoninv(u, lam_float, dtype=float_dtype) # Cap the output to a reasonable maximum to prevent overflow max_val = lam_float + jnp.array(10.0, dtype=float_dtype) * jnp.sqrt( jnp.maximum(lam_float, jnp.array(1.0, dtype=float_dtype)) ) x = jnp.minimum(x, max_val) return x.astype(dtype)