Skip to content

Commit

Permalink
Fix compiled SGD
Browse files Browse the repository at this point in the history
  • Loading branch information
Tony-Y committed Jan 18, 2025
1 parent 039c9a8 commit 9c0a48a
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions examples/cifar10/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,15 @@ def compile_functions():
test_iter_loss_fn = torch.compile(test_iter_loss_fn, mode="reduce-overhead")


def init_momentum_buffer(optimizer):
for group in optimizer.param_groups:
if group["momentum"] != 0:
for p in group["params"]:
state = optimizer.state[p]
if state.get("momentum_buffer") is None:
state["momentum_buffer"] = torch.zeros_like(p.data)


def main(args=None):
# Training settings
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Example')
Expand Down Expand Up @@ -297,6 +306,8 @@ def main(args=None):
period=args.warmup_period)

if args.compile:
if args.algorithm == 'sgd':
init_momentum_buffer(optimizer)
compile_functions()

best_acc = 0.0
Expand Down

0 comments on commit 9c0a48a

Please sign in to comment.