-
Notifications
You must be signed in to change notification settings - Fork 685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move tensor api to cpython part4 #8395
Conversation
Speed stats:
|
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/8395/ |
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
|
||
SHAPE_ONLY_METHODS(repeat, functional::Repeat, "repeat_shape"); | ||
SHAPE_ONLY_METHODS(tile, functional::Tile, "shape"); | ||
SHAPE_ONLY_METHODS(reshape, functional::Reshape, "shape"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看了下torch的api文档,repeat、view、reshape应该是既可以传shape也可以传int list;tile应该是只能传int list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看了下torch的api文档,repeat、view、reshape应该是既可以传shape也可以传int list;tile应该是只能传int list
tile 也可以的
In [4]: import torch
In [5]: x = torch.randn(1,1)
In [6]: x.tile(1)
Out[6]: tensor([[-1.2598]])
In [7]: x.tile(1, 1)
Out[7]: tensor([[-1.2598]])
@@ -648,6 +509,60 @@ static PyObject* PyTensorObject_transpose(PyObject* self, PyObject* args, PyObje | |||
END_HANDLE_ERRORS | |||
} | |||
|
|||
#define ARGS_ONLY_METHODS(func_name, bind_func, param_name, convert, data_type) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个宏是用来完成 foo(x, *dims)
类似函数的参数解析的。
如果调用方式是 tensor.foo(1,2,3)
,参数 args 就是一个长度不为 0 的 tuple;
如果调用方式是 tensor.foo(param=[1,2,3])
,参数 args 就是一个空 tuple,这时用 PyArg_ParseTupleAndKeywords
完成参数解析
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
…/Oneflow-Inc/oneflow into move_tensor_api_to_cpython_part4
python/oneflow/framework/tensor.py
Outdated
def _ndim(self): | ||
return len(self.shape) | ||
# def _ndim(self): | ||
# return len(self.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些注释应该可以删除了
bind_python: True | ||
|
||
- name: "var" | ||
signature: "Tensor (Tensor x, Int32List[1] dim=None, Bool unbiased=None, Bool keepdim=None) => Variance" | ||
signature: "Tensor (Tensor x, Int32List[1] dim=None, Bool unbiased=True, Bool keepdim=False) => Variance" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我看 torch 里的 var
是从 keepdim
开始才有默认参的 https://pytorch.org/docs/stable/generated/torch.var.html?highlight=torch+var#torch.var
Speed stats:
|
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:eq", const_cast<char**>(keywords), &other)) { | ||
return NULL; | ||
} | ||
if (other == Py_None) { Py_RETURN_FALSE; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里我有个疑惑,请教下。就是如果 self
也是 Py_None
的前提下,other
也是 Py_None
的话,是不是 Python 中应该认为应该返回 True
? 所以这里是不是缺少了对 self
的判断?
还是说传参进来前,已经保证了 self
不会是 Py_None
#define ARGS_ONLY_METHODS(func_name, bind_func, param_name, param_type, convert, int_data_type) \ | ||
static PyObject* PyTensorObject_##func_name(PyObject* self, PyObject* args, PyObject* kwargs) { \ | ||
HANDLE_ERRORS \ | ||
PyObject* shape_obj = NULL; \ | ||
std::vector<int_data_type> shape_vec; \ | ||
int args_size = PyTuple_Size(args); \ | ||
if (args_size == 0) { \ | ||
static const char* keywords[2] = {"" #param_name, NULL}; \ | ||
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:" #func_name, \ | ||
const_cast<char**>(keywords), &shape_obj)) { \ | ||
return NULL; \ | ||
} \ | ||
} else { \ | ||
CHECK_OR_THROW(kwargs == NULL || PyDict_Size(kwargs) <= 0) << #func_name \ | ||
"() got multiple values for argument '" #param_name "' or get invalid argument"; \ | ||
} \ | ||
if (PyTuple_Size(args) == 1) { \ | ||
shape_obj = PyTuple_GetItem(args, 0); \ | ||
} else if (shape_obj == NULL) { \ | ||
shape_obj = args; \ | ||
} \ | ||
CHECK_OR_THROW(PyLong_Check(shape_obj) || functional::PyLongSequenceCheck(shape_obj)) \ | ||
<< Error::TypeError() \ | ||
<< #func_name "(): argument '" #param_name "' must be " #param_type ", not " \ | ||
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(shape_obj))); \ | ||
if (PyLong_Check(shape_obj)) { \ | ||
shape_vec.emplace_back(PyLong_AsLongLong(shape_obj)); \ | ||
} else { \ | ||
shape_vec = functional::PyUnpackLongSequence<int_data_type>(shape_obj); \ | ||
} \ | ||
return PyTensor_New(ASSERT_PTR(bind_func(PyTensor_Unpack(self), convert(shape_vec)))); \ | ||
END_HANDLE_ERRORS \ | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种大段的,复杂的宏,可能比较不合适(因为调试时宏不好查看和跳转)。建议简化下。一般方法有:
- 简单一点的,把逻辑多提取一层,放到其它函数中,宏只是浅浅地调用下那个函数。这个可以参考 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/user/ops/adaptive_max_pool_op.cpp#L84-L121
- 用模板,把
bind_func
作为模板参数,这样func_name
就不需要了。(不过模板比较复杂,我也是大概想到这样,可能写的时候还得调整)。
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
此PR完成了:
一些零碎的改动:
DIRECT_PASS_FUNC
的放在一起,UNARY_METHOD
的放在一起等)ARGS_ONLY_METHODS
,用来对 reshape、permute 之类除了 tensor 只接受一个 int list 参数的函数进行解析——在 torch 中,这些方法绑定到 tensor 上时,支持多种参数输入搬运前 tensor api 的定义可参考:https://github.com/Oneflow-Inc/oneflow/blob/dde79e04b01521e65403d1d49fcf1154a6f289fb/python/oneflow/framework/tensor.py
改动的函数列表:
DIRECT_PASS_FUNC
DIRECT_PASS_FUNC
中DIRECT_PASS_FUNC
中,背景见 Move tensor api to cpython part3 #8342 (comment)DIRECT_PASS_FUNC
中DIRECT_PASS_FUNC
中