diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index 63eebb8a2806..aa4ec91d7da5 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -287,8 +287,16 @@ class ConvertAtenMultinomialOp : public OpConversionPattern { Value initSum = rewriter.create( loc, f64Ty, rewriter.getF64FloatAttr(0.0)); + int64_t srcWidth = cast(elemTy).getWidth(); + if (srcWidth > 64) + op->emitWarning("Op bitwidth will be truncated from " + + std::to_string(srcWidth) + " bits to 64 bits."); auto sumBody = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { Value input = payloadArgs[0]; + if (srcWidth < 64) + input = b.create(loc, f64Ty, input); + if (srcWidth > 64) + input = b.create(loc, f64Ty, input); Value result = payloadArgs[1]; Value nextSum = b.create(loc, input, result); b.create(loc, nextSum); @@ -310,7 +318,7 @@ class ConvertAtenMultinomialOp : public OpConversionPattern { // compute cdf in loop Value initCdf = b.create( - loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), elemTy); + loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), f64Ty); Value cdf = b.create( loc, cstZero, numCategories, cstOne, ValueRange{initCdf}, @@ -330,6 +338,11 @@ class ConvertAtenMultinomialOp : public OpConversionPattern { ind = ValueRange{jIndex, iIndex}; } Value currWeight = b.create(loc, self, ind); + if (srcWidth < 64) + currWeight = b.create(loc, f64Ty, currWeight); + if (srcWidth > 64) + currWeight = + b.create(loc, f64Ty, currWeight); Value currMass = b.create(loc, currWeight, sum); Value currCum = b.create( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e487c12a345f..48b077853fc4 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2311,6 +2311,8 @@ "ElementwiseLog2IntModule_basic", "ElementwiseFminModule_basic", "ElementwiseFmaxModule_basic", + "MultinomialModule2D_basic", + "MultinomialModule2D_F32", "PixelShuffleModuleStaticRank4Float32_basic", "ReflectionPad1dModule2dInput_Right", "ReflectionPad1dModule2dInput_basic", @@ -2339,6 +2341,8 @@ "MoveDimIntNegativeIndexModule_basic", "ReduceL3NormKeepDimModule_basic", "ViewSizeFromOtherTensor_basic", + # incorrect shape generated by torch.onnx.export (needs an unsqueeze) + "MultinomialModule_basic", # Failure - onnx_export "AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", @@ -2842,8 +2846,6 @@ "ElementwiseUnaryIntModule_basic", "ElementwiseFloatTensorGtIntTensorModule_basic", "MaskedFillTensorFloatValueModule_basic", - "MultinomialModule_basic", - "MultinomialModule2D_basic", "NativeDropoutTrainModule_basic", "NativeDropoutTrainStaticShapeModule_basic", "ReduceAnyFloatModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py index e8e4275730ca..24d5c7be025c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py @@ -377,10 +377,20 @@ def BernoulliPModule_basic(module, tu: TestUtils): # ============================================================================== -class MultinomialModule(torch.nn.Module): - def __init__(self): - super().__init__() +def generate_sample_distr(sizes: list[int], torchdtype, tu: TestUtils): + assert len(sizes) == 1 or len(sizes) == 2 + init = tu.rand(*sizes).to(dtype=torchdtype).abs() + normalized = init / (init.sum(-1, True, dtype=torchdtype)) + return normalized + + +class MultinomialBase(torch.nn.Module): + def _forward(self, x): + a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True) + return a + +class MultinomialModule(MultinomialBase): @export @annotate_args( [ @@ -389,20 +399,36 @@ def __init__(self): ] ) def forward(self, x): - a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True) - return a.mean(dtype=torch.double) + return self._forward(x).mean(dtype=torch.double) @register_test_case(module_factory=lambda: MultinomialModule()) def MultinomialModule_basic(module, tu: TestUtils): - x = tu.rand(100).double() + x = generate_sample_distr([100], torch.float64, tu) module.forward(x) -class MultinomialModule2D(torch.nn.Module): - def __init__(self): - super().__init__() +class MultinomialModule2DF32(MultinomialBase): + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + # note: this should really call mean(-1) + # for some reason, doing this causes a torchscript numerics error? + return self._forward(x).mean(dtype=torch.double) + +@register_test_case(module_factory=lambda: MultinomialModule2DF32()) +def MultinomialModule2D_F32(module, tu: TestUtils): + x = generate_sample_distr([10, 100], torch.float32, tu) + module.forward(x) + + +class MultinomialModule2D(MultinomialBase): @export @annotate_args( [ @@ -411,13 +437,14 @@ def __init__(self): ] ) def forward(self, x): - a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True) - return a.mean(dtype=torch.double) + # note: this should really call mean(-1) + # for some reason, doing this causes a torchscript numerics error? + return self._forward(x).mean(dtype=torch.double) @register_test_case(module_factory=lambda: MultinomialModule2D()) def MultinomialModule2D_basic(module, tu: TestUtils): - x = tu.rand(10, 100).double() + x = generate_sample_distr([10, 100], torch.float64, tu) module.forward(x)