Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick 2.1]Add op read_file and decode_jpeg #32686

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ function(op_library TARGET)
list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
list(REMOVE_ITEM hip_srcs "correlation_op.cu")
list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu")
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
else()
Expand Down
114 changes: 114 additions & 0 deletions paddle/fluid/operators/decode_jpeg_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <fstream>
#include <string>
#include <vector>

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/dynload/nvjpeg.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace operators {

template <typename T>
class CPUDecodeJpegKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// TODO(LieLinJiang): add cpu implement.
PADDLE_THROW(platform::errors::Unimplemented(
"DecodeJpeg op only supports GPU now."));
}
};

class DecodeJpegOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DecodeJpeg");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "DecodeJpeg");

auto mode = ctx->Attrs().Get<std::string>("mode");
std::vector<int> out_dims;

if (mode == "unchanged") {
out_dims = {-1, -1, -1};
} else if (mode == "gray") {
out_dims = {1, -1, -1};
} else if (mode == "rgb") {
out_dims = {3, -1, -1};
} else {
PADDLE_THROW(platform::errors::Fatal(
"The provided mode is not supported for JPEG files on GPU: ", mode));
}

ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}

framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
if (var_name == "X") {
return expected_kernel_type;
}

return framework::OpKernelType(tensor.type(), tensor.place(),
tensor.layout());
}
};

class DecodeJpegOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"A one dimensional uint8 tensor containing the raw bytes "
"of the JPEG image. It is a tensor with rank 1.");
AddOutput("Out", "The output tensor of DecodeJpeg op");
AddComment(R"DOC(
This operator decodes a JPEG image into a 3 dimensional RGB Tensor
or 1 dimensional Gray Tensor. Optionally converts the image to the
desired format. The values of the output tensor are uint8 between 0
and 255.
)DOC");
AddAttr<std::string>(
"mode",
"(string, default \"unchanged\"), The read mode used "
"for optionally converting the image, can be \"unchanged\" "
",\"gray\" , \"rgb\" .")
.SetDefault("unchanged");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(
decode_jpeg, ops::DecodeJpegOp, ops::DecodeJpegOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)

REGISTER_OP_CPU_KERNEL(decode_jpeg, ops::CPUDecodeJpegKernel<uint8_t>)
138 changes: 138 additions & 0 deletions paddle/fluid/operators/decode_jpeg_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef PADDLE_WITH_HIP

#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/dynload/nvjpeg.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"

namespace paddle {
namespace operators {

static cudaStream_t nvjpeg_stream = nullptr;
static nvjpegHandle_t nvjpeg_handle = nullptr;

void InitNvjpegImage(nvjpegImage_t* img) {
for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) {
img->channel[c] = nullptr;
img->pitch[c] = 0;
}
}

template <typename T>
class GPUDecodeJpegKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// Create nvJPEG handle
if (nvjpeg_handle == nullptr) {
nvjpegStatus_t create_status =
platform::dynload::nvjpegCreateSimple(&nvjpeg_handle);

PADDLE_ENFORCE_EQ(create_status, NVJPEG_STATUS_SUCCESS,
platform::errors::Fatal("nvjpegCreateSimple failed: ",
create_status));
}

nvjpegJpegState_t nvjpeg_state;
nvjpegStatus_t state_status =
platform::dynload::nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state);

PADDLE_ENFORCE_EQ(state_status, NVJPEG_STATUS_SUCCESS,
platform::errors::Fatal("nvjpegJpegStateCreate failed: ",
state_status));

int components;
nvjpegChromaSubsampling_t subsampling;
int widths[NVJPEG_MAX_COMPONENT];
int heights[NVJPEG_MAX_COMPONENT];

auto* x = ctx.Input<framework::Tensor>("X");
auto* x_data = x->data<T>();

nvjpegStatus_t info_status = platform::dynload::nvjpegGetImageInfo(
nvjpeg_handle, x_data, (size_t)x->numel(), &components, &subsampling,
widths, heights);

