From a01b5720cee83afa1bbaf1ee7c65941b03fc6c54 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 2 May 2022 22:09:56 +0000 Subject: [PATCH] Permutation extended Extended permutation support in integration (See more details on https://github.com/csarofeen/pytorch/issues/1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time. The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)` 1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation; 2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output; 3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous. By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`. Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5): 1. different rank & same permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1]) # stride (1, wc, c) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 2. different rank & compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 3. different rank & compatible permutation with broadcasting ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1) # stride (1, 1, 1) out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) preserving memory format of t0 ``` 4. different rank & in-compatible permutation ``` t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) t1 = torch.randn(h, w).cuda() # stride (w, 1) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, wc, c, 1) # nvfuser outputs contiguous tensor eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # TI preserves memory format of LHS operand ``` 5. different rank & in-compatible permutation ``` t0 = torch.randn(c, h, w).cuda() # stride (hw, w, 1) t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2]) # stride (hwc, 1, wc, c) jit_out = scripted_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, 1, wc, c) # nvfuser preserves memory format of highest rank tensors eager_out = eager_add(t0, t1) # stride (hwc, 1, wc, c) # stride (hwc, hw, w, 1) # TensorIterator preserves memory format of LHS operand ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/76563 Approved by: https://github.com/kevinstephano, https://github.com/ngimel --- parser.cpp | 247 ++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 218 insertions(+), 29 deletions(-) diff --git a/parser.cpp b/parser.cpp index 64fdcae..919fe42 100644 --- a/parser.cpp +++ b/parser.cpp @@ -139,17 +139,33 @@ struct MemoryFormat { // e.g. for an channels-last tensor, permutation_ would be (n-1)123...(n-2); // Note: we are omitting the leading '0' when applicable, and apparently this // encoding only works with rank < 10 + // see [ Note: MemoryFormat and Stride Order ] size_t permutation_ = 0; // default to non-permuted tensor MemoryFormat() = default; + // [ Note: MemoryFormat and Stride Order ] // stride_order is extracted from // `TensorType::stride_properties()::stride_index_`, it describes the // index of axes from fastest to slowest. + // or a 4d tensor, if we have stride_order = {x0, x1, x2, x3}, The i-th + // fastest dimension would be stride_order[i]. + // // Look at comment for c10::Stride in aten/src/ATen/core/jit_type.h - // e.g. for rank 4 non-permuted tensor, stride_order would be {3, 2, 1, 0} - // for rank 4 channels last tensor, stride_order would be {1, 3, 2, 0} + // + // eg0. for rank 4 non-permuted tensor, stride_order would be {3, 2, 1, 0}, it + // means the fastest dimension is axis-3. the next one would be 2, e.t.c.. So + // it's a non-permuted tensor. + // it should be encoded as permutation_ = 3210 (we special case it to 0) + // + // eg1. for rank 4 channels-last tensor, stride_order would be {1, 3, 2, 0}, + // it means the fastest dimension is axis-1. the next one would be 3, and then + // 2, and then 0. So this is a channels last tensor (NCHW). + // it will be encoded as permutation_ = 1320 + // + // eg2. for a rank 4 permuted tensor, stride_order can be {0, 3, 2, 1} + // it will be encoded as permutation_ = 321 (omitting leading '0') void setPermutation(const std::vector& stride_order) { int rank = stride_order.size(); TORCH_INTERNAL_ASSERT( @@ -158,20 +174,111 @@ struct MemoryFormat { // storing stride_order in `permuted_order` for a simpler life, so we don't // have to decode `permutation_` when we want to apply/restore permutation_. permuted_order_ = stride_order; - bool has_permutation_ = false; + bool has_permutation = false; + permutation_ = 0; for (const auto i : c10::irange(rank)) { permutation_ = permutation_ * 10 + stride_order[i]; - if (!has_permutation_ && stride_order[i] != rank - 1 - i) { - has_permutation_ = true; + if (!has_permutation && stride_order[i] != rank - 1 - i) { + has_permutation = true; } } // special case permutation_ to reflect non-permuted tensor - if (!has_permutation_) { + if (!has_permutation) { permutation_ = 0; } } + // returns the stride order for given MemoryFormat encoding permutation_ + // + // see details for encoding in [ Note: MemoryFormat and Stride Order ] + std::vector toStrideOrder() const { + std::vector stride_order; + // return empty vector for no permutation + if (hasPermutation()) { + // be generous with reserved space + stride_order.reserve(10); + bool encountered_zero = false; + size_t permutation = permutation_; + while (permutation != 0) { + int order = static_cast(permutation % 10); + permutation /= 10; + if (order == 0) { + encountered_zero = true; + } + stride_order.push_back(order); + } + if (!encountered_zero) { + // in case leading '0' is omitted, push it back + stride_order.push_back(0); + } + // since we use push_back, our stride_order is reversed. + std::reverse(stride_order.begin(), stride_order.end()); + } + return stride_order; + } + + // returns c10::nullopt when it's not safe to broadcast current permutation to + // rank + c10::optional broadcastToRank(size_t rank) const { + auto ret = Contiguous(); + if (hasPermutation()) { + auto stride_order = toStrideOrder(); + auto cur_rank = stride_order.size(); + // no op for (cur_rank == 0) || (cur_rank == rank) + if (cur_rank < rank) { + // broadcasting to hight rank can be done by: + // 1. incrementing all existing stride order by rank_diff; + // 2. push back decrementing elements starting with rank_diff; + // where rank_diff = rank - cur_rank + // + // see [ Note: MemoryFormat and Stride Order] + // e.g. + // taking broadcasted bias for channels last as an example + // stride_order = {0, 2, 1} broadcasted to rank == 4 would give us + // rank_diff = 4 - 3 = 1 + // take step 1 -> {1, 3, 2} + // take step 2 -> {1, 3, 2, 0} + int rank_diff = static_cast(rank - cur_rank); + for (auto& val : stride_order) { + val += rank_diff; + } + for (int i = rank_diff - 1; i >= 0; i--) { + stride_order.push_back(i); + } + } else if (cur_rank > rank) { + // shrink permutation to lower rank. We can simply discard higher rank + // stride order when they are not permuted to lower rank bit, because in + // those instance we can't obey broadcasting semantics while preserving + // permutation. We check for stride order and ensure that the lower + // `rank` bits are all permuted within the lower rank. Afterwards, we + // update stride_order by decrement each entry by rank_diff to reflect + // correct stride order. + // + // see [ Note: MemoryFormat and Stride Order] + // e.g. for rank 4 channels last {1, 3, 2, 0}: + // 1. format can safely shrink to rank 3, since any@{1, 3, 2} >= + // (4-3); We ditch last (4-3) rank and decrement each element by (4-1) + // that gives us {0, 2, 1}; + // 2. but when we shrink it to rank 2, we have {1, 3} where 1 < (4-2) + // and it can't be handled, we return c10::nullopt. + int collapsed_ranks = static_cast(cur_rank - rank); + for (size_t i = 0; i < rank; i++) { + if (stride_order[i] < collapsed_ranks) { + // illegal collapsing, return c10::nullopt + return c10::nullopt; + } + // update collapsed stride_order + stride_order[i] -= collapsed_ranks; + } + // discard higher rank stride order. + stride_order.resize(rank); + } + ret.setPermutation(stride_order); + } + return ret; + } + // returns non-permuted format static MemoryFormat Contiguous() { return MemoryFormat(); @@ -295,20 +402,29 @@ class ValueHolder { // returns Val in target format if it exists, otherwise, transpose an existing // copy and add that to bookkeeping. CgValue maybeConvertValue(const MemoryFormat& format) { - auto iter_val = vals_.find(format); - if (iter_val != vals_.end()) { - return iter_val->second; - } - // patching scalar (tensor), memory format doesn't carry meaning and should - // just return the value as-is. - if (!is_tensor_view_ || rank() == 0) { + auto cur_rank = rank(); + // scalar (tensor) where cur_rank == 0, memory format doesn't carry meaning + // and should just return the value as-is. same for non-tensor where + // cur_rank == -1 + if (cur_rank <= 0) { return std::get<1>(getEntry()); } MemoryFormat format_s; CgValue value_s = nullptr; std::tie(format_s, value_s) = getEntry(); - auto val = convertValue(format, format_s, value_s); - vals_[format] = val; + + auto opt_format_d = format.broadcastToRank(static_cast(cur_rank)); + TORCH_INTERNAL_ASSERT( + opt_format_d.has_value(), + "maybeConvertValue requested for illegal permutation"); + MemoryFormat format_d = opt_format_d.value(); + + auto iter_val = vals_.find(format_d); + if (iter_val != vals_.end()) { + return iter_val->second; + } + auto val = convertValue(format_d, format_s, value_s); + vals_[format_d] = val; return val; } @@ -455,6 +571,79 @@ std::pair> getConsistentValues( return std::make_pair(format, list_val); } +// iterate through all vals and return the output MemoryFormat and copies of +// vals. +// 1. When `forced_format == c10::nullopt`, target MemoryFormat returns the +// format of the first val in `vals`, this is to achieve a coherent +// behavior as with eager TensorIterator; +// 2. The target can be overwritten vias specifying `forced_format`. +// +// Note: take `Values&` by reference, since `maybeConvertValue` needs to modify +// the entry and we want that to be updated in `value_map_` +template +std::pair> getPWFormatValues( + c10::optional forced_format, + Values&... vals) { + MemoryFormat format; + if (forced_format.has_value()) { + format = forced_format.value(); + } else { + // get maximum rank on vals + std::vector formats; + std::vector ranks; + auto max_rank_func = [&ranks](const ValueHolder& val, int rank = 0) { + int v_rank = val.rank(); + ranks.push_back(v_rank); + return std::max(rank, v_rank); + }; + int max_rank = iterate(max_rank_func, vals...); + + // going through all permutation, keeping consistency with TensorIterator + // behavior and the first tensor with highest rank dictates output + // permutation + auto format_func = [&formats, &max_rank]( + const ValueHolder& val, + MemoryFormat f = MemoryFormat::Contiguous()) { + auto cur_format = std::get<0>(val.getEntry()); + formats.push_back(cur_format); + return val.rank() == max_rank ? cur_format : f; + }; + format = iterate(format_func, vals...); + + // we need to do pair-wise comparison to ensure that all permutation are + // compatible since permutation could have changed semantics among + // broadcasted tensors. Consider pointwise operation between three tensor + // [N, C, H, W] + [C, H, W] + [H, W] + for (size_t i = 0; i < formats.size() && format.hasPermutation(); i++) { + for (size_t j = 0; j < formats.size(); j++) { + // don't compare scalar tensor or scalar + if (ranks[i] <= 0 || ranks[j] <= 0 || i == j) { + continue; + } + size_t lower_rank = std::min(ranks[i], ranks[j]); + auto i_format = formats[i].broadcastToRank(lower_rank); + auto j_format = formats[j].broadcastToRank(lower_rank); + + // breaks permutation if any: + // 1. i_format can't be broadcasted to lower_rank; + // 2. j_format can't be broadcasted to lower_rank; + if (!i_format.has_value() || !j_format.has_value()) { + format = MemoryFormat::Contiguous(); + } + } + } + } + + auto convert_func = [format]( + ValueHolder& val, std::list list_val = {}) { + list_val.push_front(val.maybeConvertValue(format)); + return list_val; + }; + auto list_val = iterate(convert_func, vals...); + + return std::make_pair(format, list_val); +} + typedef void ( *ParseFuncPtr)(const Node*, std::unordered_map&); typedef bool (*MergeQueryFuncPtr)(const Node*); @@ -742,7 +931,7 @@ class IrParser { // TODO: handle scaling factor when it's not constant 1; MemoryFormat format; std::list list_val; - std::tie(format, list_val) = getConsistentValues( + std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); @@ -783,7 +972,7 @@ class IrParser { MemoryFormat format; std::list list_val; - std::tie(format, list_val) = getConsistentValues( + std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); @@ -839,7 +1028,7 @@ class IrParser { MemoryFormat format; std::list list_val; - std::tie(format, list_val) = getConsistentValues( + std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); @@ -888,7 +1077,7 @@ class IrParser { MemoryFormat format; std::list list_val; - std::tie(format, list_val) = getConsistentValues( + std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); @@ -1097,7 +1286,7 @@ class IrParser { { MemoryFormat format; std::list list_val; - std::tie(format, list_val) = getConsistentValues( + std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); @@ -1155,8 +1344,8 @@ class IrParser { { MemoryFormat format; std::list list_val; - std::tie(format, list_val) = getConsistentValues( - MemoryFormat::Contiguous(), + std::tie(format, list_val) = getPWFormatValues( + c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()], value_map[node->inputs()[2]->unique()]); @@ -1186,8 +1375,8 @@ class IrParser { { MemoryFormat format; std::list list_val; - std::tie(format, list_val) = getConsistentValues( - MemoryFormat::Contiguous(), + std::tie(format, list_val) = getPWFormatValues( + c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()], value_map[node->inputs()[2]->unique()]); @@ -1215,7 +1404,7 @@ class IrParser { { MemoryFormat format; std::list list_val; - std::tie(format, list_val) = getConsistentValues( + std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()], @@ -1316,7 +1505,7 @@ class IrParser { { MemoryFormat format; std::list list_val; - std::tie(format, list_val) = getConsistentValues( + std::tie(format, list_val) = getPWFormatValues( MemoryFormat::Contiguous(), value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()], @@ -2461,7 +2650,7 @@ class IrParser { } else { MemoryFormat format; std::list list_val; - std::tie(format, list_val) = getConsistentValues( + std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); @@ -2520,7 +2709,7 @@ class IrParser { { MemoryFormat format; std::list list_val; - std::tie(format, list_val) = getConsistentValues( + std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]); @@ -2554,7 +2743,7 @@ class IrParser { { MemoryFormat format; std::list list_val; - std::tie(format, list_val) = getConsistentValues( + std::tie(format, list_val) = getPWFormatValues( c10::nullopt, value_map[node->inputs()[0]->unique()], value_map[node->inputs()[1]->unique()]);