Quick Overview for Users#
In this section, we provide a high-level overview of the steps required to utilize the JAX SparseCore API for large embedding models. The specifics of these steps are detailed in separate pages.
Embedding Specification#
Configuring your embedding is done through two primary specs: the TableSpec
and the FeatureSpec.
A FeatureSpec specifies things like the embedding table size (vocabulary and
embedding sizes) and optimizer.
A TableSpec specifies which table an embedding feature uses (multiple features
can use the same table) and the input/output shape of the feature lookup.
Details of the embedding specification are described in the Embedding Specification page.
Input Preprocessing#
The SparseCore accepts sparse inputs (ragged/list of list) packed into a COO format. To convert sparse inputs into this format we provide an API that performs the conversion in a highly efficient way. See Input Preprocessing section for more details.
Embedding Lookup and Gradient Update#
We provide two APIs for using large embeddings. These APIs have essentially identical performance and which you choose depends on your modeling preference.
Flax API: Using a Flax layer is often the preferred choice as it’s a more natural fit for Flax based models. Here, the details of performing the embedding lookup and gradient based weight update are implemented by the SparseCoreEmbed Flax layer. An example of using this can be found in the [Shakespeare on Flax APIs] Colab example.
Primitive API: Using this API, you make direct calls to perform the embedding lookup
tpu_sparse_dense_matmul() and gradient based weight update tpu_sparse_dense_matmul_grad().
Typically this is done using JAX’s shard_map feature inside of a jax.jit() function.
A working example of this can be seen in the Shakespeare on Primitive APIs Colab example.
The primitive API is mostly of interest to JAX SparseCore developers for implementing
higher level features like the Flax and embedding pipelining APIs.
Next Steps#
The above three steps account for the essentials of using the JAX SparseCore embedding API. Once these are in place, additional features can be used to maximize performance and productionize the integration with your model and infrastructure.
FDO: Dynamically adapt TPU buffer sizes to optimize memory usage.
Table and Feature Stacking: Make fewer, larger lookups to reduce memory usage and increase performance.
Embedding Pipelining: Overlap SparseCore and TensorCore compute to maximize TPU performance.
Checkpointing: Save and restore from checkpoints for increased robustness.