diff --git a/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp b/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp index e84199048..46d7095a9 100644 --- a/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp +++ b/frontends/torch-frontend/torch-frontend/lib/Conversion/ConvertTorchToCustomCall.cpp @@ -1051,6 +1051,14 @@ class ConvertGenericCustomOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "only support constant float input"); } + } else if (isa(op.getOperand(i).getType())) { + std::string value; + if (matchPattern(op.getOperand(i), m_TorchConstantStr(value))) { + bufferAttrs.push_back(rewriter.getStringAttr(value)); + } else { + return rewriter.notifyMatchFailure(op, + "only support constant str input"); + } } else if (isa(op.getOperand(i).getType())) { bufferArgs.push_back(adaptor.getOperands()[i]); } else { @@ -1244,6 +1252,15 @@ class ConvertMathOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value input = adaptor.getSelf(); + + SmallVector bufferArgs; + if constexpr (std::is_same_v || + std::is_same_v) { + bufferArgs.push_back(adaptor.getSelf()); + bufferArgs.push_back(adaptor.getOther()); + } else { + bufferArgs.push_back(adaptor.getSelf()); + } Type resultType = OpConversionPattern::getTypeConverter()->convertType( op.getResult().getType()); @@ -1259,7 +1276,7 @@ class ConvertMathOp : public OpConversionPattern { rewriter.getDictionaryAttr(byteir_attrs)); auto customCallOp = rewriter.create( - op->getLoc(), TypeRange{resultType}, ValueRange{input}, + op->getLoc(), TypeRange{resultType}, bufferArgs, ArrayRef(attrs)); rewriter.replaceOp(op, customCallOp->getResults()); return success(); diff --git a/frontends/torch-frontend/torch-frontend/python/test/test_triton_custom_ops.py b/frontends/torch-frontend/torch-frontend/python/test/test_fximporter/test_custom_ops.py similarity index 88% rename from frontends/torch-frontend/torch-frontend/python/test/test_triton_custom_ops.py rename to frontends/torch-frontend/torch-frontend/python/test/test_fximporter/test_custom_ops.py index e8ed9b92b..739c5087d 100644 --- a/frontends/torch-frontend/torch-frontend/python/test/test_triton_custom_ops.py +++ b/frontends/torch-frontend/torch-frontend/python/test/test_fximporter/test_custom_ops.py @@ -27,16 +27,16 @@ def test_triton_custom_op(): # ============================================================================== @torch.library.custom_op("custom::add", mutates_args=()) -def custom_add(a: torch.Tensor, c: int, b: torch.Tensor) -> torch.Tensor: +def custom_add(a: torch.Tensor, c: int, d: str, b: torch.Tensor) -> torch.Tensor: return a + b + c @torch.library.register_fake("custom::add") -def custom_add_fake_impl(a, c, b): +def custom_add_fake_impl(a, c, d, b): return a + b class CustomAddMod(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return custom_add(x, 2, y) + return custom_add(x, 2, "add", y) def test_custom_op(): example_inputs = (torch.randn(10, 10), torch.randn(10, 10)) @@ -45,7 +45,7 @@ def test_custom_op(): module = compile_dynamo_model(prog, "stablehlo", backend_legal_ops=GENERIC_CUSTOM_OPS+["custom.add"], verbose=True) print(module.operation.get_asm()) assert "stablehlo.custom_call @custom.add" in module.operation.get_asm() - assert "byteir_attrs = {custom_attrs = [2]}" in module.operation.get_asm() + assert 'byteir_attrs = {custom_attrs = [2, "add"]}' in module.operation.get_asm() if __name__ == "__main__": test_custom_op() diff --git a/frontends/torch-frontend/torch-frontend/python/test/test_math_custom_ops.py b/frontends/torch-frontend/torch-frontend/python/test/test_math_custom_ops.py index 2c24c8af4..5810875a1 100644 --- a/frontends/torch-frontend/torch-frontend/python/test/test_math_custom_ops.py +++ b/frontends/torch-frontend/torch-frontend/python/test/test_math_custom_ops.py @@ -5,6 +5,7 @@ def custom_test_helper(model, inputs, custom_op_name): mlir_module = compile(model, inputs, "stablehlo", backend_legal_ops=MATH_CUSTOM_OPS) mlir_str = mlir_module.operation.get_asm() + print(mlir_str) compare_str = "stablehlo.custom_call @{}".format(custom_op_name) assert compare_str in mlir_str @@ -32,7 +33,7 @@ class CopysignModule(torch.nn.Module): def forward(self, x, y): return torch.copysign(x, y) -def test_exp2(): +def test_copysign(): custom_test_helper(CopysignModule(), [torch.rand(3, 4), torch.rand(3, 4)], "math.copysign") # ==============================================================================