Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove tmp buffer of cumprod cpu backward kernel #8369

Merged
merged 4 commits into from
Jun 7, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 44 additions & 67 deletions oneflow/user/kernels/cum_backward_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,88 +18,65 @@ limitations under the License.

namespace oneflow {
namespace {
// O(n) cumprod backward, formula: cumsum(flip(dY * Y)) / X.
// Need to take care when there is at least a zero in the input.
// CumProd backward, formula: flip(cumsum(flip(dY * Y))) / X.
template<typename T>
void CumProdBackward(const T* dy_ptr, T* dx_ptr, const T* output_ptr, const T* input_ptr,
const int64_t up_space, const int64_t space, const int64_t down_space,
const int64_t elem_cnt) {
const auto step = space * down_space;
for (size_t i = 0; i < up_space; i++) {
// two-dims buffer for 0 elem index
std::vector<size_t> cumsum_zeros_number(space * down_space, 0);
auto* cumsum_zeros_number_ptr = cumsum_zeros_number.data();
const size_t base_ptr_offset = step * i;
const T* input_ptr_base = input_ptr + base_ptr_offset;
const T* output_ptr_base = output_ptr + base_ptr_offset;
const T* dy_ptr_base = dy_ptr + base_ptr_offset;
T* dx_ptr_base = dx_ptr + base_ptr_offset;

// Use dx as tmp buffer for finding 0 element in the input.
for (size_t j = 0; j < space; j++) {
const size_t ptr_offset = j * down_space;
auto* tmp_input_ptr = input_ptr + ptr_offset;
auto* tmp_cumsum_zeros_number_ptr = cumsum_zeros_number_ptr + ptr_offset;
auto* last_tmp_cumsum_zeros_number_ptr = tmp_cumsum_zeros_number_ptr - down_space;
for (auto k = 0; k < down_space; k++) {
int is_zero = tmp_input_ptr[k] == 0 ? 1 : 0;
tmp_cumsum_zeros_number_ptr[k] =
is_zero + (j == 0 ? 0 : last_tmp_cumsum_zeros_number_ptr[k]);
}
}
{
// for k < z(z is first zero index)
std::vector<T> reverse_cumsum(down_space, 0);
for (size_t j = 0; j < space; j++) {
const size_t ptr_offset = (space - j - 1) * down_space;
auto* tmp_cumsum_zeros_number_ptr = cumsum_zeros_number_ptr + ptr_offset;
auto* tmp_dy_ptr = dy_ptr + ptr_offset;
auto* tmp_dx_ptr = dx_ptr + ptr_offset;
auto* tmp_output_ptr = output_ptr + ptr_offset;
auto* tmp_input_ptr = input_ptr + ptr_offset;
for (auto k = 0; k < down_space; k++) {
if (tmp_cumsum_zeros_number_ptr[k] > 0) { continue; }
reverse_cumsum[k] += tmp_output_ptr[k] * tmp_dy_ptr[k];
tmp_dx_ptr[k] = reverse_cumsum[k] / tmp_input_ptr[k];
}
auto* cur_input_ptr = input_ptr_base + ptr_offset;

auto* cumsum_zeros_number_ptr = dx_ptr_base + ptr_offset;
auto* last_cumsum_zeros_number_ptr = cumsum_zeros_number_ptr - down_space;
for (size_t k = 0; k < down_space; k++) {
int is_zero = cur_input_ptr[k] == 0 ? 1 : 0;
cumsum_zeros_number_ptr[k] = is_zero + (j == 0 ? 0 : last_cumsum_zeros_number_ptr[k]);
}
}
{
// for k == z
std::vector<size_t> first_zero(down_space, space);
for (size_t j = 0; j < space; j++) {
auto* tmp_cumsum_zeros_number_ptr = cumsum_zeros_number_ptr + j * down_space;
for (size_t k = 0; k < down_space; k++) {
if (tmp_cumsum_zeros_number_ptr[k] == 1 && first_zero[k] == space) { first_zero[k] = j; }
}
}
// compute along row
std::vector<T> cumsum_buffer(down_space, 0);
for (size_t k = 0; k < down_space; k++) {
auto* tmp_input_down_offset_ptr = input_ptr + k;
auto* tmp_output_down_offset_ptr = output_ptr + k;
auto* tmp_dy_down_offset_ptr = dy_ptr + k;
auto* tmp_cumsum_zero_number_down_offset_ptr = cumsum_zeros_number_ptr + k;

size_t first_zero_index = first_zero[k];
if (first_zero_index == space) { continue; }
auto cumprod_before_first_zero =
first_zero_index == 0
? 1
: *(tmp_output_down_offset_ptr + (first_zero_index - 1) * down_space);
auto cumprod = 1;
for (size_t j = first_zero_index; j < space; j++) {
const size_t ptr_offset = j * down_space;
auto tmp_dy = *(tmp_dy_down_offset_ptr + ptr_offset);
auto tmp_input = *(tmp_input_down_offset_ptr + ptr_offset);
auto tmp_cumsum_zero_number = *(tmp_cumsum_zero_number_down_offset_ptr + ptr_offset);
if (tmp_cumsum_zero_number != 1) { continue; }
if (j != first_zero_index) { cumprod *= tmp_input; }
cumsum_buffer[k] += cumprod_before_first_zero * tmp_dy * cumprod;
for (size_t j = 0; j < down_space; j++) {
auto* cumsum_zeros_number_ptr = j + dx_ptr_base;
size_t first_zero_index = space;
// Find index of first zero in input.
for (size_t k = 0; k < space; k++) {
if (cumsum_zeros_number_ptr[j + k * down_space] == 1) {
first_zero_index = k;
break;
}
}
for (size_t j = 0; j < down_space; j++) {
*(dx_ptr + first_zero[j] * down_space) = cumsum_buffer[j];
// Suppose z is index of first zero element in input,
// for element which index is less than z grad is computed as below:
T reverse_cumsum = 0;
for (size_t k = 0; k < first_zero_index; k++) {
const size_t data_offset = j + (first_zero_index - k - 1) * down_space;
reverse_cumsum += output_ptr_base[data_offset] * dy_ptr_base[data_offset];
dx_ptr_base[data_offset] = reverse_cumsum / input_ptr_base[data_offset];
liufengwei0103 marked this conversation as resolved.
Show resolved Hide resolved
}
// For where index is z, its grad is computed as below:
if (first_zero_index == space) { continue; }
T cumprod = 1;
T cumsum = 0;
T cumprod_before_first_zero =
first_zero_index == 0 ? 1 : output_ptr_base[(first_zero_index - 1) * down_space];
for (size_t k = first_zero_index; k < space; k++) {
const size_t data_offset = j + k * down_space;
// Recover dx_ptr default value
if (dx_ptr_base[data_offset] >= 1) { dx_ptr_base[data_offset] = 0; }
if (k != first_zero_index) { cumprod *= input_ptr_base[data_offset]; }
cumsum += cumprod_before_first_zero * dy_ptr_base[data_offset] * cumprod;
}
dx_ptr_base[j + first_zero_index * down_space] = cumsum;
}

input_ptr += step;
output_ptr += step;
dy_ptr += step;
dx_ptr += step;
}
}
} // namespace
Expand Down