Skip to content

Latest commit

 

History

History
808 lines (609 loc) · 33.2 KB

01-start.md

File metadata and controls

808 lines (609 loc) · 33.2 KB

Getting Started

Table of Contents

Section Description
Installing Dependencies Setup instructions.
A Short Tutorial Writing an experiment from scratch.
Launching Training Launching training with cloud TPUs.
Next Steps Where to go next.

Installing Dependencies

The installation steps depend on your machine's hardware, and whether you plan to develop AXLearn.

Pre-requisites

If you use an Intel (x86) machine, we recommend installing in a virtual environment, e.g. with conda.

conda create -n axlearn python=3.10
conda activate axlearn

If you use an Apple Silicon machine, please follow these instructions instead:

Expand for instructions

For Apple Silicon machines, we will install native versions of Python and Python packages using Miniforge.

We need Xcode to build packages like tensorstore. Please install Xcode from the App Store if you haven't already.

# Install the arm64 version of Miniforge3 + Python 3.10.
curl -L -o miniforge.sh https://github.com/conda-forge/miniforge/releases/download/24.7.1-0/Miniforge3-24.7.1-0-MacOSX-arm64.sh
bash miniforge.sh -u

# Create a conda environment.
conda create -n axlearn python=3.10
conda activate axlearn

# Install tensorflow following https://developer.apple.com/metal/tensorflow-plugin.
conda install -c apple tensorflow-deps

# If you do NOT have bazel installed.
# Note that the default ./oss_scripts/install_bazel.sh installs the x86 version.
brew install bazelisk

# Manually build tensorflow-text until a collaborator build is available.
# This was tested using clang version 15 - you may get non-working wheels with earlier versions of clang.
mkdir ~/builds && git clone https://github.com/tensorflow/text.git ~/builds/text
# Install tensorflow prior to building.
pip install 'tensorflow==2.16.1'
cd ~/builds/text && git checkout 0f9f6df5b4da19bc7a734ba05fc4fa12bccbedbe

# Build tensorflow-text.
./oss_scripts/run_build.sh
pip install ./tensorflow_text-2.16.1-cp310-cp310-macosx_*_arm64.whl

Installation (User)

This section is intended for users who do not intend to develop AXLearn, but rather use it as a package.

To install on Intel (x86) machines, simply run:

pip install 'axlearn[core]'

To install on Apple Silicon machines, make sure you have followed the required pre-requisites above. Then, install using:

pip install 'axlearn[core,apple-silicon]'

By default, AXLearn comes with tooling to launch jobs to Google Cloud Platform (GCP). To install them, run:

pip install 'axlearn[gcp]'

Installation (Developer)

This section is for users who do intend to develop AXLearn, e.g. by submitting PRs.

Expand for instructions

Instead of installing from pip, please fork the repo first, and then clone the fork.

# Clone your fork of the repo.
git clone https://github.com/<username>/axlearn
cd axlearn

In order to iterate locally and run tests, install the package in editable mode along with dev dependencies:

pip install -e '.[core,dev]'

If you intend to launch jobs to GCP, install gcp dependencies:

pip install -e '.[gcp]'

We also recommend setting up pre-commit hooks to run some CI checks locally:

pre-commit install --hook-type pre-commit

These checks will run automatically when you git commit, but you can also run pre-commit directly (please refer to the pre-commit docs for more information):

pre-commit run -a

We use pytype for static type checking:

# This can take a while, so we exclude it from pre-commit.
pytype -j auto .

To run tests (please refer to the pytest docs for more information):

pytest axlearn/common/config_test.py

# To set logging level:
pytest --log-cli-level=INFO axlearn/common/config_test.py

# To test a specific pattern of tests:
pytest axlearn/common/config_test.py -k "test_invalid"

# Run tests with 4 processes and specific markers:
pytest -n 4 -v -m "not (gs_login or tpu)" axlearn/common/

A Short Tutorial

This section walks through writing an AXLearn experiment from scratch. For the impatient, skip ahead to launching training to start training a model with an existing recipe.

AXLearn experiments have a standard anatomy. At a high level, we will need to define:

  • The model architecture.
  • The training and evaluation data.
  • The evaluation metrics.
  • The optimizer.

AXLearn comes with many reusable building blocks for building an experiment.

Let's walk through training ResNet on ImageNet as an example. All of the code is available under axlearn/experiments/vision/resnet.

