Feature and Table Stacking#
This document is only relevant to JAX SparseCore devs or advanced users. The features described here are optional and not required to use the JAX SparseCore API.
Introduction#
Feature Stacking and Table Stacking are two similar but different features that can be used to improve the efficiency of models with many tables and features.
Feature Stacking: Multiple features that reference the same table can be combined into a single lookup by combining the samples. This results in fewer, larger lookups which is generally more efficient.
Table Stacking: Multiple small tables can be combined into a single stacked table. This results in fewer, larger lookups as well and can also be more memory efficient for storing the sharded tables.
Note
In what follows, the term “row” refers to a training sample/example and “column” refers to an embedding ID in a given sample. The column maps to a vocabulary in an embedding table.
By default, feature stacking is performed implicitly by
preprocess_sparse_dense_matmul_input, and activations are unstacked by
tpu_sparse_dense_matmul. Similarly, gradients are stacked by
tpu_sparse_dense_matmul_grad. If you need to handle stacked activations or
gradients manually, you can set perform_unstacking=False in
tpu_sparse_dense_matmul or perform_stacking=False in
tpu_sparse_dense_matmul_grad.
Feature Stacking#
Stacking multiple features requires stacking along the batch/sample dimension, this is recorded in the FeatureIdTransformation structure using these fields:
Row Offset: records the offset along the batch dimension.
Col Shift: rotation of the vocabulary across the embedding table shards to distribute hot embedding IDs evenly.
Note
The sample dimension is interleaved when using SPLIT_THEN_STACK (the default and only strategy). This interleaving helps distribute embedding IDs evenly across SparseCores during embedding lookup and update. This is because we split the stacked samples along the batch dimension.
Table Stacking (Optional)#
Table stacking can help in decreasing training time by combining smaller
embedding tables to create larger ones there by reducing the number of embedding
table lookups and updates in forward and backward pass respectively. To do table
stacking, define the TableSpec and FeatureSpec as usual and then call
auto_stack_tables which will update the feature specs and the
referenced tables specs with required stacking information. All the downstream
apis for training refer to the feature specs and account for stacking as
necessary. You do not need to do anything special with regard stacking in
preparing the inputs. For instance, define TableSpecs for the embedding tables.
table_spec_a = embedding_spec.TableSpec(
vocabulary_size=64,
embedding_dim=12,
initializer=lambda: jnp.zeros((128, 16), dtype=jnp.float32),
optimizer=embedding_spec.SGDOptimizerSpec(),
combiner='sum',
name='table_a',
max_ids_per_partition=16,
max_unique_ids_per_partition=16,
)
table_spec_b = embedding_spec.TableSpec(
vocabulary_size=120,
embedding_dim=10,
initializer=lambda: jnp.zeros((128, 16), dtype=jnp.float32),
optimizer=embedding_spec.SGDOptimizerSpec(),
combiner='sum',
name='table_b',
max_ids_per_partition=16,
max_unique_ids_per_partition=16,
)
Define the FeatureSpecs that would use these tables:
feature_specs = [
embedding_spec.FeatureSpec(
table_spec=table_spec_a,
input_shape=(16, 1),
output_shape=(
16,
table_spec_a.embedding_dim,
),
name='feature_spec_a',
),
embedding_spec.FeatureSpec(
table_spec=table_spec_b,
input_shape=(16, 1),
output_shape=(
16,
table_spec_b.embedding_dim,
),
name='feature_spec_b',
),
embedding_spec.FeatureSpec(
table_spec=table_spec_b,
input_shape=(16, 1),
output_shape=(
16,
table_spec_b.embedding_dim,
),
name='feature_spec_c',
),
]
If you want to use table stacking call auto_stack_tables() as follows:
from jax_tpu_embedding.sparsecore.lib import embedding
# Optional, only needed if you want to stack tables.
embedding.auto_stack_tables(
feature_specs,
global_device_count=jax.device_count(),
num_sc_per_device=4, # 4 for TPU v5, 2 for TPU v6e
)
# Required, this will populate feature stacking related info when more than
# one feature use same table. It will also do some basic validations on the
# feature specs.
embedding.prepare_feature_specs_for_training(
feature_specs,
global_device_count=jax.device_count(),
num_sc_per_device=4, # 4 for TPU v5, 2 for TPU v6e
)
There is also an API to manually stack tables: stack_tables().
API#
- auto_stack_tables(features, global_device_count, num_sc_per_device, rotation=None, stack_to_max_ids_per_partition=<function get_default_limits>, stack_to_max_unique_ids_per_partition=<function get_default_limits>, *, use_short_stack_names=True, activation_mem_bytes_limit=2097152)#
Creates new feature specs based on auto stacking logic.
All tables with same dimensions and optimizer/combiner are stacked together. The tables are stacked in the alphabetical order of their names. The new feature specs have updated table specs with relevant fields related to stacking setup. The features are updated in-place with the new table specs.
- Parameters:
features (FeatureSpec | Sequence[FeatureSpec] | Mapping[str, FeatureSpec]) – The input features.
global_device_count (int) – The number of global devices (chips). Typically mesh.size.
num_sc_per_device (int) – The number of sparsecores per device.
rotation (int | None) – The shard rotation factor for each stacked table. If None, sets to num_sc_per_device. Default: None.
stack_to_max_ids_per_partition (Callable[[str, int], int]) – Override the max_ids_per_partition for each stack.
stack_to_max_unique_ids_per_partition (Callable[[str, int], int]) – Override the max_unique_ids_per_partition for each stack.
use_short_stack_names (bool) – If True, a hash will be appended to the stack name to avoid long names. Otherwise, the stack name will be the concatenation of the table names.
activation_mem_bytes_limit – If the activation memory usage is larger than this limit, the table will not be stacked. Default is 2MB.
- Return type:
None
- stack_tables(features, table_names, global_device_count, num_sc_per_device, rotation=None, stack_to_max_ids_per_partition=<function get_default_limits>, stack_to_max_unique_ids_per_partition=<function get_default_limits>, stack_name=None, fail_on_excess_padding=False)#
Creates new feature specs based on specified stacking groups.
Checks that the tables in the groups have same dim, optimizer and combiner. Then creates new feature specs with updated table specs with relevant fields related to stacking setup. The features are updated in-place with the new table specs.
- Parameters:
features (FeatureSpec | Sequence[FeatureSpec] | Mapping[str, FeatureSpec]) – The input features.
table_names (Sequence[str]) – A list of table names to be stacked.
global_device_count (int) – The number of global devices (chips). Typically mesh.size.
num_sc_per_device (int) – The number of sparsecores per device.
rotation (int | None) – The shard rotation factor for each stacked table. If None, sets to num_sc_per_device. Default: None.
stack_to_max_ids_per_partition (Callable[[str, int], int]) – Override the max_ids_per_partition for each stack.
stack_to_max_unique_ids_per_partition (Callable[[str, int], int]) – Override the max_unique_ids_per_partition for each stack.
stack_name (str | None) – A unique name for the table stack. If None, a default name will be chosen.
fail_on_excess_padding (bool) – If True, raises an error if the embedding dimensions of the tables to stack would lead to excessive padding (i.e. do not match when rounded up to the nearest multiple of 8 values).
- Return type:
None