Skip to content

Commit

Permalink
some util template function for instruction attribute (PaddlePaddle#37)
Browse files Browse the repository at this point in the history
* some util template function for instruction attribute

* add unit test

* move explicit template specialization to cc file and merge PopulateAttrValueProtoD0,PopulateAttrValueProtoD1

* add note::F16 with platform::float16
  • Loading branch information
CtfGo authored Aug 27, 2021
1 parent edb4fe8 commit 00c4edb
Show file tree
Hide file tree
Showing 8 changed files with 595 additions and 0 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/compiler/piano/note/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@ target_compile_options(note_proto PUBLIC "-Wno-extra")

cc_library(note_ir SRCS instruction.cc function.cc module.cc DEPS note_opcode note_proto piano_data_description)
cc_test(note_ir_test SRCS note_ir_test.cc DEPS note_ir)

cc_library(note_template_util SRCS element_type_util.cc populate_attribute_value.cc DEPS note_proto)
cc_test(note_element_type_util_test SRCS element_type_util_test.cc DEPS note_template_util)
cc_test(note_populate_attribute_value_test SRCS populate_attribute_value_test.cc DEPS note_template_util)
34 changes: 34 additions & 0 deletions paddle/fluid/compiler/piano/note/attribute_key_defs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

namespace paddle {
namespace piano {
namespace note {

// In this file, we define the attribute key name for
// specific note instructions. Format of name as bellow
// `k(p0).(p1)`
// -`p0`: the instruction name
// -`p1` the specific attribute name

// literal value of Constant instruction
constexpr char kConstantValue[] = "kConstant.Value";
// `dimensions_alignment` of Broadcast instruction
constexpr char kBroadcastAlignment[] = "kBroadcast.Alignment";

} // namespace note
} // namespace piano
} // namespace paddle
83 changes: 83 additions & 0 deletions paddle/fluid/compiler/piano/note/element_type_util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/compiler/piano/note/element_type_util.h"

namespace paddle {
namespace piano {
namespace note {

template <>
ElementTypeProto NativeToElementTypeProto<bool>() {
return B1;
}

template <>
ElementTypeProto NativeToElementTypeProto<int8_t>() {
return S8;
}

template <>
ElementTypeProto NativeToElementTypeProto<int16_t>() {
return S16;
}

template <>
ElementTypeProto NativeToElementTypeProto<int32_t>() {
return S32;
}

template <>
ElementTypeProto NativeToElementTypeProto<int64_t>() {
return S64;
}

template <>
ElementTypeProto NativeToElementTypeProto<uint8_t>() {
return U8;
}

template <>
ElementTypeProto NativeToElementTypeProto<uint16_t>() {
return U16;
}

template <>
ElementTypeProto NativeToElementTypeProto<uint32_t>() {
return U32;
}

template <>
ElementTypeProto NativeToElementTypeProto<uint64_t>() {
return U64;
}

template <>
ElementTypeProto NativeToElementTypeProto<platform::float16>() {
return F16;
}

template <>
ElementTypeProto NativeToElementTypeProto<float>() {
return F32;
}

template <>
ElementTypeProto NativeToElementTypeProto<double>() {
return F64;
}

} // namespace note
} // namespace piano
} // namespace paddle
119 changes: 119 additions & 0 deletions paddle/fluid/compiler/piano/note/element_type_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <cstdint>
#include "paddle/fluid/compiler/piano/note/note.pb.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace piano {
namespace note {

// Map the given template parameter data type (eg, float)
// to the corresponding element proto type (eg, F32).
template <typename NativeT>
ElementTypeProto NativeToElementTypeProto() {
static_assert(!std::is_same<NativeT, NativeT>::value,
"Cannot map this native type to a proto type.");
return note::INVALID_ELEMENT_TYPE;
}

// Declarations of specializations for each native type
// which correspond to a ElementTypeProto.
template <>
ElementTypeProto NativeToElementTypeProto<bool>();
template <>
ElementTypeProto NativeToElementTypeProto<int8_t>();
template <>
ElementTypeProto NativeToElementTypeProto<int16_t>();
template <>
ElementTypeProto NativeToElementTypeProto<int32_t>();
template <>
ElementTypeProto NativeToElementTypeProto<int64_t>();
template <>
ElementTypeProto NativeToElementTypeProto<uint8_t>();
template <>
ElementTypeProto NativeToElementTypeProto<uint16_t>();
template <>
ElementTypeProto NativeToElementTypeProto<uint32_t>();
template <>
ElementTypeProto NativeToElementTypeProto<uint64_t>();
template <>
ElementTypeProto NativeToElementTypeProto<platform::float16>();
template <>
ElementTypeProto NativeToElementTypeProto<float>();
template <>
ElementTypeProto NativeToElementTypeProto<double>();

// Map the given element proto type (eg, F32)
// to the corresponding native data type (eg, float).
template <ElementTypeProto>
struct ElementTypeProtoToNativeT;

// Declarations of specializations for each ElementTypeProto
// which correspond to a native type
template <>
struct ElementTypeProtoToNativeT<B1> {
using type = bool;
};
template <>
struct ElementTypeProtoToNativeT<S8> {
using type = int8_t;
};
template <>
struct ElementTypeProtoToNativeT<S16> {
using type = int16_t;
};
template <>
struct ElementTypeProtoToNativeT<S32> {
using type = int32_t;
};
template <>
struct ElementTypeProtoToNativeT<S64> {
using type = int64_t;
};
template <>
struct ElementTypeProtoToNativeT<U8> {
using type = uint8_t;
};
template <>
struct ElementTypeProtoToNativeT<U16> {
using type = uint16_t;
};
template <>
struct ElementTypeProtoToNativeT<U32> {
using type = uint32_t;
};
template <>
struct ElementTypeProtoToNativeT<U64> {
using type = uint64_t;
};
template <>
struct ElementTypeProtoToNativeT<F16> {
using type = platform::float16;
};
template <>
struct ElementTypeProtoToNativeT<F32> {
using type = float;
};
template <>
struct ElementTypeProtoToNativeT<F64> {
using type = double;
};

} // namespace note
} // namespace piano
} // namespace paddle
67 changes: 67 additions & 0 deletions paddle/fluid/compiler/piano/note/element_type_util_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/compiler/piano/note/element_type_util.h"
#include <type_traits>
#include "gtest/gtest.h"
#include "paddle/fluid/compiler/piano/note/note.pb.h"

