pypomp.core.learning_rate.LearningRate.to_array

LearningRate.to_array(param_names: list[str], M: int) Array[source]

Convert the learning rates into a JAX array of shape (M, n_params).

Parameters:
  • param_names (list[str]) – List of parameter names in canonical order.

  • M (int) – Number of iterations in the training schedule.

Returns:

A 2D array where each column is the learning rate schedule for a parameter.

Return type:

jax.Array