-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/adapters 2 #83
Fix/adapters 2 #83
Conversation
Logs the architecture of the model, and the distributed components.
Now both training and translation use `--gpu_rank 0`. Closes #82
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
I have complaints on the cosmetic side (I'm really not into type hints when declaring vars) but nothing that serious
mammoth/opts.py
Outdated
@@ -548,7 +554,7 @@ def _add_train_general_opts(parser): | |||
type=float, | |||
default=[0.3], | |||
nargs='+', | |||
help="Dropout probability; applied in LSTM stacks.", | |||
help="Dropout probability; applied in LSTM stacks. (Probably legacy?)", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be, since we're now delegating this to xtransformers
@@ -27,7 +36,7 @@ def _validate_adapters(cls, opts): | |||
"""Parse corpora specified in data field of YAML file.""" | |||
if not opts.adapters: | |||
return | |||
adapter_opts = yaml.safe_load(opts.adapters) | |||
adapter_opts = yaml_or_dict(opts.adapters, name='opts.adapters') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure about this name
thing — I get that you want to display a less cryptic message, but in practice we devs are the only ones who'll ever get to see that TypeError
... and you're using that function once? the stacktrace should be unambiguous enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forgot to apply this to the other similar locations where yaml.safe_load
is applied. Fixed now.
This crashed when loading opts.adapters from a checkpoint. I'm not sure why those other usages didn't crash earlier, but better safe than sorry.
@@ -338,7 +380,12 @@ def build_model( | |||
) | |||
|
|||
model.to(device) | |||
# logger.info(model) | |||
if opts.log_model_structure: | |||
logger.info(model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feels like this should be debug rather than info, but ok, given your new flags
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this should be logger.debug
. However, we still haven't fixed the way we set up logging (AFAIK), so debugs are never visible.
mammoth/model_builder.py
Outdated
for component in task_queue_manager.get_my_distributed_components(): | ||
logger.info(component) | ||
for name, p in model.named_parameters(): | ||
print(f'{p.requires_grad} {name}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why direct print here?
@@ -314,13 +340,28 @@ def build_model( | |||
device = torch.device("cpu") | |||
logger.info(device) | |||
|
|||
enc_adapters_by_name: Optional[Dict[str, Adapter]] = build_adapters( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the type hint at the var declaration is a bit much. maybe worth a comment insteaed? you already have a type hint for the return value of build_adapters
component for component in my_components | ||
if isinstance(component, distributed_xcoder_class) | ||
] | ||
attention_layer_blocks: Dict[int, Dict[str, AdaptedAttentionLayers]] = defaultdict(dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still not sold on the type-hint, but gets a pass for the nested structure i suppose... not pretty.
single_task: if a task_id string is given, the built model contains only the components necessary for that task. | ||
token_embs: to tie encoder and decoder embeddings, pass existing embeddings here. | ||
""" | ||
) -> Optional[Dict[str, Adapter]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring this.
""" | ||
) -> Optional[Dict[str, Adapter]]: | ||
# Create AdapterLayer objects and Adapter objects | ||
adapters_by_name: Optional[Dict[str, Adapter]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not into type hints declarations
@@ -190,25 +152,89 @@ def build_xcoder( | |||
) | |||
else: | |||
raise ValueError(f'Unrecognized adapter_type {adapter_opts["adapter_type"]}') | |||
layer_stack_index = adapter_params['layer_stack_index'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i would directly get
from this dict on line 159 but sure
mammoth/model_builder.py
Outdated
distributed_xcoder_class: type | ||
if side == Side.encoder: | ||
distributed_xcoder_class = DistributedEncoderAttentionLayersBlock | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so e.g. side=None
maps to the decoder blocks? feels like a wasted opportunity to halt and catch fire.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the type hint, side
is not optional. But sure, let's doublecheck.
on an unrelated note, it might be helpful to just have a logging_opts just like we have a model_opts and a data_opts |
I've configured |
If you want to enforce typehinting in this project, I think you'd need to incorporate the relevant testing in the actions workflow |
Some more fixes to adapters
--log_model_structure
for debugging architecture issues