Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TorchToTMTensor CumSum Lowering Only Supports Constant Axis #2845

Open
renxida opened this issue Jan 31, 2024 · 1 comment
Open

TorchToTMTensor CumSum Lowering Only Supports Constant Axis #2845

renxida opened this issue Jan 31, 2024 · 1 comment

Comments

@renxida
Copy link
Collaborator

renxida commented Jan 31, 2024

Here,

op, "unimplemented: only constant dim value is supported");

the TorchToTMTensor lowering checks to make sure that dim is a TorchConstantInt.

However, ONNX forces the axis to be variable because it allows axis to be a tensor.

This causes (at least part of) the lowering failure described in nod-ai/iree-amd-aie#103

opt-125M.fp32.onnx.mlir:231:12: error: failed to legalize operation 'torch.aten.cumsum' that was explicitly marked illegal
    %228 = torch.operator "onnx.CumSum"(%226, %227) : (!torch.vtensor<[1,6],si64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[1,6],si64>
           ^
opt-125M.fp32.onnx.mlir:231:12: note: see current operation: %287 = "torch.aten.cumsum"(%18, %286, %16) : (!torch.vtensor<[1,6],si64>, !torch.int, !torch.int) -> !torch.vtensor<[1,6],si64> loc("opt-125M.fp32.onnx.mlir":231:12)
@renxida
Copy link
Collaborator Author

renxida commented Jan 31, 2024

It looks like the main reason why TorchToTMTensor requires a constant dim is because it uses the dim input to compute the shape of an intermediate value and the output.

SmallVector<Value> accSizes(sizes);
    accSizes.erase(accSizes.begin() + dim);
    SmallVector<int64_t> accStatic(
        makeShapeTorchCompatible(resultType.getShape()));
    accStatic.erase(accStatic.begin() + dim);
    Value acc = createZeroInitTensor(rewriter, loc, accSizes, elementType);
    Type accType =
        RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType);
    acc = rewriter.create<tensor::CastOp>(loc, accType, acc);

But maybe we can just use the output shape of the AtenCumsumOp being lowered and skip that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant