Skip to content

Commit

Permalink
add save_load_checkpoint option in helloworld (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
msftsw authored Feb 8, 2022
1 parent 64666cb commit ca6f018
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tutel/examples/helloworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
parser.add_argument('--l_aux_wt', type=float, default=0.0)
parser.add_argument('--a2a_ffn_overlap_degree', type=int, default=1)
parser.add_argument('--num_steps', type=int, default=100)
parser.add_argument('--save_load_checkpoint', default=False, action='store_true')
args = parser.parse_args()

parallel_env = system_init.init_data_model_parallel()
Expand Down Expand Up @@ -85,6 +86,13 @@ def forward(self, input):
model = ExampleModel()
dist_print(model)

if args.save_load_checkpoint:
checkpoint_path = './distributed-hellworld-%d-in-%d.ckpt' % (parallel_env.global_rank, parallel_env.global_size)
if os.path.exists(checkpoint_path):
model.load_state_dict(torch.load(checkpoint_path))
else:
print('Checkpoint not loaded: file `%s` is not found' % checkpoint_path)

optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)

torch.manual_seed(0)
Expand Down Expand Up @@ -124,3 +132,6 @@ def forward(self, input):

average_time /= 10
dist_print('\n[Summary] Average synchronized step_time = %s sec.' % average_time)

if args.save_load_checkpoint:
torch.save(model.state_dict(), checkpoint_path)

0 comments on commit ca6f018

Please sign in to comment.