Skip to content

Commit

Permalink
[torch-frontend] fix math custom op's lowering (#488)
Browse files Browse the repository at this point in the history
as title
  • Loading branch information
qingyunqu authored Dec 2, 2024
1 parent 40ed4f3 commit 6152f75
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,14 @@ class ConvertGenericCustomOp : public OpConversionPattern<OperatorOp> {
return rewriter.notifyMatchFailure(
op, "only support constant float input");
}
} else if (isa<Torch::StringType>(op.getOperand(i).getType())) {
std::string value;
if (matchPattern(op.getOperand(i), m_TorchConstantStr(value))) {
bufferAttrs.push_back(rewriter.getStringAttr(value));
} else {
return rewriter.notifyMatchFailure(op,
"only support constant str input");
}
} else if (isa<Torch::ValueTensorType>(op.getOperand(i).getType())) {
bufferArgs.push_back(adaptor.getOperands()[i]);
} else {
Expand Down Expand Up @@ -1244,6 +1252,15 @@ class ConvertMathOp : public OpConversionPattern<AtenOpT> {
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value input = adaptor.getSelf();

SmallVector<Value> bufferArgs;
if constexpr (std::is_same_v<AtenOpT, AtenCopysignTensorOp> ||
std::is_same_v<AtenOpT, AtenLdexpTensorOp>) {
bufferArgs.push_back(adaptor.getSelf());
bufferArgs.push_back(adaptor.getOther());
} else {
bufferArgs.push_back(adaptor.getSelf());
}
Type resultType =
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getResult().getType());
Expand All @@ -1259,7 +1276,7 @@ class ConvertMathOp : public OpConversionPattern<AtenOpT> {
rewriter.getDictionaryAttr(byteir_attrs));

auto customCallOp = rewriter.create<stablehlo::CustomCallOp>(
op->getLoc(), TypeRange{resultType}, ValueRange{input},
op->getLoc(), TypeRange{resultType}, bufferArgs,
ArrayRef<NamedAttribute>(attrs));
rewriter.replaceOp(op, customCallOp->getResults());
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def test_triton_custom_op():
# ==============================================================================

@torch.library.custom_op("custom::add", mutates_args=())
def custom_add(a: torch.Tensor, c: int, b: torch.Tensor) -> torch.Tensor:
def custom_add(a: torch.Tensor, c: int, d: str, b: torch.Tensor) -> torch.Tensor:
return a + b + c

@torch.library.register_fake("custom::add")
def custom_add_fake_impl(a, c, b):
def custom_add_fake_impl(a, c, d, b):
return a + b

class CustomAddMod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return custom_add(x, 2, y)
return custom_add(x, 2, "add", y)

def test_custom_op():
example_inputs = (torch.randn(10, 10), torch.randn(10, 10))
Expand All @@ -45,7 +45,7 @@ def test_custom_op():
module = compile_dynamo_model(prog, "stablehlo", backend_legal_ops=GENERIC_CUSTOM_OPS+["custom.add"], verbose=True)
print(module.operation.get_asm())
assert "stablehlo.custom_call @custom.add" in module.operation.get_asm()
assert "byteir_attrs = {custom_attrs = [2]}" in module.operation.get_asm()
assert 'byteir_attrs = {custom_attrs = [2, "add"]}' in module.operation.get_asm()

if __name__ == "__main__":
test_custom_op()
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
def custom_test_helper(model, inputs, custom_op_name):
mlir_module = compile(model, inputs, "stablehlo", backend_legal_ops=MATH_CUSTOM_OPS)
mlir_str = mlir_module.operation.get_asm()
print(mlir_str)
compare_str = "stablehlo.custom_call @{}".format(custom_op_name)
assert compare_str in mlir_str

Expand Down Expand Up @@ -32,7 +33,7 @@ class CopysignModule(torch.nn.Module):
def forward(self, x, y):
return torch.copysign(x, y)

def test_exp2():
def test_copysign():
custom_test_helper(CopysignModule(), [torch.rand(3, 4), torch.rand(3, 4)], "math.copysign")

# ==============================================================================
Expand Down

0 comments on commit 6152f75

Please sign in to comment.