Feedback Directed Optimization

Contents

Feedback Directed Optimization#

FDO can be used to adjust the static input buffer parameters (e.g., limits) as training progresses. It provides a framework to automate the process of estimating and updating these parameters for next training epoch.

Architecture#

Each process records the observed input parameters in a local FDO client during the course of training and publishes the recorded parameters from local memory to a persistent storage on a periodic basis. Currently we have an implementation for file storage but users are encouraged to create a custom FDOClient implementation to better integrate with their training infrastructure.

At some frequency, usually end of epoch, all processes synchronize and load the limits from the storage and update the feature specs to use the new limits. When the limits are updated in the feature specs the jitted train step needs to be recompiled. This can happen automatically if the feature specs object is a static argument to the train function. See this JAX document that explains how the JIT triggers recompilation when a static argument changes.

All the above steps are customizable and a typical flow looks as follows:

# Import the fdo client
from jax_tpu_embedding.sparsecore.lib.fdo import file_fdo_client

# Train function: Make sure it takes feature specs as a static argument so that
# when its changes, jit triggers a recompilation.
def train_step(
    feature_spec: Nested[FeatureSpec]
    ...
):
    ...

jit_train_step = jax.jit(train_step, static_argnums=0, ...)

# Create an instance of fdo client
fdo_dir = '/tmp/fdo_dumps/"
fdo_client = file_fdo_client.NPZFileFDOClient(fdo_dir)

for step in range(100):
    # Record stats returned from preprocessin step
    preprocessed_inputs, stats = embedding.preprocess_sparse_dense_matmul_input(
        ...
    )
    # Record the stats returned by inputs preprocessing
    fdo_client.record(stats)
    jit_train_step(...)
    # At some frequency, publish and update stats.
    if step % 10 == 0:
        fdo_client.publish()
        # Add a barrier her so that all processes can finish publishing their stats
        # and so all the processes can read the same data.
        jax.experimental.multihost_utils.sync_global_devices("FDO_publish_barrier")
        # Load FDO stats dumps and update feature specs.
        loaded_stats: embedding.PreprocessSparseDenseMatmulStats = fdo_client.load()
        # Any custom code to adjust the stats can go here.
        # `transform` is a user-defined function to modify the FDO statistics.
        # For example:
        def transform(stats):
            # Alter stats as desired:
            # stats.required_buffer_size_per_sc *= 2
            return stats
        updated_stats = transform(loaded_stats)
        # Update the feature specs.
        embedding.update_preprocessing_parameters(
            feature_specs, updated_stats, num_sc_per_device
        )

API#