Skip to content

Commit a086a5b

Browse files
committed
feat(aten::std|aten::masked_fill): Implement masked_fill, aten::std
works for non bias corrected cases Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent fa7d6d9 commit a086a5b

File tree

16 files changed

+565
-35
lines changed

16 files changed

+565
-35
lines changed

core/conversion/conversion.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
8787
if (eval) {
8888
if (!eval.value().isTensor()) {
8989
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
90+
if (eval.value().isTuple() && eval.value().toTuple()->elements().size() == 1) {
91+
eval.value() = {eval.value().toTuple()->elements()[0]};
92+
}
9093
} else {
9194
LOG_DEBUG(ctx->logger, "Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
9295
}
@@ -283,6 +286,9 @@ void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n, boo
283286
auto eval = EvaluateNode(ctx, bn);
284287
if (!eval.value().isTensor()) {
285288
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be: " << eval.value());
289+
if (eval.value().isTuple() && eval.value().toTuple()->elements().size() == 1) {
290+
eval.value() = {eval.value().toTuple()->elements()[0]};
291+
}
286292
} else {
287293
LOG_DEBUG(
288294
ctx->logger,

core/conversion/converters/impl/element_wise.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,30 @@ auto element_wise_registrations TRTORCH_UNUSED =
185185
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
186186
return true;
187187
}})
188+
.pattern({"aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)",
189+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
190+
// Should implement self - alpha * other
191+
auto self = args[0].ITensorOrFreeze(ctx);
192+
auto other = args[1].unwrapToScalar().to<float>();
193+
auto alpha = args[2].unwrapToScalar().to<float>();
194+
195+
auto rhs = other * alpha;
196+
if (1 != rhs) {
197+
auto rhs_tensor = tensor_to_const(ctx, torch::tensor({rhs}));
198+
auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, self, rhs_tensor, util::node_info(n));
199+
TRTORCH_CHECK(sub, "Unable to create sub layer from node: " << *n);
200+
sub->setName(util::node_info(n).c_str());
201+
LOG_DEBUG("Output tensor shape: " << sub->getOutput(0)->getDimensions());
202+
ctx->AssociateValueAndTensor(n->outputs()[0], sub->getOutput(0));
203+
return true;
204+
} else {
205+
LOG_DEBUG("Nothing to be done this layer, passing through input");
206+
LOG_DEBUG("Output tensor shape: " << self->getDimensions());
207+
208+
ctx->AssociateValueAndTensor(n->outputs()[0], self);
209+
return true;
210+
}
211+
}})
188212
.pattern({"aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar "
189213
"alpha=1) -> (Tensor(a!))",
190214
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

core/conversion/converters/impl/reduce.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -40,23 +40,23 @@ auto reduce_registrations TRTORCH_UNUSED =
4040
auto in_dims = util::toVec(in_tensor->getDimensions());
4141
LOG_DEBUG("InDims " << in_dims); // Some abuse of toDim but just for debug info
4242
LOG_DEBUG(
43-
"Dim to reduce(original):" << util::toDims(dims)); // Some abuse of toDim but just for debug info
43+
"Dim to reduce (original): " << util::toDims(dims)); // Some abuse of toDim but just for debug info
4444
for (size_t i = 0; i < dims.size(); i++) {
4545
auto dim_val = dims[i] < 0 ? (in_dims.size() + dims[i]) : dims[i];
4646
calculated_dims.push_back(dim_val);
4747
}
4848
LOG_DEBUG(
49-
"Dim to reduce(converted):"
49+
"Dim to reduce (converted): "
5050
<< util::toDims(calculated_dims)); // Some abuse of toDim but just for debug info
5151

5252
uint32_t axis_mask = 0;
5353
for (size_t d = 0; d < calculated_dims.size(); d++) {
5454
axis_mask |= 1 << calculated_dims[d];
5555
}
56-
LOG_DEBUG("Axis Mask" << std::bitset<32>(axis_mask));
56+
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));
5757

5858
auto keepdim = args[2].unwrapToBool();
59-
LOG_DEBUG("Keep dims :" << keepdim);
59+
LOG_DEBUG("Keep dims: " << keepdim);
6060
LOG_WARNING("Mean converter disregards dtype");
6161
auto mean_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kAVG, axis_mask, keepdim);
6262
TRTORCH_CHECK(mean_layer, "Unable to create mean layer from node: " << *n);
@@ -106,10 +106,10 @@ auto reduce_registrations TRTORCH_UNUSED =
106106
for (size_t d = 0; d < calculated_dims.size(); d++) {
107107
axis_mask |= 1 << calculated_dims[d];
108108
}
109-
LOG_DEBUG("Axis Mask" << std::bitset<32>(axis_mask));
109+
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));
110110

111111
auto keepdim = args[2].unwrapToBool();
112-
LOG_DEBUG("Keep dims :" << keepdim);
112+
LOG_DEBUG("Keep dims: " << keepdim);
113113