PADDLE_ENFORCE_EQ(
info_status, NVJPEG_STATUS_SUCCESS,
platform::errors::Fatal("nvjpegGetImageInfo failed: ", info_status));

int width = widths[0];
int height = heights[0];

nvjpegOutputFormat_t output_format;
int output_components;

auto mode = ctx.Attr<std::string>("mode");
if (mode == "unchanged") {
if (components == 1) {
output_format = NVJPEG_OUTPUT_Y;
output_components = 1;
} else if (components == 3) {
output_format = NVJPEG_OUTPUT_RGB;
output_components = 3;
} else {
platform::dynload::nvjpegJpegStateDestroy(nvjpeg_state);
PADDLE_THROW(platform::errors::Fatal(
"The provided mode is not supported for JPEG files on GPU"));
}
} else if (mode == "gray") {
output_format = NVJPEG_OUTPUT_Y;
output_components = 1;
} else if (mode == "rgb") {
output_format = NVJPEG_OUTPUT_RGB;
output_components = 3;
} else {
platform::dynload::nvjpegJpegStateDestroy(nvjpeg_state);
PADDLE_THROW(platform::errors::Fatal(
"The provided mode is not supported for JPEG files on GPU"));
}

nvjpegImage_t out_image;
InitNvjpegImage(&out_image);

// create nvjpeg stream
if (nvjpeg_stream == nullptr) {
cudaStreamCreateWithFlags(&nvjpeg_stream, cudaStreamNonBlocking);
}

int sz = widths[0] * heights[0];

auto* out = ctx.Output<framework::LoDTensor>("Out");
std::vector<int64_t> out_shape = {output_components, height, width};
out->Resize(framework::make_ddim(out_shape));

T* data = out->mutable_data<T>(ctx.GetPlace());

for (int c = 0; c < output_components; c++) {
out_image.channel[c] = data + c * sz;
out_image.pitch[c] = width;
}

nvjpegStatus_t decode_status = platform::dynload::nvjpegDecode(
nvjpeg_handle, nvjpeg_state, x_data, x->numel(), output_format,
&out_image, nvjpeg_stream);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(decode_jpeg, ops::GPUDecodeJpegKernel<uint8_t>)

#endif
92 changes: 92 additions & 0 deletions paddle/fluid/operators/read_file_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <fstream>
#include <string>
#include <vector>

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace operators {

template <typename T>
class CPUReadFileKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto filename = ctx.Attr<std::string>("filename");

std::ifstream input(filename.c_str(),
std::ios::in | std::ios::binary | std::ios::ate);
std::streamsize file_size = input.tellg();

input.seekg(0, std::ios::beg);

auto* out = ctx.Output<framework::LoDTensor>("Out");
std::vector<int64_t> out_shape = {file_size};
out->Resize(framework::make_ddim(out_shape));

uint8_t* data = out->mutable_data<T>(ctx.GetPlace());

input.read(reinterpret_cast<char*>(data), file_size);
}
};

class ReadFileOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of ReadFileOp is null."));

auto out_dims = std::vector<int>(1, -1);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::UINT8,
platform::CPUPlace());
}
};

class ReadFileOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("Out", "The output tensor of ReadFile op");
AddComment(R"DOC(
This operator read a file.
)DOC");
AddAttr<std::string>("filename", "Path of the file to be readed.")
.SetDefault({});
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(
read_file, ops::ReadFileOp, ops::ReadFileOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)

REGISTER_OP_CPU_KERNEL(read_file, ops::CPUReadFileKernel<uint8_t>)
2 changes: 1 addition & 1 deletion paddle/fluid/platform/dynload/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)

list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc nvtx.cc)
list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc nvtx.cc nvjpeg.cc)

if (WITH_ROCM)
list(APPEND HIP_SRCS rocblas.cc miopen.cc hiprand.cc)
Expand Down
Loading