|
| 1 | +#include <torch/torch.h> |
1 | 2 | #include "core/conversion/converters/converter_util.h"
|
2 | 3 | #include "core/conversion/converters/converters.h"
|
3 | 4 | #include "core/util/prelude.h"
|
@@ -72,6 +73,60 @@ auto mm_registrations TRTORCH_UNUSED =
|
72 | 73 |
|
73 | 74 | LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
|
74 | 75 | return true;
|
| 76 | + }}) |
| 77 | + .pattern( |
| 78 | + {"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> (Tensor)", |
| 79 | + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { |
| 80 | + auto self = args[0].ITensorOrFreeze(ctx); |
| 81 | + auto mat1 = args[1].ITensorOrFreeze(ctx); |
| 82 | + auto mat2 = args[2].ITensorOrFreeze(ctx); |
| 83 | + auto beta = args[4].unwrapToScalar().to<float>(); |
| 84 | + auto betaTensor = tensor_to_const(ctx, torch::tensor({beta})); |
| 85 | + auto alpha = args[5].unwrapToScalar().to<float>(); |
| 86 | + auto alphaTensor = tensor_to_const(ctx, torch::tensor({alpha})); |
| 87 | + |
| 88 | + // Ensure self and other tensors have same nbDims by expanding the dimensions (from 0 axis) if |
| 89 | + // necessary. |
| 90 | + if (mat1->getDimensions().nbDims < mat2->getDimensions().nbDims) { |
| 91 | + mat1 = addPadding(ctx, n, mat1, mat2->getDimensions().nbDims, false, false); |
| 92 | + } else { |
| 93 | + mat2 = addPadding(ctx, n, mat2, mat1->getDimensions().nbDims, false, false); |
| 94 | + } |
| 95 | + |
| 96 | + auto mat2_dims = mat2->getDimensions(); |
| 97 | + nvinfer1::Dims transposed_mat2_dims; |
| 98 | + for (int i = mat2_dims.nbDims - 1; i >= 0; i--) { |
| 99 | + transposed_mat2_dims.d[i] = mat2_dims.d[mat2_dims.nbDims - 1 - i]; |
| 100 | + } |
| 101 | + auto shuffle_layer = ctx->net->addShuffle(*mat2); |
| 102 | + shuffle_layer->setReshapeDimensions(transposed_mat2_dims); |
| 103 | + mat2 = shuffle_layer->getOutput(0); |
| 104 | + |
| 105 | + auto mm_layer = ctx->net->addMatrixMultiply( |
| 106 | + *mat1, nvinfer1::MatrixOperation::kNONE, *mat2, nvinfer1::MatrixOperation::kNONE); |
| 107 | + TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication layer in node: " << *n); |
| 108 | + auto mm_scale_layer = add_elementwise( |
| 109 | + ctx, |
| 110 | + nvinfer1::ElementWiseOperation::kPROD, |
| 111 | + mm_layer->getOutput(0), |
| 112 | + alphaTensor, |
| 113 | + util::node_info(n) + "_alphaScale"); |
| 114 | + TRTORCH_CHECK(mm_scale_layer, "Unable to create alpha scaling layer in node: " << *n); |
| 115 | + auto beta_scale_layer = add_elementwise( |
| 116 | + ctx, nvinfer1::ElementWiseOperation::kPROD, self, betaTensor, util::node_info(n) + "_betaScale"); |
| 117 | + TRTORCH_CHECK(beta_scale_layer, "Unable to create beta scaling layer in node: " << *n); |
| 118 | + auto add_mm_layer = add_elementwise( |
| 119 | + ctx, |
| 120 | + nvinfer1::ElementWiseOperation::kSUM, |
| 121 | + beta_scale_layer->getOutput(0), |
| 122 | + mm_scale_layer->getOutput(0), |
| 123 | + util::node_info(n)); |
| 124 | + TRTORCH_CHECK(add_mm_layer, "Unable to create addmm layer in node: " << *n); |
| 125 | + |
| 126 | + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], add_mm_layer->getOutput(0)); |
| 127 | + |
| 128 | + LOG_DEBUG("[AddMM layer] Output tensor shape: " << out_tensor->getDimensions()); |
| 129 | + return true; |
75 | 130 | }});
|
76 | 131 | } // namespace
|
77 | 132 | } // namespace impl
|
|
0 commit comments