114114
LOG_WARNING("Sum converter disregards dtype");
115115
auto sum_layer = ctx->net->addReduce(*in_tensor, nvinfer1::ReduceOperation::kSUM, axis_mask, keepdim);
@@ -145,13 +145,13 @@ auto reduce_registrations TRTORCH_UNUSED =
145145
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
146146
auto in_tensor = args[0].ITensorOrFreeze(ctx);
147147
auto dim = args[1].unwrapToInt();
148-
LOG_DEBUG("Dim to reduce:" << dim); // Some abuse of toDim but just for debug info
148+
LOG_DEBUG("Dim to reduce: " << dim); // Some abuse of toDim but just for debug info
149149

150150
uint32_t axis_mask = 1 << dim;
151-
LOG_DEBUG("Axis Mask" << std::bitset<32>(axis_mask));
151+
LOG_DEBUG("Axis Mask: " << std::bitset<32>(axis_mask));
152152

153153
auto keepdim = args[2].unwrapToBool();
154-
LOG_DEBUG("Keep dims :" << keepdim);
154+
LOG_DEBUG("Keep dims: " << keepdim);
155155

156156
LOG_WARNING("Prod converter disregards dtype");
157157
auto prod_layer =

core/conversion/converters/impl/select.cpp

+16-18
Original file line numberDiff line numberDiff line change
@@ -71,35 +71,34 @@ auto select_registrations TRTORCH_UNUSED =
7171
RegisterNodeConversionPatterns()
7272
.pattern({"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))",
7373
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
74-
auto in = args[0].ITensor();
74+
auto in = args[0].ITensorOrFreeze(ctx);
7575
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
7676
auto axis = args[1].unwrapToInt();
7777
axis = axis < 0 ? axis + maxDim : axis;
7878
auto ind = (int32_t)args[2].unwrapToInt();
7979

8080
// index to access needs to be an at::Tensor
8181
at::Tensor indices = torch::tensor({ind}).to(torch::kI32);
82-
auto weights = Weights(ctx, indices);
83-
84-
// IConstantLayer to convert indices from Weights to ITensor
85-
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
86-
TRTORCH_CHECK(const_layer, "Unable to create constant layer from node: " << *n);
87-
auto const_out = const_layer->getOutput(0);
82+
auto const_out = tensor_to_const(ctx, indices);
8883

8984
// IGatherLayer takes in input tensor, the indices, and the axis
9085
// of input tensor to take indices from
9186
auto gather_layer = ctx->net->addGather(*in, *const_out, axis);
9287
TRTORCH_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
93-
auto gather_out = gather_layer->getOutput(0);
88+
auto out = gather_layer->getOutput(0);
9489

95-
// IShuffleLayer removes redundant dimensions
96-
auto shuffle_layer = ctx->net->addShuffle(*gather_out);
97-
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
98-
shuffle_layer->setReshapeDimensions(util::squeezeDims(gather_out->getDimensions(), axis));
99-
shuffle_layer->setName(util::node_info(n).c_str());
100-
auto shuffle_out = shuffle_layer->getOutput(0);
90+
LOG_DEBUG("Gather tensor shape: " << out->getDimensions());
10191

102-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_out);
92+
if (out->getDimensions().nbDims != 1) {
93+
// IShuffleLayer removes redundant dimensions
94+
auto shuffle_layer = ctx->net->addShuffle(*out);
95+
TRTORCH_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
96+
shuffle_layer->setReshapeDimensions(util::squeezeDims(out->getDimensions(), axis));
97+
shuffle_layer->setName(util::node_info(n).c_str());
98+
out = shuffle_layer->getOutput(0);
99+
}
100+
101+
out = ctx->AssociateValueAndTensor(n->outputs()[0], out);
103102

104103
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
105104

@@ -253,15 +252,14 @@ auto select_registrations TRTORCH_UNUSED =
253252
"aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)",
254253
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
255254
auto self = args[0].ITensorOrFreeze(ctx);
256-
LOG_DEBUG(args[1].unwrapToTensor());
257255
auto mask = castITensor(ctx, args[1].ITensorOrFreeze(ctx), nvinfer1::DataType::kBOOL);
256+
mask = addPadding(ctx, n, mask, self->getDimensions().nbDims, false, true);
258257
auto val = args[2].unwrapToScalar().to<float>();
259-
LOG_DEBUG(torch::full(util::toVec(self->getDimensions()), val));
260258
auto val_t = tensor_to_const(ctx, torch::full(util::toVec(self->getDimensions()), val));
261259

262260
TRTORCH_CHECK(util::broadcastable(self->getDimensions(), mask->getDimensions(), /*multidirectional=*/false), "Self and mask tensors are not broadcastable");
263261

264-
auto new_layer = ctx->net->addSelect(*mask, *self, *val_t);
262+
auto new_layer = ctx->net->addSelect(*mask, *val_t, *self);
265263
TRTORCH_CHECK(new_layer, "Unable to create layer for aten::masked_fill");
266264

267265
new_layer->setName(util::node_info(n).c_str());

core/conversion/evaluators/aten.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -573,10 +573,10 @@ auto aten_registrations TRTORCH_UNUSED =
573573
auto dtype = args.at(n->input(1)).IValue();
574574
auto device = args.at(n->input(2)).IValue();
575575
auto tensor = createTensorFromList(*data, *dtype, *device);
576-
LOG_DEBUG(tensor);
577576
if (tensor.dtype() == at::kByte) {
578-
return tensor.to(at::kInt);
577+
return tensor.to(at::kFloat);
579578
}
579+
std::cout << tensor << std::endl;
580580
return tensor;
581581
},
582582
EvalOptions().validSchemas({

core/lowering/lowering.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
4848
passes::UnpackAddMM(g);
4949
// passes::UnpackBatchNorm(g);
5050
passes::UnpackLogSoftmax(g);
51+
passes::UnpackStd(g);
52+
passes::UnpackVar(g);
5153
passes::RemoveNOPs(g);
5254
passes::AliasOperators(g);
5355
passes::SiluToSigmoidMultipication(g);

core/lowering/passes/BUILD

+3-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ cc_library(
2424
"unpack_addmm.cpp",
2525
"unpack_batch_norm.cpp",
2626
"unpack_log_softmax.cpp",
27-
"unpack_hardswish.cpp"
27+
"unpack_hardswish.cpp",
28+
"unpack_std.cpp",
29+
"unpack_var.cpp",
2830
],
2931
hdrs = [
3032
"passes.h",

core/lowering/passes/passes.h

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
1919
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
2020
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
2121
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
22+
void UnpackStd(std::shared_ptr<torch::jit::Graph>& graph);
23+
void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph);
2224
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
2325
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
2426
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);

core/lowering/passes/unpack_std.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void UnpackStd(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string std_pattern = R"IR(
12+
graph(%1, %dim, %unbiased, %keepdim):
13+
%out: Tensor = aten::std(%1, %dim, %unbiased, %keepdim)
14+
return (%out))IR";
15+
std::string unpacked_pattern = R"IR(
16+
graph(%1, %dim, %unbiased, %keepdim):
17+
%z: Tensor = aten::var(%1, %dim, %unbiased, %keepdim)
18+
%out: Tensor = aten::sqrt(%z)
19+
return (%out))IR";
20+
21+
torch::jit::SubgraphRewriter std_rewriter;
22+
std_rewriter.RegisterRewritePattern(std_pattern, unpacked_pattern);
23+
std_rewriter.runOnGraph(graph);
24+
LOG_GRAPH("Post unpack std: " << *graph);
25+
}
26+
27+
} // namespace passes
28+
} // namespace lowering
29+
} // namespace core
30+
} // namespace trtorch

