-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add op read_file and decode_jpeg (#32564)
* add op read_file and decode_jpeg
- Loading branch information
1 parent
7a73692
commit b22f6d6
Showing
11 changed files
with
607 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <fstream> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "paddle/fluid/framework/generator.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/operator.h" | ||
#include "paddle/fluid/platform/dynload/nvjpeg.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
class CPUDecodeJpegKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
// TODO(LieLinJiang): add cpu implement. | ||
PADDLE_THROW(platform::errors::Unimplemented( | ||
"DecodeJpeg op only supports GPU now.")); | ||
} | ||
}; | ||
|
||
class DecodeJpegOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DecodeJpeg"); | ||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "DecodeJpeg"); | ||
|
||
auto mode = ctx->Attrs().Get<std::string>("mode"); | ||
std::vector<int> out_dims; | ||
|
||
if (mode == "unchanged") { | ||
out_dims = {-1, -1, -1}; | ||
} else if (mode == "gray") { | ||
out_dims = {1, -1, -1}; | ||
} else if (mode == "rgb") { | ||
out_dims = {3, -1, -1}; | ||
} else { | ||
PADDLE_THROW(platform::errors::Fatal( | ||
"The provided mode is not supported for JPEG files on GPU: ", mode)); | ||
} | ||
|
||
ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType( | ||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); | ||
} | ||
|
||
framework::OpKernelType GetKernelTypeForVar( | ||
const std::string& var_name, const framework::Tensor& tensor, | ||
const framework::OpKernelType& expected_kernel_type) const { | ||
if (var_name == "X") { | ||
return expected_kernel_type; | ||
} | ||
|
||
return framework::OpKernelType(tensor.type(), tensor.place(), | ||
tensor.layout()); | ||
} | ||
}; | ||
|
||
class DecodeJpegOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("X", | ||
"A one dimensional uint8 tensor containing the raw bytes " | ||
"of the JPEG image. It is a tensor with rank 1."); | ||
AddOutput("Out", "The output tensor of DecodeJpeg op"); | ||
AddComment(R"DOC( | ||
This operator decodes a JPEG image into a 3 dimensional RGB Tensor | ||
or 1 dimensional Gray Tensor. Optionally converts the image to the | ||
desired format. The values of the output tensor are uint8 between 0 | ||
and 255. | ||
)DOC"); | ||
AddAttr<std::string>( | ||
"mode", | ||
"(string, default \"unchanged\"), The read mode used " | ||
"for optionally converting the image, can be \"unchanged\" " | ||
",\"gray\" , \"rgb\" .") | ||
.SetDefault("unchanged"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OPERATOR( | ||
decode_jpeg, ops::DecodeJpegOp, ops::DecodeJpegOpMaker, | ||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, | ||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>) | ||
|
||
REGISTER_OP_CPU_KERNEL(decode_jpeg, ops::CPUDecodeJpegKernel<uint8_t>) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef PADDLE_WITH_HIP | ||
|
||
#include <string> | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/platform/dynload/nvjpeg.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
#include "paddle/fluid/platform/stream/cuda_stream.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
static cudaStream_t nvjpeg_stream = nullptr; | ||
static nvjpegHandle_t nvjpeg_handle = nullptr; | ||
|
||
void InitNvjpegImage(nvjpegImage_t* img) { | ||
for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) { | ||
img->channel[c] = nullptr; | ||
img->pitch[c] = 0; | ||
} | ||
} | ||
|
||
template <typename T> | ||
class GPUDecodeJpegKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
// Create nvJPEG handle | ||
if (nvjpeg_handle == nullptr) { | ||
nvjpegStatus_t create_status = | ||
platform::dynload::nvjpegCreateSimple(&nvjpeg_handle); | ||
|
||
PADDLE_ENFORCE_EQ(create_status, NVJPEG_STATUS_SUCCESS, | ||
platform::errors::Fatal("nvjpegCreateSimple failed: ", | ||
create_status)); | ||
} | ||
|
||
nvjpegJpegState_t nvjpeg_state; | ||
nvjpegStatus_t state_status = | ||
platform::dynload::nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); | ||
|
||
PADDLE_ENFORCE_EQ(state_status, NVJPEG_STATUS_SUCCESS, | ||
platform::errors::Fatal("nvjpegJpegStateCreate failed: ", | ||
state_status)); | ||
|
||
int components; | ||
nvjpegChromaSubsampling_t subsampling; | ||
int widths[NVJPEG_MAX_COMPONENT]; | ||
int heights[NVJPEG_MAX_COMPONENT]; | ||
|
||
auto* x = ctx.Input<framework::Tensor>("X"); | ||
auto* x_data = x->data<T>(); | ||
|
||
nvjpegStatus_t info_status = platform::dynload::nvjpegGetImageInfo( | ||
nvjpeg_handle, x_data, (size_t)x->numel(), &components, &subsampling, | ||
widths, heights); | ||
|
||
PADDLE_ENFORCE_EQ( | ||
info_status, NVJPEG_STATUS_SUCCESS, | ||
platform::errors::Fatal("nvjpegGetImageInfo failed: ", info_status)); | ||
|
||
int width = widths[0]; | ||
int height = heights[0]; | ||
|
||
nvjpegOutputFormat_t output_format; | ||
int output_components; | ||
|
||
auto mode = ctx.Attr<std::string>("mode"); | ||
if (mode == "unchanged") { | ||
if (components == 1) { | ||
output_format = NVJPEG_OUTPUT_Y; | ||
output_components = 1; | ||
} else if (components == 3) { | ||
output_format = NVJPEG_OUTPUT_RGB; | ||
output_components = 3; | ||
} else { | ||
platform::dynload::nvjpegJpegStateDestroy(nvjpeg_state); | ||
PADDLE_THROW(platform::errors::Fatal( | ||
"The provided mode is not supported for JPEG files on GPU")); | ||
} | ||
} else if (mode == "gray") { | ||
output_format = NVJPEG_OUTPUT_Y; | ||
output_components = 1; | ||
} else if (mode == "rgb") { | ||
output_format = NVJPEG_OUTPUT_RGB; | ||
output_components = 3; | ||
} else { | ||
platform::dynload::nvjpegJpegStateDestroy(nvjpeg_state); | ||
PADDLE_THROW(platform::errors::Fatal( | ||
"The provided mode is not supported for JPEG files on GPU")); | ||
} | ||
|
||
nvjpegImage_t out_image; | ||
InitNvjpegImage(&out_image); | ||
|
||
// create nvjpeg stream | ||
if (nvjpeg_stream == nullptr) { | ||
cudaStreamCreateWithFlags(&nvjpeg_stream, cudaStreamNonBlocking); | ||
} | ||
|
||
int sz = widths[0] * heights[0]; | ||
|
||
auto* out = ctx.Output<framework::LoDTensor>("Out"); | ||
std::vector<int64_t> out_shape = {output_components, height, width}; | ||
out->Resize(framework::make_ddim(out_shape)); | ||
|
||
T* data = out->mutable_data<T>(ctx.GetPlace()); | ||
|
||
for (int c = 0; c < output_components; c++) { | ||
out_image.channel[c] = data + c * sz; | ||
out_image.pitch[c] = width; | ||
} | ||
|
||
nvjpegStatus_t decode_status = platform::dynload::nvjpegDecode( | ||
nvjpeg_handle, nvjpeg_state, x_data, x->numel(), output_format, | ||
&out_image, nvjpeg_stream); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_CUDA_KERNEL(decode_jpeg, ops::GPUDecodeJpegKernel<uint8_t>) | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include <fstream> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "paddle/fluid/framework/generator.h" | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/framework/operator.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
class CPUReadFileKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& ctx) const override { | ||
auto filename = ctx.Attr<std::string>("filename"); | ||
|
||
std::ifstream input(filename.c_str(), | ||
std::ios::in | std::ios::binary | std::ios::ate); | ||
std::streamsize file_size = input.tellg(); | ||
|
||
input.seekg(0, std::ios::beg); | ||
|
||
auto* out = ctx.Output<framework::LoDTensor>("Out"); | ||
std::vector<int64_t> out_shape = {file_size}; | ||
out->Resize(framework::make_ddim(out_shape)); | ||
|
||
uint8_t* data = out->mutable_data<T>(ctx.GetPlace()); | ||
|
||
input.read(reinterpret_cast<char*>(data), file_size); | ||
} | ||
}; | ||
|
||
class ReadFileOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, | ||
platform::errors::InvalidArgument( | ||
"Output(Out) of ReadFileOp is null.")); | ||
|
||
auto out_dims = std::vector<int>(1, -1); | ||
ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); | ||
} | ||
|
||
protected: | ||
framework::OpKernelType GetExpectedKernelType( | ||
const framework::ExecutionContext& ctx) const override { | ||
return framework::OpKernelType(framework::proto::VarType::UINT8, | ||
platform::CPUPlace()); | ||
} | ||
}; | ||
|
||
class ReadFileOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddOutput("Out", "The output tensor of ReadFile op"); | ||
AddComment(R"DOC( | ||
This operator read a file. | ||
)DOC"); | ||
AddAttr<std::string>("filename", "Path of the file to be readed.") | ||
.SetDefault({}); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
|
||
REGISTER_OPERATOR( | ||
read_file, ops::ReadFileOp, ops::ReadFileOpMaker, | ||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, | ||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>) | ||
|
||
REGISTER_OP_CPU_KERNEL(read_file, ops::CPUReadFileKernel<uint8_t>) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.