Learning Rate

class pypomp.core.learning_rate.LearningRate(rates: Mapping[str, float | list[float] | ndarray])[source]

Bases: object

Represent the learning rate schedule for model parameters during training.

This class encapsulates learning rate values for each parameter, which can be either constant values or time-varying schedules (1D arrays of length M). It provides utility methods to generate common decay schedules such as cosine, geometric, and linear decay.

Parameters:

rates (Mapping[str, Union[float, list[float], np.ndarray]]) – Learning rates keyed by parameter name. Can be a single float, a list of floats, or a numpy array.

Examples

>>> import pypomp as pp
>>> rates = pp.LearningRate({"beta": 0.1, "rho": 0.01})
>>> rates = pp.LearningRate({"beta": [0.1, 0.2], "rho": [0.01, 0.02]})
>>> rates = pp.LearningRate({"beta": np.array([0.1, 0.2]), "rho": np.array([0.01, 0.02])})

Attributes

rates

Dictionary mapping parameter names to learning rate values or schedules.

Methods

to_array(param_names, M)

Convert the learning rates into a JAX array of shape (M, n_params).

cosine_decay(final_factor, M)

Apply a cosine cooling schedule to all current rates.

geometric_decay(decay_rate, M)

Apply a geometric decay schedule: eta_t = eta_0 * (decay_rate ^ t).

linear_decay(final_factor, M)

Apply a linear decay schedule from 1.0 down to final_factor.