diff --git a/oneflow/api/python/functional/common.cpp b/oneflow/api/python/functional/common.cpp index 1a5b39088f9..f525483789e 100644 --- a/oneflow/api/python/functional/common.cpp +++ b/oneflow/api/python/functional/common.cpp @@ -19,6 +19,7 @@ limitations under the License. #include #include "oneflow/api/python/functional/indexing.h" +#include "oneflow/extension/python/numpy.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/dtype.h" @@ -195,18 +196,20 @@ Maybe>> PyUnpackSbpParallelSequence(PyObject* ob // Tensor index bool PyTensorIndexCheck(PyObject* obj) { - return PySlice_Check(obj) || PyLong_Check(obj) || obj == Py_Ellipsis || obj == Py_None - || PyTensorCheck(obj) || PySequence_Check(obj) || PyUnicode_Check(obj); + return PySlice_Check(obj) || PyLong_Check(obj) || numpy::PyArrayCheckLongScalar(obj) + || obj == Py_Ellipsis || obj == Py_None || PyTensorCheck(obj) || PySequence_Check(obj) + || PyUnicode_Check(obj); } Maybe PyUnpackTensorIndex(PyObject* obj) { auto tensor_index = std::make_shared(); // Obvious single-entry cases. - if (PySlice_Check(obj) // NOLINT - || PyLong_Check(obj) // NOLINT - || obj == Py_Ellipsis // NOLINT - || obj == Py_None // NOLINT - || PyTensorCheck(obj) // NOLINT - || !PySequence_Check(obj) // NOLINT + if (PySlice_Check(obj) // NOLINT + || PyLong_Check(obj) // NOLINT + || numpy::PyArrayCheckLongScalar(obj) // NOLINT + || obj == Py_Ellipsis // NOLINT + || obj == Py_None // NOLINT + || PyTensorCheck(obj) // NOLINT + || !PySequence_Check(obj) // NOLINT || PyUnicode_Check(obj)) { tensor_index->emplace_back(*JUST(detail::UnpackIndexItem(obj))); return tensor_index; diff --git a/oneflow/api/python/functional/indexing.cpp b/oneflow/api/python/functional/indexing.cpp index 467f7a526e7..d92cb23f86d 100644 --- a/oneflow/api/python/functional/indexing.cpp +++ b/oneflow/api/python/functional/indexing.cpp @@ -65,6 +65,8 @@ Maybe InferScalarType(PyObject* object) { return DataType::kInt64; } else if (PyArray_Check(object)) { return numpy::GetOFDataTypeFromNpArray(reinterpret_cast(object)); + } else if (PyArray_CheckScalar(object)) { + return numpy::NumpyTypeToOFDataType(PyArray_DescrFromScalar(object)->type_num); } else if (PySequence_Check(object)) { int64_t length = PySequence_Length(object); CHECK_GT_OR_RETURN(length, 0) << "Index should not be empty."; @@ -86,13 +88,20 @@ Maybe InferScalarType(PyObject* object) { Maybe ParseScalar(PyObject* object, char* data, const DataType& dtype) { if (dtype == DataType::kInt64) { - CHECK_OR_RETURN(PyLong_Check(object)) << "Expected a long value."; + CHECK_OR_RETURN(PyLong_Check(object) || numpy::PyArrayCheckLongScalar(object)) + << "Expected a long value."; *(reinterpret_cast(data)) = PyLong_AsLongLong(object); return Maybe::Ok(); + } else if (dtype == DataType::kInt32) { + CHECK_OR_RETURN(PyLong_Check(object) || numpy::PyArrayCheckLongScalar(object)) + << "Expected a long value."; + *(reinterpret_cast(data)) = PyLong_AsLongLong(object); + return Maybe::Ok(); } else if (dtype == DataType::kUInt8 || dtype == DataType::kBool) { - CHECK_OR_RETURN(PyBool_Check(object) || PyLong_Check(object)) + CHECK_OR_RETURN(PyBool_Check(object) || PyLong_Check(object) + || numpy::PyArrayCheckLongScalar(object)) << "Expected a boolean or long value."; - if (PyBool_Check(object)) { + if (PyBool_Check(object) || numpy::PyArrayCheckBoolScalar(object)) { *(reinterpret_cast(data)) = (object == Py_True); } else { int64_t value = PyLong_AsLongLong(object); @@ -193,6 +202,8 @@ Maybe UnpackIndexItem(PyObject* object) { return std::make_shared(start, end, step); } else if (PyLong_Check(object) && object != Py_False && object != Py_True) { return std::make_shared(static_cast(PyLong_AsLongLong(object))); + } else if (numpy::PyArrayCheckLongScalar(object)) { + return std::make_shared(static_cast(PyLong_AsLongLong(object))); } else if (object == Py_False || object == Py_True) { return std::make_shared(object == Py_True); } else if (object == Py_None) { diff --git a/oneflow/api/python/functional/python_arg.cpp b/oneflow/api/python/functional/python_arg.cpp index dadeab426ef..ac48f1513d4 100644 --- a/oneflow/api/python/functional/python_arg.cpp +++ b/oneflow/api/python/functional/python_arg.cpp @@ -17,6 +17,7 @@ limitations under the License. #include "oneflow/api/python/functional/python_arg.h" #include "oneflow/api/python/functional/common.h" #include "oneflow/api/python/functional/indexing.h" +#include "oneflow/extension/python/numpy.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/device.h" @@ -198,21 +199,25 @@ Maybe PythonArg::TypeCheck(ValueType type) const { case kUINT32: case kINT64: case kUINT64: - case kBOOL: return PyLong_Check(object_); + case kBOOL: return PyLong_Check(object_) || numpy::PyArrayCheckLongScalar(object_); case kINT32_LIST: case kUINT32_LIST: case kINT64_LIST: case kUINT64_LIST: case kBOOL_LIST: return PyLongSequenceCheck(object_) || (size_ > 0 && PyLong_Check(object_)); case kFLOAT: - case kDOUBLE: return PyFloat_Check(object_) || PyLong_Check(object_); + case kDOUBLE: + return PyFloat_Check(object_) || PyLong_Check(object_) + || numpy::PyArrayCheckFloatScalar(object_) || numpy::PyArrayCheckLongScalar(object_); case kFLOAT_LIST: case kDOUBLE_LIST: return PyFloatSquenceCheck(object_) || (size_ > 0 && (PyFloat_Check(object_) || PyLong_Check(object_))); case kSTRING: return PyStringCheck(object_); case kSTRING_LIST: return PyStringSequenceCheck(object_); - case kSCALAR: return PyScalarCheck(object_); + case kSCALAR: + return PyScalarCheck(object_) || numpy::PyArrayCheckLongScalar(object_) + || numpy::PyArrayCheckFloatScalar(object_); case kTENSOR: case kTENSOR_REF: return PyTensorCheck(object_); case kTENSOR_TUPLE: return PyTensorTupleCheck(object_) || PyTensorSequenceCheck(object_); diff --git a/oneflow/api/python/utils/tensor_utils.h b/oneflow/api/python/utils/tensor_utils.h index af41cedcb3f..d2c9411cfef 100644 --- a/oneflow/api/python/utils/tensor_utils.h +++ b/oneflow/api/python/utils/tensor_utils.h @@ -31,7 +31,6 @@ limitations under the License. #include "oneflow/core/register/ofblob.h" #include "oneflow/core/common/blocking_then_busy.h" #include "oneflow/core/vm/virtual_machine.h" -#include "oneflow/extension/python/numpy.h" #include "oneflow/core/common/foreign_lock_helper.h" namespace py = pybind11; diff --git a/oneflow/extension/python/numpy.cpp b/oneflow/extension/python/numpy.cpp index 99a313cad32..2efee8e3fb7 100644 --- a/oneflow/extension/python/numpy.cpp +++ b/oneflow/extension/python/numpy.cpp @@ -90,6 +90,18 @@ std::vector OFStrideToNumpyStride(const StrideVector& fixed_vec, const D return result; } +bool PyArrayCheckLongScalar(PyObject* obj) { + return PyArray_CheckScalar(obj) && PyDataType_ISINTEGER(PyArray_DescrFromScalar(obj)); +} + +bool PyArrayCheckFloatScalar(PyObject* obj) { + return PyArray_CheckScalar(obj) && PyDataType_ISFLOAT(PyArray_DescrFromScalar(obj)); +} + +bool PyArrayCheckBoolScalar(PyObject* obj) { + return PyArray_CheckScalar(obj) && PyDataType_ISBOOL(PyArray_DescrFromScalar(obj)); +} + // Executing any numpy c api before _import_array() results in segfault // NOTE: this InitNumpyCAPI() works because of `PY_ARRAY_UNIQUE_SYMBOL` // defined in numpy_internal.h diff --git a/oneflow/extension/python/numpy_internal.h b/oneflow/extension/python/numpy_internal.h index 7a0375c9875..22e4a6694ce 100644 --- a/oneflow/extension/python/numpy_internal.h +++ b/oneflow/extension/python/numpy_internal.h @@ -62,6 +62,12 @@ std::vector OFShapeToNumpyShape(const DimVector& fixed_vec); std::vector OFStrideToNumpyStride(const StrideVector& fixed_vec, const DataType data_type); +bool PyArrayCheckLongScalar(PyObject* obj); + +bool PyArrayCheckFloatScalar(PyObject* obj); + +bool PyArrayCheckBoolScalar(PyObject* obj); + Maybe InitNumpyCAPI(); } // namespace numpy diff --git a/python/oneflow/test/tensor/test_tensor_indexing.py b/python/oneflow/test/tensor/test_tensor_indexing.py index fdaa3e799c2..cb238798e13 100644 --- a/python/oneflow/test/tensor/test_tensor_indexing.py +++ b/python/oneflow/test/tensor/test_tensor_indexing.py @@ -23,6 +23,67 @@ import oneflow.unittest +def _test_numpy_scalar_indexing(test_case, numpy_x, np_scalar): + x = flow.Tensor(numpy_x) + + # basic_slice + test_case.assertTrue(np.allclose(numpy_x[np_scalar(1)], x[np_scalar(1)].numpy())) + test_case.assertTrue(np.allclose(numpy_x[np_scalar(-2)], x[np_scalar(-2)].numpy())) + test_case.assertTrue( + np.allclose( + numpy_x[np_scalar(0), np_scalar(1)], x[np_scalar(0), np_scalar(1)].numpy() + ) + ) + test_case.assertTrue( + np.allclose( + numpy_x[(np_scalar(0), np_scalar(1))], + x[(np_scalar(0), np_scalar(1))].numpy(), + ) + ) + test_case.assertTrue( + np.allclose( + numpy_x[((np_scalar(0), np_scalar(1)))], + x[((np_scalar(0), np_scalar(1)))].numpy(), + ) + ) + + +def _test_numpy_scalar_advance_indexing(test_case, numpy_x, np_scalar): + x = flow.Tensor(numpy_x) + + # advance indexing + test_case.assertTrue( + np.allclose( + numpy_x[[np_scalar(0), np_scalar(1)]], + x[[np_scalar(0), np_scalar(1)]].numpy(), + ) + ) + test_case.assertTrue( + np.allclose( + numpy_x[[np_scalar(0), np_scalar(1)], [np_scalar(1), np_scalar(0)]], + x[[np_scalar(0), np_scalar(1)], [np_scalar(1), np_scalar(0)]].numpy(), + ) + ) + test_case.assertTrue( + np.allclose( + numpy_x[ + [ + [np_scalar(0), np_scalar(1)], + [np_scalar(0), np_scalar(1)], + [np_scalar(1), np_scalar(0)], + ] + ], + x[ + [ + [np_scalar(0), np_scalar(1)], + [np_scalar(0), np_scalar(1)], + [np_scalar(1), np_scalar(0)], + ] + ].numpy(), + ) + ) + + def _test_basic_slice(test_case, numpy_x): x = flow.tensor(numpy_x) @@ -278,6 +339,28 @@ def test_combining_indexing(test_case): numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32) _test_combining_indexing(test_case, numpy_x) + def test_numpy_scalar_indexing(test_case): + for np_scalar in [np.int8, np.int16, np.int32, np.int64]: + numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32) + _test_numpy_scalar_indexing(test_case, numpy_x, np_scalar) + + numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32) + _test_numpy_scalar_indexing(test_case, numpy_x, np_scalar) + + numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32) + _test_numpy_scalar_indexing(test_case, numpy_x, np_scalar) + + # TODO: add np.int16 when advance indexing supports np.int16 mapping + for np_scalar in [np.int32, np.int64]: + numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32) + _test_numpy_scalar_advance_indexing(test_case, numpy_x, np_scalar) + + numpy_x = np.arange(0, 360, 1).reshape([3, 4, 5, 6]).astype(np.float32) + _test_numpy_scalar_advance_indexing(test_case, numpy_x, np_scalar) + + numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32) + _test_numpy_scalar_advance_indexing(test_case, numpy_x, np_scalar) + def test_mask_getitem(test_case): numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32) _test_mask_getitem(test_case, numpy_x)