If you plan to follow along with the code, create a new file axlearn/experiments/tutorial.py with the following skeleton:

# Imports to be added here...

def resnet_imagenet_trainer():
    # Code to be added here...
    ...

Model Architecture

Since ImageNet is a classification task, we will start with an image classification model. Thankfully, a skeleton already exists in the vision.image_classification module:

+from axlearn.vision import image_classification

def resnet_imagenet_trainer():
+    # Construct an image classifier config.
+    model_cfg = image_classification.ImageClassificationModel.default_config()

To use this model, let's look at the definition of ImageClassificationModel.Config. You may notice that it has a couple required fields:

class ImageClassificationModel(BaseModel):
"""An image classification model."""
@config_class
class Config(BaseLayer.Config):
backbone: Required[InstantiableConfig] = REQUIRED
num_classes: Required[int] = REQUIRED
classifier: InstantiableConfig = Linear.default_config().set(
bias=True,
param_partition_spec=("model", None),
)
dropout: Dropout.Config = Dropout.default_config()
metric: InstantiableConfig = ClassificationMetric.default_config()

  • The first is the backbone, which is the underlying model architecture we want to use for computing the image embeddings.
  • The second is num_classes, which is the number of class labels in our dataset.

Since we are interested in ResNet on ImageNet, we can modify the above code like so:

from axlearn.vision import image_classification
+from axlearn.vision import resnet

def resnet_imagenet_trainer():
    # Construct an image classifier config.
    model_cfg = image_classification.ImageClassificationModel.default_config()
+   model_cfg.backbone = resnet.ResNet.resnet50_config()
+   model_cfg.num_classes = 1000

This will use a vanilla ResNet-50 backbone for classifying the 1000 classes in ImageNet. ImageClassificationModel will handle the rest, such as extract embeddings from the backbone, computing the logits, and computing the loss and other metrics.

We can further customize the model by setting additional configs. For example, ResNet models commonly use He initialization:

+import math
+from axlearn.common import param_init
from axlearn.vision import image_classification
from axlearn.vision import resnet

def resnet_imagenet_trainer():
    # Construct an image classifier config.
    model_cfg = image_classification.ImageClassificationModel.default_config()
    model_cfg.backbone = resnet.ResNet.resnet50_config()
    model_cfg.num_classes = 1000
+   model_cfg.param_init = param_init.DefaultInitializer.default_config().set(
+       init_by_param_name={
+           param_init.PARAM_REGEXP_WEIGHT: param_init.WeightInitializer.default_config().set(
+               fan="fan_out",
+               distribution="normal",
+               scale=math.sqrt(2),
+           )
+       }
+   )

If you want, you can also customize the ResNet backbone:

def resnet_imagenet_trainer():
    # Construct an image classifier config.
    model_cfg = image_classification.ImageClassificationModel.default_config()
-   model_cfg.backbone = resnet.ResNet.resnet50_config()
+   model_cfg.backbone = resnet.ResNet.resnet50_config().set(
+       hidden_dim=123,
+       num_blocks_per_stage=[1, 2, 3],
+       ...
+   )
    ...

Or, you can just as easily switch to a different backbone:

def resnet_imagenet_trainer():
    # Construct an image classifier config.
    model_cfg = image_classification.ImageClassificationModel.default_config()
-   model_cfg.backbone = resnet.ResNet.resnet50_config()
+   model_cfg.backbone = resnet.ResNet.resnet101_config()
    ...

You can refer to the ResNet.Config definition for more details on the ResNet config API.

For simplicity, we will use the default values of our ResNet-50 backbone.


Training and Evaluation Data

