From 677c76215b3a5bec936accad134d0780852cdc72 Mon Sep 17 00:00:00 2001 From: huwei02 <53012141+huwei02@users.noreply.github.com> Date: Thu, 8 Dec 2022 17:01:03 +0800 Subject: [PATCH] add unzip op (#183) --- .../ps/table/common_graph_table.cc | 8 + paddle/fluid/framework/data_feed.cc | 8 +- paddle/fluid/framework/data_feed.h | 5 + paddle/fluid/operators/unity_build_rule.cmake | 4 + paddle/fluid/operators/unzip_op.cc | 179 ++++++++++++++++++ paddle/fluid/operators/unzip_op.cu | 110 +++++++++++ paddle/fluid/operators/unzip_op.h | 42 ++++ python/paddle/fluid/contrib/layers/nn.py | 60 ++++++ .../fluid/tests/unittests/test_unzip_op.py | 53 ++++++ 9 files changed, 467 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/unzip_op.cc create mode 100644 paddle/fluid/operators/unzip_op.cu create mode 100644 paddle/fluid/operators/unzip_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_unzip_op.py diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.cc b/paddle/fluid/distributed/ps/table/common_graph_table.cc index 2b545047c3bfe..c059e1a3ff99e 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.cc +++ b/paddle/fluid/distributed/ps/table/common_graph_table.cc @@ -494,6 +494,8 @@ void GraphTable::export_partition_files(int idx, std::string file_path) { for (size_t i = 0; i < tasks.size(); i++) tasks[i].get(); } +#endif + void GraphTable::clear_graph(int idx) { for (auto p : edge_shards[idx]) { p->clear(); @@ -506,6 +508,7 @@ void GraphTable::clear_graph(int idx) { } } +#ifdef PADDLE_WITH_HETERPS void GraphTable::release_graph() { // Before releasing graph, prepare for sampling ids and embedding keys. build_graph_type_keys(); @@ -545,6 +548,7 @@ void GraphTable::release_graph_node() { feature_shrink_to_fit(); } } +#endif void GraphTable::clear_edge_shard() { VLOG(0) << "begin clear edge shard"; @@ -590,6 +594,7 @@ void GraphTable::clear_feature_shard() { VLOG(0) << "finish clear feature shard"; } +#ifdef PADDLE_WITH_HETERPS void GraphTable::feature_shrink_to_fit() { std::vector> tasks; for (auto &type_shards : feature_shards) { @@ -619,6 +624,8 @@ void GraphTable::merge_feature_shard() { feature_shards.resize(1); } +#endif + void GraphTable::clear_graph() { VLOG(0) << "begin clear_graph"; clear_edge_shard(); @@ -626,6 +633,7 @@ void GraphTable::clear_graph() { VLOG(0) << "finish clear_graph"; } +#ifdef PADDLE_WITH_HETERPS int32_t GraphTable::load_next_partition(int idx) { if (next_partition >= static_cast(partitions[idx].size())) { VLOG(0) << "partition iteration is done"; diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index f1b7d696a4ec0..dd4e3139f9dcf 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -2116,11 +2116,15 @@ void SlotRecordInMemoryDataFeed::Init(const DataFeedDesc& data_feed_desc) { #if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) void SlotRecordInMemoryDataFeed::InitGraphResource() { +#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) gpu_graph_data_generator_.AllocResource(thread_id_, feed_vec_); +#endif } void SlotRecordInMemoryDataFeed::InitGraphTrainResource() { +#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) gpu_graph_data_generator_.AllocTrainResource(thread_id_); +#endif } #endif @@ -2702,11 +2706,11 @@ int SlotRecordInMemoryDataFeed::Next() { #endif } -#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) void SlotRecordInMemoryDataFeed::DoWalkandSage() { +#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) gpu_graph_data_generator_.DoWalkandSage(); -} #endif +} #if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS) void SlotRecordInMemoryDataFeed::BuildSlotBatchGPU(const int ins_num) { diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 77bce79338161..045e1202ffa37 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -1137,6 +1137,11 @@ class DataFeed { virtual void SetDeviceKeys(std::vector* device_keys, int type) { gpu_graph_data_generator_.SetDeviceKeys(device_keys, type); } + virtual bool get_epoch_finish() { +#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) + return gpu_graph_data_generator_.get_epoch_finish(); +#else + return false; #endif virtual void SetGpuGraphMode(int gpu_graph_mode) { diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 891cb40ab28df..f1d629b5c86bd 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -64,6 +64,7 @@ register_unity_group( cudnn_lstm_op.cc cumsum_op.cc cvm_op.cc + unzip_op.cc data_norm_op.cc deformable_conv_op.cc deformable_conv_v1_op.cc @@ -402,6 +403,7 @@ register_unity_group( ctc_align_op.cu cumsum_op.cu cvm_op.cu + unzip_op.cu data_norm_op.cu deformable_conv_op.cu deformable_conv_v1_op.cu @@ -579,3 +581,5 @@ register_unity_group(cu expand_op.cu) register_unity_group(cu matmul_v2_op.cu) register_unity_group(cu top_k_v2_op.cu) register_unity_group(cu set_value_op.cu) +register_unity_group(cu unzip.cu) +register_unity_group(cc unzip.cc) diff --git a/paddle/fluid/operators/unzip_op.cc b/paddle/fluid/operators/unzip_op.cc new file mode 100644 index 0000000000000..2f7c5b468d564 --- /dev/null +++ b/paddle/fluid/operators/unzip_op.cc @@ -0,0 +1,179 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/unzip_op.h" + +#include + +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class unzipOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lod"); + OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "lod"); + + auto x_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_EQ( + x_dims.size(), + 2UL, + platform::errors::InvalidArgument( + "Input(X)'s rank should be 2, but got %d", x_dims.size())); + + auto lod_dims = ctx->GetInputDim("lod"); + PADDLE_ENFORCE_EQ( + lod_dims.size(), + 1UL, + platform::errors::InvalidArgument( + "Input(X)'s rank should be 1, but got %d", lod_dims.size())); + + ctx->SetOutputDim("Y", {lod_dims[0] - 1, x_dims[1]}); + } + + protected: + // Explicitly set that the data type of computation kernel of + // unzip + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class unzipGradientOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "unzipGradient"); + OP_INOUT_CHECK(ctx->HasInput("lod"), "Input", "unzip", "unzipGradient"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), + "Input", + framework::GradVarName("Y"), + "unzipGradient"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), + "Output", + framework::GradVarName("X"), + "unzipGradient"); + + auto x_dims = ctx->GetInputDim("X"); + auto lod_dims = ctx->GetInputDim("lod"); + auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); + PADDLE_ENFORCE_EQ( + x_dims.size(), + 2, + platform::errors::InvalidArgument( + "Expect Input(X)'s rank == 2, but got %d", x_dims.size())); + PADDLE_ENFORCE_EQ( + dy_dims.size(), + 2, + platform::errors::InvalidArgument( + "Expect Input(X)'s rank == 2, but got %d", dy_dims.size())); + PADDLE_ENFORCE_EQ( + lod_dims.size(), + 1, + platform::errors::InvalidArgument( + "Expect Input(X)'s rank == 1, but got %d", lod_dims.size())); + + PADDLE_ENFORCE_EQ( + x_dims[1], + dy_dims[1], + platform::errors::InvalidArgument( + "The 1st dimension of Input(X) and Input(Y@Grad) should " + "be equal, X is %d, Y@Grad is %d", + x_dims[1], + dy_dims[1])); + + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD("X", framework::GradVarName("X")); + } + + protected: + // Explicitly set that the data type of computation kernel of + // unzip + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Y")), + ctx.device_context()); + } +}; + +class unzipOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(LodTensor, default LodTensor), a 2-D tensor with shape " + "[M x N]," + " where N is the batch size and D is the emebdding dim. "); + AddInput("lod", + "(Tensor), a 1-D Tensor with shape [K]"); + AddOutput("Y", + "(LodTensor, default LodTensor), a 2-D tensor with shape " + "[K-1 x N]."); + AddComment(R"DOC( +unzip Operator. +)DOC"); + } +}; + +template +class unzipGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("unzip_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("lod", this->Input("lod")); + op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(unzipNoNeedBufferVarInferer, "lod"); +DECLARE_NO_NEED_BUFFER_VARS_INFERER(unzipGradNoNeedBufferVarInferer, "X"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(unzip, + ops::unzipOp, + ops::unzipOpMaker, + ops::unzipGradOpMaker, + ops::unzipGradOpMaker, + ops::unzipNoNeedBufferVarInferer); + +REGISTER_OPERATOR(unzip_grad, + ops::unzipGradientOp, + ops::unzipGradNoNeedBufferVarInferer); + +REGISTER_OP_CPU_KERNEL(unzip, ops::unzipOpKernel, ops::unzipOpKernel); + +REGISTER_OP_CPU_KERNEL(unzip_grad, + ops::unzipGradOpKernel, + ops::unzipGradOpKernel); diff --git a/paddle/fluid/operators/unzip_op.cu b/paddle/fluid/operators/unzip_op.cu new file mode 100644 index 0000000000000..375eb9f1016aa --- /dev/null +++ b/paddle/fluid/operators/unzip_op.cu @@ -0,0 +1,110 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/operators/unzip_op.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace paddle { +namespace operators { + +using platform::PADDLE_CUDA_NUM_THREADS; +using LoDTensor = framework::LoDTensor; + +template +__global__ void unzipKernel(const T* X, + const LodType* lod, + T* Y, + size_t col_size, + size_t n) { + CUDA_KERNEL_LOOP(i, n) { + int lod_idx = i / col_size; + if ((lod[lod_idx + 1] - lod[lod_idx]) > 0) { + assert((lod[lod_idx + 1] - lod[lod_idx]) == col_size); + int x_idx = 0; + for (int j = 0; j < lod_idx; ++j) { + if ((lod[j + 1] - lod[j]) > 0) { + x_idx++; + } + } + Y[i] = X[x_idx * col_size + (i % col_size)]; + } else { + Y[i] = 0; + } + } +} + +template +class unzipCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const auto* x = context.Input("X"); + const T* x_data = x->data(); + + const auto* lod = context.Input("lod"); + const LodType* lod_data = lod->data(); + + auto col_size = x->dims()[1]; + auto row_size = lod->dims()[0] - 1; + auto y_numel = col_size * row_size; + + auto* y = context.Output("Y"); + T* y_data = y->mutable_data(context.GetPlace()); + + // for Input X do not have lod Information. + auto stream = context.template device_context().stream(); + unzipKernel<<<(y_numel + PADDLE_CUDA_NUM_THREADS - 1) / + PADDLE_CUDA_NUM_THREADS, + PADDLE_CUDA_NUM_THREADS, + 0, + stream>>>( + x_data, lod_data, y_data, col_size, y_numel); + } +}; + +template +class unzipGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_THROW(phi::errors::Unimplemented("unzip_grad is unimplemented")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(unzip, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel, + ops::unzipCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(unzip_grad, + ops::unzipGradCUDAKernel, + ops::unzipGradCUDAKernel, + ops::unzipGradCUDAKernel, + ops::unzipGradCUDAKernel, + ops::unzipGradCUDAKernel, + ops::unzipGradCUDAKernel); diff --git a/paddle/fluid/operators/unzip_op.h b/paddle/fluid/operators/unzip_op.h new file mode 100644 index 0000000000000..3ab14f2cd038d --- /dev/null +++ b/paddle/fluid/operators/unzip_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class unzipOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_THROW(phi::errors::Unimplemented("unzip is unimplemented")); + } +}; + +template +class unzipGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_THROW(phi::errors::Unimplemented("unzip_grad is unimplemented")); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index abb6e1ed887eb..a48add72e2a5b 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -62,6 +62,7 @@ 'correlation', 'fused_bn_add_act', 'fused_seqpool_cvm', + 'unzip' ] @@ -2155,3 +2156,62 @@ def pow2_decay_with_linear_warmup( }, ) return lr + + +def unzip(input, lod): + r""" + + **unzip layers** + + unzip 'input' accroding to 'lod' + + Args: + input (Variable): The zipped input, 2-D LodTensor with shape [N, M]. + lod (Variable): The original lod of unzipped input, 1-D LodTensor with shape[K]. + + Returns: + Variable: The original unzipped tensor, 2-D LodTensor with shape[K-1, M]. + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + input_np = np.array([ + [1.0, 2.0, 3.0, 4.0], + [10.0, 20.0, 30.0, 40.0], + [100.0, 200.0, 300.0, 400.0] + ]) + lod_np = np.array([0, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12]) + input = paddle.to_tensor(input_np, "int64") + lod = paddle.to_tensor(lod_np, "int64") + + unzipped_input = fluid.layers.nn.unzip(input, lod) + ''' + unzipped_input is [ + [1.0, 2.0, 3.0, 4.0], + [0.0, 0.0, 0.0, 0.0], + [10.0, 20.0, 30.0, 40.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [100.0, 200.0, 300.0, 400.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0] + ] + ''' + """ + helper = LayerHelper('unzip', **locals()) + out = helper.create_variable(dtype=input.dtype) + check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64', 'int', 'bool', 'int64'], + 'unzip') + check_variable_and_dtype(lod, 'lod', ['int', 'int64'], + 'unzip') + helper.append_op(type='unzip', + inputs={ + 'X': [input], + 'lod': [lod] + }, + outputs={'Y': [out]}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_unzip_op.py b/python/paddle/fluid/tests/unittests/test_unzip_op.py new file mode 100644 index 0000000000000..c93e6ea5639f0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unzip_op.py @@ -0,0 +1,53 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from math import log +from math import exp +from op_test import OpTest +import unittest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core + +class TestUnzipOp(OpTest): + """ + Test unzip op with discrete one-hot labels. + """ + + def setUp(self): + self.op_type = "unzip" + self.__class__.op_type = "unzip" + self.__class__.no_need_check_grad = True + + input = [[1.0, 2.0, 3.0, 4.0], [10.0, 20.0, 30.0, 40.0], [100.0, 200.0, 300.0, 400.0]] + lod = [0, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12] + self.inputs = { + 'X': np.array(input).astype("float64"), + 'lod': np.array(lod).astype("int32") + } + out = [[1.0, 2.0, 3.0, 4.0], [0.0, 0.0, 0.0, 0.0], [10.0, 20.0, 30.0, 40.0], [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [100.0, 200.0, 300.0, 400.0], [0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]] + self.outputs = {'Y': np.array(out, dtype=float)} + + def test_check_output(self): + paddle.enable_static() + if core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + self.check_output(place) + + +if __name__ == '__main__': + unittest.main()