Skip to content

Commit

Permalink
polish kernel factory and kernel registry
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql committed Oct 21, 2021
1 parent 76a588e commit fb224ab
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 71 deletions.
25 changes: 4 additions & 21 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1080,20 +1080,6 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
this->InferShape(&infer_shape_ctx);
}

static std::string RuntimeContextDebugString(const RuntimeContext& ctx) {
std::stringstream ss;
ss << "RuntimeContext(Inputs: ";
for (auto& var_pair : ctx.inputs) {
ss << var_pair.first << ", ";
}
ss << "Outputs: ";
for (auto& var_pair : ctx.outputs) {
ss << var_pair.first << ", ";
}
ss << ")";
return ss.str();
}

void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const {
// To reduce the elapsed time of HasAttr, we use bool variable to record the
Expand Down Expand Up @@ -1144,7 +1130,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// and RCOM backend, the XPU, NPU and MKLDNN will be supported in the second
// phase
if (FLAGS_run_pt_kernel &&
pten::KernelFactory::Instance().ContainsKernel(type_.c_str())) {
pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) {
if (pt_kernel_signature_.get() == nullptr || pt_kernel_.get() == nullptr) {
ChoosePtenKernel(exe_ctx);
}
Expand Down Expand Up @@ -1651,10 +1637,9 @@ void OperatorWithKernel::ParseInputDataType(
if (t != nullptr) {
PADDLE_ENFORCE_EQ(
t->IsInitialized(), true,
platform::errors::InvalidArgument(
"The Tensor in the %s Op's Input Variable %s(%s) is "
"not initialized.",
Type(), name, Inputs().at(name).at(i)));
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
"contains uninitialized Tensor.",
Type(), name));
proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE(tmp == *data_type || *data_type == default_data_type,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -1789,8 +1774,6 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(

pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const {
VLOG(1) << RuntimeContextDebugString(ctx);

// TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later:
// 1. the input and output are not tensor
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;

if (FLAGS_run_pt_kernel &&
pten::KernelFactory::Instance().ContainsKernel(op.Type().c_str())) {
pten::KernelFactory::Instance().HasCompatiblePtenKernel(op.Type())) {
auto pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx);

VLOG(1) << framework::KernelSignatureToString(pt_kernel_signature);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/op_function_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ GenerateOpFunctions() {
// since only OperatorWithKernel can run in dygraph mode.
// if the pten lib contains op kernel, we still generate ops method
if (!all_kernels.count(op_type) &&
!pten::KernelFactory::Instance().ContainsKernel(op_type.c_str())) {
!pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) {
continue;
}

Expand Down
18 changes: 13 additions & 5 deletions paddle/pten/core/kernel_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,24 @@

namespace pten {

uint32_t KernelKey::Hash::operator()(const KernelKey& key) const {
uint32_t hash_value = 0;
// |----31-20------|---19-12---|---11-8----|---7-0---|
// | For extension | DataType | DataLayout | Backend |
hash_value |= static_cast<uint8_t>(key.backend());
hash_value |=
(static_cast<uint8_t>(key.layout()) << KernelKey::kBackendBitLength);
hash_value |=
(static_cast<uint16_t>(key.dtype())
<< (KernelKey::kBackendBitLength + KernelKey::kDataTypeBitLength));
return hash_value;
}

KernelFactory& KernelFactory::Instance() {
static KernelFactory g_op_kernel_factory;
return g_op_kernel_factory;
}

bool KernelFactory::ContainsKernel(const char* kernel_name) const {
auto iter = kernels_.find(KernelName(kernel_name, ""));
return (iter != kernels_.end());
}

Kernel KernelFactory::SelectKernel(const KernelName& kernel_name,
const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
Expand Down
77 changes: 34 additions & 43 deletions paddle/pten/core/kernel_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <ostream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include "paddle/pten/common/backend.h"
Expand All @@ -37,10 +38,10 @@ using DataLayout = paddle::experimental::DataLayout;
/**
* [ Naming considerations ]
*
* The tensor Compute library contains many kernels, and the computation
* The tensor operation library contains many kernels, and the computation
* in each specific scenario is represented by an kernel.
*
* We directly named it `Kernel` instead of `Kernel`, the tensor Compute
* We directly named it `Kernel` instead of `Kernel`, the tensor operation
* library here and fluid are independent, avoiding developers from
* misunderstanding the relationship between the two concepts.
*/
Expand All @@ -52,10 +53,7 @@ using KernelFn = void (*)(KernelContext* ctx);
class KernelName final {
public:
KernelName(std::string name, std::string overload_name)
: name_(std::move(name)), overload_name_(std::move(overload_name)) {
hash_value_ = std::hash<std::string>()(name_) ^
(std::hash<std::string>()(overload_name_) << 1);
}
: name_(std::move(name)), overload_name_(std::move(overload_name)) {}

KernelName(const std::string& kernel_name) {
ParseNameAndOverloadNameFromString(kernel_name);
Expand All @@ -68,24 +66,26 @@ class KernelName final {

const std::string& name() const { return name_; }
const std::string& overload_name() const { return overload_name_; }
size_t hash_value() const { return hash_value_; }

struct Hash {
size_t operator()(const KernelName& kernel_name) const {
return kernel_name.hash_value();
return std::hash<std::string>()(kernel_name.name()) ^
(std::hash<std::string>()(kernel_name.overload_name()) << 1);
}
};

size_t hash_value() const { return Hash()(*this); }

bool operator<(const KernelName& kernel_name) const {
return hash_value_ < kernel_name.hash_value();
return hash_value() < kernel_name.hash_value();
}

bool operator==(const KernelName& kernel_name) const {
return hash_value_ == kernel_name.hash_value();
return hash_value() == kernel_name.hash_value();
}

bool operator!=(const KernelName& kernel_name) const {
return hash_value_ != kernel_name.hash_value();
return hash_value() != kernel_name.hash_value();
}

private:
Expand All @@ -98,57 +98,45 @@ class KernelName final {
name_ = kernel_name.substr(0, pos);
overload_name_ = kernel_name.substr(pos + 1, kernel_name.size());
}
hash_value_ = std::hash<std::string>()(name_) ^
(std::hash<std::string>()(overload_name_) << 1);
}

// The members cannot be modified except by constructing,
// because the hash value need to be re calculated
// TODO(chenweihang): use string_view later?
// TODO(chenweihang): use string_view to improve performance later
std::string name_;
std::string overload_name_;
// Avoid calculating Hash value at runtime
size_t hash_value_;
};

class KernelKey {
public:
KernelKey() = default;

KernelKey(Backend backend, DataLayout layout, DataType dtype)
: backend_(backend), layout_(layout), dtype_(dtype) {
// |----31-20------|---19-12---|---11-8----|---7-0---|
// | For extension | DataType | DataLayout | Backend |

hash_value_ = 0;
hash_value_ |= static_cast<uint8_t>(backend_);
hash_value_ |= (static_cast<uint8_t>(layout_) << kBackendBitLength);
hash_value_ |= (static_cast<uint16_t>(dtype_)
<< (kBackendBitLength + kDataTypeBitLength));
}
: backend_(backend), layout_(layout), dtype_(dtype) {}

Backend backend() const { return backend_; }
DataLayout layout() const { return layout_; }
DataType dtype() const { return dtype_; }

uint32_t hash_value() const { return hash_value_; }
struct Hash {
// Note: Now the number of bits we need does not exceed 32 bits, so there is
// no need to use 64 bits. If needed in the future, it can be expanded,
// but now we don’t over-design.
uint32_t operator()(const KernelKey& key) const;
};

uint32_t hash_value() const { return Hash()(*this); }

bool operator<(const KernelKey& key) const {
return hash_value_ < key.hash_value();
return hash_value() < key.hash_value();
}

bool operator==(const KernelKey& key) const {
return hash_value_ == key.hash_value();
return hash_value() == key.hash_value();
}

bool operator!=(const KernelKey& key) const {
return hash_value_ != key.hash_value();
return hash_value() != key.hash_value();
}

struct Hash {
uint32_t operator()(const KernelKey& key) const { return key.hash_value(); }
};

private:
// In total should be smaller than 32.
constexpr static int kBackendBitLength = 8;
Expand All @@ -158,12 +146,6 @@ class KernelKey {
Backend backend_{Backend::UNDEFINED};
DataLayout layout_{DataLayout::UNDEFINED};
DataType dtype_{DataType::UNDEFINED};

// Avoid calculating Hash value at runtime.
// Note: Now the number of bits we need does not exceed 32 bits, so there is
// no need to use 64 bits. If needed in the future, it can be expanded,
// but now we don’t over-design.
uint32_t hash_value_;
};

// TODO(chenweihang): how deal with vector<Param>?
Expand Down Expand Up @@ -282,7 +264,13 @@ class KernelFactory {

KernelMap& kernels() { return kernels_; }

bool ContainsKernel(const char* name) const;
void InsertCompatibleOpType(const std::string& op_type) {
compatible_op_types_.insert(op_type);
}

bool HasCompatiblePtenKernel(const std::string& op_type) const {
return compatible_op_types_.count(op_type) > 0;
}

const Kernel& SelectKernelOrThrowError(const KernelName& kernel_name,
const KernelKey& kernel_key) const;
Expand All @@ -299,6 +287,9 @@ class KernelFactory {
KernelFactory() = default;

KernelMap kernels_;
// Used to be compatible with the original execution system and
// quickly confirm whether the new kernel can be called
std::unordered_set<std::string> compatible_op_types_;
};

/** operator << overload **/
Expand Down
1 change: 1 addition & 0 deletions paddle/pten/core/kernel_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ struct KernelRegistrar {
args_parse_fn(kernel_key, kernel.mutable_args_def());
args_def_fn(&kernel);

KernelFactory::Instance().InsertCompatibleOpType(kernel_name.name());
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
}
};
Expand Down

0 comments on commit fb224ab

Please sign in to comment.