-
Notifications
You must be signed in to change notification settings - Fork 532
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tcp] Add boilerplate for TCP dialect #1375
Conversation
...m-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td
Outdated
Show resolved
Hide resolved
externals/llvm-external-projects/torch-mlir-dialects/test/Conversion/TcpToTosa/binary_ops.mlir
Outdated
Show resolved
Hide resolved
Changed base branch from |
.../llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Conversion/Passes.td
Outdated
Show resolved
Hide resolved
.../llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Conversion/Passes.td
Outdated
Show resolved
Hide resolved
...-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpBase.td
Outdated
Show resolved
Hide resolved
...m-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td
Show resolved
Hide resolved
...m-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td
Outdated
Show resolved
Hide resolved
externals/llvm-external-projects/torch-mlir-dialects/lib/Conversion/TcpToTosa/TcpToTosa.cpp
Outdated
Show resolved
Hide resolved
externals/llvm-external-projects/torch-mlir-dialects/lib/Conversion/TcpToTosa/TcpToTosa.cpp
Outdated
Show resolved
Hide resolved
externals/llvm-external-projects/torch-mlir-dialects/lib/Conversion/TcpToTosa/TcpToTosa.cpp
Outdated
Show resolved
Hide resolved
externals/llvm-external-projects/torch-mlir-dialects/lib/Conversion/TcpToTosa/TcpToTosa.cpp
Outdated
Show resolved
Hide resolved
...m-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td
Outdated
Show resolved
Hide resolved
Posting this here for persistence: #1223 -- here is an example of adding the e2e framework for mhlo, we should do something similar for TCP. See rationale here for the critical importance of e2e testing in this space: https://github.com/llvm/torch-mlir/blob/main/docs/architecture.md#why-so-much-end-to-end-testing |
nit: I recommend renaming the commit prefix to remote the all caps which looks a little "shouty". Just [tcp] as a required prefix should be sufficient to identify every push. |
...m-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td
Show resolved
Hide resolved
); | ||
|
||
let results = (outs | ||
Tcp_Tensor:$out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the relationship between operands and results of unary and binary ops? Do we want them to be exactly the same or just compatible, for some definition of "compatible"?
In MHLO, we allow "compatible" operand and result types via isCompatibleReturnTypes and HLO_CompatibleOperandsAndResultType which are both implemented via isCompatibleForHloTypeInference, which roughly speaking allows things like:
func @dynamism(%arg0: tensor<?xf32>, %arg1: tensor<1xf32>) {
%0 = "mhlo.add"(%arg0, %arg0) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = "mhlo.add"(%arg0, %arg0) : (tensor<?xf32>, tensor<?xf32>) -> tensor<1xf32>
%2 = "mhlo.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<1xf32>) -> tensor<?xf32>
%3 = "mhlo.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<1xf32>) -> tensor<1xf32>
%4 = "mhlo.add"(%arg1, %arg0) : (tensor<1xf32>, tensor<?xf32>) -> tensor<?xf32>
%5 = "mhlo.add"(%arg1, %arg0) : (tensor<1xf32>, tensor<?xf32>) -> tensor<1xf32>
%6 = "mhlo.add"(%arg1, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<?xf32>
%7 = "mhlo.add"(%arg1, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
return
}
I wasn't around when this functionality was added to MHLO, but I've been told that it was motivated by similar functionality in the TF dialect, which itself was motivated by the desire to support progressive shape refinement during --tf-shape-inference
. If MHLO didn't support these relaxed compatibility checks, then --tf-shape-inference
would need to introduce casts when doing shape refinement, which was deemed undesirable.
Are we planning to do shape refinement in TCP, or that's expected to be the job of higher-level layers? If it's the former (which I think could be the right choice), then does anyone have a specific shape inference facility in mind that we could use? In the recent ODM, I remember someone said that there are better tools than --tf-shape-inference
(which also isn't available in upstream, so we cannot use it anyway in TCP), and I'm eager to learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a variety of ways to perform shape inference. Linalg does a very good job despite the "exact match of the static type" requirement by having patterns that push the casts around to a fixed-point, which typically results in them being absorbed somewhere.
I wrote some more general shape inference thoughts here, though I think a lot of what is described there doesn't apply to TCP/MHLO-like design points link.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for those points. That is very useful to know.
Will address shape inference in a separate PR.
externals/llvm-external-projects/torch-mlir-dialects/lib/Conversion/TcpToTosa/TcpToTosa.cpp
Outdated
Show resolved
Hide resolved
let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)"; | ||
} | ||
|
||
def Tcp_MatmulOp : Tcp_Op<"matmul", [NoSideEffect]> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to implement verification + type inference for this op right away. This will force us to face interesting questions, in addition to the 2D/3D item which is already discussed. E.g.:
- Do we allow result element type to be different from operand element type? I think we should, if we want to support quantization well.
- If yes, what do we do with type inference? Do we: a) say that matmul doesn't support it, b) add something like
preferred_element_type
attribute to enable type inference, c) infer result element type equal to operand element type and then work around via isCompatibleReturnTypes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1,
Also, why not have separate matmul and batch_matmul ops? My experience is that with these "variant" ops it's much easier to isa<BatchMatmulOp>
than if (getOperand(0).cast<RankedTensorType>().getRank() == 3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we allow result element type to be different from operand element type? I think we should, if we want to support quantization well.
That's right. Although I don't want to make the design choices w.r.t quantization at this point, while that is still being worked on. We plan to have a separate discussion regarding quantization. We can address these once that is finalized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not have separate matmul and batch_matmul ops?
I'm working on a document to summarize the different design choices for Matmul (2 ops vs 1 op, 3D and 2D cases, etc.). So, I'm removing matmul op from this PR. Will send a separate PR for matmul later.
externals/llvm-external-projects/torch-mlir-dialects/test/Conversion/TcpToTosa/binary_ops.mlir
Outdated
Show resolved
Hide resolved
.../llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Conversion/Passes.td
Outdated
Show resolved
Hide resolved
namespace torch { | ||
namespace tcp { | ||
|
||
std::unique_ptr<Pass> createConvertTcpToTosaPass(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is auto-generated right? I'm looking here: https://github.com/llvm/llvm-project/blob/e854c17b02f8cd82a303d223ba5f3b0d87579cd7/mlir/tools/mlir-tblgen/PassGen.cpp#L127
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried removing this and it doesn't auto generate for me. Not sure if I'm missing something in cmake files to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check if other passes have this? Maybe you need to do something extra to put it in the torch::tcp
namespace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what you pointed to is only generating code that goes into Passes.h.inc
file. For example, that file contains the call to the pass, like the following:
inline void registerConvertTcpToTosa() {
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::torch::tcp::createConvertTcpToTosaPass();
});
}
I think that is what is being generated here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@silvasean Do you know if this file could be generated automatically?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
createConvertTcpToTosaPass needs to be declared here. It is not autogenerated. See e.g. https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Conversion/TorchToArith/TorchToArith.h
...xternal-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpDialect.h
Outdated
Show resolved
Hide resolved
...-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpBase.td
Outdated
Show resolved
Hide resolved
...-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpBase.td
Outdated
Show resolved
Hide resolved
...m-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td
Outdated
Show resolved
Hide resolved
In case anyone is wondering, I switched base branches to main and back (to mlir-tcp) to flush the extraneous commit history that was showing earlier, likely due to a rebase on main (while base branch is mlir-tcp). When rebasing, we may want to do mlir-tcp on main first, then this branch tcp_1 on mlir-tcp, which should do it cleanly hopefully. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review. Addressed the comments. PTAL.
I'm moving some of the stuff discussed here to follow up PRs:
- Linalg lowering
- e2e tests
- matmul op
- shape inference
Thanks for pointers regarding those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor nits inline, LGTM otherwise.
namespace torch { | ||
namespace tcp { | ||
|
||
std::unique_ptr<Pass> createConvertTcpToTosaPass(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check if other passes have this? Maybe you need to do something extra to put it in the torch::tcp
namespace?
...m-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td
Outdated
Show resolved
Hide resolved
@@ -0,0 +1,13 @@ | |||
#ifndef TORCH_MLIR_DIALECTS_CONVERSION_PASSES_H |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all these files need license headers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
namespace torch { | ||
namespace tcp { | ||
|
||
std::unique_ptr<Pass> createConvertTcpToTosaPass(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
createConvertTcpToTosaPass needs to be declared here. It is not autogenerated. See e.g. https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Conversion/TorchToArith/TorchToArith.h
|
||
namespace mlir { | ||
|
||
#define GEN_PASS_DECL_CONVERTTCPTOTOSA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this GEN_PASS_DECL_CONVERTTCPTOTOSA? I don't think we do this in the main Torch-MLIR codebase. I think we just have a single PassDetail.h which only need to be included by the .cpp files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aah okay. Thanks for pointing that out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@silvasean GEN_PASS_CLASSES
has been recently deprecated, and will be removed in the near future. GEN_PASS_DECL_PASSNAME
is the recommended way to approach this moving forward. E.g. see this MLIR-HLO commit for an example.
By the way, functions like createConvertTcpToTosaPass
can be autogenerated, but only if you don't specify let constructor
in Passes.td, e.g. see this StableHLO PR for an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh neat. I wasn't aware of that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, functions like createConvertTcpToTosaPass can be autogenerated, but only if you don't specify let constructor in Passes.td, e.g. see openxla/stablehlo#176 for an example.
Good to know that. Thanks.
b3e8a41
to
e626771
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed prose and .td files. Sean will undoubtedly have a better perspective on the conventions for .h/CMakeLists.txt files. At a glance, everything checks out.
// Tcp Type Definitions. | ||
//===----------------------------------------------------------------------===// | ||
|
||
def Tcp_Scalar : AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The spec draft says AnySignlessIntegerOrIndex
. I think AnySignlessInteger
is a more recent development, so the spec needs to be updated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point. I have updated the spec to reflect this (w/o index for now). When we add ops that need index, we can update it.
...-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpBase.td
Show resolved
Hide resolved
|
||
def Tcp_Dialect : Dialect { | ||
let name = "tcp"; | ||
let cppNamespace = "::mlir::torch::tcp"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
::mlir::tcp
perhaps, given the plans to be applicable more widely than to PyTorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that we are bootstrapping TCP under TorchMLIR externals, I assumed it has to use the ::mlir::torch::tcp
namespace.
@silvasean Is it okay to use ::mlir::tcp
for this instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the plans to make it more widely applicable, mlir::tcp seems fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying. Updated it to mlir::tcp
def Tcp_Tensor : RankedTensorOf<[Tcp_Scalar]>; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Tcp Operator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Tcp Operators"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated it to "Tcp Ops Base", which is more appropriate here.
...vm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.h
Outdated
Show resolved
Hide resolved
externals/llvm-external-projects/torch-mlir-dialects/test/Dialect/Tcp/dummy_op.mlir
Outdated
Show resolved
Hide resolved
Super pumped to see this PR landing!! 🎉 🎉 🎉 |
…lvm#1375) Implement support for the ONNX `GatherElements` operator: - [x] verification code (diagnose operator constraints) - [x] shape inference with helper - [x] codegen support - [x] add lit tests to verify constraint diagnostics - [x] add lit test to verify code generation - [x] enable end-to-end test (backend test) Signed-off-by: Ettore Tiotto <etiotto@ca.ibm.com>
* Initial boilerplate for TCP with a dummy op. * Conditional flag to enable TCP
* Initial boilerplate for TCP with a dummy op. * Conditional flag to enable TCP
* Initial boilerplate for TCP with a dummy op. * Conditional flag to enable TCP
* Initial boilerplate for TCP with a dummy op. * Conditional flag to enable TCP
* Initial boilerplate for TCP with a dummy op. * Conditional flag to enable TCP
* Initial boilerplate for TCP with a dummy op. * Conditional flag to enable TCP
* Initial boilerplate for TCP with a dummy op. * Conditional flag to enable TCP
* Initial boilerplate for TCP with a dummy op. * Conditional flag to enable TCP
This PR adds the initial boilerplate necessary for the TCP dialect. It includes:
-DTORCH_MLIR_DIALECTS_ENABLE_TCP
to enable TCP. This will beOFF
by default.