Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Tensor.__setitem__ and global related api to Python/C api #8375

Merged
merged 29 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ad99fb4
add local_to_global, global_to_global, to_global. global_to_global st…
marigoold Jun 6, 2022
9714170
fix bug of global_to_global
marigoold Jun 6, 2022
9dcec22
remove python api
marigoold Jun 6, 2022
c0dc06f
add setitem
marigoold Jun 6, 2022
ff77bf9
remove local_to_global sbp pack, format code
marigoold Jun 6, 2022
7b7a7d3
format code
marigoold Jun 6, 2022
5e52e9c
remove redundant code
marigoold Jun 6, 2022
1f09070
add error msg, refine check of to_global
marigoold Jun 6, 2022
08a1064
Merge branch 'master' into move_setitem_to_cpython
marigoold Jun 7, 2022
32832c7
Merge branch 'master' into move_setitem_to_cpython
marigoold Jun 8, 2022
35dc2de
fix bug of check
marigoold Jun 8, 2022
7a3a8da
add error msg
marigoold Jun 8, 2022
5eadac1
fix clang static check error
marigoold Jun 8, 2022
15cd3e7
remove useless api in tensor.py, remove redundant code, remove useles…
marigoold Jun 8, 2022
02026cd
Merge branch 'master' into move_setitem_to_cpython
marigoold Jun 8, 2022
14336a3
add to_local
marigoold Jun 9, 2022
cc95e36
Merge branch 'master' into move_setitem_to_cpython
marigoold Jun 10, 2022
5c6da7d
Merge branch 'master' into move_setitem_to_cpython
marigoold Jun 13, 2022
c9b56f7
fix wrong exception type in unittest for to_local exception message
marigoold Jun 13, 2022
3792593
Merge branch 'master' into move_setitem_to_cpython
marigoold Jun 13, 2022
56a7202
Merge branch 'master' into move_setitem_to_cpython
marigoold Jun 14, 2022
e14e795
Merge branch 'master' into move_setitem_to_cpython
marigoold Jun 14, 2022
7abbeb8
Merge branch 'master' into move_setitem_to_cpython
marigoold Jun 14, 2022
cdac22e
modify variable name
marigoold Jun 15, 2022
8bafe1b
Merge branch 'move_setitem_to_cpython' of github.com:Oneflow-Inc/onef…
marigoold Jun 15, 2022
546ee86
auto format by CI
oneflow-ci-bot Jun 15, 2022
0a9717e
Merge branch 'master' into move_setitem_to_cpython
marigoold Jun 15, 2022
04b7a8d
add global / local check in setitem
marigoold Jun 15, 2022
c1dd6a3
Merge branch 'master' into move_setitem_to_cpython
mergify[bot] Jun 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions oneflow/api/python/framework/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,26 +125,17 @@ static PyObject* PyTensorObject_subscript(PyObject* self, PyObject* item) {
END_HANDLE_ERRORS
}

static int PyTensorObject_ass_subscript(PyObject* self, PyObject* item, PyObject* value) {
HANDLE_ERRORS
const auto& p = PyTensor_Unpack(self);
const auto& v = PyTensor_Unpack(value);
functional::PythonArg arg(item);
ASSERT(functional::TensorSetItem(p, arg.As<functional::TensorIndex>(), v));
return 0;
END_HANDLE_ERRORS_RET(-1)
}

static PySequenceMethods PyTensorObject_as_sequence = {
(lenfunc)PyTensorObject_length, NULL, /*sq_concat*/
NULL, /*sq_repeat*/
(ssizeargfunc)PyTensorObject_getitem, /*sq_item*/
};

extern int PyTensorObject_setitem(PyObject*, PyObject*, PyObject*);
static PyMappingMethods PyTensorObject_as_mapping = {
(lenfunc)PyTensorObject_length,
(binaryfunc)PyTensorObject_subscript,
(objobjargproc)PyTensorObject_ass_subscript,
hjchen2 marked this conversation as resolved.
Show resolved Hide resolved
(objobjargproc)PyTensorObject_setitem,
};

