Skip to content

Commit

Permalink
Merge pull request #7 from yuzhou03/wwl/modify
Browse files Browse the repository at this point in the history
modify-init_dist_training_env
  • Loading branch information
yuzhou03 authored Feb 24, 2023
2 parents 3377ba1 + 2c3193b commit 1e82264
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"../../"))) # benchmarks目录
# 本地库
import config
from driver import Event, dist_pytorch, check
from driver import Event, dist_pytorch
from driver.helper import InitHelper, get_finished_info

# TODO 导入相关的模块、方法、变量。这里保持名称一致,实现可以不同。
Expand All @@ -34,16 +34,17 @@

def main() -> Tuple[Any, Any]:
global logger

# init
init_helper = InitHelper(config)
model_driver = init_helper.init_driver(config.name) # _base.py增加模型名称name
config.local_rank = init_helper.get_local_rank()
logger = model_driver.logger

model_driver = init_helper.init_driver() # _base.py增加模型名称name
config = model_driver.config
dist_pytorch.init_dist_training_env(config)
check.check_config(config)

dist_pytorch.barrier()
dist_pytorch.barrier(config.vendor)
model_driver.event(Event.INIT_START)

# logger
logger = model_driver.logger
init_start_time = logger.previous_log_time

# TODO 得到seed
Expand Down Expand Up @@ -78,9 +79,9 @@ def main() -> Tuple[Any, Any]:
training_state._trainer = trainer

# 设置分布式环境, trainer init()
dist_pytorch.barrier()
dist_pytorch.barrier(config.vendor)
trainer.init()
dist_pytorch.barrier()
dist_pytorch.barrier(config.vendor)

# evaluation统计
init_evaluation_start = time.time()
Expand Down Expand Up @@ -111,7 +112,7 @@ def main() -> Tuple[Any, Any]:
training_state.init_time = (init_end_time - init_start_time) / 1e+3

# TRAIN_START
dist_pytorch.barrier()
dist_pytorch.barrier(config.vendor)
model_driver.event(Event.TRAIN_START)
raw_train_start_time = logger.previous_log_time

Expand Down
2 changes: 1 addition & 1 deletion training/benchmarks/cpm/pytorch/config/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@

## load and save args
# Path to a directory containing a model checkpoint.
init_checkpoint: str = None
init_checkpoint = "cpm_model_states_medium.pt"

# Output directory to save checkpoints to.
save: str = None
Expand Down
2 changes: 1 addition & 1 deletion training/benchmarks/cpm/pytorch/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def main():
logger = cpm_driver.logger
dist_pytorch.init_dist_training_env(config)

check.check_config(config, "cpm_model_states_medium.pt")
check.check_config(config)

dist_pytorch.barrier()
cpm_driver.event(Event.INIT_START)
Expand Down
9 changes: 4 additions & 5 deletions training/benchmarks/driver/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_config_arg(config, name):
return None


