JAX TPU Embedding documentation#
JAX SparseCore provides support for leveraging the SparseCore accelerators present in TPU generations starting with TPU v5. SparseCores are specialized processors designed to accelerate workloads with sparse data access patterns, particularly large-scale embedding lookups common in deep learning recommendation models and other areas.
Installation#
pip install jax-tpu-embedding