Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Hackathon 3rd No.2 ] add paddle.iinfo #45321

Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
204fd69
not finished yet
OccupyMars2025 Aug 2, 2022
4008fc1
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task2-add-iinfo
OccupyMars2025 Aug 4, 2022
645247e
Update pybind.cc
OccupyMars2025 Aug 22, 2022
8da6c55
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task2-add-iinfo
OccupyMars2025 Aug 22, 2022
17e54e0
delete unrelated files
OccupyMars2025 Aug 22, 2022
89ad9f8
Update pybind.cc
OccupyMars2025 Aug 22, 2022
3e4c1fc
import iinfo
OccupyMars2025 Aug 22, 2022
dbea60b
Create test_iinfo.py
OccupyMars2025 Aug 22, 2022
802ad53
Update test_iinfo.py
OccupyMars2025 Aug 22, 2022
2d2ffd0
Update test_iinfo.py
OccupyMars2025 Aug 22, 2022
1cd38bd
more properties
OccupyMars2025 Aug 22, 2022
5589aa8
add finfo
OccupyMars2025 Aug 22, 2022
c2ba90f
Update test_iinfo_and_finfo.py
OccupyMars2025 Aug 22, 2022
f67e3d7
Update test_iinfo_and_finfo.py
OccupyMars2025 Aug 22, 2022
e44f818
Update pybind.cc
OccupyMars2025 Aug 22, 2022
c5fa801
Update pybind.cc
OccupyMars2025 Aug 22, 2022
74ab8d5
resolution not equal
OccupyMars2025 Aug 22, 2022
9b318ab
Update test_iinfo_and_finfo.py
OccupyMars2025 Aug 22, 2022
94ef052
type check
OccupyMars2025 Aug 22, 2022
169add9
Update test_iinfo_and_finfo.py
OccupyMars2025 Aug 22, 2022
a3bb736
check dtype property
OccupyMars2025 Aug 22, 2022
75484ec
Update pybind.cc
OccupyMars2025 Aug 22, 2022
4a724e9
add only iinfo
OccupyMars2025 Aug 27, 2022
145ee68
Update test_iinfo_and_finfo.py
OccupyMars2025 Aug 27, 2022
01e7c7f
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task1-add-finf…
OccupyMars2025 Aug 27, 2022
b78dc08
Update test_iinfo_and_finfo.py
OccupyMars2025 Aug 28, 2022
f61de70
add python wrapper and __repr__
OccupyMars2025 Aug 28, 2022
5c6f96f
modify __repr__ method
OccupyMars2025 Aug 28, 2022
ae1d259
add docstring
OccupyMars2025 Aug 28, 2022
eb54abe
add iinfo to __all__ list
OccupyMars2025 Aug 29, 2022
0d352f5
Update dtype.py
OccupyMars2025 Aug 30, 2022
dd92f76
test pre-commit
OccupyMars2025 Aug 30, 2022
afe4e96
Merge branch 'PaddlePaddle:develop' into hackathon-3rd-task1-add-finf…
OccupyMars2025 Sep 4, 2022
cf221c9
Update core.py
OccupyMars2025 Sep 4, 2022
05dd1a7
Update dtype.py
OccupyMars2025 Sep 6, 2022
a013bdf
Update dtype.py
OccupyMars2025 Sep 7, 2022
bbdb3b5
Update dtype.py
OccupyMars2025 Sep 7, 2022
32ea2fa
Update dtype.py
OccupyMars2025 Sep 7, 2022
623179f
for punctuation mark
Ligoml Sep 7, 2022
9113d68
Update grad_node_info.h
OccupyMars2025 Sep 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include <map>
#include <memory>
#include <mutex> // NOLINT // for call_once
#include <sstream>
#include <string>
#include <tuple>
#include <type_traits>
Expand Down Expand Up @@ -346,6 +347,52 @@ bool IsCompiledWithDIST() {
#endif
}

struct iinfo {
int64_t min, max;
int bits;
std::string dtype;

explicit iinfo(const framework::proto::VarType::Type &type) {
switch (type) {
case framework::proto::VarType::INT16:
min = std::numeric_limits<int16_t>::min();
max = std::numeric_limits<int16_t>::max();
bits = 16;
dtype = "int16";
break;
case framework::proto::VarType::INT32:
min = std::numeric_limits<int32_t>::min();
max = std::numeric_limits<int32_t>::max();
bits = 32;
dtype = "int32";
break;
case framework::proto::VarType::INT64:
min = std::numeric_limits<int64_t>::min();
max = std::numeric_limits<int64_t>::max();
bits = 64;
dtype = "int64";
break;
case framework::proto::VarType::INT8:
min = std::numeric_limits<int8_t>::min();
max = std::numeric_limits<int8_t>::max();
bits = 8;
dtype = "int8";
break;
case framework::proto::VarType::UINT8:
min = std::numeric_limits<uint8_t>::min();
max = std::numeric_limits<uint8_t>::max();
bits = 8;
dtype = "uint8";
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"the argument of paddle.iinfo can only be paddle.int8, "
"paddle.int16, paddle.int32, paddle.int64, or paddle.uint8"));
break;
}
}
};

