Skip to content

Commit

Permalink
Separate global and local TQM responsibilities
Browse files Browse the repository at this point in the history
Naming scheme now consistently uses my_ or get_my_ prefix for local.
  • Loading branch information
Waino committed Feb 12, 2024
1 parent 9ddb8fe commit 4e768e4
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 108 deletions.
180 changes: 100 additions & 80 deletions mammoth/distributed/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def __init__(
tasks: List[TaskSpecs],
accum_count: int,
world_context: WorldContext,
device_context: Optional[DeviceContext] = None,
components_to_gpus=None,
components_to_groups=None,
task_distribution_strategy: Optional[TaskDistributionStrategy] = None,
Expand All @@ -200,13 +199,8 @@ def __init__(
self.accum_count = accum_count[0] if isinstance(accum_count, list) else accum_count
self.task_distribution_strategy = task_distribution_strategy
self.world_context = world_context
self.device_context = device_context
self.uses_adapters = uses_adapters

if self.world_context and self.device_context:
logger.info(f'in task_queue_manager: node_rank {self.node_rank} local_rank {self.local_rank}')
self.device_context.validate(self.world_context)

self.components_to_gpus = components_to_gpus
self.components_to_groups = components_to_groups
self.sampled_task_counts = Counter()
Expand All @@ -219,18 +213,6 @@ def gpus_per_node(self):
def n_nodes(self):
return self.world_context.n_nodes

@property
def node_rank(self):
if not self.device_context:
raise Exception('Trying to get node_rank of global TQM')
return self.device_context.node_rank

@property
def local_rank(self):
if not self.device_context:
raise Exception('Trying to get local_rank of global TQM')
return self.device_context.local_rank

@classmethod
def from_opts(cls, opts: Namespace, world_context: WorldContext):
n_tasks = len(opts.tasks)
Expand Down Expand Up @@ -320,15 +302,15 @@ def global_to_local(self, node_rank, local_rank, opts):
assert local_rank is not None
task_distribution_strategy = self._get_strategy(node_rank=node_rank, local_rank=local_rank, opts=opts)
device_context = self.world_context.global_to_local(node_rank, local_rank)
return self.__class__(
return LocalTaskQueueManager(
self.tasks,
accum_count=self.accum_count,
world_context=self.world_context,
device_context=device_context,
components_to_gpus=self.components_to_gpus,
components_to_groups=self.components_to_groups,
task_distribution_strategy=task_distribution_strategy,
uses_adapters=self.uses_adapters,
device_context=device_context,
)

def _get_strategy(self, node_rank, local_rank, opts):
Expand Down Expand Up @@ -365,12 +347,8 @@ def __repr__(self):
def _tasks_on_device(self, node_rank, local_rank):
return [task for task in self.tasks if (task.node_rank, task.local_rank) == (node_rank, local_rank)]

def get_tasks(self):
if not self.device_context:
# global mode: return all
return self.tasks
else:
return self._tasks_on_device(self.node_rank, self.local_rank)
def get_all_tasks(self):
return self.tasks

@staticmethod
def _default_node_gpu(n_tasks, n_nodes, gpus_per_node):
Expand Down Expand Up @@ -445,20 +423,81 @@ def create_all_distributed_groups(
continue
sorted_global_ranks = list(sorted(global_ranks))
min_rank = sorted_global_ranks[0]
# The torch.distributed.new_group function requires that all
# processes in the main group (i.e. all processes that are part of
# the distributed job) enter the function, even if they are not
# going to be members of the group. Additionally, groups should be
# created in the same order in all processes.
group_tpl = (min_rank, new_group_func(sorted_global_ranks))
component_type = key[0]
component_id = key[1:]
self.components_to_groups.setdefault(component_type, OrderedDict())[component_id] = group_tpl

return self.components_to_groups

def get_langs(self, side):
if side == 'src':
return [task.src_lang for task in self.get_all_tasks()]
elif side == 'tgt':
return [task.tgt_lang for task in self.get_all_tasks()]
else:
raise ValueError(f'side "{side}" not in {{src, tgt}}')


class LocalTaskQueueManager(TaskQueueManager):
def __init__(
self,
tasks: List[TaskSpecs],
accum_count: int,
world_context: WorldContext,
components_to_gpus=None,
components_to_groups=None,
task_distribution_strategy: Optional[TaskDistributionStrategy] = None,
uses_adapters: bool = False,
device_context: Optional[DeviceContext] = None,
):
"""
Schedules tasks (language pairs) to devices.
Has the responsibility for all resources that need to be
consistently assigned to nodes and GPUs.
This includes data, parameters, and vocabularies.
`local_rank` is the local rank of the GPU on this node.
When `node_rank` and `local_rank` are given, the methods return only
the items needed in the specified process.
When set to None, all items are returned.
"""
super().__init__(
tasks=tasks,
accum_count=accum_count,
world_context=world_context,
task_distribution_strategy=task_distribution_strategy,
uses_adapters=uses_adapters,
components_to_gpus=components_to_gpus,
components_to_groups=components_to_groups,
)

assert device_context is not None
self.device_context = device_context

logger.info(f'in task_queue_manager: node_rank {self.node_rank} local_rank {self.local_rank}')
self.device_context.validate(self.world_context)

self.sampled_task_counts = Counter()

@property
def node_rank(self):
return self.device_context.node_rank

@property
def local_rank(self):
return self.device_context.local_rank

@property
def global_rank(self):
assert self.node_rank is not None
assert self.local_rank is not None
return self.node_rank * self.gpus_per_node + self.local_rank

def get_distributed_groups(
def get_my_distributed_groups(
self,
new_group_func=torch.distributed.new_group,
):
Expand Down Expand Up @@ -502,15 +541,15 @@ def get_distributed_groups(

return my_distributed_groups

def get_grouped_components(self, model):
def get_my_grouped_components(self, model):
"""
Returns nested dict of component_type -> component_id -> nn.Module.
Only components present on this GPU are returned.
Unlike get_distributed_groups, this method also returns components on a single device,
Unlike get_my_distributed_groups, this method also returns components on a single device,
and it does not retrieve communication groups.
"""
if self.components_to_groups is None:
raise Exception('Must call get_distributed_groups first')
raise Exception('Must call get_my_distributed_groups first')

my_grouped_components = {
'encoder': OrderedDict(),
Expand All @@ -524,7 +563,7 @@ def get_grouped_components(self, model):
if not self.world_context.is_distributed():
tasks = self.tasks
else:
tasks = self.get_tasks()
tasks = self.get_my_tasks()

for task in tasks:
# loop over my tasks, getting all the relevant module ids and modules
Expand Down Expand Up @@ -555,21 +594,33 @@ def get_grouped_components(self, model):

return my_grouped_components

# TODO: soon deprecated by #18 Data pipeline refactoring
def get_fields(self, side: str, fields_dict):
"""Returns a list of tuples: (side, lang, component_id, fields)."""
raise RuntimeError
def sample_corpus_ids(self, communication_batch_id: int):
corpus_id = self.task_distribution_strategy.sample_corpus_ids(
1,
communication_batch_id,
)[0]
corpus_ids = [corpus_id for _ in range(self.accum_count)]
self.sampled_task_counts.update(corpus_ids)
return corpus_ids

# FIXME: merge with below
def get_vocabularies(self, opts: Namespace, side: str):
result = []
for task in self.get_tasks():
lang = self.src_lang if side == 'src' else self.tgt_lang
vocab_path = opts.__getattribute__(f'{side}_vocab')[lang]
result.append((lang, vocab_path))
return result
def get_my_encoders(self, layer_stack_index: int):
my_encoder_ids = [task.encoder_id[layer_stack_index] for task in self.get_my_tasks()]
return my_encoder_ids

def get_my_decoders(self, layer_stack_index: int):
my_decoder_ids = [task.decoder_id[layer_stack_index] for task in self.get_my_tasks()]
return my_decoder_ids

def get_vocabs(self, side: str, vocabs_dict):
def get_my_src_langs(self):
return [task.src_lang for task in self.get_my_tasks()]

def get_my_tgt_langs(self):
return [task.tgt_lang for task in self.get_my_tasks()]

def get_my_generators(self):
return [task.tgt_lang for task in self.get_my_tasks()]

def get_my_vocabs(self, side: str, vocabs_dict):
"""Returns a list of tuples: (side, lang, component_id, vocabs).
side: Either 'src' or 'tgt'.
lang: The language code. Vocabularies are language specific.
Expand All @@ -579,7 +630,7 @@ def get_vocabs(self, side: str, vocabs_dict):
seen = set()
result = []
component_id = None # for hysterical raisins
for task in self.get_tasks():
for task in self.get_my_tasks():
if side == 'src':
lang = task.src_lang
else:
Expand All @@ -589,36 +640,5 @@ def get_vocabs(self, side: str, vocabs_dict):
seen.add((side, lang, component_id))
return result

def sample_corpus_ids(self, communication_batch_id: int):
corpus_id = self.task_distribution_strategy.sample_corpus_ids(
1,
communication_batch_id,
)[0]
corpus_ids = [corpus_id for _ in range(self.accum_count)]
self.sampled_task_counts.update(corpus_ids)
return corpus_ids

def get_encoders(self, layer_stack_index: int):
my_encoder_ids = [task.encoder_id[layer_stack_index] for task in self.get_tasks()]
return my_encoder_ids

def get_decoders(self, layer_stack_index: int):
my_decoder_ids = [task.decoder_id[layer_stack_index] for task in self.get_tasks()]
return my_decoder_ids

def get_src_langs(self):
return [task.src_lang for task in self.get_tasks()]

def get_tgt_langs(self):
return [task.tgt_lang for task in self.get_tasks()]

def get_generators(self):
return [task.tgt_lang for task in self.get_tasks()]

def get_langs(self, side):
if side == 'src':
return [task.src_lang for task in self.get_tasks()]
elif side == 'tgt':
return [task.tgt_lang for task in self.get_tasks()]
else:
raise ValueError(f'side "{side}" not in {{src, tgt}}')
def get_my_tasks(self):
return self._tasks_on_device(self.node_rank, self.local_rank)
2 changes: 1 addition & 1 deletion mammoth/inputters/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def from_opts(cls, task_queue_manager, transforms_cls, vocabs_dict, opts, is_tra

def _init_datasets(self):
self.dataset_iterators = dict()
for task in self.task_queue_manager.get_tasks():
for task in self.task_queue_manager.get_my_tasks():
src_vocab = self.vocabs_dict[('src', task.src_lang)]
tgt_vocab = self.vocabs_dict[('tgt', task.tgt_lang)]
# merged_fields = {'src': src_fields['src'], 'tgt': tgt_fields['tgt']}
Expand Down
6 changes: 3 additions & 3 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,13 @@ def build_task_specific_model(
generators_md = nn.ModuleDict()

# FIXME: it's getting late and I just want this to compile
for side, lang, _, vocab in task_queue_manager.get_vocabs(side='src', vocabs_dict=vocabs_dict):
for side, lang, _, vocab in task_queue_manager.get_my_vocabs(side='src', vocabs_dict=vocabs_dict):
src_emb = build_src_emb(model_opts, vocab)
src_embs[lang] = src_emb
pluggable_src_emb = PluggableEmbeddings(src_embs)
encoder = build_only_enc(model_opts, pluggable_src_emb, task_queue_manager)

for side, lang, _, vocab in task_queue_manager.get_vocabs(side='tgt', vocabs_dict=vocabs_dict):
for side, lang, _, vocab in task_queue_manager.get_my_vocabs(side='tgt', vocabs_dict=vocabs_dict):
tgt_emb = build_tgt_emb(model_opts, vocab)
tgt_embs[lang] = tgt_emb
generator = build_generator(model_opts, len(vocab), tgt_emb)
Expand Down Expand Up @@ -510,7 +510,7 @@ def create_all_adapters(model, opts, task_queue_manager):
my_dec_adapter_ids = set()
adapter_to_encoder_ids = defaultdict(set)
adapter_to_decoder_ids = defaultdict(set)
for task in task_queue_manager.get_tasks():
for task in task_queue_manager.get_my_tasks():
for adapter_id in task.encoder_adapter_ids:
adapter_id = tuple(adapter_id)
my_enc_adapter_ids.add(adapter_id)
Expand Down
2 changes: 1 addition & 1 deletion mammoth/modules/layer_stack_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def from_opts(cls, opts, embeddings, task_queue_manager, is_on_top=False):
for layer_stack_index, n_layers in enumerate(opts.dec_layers):
is_on_top = layer_stack_index == len(opts.dec_layers) - 1
stacks = nn.ModuleDict()
for module_id in task_queue_manager.get_decoders(layer_stack_index):
for module_id in task_queue_manager.get_my_decoders(layer_stack_index):
if module_id in stacks:
# several tasks using the same layer stack
continue
Expand Down
2 changes: 1 addition & 1 deletion mammoth/modules/layer_stack_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def from_opts(cls, opts, embeddings, task_queue_manager):
for layer_stack_index, n_layers in enumerate(opts.enc_layers):
stacks = nn.ModuleDict()
is_on_top = layer_stack_index == len(opts.enc_layers) - 1
for module_id in task_queue_manager.get_encoders(layer_stack_index):
for module_id in task_queue_manager.get_my_encoders(layer_stack_index):
if module_id in stacks:
# several tasks using the same layer stack
continue
Expand Down
Loading

0 comments on commit 4e768e4

Please sign in to comment.