static PyObject* PyTensorObject_storage_offset(PyObject* self, PyObject* unused) {
Expand Down
168 changes: 168 additions & 0 deletions oneflow/api/python/framework/tensor_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,168 @@ static PyObject* PyTensorObject_transpose(PyObject* self, PyObject* args, PyObje
END_HANDLE_ERRORS
}

static PyObject* PyTensorObject_local_to_global(PyObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_ERRORS
auto tensor = PyTensor_Unpack(self);
CHECK_OR_THROW(tensor->is_local()) << Error::RuntimeError() << "input must be a local tensor";
PyObject* placement_obj = Py_None;
PyObject* sbp_obj = Py_None;
bool check_meta = true;
static const char* keywords[4] = {"placement", "sbp", "check_meta", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO$O!:local_to_global",
const_cast<char**>(keywords), &placement_obj, &sbp_obj,
&PyBool_Type, &check_meta)) {
return NULL;
};

CHECK_OR_THROW(placement_obj != Py_None && sbp_obj != Py_None) << Error::InvalidValueError(
"Converting a local tensor to global tensor must have placement and sbp parameters.");
CHECK_OR_THROW(functional::PyParallelDescCheck(placement_obj))
<< Error::TypeError() << "Invalid parameter placement with type "
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(placement_obj)));

std::vector<Symbol<SbpParallel>> sbp;
if (functional::PySbpParallelCheck(sbp_obj)) {
sbp.emplace_back(functional::PyUnpackSbpParallel(sbp_obj));
} else {
CHECK_OR_THROW(functional::PySbpParallelSequenceCheck(sbp_obj))
<< Error::TypeError() << "Invalid parameter sbp with type "
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(sbp_obj)));
sbp = functional::PyUnpackSbpParallelSequence(sbp_obj);
}
return PyTensor_New(ASSERT_PTR(functional::ToConsistent(
tensor, functional::PyUnpackParallelDesc(placement_obj), sbp, {}, check_meta)));
END_HANDLE_ERRORS
}

static PyObject* PyTensorObject_global_to_global(PyObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_ERRORS
auto tensor = PyTensor_Unpack(self);
CHECK_OR_THROW(tensor->is_consistent())
<< Error::RuntimeError() << "input must be a global tensor";
PyObject* placement_obj = Py_None;
PyObject* sbp_obj = Py_None;
PyObject* grad_sbp_obj = Py_None;
Symbol<ParallelDesc> placement;
std::vector<Symbol<SbpParallel>> sbp;
std::vector<Symbol<SbpParallel>> grad_sbp;
bool check_meta = false;
static const char* keywords[5] = {"placement", "sbp", "grad_sbp", "check_meta", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO$OO!:global_to_global",
const_cast<char**>(keywords), &placement_obj, &sbp_obj,
&grad_sbp_obj, &PyBool_Type, &check_meta)) {
return NULL;
};

// sbp
CHECK_OR_THROW(sbp_obj == Py_None || functional::PySbpParallelCheck(sbp_obj)
|| functional::PySbpParallelSequenceCheck(sbp_obj))
<< Error::TypeError()
<< "sbp parameter must be type of oneflow.sbp.sbp or list/tuple of oneflow.sbp.sbp";
if (functional::PySbpParallelCheck(sbp_obj)) {
sbp.emplace_back(functional::PyUnpackSbpParallel(sbp_obj));
} else if (functional::PySbpParallelSequenceCheck(sbp_obj)) {
sbp = functional::PyUnpackSbpParallelSequence(sbp_obj);
} else {
for (int32_t i = 0; i < ASSERT(tensor->nd_sbp())->sbp_parallel_size(); i++)
sbp.emplace_back(ASSERT(tensor->nd_sbp())->sbp_parallel(i));
}

// placement
CHECK_OR_THROW(placement_obj == Py_None || functional::PyParallelDescCheck(placement_obj))
<< Error::TypeError() << "Invalid parameter placement with type "
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(placement_obj)));
if (placement_obj == Py_None) {
placement = ASSERT(tensor->parallel_desc());
} else {
placement = functional::PyUnpackParallelDesc(placement_obj);
}

// grad_sbp
CHECK_OR_THROW(grad_sbp_obj == Py_None || functional::PySbpParallelCheck(grad_sbp_obj)
|| functional::PySbpParallelSequenceCheck(grad_sbp_obj))
<< Error::TypeError()
<< "grad_sbp parameter must be type of oneflow.sbp.sbp or list/tuple of oneflow.sbp.sbp";
if (functional::PySbpParallelCheck(grad_sbp_obj)) {
grad_sbp.emplace_back(functional::PyUnpackSbpParallel(grad_sbp_obj));
} else if (functional::PySbpParallelSequenceCheck(grad_sbp_obj)) {
grad_sbp = functional::PyUnpackSbpParallelSequence(grad_sbp_obj);
}
return PyTensor_New(
ASSERT_PTR(functional::ToConsistent(tensor, placement, sbp, grad_sbp, check_meta)));
END_HANDLE_ERRORS
}

