diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index b327c26fa..464777bbf 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -2087,7 +2087,7 @@ ] }, "torch.Tensor.nanquantile": { - "Matcher": "GenericMatcher", + "Matcher": "NanquantileMatcher", "paddle_api": "paddle.Tensor.nanquantile", "min_input_args": 1, "args_list": [ diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 68844c824..8849fa6ce 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -4078,3 +4078,20 @@ def generate_code(self, kwargs): return ast.parse( "paddle.utils.cpp_extension.setup({})".format(self.kwargs_to_str(kwargs)) ) + + +class NanquantileMatcher(BaseMatcher): + def generate_code(self, kwargs): + new_kwargs = {} + kwargs_change = self.api_mapping["kwargs_change"] + for k in list(kwargs.keys()): + if k in kwargs_change: + new_kwargs[kwargs_change[k]] = kwargs[k] + else: + new_kwargs[k] = kwargs[k] + if "q" in k: + if "tensor" in kwargs[k] and "[" in kwargs[k]: + new_kwargs[k] = "{}.tolist()".format(kwargs[k]) + + code = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(new_kwargs)) + return code diff --git a/tests/test_Tensor_nanquantile.py b/tests/test_Tensor_nanquantile.py index a8e024ca5..a220bd260 100644 --- a/tests/test_Tensor_nanquantile.py +++ b/tests/test_Tensor_nanquantile.py @@ -100,3 +100,14 @@ def test_case_7(): unsupport=True, reason="Paddle not support this parameter", ) + + +def test_multi_q(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[float('nan'), 1.02, 2.21, 3.333,30, float('nan')]], dtype=torch.float64) + result = x.nanquantile(q=torch.tensor([0.3, 0.4], dtype=torch.float64), dim=1, keepdim=True) + """ + ) + obj.run(pytorch_code, ["result"])