diff --git a/oneflow/core/autograd/gradient_funcs/noncontiguous_binary_op.cpp b/oneflow/core/autograd/gradient_funcs/noncontiguous_binary_op.cpp new file mode 100644 index 00000000000..852c2b9cd2f --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/noncontiguous_binary_op.cpp @@ -0,0 +1,91 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include "oneflow/core/common/just.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/functional/functional.h" +#include "oneflow/core/functional/functional_api.yaml.h" + +namespace oneflow { +namespace one { + +struct NonContiguousBinaryOpCaptureState : public AutoGradCaptureState { + bool lhs_requires_grad = false; + bool rhs_requires_grad = false; + std::string op = "add"; + bool inplace = false; +}; + +class NonContiguousBinaryOp : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override; + Maybe Capture(NonContiguousBinaryOpCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe Apply(const NonContiguousBinaryOpCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override; + + private: + AttrMap base_attrs_; +}; + +Maybe NonContiguousBinaryOp::Init(const OpExpr& op) { + const UserOpExpr* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); +} + +Maybe NonContiguousBinaryOp::Capture(NonContiguousBinaryOpCaptureState* ctx, + const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const { + ctx->lhs_requires_grad = inputs.at(0)->requires_grad(); + ctx->rhs_requires_grad = inputs.at(1)->requires_grad(); + if (!ctx->lhs_requires_grad && !ctx->rhs_requires_grad) { return Maybe::Ok(); } + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->inplace = JUST(composed_attrs.GetAttr("inplace")); + ctx->op = JUST(composed_attrs.GetAttr("op")); + if (ctx->inplace && ctx->rhs_requires_grad) { + CHECK_OR_RETURN(ctx->op == "add" || ctx->op == "sub") + << "when inplace and rhs requires grad, op should be add/sub"; + } + ctx->SaveTensorForBackward(inputs.at(0)); + ctx->SaveTensorForBackward(inputs.at(1)); + return Maybe::Ok(); +} + +Maybe NonContiguousBinaryOp::Apply(const NonContiguousBinaryOpCaptureState* ctx, + const TensorTuple& out_grads, + TensorTuple* in_grads) const { + if (!ctx->lhs_requires_grad && !ctx->rhs_requires_grad) { return Maybe::Ok(); } + CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) + in_grads->resize(2); + auto lhs = ctx->SavedTensors().at(0); + auto rhs = ctx->SavedTensors().at(1); + auto ret = JUST(functional::NonContiguousBinaryOpGrad(out_grads.at(0), lhs, rhs, ctx->op, false)); + if (ctx->lhs_requires_grad) in_grads->at(0) = ret->at(0); + if (ctx->rhs_requires_grad) in_grads->at(1) = ret->at(1); + return Maybe::Ok(); +} + +REGISTER_OP_EXPR_GRAD_FUNCTION("noncontiguous_binary_op", NonContiguousBinaryOp); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 7b13c3e33ef..71ef2673ac7 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -2696,6 +2696,14 @@ signature: 'Tensor (Tensor y, Tensor dy, Float scale=0.35355) => FusedScaleMaskBiasSoftmaxGrad' bind_python: False +- name: "noncontiguous_binary_op" + signature: 'Tensor (Tensor lhs, Tensor rhs, String op="add", Bool inplace=False) => NonContiguousBinaryOp' + bind_python: True + +- name: "noncontiguous_binary_op_grad" + signature: 'TensorTuple (Tensor dy, Tensor lhs, Tensor rhs, String op="add", Bool inplace=False) => NonContiguousBinaryOpGrad' + bind_python: False + - name: "fused_get_center_dist" signature: "Tensor (Tensor b1_x1, Tensor b1_x2, Tensor b2_x1, Tensor b2_x2, Tensor b1_y1, Tensor b1_y2, Tensor b2_y1, Tensor b2_y2) => FusedCenter" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 8fb71e103d3..79a9dc2593b 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -5345,6 +5345,55 @@ class FusedScaleMaskBiasSoftmaxGradFunctor { std::shared_ptr op_; }; +class NonContiguousBinaryOpFunctor { + public: + NonContiguousBinaryOpFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("noncontiguous_binary_op").Input("lhs").Input("rhs").Output("y").Build()); + } + + Maybe operator()(const std::shared_ptr& lhs, const std::shared_ptr& rhs, + const std::string& op, const bool& inplace = false) const { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("op", "inplace"); + attrs.SetAllAttrs(op, inplace); + if (inplace) { + std::shared_ptr outputs = std::make_shared(1); + outputs->at(0) = lhs; + JUST(OpInterpUtil::Dispatch(*op_, {lhs, rhs}, outputs.get(), attrs)); + return outputs->at(0); + } + return OpInterpUtil::Dispatch(*op_, {lhs, rhs}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class NonContiguousBinaryOpGradFunctor { + public: + NonContiguousBinaryOpGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("noncontiguous_binary_op_grad") + .Input("dy") + .Input("lhs") + .Input("rhs") + .Output("dlhs") + .Output("drhs") + .Build()); + } + + Maybe operator()(const std::shared_ptr& dy, + const std::shared_ptr& lhs, + const std::shared_ptr& rhs, const std::string& op, + const bool& inplace = false) const { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("op", "inplace"); + attrs.SetAllAttrs(op, inplace); + return OpInterpUtil::Dispatch(*op_, {dy, lhs, rhs}, attrs); + } + + private: + std::shared_ptr op_; +}; + } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { @@ -5478,6 +5527,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("SkipRMSNorm"); m.add_functor("FusedScaleMaskBiasSoftmax"); m.add_functor("FusedScaleMaskBiasSoftmaxGrad"); + m.add_functor("NonContiguousBinaryOp"); + m.add_functor("NonContiguousBinaryOpGrad"); m.add_functor("MultiTensorYoloV5WeightUpdate"); } diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index db1ea5bfb44..b549b767a63 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -3336,6 +3336,44 @@ def OneFlow_FusedCodegeexQkvReshapeOp : OneFlow_BaseOp<"fused_codegeex_qkv_resha let has_data_type_infer_fn = 1; } +def OneFlow_NonContiguousBinaryOp : OneFlow_BaseOp<"noncontiguous_binary_op", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$lhs, + OneFlow_Tensor:$rhs + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$op, + DefaultValuedAttr:$inplace + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NonContiguousBinaryOpGrad : OneFlow_BaseOp<"noncontiguous_binary_op_grad", [NoMemoryEffect, SupportNonContiguous, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$lhs, + OneFlow_Tensor:$rhs + ); + let output = (outs + OneFlow_Tensor:$dlhs, + OneFlow_Tensor:$drhs + ); + let attrs = (ins + DefaultValuedAttr:$op, + DefaultValuedAttr:$inplace + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + #endif // GET_ONEFLOW_FUSED_OP_DEFINITIONS diff --git a/oneflow/user/kernels/noncontiguous_binary_op.cu b/oneflow/user/kernels/noncontiguous_binary_op.cu new file mode 100644 index 00000000000..90e3e4ffa49 --- /dev/null +++ b/oneflow/user/kernels/noncontiguous_binary_op.cu @@ -0,0 +1,470 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/device/cuda_util.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/user_op_tensor.h" +#include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/ep/include/primitive/fast_integer_math.h" +#include "oneflow/core/cuda/elementwise.cuh" + +namespace oneflow { + +namespace { + +#define MaxDims 6 +#define MAX2(a, b) ((a) > (b)) ? (a) : (b) +#define MAX3(a, b, c) MAX2(MAX2(a, b), c) + +using cuda::elementwise::Packed; + +#define DEFINE_BINARY_FUNCTOR(OP, expr) \ + template \ + struct OP { \ + __device__ __forceinline__ T operator()(const T& a, const T& b) const { return a expr b; } \ + }; \ + template<> \ + struct OP { \ + __device__ __forceinline__ half operator()(const half& a, const half& b) const { \ + return __float2half(__half2float(a) expr __half2float(b)); \ + } \ + }; + +DEFINE_BINARY_FUNCTOR(Add, +) +DEFINE_BINARY_FUNCTOR(Sub, -) +DEFINE_BINARY_FUNCTOR(Mul, *) +DEFINE_BINARY_FUNCTOR(Div, /) +#undef DEFINE_BINARY_FUNCTOR + +#define DEFINE_BINARY_OP_GRAD_FUNCTOR(OP, dl_expr, dr_expr) \ + template \ + struct OP##Grad { \ + __device__ __forceinline__ void operator()(const T& dout, const T& a, const T& b, T* da, \ + T* db) const { \ + *da = dl_expr dout; \ + *db = dr_expr dout; \ + } \ + }; \ + template<> \ + struct OP##Grad { \ + __device__ __forceinline__ void operator()(const half& hdout, const half& ha, const half& hb, \ + half* hda, half* hdb) const { \ + float dout, a, b; \ + dout = __half2float(hdout), a = __half2float(ha), b = __half2float(hb); \ + *hda = __float2half(dl_expr dout); \ + *hdb = __float2half(dr_expr dout); \ + } \ + }; + +DEFINE_BINARY_OP_GRAD_FUNCTOR(Add, 1 *, 1 *) +DEFINE_BINARY_OP_GRAD_FUNCTOR(Sub, 1 *, -1 *) +DEFINE_BINARY_OP_GRAD_FUNCTOR(Mul, b*, a*) +DEFINE_BINARY_OP_GRAD_FUNCTOR(Div, 1 / b*, -a / b / b*) +#undef DEFINE_BINARY_OP_GRAD_FUNCTOR + +template +__global__ void noncontiguous_binary_op_kernel(IndexType n_pack, Store y, Loader1 x1, Loader2 x2) { + Packed pack_y; + Packed pack_x1; + Packed pack_x2; + CUDA_1D_KERNEL_LOOP_T(IndexType, i, n_pack) { + x1.load(i, &pack_x1); + x2.load(i, &pack_x2); +#pragma unroll + for (int j = 0; j < pack_size; ++j) + pack_y.elem[j] = BinaryOp()(static_cast(pack_x1.elem[j]), + static_cast(pack_x2.elem[j])); // todo: Apply2 + y.store(i, &pack_y); + } +}; + +template +struct LoadStore { + LoadStore(FastIntegerMath fast_integer_math[MaxDims], const int ndims, const int strides[MaxDims], + const Src* src, Dst* dst = nullptr, bool is_contiguous = false) + : ndims_(ndims), src_(src), dst_(dst), is_contiguous_(is_contiguous) { + for (int i = 0; i < ndims; i++) { + strides_[i] = static_cast(strides[i]); + fast_integer_math_[i] = fast_integer_math[i]; + } + } + + OF_DEVICE_FUNCTION IndexType index2offset(IndexType index) { + IndexType offset = 0; + IndexType div = 0, mod = 0; +#pragma unroll + for (int dim = ndims_ - 1; dim >= 0; --dim) { + if (index == 0) break; + fast_integer_math_[dim].divmod(index, &div, &mod); + index = div; + offset += mod * strides_[dim]; + } + return offset; + } + + OF_DEVICE_FUNCTION void load(IndexType idx, Packed* pack) { + IndexType offset; + if (is_contiguous_) + offset = idx * pack_size; + else + offset = index2offset(idx); + *pack = *(reinterpret_cast*>(src_ + offset)); + } + + OF_DEVICE_FUNCTION void store(IndexType idx, Packed* pack) { + IndexType offset; + if (is_contiguous_) + offset = idx * pack_size; + else + offset = index2offset(idx); + *(reinterpret_cast*>(dst_ + offset)) = *pack; + } + + int ndims_; + int pack_dim_; + bool is_contiguous_; + const Src* src_; + Dst* dst_; + IndexType strides_[MaxDims]; + FastIntegerMath fast_integer_math_[MaxDims]; +}; + +template +void launch_noncontiguous_binary_op_kernel(cudaStream_t stream, const IndexType n_pack, + Store& store, Load1& load1, Load2& load2) { + int num_blocks = 1, block_size = cuda::elementwise::kBlockSize; + cudaError_t err = cuda::elementwise::GetNumBlocks(n_pack, &num_blocks); + CHECK(err == cudaSuccess); + noncontiguous_binary_op_kernel + <<>>(n_pack, store, load1, load2); +} + +template +void dispatchOp(cudaStream_t stream, const std::string& op, const IndexType n_pack, Store& store, + Load1& load1, Load2& load2) { + if (op == "add") + launch_noncontiguous_binary_op_kernel, R, lhs, rhs>( + stream, n_pack, store, load1, load2); + else if (op == "sub") + launch_noncontiguous_binary_op_kernel, R, lhs, rhs>( + stream, n_pack, store, load1, load2); + else if (op == "mul") + launch_noncontiguous_binary_op_kernel, R, lhs, rhs>( + stream, n_pack, store, load1, load2); + else if (op == "div") + launch_noncontiguous_binary_op_kernel, R, lhs, rhs>( + stream, n_pack, store, load1, load2); + else + UNIMPLEMENTED_THEN_THROW(); +} + +template +void dispatchInplace(cudaStream_t stream, const bool inplace, const std::string& op, + const int& ndims, const IndexType n_pack, const int sizes[MaxDims], + const int strides[][MaxDims], R* y, const lhs* x1, const rhs* x2) { + typedef FastIntegerMath FastIntegerMathT; + FastIntegerMathT fast_integer_math[MaxDims]; + for (int i = 0; i < ndims; ++i) fast_integer_math[i] = FastIntegerMathT(sizes[i]); + if (inplace) { + LoadStore load_store(fast_integer_math, ndims, + strides[0], x1, y); + LoadStore loader2(fast_integer_math, ndims, + strides[2], x2); + dispatchOp(stream, op, n_pack, load_store, load_store, + loader2); + } else { + LoadStore store(fast_integer_math, ndims, + strides[0], nullptr, y); + LoadStore loader1(fast_integer_math, ndims, + strides[1], x1); + + LoadStore loader2(fast_integer_math, ndims, + strides[2], x2); + dispatchOp(stream, op, n_pack, store, loader1, loader2); + } +} + +template +void dispatchIndexType(cudaStream_t stream, const bool inplace, const std::string& op, + const int& ndims, const int64_t& n_pack, const int sizes[MaxDims], + const int strides[][MaxDims], R* y, const lhs* x1, const rhs* x2) { + if ((n_pack * pack_size) >> 30 == 0) { + int32_t n = (int32_t)n_pack; + dispatchInplace(stream, inplace, op, ndims, n, sizes, strides, + y, x1, x2); + } else + dispatchInplace(stream, inplace, op, ndims, n_pack, sizes, + strides, y, x1, x2); +} + +template +void dispatchPacksize(cudaStream_t stream, const bool inplace, const std::string& op, + const int& ndims, const int64_t n_pack, int pack_size, + const int sizes[MaxDims], const int strides[][MaxDims], R* y, const lhs* x1, + const rhs* x2) { + if (pack_size == 8) + dispatchIndexType<8, R, lhs, rhs>(stream, inplace, op, ndims, n_pack, sizes, strides, y, x1, + x2); + else if (pack_size == 4) + dispatchIndexType<4, R, lhs, rhs>(stream, inplace, op, ndims, n_pack, sizes, strides, y, x1, + x2); + else if (pack_size == 2) + dispatchIndexType<2, R, lhs, rhs>(stream, inplace, op, ndims, n_pack, sizes, strides, y, x1, + x2); + else if (pack_size == 1) + dispatchIndexType<1, R, lhs, rhs>(stream, inplace, op, ndims, n_pack, sizes, strides, y, x1, + x2); + else + UNIMPLEMENTED(); +} +} // namespace + +template +class NonContiguousBinaryOpKernel final : public user_op::OpKernel { + public: + NonContiguousBinaryOpKernel() = default; + ~NonContiguousBinaryOpKernel() override = default; + + private: + using user_op::OpKernel::Compute; + void Compute(user_op::KernelComputeContext* ctx) const override { + user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); + const user_op::Tensor* x1 = ctx->Tensor4ArgNameAndIndex("lhs", 0); + const user_op::Tensor* x2 = ctx->Tensor4ArgNameAndIndex("rhs", 0); + const std::string op = ctx->Attr("op"); + const bool inplace = ctx->Attr("inplace"); + int ndims = y->shape_view().NumAxes(); + const ShapeView& shape = y->shape_view(); + int sizes[MaxDims]; + int strides[3][MaxDims]; + + int pack_size = 1; + int64_t elem_cnt = 1; + int max_elem_size = MAX3(GetSizeOfDataType(y->data_type()), GetSizeOfDataType(x1->data_type()), + GetSizeOfDataType(x2->data_type())); + for (int i = 0; i < ndims; ++i) { + sizes[i] = shape.At(i); + elem_cnt *= shape.At(i); + strides[0][i] = y->stride()[i]; + strides[1][i] = x1->stride()[i]; + strides[2][i] = x2->stride()[i]; + if (x1->stride()[i] == 1 && x2->stride()[i] == 1 && y->stride()[i] == 1) { + pack_size = 16 / max_elem_size; + while (pack_size > 1 && sizes[i] % pack_size) pack_size >>= 1; + sizes[i] = sizes[i] / pack_size; + strides[0][i] *= pack_size; + strides[1][i] *= pack_size; + strides[2][i] *= pack_size; + } + } + + dispatchPacksize(ctx->stream()->As()->cuda_stream(), inplace, op, ndims, + elem_cnt / pack_size, pack_size, sizes, strides, y->mut_dptr(), + x1->dptr(), x2->dptr()); + } + + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_KERNEL(dtype, lhs, rhs) \ + REGISTER_USER_KERNEL("noncontiguous_binary_op") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("y", 0) == GetDataType::value) \ + && (user_op::HobDataType("lhs", 0) == GetDataType::value) \ + && (user_op::HobDataType("rhs", 0) == GetDataType::value)); + +// output_type, lhs_type, rhs_type +REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_KERNEL(float, float, float) +REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_KERNEL(half, half, half) +// #if CUDA_VERSION >= 11000 +// REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_KERNEL(nv_bfloat16, nv_bfloat16, nv_bfloat16) +// #endif + +// ------------------------------------- grad kernel ------------------------------------- +template +__global__ void noncontiguous_binary_op_grad_kernel(IndexType n_pack, Loadery dy, Loader1 load1, + Loader2 load2) { + Packed pack_dy; + Packed pack_x1; + Packed pack_x2; + Packed pack_dx1; + Packed pack_dx2; + CUDA_1D_KERNEL_LOOP_T(IndexType, i, n_pack) { + load1.load(i, &pack_x1); + load2.load(i, &pack_x2); + dy.load(i, &pack_dy); +#pragma unroll + for (int j = 0; j < pack_size; ++j) + BinaryOp()(pack_dy.elem[j], pack_x1.elem[j], pack_x2.elem[j], &pack_dx1.elem[j], + &pack_dx2.elem[j]); // todo: Apply2 + load1.store(i, &pack_dx1); + load2.store(i, &pack_dx2); + } +}; + +template +void launch_noncontiguous_binary_op_grad_kernel(cudaStream_t stream, const IndexType n_pack, + Loady& load_y, Load1& load1, Load2& load2) { + int num_blocks = 1, block_size = cuda::elementwise::kBlockSize; + cudaError_t err = cuda::elementwise::GetNumBlocks(n_pack, &num_blocks); + CHECK(err == cudaSuccess); + noncontiguous_binary_op_grad_kernel + <<>>(n_pack, load_y, load1, load2); +} + +template +void dispatchOpGrad(cudaStream_t stream, const std::string& op, const IndexType& n_pack, + Loady& load_y, Load1& load1, Load2& load2) { + if (op == "add") + launch_noncontiguous_binary_op_grad_kernel, R, lhs, rhs>( + stream, n_pack, load_y, load1, load2); + else if (op == "sub") + launch_noncontiguous_binary_op_grad_kernel, R, lhs, rhs>( + stream, n_pack, load_y, load1, load2); + else if (op == "mul") + launch_noncontiguous_binary_op_grad_kernel, R, lhs, rhs>( + stream, n_pack, load_y, load1, load2); + else if (op == "div") + launch_noncontiguous_binary_op_grad_kernel, R, lhs, rhs>( + stream, n_pack, load_y, load1, load2); + else + UNIMPLEMENTED_THEN_THROW(); +} + +template +void dispatchLoader(cudaStream_t stream, const std::string& op, const int& ndims, + const IndexType n_pack, const int sizes[MaxDims], const int strides[][MaxDims], + lhs* dx1, rhs* dx2, const R* dy, const lhs* x1, const rhs* x2) { + typedef FastIntegerMath FastIntegerMathT; + FastIntegerMathT fast_integer_math[MaxDims]; + for (int i = 0; i < ndims; ++i) fast_integer_math[i] = FastIntegerMathT(sizes[i]); + LoadStore load_y(fast_integer_math, ndims, + strides[0], dy); + LoadStore loader_store1( + fast_integer_math, ndims, strides[1], x1, dx1); + + LoadStore loader_store2( + fast_integer_math, ndims, strides[2], x2, dx2); + dispatchOpGrad(stream, op, n_pack, load_y, loader_store1, + loader_store2); +} + +template +void dispatchIndexTypeGrad(cudaStream_t stream, const std::string& op, const int& ndims, + const int64_t& n_pack, const int sizes[MaxDims], + const int strides[][MaxDims], lhs* dx1, rhs* dx2, const R* dy, + const lhs* x1, const rhs* x2) { + if ((n_pack * pack_size) >> 30 == 0) { + int32_t n = (int32_t)n_pack; + dispatchLoader(stream, op, ndims, n, sizes, strides, dx1, dx2, + dy, x1, x2); + } else + dispatchLoader(stream, op, ndims, n_pack, sizes, strides, dx1, + dx2, dy, x1, x2); +} + +template +void dispatchPacksizeGrad(cudaStream_t stream, const std::string& op, const int& ndims, + const int64_t& n_pack, int& pack_size, const int sizes[MaxDims], + const int strides[][MaxDims], lhs* dx1, rhs* dx2, const R* dy, + const lhs* x1, const rhs* x2) { + if (pack_size == 8) + dispatchIndexTypeGrad<8, R, lhs, rhs>(stream, op, ndims, n_pack, sizes, strides, dx1, dx2, dy, + x1, x2); + else if (pack_size == 4) + dispatchIndexTypeGrad<4, R, lhs, rhs>(stream, op, ndims, n_pack, sizes, strides, dx1, dx2, dy, + x1, x2); + else if (pack_size == 2) + dispatchIndexTypeGrad<2, R, lhs, rhs>(stream, op, ndims, n_pack, sizes, strides, dx1, dx2, dy, + x1, x2); + else if (pack_size == 1) + dispatchIndexTypeGrad<1, R, lhs, rhs>(stream, op, ndims, n_pack, sizes, strides, dx1, dx2, dy, + x1, x2); + else + UNIMPLEMENTED(); +} + +template +class NonContiguousBinaryOpGradKernel final : public user_op::OpKernel { + public: + NonContiguousBinaryOpGradKernel() = default; + ~NonContiguousBinaryOpGradKernel() override = default; + + private: + using user_op::OpKernel::Compute; + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + const user_op::Tensor* x1 = ctx->Tensor4ArgNameAndIndex("lhs", 0); + const user_op::Tensor* x2 = ctx->Tensor4ArgNameAndIndex("rhs", 0); + user_op::Tensor* dx1 = ctx->Tensor4ArgNameAndIndex("dlhs", 0); + user_op::Tensor* dx2 = ctx->Tensor4ArgNameAndIndex("drhs", 0); + const std::string op = ctx->Attr("op"); + const bool inplace = ctx->Attr("inplace"); + CHECK(inplace == false) << "inplace should be set to `false` to compute gradients."; + int ndims = dy->shape_view().NumAxes(); + const ShapeView& shape = dy->shape_view(); + int sizes[MaxDims]; + int strides[3][MaxDims]; + + int pack_size = 1; + int64_t elem_cnt = 1; + int max_elem_size = MAX3(GetSizeOfDataType(dy->data_type()), GetSizeOfDataType(x1->data_type()), + GetSizeOfDataType(x2->data_type())); + for (int i = 0; i < ndims; ++i) { + sizes[i] = shape.At(i); + elem_cnt *= shape.At(i); + strides[0][i] = dy->stride()[i]; + strides[1][i] = x1->stride()[i]; + strides[2][i] = x2->stride()[i]; + if (x1->stride()[i] == 1 && x2->stride()[i] == 1 && dy->stride()[i] == 1) { + pack_size = 16 / max_elem_size; + while (pack_size > 1 && sizes[i] % pack_size) pack_size >>= 1; + sizes[i] = sizes[i] / pack_size; + strides[0][i] *= pack_size; + strides[1][i] *= pack_size; + strides[2][i] *= pack_size; + } + } + + dispatchPacksizeGrad(ctx->stream()->As()->cuda_stream(), op, ndims, + elem_cnt / pack_size, pack_size, sizes, strides, dx1->mut_dptr(), + dx2->mut_dptr(), dy->dptr(), x1->dptr(), x2->dptr()); + } + + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_GRAD_KERNEL(dtype, lhs, rhs) \ + REGISTER_USER_KERNEL("noncontiguous_binary_op_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("dy", 0) == GetDataType::value) \ + && (user_op::HobDataType("lhs", 0) == GetDataType::value) \ + && (user_op::HobDataType("rhs", 0) == GetDataType::value)); + +// output_type, lhs_type, rhs_type +REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_GRAD_KERNEL(float, float, float) +REGISTER_USER_KERNEL_NONCONTIGUOUS_BINARY_OP_GRAD_KERNEL(half, half, half) + +} // namespace oneflow diff --git a/oneflow/user/ops/noncontiguous_binary_op.cpp b/oneflow/user/ops/noncontiguous_binary_op.cpp new file mode 100644 index 00000000000..86f0e381dfc --- /dev/null +++ b/oneflow/user/ops/noncontiguous_binary_op.cpp @@ -0,0 +1,95 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/shape.h" +#include "oneflow/core/common/shape_view.h" +#include "oneflow/core/common/stride.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" + +namespace oneflow { + +/*static*/ Maybe NonContiguousBinaryOp::GetSbp(user_op::SbpContext* ctx) { + // only support broadcast + ctx->NewBuilder() + .Broadcast(user_op::OpArg("lhs", 0)) + .Broadcast(user_op::OpArg("rhs", 0)) + .Broadcast(user_op::OpArg("y", 0)) + .Build(); + return Maybe::Ok(); +} + +/*static*/ Maybe NonContiguousBinaryOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& lhs = ctx->InputShape("lhs", 0); + const Shape& rhs = ctx->InputShape("rhs", 0); + CHECK_EQ(lhs.NumAxes(), rhs.NumAxes()); + for (int i = 0; i < lhs.NumAxes(); i++) CHECK_EQ(lhs.At(i), rhs.At(i)); + ctx->SetOutputShape("y", 0, lhs); + const bool inplace = ctx->Attr("inplace"); + if (inplace) { + ctx->SetOutputStride("y", 0, ctx->InputStride("lhs", 0)); + } else { // set contiguous for y if not inplace + ctx->SetOutputStride("y", 0, Stride(lhs)); + } + return Maybe::Ok(); +} +/*static*/ Maybe NonContiguousBinaryOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe NonContiguousBinaryOp::InferDataType(user_op::InferContext* ctx) { + auto lhs = ctx->InputDType("lhs", 0); + auto rhs = ctx->InputDType("rhs", 0); + ctx->SetOutputDType("y", 0, GetSizeOfDataType(lhs) >= GetSizeOfDataType(rhs) ? lhs : rhs); + return Maybe::Ok(); +} + +/*static*/ Maybe NonContiguousBinaryOpGrad::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("lhs", 0)) + .Broadcast(user_op::OpArg("rhs", 0)) + .Broadcast(user_op::OpArg("dy", 0)) + .Broadcast(user_op::OpArg("dlhs", 0)) + .Broadcast(user_op::OpArg("drhs", 0)) + .Build(); + return Maybe::Ok(); +} + +/*static*/ Maybe NonContiguousBinaryOpGrad::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const Shape& lhs = ctx->InputShape("lhs", 0); + const Shape& rhs = ctx->InputShape("rhs", 0); + CHECK_EQ(lhs.NumAxes(), rhs.NumAxes()); + for (int i = 0; i < lhs.NumAxes(); i++) CHECK_EQ(lhs.At(i), rhs.At(i)); + ctx->SetOutputShape("dlhs", 0, lhs); + ctx->SetOutputStride("dlhs", 0, ctx->InputStride("lhs", 0)); + ctx->SetOutputShape("drhs", 0, rhs); + ctx->SetOutputStride("drhs", 0, ctx->InputStride("rhs", 0)); + return Maybe::Ok(); +} +/*static*/ Maybe NonContiguousBinaryOpGrad::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe NonContiguousBinaryOpGrad::InferDataType(user_op::InferContext* ctx) { + auto lhs = ctx->InputDType("lhs", 0); + auto rhs = ctx->InputDType("rhs", 0); + ctx->SetOutputDType("dlhs", 0, lhs); + ctx->SetOutputDType("drhs", 0, rhs); + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/python/oneflow/test/modules/test_noncontiguous_binary_op.py b/python/oneflow/test/modules/test_noncontiguous_binary_op.py new file mode 100644 index 00000000000..783c6762aff --- /dev/null +++ b/python/oneflow/test/modules/test_noncontiguous_binary_op.py @@ -0,0 +1,83 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from collections import OrderedDict + +import numpy as np +import oneflow as flow +from oneflow.test_utils.test_util import GenArgList + + +def _test_op(test_case, x, y, inplace): + ref1 = x + y + out1 = flow._C.noncontiguous_binary_op(x, y, op="add", inplace=inplace) + test_case.assertTrue(np.allclose(ref1.numpy(), out1.numpy(), rtol=1e-5, atol=1e-5)) + + ref2 = x - y + out2 = flow._C.noncontiguous_binary_op(x, y, op="sub", inplace=inplace) + test_case.assertTrue(np.allclose(ref2.numpy(), out2.numpy(), rtol=1e-5, atol=1e-5)) + + ref3 = x * y + out3 = flow._C.noncontiguous_binary_op(x, y, op="mul", inplace=inplace) + test_case.assertTrue(np.allclose(ref3.numpy(), out3.numpy(), rtol=1e-5, atol=1e-5)) + + y = y.abs() + 1e-3 # incase zero + ref4 = x / y + out4 = flow._C.noncontiguous_binary_op(x, y, op="div", inplace=inplace) + print(np.abs(ref4 - out4).max()) + test_case.assertTrue(np.allclose(ref4.numpy(), out4.numpy(), rtol=1e-3, atol=1e-3)) + + +def _test_noncontiguous_binary_op(test_case, dtype, pack_size, ndims, inplace): + shape = [] + for _ in range(ndims - 1): + if np.random.uniform(-1, 1) > 0: + shape.append(1 << np.random.randint(4, 7)) + else: + shape.append(np.random.randint(20, 100)) + shape.append(1 << np.random.randint(3, 7) + pack_size) + # case 1 + x = flow.randn(*shape, requires_grad=True).cuda().to(dtype) + y = flow.randn(*shape, requires_grad=True).cuda().to(dtype) + d1, d2 = np.random.choice(ndims, 2, replace=False) + x1 = x.transpose(d1, d2) + y1 = y.transpose(d1, d2) + _test_op(test_case, x1, y1, inplace) + + # case 2 + y2 = flow.randn(*shape, requires_grad=True).cuda().to(dtype) + shape[d1], shape[d2] = shape[d2], shape[d1] + x = flow.randn(*shape, requires_grad=True).cuda().to(dtype) + x2 = x.transpose(d1, d2) + _test_op(test_case, x2, y2, inplace) + + +@unittest.skipIf(True, "skip test for noncontiguous_binary_op.") +@flow.unittest.skip_unless_1n1d() +class TestNonContiguousBinaryOp(flow.unittest.TestCase): + def test_noncontiguous_binary_op(test_case): + arg_dict = OrderedDict() + arg_dict["test_fn"] = [_test_noncontiguous_binary_op] + arg_dict["dtype"] = [flow.float16, flow.float32] + arg_dict["pack_size"] = [1, 2, 4] + arg_dict["ndims"] = [2, 3, 4] + arg_dict["inplace"] = [True, False] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + +if __name__ == "__main__": + unittest.main()