Skip to content

Commit

Permalink
Fix dtype mismatch in arith ops, improve tests and onnx export path
Browse files Browse the repository at this point in the history
  • Loading branch information
zjgarvey committed Aug 13, 2024
1 parent d3695a9 commit dbfb527
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 16 deletions.
13 changes: 12 additions & 1 deletion lib/Conversion/TorchToLinalg/Random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,15 @@ class ConvertAtenMultinomialOp : public OpConversionPattern<AtenMultinomialOp> {

Value initSum = rewriter.create<arith::ConstantOp>(
loc, f64Ty, rewriter.getF64FloatAttr(0.0));
int64_t srcWidth = cast<mlir::FloatType>(elemTy).getWidth();
if (srcWidth > 64)
op->emitWarning("Op bitwidth will be truncated from " + std::to_string(srcWidth) + " bits to 64 bits.");
auto sumBody = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
Value input = payloadArgs[0];
if (srcWidth < 64)
input = b.create<arith::ExtFOp>(loc, f64Ty, input);
if (srcWidth > 64)
input = b.create<arith::TruncFOp>(loc, f64Ty, input);
Value result = payloadArgs[1];
Value nextSum = b.create<arith::AddFOp>(loc, input, result);
b.create<linalg::YieldOp>(loc, nextSum);
Expand All @@ -310,7 +317,7 @@ class ConvertAtenMultinomialOp : public OpConversionPattern<AtenMultinomialOp> {

// compute cdf in loop
Value initCdf = b.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), elemTy);
loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), f64Ty);
Value cdf =
b.create<scf::ForOp>(
loc, cstZero, numCategories, cstOne, ValueRange{initCdf},
Expand All @@ -330,6 +337,10 @@ class ConvertAtenMultinomialOp : public OpConversionPattern<AtenMultinomialOp> {
ind = ValueRange{jIndex, iIndex};
}
Value currWeight = b.create<tensor::ExtractOp>(loc, self, ind);
if (srcWidth < 64)
currWeight = b.create<arith::ExtFOp>(loc, f64Ty, currWeight);
if (srcWidth > 64)
currWeight = b.create<arith::TruncFOp>(loc, f64Ty, currWeight);
Value currMass = b.create<arith::DivFOp>(loc, currWeight, sum);
Value currCum =
b.create<scf::IfOp>(
Expand Down
6 changes: 4 additions & 2 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2300,6 +2300,8 @@
"ElementwiseLog2IntModule_basic",
"ElementwiseFminModule_basic",
"ElementwiseFmaxModule_basic",
"MultinomialModule2D_F32",
"MultinomialModule2D_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"ReflectionPad1dModule2dInput_Right",
"ReflectionPad1dModule2dInput_basic",
Expand Down Expand Up @@ -2328,6 +2330,8 @@
"MoveDimIntNegativeIndexModule_basic",
"ReduceL3NormKeepDimModule_basic",
"ViewSizeFromOtherTensor_basic",
# incorrect shape generated by torch.onnx.export (needs an unsqueeze)
"MultinomialModule_basic",
# Failure - onnx_export
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
Expand Down Expand Up @@ -2826,8 +2830,6 @@
"ElementwiseUnaryIntModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"MaskedFillTensorFloatValueModule_basic",
"MultinomialModule_basic",
"MultinomialModule2D_basic",
"NativeDropoutTrainModule_basic",
"NativeDropoutTrainStaticShapeModule_basic",
"ReduceAnyFloatModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def convert_onnx(model, inputs):
for index, arg in enumerate(inputs):
shape = map(lambda d: d if d >= 0 else 1, arg.shape)
shape = tuple(shape)
examples.append(torch.zeros(size=shape, dtype=arg.dtype))
examples.append(torch.ones(size=shape, dtype=arg.dtype))

input_name = "input_{}".format(index)
input_names.append(input_name)
Expand Down Expand Up @@ -150,6 +150,7 @@ def __init__(

def compile(self, program: torch.nn.Module, verbose: bool = False) -> Any:
example_args = convert_annotations_to_placeholders(program.forward)
print(example_args)
onnx_module = convert_onnx(program, example_args)
backend_module = _module_lowering(
verbose, OutputType.get(self.output_type), onnx_module
Expand Down
47 changes: 35 additions & 12 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,20 @@ def BernoulliPModule_basic(module, tu: TestUtils):
# ==============================================================================


class MultinomialModule(torch.nn.Module):
def __init__(self):
super().__init__()
def generate_sample_distr(sizes: list[int], torchdtype, tu: TestUtils):
assert len(sizes) == 1 or len(sizes) == 2
init = tu.rand(*sizes).to(dtype=torchdtype).abs()
normalized = init / (init.sum(-1, True, dtype=torchdtype))
return normalized


class MultinomialBase(torch.nn.Module):
def _forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a


class MultinomialModule(MultinomialBase):
@export
@annotate_args(
[
Expand All @@ -389,20 +399,34 @@ def __init__(self):
]
)
def forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a.mean(dtype=torch.double)
return self._forward(x).mean(dtype=torch.double)


@register_test_case(module_factory=lambda: MultinomialModule())
def MultinomialModule_basic(module, tu: TestUtils):
x = tu.rand(100).double()
x = generate_sample_distr([100], torch.float64, tu)
module.forward(x)


class MultinomialModule2D(torch.nn.Module):
def __init__(self):
super().__init__()
class MultinomialModule2DF32(MultinomialBase):
@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
return self._forward(x).mean(-1, dtype=torch.float32)


@register_test_case(module_factory=lambda: MultinomialModule2DF32())
def MultinomialModule2D_F32(module, tu: TestUtils):
x = generate_sample_distr([10, 100], torch.float32, tu)
module.forward(x)


class MultinomialModule2D(MultinomialBase):
@export
@annotate_args(
[
Expand All @@ -411,13 +435,12 @@ def __init__(self):
]
)
def forward(self, x):
a = torch.ops.aten.multinomial(x, 1024 * 1024, replacement=True)
return a.mean(dtype=torch.double)
return self._forward(x).mean(-1, dtype=torch.double)


@register_test_case(module_factory=lambda: MultinomialModule2D())
def MultinomialModule2D_basic(module, tu: TestUtils):
x = tu.rand(10, 100).double()
x = generate_sample_distr([10, 100], torch.float64, tu)
module.forward(x)


Expand Down

0 comments on commit dbfb527

Please sign in to comment.