Feature and Table Stacking#

This document is only relevant to JAX SparseCore devs or advanced users. The features described here are optional and not required to use the JAX SparseCore API.

Introduction#

Feature Stacking and Table Stacking are two similar but different features that can be used to improve the efficiency of models with many tables and features.

  • Feature Stacking: Multiple features that reference the same table can be combined into a single lookup by combining the samples. This results in fewer, larger lookups which is generally more efficient.

  • Table Stacking: Multiple small tables can be combined into a single stacked table. This results in fewer, larger lookups as well and can also be more memory efficient for storing the sharded tables.

Note

In what follows, the term “row” refers to a training sample/example and “column” refers to an embedding ID in a given sample. The column maps to a vocabulary in an embedding table.

By default, feature stacking is performed implicitly by preprocess_sparse_dense_matmul_input, and activations are unstacked by tpu_sparse_dense_matmul. Similarly, gradients are stacked by tpu_sparse_dense_matmul_grad. If you need to handle stacked activations or gradients manually, you can set perform_unstacking=False in tpu_sparse_dense_matmul or perform_stacking=False in tpu_sparse_dense_matmul_grad.

Feature Stacking#

Stacking multiple features requires stacking along the batch/sample dimension, this is recorded in the FeatureIdTransformation structure using these fields:

  • Row Offset: records the offset along the batch dimension.

  • Col Shift: rotation of the vocabulary across the embedding table shards to distribute hot embedding IDs evenly.

Note

The sample dimension is interleaved when using SPLIT_THEN_STACK (the default and only strategy). This interleaving helps distribute embedding IDs evenly across SparseCores during embedding lookup and update. This is because we split the stacked samples along the batch dimension.

Table Stacking (Optional)#

Table stacking can help in decreasing training time by combining smaller embedding tables to create larger ones there by reducing the number of embedding table lookups and updates in forward and backward pass respectively. To do table stacking, define the TableSpec and FeatureSpec as usual and then call auto_stack_tables which will update the feature specs and the referenced tables specs with required stacking information. All the downstream apis for training refer to the feature specs and account for stacking as necessary. You do not need to do anything special with regard stacking in preparing the inputs. For instance, define TableSpecs for the embedding tables.

table_spec_a = embedding_spec.TableSpec(
    vocabulary_size=64,
    embedding_dim=12,
    initializer=lambda: jnp.zeros((128, 16), dtype=jnp.float32),
    optimizer=embedding_spec.SGDOptimizerSpec(),
    combiner='sum',
    name='table_a',
    max_ids_per_partition=16,
    max_unique_ids_per_partition=16,
)
table_spec_b = embedding_spec.TableSpec(
    vocabulary_size=120,
    embedding_dim=10,
    initializer=lambda: jnp.zeros((128, 16), dtype=jnp.float32),
    optimizer=embedding_spec.SGDOptimizerSpec(),
    combiner='sum',
    name='table_b',
    max_ids_per_partition=16,
    max_unique_ids_per_partition=16,
)

Define the FeatureSpecs that would use these tables:

feature_specs = [
  embedding_spec.FeatureSpec(
      table_spec=table_spec_a,
      input_shape=(16, 1),
      output_shape=(
          16,
          table_spec_a.embedding_dim,
      ),
      name='feature_spec_a',
  ),
  embedding_spec.FeatureSpec(
      table_spec=table_spec_b,
      input_shape=(16, 1),
      output_shape=(
          16,
          table_spec_b.embedding_dim,
      ),
      name='feature_spec_b',
  ),
  embedding_spec.FeatureSpec(
      table_spec=table_spec_b,
      input_shape=(16, 1),
      output_shape=(
          16,
          table_spec_b.embedding_dim,
      ),
      name='feature_spec_c',
  ),
]

If you want to use table stacking call auto_stack_tables() as follows:

from jax_tpu_embedding.sparsecore.lib import embedding

# Optional, only needed if you want to stack tables.
embedding.auto_stack_tables(
    feature_specs,
    global_device_count=jax.device_count(),
    num_sc_per_device=4, # 4 for TPU v5, 2 for TPU v6e
)
# Required, this will populate feature stacking related info when more than
# one feature use same table. It will also do some basic validations on the
# feature specs.
embedding.prepare_feature_specs_for_training(
    feature_specs,
    global_device_count=jax.device_count(),
    num_sc_per_device=4, # 4 for TPU v5, 2 for TPU v6e
)

There is also an API to manually stack tables: stack_tables().

API#

auto_stack_tables(features, global_device_count, num_sc_per_device, rotation=None, stack_to_max_ids_per_partition=<function get_default_limits>, stack_to_max_unique_ids_per_partition=<function get_default_limits>, *, use_short_stack_names=True, activation_mem_bytes_limit=2097152)#

Creates new feature specs based on auto stacking logic.

All tables with same dimensions and optimizer/combiner are stacked together. The tables are stacked in the alphabetical order of their names. The new feature specs have updated table specs with relevant fields related to stacking setup. The features are updated in-place with the new table specs.

Parameters:
  • features (FeatureSpec | Sequence[FeatureSpec] | Mapping[str, FeatureSpec]) – The input features.

  • global_device_count (int) – The number of global devices (chips). Typically mesh.size.

  • num_sc_per_device (int) – The number of sparsecores per device.

  • rotation (int | None) – The shard rotation factor for each stacked table. If None, sets to num_sc_per_device. Default: None.

  • stack_to_max_ids_per_partition (Callable[[str, int], int]) – Override the max_ids_per_partition for each stack.

  • stack_to_max_unique_ids_per_partition (Callable[[str, int], int]) – Override the max_unique_ids_per_partition for each stack.

  • use_short_stack_names (bool) – If True, a hash will be appended to the stack name to avoid long names. Otherwise, the stack name will be the concatenation of the table names.

  • activation_mem_bytes_limit – If the activation memory usage is larger than this limit, the table will not be stacked. Default is 2MB.

Return type:

None

stack_tables(features, table_names, global_device_count, num_sc_per_device, rotation=None, stack_to_max_ids_per_partition=<function get_default_limits>, stack_to_max_unique_ids_per_partition=<function get_default_limits>, stack_name=None, fail_on_excess_padding=False)#

Creates new feature specs based on specified stacking groups.

Checks that the tables in the groups have same dim, optimizer and combiner. Then creates new feature specs with updated table specs with relevant fields related to stacking setup. The features are updated in-place with the new table specs.

Parameters:
  • features (FeatureSpec | Sequence[FeatureSpec] | Mapping[str, FeatureSpec]) – The input features.

  • table_names (Sequence[str]) – A list of table names to be stacked.

  • global_device_count (int) – The number of global devices (chips). Typically mesh.size.

  • num_sc_per_device (int) – The number of sparsecores per device.

  • rotation (int | None) – The shard rotation factor for each stacked table. If None, sets to num_sc_per_device. Default: None.

  • stack_to_max_ids_per_partition (Callable[[str, int], int]) – Override the max_ids_per_partition for each stack.

  • stack_to_max_unique_ids_per_partition (Callable[[str, int], int]) – Override the max_unique_ids_per_partition for each stack.

  • stack_name (str | None) – A unique name for the table stack. If None, a default name will be chosen.

  • fail_on_excess_padding (bool) – If True, raises an error if the embedding dimensions of the tables to stack would lead to excessive padding (i.e. do not match when rounded up to the nearest multiple of 8 values).

Return type:

None

References#