static PyObject* PyTensorObject_to_global(PyObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_ERRORS
const auto& tensor = PyTensor_Unpack(self);
PyObject* result = NULL;
if (tensor->is_consistent())
result = PyTensorObject_global_to_global(self, args, kwargs);
else {
result = PyTensorObject_local_to_global(self, args, kwargs);
}
if (PyErr_Occurred()) { throw py::error_already_set(); }
return result;

END_HANDLE_ERRORS
}

static PyObject* PyTensorObject_to_local(PyObject* self, PyObject* unused) {
HANDLE_ERRORS
auto tensor = PyTensor_Unpack(self);
CHECK_OR_THROW(tensor->is_consistent())
<< Error::RuntimeError() << "Expected global tensor for to_local but got local tensor!";
return PyTensor_New(ASSERT_PTR(functional::ConsistentToLocal(tensor)));
END_HANDLE_ERRORS
}

int PyTensorObject_setitem(PyObject* self, PyObject* item, PyObject* value) {
HANDLE_ERRORS
auto tensor = PyTensor_Unpack(self);
std::shared_ptr<Tensor> value_tensor;
CHECK_OR_THROW(functional::PyTensorIndexCheck(item))
<< Error::TypeError() << "tensor_setitem(): argument 'index' must be index, not "
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(item)));
CHECK_OR_THROW(functional::PyScalarCheck(value) || PyTensor_Check(value))
<< Error::TypeError() << "tensor_setitem(): argument 'value' must be tensor or scalar, not "
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(value)));

if (tensor->is_consistent()) {
Symbol<ParallelDesc> placement = ASSERT(tensor->parallel_desc());
auto ndsbp = ASSERT(tensor->nd_sbp());
std::vector<Symbol<SbpParallel>> sbp(ndsbp->sbp_parallel_size(),
ASSERT(MakeBroadcastSbpParallel()));
if (functional::PyScalarCheck(value)) {
Scalar value_scalar = functional::PyUnpackScalar(value);
value_tensor = ASSERT_PTR(
functional::ConsistentConstant({1}, value_scalar, tensor->dtype(), placement, sbp));
} else {
value_tensor = PyTensor_Unpack(value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里最好加一个检查,value tensor也应该要是consistent的

CHECK_OR_THROW(value_tensor->is_consistent())
<< Error::RuntimeError()
<< "tensor_setitem(): value must be a global tensor when self is global";
value_tensor = ASSERT_PTR(functional::ToConsistent(value_tensor, placement, sbp, {}, true));
}
} else {
if (functional::PyScalarCheck(value)) {
Scalar value_scalar = functional::PyUnpackScalar(value);
value_tensor = ASSERT_PTR(
functional::Constant({1}, value_scalar, tensor->dtype(), ASSERT(tensor->device())));
} else {
value_tensor = PyTensor_Unpack(value);
CHECK_OR_THROW(value_tensor->is_local())
<< Error::RuntimeError()
<< "tensor_setitem(): value must be a local tensor when self is local";
Optional<Symbol<Device>> device = ASSERT(tensor->device());
value_tensor = ASSERT_PTR(functional::To(value_tensor, device, value_tensor->dtype(), false));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也是一样,需要检查value tensor是local的

}
}
ASSERT(functional::TensorSetItem(tensor, functional::PyUnpackTensorIndex(item), value_tensor));
return 0;
END_HANDLE_ERRORS_RET(-1)
}

