Skip to content

Commit

Permalink
infershaped autogen (PR #1), test=develop (PaddlePaddle#39405)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 authored Feb 9, 2022
1 parent 1bd7a14 commit b3e049f
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 32 deletions.
3 changes: 2 additions & 1 deletion paddle/infrt/naive/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ cc_library(infrt_naive SRCS meta_tensor.cc
infershaped/infershaped_kernel_launchers.cc
)

cc_test_tiny(test_infrt_infershape_launchers SRCS infershaped/infershape_launchers_test.cc DEPS infrt)
cc_test_tiny(test_infrt_infershape_launchers SRCS
infershaped/infershape_launchers_test.cc DEPS infrt)
30 changes: 14 additions & 16 deletions paddle/infrt/naive/infershaped/elementwise_add.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"

// This file contains a example of the infershape ElementwiseAdd kernel.
// Some of the following code should be generated from PTEN by script.
Expand All @@ -32,39 +33,36 @@ static void ElementwiseAddInferShape(const MetaTensor& a,
*c->mutable_shape() = a.shape();
}

static void ElementwiseAdd(const tensor::DenseHostTensor& a,
static void ElementwiseAdd(tensor::DenseHostTensor* /*Context*/,
const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c) {}

// TODO(zhiqiang) This class should be generated by a script offline.
class ElementwiseAddLauncher : public InferShapedKernelLauncher {
template <typename KernelFunc,
KernelFunc kernel,
typename InferShapedFunc,
InferShapedFunc infershape>
class KernelLauncher : public InferShapedKernelLauncher {
public:
static const uint16_t input_tensor_indices[2];
static const uint16_t num_input_tensors{2};
static const uint16_t num_input_tensors{InferShapeHelper<KernelFunc>::count};
static const bool turn_on_infer_shape_cache{true};

void Invoke(host_context::KernelFrame* frame) override {
// Build the infershape KernelFrame if needed.
// TODO(Superjomn) add unlikely here.
if (infershape_kernel_frame_builder.IsEmpty()) {
CreateKernelFrameForInferShape(frame);
}
if (turn_on_infer_shape_cache) {
if (IsShapeChanged(input_tensor_indices, num_input_tensors)) {
INFRT_KERNEL(ElementwiseAddInferShape)
(&infershape_kernel_frame_builder);
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
if (!turn_on_infer_shape_cache || IsShapeChanged(num_input_tensors)) {
::infrt::host_context::KernelImpl<InferShapedFunc, infershape>::Invoke(
&infershape_kernel_frame_builder);
BuildInferShapeCache(num_input_tensors);
}
} else {
INFRT_KERNEL(ElementwiseAddInferShape)(&infershape_kernel_frame_builder);
BuildInferShapeCache(input_tensor_indices, num_input_tensors);
}

INFRT_KERNEL(ElementwiseAdd)(frame);
::infrt::host_context::KernelImpl<KernelFunc, kernel>::Invoke(frame);
}
};

const uint16_t ElementwiseAddLauncher::input_tensor_indices[2] = {0, 1};

} // namespace naive
} // namespace infrt
14 changes: 14 additions & 0 deletions paddle/infrt/naive/infershaped/infershape_launchers_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,24 @@
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"
#include "paddle/infrt/naive/infershaped/infershaped_utils.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"

namespace infrt {
namespace naive {

namespace {
static void ElementwiseAddTest(const tensor::DenseHostTensor& a,
const tensor::DenseHostTensor& b,
tensor::DenseHostTensor* c);
}

TEST(utils, registry) {
constexpr uint8_t count =
InferShapeHelper<decltype(&ElementwiseAddTest)>::count;
CHECK_EQ(count, 2U);
}

TEST(ElementwiseAdd, registry) {
InferShapedKernelRegistry registry;
RegisterInferShapeLaunchers(&registry);
Expand All @@ -35,6 +48,7 @@ TEST(ElementwiseAdd, registry) {
tensor::DenseHostTensor c({2, 8}, GetDType<float>());

host_context::KernelFrameBuilder kernel_frame_builder;
kernel_frame_builder.AddArgument(new host_context::Value(0));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(a)));
kernel_frame_builder.AddArgument(new host_context::Value(std::move(b)));
kernel_frame_builder.SetResults({new host_context::Value(std::move(c))});
Expand Down
17 changes: 7 additions & 10 deletions paddle/infrt/naive/infershaped/infershaped_kernel_launcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace naive {
void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
host_context::KernelFrame* frame) {
for (host_context::Value* value :
frame->GetValues(0, frame->GetNumElements())) {
frame->GetValues(1, frame->GetNumElements() - 1)) {
// TODO(Superjomn) To extend this.
if (value->is_type<tensor::DenseHostTensor>()) {
values.emplace_back(MetaTensor{&value->get<tensor::DenseHostTensor>()});
Expand All @@ -32,27 +32,24 @@ void InferShapedKernelLauncher::CreateKernelFrameForInferShape(
}

void InferShapedKernelLauncher::BuildInferShapeCache(
const uint16_t* input_indices, const uint16_t num_inputs) {
const uint16_t num_inputs) {
tensor_shape_cache.resize(num_inputs);
for (uint16_t i = 0; i < num_inputs; i++) {
tensor_shape_cache[i] =
infershape_kernel_frame_builder.GetArgAt(input_indices[i])
->get<MetaTensor>()
.shape();
infershape_kernel_frame_builder.GetArgAt(i)->get<MetaTensor>().shape();
}
}

bool InferShapedKernelLauncher::IsShapeChanged(
const uint16_t* input_indices, const uint16_t num_inputs) const {
const uint16_t num_inputs) const {
if (tensor_shape_cache.empty() && !infershape_kernel_frame_builder.IsEmpty())
return true;

bool changed = false;
for (uint16_t i = 0; i < num_inputs && !changed; i++) {
changed = changed || (tensor_shape_cache[i] !=
infershape_kernel_frame_builder
.GetArgAt<MetaTensor>(input_indices[i])
.shape());
changed = changed ||
(tensor_shape_cache[i] !=
infershape_kernel_frame_builder.GetArgAt<MetaTensor>(i).shape());
}
return changed;
}
Expand Down
6 changes: 2 additions & 4 deletions paddle/infrt/naive/infershaped/infershaped_kernel_launcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,10 @@ struct InferShapedKernelLauncher {

//! Build or update the infer-shape cache using the latest shape from
//! InferShapeFrame.
void BuildInferShapeCache(const uint16_t* input_indices,
const uint16_t num_inputs);
void BuildInferShapeCache(const uint16_t num_inputs);

//! Compare the latest shape with the shape cache.
bool IsShapeChanged(const uint16_t* input_indices,
const uint16_t num_inputs) const;
bool IsShapeChanged(const uint16_t num_inputs) const;

// values to hold the TensorMeta.
llvm::SmallVector<host_context::ValueRef, 3> values;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@
// limitations under the License.

#include "paddle/infrt/naive/infershaped/infershaped_kernel_launchers.h"

#include "paddle/infrt/naive/infershaped/elementwise_add.h"
#include "paddle/infrt/naive/infershaped/infershaped_registry.h"

namespace infrt {
namespace naive {

using ElementwiseAddLauncher =
KernelLauncher<decltype(&ElementwiseAdd),
&ElementwiseAdd,
decltype(&ElementwiseAddInferShape),
&ElementwiseAddInferShape>;

void RegisterInferShapeLaunchers(InferShapedKernelRegistry* registry) {
registry->AddKernel("elementwise_add",
INFERSHAPED_KERNEL_CREATOR(ElementwiseAddLauncher));
Expand Down
77 changes: 77 additions & 0 deletions paddle/infrt/naive/infershaped/infershaped_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <type_traits>
#include "paddle/infrt/tensor/dense_host_tensor.h"

namespace infrt {
namespace naive {
namespace infershaped {

using KeyType = const tensor::DenseHostTensor&;
using CountType = uint8_t;

constexpr CountType value(std::true_type) { return 1; }

constexpr CountType value(std::false_type) { return 0; }

template <typename T>
constexpr CountType value() {
return value(std::integral_constant<bool, std::is_same<T, KeyType>::value>{});
}

template <typename FirstArg>
constexpr CountType count(CountType num) {
return num;
}

template <typename FirstArg>
constexpr CountType count() {
return 0;
}

template <>
constexpr CountType count<KeyType>(CountType num) {
return num + 1;
}

template <>
constexpr CountType count<KeyType>() {
return 1;
}

template <typename FirstArg, typename SecondArg, typename... RestOfArgs>
constexpr CountType count(CountType num) {
return count<SecondArg, RestOfArgs...>(num + value<FirstArg>());
}

template <typename FirstArg, typename SecondArg, typename... RestOfArgs>
constexpr CountType count() {
return count<SecondArg, RestOfArgs...>(value<FirstArg>());
}

} // namespace infershaped

template <typename F>
struct InferShapeHelper;

template <typename Return, typename... Args>
struct InferShapeHelper<Return (*)(Args...)> {
static constexpr int count = infershaped::count<Args...>();
};

} // namespace naive
} // namespace infrt

0 comments on commit b3e049f

Please sign in to comment.