Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HFTracer.trace can now take callables and torch.nn.Module #18457

Merged

Conversation

michaelbenayoun
Copy link
Member

What does this PR do?

This PR enables to use the HFTracer "meta-tracing" features to trace any Python callable / torch.nn.Module.

For transformers.PreTrainedModels, the method HFTracer._generate_dummy_inputs already takes care of creating the original dummy inputs needed to handle data-dependent control-flow in the forward pass.

Now, the user can specify dummy_inputs directly to the HFTracer.trace method in order to be able to trace other things than transformers.PreTrainedModels. This is useful for pattern matching for instance.

This becomes possible:

def f(x, y, z=None):
    temp =  x * y
    if z is not None:
        temp += z
    return temp

traced_f = HFTracer().trace(f, dummy_inputs={"x": torch.rand(1, 2), "y": torch.rand(1, 2)})

By default, if dummy_inputs is specified, every argument to root that is not in dummy_inputs will be considered a concrete arg (and thus added to concrete_args). You can disable that by setting infer_concrete_args_from_dummy_inputs to False. This is useful if want to provide custom dummy inputs for some inputs, while still keeping the HFTracer._generate_dummy_inputs doing the work for other inputs (provided that root is a transformers.PreTrainedModel since only this case is supported for automatic dummy inputs generation).

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 3, 2022

The documentation is not available anymore as the PR was closed or merged.

@michaelbenayoun michaelbenayoun requested a review from sgugger August 3, 2022 15:51
Copy link
Contributor

@thomasw21 thomasw21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Some small comment!

The dummy inputs needed to handle data-dependent control-flow if `root` is not a
[`~transformers.PreTrainedModel`]. It can also be used when `root` is a
[`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
infer_concrete_args_from_dummy_inputs (`bool`, *optional*, defaults to `True`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
infer_concrete_args_from_dummy_inputs (`bool`, *optional*, defaults to `True`):
infer_concrete_args_from_dummy_inputs (`bool`, defaults to `True`):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the argument is optional since it has a default (the user does not have to provide it). Please read the writing documentation guide.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was unaware of this. I mistakenly thought that bool, *optional* means Optional[bool] in typing nomenclature. This seems weird to me as all defaults become optional now? Anyway perhaps a chat we can have someplace else.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All arguments that have a default are optionals, yes. That is the definition of an optional argument, an argument you do not need to provide. It's not because the typing module decide to (badly) reuse that word for something else that this will change.

if concrete_args is None:
concrete_args = {}

sig = inspect.signature(root.forward)
if dummy_inputs is not None and infer_concrete_args_from_dummy_inputs:
concrete_args.update({p.name: p.default for p in sig.parameters.values() if p.name not in dummy_inputs})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the parameter doesn't have default?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then tracing will most likely fail afterwards, added a check to fail early, as you suggested.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

The dummy inputs needed to handle data-dependent control-flow if `root` is not a
[`~transformers.PreTrainedModel`]. It can also be used when `root` is a
[`~transformers.PreTrainedModel`] to specify custom dummy inputs for a subset or all the model inputs.
infer_concrete_args_from_dummy_inputs (`bool`, *optional*, defaults to `True`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the argument is optional since it has a default (the user does not have to provide it). Please read the writing documentation guide.

@michaelbenayoun michaelbenayoun merged commit c74befc into huggingface:main Aug 4, 2022
@michaelbenayoun michaelbenayoun deleted the fx_hf_tracer_dummy_inputs branch August 4, 2022 11:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants