Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Schedule free optimizer support #2631

Merged
merged 6 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 225 additions & 0 deletions examples/by_feature/schedule_free.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os

import evaluate
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, set_seed

from accelerate import Accelerator, DistributedType
from accelerate.utils import is_schedulefree_available


if is_schedulefree_available():
import schedulefree
else:
raise ImportError(
"This example requires the `schedulefree` library. Please install it with `pip install schedulefree`"
)


########################################################################
# This is a fully working simple example to use Accelerate and Facebook's
# scheduler-free optimizer: https://github.com/facebookresearch/schedule_free/
#
# This example trains a Bert base model on GLUE MRPC
# in any of the following settings (with the same script):
# - single CPU or single GPU
# - multi GPUS (using PyTorch distributed mode)
# - (multi) TPUs
# - fp16 (mixed-precision) or fp32 (normal precision)
#
# To run it in each of these various modes, follow the instructions
# in the readme for examples:
# https://github.com/huggingface/accelerate/tree/main/examples
#
########################################################################


MAX_GPU_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 32


def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
"""
Creates a set of `DataLoader`s for the `glue` dataset,
using "bert-base-cased" as the tokenizer.

Args:
accelerator (`Accelerator`):
An `Accelerator` object
batch_size (`int`, *optional*):
The batch size for the train and validation DataLoaders.
"""
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
datasets = load_dataset("glue", "mrpc")

def tokenize_function(examples):
# max_length=None => use the model max length (it's actually the default)
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
return outputs

# Apply the method we just defined to all the examples in all the splits of the dataset
# starting with the main process first:
with accelerator.main_process_first():
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
remove_columns=["idx", "sentence1", "sentence2"],
)

# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
# transformers library
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

def collate_fn(examples):
# For Torchxla, it's best to pad everything to the same length or training will be very slow.
max_length = 128 if accelerator.distributed_type == DistributedType.XLA else None
# When using mixed precision we want round multiples of 8/16
if accelerator.mixed_precision == "fp8":
pad_to_multiple_of = 16
elif accelerator.mixed_precision != "no":
pad_to_multiple_of = 8
else:
pad_to_multiple_of = None

return tokenizer.pad(
examples,
padding="longest",
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
)

# Instantiate dataloaders.
train_dataloader = DataLoader(
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True
)
eval_dataloader = DataLoader(
tokenized_datasets["validation"],
shuffle=False,
collate_fn=collate_fn,
batch_size=EVAL_BATCH_SIZE,
drop_last=(accelerator.mixed_precision == "fp8"),
)

return train_dataloader, eval_dataloader


# For testing only


if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
from accelerate.test_utils.training import mocked_dataloaders

get_dataloaders = mocked_dataloaders # noqa: F811


def training_function(config, args):
# Initialize accelerator
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
lr = config["lr"]
num_epochs = int(config["num_epochs"])
seed = int(config["seed"])
batch_size = int(config["batch_size"])

metric = evaluate.load("glue", "mrpc")

# If the batch size is too big we use gradient accumulation
gradient_accumulation_steps = 1
if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.XLA:
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
batch_size = MAX_GPU_BATCH_SIZE

set_seed(seed)
train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True)

# We could avoid this line since the accelerator is set with `device_placement=True` (default value).
# Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer
# creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).
model = model.to(accelerator.device)
# Instantiate optimizer with warmup steps
optimizer = schedulefree.AdamWScheduleFree(
model.parameters(),
lr=lr,
warmup_steps=100,
)

# Prepare everything
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
# prepare method.

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)

# Now we train the model
for epoch in range(num_epochs):
model.train()
optimizer.train()
for step, batch in enumerate(train_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
outputs = model(**batch)
loss = outputs.loss
loss = loss / gradient_accumulation_steps
accelerator.backward(loss)
if step % gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()

model.eval()
optimizer.eval()
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True`.
batch.to(accelerator.device)
with torch.no_grad():
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
metric.add_batch(
predictions=predictions,
references=references,
)

eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)


def main():
parser = argparse.ArgumentParser(description="Simple example of training script.")
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16", "fp8"],
help="Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.",
)
parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.")
args = parser.parse_args()
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 16}
training_function(config, args)


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
accelerate # used to be installed in Amazon SageMaker environment
evaluate
datasets==2.3.2
datasets==2.3.2
schedulefree
12 changes: 12 additions & 0 deletions src/accelerate/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ def zero_grad(self, set_to_none=None):
raise ValueError("`set_to_none` for Optimizer.zero_grad` is not supported by this optimizer.")
self.optimizer.zero_grad()

def train(self):
"""
Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free`
"""
return self.optimizer.train()

def eval(self):
"""
Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free`
"""
return self.optimizer.eval()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding a @property for self.optimizer.training too? I don't think we also need a setter for this, as train() and eval() should be enough.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That gets to be optimizer-specific, so not a fan of it unless its downstreamed, as they currently don't have that: https://github.com/facebookresearch/schedule_free/blob/main/schedulefree/adamw_schedulefree.py#L86

(Otherwise I'd agree, yes that's a good idea)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, good point.

def step(self, closure=None):
if (
not self.gradient_state.is_xla_gradients_synced
Expand Down
8 changes: 8 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
is_npu_available,
is_pandas_available,
is_pippy_available,
is_schedulefree_available,
is_tensorboard_available,
is_timm_available,
is_torch_version,
Expand Down Expand Up @@ -213,6 +214,13 @@ def require_timm(test_case):
return unittest.skipUnless(is_timm_available(), "test requires the timm library")(test_case)


def require_schedulefree(test_case):
"""
Decorator marking a test that requires schedulefree. These tests are skipped when they are not.
"""
return unittest.skipUnless(is_schedulefree_available(), "test requires the schedulefree library")(test_case)


def require_bnb(test_case):
"""
Decorator marking a test that requires bitsandbytes. These tests are skipped when they are not.
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
is_pynvml_available,
is_rich_available,
is_sagemaker_available,
is_schedulefree_available,
is_tensorboard_available,
is_timm_available,
is_torch_xla_available,
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def is_msamp_available():
return _is_package_available("msamp", "ms-amp")


def is_schedulefree_available():
return _is_package_available("schedulefree")


def is_transformer_engine_available():
return _is_package_available("transformer_engine")

Expand Down
7 changes: 7 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
require_huggingface_suite,
require_multi_gpu,
require_pippy,
require_schedulefree,
require_trackers,
run_command,
slow,
Expand All @@ -47,6 +48,7 @@
"local_sgd.py",
"multi_process_metrics.py",
"memory.py",
"schedule_free.py",
"automatic_gradient_accumulation.py",
"fsdp_with_peak_mem_tracking.py",
"deepspeed_with_config_support.py",
Expand Down Expand Up @@ -216,6 +218,11 @@ def test_multi_process_metrics(self):
testargs = ["examples/by_feature/multi_process_metrics.py"]
run_command(self.launch_args + testargs)

@require_schedulefree
def test_schedulefree(self):
testargs = ["examples/by_feature/schedule_free.py"]
run_command(self.launch_args + testargs)

@require_trackers
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
def test_tracking(self):
Expand Down
Loading