Checkpointing

Checkpointing#

Checkpoints are the serialized state of training. In order to save and restore training the recommended library for checkpointing in JAX is Orbax.

import orbax.checkpoint as ocp

chkpt_mgr = ocp.CheckpointManager(
    directory='/path/to/checkpoint',
    options=ocp.CheckpointManagerOptions(
        max_to_keep=5,
        ...
    ),
)
## Saving:
chkpt_mgr.save(
    step,
    args=ocp.args.Composite(
        train_state=ocp.args.PyTreeSave(train_state),
        embedding=ocp.args.PyTreeSave(embedding_variables)),
)
# Restore embedding from checkpoint
restored = chkpt_mgr.restore(
    step,
    args=ocp.args.Composite(
    train_state=ocp.args.PyTreeRestore(init_train_state),
    embedding=ocp.args.PyTreeRestore()),
)
train_state = restored['train_state']
emb_variables = {}
for k, v in restored['embedding'].items():
    emb_variables[k] = embedding.EmbeddingVariables(
        table=v['table'], slot=v['slot'])

The embedding is usually saved as a separate item and it can be restored and used in continued training. For a complete example see shakespeare_jit.