Skip to content

Commit

Permalink
Permutation extended
Browse files Browse the repository at this point in the history
Extended permutation support in integration (See more details on csarofeen/pytorch#1601). This update allows us to better support permutation propagation on tensors, specifically for binary ops with inputs of different ranks. Our goal is to avoid permuting tensors unless absolutely necessary. We try to preserve the permutation propagation rule in aten, with some known limitation at the time.

The idea in this implementation is the same as with our existing code, which is to permute input/output tensors outside of codegen: For a simplified binary op scenario: `output = binaryOp(input0, input1)`

1. In a simple case where `input0` and `input1` come with the same rank & permutation order, our output would preserve the same permutation;
2. For cases where `input0` and `input1` come with different ranks but with **compatible** permutation, the tensor with the higher rank dictates the permutation of the output;
3. For cases where `input0` and `input1` come with different ranks but with **in-compatible** permutation, this is where permutation propagation fails and the output tensor will be contiguous.

By **compatible** permutation, it means that we can permute the higher rank tensor to contiguous format, and then apply a second permutation to the tensor with lower rank to match their axes. This check is implemented in `MemoryFormat::broadcastToRank(int lower_rank)`.

Some concrete example (note that we comply with eager propagation on cases 1-3, but diverge in behavior for cases 4, 5):
1. different rank & same permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(h, w, c).cuda().permute([2, 0, 1])  # stride (1, wc, c)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
2. different rank & compatible permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(c, h, w).cuda()  # stride (hw, w, 1)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
3. different rank & compatible permutation with broadcasting
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(c).cuda().unsqueeze(-1).unsqueeze(-1)  # stride (1, 1, 1)
    out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c) preserving memory format of t0
```
4. different rank & in-compatible permutation
```
    t0 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    t1 = torch.randn(h, w).cuda()  # stride (w, 1)
    jit_out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, wc, c, 1)  # nvfuser outputs contiguous tensor
    eager_out = eager_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, 1, wc, c)  # TI preserves memory format of LHS operand
```
5. different rank & in-compatible permutation
```
    t0 = torch.randn(c, h, w).cuda()  # stride (hw, w, 1)
    t1 = torch.randn(b, h, w, c).cuda().permute([0, 3, 1, 2])  # stride (hwc, 1, wc, c)
    jit_out = scripted_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, 1, wc, c)  # nvfuser preserves memory format of highest rank tensors
    eager_out = eager_add(t0, t1)  # stride (hwc, 1, wc, c)  # stride (hwc, hw, w, 1)  # TensorIterator preserves memory format of LHS operand
