Skip to content

Commit

Permalink
Merge branch 'master' into add_check_of_placement_constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
oneflow-ci-bot authored Dec 14, 2021
2 parents 5b0914f + 95355aa commit c432c52
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 21 deletions.
1 change: 1 addition & 0 deletions docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ oneflow
set_printoptions,
decode_onerec,
read_onerec,
from_numpy,

.. autofunction:: oneflow.relu
.. autofunction:: oneflow.env.get_rank
57 changes: 55 additions & 2 deletions oneflow/api/python/functional/tensor_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include <Python.h>
#include <memory>

#include "oneflow/api/python/utils/tensor_utils.h"
#include "oneflow/api/python/functional/common.h"
Expand Down Expand Up @@ -59,9 +60,10 @@ class TensorWithDataFunctor {

const auto& other = JUST(PyUnpackTensor(data));
return MakeTensorFromOtherTensor(other, dtype, device, requires_grad);
} else {
// Make tensor from python sequence or numpy array.
return MakeLocalTensorFromData(data, dtype, device, requires_grad);
}
// Make tensor from python sequence or numpy array.
return MakeLocalTensorFromData(data, dtype, device, requires_grad);
}
};

Expand Down Expand Up @@ -213,6 +215,56 @@ class AssignLocalTensorFunctor {
std::shared_ptr<OpExpr> op_;
};

class LocalTensorSharedNumpyDataFunctor {
public:
LocalTensorSharedNumpyDataFunctor() {}
Maybe<Tensor> operator()(PyObject* obj) const {
if (!PyArray_Check(obj)) {
return Error::TypeError() << "expected np.ndarray, but got " << Py_TYPE(obj)->tp_name;
}
auto* array = reinterpret_cast<PyArrayObject*>(obj);

// Build TensorMeta
int32_t dim = PyArray_NDIM(array);
const npy_intp* dims_ptr = PyArray_SHAPE(array);
const auto shape = std::make_shared<Shape>(DimVector(dims_ptr, dims_ptr + dim));
DataType data_type = JUST(numpy::GetOFDataTypeFromNpArray(array));
Symbol<Device> device = JUST(Device::New("cpu"));
const npy_intp* stride_ptr = PyArray_STRIDES(array);
const auto stride = std::make_shared<Stride>(DimVector(stride_ptr, stride_ptr + dim));
auto tensor_meta = std::make_shared<MirroredTensorMeta>(shape, data_type, device, stride, 0);

// Build TensorBuffer
const auto& Free = [obj](char* dptr) {
py::gil_scoped_acquire gil;
Py_DECREF(obj);
};
Py_INCREF(obj); // make TensorBuffer hold ndarray
void* data_ptr = PyArray_DATA(array);
auto array_size_in_bytes = PyArray_NBYTES(array);
auto tensor_data = std::make_shared<vm::TensorStorage>();
tensor_data->set_blob_dptr(
std::unique_ptr<char, std::function<void(char*)>>(static_cast<char*>(data_ptr), Free),
array_size_in_bytes);

// Build TensorStorage: decrease ndarray reference count before releasing
auto tensor_storage = std::make_shared<TensorStorage>(tensor_data);

// Build Tensor
auto tensor_impl = std::make_shared<EagerMirroredTensorImpl>(tensor_meta, tensor_storage,
/*requires_grad=*/false,
/*ls_leaf=*/true);

// Init blob
JUST(tensor_impl->InitEagerBlobObject(JUST(GetLocalDepObject4Device(*device))));
JUST(tensor_impl->eager_blob_object())->set_last_used_device(device);
JUST(JUST(tensor_impl->eager_blob_object())->TryInitBlob());
JUST(tensor_impl->eager_blob_object())->mut_blob()->reset_dptr(static_cast<char*>(data_ptr));
std::shared_ptr<Tensor> out(new MirroredTensor(tensor_impl));
return out;
}
};

} // namespace impl

ONEFLOW_FUNCTION_LIBRARY(m) {
Expand All @@ -226,6 +278,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::TensorWithShapeCtorFunctor>("TensorWithShapeCtor");
m.add_functor<impl::ConsistentTensorWithShapeCtorFunctor>("ConsistentTensorWithShapeCtor");
m.add_functor<impl::AssignLocalTensorFunctor>("AssignLocalTensorFunctor");
m.add_functor<impl::LocalTensorSharedNumpyDataFunctor>("LocalTensorSharedNumpyData");
}

} // namespace functional
Expand Down
35 changes: 20 additions & 15 deletions oneflow/api/python/functional/tensor_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,30 @@

- name: "tensor"
signature: [
"Tensor (PyObject* data, *, DataType dtype=None, Device device=None,
Bool requires_grad=False) => TensorWithData",
"Tensor (PyObject* data, *, DataType dtype=None, Placement placement,
SbpList sbp, Bool requires_grad=False) => ConsistentTensorWithData",
]
"Tensor (PyObject* data, *, DataType dtype=None, Device device=None,
Bool requires_grad=False) => TensorWithData",
"Tensor (PyObject* data, *, DataType dtype=None, Placement placement,
SbpList sbp, Bool requires_grad=False) => ConsistentTensorWithData",
]
bind_python: True

