Skip to content

Commit

Permalink
remove tmp buffer of cumprod cpu backward kernel (#8369)
Browse files Browse the repository at this point in the history
* remove tmp buffer of cumprod cpu backward kernel

* refine

* refine

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
liufengwei0103 and mergify[bot] authored Jun 7, 2022
1 parent c07f587 commit f237503
Showing 1 changed file with 44 additions and 67 deletions.
111 changes: 44 additions & 67 deletions oneflow/user/kernels/cum_backward_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,88 +18,65 @@ limitations under the License.

namespace oneflow {
namespace {
// O(n) cumprod backward, formula: cumsum(flip(dY * Y)) / X.
// Need to take care when there is at least a zero in the input.
// CumProd backward, formula: flip(cumsum(flip(dY * Y))) / X.
template<typename T>
void CumProdBackward(const T* dy_ptr, T* dx_ptr, const T* output_ptr, const T* input_ptr,
const int64_t up_space, const int64_t space, const int64_t down_space,
const int64_t elem_cnt) {
const auto step = space * down_space;
for (size_t i = 0; i < up_space; i++) {
// two-dims buffer for 0 elem index
std::vector<size_t> cumsum_zeros_number(space * down_space, 0);
auto* cumsum_zeros_number_ptr = cumsum_zeros_number.data();
const size_t base_ptr_offset = step * i;
const T* input_ptr_base = input_ptr + base_ptr_offset;
const T* output_ptr_base = output_ptr + base_ptr_offset;
const T* dy_ptr_base = dy_ptr + base_ptr_offset;
T* dx_ptr_base = dx_ptr + base_ptr_offset;

// Use dx as tmp buffer for finding 0 element in the input.
for (size_t j = 0; j < space; j++) {
const size_t ptr_offset = j * down_space;
auto* tmp_input_ptr = input_ptr + ptr_offset;
auto* tmp_cumsum_zeros_number_ptr = cumsum_zeros_number_ptr + ptr_offset;
auto* last_tmp_cumsum_zeros_number_ptr = tmp_cumsum_zeros_number_ptr - down_space;
for (auto k = 0; k < down_space; k++) {
int is_zero = tmp_input_ptr[k] == 0 ? 1 : 0;
tmp_cumsum_zeros_number_ptr[k] =
is_zero + (j == 0 ? 0 : last_tmp_cumsum_zeros_number_ptr[k]);
}
}
{
// for k < z(z is first zero index)
std::vector<T> reverse_cumsum(down_space, 0);
for (size_t j = 0; j < space; j++) {
const size_t ptr_offset = (space - j - 1) * down_space;
auto* tmp_cumsum_zeros_number_ptr = cumsum_zeros_number_ptr + ptr_offset;
auto* tmp_dy_ptr = dy_ptr + ptr_offset;
auto* tmp_dx_ptr = dx_ptr + ptr_offset;
auto* tmp_output_ptr = output_ptr + ptr_offset;
auto* tmp_input_ptr = input_ptr + ptr_offset;
for (auto k = 0; k < down_space; k++) {
if (tmp_cumsum_zeros_number_ptr[k] > 0) { continue; }
reverse_cumsum[k] += tmp_output_ptr[k] * tmp_dy_ptr[k];
tmp_dx_ptr[k] = reverse_cumsum[k] / tmp_input_ptr[k];
}
auto* cur_input_ptr = input_ptr_base + ptr_offset;

auto* cumsum_zeros_number_ptr = dx_ptr_base + ptr_offset;
auto* last_cumsum_zeros_number_ptr = cumsum_zeros_number_ptr - down_space;
for (size_t k = 0; k < down_space; k++) {
int is_zero = cur_input_ptr[k] == 0 ? 1 : 0;
cumsum_zeros_number_ptr[k] = is_zero + (j == 0 ? 0 : last_cumsum_zeros_number_ptr[k]);
}
}
{
// for k == z
std::vector<size_t> first_zero(down_space, space);
for (size_t j = 0; j < space; j++) {
auto* tmp_cumsum_zeros_number_ptr = cumsum_zeros_number_ptr + j * down_space;
for (size_t k = 0; k < down_space; k++) {
if (tmp_cumsum_zeros_number_ptr[k] == 1 && first_zero[k] == space) { first_zero[k] = j; }
}
}
// compute along row
std::vector<T> cumsum_buffer(down_space, 0);
for (size_t k = 0; k < down_space; k++) {
auto* tmp_input_down_offset_ptr = input_ptr + k;
auto* tmp_output_down_offset_ptr = output_ptr + k;
auto* tmp_dy_down_offset_ptr = dy_ptr + k;
auto* tmp_cumsum_zero_number_down_offset_ptr = cumsum_zeros_number_ptr + k;

size_t first_zero_index = first_zero[k];
if (first_zero_index == space) { continue; }
auto cumprod_before_first_zero =
first_zero_index == 0
? 1
: *(tmp_output_down_offset_ptr + (first_zero_index - 1) * down_space);
auto cumprod = 1;
for (size_t j = first_zero_index; j < space; j++) {
const size_t ptr_offset = j * down_space;
auto tmp_dy = *(tmp_dy_down_offset_ptr + ptr_offset);
auto tmp_input = *(tmp_input_down_offset_ptr + ptr_offset);
auto tmp_cumsum_zero_number = *(tmp_cumsum_zero_number_down_offset_ptr + ptr_offset);
if (tmp_cumsum_zero_number != 1) { continue; }
if (j != first_zero_index) { cumprod *= tmp_input; }
cumsum_buffer[k] += cumprod_before_first_zero * tmp_dy * cumprod;
for (size_t j = 0; j < down_space; j++) {
auto* cumsum_zeros_number_ptr = j + dx_ptr_base;
size_t first_zero_index = space;
// Find index of first zero in input.
for (size_t k = 0; k < space; k++) {
if (cumsum_zeros_number_ptr[j + k * down_space] == 1) {
first_zero_index = k;
break;
}
}
for (size_t j = 0; j < down_space; j++) {
*(dx_ptr + first_zero[j] * down_space) = cumsum_buffer[j];
// Suppose z is index of first zero element in input,
// for element which index is less than z grad is computed as below:
T reverse_cumsum = 0;
for (size_t k = 0; k < first_zero_index; k++) {
const size_t data_offset = j + (first_zero_index - k - 1) * down_space;
reverse_cumsum += output_ptr_base[data_offset] * dy_ptr_base[data_offset];
dx_ptr_base[data_offset] = reverse_cumsum / input_ptr_base[data_offset];
}
// For where index is z, its grad is computed as below:
if (first_zero_index == space) { continue; }
T cumprod = 1;
T cumsum = 0;
T cumprod_before_first_zero =
first_zero_index == 0 ? 1 : output_ptr_base[(first_zero_index - 1) * down_space];
for (size_t k = first_zero_index; k < space; k++) {
const size_t data_offset = j + k * down_space;
// Recover dx_ptr default value
if (dx_ptr_base[data_offset] >= 1) { dx_ptr_base[data_offset] = 0; }
if (k != first_zero_index) { cumprod *= input_ptr_base[data_offset]; }
cumsum += cumprod_before_first_zero * dy_ptr_base[data_offset] * cumprod;
}
dx_ptr_base[j + first_zero_index * down_space] = cumsum;
}

input_ptr += step;
output_ptr += step;
dy_ptr += step;
dx_ptr += step;
}
}
} // namespace
Expand Down

0 comments on commit f237503

Please sign in to comment.