Skip to content

Commit

Permalink
support copy_ on different devices and dtypes (#9888)
Browse files Browse the repository at this point in the history
支持 copy_ 的输入和输出的 device 和 dtype 都不同的情况

---------

Signed-off-by: daquexian <daquexian566@gmail.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 28, 2023
1 parent a872fb8 commit d6038c6
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 65 deletions.
24 changes: 14 additions & 10 deletions oneflow/api/python/functional/tensor_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "oneflow/api/python/functional/tensor_api.yaml.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/common/scalar.h"
#include "oneflow/core/framework/mutable_attr_map.h"
#include "oneflow/core/framework/stream.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
Expand Down Expand Up @@ -226,17 +227,20 @@ class GlobalTensorWithShapeGenericCtorFunctor {
class AssignLocalTensorFunctor {
public:
AssignLocalTensorFunctor() {
op_ = CHECK_JUST(one::OpBuilder("assign").Input("ref").Input("value").Build());
op_ = CHECK_JUST(one::OpBuilder("copy").Input("in").Output("out").Build());
}
Maybe<void> operator()(const std::shared_ptr<one::Tensor>& ref,
const std::shared_ptr<one::Tensor>& value) const {
// JUST(CheckInplaceValid(ref)); // align check to torch
CHECK_OR_RETURN(ref->is_local() && value->is_local())
<< "Both ref and value must be local tensor.";
std::shared_ptr<one::Tensor> src = value;
if (ref->dtype() != src->dtype()) { src = JUST(To(src, ref->dtype(), false)); }
JUST(OpInterpUtil::Dispatch<TensorTuple>(*op_, {ref, src}));
return Maybe<void>::Ok();
Maybe<void> operator()(const std::shared_ptr<one::Tensor>& y,
const std::shared_ptr<one::Tensor>& x) const {
// JUST(CheckInplaceValid(y)); // align check to torch
CHECK_OR_RETURN(y->is_local() && x->is_local()) << "Both x and y must be local tensor.";
std::shared_ptr<one::Tensor> src = x;
if (y->dtype() != src->dtype()) { src = JUST(To(src, y->dtype(), false)); }

auto device = JUST(y->device());
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("device", "pin_memory");
attrs.SetAllAttrs(device, false);
TensorTuple outputs{y};
return OpInterpUtil::Dispatch(*op_, {x}, &outputs, attrs);
}

private:
Expand Down
60 changes: 19 additions & 41 deletions python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,50 +298,28 @@ def _copy_from_numpy_to_eager_local_tensor(eager_local_tensor, np_arr):


def _copy(self, other: Union[Tensor, np.ndarray]):
# Possibility 1: self and other are tensors on the same device/placement and have the same sbp.
if isinstance(other, Tensor):
if self.is_global:
assert (
other.is_global
), "Only global tensor can be assigned to global tensor."
if self.placement == other.placement and self.sbp == other.sbp:
flow._C.assign_local_tensor(self.to_local(), other.to_local())
return
else:
assert (
not other.is_global
), "Only local tensor can be assigned to local tensor."
if self.device == other.device:
other = flow._C.broadcast_like(other, self)
if not self.is_contiguous():
# NOTE: slice_update support non-contiguous input tensor
with flow.no_grad():
self[...] = other
else:
flow._C.assign_local_tensor(self, other)
return

# Possibility 2: `other` is a numpy array, or `self` and `other` are tensors on different devices/placements.
# In this case, we run boxing through cpu to avoid extra gpu memory usage.
if isinstance(other, np.ndarray):
other = flow.from_numpy(other)
elif not isinstance(other, Tensor):
other = flow.tensor(other)
other = other.to(self.dtype)
if self.is_global:
self_cpu_placement = flow.placement("cpu", self.placement.ranks)
if isinstance(other, Tensor):
assert other.is_global, "Only global tensor can be assigned to global tensor."
if not (self.sbp == other.sbp and self.placement == other.placement):
other_cpu_placement = flow.placement("cpu", other.placement.ranks)
other = other.to_global(placement=other_cpu_placement).to_global(
placement=self_cpu_placement, sbp=self.sbp
)
else:
other = flow.tensor(
other, dtype=self.dtype, placement=self_cpu_placement, sbp=self.sbp
)
_copy_from_numpy_to_eager_local_tensor(
self.to_local(), other.to_local().numpy()
)
other = other.to_global(placement=other_cpu_placement)
self_cpu_placement = flow.placement("cpu", self.placement.ranks)
other = other.to_global(placement=self_cpu_placement, sbp=self.sbp)
flow._C.assign_local_tensor(self.to_local(), other.to_local())
else:
if isinstance(other, Tensor):
other = other.numpy()

_copy_from_numpy_to_eager_local_tensor(self, other)
assert other.is_local, "Only local tensor can be assigned to local tensor."
other = flow._C.broadcast_like(other, self)
if not self.is_contiguous():
# NOTE: slice_update support non-contiguous input tensor
with flow.no_grad():
self[...] = other
else:
flow._C.assign_local_tensor(self, other)


def _format(self, format_spec):
Expand Down
30 changes: 23 additions & 7 deletions python/oneflow/test/modules/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,31 @@ def test_copy_fp16(test_case):
x.copy_(a)
test_case.assertTrue(np.array_equal(x.numpy(), a))

@flow.unittest.skip_unless_1n1d()
def test_tensor_inplace_copy_with_diff_dtype(test_case):
np_arr = np.random.randn(4, 12)
x = flow.tensor(np_arr)
y = flow.tensor(np_arr, dtype=flow.int)
x = flow.randn(4, 12).to(flow.int)
y = flow.randn(4, 12)
y.copy_(x)
a = ori_torch.tensor(np_arr)
b = ori_torch.tensor(np_arr, dtype=ori_torch.int)
test_case.assertTrue(np.array_equal(y.numpy(), b.cpu().numpy()))
test_case.assertTrue(np.array_equal(y.numpy(), x.numpy()))

@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
def test_tensor_inplace_copy_with_diff_dtype_and_device(test_case):
x = flow.randn(4, 12).to(flow.int)
y = flow.randn(4, 12).to("cuda")
y.copy_(x)
test_case.assertTrue(np.array_equal(y.numpy(), x.numpy()))

@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
def test_global_tensor_inplace_copy_with_diff_dtype_and_device(test_case):
x = (
flow.randn(4, 12)
.to(flow.int)
.to_global(placement=flow.placement.all("cpu"), sbp=flow.sbp.broadcast)
)
y = flow.randn(4, 12).to_global(
placement=flow.placement.all("cuda"), sbp=flow.sbp.broadcast
)
y.copy_(x)
test_case.assertTrue(np.array_equal(y.numpy(), x.numpy()))


if __name__ == "__main__":
Expand Down
7 changes: 0 additions & 7 deletions python/oneflow/test/tensor/test_global_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,6 @@ def test_copy(test_case):
x.copy_(y)
test_case.assertTrue(np.array_equal(x.numpy(), y.numpy()))

x = flow.zeros(
4, 6, placement=flow.placement("cuda", [0, 1]), sbp=flow.sbp.broadcast
)
y = np.ones((4, 6), dtype=np.float32)
x.copy_(y)
test_case.assertTrue(np.array_equal(x.numpy(), y))


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,16 @@ def __init__(self, name, pytorch, oneflow):
k: v.detach() for (k, v) in oneflow_state_dict.items()
}
already_global = any([v.is_global for v in oneflow_state_dict.values()])
if is_global() and already_global:
for k, v in state_dict.items():
if k not in oneflow_state_dict:
continue
of_state = oneflow_state_dict[k]
if of_state.is_global:
state_dict[k] = flow.tensor(
v, sbp=of_state.sbp, placement=of_state.placement
)

oneflow.load_state_dict(state_dict, strict=False)

if is_global():
Expand Down

0 comments on commit d6038c6

Please sign in to comment.