```
Pull Request resolved: pytorch/pytorch#76563
Approved by: https://github.com/kevinstephano, https://github.com/ngimel
  • Loading branch information
jjsjann123 authored and pytorchmergebot committed May 2, 2022
1 parent 108b113 commit a01b572
Showing 1 changed file with 218 additions and 29 deletions.
247 changes: 218 additions & 29 deletions parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,33 @@ struct MemoryFormat {
// e.g. for an channels-last tensor, permutation_ would be (n-1)123...(n-2);
// Note: we are omitting the leading '0' when applicable, and apparently this
// encoding only works with rank < 10
// see [ Note: MemoryFormat and Stride Order ]
size_t permutation_ = 0;

// default to non-permuted tensor
MemoryFormat() = default;

// [ Note: MemoryFormat and Stride Order ]
// stride_order is extracted from
// `TensorType::stride_properties()::stride_index_`, it describes the
// index of axes from fastest to slowest.
// or a 4d tensor, if we have stride_order = {x0, x1, x2, x3}, The i-th
// fastest dimension would be stride_order[i].
//
// Look at comment for c10::Stride in aten/src/ATen/core/jit_type.h
// e.g. for rank 4 non-permuted tensor, stride_order would be {3, 2, 1, 0}
// for rank 4 channels last tensor, stride_order would be {1, 3, 2, 0}
//
// eg0. for rank 4 non-permuted tensor, stride_order would be {3, 2, 1, 0}, it
// means the fastest dimension is axis-3. the next one would be 2, e.t.c.. So
// it's a non-permuted tensor.
// it should be encoded as permutation_ = 3210 (we special case it to 0)
//
// eg1. for rank 4 channels-last tensor, stride_order would be {1, 3, 2, 0},
// it means the fastest dimension is axis-1. the next one would be 3, and then
// 2, and then 0. So this is a channels last tensor (NCHW).
// it will be encoded as permutation_ = 1320
//
// eg2. for a rank 4 permuted tensor, stride_order can be {0, 3, 2, 1}
// it will be encoded as permutation_ = 321 (omitting leading '0')
void setPermutation(const std::vector<int>& stride_order) {
int rank = stride_order.size();
TORCH_INTERNAL_ASSERT(
Expand All @@ -158,20 +174,111 @@ struct MemoryFormat {
// storing stride_order in `permuted_order` for a simpler life, so we don't
// have to decode `permutation_` when we want to apply/restore permutation_.
permuted_order_ = stride_order;
bool has_permutation_ = false;
bool has_permutation = false;
permutation_ = 0;
for (const auto i : c10::irange(rank)) {
permutation_ = permutation_ * 10 + stride_order[i];
if (!has_permutation_ && stride_order[i] != rank - 1 - i) {
has_permutation_ = true;
if (!has_permutation && stride_order[i] != rank - 1 - i) {
has_permutation = true;
}
}

// special case permutation_ to reflect non-permuted tensor
if (!has_permutation_) {
if (!has_permutation) {
permutation_ = 0;
}
}

// returns the stride order for given MemoryFormat encoding permutation_
//
// see details for encoding in [ Note: MemoryFormat and Stride Order ]
std::vector<int> toStrideOrder() const {
std::vector<int> stride_order;
// return empty vector for no permutation
if (hasPermutation()) {
// be generous with reserved space
stride_order.reserve(10);
bool encountered_zero = false;
size_t permutation = permutation_;
while (permutation != 0) {
int order = static_cast<int>(permutation % 10);
permutation /= 10;
if (order == 0) {
encountered_zero = true;
}
stride_order.push_back(order);
}
if (!encountered_zero) {
// in case leading '0' is omitted, push it back
stride_order.push_back(0);
}
// since we use push_back, our stride_order is reversed.
std::reverse(stride_order.begin(), stride_order.end());
}
return stride_order;
}

// returns c10::nullopt when it's not safe to broadcast current permutation to
// rank
c10::optional<MemoryFormat> broadcastToRank(size_t rank) const {
auto ret = Contiguous();
if (hasPermutation()) {
auto stride_order = toStrideOrder();
auto cur_rank = stride_order.size();
// no op for (cur_rank == 0) || (cur_rank == rank)
if (cur_rank < rank) {
// broadcasting to hight rank can be done by:
// 1. incrementing all existing stride order by rank_diff;
// 2. push back decrementing elements starting with rank_diff;
// where rank_diff = rank - cur_rank
//
// see [ Note: MemoryFormat and Stride Order]
// e.g.
// taking broadcasted bias for channels last as an example
// stride_order = {0, 2, 1} broadcasted to rank == 4 would give us
// rank_diff = 4 - 3 = 1
// take step 1 -> {1, 3, 2}
// take step 2 -> {1, 3, 2, 0}
int rank_diff = static_cast<int>(rank - cur_rank);
for (auto& val : stride_order) {
val += rank_diff;
}
for (int i = rank_diff - 1; i >= 0; i--) {
stride_order.push_back(i);
}
} else if (cur_rank > rank) {
// shrink permutation to lower rank. We can simply discard higher rank
// stride order when they are not permuted to lower rank bit, because in
// those instance we can't obey broadcasting semantics while preserving
// permutation. We check for stride order and ensure that the lower
// `rank` bits are all permuted within the lower rank. Afterwards, we
// update stride_order by decrement each entry by rank_diff to reflect
// correct stride order.
//
// see [ Note: MemoryFormat and Stride Order]
// e.g. for rank 4 channels last {1, 3, 2, 0}:
// 1. format can safely shrink to rank 3, since any@{1, 3, 2} >=
// (4-3); We ditch last (4-3) rank and decrement each element by (4-1)
// that gives us {0, 2, 1};
// 2. but when we shrink it to rank 2, we have {1, 3} where 1 < (4-2)
// and it can't be handled, we return c10::nullopt.
int collapsed_ranks = static_cast<int>(cur_rank - rank);
for (size_t i = 0; i < rank; i++) {
if (stride_order[i] < collapsed_ranks) {
// illegal collapsing, return c10::nullopt
return c10::nullopt;
}
// update collapsed stride_order
stride_order[i] -= collapsed_ranks;
}
// discard higher rank stride order.
stride_order.resize(rank);
}
ret.setPermutation(stride_order);
}
return ret;
}

// returns non-permuted format
static MemoryFormat Contiguous() {
return MemoryFormat();
Expand Down Expand Up @@ -295,20 +402,29 @@ class ValueHolder {
// returns Val in target format if it exists, otherwise, transpose an existing
// copy and add that to bookkeeping.
CgValue maybeConvertValue(const MemoryFormat& format) {
auto iter_val = vals_.find(format);
if (iter_val != vals_.end()) {
return iter_val->second;
}
// patching scalar (tensor), memory format doesn't carry meaning and should
// just return the value as-is.
if (!is_tensor_view_ || rank() == 0) {
auto cur_rank = rank();
// scalar (tensor) where cur_rank == 0, memory format doesn't carry meaning
// and should just return the value as-is. same for non-tensor where
// cur_rank == -1
if (cur_rank <= 0) {
return std::get<1>(getEntry());
}
MemoryFormat format_s;
CgValue value_s = nullptr;
std::tie(format_s, value_s) = getEntry();
auto val = convertValue(format, format_s, value_s);
vals_[format] = val;

auto opt_format_d = format.broadcastToRank(static_cast<size_t>(cur_rank));
TORCH_INTERNAL_ASSERT(
opt_format_d.has_value(),
"maybeConvertValue requested for illegal permutation");
MemoryFormat format_d = opt_format_d.value();

auto iter_val = vals_.find(format_d);
if (iter_val != vals_.end()) {
return iter_val->second;
}
auto val = convertValue(format_d, format_s, value_s);
vals_[format_d] = val;
return val;
}

Expand Down Expand Up @@ -455,6 +571,79 @@ std::pair<MemoryFormat, std::list<CgValue>> getConsistentValues(
return std::make_pair(format, list_val);
}

// iterate through all vals and return the output MemoryFormat and copies of
// vals.
// 1. When `forced_format == c10::nullopt`, target MemoryFormat returns the
// format of the first val in `vals`, this is to achieve a coherent
// behavior as with eager TensorIterator;
// 2. The target can be overwritten vias specifying `forced_format`.
//
// Note: take `Values&` by reference, since `maybeConvertValue` needs to modify
// the entry and we want that to be updated in `value_map_`
template <class... Values>
std::pair<MemoryFormat, std::list<CgValue>> getPWFormatValues(
c10::optional<MemoryFormat> forced_format,
Values&... vals) {
MemoryFormat format;
if (forced_format.has_value()) {
format = forced_format.value();
} else {
// get maximum rank on vals
std::vector<MemoryFormat> formats;
std::vector<int> ranks;
auto max_rank_func = [&ranks](const ValueHolder& val, int rank = 0) {
int v_rank = val.rank();
ranks.push_back(v_rank);
return std::max(rank, v_rank);
};
int max_rank = iterate(max_rank_func, vals...);

// going through all permutation, keeping consistency with TensorIterator
// behavior and the first tensor with highest rank dictates output
// permutation
auto format_func = [&formats, &max_rank](
const ValueHolder& val,
MemoryFormat f = MemoryFormat::Contiguous()) {
auto cur_format = std::get<0>(val.getEntry());
formats.push_back(cur_format);
return val.rank() == max_rank ? cur_format : f;
};
format = iterate(format_func, vals...);

// we need to do pair-wise comparison to ensure that all permutation are
// compatible since permutation could have changed semantics among
// broadcasted tensors. Consider pointwise operation between three tensor
// [N, C, H, W] + [C, H, W] + [H, W]
for (size_t i = 0; i < formats.size() && format.hasPermutation(); i++) {
for (size_t j = 0; j < formats.size(); j++) {
// don't compare scalar tensor or scalar
if (ranks[i] <= 0 || ranks[j] <= 0 || i == j) {
continue;
}
size_t lower_rank = std::min(ranks[i], ranks[j]);
auto i_format = formats[i].broadcastToRank(lower_rank);
auto j_format = formats[j].broadcastToRank(lower_rank);

// breaks permutation if any:
// 1. i_format can't be broadcasted to lower_rank;
// 2. j_format can't be broadcasted to lower_rank;
if (!i_format.has_value() || !j_format.has_value()) {
format = MemoryFormat::Contiguous();
}
}
}
}

auto convert_func = [format](
ValueHolder& val, std::list<CgValue> list_val = {}) {
list_val.push_front(val.maybeConvertValue(format));
return list_val;
};
auto list_val = iterate(convert_func, vals...);

return std::make_pair(format, list_val);
}

typedef void (
*ParseFuncPtr)(const Node*, std::unordered_map<size_t, ValueHolder>&);
typedef bool (*MergeQueryFuncPtr)(const Node*);
Expand Down Expand Up @@ -742,7 +931,7 @@ class IrParser {
// TODO: handle scaling factor when it's not constant 1;
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
std::tie(format, list_val) = getPWFormatValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
Expand Down Expand Up @@ -783,7 +972,7 @@ class IrParser {

MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
std::tie(format, list_val) = getPWFormatValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
Expand Down Expand Up @@ -839,7 +1028,7 @@ class IrParser {

MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
std::tie(format, list_val) = getPWFormatValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
Expand Down Expand Up @@ -888,7 +1077,7 @@ class IrParser {

MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
std::tie(format, list_val) = getPWFormatValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
Expand Down Expand Up @@ -1097,7 +1286,7 @@ class IrParser {
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
std::tie(format, list_val) = getPWFormatValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
Expand Down Expand Up @@ -1155,8 +1344,8 @@ class IrParser {
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
std::tie(format, list_val) = getPWFormatValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()],
value_map[node->inputs()[2]->unique()]);
Expand Down Expand Up @@ -1186,8 +1375,8 @@ class IrParser {
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
MemoryFormat::Contiguous(),
std::tie(format, list_val) = getPWFormatValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()],
value_map[node->inputs()[2]->unique()]);
Expand Down Expand Up @@ -1215,7 +1404,7 @@ class IrParser {
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
std::tie(format, list_val) = getPWFormatValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()],
Expand Down Expand Up @@ -1316,7 +1505,7 @@ class IrParser {
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
std::tie(format, list_val) = getPWFormatValues(
MemoryFormat::Contiguous(),
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()],
Expand Down Expand Up @@ -2461,7 +2650,7 @@ class IrParser {
} else {
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
std::tie(format, list_val) = getPWFormatValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
Expand Down Expand Up @@ -2520,7 +2709,7 @@ class IrParser {
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
std::tie(format, list_val) = getPWFormatValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
Expand Down Expand Up @@ -2554,7 +2743,7 @@ class IrParser {
{
MemoryFormat format;
std::list<Val*> list_val;
std::tie(format, list_val) = getConsistentValues(
std::tie(format, list_val) = getPWFormatValues(
c10::nullopt,
value_map[node->inputs()[0]->unique()],
value_map[node->inputs()[1]->unique()]);
Expand Down

0 comments on commit a01b572

Please sign in to comment.