diff --git a/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp b/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp index 46d7095a9..ac705350b 100644 --- a/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp +++ b/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp @@ -1150,13 +1150,6 @@ class ConvertAtenUpsampleNearest2dOp : public OpConversionPattern { // TODO: if result have dynamic shape, should lowering to target_mode=scale if (!resultType.hasStaticShape()) return failure(); - if constexpr (std::is_same_v) { - if (!isa(adaptor.getScalesH().getType()) || - !isa(adaptor.getScalesW().getType())) { - // FIXME: check shape inference when scales_h or scales_w is not None. - return failure(); - } - } std::vector byteir_attrs; byteir_attrs.emplace_back(rewriter.getStringAttr("target_mode"), @@ -1183,17 +1176,24 @@ class ConvertAtenUpsampleNearest2dOp : public OpConversionPattern { } }; -// aten.upsample_bilinear2d.vec -class ConvertAtenUpsampleBilinear2dVecOp - : public OpConversionPattern { +// aten.upsample_bilinear2d.vec && aten.upsample_bilinear2d +template +class ConvertAtenUpsampleBilinear2dOp : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OP::Adaptor; LogicalResult - matchAndRewrite(AtenUpsampleBilinear2dVecOp op, OpAdaptor adaptor, + matchAndRewrite(OP op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value input = adaptor.getInput(); + Value input; + if constexpr (std::is_same_v) { + input = adaptor.getSelf(); + } else { + input = adaptor.getInput(); + } RankedTensorType resultType = cast( - getTypeConverter()->convertType(op.getResult().getType())); + OpConversionPattern::getTypeConverter()->convertType( + op.getResult().getType())); // TODO: if result have dynamic shape, should lowering to target_mode=scale if (!resultType.hasStaticShape()) @@ -1387,8 +1387,13 @@ class ConvertTorchToCustomCall target.addIllegalOp(); patterns.add>( typeConverter, context); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); target.addIllegalOp(); - patterns.add(typeConverter, context); + patterns + .add>( + typeConverter, context); } populateMathToCustomCallPattern(target, typeConverter, patterns, diff --git a/frontends/torch-frontend/torch-frontend/python/test/test_torchscript/test_byteir_customcall_ops.py b/frontends/torch-frontend/torch-frontend/python/test/test_torchscript/test_byteir_customcall_ops.py index 8489b590c..b91d6daff 100644 --- a/frontends/torch-frontend/torch-frontend/python/test/test_torchscript/test_byteir_customcall_ops.py +++ b/frontends/torch-frontend/torch-frontend/python/test/test_torchscript/test_byteir_customcall_ops.py @@ -287,7 +287,7 @@ def test_resize_nearest(): model = torch.jit.trace(UpsampleNearest2dModule1(), inputs) custom_test_helper(model, inputs, "byteir.resize") -class UpsampleBilinear2dModule(torch.nn.Module): +class UpsampleBilinear2dVecModule(torch.nn.Module): def __init__(self): super().__init__() @@ -298,10 +298,10 @@ def forward(self, x): @pytest.mark.mhlo_tools def test_resize_bilinear(): inputs = [tu.randn(3, 3, 10, 20)] - model = torch.jit.trace(UpsampleBilinear2dModule(), inputs) + model = torch.jit.trace(UpsampleBilinear2dVecModule(), inputs) custom_test_helper(model, inputs, "byteir.resize") -class UpsampleBilinear2dModule1(torch.nn.Module): +class UpsampleBilinear2dVecModule1(torch.nn.Module): def __init__(self): super().__init__() @@ -312,7 +312,21 @@ def forward(self, x): @pytest.mark.mhlo_tools def test_resize_bilinear_half_pixel(): inputs = [tu.randn(3, 3, 10, 20)] - model = torch.jit.trace(UpsampleBilinear2dModule1(), inputs) + model = torch.jit.trace(UpsampleBilinear2dVecModule1(), inputs) + custom_test_helper(model, inputs, "byteir.resize") + +class UpsampleBilinear2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + #FIXME: use torch.nn.interpolate to avoid torch.jit.trace + return torch.ops.aten.upsample_bilinear2d(x, (11, 25), True, None, None) + +@pytest.mark.mhlo_tools +def test_resize_bilinear_1(): + inputs = [tu.randn(3, 3, 10, 20)] + model = torch.jit.trace(UpsampleBilinear2dModule(), inputs) custom_test_helper(model, inputs, "byteir.resize") # ==============================================================================