Skip to content

Commit

Permalink
Skepticleo trainer argparser (Lightning-AI#1023)
Browse files Browse the repository at this point in the history
* Added default parser for trainer and class method to construct trainer from default args

* Removed print statement

* Added test for constructing Trainer from command line args

* Removed extra line

* Removed redundant imports, removed whitespace from empty lines

* Fixed typo

* Updated default parser creation to get class attributes automatically

* Updated default parser creation to get class attributes automatically

* Added method to get default args for trainer

* Trimmed trainer get default args method

* Updated from argparse method to not return trainer with static arguments

* Update trainer get default args to classmethod

* adjustment

* fix

* Fixed variable name

* Update trainer.py

* Update test_trainer.py

* Update trainer.py

* Update tests/trainer/test_trainer.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update trainer.py

* Update test_trainer.py

* Update trainer.py

* Update test_trainer.py

* Update tests/trainer/test_trainer.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/trainer/trainer.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update trainer.py

* Update test_trainer.py

Co-authored-by: Mudit Tanwani <mudittanwani@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people authored and tullie committed Apr 3, 2020
1 parent c993221 commit 41b6d3f
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 4 deletions.
30 changes: 29 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
import logging as log
from typing import Union, Optional, List, Dict, Tuple, Iterable
from argparse import ArgumentParser

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
profiler: Optional[BaseProfiler] = None,
benchmark: bool = False,
reload_dataloaders_every_epoch: bool = False,
**kwargs
):
r"""
Expand Down Expand Up @@ -627,6 +629,7 @@ def on_train_end(self):

# Transfer params
# Backward compatibility
self.num_nodes = num_nodes
if nb_gpu_nodes is not None:
warnings.warn("`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0"
" and this method will be removed in v0.8.0", DeprecationWarning)
Expand Down Expand Up @@ -747,10 +750,12 @@ def on_train_end(self):
self.weights_save_path = weights_save_path

# accumulated grads
self.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)

# allow int, string and gpu list
self.data_parallel_device_ids = parse_gpu_ids(gpus)
self.gpus = gpus
self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)

# tpu state flags
Expand Down Expand Up @@ -797,6 +802,7 @@ def on_train_end(self):
self.row_log_interval = row_log_interval

# how much of the data to use
self.overfit_pct = overfit_pct
self.determine_data_use_amount(train_percent_check, val_percent_check,
test_percent_check, overfit_pct)

Expand All @@ -822,6 +828,28 @@ def slurm_job_id(self) -> int:
job_id = None
return job_id

@classmethod
def default_attributes(cls):
return vars(cls())

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
"""Extend existing argparse by default `Trainer` attributes."""
parser = ArgumentParser(parents=[parent_parser])

trainer_default_params = Trainer.default_attributes()

for arg in trainer_default_params:
parser.add_argument('--{0}'.format(arg), default=trainer_default_params[arg], dest=arg)

return parser

@classmethod
def from_argparse_args(cls, args):

params = vars(args)
return cls(**params)

def __parse_gpu_ids(self, gpus):
"""Parse GPUs id.
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def print_nan_gradients(self):
log.info(param, param.grad)

def configure_accumulated_gradients(self, accumulate_grad_batches):
self.accumulate_grad_batches = None

if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
Expand Down
22 changes: 21 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import math
import os

import pytest
import torch
import argparse

import tests.models.utils as tutils
from unittest import mock
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
Expand Down Expand Up @@ -600,3 +601,22 @@ def test_end(self, outputs):

model = LightningTestModel(hparams)
Trainer().test(model)

@mock.patch('argparse.ArgumentParser.parse_args',
return_value=argparse.Namespace(**Trainer.default_attributes()))
def test_default_args(tmpdir):
"""Tests default argument parser for Trainer"""
tutils.reset_seed()

# logger file to get meta
logger = tutils.get_test_tube_logger(tmpdir, False)

parser = argparse.ArgumentParser(add_help=False)
args = parser.parse_args()
args.logger = logger

args.max_epochs = 5
trainer = Trainer.from_argparse_args(args)

assert isinstance(trainer, Trainer)
assert trainer.max_epochs == 5

0 comments on commit 41b6d3f

Please sign in to comment.