PyMethodDef PyTensorObject_extra_methods[] = {
{"byte", PyTensorObject_byte, METH_NOARGS, NULL},
{"size", (PyCFunction)PyTensorObject_size, METH_VARARGS | METH_KEYWORDS, NULL},
Expand All @@ -655,6 +817,12 @@ PyMethodDef PyTensorObject_extra_methods[] = {
{"half", PyTensorObject_half, METH_NOARGS, NULL},
{"float", PyTensorObject_float, METH_NOARGS, NULL},
{"double", PyTensorObject_double, METH_NOARGS, NULL},
{"local_to_global", (PyCFunction)PyTensorObject_local_to_global, METH_VARARGS | METH_KEYWORDS,
NULL},
{"global_to_global", (PyCFunction)PyTensorObject_global_to_global, METH_VARARGS | METH_KEYWORDS,
NULL},
{"to_local", PyTensorObject_to_local, METH_NOARGS, NULL},
{"to_global", (PyCFunction)PyTensorObject_to_global, METH_VARARGS | METH_KEYWORDS, NULL},
{"cpu", PyTensorObject_cpu, METH_NOARGS, NULL},
{"cuda", (PyCFunction)PyTensorObject_cuda, METH_VARARGS | METH_KEYWORDS, NULL},
{"var", (PyCFunction)PyTensorObject_var, METH_VARARGS | METH_KEYWORDS, NULL},
Expand Down
51 changes: 0 additions & 51 deletions python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,30 +71,6 @@ def _backward(self, gradient=None, retain_graph=False, create_graph=False):
flow._oneflow_internal.nn.graph.AddTensorAsGraphLoss(self)


def _setitem(self, key, value):
if self.is_global:
if isinstance(value, (int, float)):
value = flow._C.global_constant(
[1],
value,
dtype=self.dtype,
placement=self.placement,
sbp=[flow.sbp.broadcast,] * len(self.sbp),
)
else:
value = value.to_global(
self.placement, sbp=[flow.sbp.broadcast,] * len(self.sbp)
)
else:
if isinstance(value, (int, float)):
value = flow._C.constant([1], value, dtype=self.dtype, device=self.device)
else:
value = value.to(device=self.device)

flow._C.tensor_setitem(self, key, value)
return self


def _str(self):
return self.__repr__()

Expand Down Expand Up @@ -641,10 +617,6 @@ def _triu(self, diagonal=0):
return flow.triu(self, diagonal=diagonal)


def _to_local(self):
return flow.to_local(self)


def _relu(self):
return flow._C.relu(self)

Expand Down Expand Up @@ -920,24 +892,6 @@ def _to(self, *args, **kwargs):
return flow._C.to(self, *new_args, **kwargs)


def _local_to_global(self, placement=None, sbp=None, *, check_meta=True):
return flow.local_to_global(self, placement, sbp, check_meta)


def _global_to_global(
self, placement=None, sbp=None, *, grad_sbp=None, check_meta=False
):
return flow.global_to_global(self, placement, sbp, grad_sbp, check_meta)


def _to_global(self, placement=None, sbp=None, **kwargs):
return flow.to_global(self, placement, sbp, **kwargs)


def _to_local(self):
return flow.to_local(self)


def _tolist(self):
if self.numel() == 1 and self.ndim == 0:
return self.item()
Expand Down Expand Up @@ -1144,7 +1098,6 @@ def RegisterMethods():
Tensor.sub = _sub
Tensor.sub_ = _sub_inplace
Tensor.backward = _backward
Tensor.__setitem__ = _setitem
hjchen2 marked this conversation as resolved.
Show resolved Hide resolved
Tensor.__str__ = _str
Tensor.__repr__ = _repr
Tensor.__bool__ = is_nonzero
Expand Down Expand Up @@ -1176,9 +1129,6 @@ def RegisterMethods():
Tensor.new_zeros = _new_zeros
Tensor.where = _where
Tensor.norm = _norm
Tensor.local_to_global = _local_to_global
Tensor.global_to_global = _global_to_global
Tensor.to_global = _to_global
hjchen2 marked this conversation as resolved.
Show resolved Hide resolved
Tensor.repeat = _repeat
Tensor.repeat_interleave = _repeat_interleave
Tensor.tile = _tile
Expand All @@ -1189,7 +1139,6 @@ def RegisterMethods():
Tensor.masked_select = _masked_select
Tensor.eq = _eq
Tensor.item = _item
Tensor.to_local = _to_local
Tensor.sort = _sort
Tensor.type_as = _type_as
Tensor.tolist = _tolist
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_global_to_global_with_invalid_split_axis(test_case):
@flow.unittest.skip_unless_1n1d()
def test_call_to_local_for_local_tensor(test_case):
x = flow.tensor([1, 2, 3, 4])
with test_case.assertRaises(AssertionError) as ctx:
with test_case.assertRaises(RuntimeError) as ctx:
y = x.to_local()
test_case.assertTrue(
"Expected global tensor for to_local but got local tensor!"
Expand Down