Minibatching for SparseCore Embedding Lookups#
Overview#
Minibatching is a feature in the SparseCore embedding library for handling input batches that are too large to be processed by SparseCore in a single pass due to on-chip memory constraints.
The SparseCore hardware has limits on the number of embedding IDs it can process in one operation for a single table partition, constraints dictated by the size of its on-chip memory (SPMEM):
max_ids_per_partition: The maximum number of embedding IDs.max_unique_ids_per_partition: The maximum number of unique embedding IDs.
When an input batch requires more lookups for a given partition than these limits allow, minibatching breaks the batch down into smaller pieces, called minibatches. These minibatches are then processed sequentially on the SparseCore, and their results are accumulated to produce the final output for the original batch.
This mechanism offers several advantages:
Prevents errors caused by exceeding hardware ID limits.
Allows users to train with larger logical batch sizes, even if parts of the input data are highly skewed (i.e., some samples require many more embedding ID lookups than others).
Does not affect model quality, as it is designed to be mathematically equivalent to processing the full batch in a single pass.
Minibatching is particularly useful when the number of embedding IDs that need
to be looked up for a single SparseCore core or partition exceeds limits defined
by max_ids_per_partition or max_unique_ids_per_partition, which are
parameters automatically tuned by Feedback-Directed Optimization (FDO).
Enabling Minibatching#
Minibatching can be enabled by setting the enable_minibatching flag to True
in various APIs.
Flax Embedding Layer#
If you are using the Flax Embedding Layer, you can
enable minibatching by passing enable_minibatching=True layer initialization:
from jax_tpu_embedding.sparsecore.lib.flax.linen import embed as flax_embed
from jax_tpu_embedding.sparsecore.lib.nn import embedding
embed_layer = flax_embed.SparseCoreEmbed(
feature_specs=features,
enable_minibatching=True,
mesh=mesh,
)
variables = embed_layer.init(jax.random.PRNGKey(0), ...)
# For multi-host minibatching, create all_reduce_interface
all_reduce_interface = embedding.get_all_reduce_interface(...)
# Preprocess inputs
preprocessed_inputs = embed_layer.preprocess_inputs(
step, features, features_weights, all_reduce_interface)
# During forward pass
activations = embed_layer.apply(variables, preprocessed_inputs)
# During backward pass
updated_variables = embed_layer.apply_gradient(
gradients,
preprocessed_inputs,
)
Low-level API#
If using the embedding module directly, pass enable_minibatching=True to
preprocess_sparse_dense_matmul_input and tpu_sparse_dense_matmul:
from jax_tpu_embedding.sparsecore.lib.nn import embedding
preprocessed_input, _ = embedding.preprocess_sparse_dense_matmul_input(
...,
enable_minibatching=True,
all_reduce_interface=all_reduce_interface,
)
activations = embedding.tpu_sparse_dense_matmul(
preprocessed_input,
embedding_vars,
...,
enable_minibatching=True,
)
Note that for multi-host minibatching, you need to initialize and pass an
all_reduce_interface object to preprocess_sparse_dense_matmul_input. This
can be obtained via embedding.get_all_reduce_interface(...).
Performance Considerations#
Enabling minibatching introduces some overhead:
Preprocessing: The bucketization and cross-host synchronization steps add latency to input preprocessing.
Execution: Executing SparseCore lookups in a loop for multiple minibatches increases TPU execution time compared to a single lookup for the entire batch.
Communication: In multi-host settings, gRPC-based AllReduce operations introduce communication overhead.
Despite this overhead, minibatching is essential for stability when dealing with
large or skewed inputs that would otherwise exceed hardware limits for
max_ids_per_partition or max_unique_ids_per_partition. If your model runs
without exceeding these limits and does not report ID dropping with minibatching
disabled, you might achieve better performance by leaving it disabled. However,
if you experience ID dropping or errors due to these limits, enabling
minibatching is the recommended solution.
How it works#
When minibatching is enabled, the input preprocessing pipeline performs the following steps:
Check for minibatching requirement: For each table in the input batch, the preprocessing step checks if the number of embedding IDs or unique embedding IDs destined for any SparseCore partition exceeds the limits (
max_ids_per_partitionandmax_unique_ids_per_partition). If these limits are exceeded for any table, that table is marked as requiring minibatching on the current host.Cross-host synchronization: In a multi-host environment, if at least one host requires minibatching for any table, all hosts must agree to use minibatching for the current step. This is achieved via a cross-host
AllReduceoperation implemented using gRPC, which aggregates the minibatching requirement status from all hosts. If any host requires minibatching, all hosts will proceed with it.Bucketization: If minibatching is required, all embedding IDs in tables that require minibatching are assigned to one of 64 buckets based on a hash of the embedding ID.
Minibatch creation: The 64 buckets are grouped into minibatches. The goal is to create minibatches such that each minibatch fits within the memory constraints of the SparseCore. The division of buckets into minibatches is determined by another
AllReduceoperation across hosts, ensuring all hosts use the same minibatching strategy for the current step. This division is represented by a bitmask calledMinibatchingSplit.Sequential processing: During the embedding lookup (forward pass) and gradient update (backward pass), if minibatching is active (
num_minibatches > 1), the SparseCore operation (sparse_dense_matmul) is executed in a loop, once for each minibatch. In the forward pass, embedding lookups are accumulated into the activation tensors based on the feature’s combiner (e.g., ‘sum’). In the backward pass, gradients are computed for each minibatch and applied sequentially to update the embedding tables in-place using the configured optimizer.
Cross-Host Synchronization Flow#
As mentioned in steps 2 and 4 above, cross-host synchronization is performed
using a gRPC-based AllReduce operation. This is used first to
synchronize a boolean minibatching_required flag, and then to synchronize the
MinibatchingSplit bitmask across all hosts. Both reductions use a logical OR
operation. The synchronization within AllReduce involves two main
phases:
Local Reduction: Within each host, all participating threads (K) call
InitializeOrUpdateState, and their values are combined (OR-ed) into a single host-level value. The last thread to contribute its value is responsible for initiating the global reduction, as at this point, all local contributions are guaranteed to be aggregated into the host-level value.Global Reduction: The locally-reduced values from all hosts (N) are combined via an all-to-all gRPC exchange into a single globally-reduced value, which is then made available to all threads on all hosts.
The diagram below illustrates this flow, showing how SendLocalData executes in
parallel with the ContributeData RPC handler.
![digraph G {
node [shape=box, style="rounded,filled", fillcolor=white, fontsize=10];
edge [fontsize=9];
compound=true;
newrank=true;
rankdir=TB;
nodesep=0.3;
ranksep=0.6;
label="AllReduce Synchronization Flow for Minibatching";
labelloc=t;
fontsize=16;
Host_i_T [label="Threads 0..k-1", shape=ellipse];
Host_j_T [label="Threads 0..k-1", shape=ellipse];
subgraph cluster_host_i {
label = "Host i";
bgcolor="#E6F0FA";
i_Init [label="1. InitializeOrUpdateState\n(K threads reduce to local value L_i)"];
i_Send [label="2a. SendLocalData(L_i)\n(Last contributing thread sends L_i to peers)"];
i_Handler [label="2b. ContributeData(L_peer)\n(RPC handler reduces peer data into L_i)", fillcolor="#FEF9E7"];
i_GetRes [label="3. GetResult\n(K threads wait for local and peer data, then read result)"];
i_Init -> i_Send -> i_GetRes;
i_Handler -> i_GetRes [style=dashed, label="contributes to result"];
}
subgraph cluster_host_j {
label = "Host j";
bgcolor="#F5E6FA";
j_Init [label="1. InitializeOrUpdateState\n(K threads reduce to local value L_j)"];
j_Send [label="2a. SendLocalData(L_j)\n(Last contributing thread sends L_j to peers)"];
j_Handler [label="2b. ContributeData(L_peer)\n(RPC handler reduces peer data into L_j)", fillcolor="#FEF9E7"];
j_GetRes [label="3. GetResult\n(K threads wait for local and peer data, then read result)"];
j_Init -> j_Send -> j_GetRes;
j_Handler -> j_GetRes [style=dashed, label="contributes to result"];
}
Host_i_T -> i_Init;
i_GetRes -> Host_i_T;
Host_j_T -> j_Init;
j_GetRes -> Host_j_T;
// RPCs
i_Send -> j_Handler [label="gRPC", style=dashed, constraint=false];
j_Send -> i_Handler [label="gRPC", style=dashed, constraint=false];
// Ranks
{rank=same; Host_i_T; Host_j_T;}
{rank=same; i_Init; j_Init;}
{rank=same; i_Send; j_Send; i_Handler; j_Handler;}
{rank=same; i_GetRes; j_GetRes;}
}](../_images/graphviz-c6913ff8d78b1c1db3aa53a5ed47fe2903e4f50f.png)
Explanation of Global Reduction Parallelism#
Stage 1 InitializeOrUpdateState performs local reduction among K threads, with
the last contributing thread (the “last contributing thread”), identified by
causing a local_contributions_counter to reach zero, emerging with the host’s
locally-reduced value (e.g., Li for Host i).
Stage 2 is the global reduction, which involves parallel send and receive operations:
2a. SendLocalData: The last contributing thread on Host
icallsSendLocalData, which sends Li to Hostj(and all other peers) via asynchronous gRPC calls. This function initiates the sends but does not wait for responses.2b. ContributeData: This is an RPC handler running on Host
i’s gRPC server. When HostjcallsSendLocalData, its RPC arrives at Hostiand invokesContributeData(L_j). This handler incorporates Lj into Hosti’s state via an OR-reduction and decrements a counter tracking pending contributions from peers.Parallelism: Because
SendLocalDatasends RPCs asynchronously, andContributeDatais an RPC handler that reacts to incoming RPCs, these two operations occur concurrently. Hostican be sending Li to Hostjat the same time as itsContributeDatahandler is processing Lk received from Hostk.
Synchronization and Result Retrieval:
Stage 3: Synchronization and Result Retrieval:
Synchronization occurs when threads call GetResult, which blocks until
both of the following conditions are met:
All K local threads have completed Stage 1 (
local_reduction_countdown).All N-1 peers have sent their data, which is processed by
ContributeDataon the local gRPC server (global_values_countdown).
Once both local and global reduction are complete, GetResult unblocks and
provides the final result to all waiting threads.