Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchToLinalg] address a dtype mismatch in aten.multinomial lowering #3630

Merged
merged 6 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion lib/Conversion/TorchToLinalg/Random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,16 @@ class ConvertAtenMultinomialOp : public OpConversionPattern<AtenMultinomialOp> {

Value initSum = rewriter.create<arith::ConstantOp>(
loc, f64Ty, rewriter.getF64FloatAttr(0.0));
int64_t srcWidth = cast<mlir::FloatType>(elemTy).getWidth();
if (srcWidth > 64)
op->emitWarning("Op bitwidth will be truncated from " +
std::to_string(srcWidth) + " bits to 64 bits.");
auto sumBody = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value input = payloadArgs[0];
if (srcWidth < 64)
input = b.create<arith::ExtFOp>(loc, f64Ty, input);
if (srcWidth > 64)
input = b.create<arith::TruncFOp>(loc, f64Ty, input);
Value result = payloadArgs[1];
Value nextSum = b.create<arith::AddFOp>(loc, input, result);
b.create<linalg::YieldOp>(loc, nextSum);
Expand All @@ -310,7 +318,7 @@ class ConvertAtenMultinomialOp : public OpConversionPattern<AtenMultinomialOp> {

// compute cdf in loop
Value initCdf = b.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), elemTy);
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), f64Ty);
Value cdf =
b.create<scf::ForOp>(
loc, cstZero, numCategories, cstOne, ValueRange{initCdf},
Expand All @@ -330,6 +338,11 @@ class ConvertAtenMultinomialOp : public OpConversionPattern<AtenMultinomialOp> {
ind = ValueRange{jIndex, iIndex};
}
Value currWeight = b.create<tensor::ExtractOp>(loc, self, ind);
if (srcWidth < 64)
currWeight = b.create<arith::ExtFOp>(loc, f64Ty, currWeight);
if (srcWidth > 64)
currWeight =
b.create<arith::TruncFOp>(loc, f64Ty, currWeight);
Value currMass = b.create<arith::DivFOp>(loc, currWeight, sum);
Value currCum =
b.create<scf::IfOp>(
Expand Down
6 changes: 4 additions & 2 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2311,6 +2311,8 @@
"ElementwiseLog2IntModule_basic",
"ElementwiseFminModule_basic",
"ElementwiseFmaxModule_basic",
"MultinomialModule2D_basic",
"MultinomialModule2D_F32",
"PixelShuffleModuleStaticRank4Float32_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
Expand Down Expand Up @@ -2339,6 +2341,8 @@
"MoveDimIntNegativeIndexModule_basic",
"ReduceL3NormKeepDimModule_basic",
"ViewSizeFromOtherTensor_basic",
# incorrect shape generated by torch.onnx.export (needs an unsqueeze)
"MultinomialModule_basic",
# Failure - onnx_export
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
Expand Down Expand Up @@ -2842,8 +2846,6 @@
"ElementwiseUnaryIntModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"MaskedFillTensorFloatValueModule_basic",
"MultinomialModule_basic",
"MultinomialModule2D_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"ReduceAnyFloatModule_basic",
Expand Down
51 changes: 39 additions & 12 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,20 @@ def BernoulliPModule_basic(module, tu: TestUtils):
# ==============================================================================


class MultinomialModule(torch.nn.Module):
def __init__(self):
super().__init__()
def generate_sample_distr(sizes: list[int], torchdtype, tu: TestUtils):
assert len(sizes) == 1 or len(sizes) == 2
init = tu.rand(*sizes).to(dtype=torchdtype).abs()
normalized = init / (init.sum(-1, True, dtype=torchdtype))
return normalized


class MultinomialBase(torch.nn.Module):
def _forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a


class MultinomialModule(MultinomialBase):
@export
@annotate_args(
[
Expand All @@ -389,20 +399,36 @@ def __init__(self):
]
)
def forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a.mean(dtype=torch.double)
return self._forward(x).mean(dtype=torch.double)


@register_test_case(module_factory=lambda: MultinomialModule())
def MultinomialModule_basic(module, tu: TestUtils):
x = tu.rand(100).double()
x = generate_sample_distr([100], torch.float64, tu)
module.forward(x)


class MultinomialModule2D(torch.nn.Module):
def __init__(self):
super().__init__()
class MultinomialModule2DF32(MultinomialBase):
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
# note: this should really call mean(-1)
# for some reason, doing this causes a torchscript numerics error?
return self._forward(x).mean(dtype=torch.double)


@register_test_case(module_factory=lambda: MultinomialModule2DF32())
def MultinomialModule2D_F32(module, tu: TestUtils):
x = generate_sample_distr([10, 100], torch.float32, tu)
module.forward(x)


class MultinomialModule2D(MultinomialBase):
@export
@annotate_args(
[
Expand All @@ -411,13 +437,14 @@ def __init__(self):
]
)
def forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a.mean(dtype=torch.double)
# note: this should really call mean(-1)
# for some reason, doing this causes a torchscript numerics error?
return self._forward(x).mean(dtype=torch.double)


@register_test_case(module_factory=lambda: MultinomialModule2D())
def MultinomialModule2D_basic(module, tu: TestUtils):
x = tu.rand(10, 100).double()
x = generate_sample_distr([10, 100], torch.float64, tu)
module.forward(x)


Expand Down
Loading