Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][PyTorch][Bugfix] Update layer_norm converter to support immutable_list for normalized_shape #17330

Merged
merged 1 commit into from
Sep 4, 2024

Conversation

mshr-h
Copy link
Contributor

@mshr-h mshr-h commented Sep 3, 2024

With torch==2.4, the normalized_shape argument of the torch.nn.functional.layer_norm can be torch.fx.immutable_collections.immutable_list.
This PR update the layer_norm conveter to cast it to tuple when the normalized_shape is immutable_list.

>       verify_model(model, input_info, binding, expected3)

tests/python/relax/test_frontend_from_fx.py:1323: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/python/relax/test_frontend_from_fx.py:36: in verify_model
    mod = from_fx(graph_model, input_info)
python/tvm/relax/frontend/torch/fx_translator.py:1770: in from_fx
    return TorchFXImporter().from_fx(
python/tvm/relax/frontend/torch/fx_translator.py:1653: in from_fx
    self.env[node] = self.convert_map[func_name](node)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <tvm.relax.frontend.torch.fx_translator.TorchFXImporter object at 0x778d11543da0>, node = layer_norm

    def _layer_norm(self, node: fx.node.Node) -> relax.Var:
        import torch  # type: ignore
        import numpy as np  # type: ignore
    
        x = self.env[node.args[0]]
    
        # functional.layer_norm
        if node.target not in self.named_modules:
            # static or symbolic
            arg = node.args[1]
            if isinstance(arg, tuple):
                value = arg
            else:
                try:
>                   value = self.env[arg]
E                   KeyError: [10, 10]

python/tvm/relax/frontend/torch/fx_translator.py:1084: KeyError

cc @vinx13 @yongwww @Hzfengsy

@yongwww yongwww merged commit e19541d into apache:main Sep 4, 2024
18 of 19 checks passed
@mshr-h mshr-h deleted the fix-layer-norm-arg branch September 5, 2024 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants