-
Notifications
You must be signed in to change notification settings - Fork 793
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add cumsum op's forward definition * add cumsum forward test case * cumsum ver3 * remove calculating time * add cumsum forward gpu implementation * fix gpu forward error * change var name * remove annotation * add cumsum cpu forward multi-thread support * add multi-thread annotation * add cumsum grad definition * update * add cumsum cpu backward * add cumsum cpu backward functor * add cumsum autograd * update * remove user interface * use random method to test cumsum forward * add cumsum gpu backward * add cumsum gpu test * fix gpu backward bug * add a 3d cuda kernel try * Revert "add cumsum gpu test" This reverts commit 05c31556ba28ecb827b25e54c2f5fa38984e8096. * Revert "Revert "add cumsum gpu test"" This reverts commit 918ee1569863b008c1d419c3528257416cffd840. * change nele to ele_cnt * add test_cumsum.py in oneflow/test/modules * change original test_cumsum to autotest version * optimize cumsum for special up_space and down_space * add two special cu func * add cumsum doc * update doc * update doc * update code according to bbuf's review * ditto * change pin/pout to in_ptr/out_ptr * remove multi-thread func * update doc * use tensor processor * update by review * update by review * update * update * auto format by CI * auto format by CI * update doc * update Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
- Loading branch information
1 parent
eabe79e
commit 9869a3f
Showing
10 changed files
with
679 additions
and
75 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -152,5 +152,6 @@ oneflow | |
decode_onerec, | ||
read_onerec, | ||
from_numpy, | ||
cumsum, | ||
|
||
.. autofunction:: oneflow.relu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#include "oneflow/core/framework/op_expr_grad_function.h" | ||
#include "oneflow/core/functional/functional.h" | ||
|
||
namespace oneflow { | ||
namespace one { | ||
|
||
struct CumsumCaptureState : public AutoGradCaptureState { | ||
bool requires_grad = false; | ||
int64_t dim = 0; | ||
}; | ||
|
||
class CumsumGrad : public OpExprGradFunction<CumsumCaptureState> { | ||
public: | ||
Maybe<void> Init(const OpExpr& op) override { | ||
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op); | ||
CHECK_NOTNULL_OR_RETURN(fw_op_expr); | ||
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> Capture(CumsumCaptureState* ctx, const TensorTuple& inputs, | ||
const TensorTuple& outputs, const AttrMap& attrs) const override { | ||
CHECK_EQ_OR_RETURN(inputs.size(), 1); | ||
ctx->requires_grad = inputs.at(0)->requires_grad(); | ||
if (!ctx->requires_grad) { return Maybe<void>::Ok(); } | ||
|
||
ComposedAttrMap composed_attrs(attrs, base_attrs_); | ||
ctx->dim = JUST(composed_attrs.GetAttr<int64_t>("dim")); | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
Maybe<void> Apply(const CumsumCaptureState* ctx, const TensorTuple& out_grads, | ||
TensorTuple* in_grads) const override { | ||
CHECK_EQ_OR_RETURN(out_grads.size(), 1); | ||
in_grads->resize(1); | ||
if (ctx->requires_grad) { | ||
in_grads->at(0) = JUST(functional::CumsumGrad(out_grads.at(0), ctx->dim)); | ||
} | ||
return Maybe<void>::Ok(); | ||
} | ||
|
||
private: | ||
AttrMap base_attrs_; | ||
}; | ||
|
||
REGISTER_OP_EXPR_GRAD_FUNCTION("cumsum", CumsumGrad); | ||
|
||
} // namespace one | ||
} // namespace oneflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#include "oneflow/core/framework/framework.h" | ||
|
||
namespace oneflow { | ||
|
||
namespace { | ||
template<typename T> | ||
void cumsum_forward(const T* in_ptr, T* out_ptr, int64_t cs_up_space, int64_t cs_space, | ||
int64_t cs_down_space, int64_t elem_cnt) { | ||
std::copy_n(in_ptr, elem_cnt, out_ptr); | ||
auto* tmp_out_ptr_base = out_ptr; | ||
auto step = cs_space * cs_down_space; | ||
for (auto i = 0; i < cs_up_space; i++) { | ||
for (auto j = 1; j < cs_space; j++) { | ||
auto* tmp_out_ptr = tmp_out_ptr_base + j * cs_down_space; | ||
auto* last_tmp_out_ptr = tmp_out_ptr - cs_down_space; | ||
for (auto k = 0; k < cs_down_space; k++) { tmp_out_ptr[k] += last_tmp_out_ptr[k]; } | ||
} | ||
tmp_out_ptr_base += step; | ||
} | ||
} | ||
|
||
template<typename T> | ||
void cumsum_backward(const T* in_ptr, T* out_ptr, int64_t cs_up_space, int64_t cs_space, | ||
int64_t cs_down_space, int64_t elem_cnt) { | ||
auto* tmp_in_ptr_base = in_ptr; | ||
auto* tmp_out_ptr_base = out_ptr; | ||
auto step = cs_space * cs_down_space; | ||
for (auto i = 0; i < cs_up_space; i++) { | ||
for (auto j = 0; j < cs_space; j++) { | ||
auto* tmp_in_ptr = tmp_in_ptr_base + j * cs_down_space; | ||
auto* tmp_out_ptr = tmp_out_ptr_base + j * cs_down_space; | ||
std::fill_n(tmp_out_ptr, cs_down_space, cs_space - j); | ||
for (auto k = 0; k < cs_down_space; k++) { tmp_out_ptr[k] *= tmp_in_ptr[k]; } | ||
} | ||
tmp_in_ptr_base += step; | ||
tmp_out_ptr_base += step; | ||
} | ||
} | ||
} // namespace | ||
|
||
template<typename T> | ||
class CpuCumsumKernel final : public user_op::OpKernel { | ||
public: | ||
CpuCumsumKernel() = default; | ||
~CpuCumsumKernel() = default; | ||
|
||
private: | ||
void Compute(user_op::KernelComputeContext* ctx) const override { | ||
const auto* in = ctx->Tensor4ArgNameAndIndex("in", 0); | ||
auto elem_cnt = in->shape().elem_cnt(); | ||
// judge whether tensor has 0 size dimension first | ||
if (!elem_cnt) { return; } | ||
|
||
auto* out = ctx->Tensor4ArgNameAndIndex("out", 0); | ||
auto dim = ctx->Attr<int64_t>("dim"); | ||
const auto* in_ptr = in->dptr<T>(); | ||
auto* out_ptr = out->mut_dptr<T>(); | ||
|
||
// take cumsum's abbreviation as `cs` | ||
// data partition: cs_up_space|cs_space|cs_down_space | ||
auto cs_up_space = elem_cnt / in->shape().Count(dim); | ||
auto cs_space = in->shape().At(dim); | ||
auto cs_down_space = in->shape().Count(dim + 1); | ||
|
||
cumsum_forward<T>(in_ptr, out_ptr, cs_up_space, cs_space, cs_down_space, elem_cnt); | ||
} | ||
|
||
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } | ||
}; | ||
|
||
#define REGISTER_CUMSUM_KERNEL(dtype) \ | ||
REGISTER_USER_KERNEL("cumsum").SetCreateFn<CpuCumsumKernel<dtype>>().SetIsMatchedHob( \ | ||
(user_op::HobDeviceType() == DeviceType::kCPU) \ | ||
&& (user_op::HobDataType("out", 0) == GetDataType<dtype>::value)); | ||
|
||
REGISTER_CUMSUM_KERNEL(int64_t) | ||
REGISTER_CUMSUM_KERNEL(float) | ||
REGISTER_CUMSUM_KERNEL(double) | ||
|
||
template<typename T> | ||
class CpuCumsumGradKernel final : public user_op::OpKernel { | ||
public: | ||
CpuCumsumGradKernel() = default; | ||
~CpuCumsumGradKernel() = default; | ||
|
||
private: | ||
void Compute(user_op::KernelComputeContext* ctx) const override { | ||
const auto* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); | ||
auto* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); | ||
auto elem_cnt = dy->shape().elem_cnt(); | ||
auto dim = ctx->Attr<int64_t>("dim"); | ||
const auto* dy_ptr = dy->dptr<T>(); | ||
auto* dx_ptr = dx->mut_dptr<T>(); | ||
|
||
// data partition: cs_up_space|cs_space|cs_down_space | ||
auto cs_up_space = elem_cnt / dx->shape().Count(dim); | ||
auto cs_space = dx->shape().At(dim); | ||
auto cs_down_space = dx->shape().Count(dim + 1); | ||
|
||
cumsum_backward(dy_ptr, dx_ptr, cs_up_space, cs_space, cs_down_space, elem_cnt); | ||
} | ||
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } | ||
}; | ||
|
||
#define REGISTER_CPU_CUMSUM_GRAD_KERNEL(dtype) \ | ||
REGISTER_USER_KERNEL("cumsum_grad") \ | ||
.SetCreateFn<CpuCumsumGradKernel<dtype>>() \ | ||
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ | ||
&& (user_op::HobDataType("dx", 0) == GetDataType<dtype>::value)); | ||
|
||
REGISTER_CPU_CUMSUM_GRAD_KERNEL(float) | ||
REGISTER_CPU_CUMSUM_GRAD_KERNEL(double) | ||
|
||
} // namespace oneflow |
Oops, something went wrong.