- name: "_legacy_tensor_ctor"
signature: [
"Tensor (*, Device device=None) => TensorEmptyCtor",
"Tensor (*, Placement placement, SbpList sbp) => ConsistentTensorEmptyCtor",
"Tensor (Tensor other) => TensorWithOtherCtor",
"Tensor (PyObject* data, *, Device device=None) => TensorWithDataCtor",
"Tensor (PyObject* data, *, Placement placement, SbpList sbp) => ConsistentTensorWithDataCtor",
"Tensor (Shape size, *, Device device=None) => TensorWithShapeCtor",
"Tensor (Shape size, *, Placement placement, SbpList sbp) => ConsistentTensorWithShapeCtor",
]
signature:
[
"Tensor (*, Device device=None) => TensorEmptyCtor",
"Tensor (*, Placement placement, SbpList sbp) => ConsistentTensorEmptyCtor",
"Tensor (Tensor other) => TensorWithOtherCtor",
"Tensor (PyObject* data, *, Device device=None) => TensorWithDataCtor",
"Tensor (PyObject* data, *, Placement placement, SbpList sbp) => ConsistentTensorWithDataCtor",
"Tensor (Shape size, *, Device device=None) => TensorWithShapeCtor",
"Tensor (Shape size, *, Placement placement, SbpList sbp) => ConsistentTensorWithShapeCtor",
]
bind_python: True

- name: "assign_local_tensor"
signature: "Void (Tensor ref, Tensor value)=> AssignLocalTensorFunctor"
signature: "Void (Tensor ref, Tensor value) => AssignLocalTensorFunctor"
bind_python: True

- name: "from_numpy"
signature: "Tensor (PyObject* obj) => LocalTensorSharedNumpyData"
bind_python: True
2 changes: 1 addition & 1 deletion oneflow/core/framework/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Maybe<Tensor> BasicView(const std::shared_ptr<Tensor>& input, const Shape& targe
std::make_shared<Shape>(target_shape), input->dtype()->data_type(), device,
std::make_shared<Stride>(target_strides), storage_offset);

JUST(input->has_eager_blob_object());
CHECK_OR_RETURN(JUST(input->has_eager_blob_object()));
// new output tensor
const auto& blob_object = JUST(input->eager_blob_object());
auto tensor_impl = std::make_shared<EagerMirroredTensorImpl>(
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def atexit_hook(hook):

import oneflow._C
from oneflow._C import tensor, batch_gather
from oneflow._C import from_numpy

from oneflow.autograd import grad_enable, no_grad, inference_mode, is_grad_enabled
import oneflow.nn.image
Expand Down
29 changes: 29 additions & 0 deletions python/oneflow/framework/docstr/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,35 @@
""",
)

add_docstr(
oneflow.from_numpy,
r"""
Creates a ``Tensor`` from a ``numpy.ndarray``.
The returned tensor and ndarray share the same memory. Modifications to the tensor
will be reflected in the ndarray and vice versa.
It currently accepts ndarray with dtypes of numpy.float64, numpy.float32, numpy.float16,
numpy.int64, numpy.int32, numpy.int8, numpy.uint8.
For example:
.. code-block:: python
>>> import oneflow as flow
>>> import numpy as np
>>> np_arr = np.arange(6).reshape(2, 3)
>>> t = flow.from_numpy(np_arr)
>>> t
tensor([[0, 1, 2],
[3, 4, 5]], dtype=oneflow.int64)
>>> np_arr[0, 0] = -1
>>> t
tensor([[-1, 1, 2],
[ 3, 4, 5]], dtype=oneflow.int64)
""",
)

add_docstr(
oneflow.Tensor.atan2,
r"""
Expand Down
59 changes: 59 additions & 0 deletions python/oneflow/test/modules/test_from_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import random
import unittest

import numpy as np
import oneflow as flow
import oneflow.unittest


@flow.unittest.skip_unless_1n1d()
class TestFromNumpy(flow.unittest.TestCase):
def test_same_data(test_case):
np_arr = np.random.randn(3, 4, 5)
tensor = flow.from_numpy(np_arr)
test_case.assertTrue(np.array_equal(np_arr, tensor.numpy()))

np_arr[1:2, 2:3, 3:4] = random.random()
test_case.assertTrue(np.array_equal(np_arr, tensor.numpy()))

def test_use_ops(test_case):
np_arr = np.random.randn(3, 4, 5)
tensor = flow.from_numpy(np_arr)
res = tensor ** 2
test_case.assertTrue(np.allclose(np_arr ** 2, res.numpy()))

def test_more_dtype(test_case):
for dtype in [
np.float64,
np.float32,
np.float16,
np.int64,
np.int32,
np.int8,
np.uint8,
]:
np_arr = np.ones((2, 3), dtype=dtype)
tensor = flow.from_numpy(np_arr)
# TODO(wyg): oneflow.float16 do not support to copy from tensor to numpy
if tensor.dtype not in [flow.float16]:
test_case.assertTrue(np.array_equal(np_arr, tensor.numpy()))


if __name__ == "__main__":
unittest.main()
7 changes: 4 additions & 3 deletions python/oneflow/test/modules/test_functional_docstr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import inspect
import os
import unittest
from collections import OrderedDict

from test_util import GenArgList

import oneflow as flow
import oneflow.unittest

from test_util import GenArgList


def _run_functional_doctest(
test_case,
Expand Down Expand Up @@ -59,4 +59,5 @@ def test_functional_docstr(test_case):


if __name__ == "__main__":
flow.set_printoptions(linewidth=80)
unittest.main()

0 comments on commit c432c52

Please sign in to comment.