Next, we will define an input pipeline for reading ImageNet. As of writing, the most well-supported1 option uses Tensorflow Datasets (TFDS), which may already be familiar to you (if not, that's completely fine).

As before, we can leverage existing building blocks in AXLearn. This time we can reuse the ImagenetInput under the vision.input_image module:

from axlearn.vision import image_classification
from axlearn.vision import resnet
+from axlearn.vision import input_image

def resnet_imagenet_trainer():
    # Construct an image classifier config.
    ...

+   # Construct an input pipeline config.
+   input_cfg = input_image.ImagenetInput.default_config()

Taking a peek under the hood, we can see that ImagenetInput.default_config() constructs a tfds_dataset, and applies some standard processing like cropping and resizing:

def default_config(cls):
cfg = super().default_config() # type: ImagenetInput.Config
cfg.source = config_for_function(input_tf_data.tfds_dataset).set(
dataset_name="imagenet2012",
shuffle_buffer_size=1024, # to be tuned.
)
cfg.processor = config_for_function(_process_example).set(
image_size=(224, 224),
eval_resize=None,
augment_name=None,
)
cfg.batcher.set(pad_example_fn=pad_with_negative_labels) # pylint: disable=no-member
return cfg

Like before, let's inspect the config API inherited from its parent class, the Input.Config:

class Input(Module):
"""A Module to generate input batches with tf.data.Dataset.
This input module contains three components:
* source: generates the raw examples
* processor: processes examples (potentially splitting and merging examples)
* batcher: converts a stream of examples to a stream of batches
This structure allows the users to replace source but reuse processor/batcher, e.g.,
for inference.
"""
@config_class
class Config(Module.Config):
"""Configures Input."""
is_training: Required[bool] = REQUIRED
# TODO(xianzhi): consider unifying `source`, `processor` and `batcher` with a single
# BuildDatasetFn.
# A config that instantiates to a BuildDatasetFn. The result dataset will contain
# a stream of examples representing one epoch of the source dataset.
source: Required[InstantiableConfig] = REQUIRED
# A config that instantiates to a DatasetToDatasetFn, which processes examples from
# the source dataset and generates the example dataset to be batched, potentially
# splitting and merging examples.
processor: Required[InstantiableConfig] = REQUIRED
# A config that instantiates to a DatasetToDatasetFn, which performs batching of examples.
batcher: InstantiableConfig = config_for_function(batch)

The main required fields are:

  • is_training: a bool indicating whether the dataset is used for training2.
  • The source: a function3 that returns a dataset (in this case, a tf.data.Dataset).
  • The processor: a function that takes a dataset, and outputs another dataset.
  • The batcher: a function that takes a dataset, and outputs a batched dataset.

The ImagenetInput.default_config() fills in these required configs for you, using reasonable defaults in context of ImageNet processing.

Note that each of source, processor, and batcher are themselves configs. This allows us to configure their properties with minimal changes:

from axlearn.vision import image_classification
from axlearn.vision import resnet
from axlearn.vision import input_image
+from axlearn.common import input_tf_data
+from axlearn.common import config

def resnet_imagenet_trainer():
    # Construct an image classifier config.
    ...

    # Construct an input pipeline config.
    input_cfg = input_image.ImagenetInput.default_config()
+   input_cfg.source.set(
+       split="train",
+       read_config=config.config_for_function(input_tf_data.tfds_read_config).set(
+           decode_parallelism=128,
+       ),
+   )
+   input_cfg.batcher.global_batch_size = 1024

Above, we independently set source.split = "train" to read the "train" data split, and set batcher.global_batch_size = 1024 to configure the batch size across all training hosts.

For efficiency reasons, we also set the decode_parallelism when reading the dataset.

You may be wondering what config.config_for_function is. We will cover more details in concepts, but at a high level, it dynamically generates a config from a function signature (in this case input_tf_data.tfds_read_config). This allows any arbitrary function to interoperate with the config system.

This also gives insight into the config API for input_tf_data.tfds_read_config -- it is simply the arguments of the function itself:

def tfds_read_config(
*,
is_training: bool,
num_shards: Optional[int] = None,
shard_index: Optional[int] = None,
read_parallelism: int = 1,
decode_parallelism: int = 32,
) -> tfds.ReadConfig:

As you can see from the above example, we've configured the decode_parallelism parameter to be 128.

As another example, we can switch to the newer ImageNetV2 dataset with the following changes:

from axlearn.vision import image_classification
from axlearn.vision import resnet
from axlearn.vision import input_image
from axlearn.common import input_tf_data
from axlearn.common import config

def resnet_imagenet_trainer():
    # Construct an image classifier config.
    ...

    # Construct an input pipeline config.
    input_cfg = input_image.ImagenetInput.default_config()
    input_cfg.source.set(
        split="train",
        read_config=config.config_for_function(input_tf_data.tfds_read_config).set(
            decode_parallelism=128,
        ),
    )
    input_cfg.batcher.global_batch_size = 1024

+   # Swaps the dataset name to the newer ImageNetV2.
+   input_cfg.source.dataset_name = "imagenet_v2/matched-frequency"

This will apply the same input processing as before, except to imagenet_v2 instead of imagenet2012. Hopefully, this provides some basic intuition about the design of the config system, which is prevalent in the AXLearn codebase.

For now, let's stick to the original imagenet2012.


Evaluation Metrics

A crucial part of any experiment is evaluating how well the model is learning.

AXLearn provides an evaler implementation called SpmdEvaler, which has the following config API:

class SpmdEvaler(Module):
"""An evaler implementation that supports partitioning of computation and data with GSPMD."""
@config_class
class Config(Module.Config):
"""Configures SpmdEvaler."""
# The input source.
input: Required[InstantiableConfig] = REQUIRED
# A summary writer to log tagged summary values.
summary_writer: InstantiableConfig = summary_writer.SummaryWriter.default_config()
# Run this evaler according to this policy.
eval_policy: InstantiableConfig = config_for_function(every_n_steps_policy)
# Which evaluation iters to trace with the profiler each time the evaler is run.
# Each trace will cover one full evaluation batch.
# Traces will run for at most 3 unique steps.
trace_at_iters: Sequence[int] = []
# Cast float inputs and parameters to this dtype for the evaluation step.
# If None, do not cast.
eval_dtype: Optional[jnp.dtype] = None
# The evaler metric_calculator to compute summaries.
metric_calculator: BaseMetricCalculator.Config = ModelSummaryAccumulator.default_config()
# If not None, writes input batches and `metric_calculator` forward outputs.
output_writer: Optional[BaseOutputWriter.Config] = None

As we can see, we only need to provide an input source to use it. Among other things, it already comes with a summary_writer to log evaluation metrics, and a basic metric_calculator to compute metrics for the summary writer.

In fact, we already have most of the pieces ready:

  • We have an input config which can read imagenet2012. We just need to tweak it slightly to read the "validation" split, instead of the "train" split.
  • We have a classification model ImageClassificationModel, which already comes with basic metrics like accuracy, perplexity and cross_entropy_loss.
  • The summary_writer is already capable of logging these metrics to tensorboard.

By now, you may already have an idea of how to construct the evaler input:

from axlearn.vision import image_classification
from axlearn.vision import resnet
from axlearn.vision import input_image
from axlearn.common import input_tf_data
from axlearn.common import config

def resnet_imagenet_trainer():
    # Construct an image classifier config.
    ...

    # Construct an input pipeline config.
    input_cfg = input_image.ImagenetInput.default_config()
    input_cfg.source.set(
        split="train",
        read_config=config.config_for_function(input_tf_data.tfds_read_config).set(
            decode_parallelism=128,
        ),
    )
    input_cfg.batcher.global_batch_size = 1024

+   # Construct an input pipeline config for evaluation.
+   eval_input_cfg = input_image.ImagenetInput.default_config()
+   eval_input_cfg.source.split = "validation"
+   eval_input_cfg.batcher.global_batch_size = 80

Because we often do not want to shuffle the eval dataset, we also disable shuffling using the utility input_tf_data.disable_shuffle_recursively.

from axlearn.vision import image_classification
from axlearn.vision import resnet
from axlearn.vision import input_image
from axlearn.common import input_tf_data
from axlearn.common import config

def resnet_imagenet_trainer():
    # Construct an image classifier config.
    ...

    # Construct an input pipeline config.
    input_cfg = input_image.ImagenetInput.default_config()
    input_cfg.source.set(
        split="train",
        read_config=config.config_for_function(input_tf_data.tfds_read_config).set(
            decode_parallelism=128,
        ),
    )
    input_cfg.batcher.global_batch_size = 1024

    # Construct an input pipeline config for evaluation.
    eval_input_cfg = input_image.ImagenetInput.default_config()
    eval_input_cfg.source.split = "validation"
    eval_input_cfg.batcher.global_batch_size = 80
+   input_tf_data.disable_shuffle_recursively(eval_input_cfg)

We can then construct the SpmdEvaler config, which takes the eval input as a child:

from axlearn.vision import image_classification
from axlearn.vision import resnet
from axlearn.vision import input_image
from axlearn.common import input_tf_data
from axlearn.common import config
+from axlearn.common import evaler

def resnet_imagenet_trainer():
    # Construct an image classifier config.
    ...

    # Construct an input pipeline config.
    input_cfg = input_image.ImagenetInput.default_config()
    input_cfg.source.set(
        split="train",
        read_config=config.config_for_function(input_tf_data.tfds_read_config).set(
            decode_parallelism=128,
        ),
    )
    input_cfg.batcher.global_batch_size = 1024

    # Construct an input pipeline config for evaluation.
    eval_input_cfg = input_image.ImagenetInput.default_config()
    eval_input_cfg.source.split = "validation"
    eval_input_cfg.batcher.global_batch_size = 80
    input_tf_data.disable_shuffle_recursively(eval_input_cfg)

+   # Construct the evaler.
+   evaler_cfg = evaler.SpmdEvaler.default_config()
+   evaler_cfg.input = eval_input_cfg

A minor note is that the default evaler runs every step, which may be more frequent than we'd like. We can tweak the eval rate by customizing the eval policy:

from axlearn.vision import image_classification
from axlearn.vision import resnet
from axlearn.vision import input_image
from axlearn.common import input_tf_data
from axlearn.common import config
from axlearn.common import evaler

def resnet_imagenet_trainer():
    # Construct an image classifier config.
    ...

    # Construct an input pipeline config.
    input_cfg = input_image.ImagenetInput.default_config()
    input_cfg.source.set(
        split="train",
        read_config=config.config_for_function(input_tf_data.tfds_read_config).set(
            decode_parallelism=128,
        ),
    )
    input_cfg.batcher.global_batch_size = 1024

    # Construct an input pipeline config for evaluation.
    eval_input_cfg = input_image.ImagenetInput.default_config()
    eval_input_cfg.source.split = "validation"
    eval_input_cfg.batcher.global_batch_size = 80
    input_tf_data.disable_shuffle_recursively(eval_input_cfg)

    # Construct the evaler.
    evaler_cfg = evaler.SpmdEvaler.default_config()
    evaler_cfg.input = eval_input_cfg
+   evaler_cfg.eval_policy.n = 12_510  # Eval roughly every 10 epochs.

This will cause the evaler to run every 12.5k steps instead, roughly every 10 epochs.


Optimizer

Next, we will need to define an optimizer. AXLearn comes with a variety of default implementations in common.optimizers. For this example, we'll use standard Stochastic Gradient Descent (SGD) with a weight decay of 1e-4 and momentum of 0.9, mostly following the original paper.

from axlearn.common import config, optimizers

def resnet_imagenet_trainer():
    # Other code redacted for simplicity...

    # Construct the optimizer config.
    optimizer_cfg = config.config_for_function(optimizers.sgd_optimizer).set(
        decouple_weight_decay=False,  # Set to False to match Torch behavior.
        momentum=0.9,
        weight_decay=1e-4,
    )

As before, we use the config.config_for_function utility to dynamically generate a config from the optimizers.sgd_optimizer function signature:

def sgd_optimizer(
learning_rate: schedule.Schedule,
*,
decouple_weight_decay: bool,
momentum: float = 0,
weight_decay: float = 0,
weight_decay_per_param_scale: Optional[Callable[[NestedOptParam], Any]] = None,
) -> PartitionedGradientTransformation:

Among these parameters is the learning_rate, which we still haven't configured yet. Out of the box, AXLearn comes with a variety of learning rate schedules in common.schedule. An appropriate learning rate schedule for our use-case (and batch size we intend to use) is a linear warmup to a peak learning rate of 0.4 followed by a cosine decay:

from axlearn.common import config, optimizers
+from axlearn.common import schedule

def resnet_imagenet_trainer():
    # Other code redacted for simplicity...

    # Construct the optimizer config.
+   learning_rate = config.config_for_function(schedule.cosine_with_linear_warmup).set(
+       peak_lr=0.4,
+       max_step=112_590,  # Roughly 90 epochs.
+       warmup_steps=6_255,  # Roughly 5 epochs.
+   )
    optimizer_cfg = config.config_for_function(optimizers.sgd_optimizer).set(
+       learning_rate=learning_rate,
        decouple_weight_decay=False,  # Set to False to match Torch behavior.
        momentum=0.9,
        weight_decay=1e-4,
    )

One caveat of applying weight decay naively is that we regularize all parameters globally. Empirically, we find that regularizing the BatchNorm4 parameters hurts model performance, so we exclude them from weight decay:

from axlearn.common import config, optimizers
from axlearn.common import schedule

def resnet_imagenet_trainer():
    # Other code redacted for simplicity...

    # Construct the optimizer config.
    learning_rate = config.config_for_function(schedule.cosine_with_linear_warmup).set(
        peak_lr=0.4,
        max_step=112_590,  # Roughly 90 epochs.
        warmup_steps=6_255,  # Roughly 5 epochs.
    )
+    per_param_decay = config.config_for_function(optimizers.per_param_scale_by_path).set(
+        description="weight_decay_scale",
+        scale_by_path=[
+            (".*norm.*", 0),  # Exclude the norm parameters from weight decay.
+        ],
+    )
    optimizer_cfg = config.config_for_function(optimizers.sgd_optimizer).set(
        learning_rate=learning_rate,
        decouple_weight_decay=False,  # Set to False to match Torch behavior.
        momentum=0.9,
        weight_decay=1e-4,
+       weight_decay_per_param_scale=per_param_decay,
    )

Putting Everything Together

We are now ready to put all the pieces together! The glue that makes everything work is the SpmdTrainer. As the name implies, it runs the main training loop using the components that we've defined above.

Once again, we can get an idea of how to use this component by inspecting its config API:

class SpmdTrainer(Module):
"""A trainer implementation that supports partitioning of computation and data with GSPMD."""
@config_class
# pylint: disable-next=too-many-instance-attributes
class Config(Module.Config):
"""Configures SpmdTrainer."""
# The input source.
input: Required[InstantiableConfig] = REQUIRED
# A summary writer to log tagged summary values.
summary_writer: BaseWriter.Config = SummaryWriter.default_config()
# The trainer root dir.
# By default, checkpoints will be written under {dir}/checkpoints/
# and summaries will be written under {dir}/summaries.
dir: Required[str] = REQUIRED
# If not None, initializes trainer states according to the given config.
# This is only applied if we aren't restoring from an existing checkpoint.
init_state_builder: Optional[TrainerStateBuilder.Config] = None
# The maximum number of steps.
max_step: Union[int, float] = math.inf
# The device mesh shape in the form of a tuple of ints.
# Must have the same length as mesh_axis_names.
mesh_shape: Required[Sequence[int]] = REQUIRED
# The mesh axis names. The names can be referenced in ParameterSpec.mesh_axes.
mesh_axis_names: Required[Sequence[str]] = REQUIRED
# Subset of mesh axis names over which the leaves of the input batch are sharded.
batch_axis_names: Union[str, Sequence[str]] = "data"
# The model config.
model: Required[BaseModel.Config] = REQUIRED
# The learner config.
learner: Required[Learner.Config] = REQUIRED
# The checkpointer config.
checkpointer: Checkpointer.Config = Checkpointer.default_config()
# A dict of evaler names to configs, each name must be non-empty.
evalers: Dict[str, SpmdEvaler.Config] = {}
# If True, saves the input iterator in checkpoints.
#
# It is OK to change this option for existing models, since our checkpoint restoration
# logic can handle legacy checkpoints with a different value of `save_input_iterator`.
#
# WARNING: many input processing ops are stateful and cannot be saved, e.g.,
# FailedPreconditionError: RandomUniformInt is stateful.
# FailedPreconditionError: ReduceDataset is stateful.
# FailedPreconditionError: SentencepieceOp is stateful.
save_input_iterator: bool = False
# At which steps to start profiler tracing.
# Currently each trace will cover 3 consecutive training steps.
# The start steps must therefore be at least 3 steps apart from each other.
start_trace_steps: Sequence[int] = []
# By default, only trace on host 0.
start_trace_process_indices: Union[Literal["all"], Sequence[int]] = [0]
# Prune empty state updates.
prune_empty_state_updates: bool = True
# Cast float inputs and model parameters to this dtype for the train step.
# If None, we do not cast.
train_dtype: Optional[jnp.dtype] = None

This particular module is quite complex, but feel free to refer to the inline documentation and comments for more details. We can also construct it in a familiar fashion:

from axlearn.common import trainer, learner

def resnet_imagenet_trainer():
    # Other code redacted for simplicity...

    # Construct the trainer config.
    trainer_cfg = trainer.SpmdTrainer.default_config()
    trainer_cfg.model = model_cfg
    trainer_cfg.input = input_cfg
    trainer_cfg.evalers = {"eval_validation": evaler_cfg}
    trainer_cfg.learner = learner.Learner.default_config().set(optimizer=optimizer_cfg)
    trainer_cfg.max_step = 112_590  # Roughly 90 epochs.

The code block plugs in the model, input, evalers, and other components that we have already defined.

Note that we have wrapped the optimizer with a Learner.Config. The Learner internally uses the optimizer to update model params, and acts as an intermediary to the trainer.

For some basic book-keeping, we also configure the frequency of checkpoints and summaries:

from axlearn.common import trainer, learner

def resnet_imagenet_trainer():
    # Other code redacted for simplicity...

    # Construct the trainer config.
    trainer_cfg = trainer.SpmdTrainer.default_config()
    trainer_cfg.model = model_cfg
    trainer_cfg.input = input_cfg
    trainer_cfg.evalers = {"eval_validation": evaler_cfg}
    trainer_cfg.learner = learner.Learner.default_config().set(optimizer=optimizer_cfg)
    trainer_cfg.max_step = 112_590  # Roughly 90 epochs.

+   # Define checkpoint frequency and summary frequency.
+   trainer_cfg.checkpointer.save_policy.n = 12_510
+   trainer_cfg.checkpointer.keep_every_n_steps = 12_510
+   trainer_cfg.summary_writer.write_every_n_steps = 100

The last step is to expose the trainer_cfg in a format that is understood by the AXLearn CLI. We'll cover this in more detail in a later section, but in short, the AXLearn CLI looks for a "registry" of trainers via named_trainer_configs:

from axlearn.common import trainer, learner

def resnet_imagenet_trainer():
    # Other code redacted for simplicity...

    # Construct the trainer config.
    trainer_cfg = trainer.SpmdTrainer.default_config()
    trainer_cfg.model = model_cfg
    trainer_cfg.input = input_cfg
    trainer_cfg.evalers = {"eval_validation": evaler_cfg}
    trainer_cfg.learner = learner.Learner.default_config().set(optimizer=optimizer_cfg)
    trainer_cfg.max_step = 112_590  # Roughly 90 epochs.

    # Define checkpoint frequency and summary frequency.
    trainer_cfg.checkpointer.save_policy.n = 12_510
    trainer_cfg.checkpointer.keep_every_n_steps = 12_510
    trainer_cfg.summary_writer.write_every_n_steps = 100

+   # Return the final trainer config.
+   return trainer_cfg.set(name="resnet_imagenet")


+# Expose the trainer configs to the AXLearn CLI.
+def named_trainer_configs():
+   # Return a mapping from name(s) to trainer config function(s).
+   return {"ResNet-50": resnet_imagenet_trainer}

Testing

While building a trainer config with Python code allows us to reuse configuration logic, a downside is that the indirections make it hard to see the effects, especially when we want to update the logic. To address this issue, the golden_config_test generates the full configuration of each registered trainer config and puts them under axlearn/experiments/testdata. For example, you can see the full trainer config of the ResNet50 experiment on ImageNet here. This is especially useful for catching unintended changes to experiment configurations during refactoring.

To generate the golden configs for your own trainer(s), update the golden_config_test and run:

pytest -n auto axlearn/experiments/golden_config_test.py --update

For more details on golden configs, please see concepts.

Before launching experiments into the cloud, it's also recommended to write unit tests to catch failures early. For some examples, we refer the reader to unit tests under axlearn/experiments/vision/resnet.


Summary

Congratulations on getting this far! Hopefully, you now have a taste of how to build experiments with AXLearn. Granted, this was a fairly quick overview of what AXLearn has to offer, and some of the content may still feel abstract or foreign. For more details on the config system or other concepts, please refer to the concepts page.

The following section will cover how to launch your experiment in the cloud.


Launching Training

AXLearn comes with tooling for provisioning and launching training on public clouds. This section will guide you with launching training on a Google Cloud TPU.

Pre-requisites

We assume you have:

  1. gcloud setup, following e.g. https://cloud.google.com/sdk/docs/install.
  2. A Google Cloud Platform (GCP) project. To set up a brand new GCP project with the basic resources needed, please run this script.
  3. TPU quota in your project. To request TPU quota, please follow these instructions.
  4. At least one Google Cloud Storage (GCS) bucket.

Preparing the CLI

Please follow the instructions in the CLI docs to setup the CLI.

We assume that you are launching from a working directory that contains a pyproject.toml or setup.py (for instance, if you cloned the repo, you should have one already). If not, you can create a minimal pyproject.toml:

[project]
name = "my_project"
version = "0.0.1"
dependencies = ["axlearn"]

[project.optional-dependencies]
tpu = ["axlearn[tpu]"]

Launching a Command

We can now leverage the AXLearn infrastructure to launch commands on arbitrary TPU configurations.

First, make sure you have authenticated to GCP:

# Authenticate to GCP.
axlearn gcp auth

We can then test a simple v4-8 command:

# Run a dummy command on v4-8.
# Note: the "'...'" quotes are important.
axlearn gcp tpu start --name=$USER-test --tpu_type=v4-8 -- python3 -c "'import jax; print(jax.devices())'"

This provisions a v4-8 TPU, installs axlearn on it, and runs the python3 command that comes after --. As the job is running, any logs from the command will be synced to GCS. Once the job is completed, the TPU resources will be torn down.


Launching an Experiment

To launch an actual experiment, we must first define an experiment module that AXLearn understands how to parse.

There are two aspects to this:

  1. By default, AXLearn looks under the axlearn/experiments directory for experiment modules, so we should define it there.
  2. Experiment modules must expose a function named_trainer_configs which returns a dictionary with experiment names as keys, and TrainerConfigFns as values. As the name implies, a TrainerConfigFn is a function that simply returns a trainer config, similar to the one constructed above5.

We've already packaged the ResNet on ImageNet example for you, which can be launched via:

OUTPUT_DIR=gs://path/to/$USER/experiments/resnet50-$(date +%F)
DATA_DIR=gs://path/to/tensorflow_datasets

axlearn gcp tpu start --tpu_type=v4-8 --output_dir=$OUTPUT_DIR -- \
    python3 -m axlearn.common.launch_trainer_main \
    --module=vision.resnet.imagenet_trainer --config=ResNet-50 \
    --trainer_dir=$OUTPUT_DIR --data_dir=$DATA_DIR --jax_backend=tpu

If you have been following along with the code, assuming you have a file axlearn/experiments/tutorial.py, you can also launch your own experiment with:

OUTPUT_DIR=gs://path/to/$USER/experiments/resnet50-$(date +%F)
DATA_DIR=gs://path/to/tensorflow_datasets

axlearn gcp tpu start --tpu_type=v4-8 --output_dir=$OUTPUT_DIR -- \
    python3 -m axlearn.common.launch_trainer_main \
-   --module=vision.resnet.imagenet_trainer --config=ResNet-50 \
+   --module=tutorial --config=ResNet-50 \
    --trainer_dir=$OUTPUT_DIR --data_dir=$DATA_DIR --jax_backend=tpu

Both commands are similar to the one from the previous section except we run the trainer defined by --module and --config instead of simply printing jax.devices().

The OUTPUT_DIR defines where to emit training artifacts like checkpoints and summaries, and the DATA_DIR defines where to look for datasets.

To view tensorboard, point to the OUTPUT_DIR:

tensorboard --logdir=$OUTPUT_DIR

Or, if VertexAI is configured, you should also see a VertexAI Tensorboard link.


Launching via Bastion

In an organization setting, it's typical to launch jobs from a centralized system, which can:

  1. Constantly monitor and restart jobs as necessary.
  2. Queue and schedule jobs based on priority.

AXLearn provides such an orchestrator called the "bastion", which can run in GCP with minimal dependencies.

It is often recommended to launch from the bastion. Please see the infrastructure docs for instructions on how to set it up.


Next Steps

As a next step, we encourage you to read some of the AXLearn Concepts if you have not already. While this document covers much of "how" to run experiments, the next section aims to explain the "why" behind the AXLearn design.

Footnotes

  1. Note that it's possible to use other types of data processing pipelines (e.g. torch dataloaders). AXLearn is designed to be an open system.

  2. This affects things like whether we should shuffle the dataset or not.

  3. More accurately, it is a config that instantiates to such a function, but more on that in concepts.

  4. BatchNorm is used throughout the ResNet architecture by default. We did not need to configure it explicitly.

  5. One common question is: why return a TrainerConfigFn instead of, say, the trainer config itself? The reason is that a TrainerConfigFn allows us to defer the construction of a trainer config until we need to use it. When your project has many experiments, the cost of building all trainer configs can be non-trivial (such as when running golden config tests).