static PyObject *GetPythonAttribute(PyObject *obj, const char *attr_name) {
// NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name
// is not inside obj, but it would also set the error flag of Python.
Expand Down Expand Up @@ -555,6 +602,21 @@ PYBIND11_MODULE(core_noavx, m) {

BindException(&m);

py::class_<iinfo>(m, "iinfo")
.def(py::init<const framework::proto::VarType::Type &>())
.def_readonly("min", &iinfo::min)
.def_readonly("max", &iinfo::max)
.def_readonly("bits", &iinfo::bits)
.def_readonly("dtype", &iinfo::dtype)
.def("__repr__", [](const iinfo &a) {
std::ostringstream oss;
oss << "paddle.iinfo(min=" << a.min;
oss << ", max=" << a.max;
oss << ", bits=" << a.bits;
oss << ", dtype=" << a.dtype << ")";
return oss.str();
});

m.def("set_num_threads", &platform::SetNumThreads);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我自己测试发现用 "oss <<"的方法才能正确显示浮点数,而 std::to_string() 会将很小的正浮点数显示为0


m.def("disable_signal_handler", &DisableSignalHandler);
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .fluid.dataset import * # noqa: F401
from .fluid.lazy_init import LazyGuard # noqa: F401

from .framework.dtype import iinfo # noqa: F401
from .framework.dtype import dtype as dtype # noqa: F401
from .framework.dtype import uint8 # noqa: F401
from .framework.dtype import int8 # noqa: F401
Expand Down Expand Up @@ -386,6 +387,7 @@
disable_static()

__all__ = [ # noqa
'iinfo',
'dtype',
'uint8',
'int8',
Expand Down
45 changes: 45 additions & 0 deletions python/paddle/fluid/tests/unittests/test_iinfo_and_finfo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

OccupyMars2025 marked this conversation as resolved.
Show resolved Hide resolved
import paddle
import unittest
import numpy as np


class TestIInfoAndFInfoAPI(unittest.TestCase):

def test_invalid_input(self):
for dtype in [
paddle.float16, paddle.float32, paddle.float64, paddle.bfloat16,
paddle.complex64, paddle.complex128, paddle.bool
]:
with self.assertRaises(ValueError):
_ = paddle.iinfo(dtype)

def test_iinfo(self):
for paddle_dtype, np_dtype in [(paddle.int64, np.int64),
(paddle.int32, np.int32),
(paddle.int16, np.int16),
(paddle.int8, np.int8),
(paddle.uint8, np.uint8)]:
xinfo = paddle.iinfo(paddle_dtype)
xninfo = np.iinfo(np_dtype)
self.assertEqual(xinfo.bits, xninfo.bits)
self.assertEqual(xinfo.max, xninfo.max)
self.assertEqual(xinfo.min, xninfo.min)
self.assertEqual(xinfo.dtype, xninfo.dtype)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是进行整数的比较,用 "=="运算符就可以比较,感觉在这里还是可以用self.assertEqual


if __name__ == '__main__':
unittest.main()
36 changes: 35 additions & 1 deletion python/paddle/framework/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from ..fluid.core import VarDesc
from ..fluid.core import iinfo as core_iinfo

dtype = VarDesc.VarType
dtype.__qualname__ = "dtype"
Expand All @@ -34,4 +35,37 @@

bool = VarDesc.VarType.BOOL

__all__ = []

def iinfo(dtype):
"""

paddle.iinfo is a function that returns an object that represents the numerical properties of
an integer paddle.dtype.
This is similar to `numpy.iinfo <https://numpy.org/doc/stable/reference/generated/numpy.iinfo.html#numpy-iinfo>`_.

Args:
dtype(paddle.dtype): One of paddle.uint8, paddle.int8, paddle.int16, paddle.int32, and paddle.int64.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

明确说明支持哪些类型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改了

Returns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

这里如果想让四个属性分行展示,需要修改一下格式,参考:
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

An iinfo object, which has the following 4 attributes:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

明确说明返回的对象包含哪些属性,针对每个属性,解释清楚其含义

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

- min: int, The smallest representable integer number.
- max: int, The largest representable integer number.
- bits: int, The number of bits occupied by the type.
- dtype: str, The string name of the argument dtype.

Examples:
.. code-block:: python

import paddle

iinfo_uint8 = paddle.iinfo(paddle.uint8)
print(iinfo_uint8)
# paddle.iinfo(min=0, max=255, bits=8, dtype=uint8)
print(iinfo_uint8.min) # 0
print(iinfo_uint8.max) # 255
print(iinfo_uint8.bits) # 8
print(iinfo_uint8.dtype) # uint8

"""
return core_iinfo(dtype)