def check_config(config, model_pt_file):
def check_config(config):
print(
"device: {} n_device: {}, distributed training: {}, 16-bits training: {}"
.format(config.device, config.n_device, config.local_rank != -1,
Expand All @@ -41,12 +41,11 @@ def check_config(config, model_pt_file):
config.eval_data = ospath.join(data_dir, eval_data)

init_checkpoint = get_config_arg(config, "init_checkpoint")
if init_checkpoint is None:
config.init_checkpoint = ospath.join(data_dir, model_pt_file)
else:
config.init_checkpoint = init_checkpoint
if init_checkpoint is not None:
config.init_checkpoint = ospath.join(data_dir, config.init_checkpoint)

if config.gradient_accumulation_steps < 1:
raise ValueError(
"Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
.format(config.gradient_accumulation_steps))
return config
48 changes: 24 additions & 24 deletions training/benchmarks/driver/dist_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,6 @@ def setup_seeds(master_seed, epochs, device):
return worker_seeds, shuffling_seeds


def barrier():
"""
Works as a temporary distributed barrier, currently pytorch
doesn't implement barrier for NCCL backend.
Calls all_reduce on dummy tensor and synchronizes with GPU.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.all_reduce(torch.cuda.FloatTensor(1))
torch.cuda.synchronize()


def get_rank(default=0):
"""
Gets distributed rank or returns zero if distributed is not initialized.
Expand All @@ -113,16 +102,17 @@ def get_rank(default=0):
return rank


def get_world_size():
def get_world_size(vendor="nvidia"):
"""
Gets total number of distributed workers or returns one if distributed is
not initialized.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
else:
world_size = 1
return world_size
if vendor == "nvidia":
if torch.distributed.is_available() and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
else:
world_size = 1
return world_size


def main_proc_print(*args, **kwargs):
Expand All @@ -146,14 +136,22 @@ def set_device(cuda, local_rank):
return device


def barrier(vendor):
"""
Works as a temporary distributed barrier, currently pytorch
doesn't implement barrier for NCCL backend.
Calls all_reduce on dummy tensor and synchronizes with GPU.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
if vendor == "nvidia":
torch.distributed.all_reduce(torch.cuda.FloatTensor(1))
torch.cuda.synchronize()


def init_dist_training_env(config):
''' TODO: Support other accelarators. '''
if config.local_rank == -1:
config.device = torch.device("cuda")
config.n_device = torch.cuda.device_count()
else:
if config.vendor == "nvidia":
torch.cuda.set_device(config.local_rank)
config.device = torch.device("cuda", config.local_rank)
host_addr_full = 'tcp://' + os.environ[
"MASTER_ADDR"] + ':' + os.environ["MASTER_PORT"]
rank = int(os.environ["RANK"])
Expand All @@ -162,9 +160,11 @@ def init_dist_training_env(config):
init_method=host_addr_full,
rank=rank,
world_size=world_size)
config.device = torch.device("cuda", config.local_rank)
config.n_device = torch.distributed.get_world_size()
return

else:
raise Exception("config.vendor should be right.")


def global_batch_size(config):
return config.train_batch_size * config.n_device
Expand Down
35 changes: 18 additions & 17 deletions training/benchmarks/driver/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
import numpy as np
import torch
from driver import perf_logger, Driver
from driver import perf_logger, Driver, check
import driver


Expand All @@ -18,17 +18,29 @@ class InitHelper:

def __init__(self, config: object) -> None:
self.config = config
self.update_local_rank()
self.config = check.check_config(self.config)

def init_driver(self) -> Driver:
"""
params:
name: model name
"""
config = self.config
model_driver = Driver(config, config.mutable_params)
model_driver.setup_config(argparse.ArgumentParser(config.name))
model_driver.setup_modules(driver, globals(), locals())
return model_driver

def get_logger(self) -> perf_logger.PerfLogger:
"""get logger for FlagPerf"""
return perf_logger.PerfLogger.get_default_logger(
rank=self.config.local_rank)

def get_local_rank(self) -> int:
"""get local rank"""
if self.config.use_env and 'LOCAL_RANK' in os.environ:
return int(os.environ['LOCAL_RANK'])
return 0
def update_local_rank(self) -> int:
"""set local rank"""
if 'LOCAL_RANK' in os.environ:
self.config.local_rank = int(os.environ['LOCAL_RANK'])

def set_seed(self, seed: int, vendor: str):
"""set seed"""
Expand All @@ -45,17 +57,6 @@ def set_seed(self, seed: int, vendor: str):
# TODO 其他厂商设置seed,在此扩展
pass

def init_driver(self, name: str, vendor: str = None) -> Driver:
"""
params:
name: driver name
vendor: vendor name
"""
config = self.config
model_driver = Driver(config, config.mutable_params)
model_driver.setup_config(argparse.ArgumentParser(name))
model_driver.setup_modules(driver, globals(), locals())
return model_driver


def get_finished_info(start_time: int, state: object, do_train: bool,
Expand Down
2 changes: 1 addition & 1 deletion training/benchmarks/glm/pytorch/config/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
train_data: str = "ReCoRD/glm_train_eval_hdf5_sparse/train_hdf5/train_sparse.hdf5"
eval_data: str = "ReCoRD/glm_train_eval_hdf5_sparse/eval_hdf5/eval_sparse.hdf5"
output_dir: str = ""
init_checkpoint: str = None
init_checkpoint = "blocklm-large-blank/200000/mp_rank_00_model_states.pt"

# =========================================================
# Model
Expand Down
3 changes: 1 addition & 2 deletions training/benchmarks/glm/pytorch/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def main():

dist_pytorch.init_dist_training_env(config)

check.check_config(
config, "blocklm-large-blank/200000/mp_rank_00_model_states.pt")
check.check_config(config)

dist_pytorch.barrier()
glm_driver.event(Event.INIT_START)
Expand Down

0 comments on commit 1e82264

Please sign in to comment.