core/lowering/passes/unpack_var.cpp

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string var_pattern = R"IR(
12+
graph(%input, %dim, %unbiased, %keepdim):
13+
%out: Tensor = aten::var(%input, %dim, %unbiased, %keepdim)
14+
return (%out))IR";
15+
std::string unpacked_pattern = R"IR(
16+
graph(%input, %dims, %unbiased, %keepdim):
17+
%none: None = prim::Constant()
18+
%false: bool = prim::Constant[value=0]()
19+
%0: int = prim::Constant[value=0]()
20+
%1: int = prim::Constant[value=1]()
21+
%sqrd: Tensor = aten::mul(%input, %input)
22+
%sqrdmean: Tensor = aten::mean(%sqrd, %dims, %keepdim, %none)
23+
%mean: Tensor = aten::mean(%input, %dims, %keepdim, %none)
24+
%meansqrd: Tensor = aten::mul(%mean, %mean)
25+
%var: Tensor = aten::sub(%sqrdmean, %meansqrd, %1)
26+
%varout : Tensor = prim::If(%unbiased)
27+
block0():
28+
%shape: int[] = aten::size(%input)
29+
%shapet: Tensor = aten::tensor(%shape, %0, %none, %false)
30+
%dim: int = prim::ListUnpack(%dims)
31+
%reduceddims: Tensor = aten::select(%shapet, %0, %dim)
32+
%numel: Tensor = aten::prod(%reduceddims, %dim, %keepdim, %none)
33+
%mul: Tensor = aten::mul(%var, %numel)
34+
%sub: Tensor = aten::sub(%numel, %1, %1)
35+
%v: Tensor = aten::div(%mul, %sub)
36+
-> (%v)
37+
block1():
38+
-> (%var)
39+
return(%varout))IR";
40+
41+
torch::jit::SubgraphRewriter var_rewriter;
42+
var_rewriter.RegisterRewritePattern(var_pattern, unpacked_pattern);
43+
var_rewriter.runOnGraph(graph);
44+
LOG_DEBUG("Post unpack var: " << *graph);
45+
46+
}
47+
48+
} // namespace passes
49+
} // namespace lowering
50+
} // namespace core
51+
} // namespace trtorch

core/util/trt_util.h

-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ inline std::ostream& operator<<(std::ostream& os, const nvinfer1::TensorFormat&
2121

2222
inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DataType& dtype) {
2323
switch (dtype) {
24-
case nvinfer1::DataType::kBOOL:
25-
return stream << "Bool";
2624
case nvinfer1::DataType::kFLOAT:
2725
return stream << "Float32";
2826
case nvinfer1::DataType::kHALF:

0 commit comments

Comments
 (0)