-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HFTracer.trace can now take callables and torch.nn.Module #18457
HFTracer.trace can now take callables and torch.nn.Module #18457
Conversation
… and torch.nn.Module in general
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Some small comment!
src/transformers/utils/fx.py
Outdated
The dummy inputs needed to handle data-dependent control-flow if `root` is not a | ||
[`~transformers.PreTrainedModel`]. It can also be used when `root` is a | ||
[`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs. | ||
infer_concrete_args_from_dummy_inputs (`bool`, *optional*, defaults to `True`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
infer_concrete_args_from_dummy_inputs (`bool`, *optional*, defaults to `True`): | |
infer_concrete_args_from_dummy_inputs (`bool`, defaults to `True`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the argument is optional since it has a default (the user does not have to provide it). Please read the writing documentation guide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was unaware of this. I mistakenly thought that bool, *optional*
means Optional[bool]
in typing
nomenclature. This seems weird to me as all defaults become optional now? Anyway perhaps a chat we can have someplace else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All arguments that have a default are optionals, yes. That is the definition of an optional argument, an argument you do not need to provide. It's not because the typing
module decide to (badly) reuse that word for something else that this will change.
if concrete_args is None: | ||
concrete_args = {} | ||
|
||
sig = inspect.signature(root.forward) | ||
if dummy_inputs is not None and infer_concrete_args_from_dummy_inputs: | ||
concrete_args.update({p.name: p.default for p in sig.parameters.values() if p.name not in dummy_inputs}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if the parameter doesn't have default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then tracing will most likely fail afterwards, added a check to fail early, as you suggested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
src/transformers/utils/fx.py
Outdated
The dummy inputs needed to handle data-dependent control-flow if `root` is not a | ||
[`~transformers.PreTrainedModel`]. It can also be used when `root` is a | ||
[`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs. | ||
infer_concrete_args_from_dummy_inputs (`bool`, *optional*, defaults to `True`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the argument is optional since it has a default (the user does not have to provide it). Please read the writing documentation guide.
What does this PR do?
This PR enables to use the
HFTracer
"meta-tracing" features to trace any Python callable /torch.nn.Module
.For
transformers.PreTrainedModel
s, the methodHFTracer._generate_dummy_inputs
already takes care of creating the original dummy inputs needed to handle data-dependent control-flow in the forward pass.Now, the user can specify
dummy_inputs
directly to theHFTracer.trace
method in order to be able to trace other things thantransformers.PreTrainedModel
s. This is useful for pattern matching for instance.This becomes possible:
By default, if
dummy_inputs
is specified, every argument toroot
that is not indummy_inputs
will be considered a concrete arg (and thus added toconcrete_args
). You can disable that by settinginfer_concrete_args_from_dummy_inputs
toFalse
. This is useful if want to provide custom dummy inputs for some inputs, while still keeping theHFTracer._generate_dummy_inputs
doing the work for other inputs (provided thatroot
is atransformers.PreTrainedModel
since only this case is supported for automatic dummy inputs generation).