Optimizers#
JAX SparseCore supports several optimizers for training embedding tables. These optimizers
are specified in the TableSpec for each table.
Introduction#
Optimizers are algorithms or methods used to change the attributes of the neural network
such as weights and learning rate in order to reduce the losses. JAX SparseCore supports
several common optimizers used for training large-scale recommendation models. You can specify
the optimizer for each embedding table via the optimizer field in the TableSpec.
Choosing an Optimizer: OptimizerSpec vs. optax#
When training a model with JAX SparseCore, you will encounter two approaches to optimization:
the ...OptimizerSpec classes provided by this library, and standard optax optimizers.
For training SparseCore embedding tables, we strongly encourage you to use the
...OptimizerSpec classes detailed in this document (e.g., SGDOptimizerSpec,
AdamOptimizerSpec). These specs configure highly efficient optimizers where the
update logic is fused directly into the backward pass on the SparseCore hardware. This avoids
costly round-trips of gradients to the host and provides the best performance.
While optax is a powerful and flexible library for JAX, its role in a SparseCore model is
for any other non-embedding parameters your model might have. JAX SparseCore provides a helper
function, create_optimizer_for_sc_model, which applies a given optax optimizer to your
model’s other parameters, while ensuring the specialized ...OptimizerSpec logic is used for
the embedding tables.
API#
|
Base class for the optimizer specs. |
|
Spec for the Adagrad with Momentum optimizer. |
|
Spec for the Adagrad optimizer. |
|
Spec for the Adam optimizer. |
|
Spec for the FTRL optimizer. |
|
Spec for the LaProp optimizer. |
|
Spec for the Stochastic Gradient Descent (SGD) optimizer. |
- class OptimizerSpec(learning_rate)#
Base class for the optimizer specs.
This base class defines the interface for all the optimizer specs.
- Parameters:
learning_rate (float | Callable[..., float | jax.Array])
- learning_rate#
The learning rate for the training variables or embeddings.
- get_hyperparameters(step=None)#
Returns the hyperparameters for the optimizer.
- get_learning_rate(step=None)#
Returns the learning rate for the optimizer.
- abstractmethod get_optimizer_primitive()#
Derived classes should implement this method to return the xla primitive for the optimizer.
- Return type:
- abstractmethod short_name()#
Implement this method to return a short name for the optimizer.
This short name will be used as part of the identifier for the variables being trained.
- Return type:
str
- slot_variables_count()#
Returns the number of slot variables for the optimizer.
- Return type:
int
- slot_variables_initializers()#
Slot variables initializers for the optimizer.
Derived classes should implement this method to return the initializers for the applicable slot variables, if any.
- Returns:
A tuple of initializers for the slot variables.
- Return type:
tuple[Initializer, …]
- class AdagradMomentumOptimizerSpec(learning_rate=0.001, momentum=0.9, beta2=1.0, epsilon=1e-10, exponent=2.0, use_nesterov=False, initial_accumulator_value=0.1, initial_momentum_value=0.0)#
Spec for the Adagrad with Momentum optimizer.
An Adagrad with Momentum optimizer is an adaptive optimizer that combines the benefits of both Adagrad and Momentum. It adjusts the learning rate for each embedding variable based on its past gradients, while also incorporating momentum to accelerate convergence.
- Parameters:
- learning_rate#
The learning rate for the training variables or embeddings.
- momentum#
The momentum parameter.
- initial_accumulator_value#
The initial value for the accumulator slot variable.
- initial_momentum_value#
The initial value for the momentum slot variable.
- beta2#
The exponential decay rate for the 2nd moment estimates.
- epsilon#
A small constant for numerical stability.
- exponent#
The exponent for the gradient squared accumulator.
- use_nesterov#
Whether to use Nesterov momentum.
- get_hyperparameters(step=None)#
Returns the hyperparameters for the optimizer.
- Return type:
tuple[Array, …]
- get_optimizer_primitive()#
Derived classes should implement this method to return the xla primitive for the optimizer.
- Return type:
- short_name()#
Implement this method to return a short name for the optimizer.
This short name will be used as part of the identifier for the variables being trained.
- Return type:
str
- slot_variables_initializers()#
Slot variables initializers for the optimizer.
Derived classes should implement this method to return the initializers for the applicable slot variables, if any.
- Returns:
A tuple of initializers for the slot variables.
- Return type:
tuple[Initializer, …]
- class AdagradOptimizerSpec(learning_rate=0.001, initial_accumulator_value=0.1)#
Spec for the Adagrad optimizer.
An Adagrad optimizer is an adaptive optimizer that adjusts the learning rate for each embedding variable based on its past gradients. This helps in reducing the number of steps needed for convergence, especially for sparse data.
- Parameters:
- learning_rate#
The learning rate for the training variables or embeddings.
- initial_accumulator_value#
The initial value for the accumulator slot variable. This constant is used to initialize the accumulator slot variable.
- get_optimizer_primitive()#
Derived classes should implement this method to return the xla primitive for the optimizer.
- Return type:
- short_name()#
Implement this method to return a short name for the optimizer.
This short name will be used as part of the identifier for the variables being trained.
- Return type:
str
- slot_variables_initializers()#
Slot variables initializers for the optimizer.
Derived classes should implement this method to return the initializers for the applicable slot variables, if any.
- Returns:
A tuple of initializers for the slot variables.
- Return type:
tuple[Initializer, …]
- class AdamOptimizerSpec(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)#
Spec for the Adam optimizer.
Adam optimization is a stochastic gradient descent method that is based on adaptive estimation of first-order and second-order moments.
According to Kingma et al., 2014, the method is “computationally efficient, has little memory requirement, invariant to diagonal rescaling of gradients, and is well suited for problems that are large in terms of data/parameters”.
- Parameters:
- learning_rate#
The learning rate for the training variables or embeddings.
- beta_1#
A float value or a constant float tensor, or a callable that takes no arguments and returns the actual value to use. The exponential decay rate for the 1st moment estimates. Defaults to 0.9.
- beta_2#
A float value or a constant float tensor, or a callable that takes no arguments and returns the actual value to use. The exponential decay rate for the 2nd moment estimates. Defaults to 0.999.
- epsilon#
A small constant for numerical stability. Defaults to 1e-8.
- get_hyperparameters(step=None)#
Compute the bias-corrected Adam hyperparameters.
Here we use the bias-corrected parameters from section 2.1 of the original paper:
alpha_t = alpha * sqrt(1 - beta_2^t) / (1 - beta_1^t) epsilon_hat = epsilon * sqrt(1 + beta_2^t)
- get_optimizer_primitive()#
Derived classes should implement this method to return the xla primitive for the optimizer.
- Return type:
- short_name()#
Implement this method to return a short name for the optimizer.
This short name will be used as part of the identifier for the variables being trained.
- Return type:
str
- slot_variables_initializers()#
Slot variables initializers for the optimizer.
Derived classes should implement this method to return the initializers for the applicable slot variables, if any.
- Returns:
A tuple of initializers for the slot variables.
- Return type:
tuple[Initializer, …]
- class FTRLOptimizerSpec(learning_rate=0.01, learning_rate_power=-0.5, l1_regularization_strength=0.0, l2_regularization_strength=0.0, beta=0.0, initial_accumulator_value=0.1, initial_linear_value=0.0, multiply_linear_by_learning_rate=False)#
Spec for the FTRL optimizer.
Follow The Regularized Leader (FTRL) is an optimization algorithm developed at Google for click-through rate prediction.
- Parameters:
- learning_rate#
The learning rate.
- learning_rate_power#
A float value, typically -0.5.
- l1_regularization_strength#
A float value, must be greater than or equal to 0.
- l2_regularization_strength#
A float value, must be greater than or equal to 0.
- beta#
A float value.
- initial_accumulator_value#
Initial value for the accumulator slot.
- initial_linear_value#
Initial value for the linear slot.
- multiply_linear_by_learning_rate#
A bool value, if True, multiply the linear slot by the learning rate.
- get_hyperparameters(step=None)#
Returns the FTRL hyperparameters.
- get_optimizer_primitive()#
Derived classes should implement this method to return the xla primitive for the optimizer.
- Return type:
- short_name()#
Implement this method to return a short name for the optimizer.
This short name will be used as part of the identifier for the variables being trained.
- Return type:
str
- slot_variables_initializers()#
Slot variables initializers for the optimizer.
Derived classes should implement this method to return the initializers for the applicable slot variables, if any.
- Returns:
A tuple of initializers for the slot variables.
- Return type:
tuple[Initializer, …]
- class LaPropOptimizerSpec(learning_rate=0.001, b1=0.9, b2=0.95, eps=1e-30, rms_clip_threshold=None, initial_slot_value=0.0)#
Spec for the LaProp optimizer.
Laprop decouples momentum and adaptivity in the Adam-style methods, leading to improved speed and stability compare to Adam. https://arxiv.org/abs/2002.04839
- Parameters:
b1 (float)
b2 (float)
eps (float)
rms_clip_threshold (float | None)
initial_slot_value (float)
- learning_rate#
The learning rate for the training variables or embeddings.
- b1#
decay rate for the exponentially weighted average of grads.
- b2#
decay rate for the exponentially weighted average of squared grads.
- eps#
term added to the squared gradient to improve numerical stability.
- rms_clip_threshold#
Clipping threshold for RMS.
- initial_slot_value#
Initial value for the slot variables.
- get_decay_rate(step=None)#
Returns the decay rate for the optimizer.
- get_hyperparameters(step=None)#
Returns the LaProp hyperparameters: (learning_rate, b1, decay_rate, eps).
- get_optimizer_primitive()#
Derived classes should implement this method to return the xla primitive for the optimizer.
- Return type:
- short_name()#
Implement this method to return a short name for the optimizer.
This short name will be used as part of the identifier for the variables being trained.
- Return type:
str
- slot_variables_initializers()#
Slot variables initializers for the optimizer.
Derived classes should implement this method to return the initializers for the applicable slot variables, if any.
- Returns:
A tuple of initializers for the slot variables.
- Return type:
tuple[Initializer, …]
- class SGDOptimizerSpec(learning_rate=0.001)#
Spec for the Stochastic Gradient Descent (SGD) optimizer.
An iterative optimization method that updates the weights of the embedding variables by taking a step in the direction of the gradient. The step size is controlled by the learning rate. SGD is a usually a default choice in training setup.
- learning_rate#
The learning rate for the training variables or embeddings.
- get_optimizer_primitive()#
Returns the optimizer primitive for the SGD optimizer.
- Return type:
- short_name()#
Implement this method to return a short name for the optimizer.
This short name will be used as part of the identifier for the variables being trained.
- Return type:
str
- slot_variables_initializers()#
SGD does not have any slot variables, hence this returns an empty tuple.
- Return type:
tuple[Initializer, …]