Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/adapters 2 #83

Merged
merged 6 commits into from
Dec 9, 2024
Merged

Fix/adapters 2 #83

merged 6 commits into from
Dec 9, 2024

Conversation

Waino
Copy link
Collaborator

@Waino Waino commented Dec 2, 2024

Some more fixes to adapters

  • Create adapters before the StackXCoder, to ensure no accidental duplication
  • Prevent LoRA adapter from taking wrapped layer as its child
  • Flag --log_model_structure for debugging architecture issues

@Waino Waino requested a review from TimotheeMickus December 2, 2024 11:18
@Waino Waino mentioned this pull request Dec 2, 2024
Waino added 2 commits December 2, 2024 19:26
Now both training and translation use `--gpu_rank 0`.

Closes #82
Copy link
Collaborator

@TimotheeMickus TimotheeMickus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

I have complaints on the cosmetic side (I'm really not into type hints when declaring vars) but nothing that serious

mammoth/opts.py Outdated
@@ -548,7 +554,7 @@ def _add_train_general_opts(parser):
type=float,
default=[0.3],
nargs='+',
help="Dropout probability; applied in LSTM stacks.",
help="Dropout probability; applied in LSTM stacks. (Probably legacy?)",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be, since we're now delegating this to xtransformers

@@ -27,7 +36,7 @@ def _validate_adapters(cls, opts):
"""Parse corpora specified in data field of YAML file."""
if not opts.adapters:
return
adapter_opts = yaml.safe_load(opts.adapters)
adapter_opts = yaml_or_dict(opts.adapters, name='opts.adapters')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about this name thing — I get that you want to display a less cryptic message, but in practice we devs are the only ones who'll ever get to see that TypeError ... and you're using that function once? the stacktrace should be unambiguous enough.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot to apply this to the other similar locations where yaml.safe_load is applied. Fixed now.

This crashed when loading opts.adapters from a checkpoint. I'm not sure why those other usages didn't crash earlier, but better safe than sorry.

@@ -338,7 +380,12 @@ def build_model(
)

model.to(device)
# logger.info(model)
if opts.log_model_structure:
logger.info(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like this should be debug rather than info, but ok, given your new flags

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this should be logger.debug. However, we still haven't fixed the way we set up logging (AFAIK), so debugs are never visible.

for component in task_queue_manager.get_my_distributed_components():
logger.info(component)
for name, p in model.named_parameters():
print(f'{p.requires_grad} {name}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why direct print here?

@@ -314,13 +340,28 @@ def build_model(
device = torch.device("cpu")
logger.info(device)

enc_adapters_by_name: Optional[Dict[str, Adapter]] = build_adapters(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the type hint at the var declaration is a bit much. maybe worth a comment insteaed? you already have a type hint for the return value of build_adapters

component for component in my_components
if isinstance(component, distributed_xcoder_class)
]
attention_layer_blocks: Dict[int, Dict[str, AdaptedAttentionLayers]] = defaultdict(dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still not sold on the type-hint, but gets a pass for the nested structure i suppose... not pretty.

single_task: if a task_id string is given, the built model contains only the components necessary for that task.
token_embs: to tie encoder and decoder embeddings, pass existing embeddings here.
"""
) -> Optional[Dict[str, Adapter]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring this.

"""
) -> Optional[Dict[str, Adapter]]:
# Create AdapterLayer objects and Adapter objects
adapters_by_name: Optional[Dict[str, Adapter]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not into type hints declarations

@@ -190,25 +152,89 @@ def build_xcoder(
)
else:
raise ValueError(f'Unrecognized adapter_type {adapter_opts["adapter_type"]}')
layer_stack_index = adapter_params['layer_stack_index']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would directly get from this dict on line 159 but sure

distributed_xcoder_class: type
if side == Side.encoder:
distributed_xcoder_class = DistributedEncoderAttentionLayersBlock
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so e.g. side=None maps to the decoder blocks? feels like a wasted opportunity to halt and catch fire.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the type hint, side is not optional. But sure, let's doublecheck.

@TimotheeMickus
Copy link
Collaborator

on an unrelated note, it might be helpful to just have a logging_opts just like we have a model_opts and a data_opts

@Waino
Copy link
Collaborator Author

Waino commented Dec 9, 2024

I'm really not into type hints when declaring vars

I've configured nvim to do linting and type checking automatically, and in most projects I aim for zero type checking errors (in mammoth, this is far from the case). I've developed a habit of type hinting a lot because of this.

@Waino Waino merged commit 56ac097 into main Dec 9, 2024
2 checks passed
@Waino Waino deleted the fix/adapters-2 branch December 9, 2024 10:42
@TimotheeMickus
Copy link
Collaborator

I'm really not into type hints when declaring vars

I've configured nvim to do linting and type checking automatically, and in most projects I aim for zero type checking errors (in mammoth, this is far from the case). I've developed a habit of type hinting a lot because of this.

If you want to enforce typehinting in this project, I think you'd need to incorporate the relevant testing in the actions workflow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants