Skip to content

Commit

Permalink
add top level seeding to cifar (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Feb 11, 2023
1 parent 47d5086 commit 5fde16d
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion examples/cifar/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from composer.callbacks import LRMonitor, MemoryMonitor, SpeedMonitor
from composer.loggers import ProgressBarLogger, WandBLogger
from composer.optim import DecoupledSGDW, MultiStepWithWarmupScheduler
from composer.utils import dist
from composer.utils import dist, reproducibility
from omegaconf import OmegaConf

from examples.cifar.data import build_cifar10_dataspec
Expand All @@ -29,6 +29,7 @@ def build_logger(name: str, kwargs: Dict):


def main(config):
reproducibility.seed_all(config.seed)
if config.grad_accum == 'auto' and not torch.cuda.is_available():
raise ValueError(
'grad_accum="auto" requires training with a GPU; please specify grad_accum as an integer'
Expand Down

0 comments on commit 5fde16d

Please sign in to comment.