Embedding Specification#
See Google’s page about Embedding for a definition and examples.
Terminology#
(Embedding) Table: The lower-dimensional representations of sparse/categorical data. For each token in the vocabulary we have a vector with a size of the embedding dimension.
Embedding ID (Token): Represents an element of the embedding vocabulary.
Vocabulary Size: The total number of unique embedding IDs. This is the number of rows in the embedding table.
Embedding Dimension: The size of the lower dimensional space for the embeddings. This is the number of columns in the embedding table.
Sample (Example): Represents a single training example with multiple tokens.
Feature (Input): Represents a collection of samples.
Max Sequence Length: Defines the maximum number of tokens that a sample can have in a given feature.
Weight/Gain: The weight of each Embedding ID in a given sample.
Combiner: The aggregation function for combining the embeddings for a given sample. For instance, sum or mean.
(Feature) Activations: The weighted aggregation calculated with the Combiner for each sample in a given Input Feature.
(Feature) Gradients: Gradients (of the feature activations) with respect to the loss function.
(Embedding Table) Optimizer: The update function for the Model parameters and Embedding Table.
API#
- class TableSpec(*, name, vocabulary_size, embedding_dim, initializer, optimizer, combiner, max_ids_per_partition=256, max_unique_ids_per_partition=256, suggested_coo_buffer_size_per_device=None, quantization_config=None, _stacked_table_spec=None, _setting_in_stack=None)#
Specifies one embedding table.
TableSpec is virtually immutable (for jax.jit) using eq=True and unsafe_hash=True, but has frozen=False to allow in-place updates when preparing for feature stacking or table stacking. See [dataclass doc](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass) for more information.
- Parameters:
name (str)
vocabulary_size (int)
embedding_dim (int)
initializer (Initializer)
optimizer (OptimizerSpec)
combiner (str)
max_ids_per_partition (int)
max_unique_ids_per_partition (int)
suggested_coo_buffer_size_per_device (int | None)
quantization_config (QuantizationConfig | None)
_stacked_table_spec (StackedTableSpec | None)
_setting_in_stack (TableSettingInStack | None)
- combiner: str#
The aggregation function to compute activations for each sample. For example, sum or mean.
- embedding_dim: int#
The number of columns in the embedding table.
- initializer: Initializer#
An initializer for the embedding table. See
jax.nn.initializers()for more details.
- max_ids_per_partition: int = 256#
The maximum number of embedding IDs that can be packed into a single partition.
- max_unique_ids_per_partition: int = 256#
The maximum number of unique embedding IDs that can be packed into a single partition.
- name: str#
Name of the table.
- optimizer: OptimizerSpec#
An optimizer for the embedding table.
- quantization_config: QuantizationConfig | None = None#
Quantization config (min, max, num_buckets) which represent the float range and number of discrete integer buckets to use for quantization.
- property setting_in_stack: TableSettingInStack#
Returns the setting of this table in the stack.
- property stacked_table_spec: StackedTableSpec#
Returns the stacked table spec which this table belongs to.
- suggested_coo_buffer_size_per_device: int | None = None#
The minimum size of the input buffer that the preprocessing should try to create.
- vocabulary_size: int#
The total number of unique embedding IDs. This is the number of rows in the embedding table.
- class FeatureSpec(*, name, table_spec, input_shape, output_shape, _id_transformation=None)#
Specification for one embedding feature.
Notes
FeatureSpec is virtually immutable (for
jax.jit()) usingeq=Trueandunsafe_hash=True, but hasfrozen=Falseto allow in-place updates when preparing for feature stacking or table stacking. See [dataclass doc](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass) for more information.Warning
For all other purposes use embedding.update_preprocessing_parameters to maintain consistency between features, tables and stacked tables.
- Parameters:
name (str)
table_spec (TableSpec)
input_shape (Sequence[int])
output_shape (Sequence[int])
_id_transformation (FeatureIdTransformation | None)
- property id_transformation: FeatureIdTransformation#
Returns the transformation to apply to the input feature ids.
- input_shape: Sequence[int]#
The shape of the input jax array, this is [global_batch_size, feature_valency]. The second element can be omitted for ragged input.
- name: str#
Name of the feature.
- output_shape: Sequence[int]#
The expected shape of the output activation jax array, this is [global_batch_size, embedding_dim].
Multivalent (Unordered/Pooled) Features#
For multivalent features, each sample is represented by an unordered set of embedding IDs.
The embeddings corresponding to these IDs are aggregated or “pooled” into a single embedding
vector for the sample. This is done using the combiner (e.g., sum, mean) specified in the TableSpec.
For example, if a sample has IDs [10, 21, 32] and the combiner is mean, the output activation
will be mean(embedding(10), embedding(21), embedding(32)).
The input shape for a batch of such features is [batch_size, max_ids_per_sample], where
max_ids_per_sample is the valency. The output shape is [batch_size, embedding_dim].
Sequence (Ordered/Concatenated) Features#
For sequence features, each sample is an ordered sequence of items, where each item can be one or more embedding IDs. The embeddings for each item in the sequence are computed and then concatenated to form the final output.
To handle sequence features, you will need to flatten the sequence dimension into the batch dimension before passing the features to the embedding layer. You can then reshape the output back to recover the sequence dimension. This is equivalent to concatenating the embeddings for each item in the sequence.
# input shape: [batch_size, sequence_length, valency]
# 1. Flatten the sequence dimension into the batch dimension
flattened_input = jnp.reshape(input, (batch_size * sequence_length, valency))
# 2. Perform the embedding lookup and combinations (if valency > 1)
flattened_output = embed_layer(flattened_input)
# flattened_output shape: [batch_size * sequence_length, embedding_dim]
# 3. Reshape the output back to the original sequence shape
output = jnp.reshape(flattened_output, (batch_size, sequence_length, embedding_dim))
If you have variable sequence lengths, you will need to pad your inputs to a
max_sequence_length.
Optimizers#
See the Optimizers page for more details on the available optimizers and how to configure them.
Flax Embedding Layer#
Flax is the most commonly used JAX neural network library. The JAX SparseCore API provides a Flax layer that uses the primitive APIs to support large embeddings.
Flax comes in two flavors:
Linen (now deprecated) and the more recent NNX. The Flax project provides a guide for migrating from Linen to NNX. SparseCore project provides both Linen and NNX layers for large embedding models that can be used without the need for modification or extension. These layers are built on the primitive API, use the same Embedding Specification objects to configure the embedding and accept inputs from the preprocessing API.
You can find the Linen module here: linen.embed.SparseCoreEmbed.
The newer NNX module is here: nnx.embed.SparseCoreEmbed.
Caveats#
Caveat 1: As with the primitive API and due to the size of embedding tables, the embedding tables are updated in-place during the gradient calculation. As such, gradients of the embeddings can’t be extracted in the same way as they are with dense layers.