namespace paddle {
namespace piano {
namespace note {

TEST(NativeToElementTypeProtoTest, Basic) {
ASSERT_EQ(B1, NativeToElementTypeProto<bool>());
ASSERT_EQ(S8, NativeToElementTypeProto<int8_t>());
ASSERT_EQ(S16, NativeToElementTypeProto<int16_t>());
ASSERT_EQ(S32, NativeToElementTypeProto<int32_t>());
ASSERT_EQ(S64, NativeToElementTypeProto<int64_t>());
ASSERT_EQ(U8, NativeToElementTypeProto<uint8_t>());
ASSERT_EQ(U16, NativeToElementTypeProto<uint16_t>());
ASSERT_EQ(U32, NativeToElementTypeProto<uint32_t>());
ASSERT_EQ(U64, NativeToElementTypeProto<uint64_t>());
ASSERT_EQ(F16, NativeToElementTypeProto<platform::float16>());
ASSERT_EQ(F32, NativeToElementTypeProto<float>());
ASSERT_EQ(F64, NativeToElementTypeProto<double>());
}

TEST(ElementTypeProtoToNativeTTest, Basic) {
ASSERT_TRUE((std::is_same<bool, ElementTypeProtoToNativeT<B1>::type>::value));
ASSERT_TRUE(
(std::is_same<int8_t, ElementTypeProtoToNativeT<S8>::type>::value));
ASSERT_TRUE(
(std::is_same<int16_t, ElementTypeProtoToNativeT<S16>::type>::value));
ASSERT_TRUE(
(std::is_same<int32_t, ElementTypeProtoToNativeT<S32>::type>::value));
ASSERT_TRUE(
(std::is_same<int64_t, ElementTypeProtoToNativeT<S64>::type>::value));
ASSERT_TRUE(
(std::is_same<uint8_t, ElementTypeProtoToNativeT<U8>::type>::value));
ASSERT_TRUE(
(std::is_same<uint16_t, ElementTypeProtoToNativeT<U16>::type>::value));
ASSERT_TRUE(
(std::is_same<uint32_t, ElementTypeProtoToNativeT<U32>::type>::value));
ASSERT_TRUE(
(std::is_same<uint64_t, ElementTypeProtoToNativeT<U64>::type>::value));
ASSERT_TRUE((std::is_same<platform::float16,
ElementTypeProtoToNativeT<F16>::type>::value));
ASSERT_TRUE(
(std::is_same<float, ElementTypeProtoToNativeT<F32>::type>::value));
ASSERT_TRUE(
(std::is_same<double, ElementTypeProtoToNativeT<F64>::type>::value));
}

} // namespace note
} // namespace piano
} // namespace paddle
Loading

0 comments on commit 00c4edb

Please sign in to comment.