Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move tensor api to cpython part4 #8395

Closed
wants to merge 105 commits into from

Conversation

marigoold
Copy link
Contributor

@marigoold marigoold commented Jun 10, 2022

此PR完成了:

  • 搬运了 python 中的 tensor api 到 Python C api
  • 确定了用 CPython api 重写 tensor api 的规则,即“只手写 Python 逻辑,不手写参数解析逻辑“。按照这个规则,删掉了原来比较复杂的手写的 tensor api ,将其挪到了对应的宏中

一些零碎的改动:

  • std 和 var 两个函数,修改了 functional_api.yaml 中的默认参数,和 torch 对齐
  • 删除了 flip_op.py 文件,把 flip_op.py 中判断 dim 的逻辑挪到了 functor 中
  • 更改了 type_as 相关测试代码的变量,把 tgt_xxx 改成了 dst_xxx ,使其更规范
  • 将 tensor_functions.cpp 的 PyTensorObject_extra_methods 中的函数按照类别重新排序(比如 DIRECT_PASS_FUNC 的放在一起,UNARY_METHOD 的放在一起等)
  • 改动了返回 True / False 中的错误写法,应该使用 PY_RETURN_TRUE/FALSE,而不是返回 Py_True / Py_False ,这样就不会造成计数引用问题而导致内存错误
  • 增加宏 ARGS_ONLY_METHODS ,用来对 reshape、permute 之类除了 tensor 只接受一个 int list 参数的函数进行解析——在 torch 中,这些方法绑定到 tensor 上时,支持多种参数输入
    • x.reshape(1, 2) 这样传入可变的 positional arguments
    • x.reshape(shape=(1, 2)) 这样传入 keyword argument
    • 但是不支持 torch.reshape(x, 1, 2) 这样
    • 这个宏实现了类似如下的参数解析逻辑
def _permute(self, *dims):
    if len(dims) == 1:
        new_dims = dims[0]
        if isinstance(new_dims, int):
            new_dims = (new_dims,)
    else:
        new_dims = dims
    return flow._C.permute(self, new_dims)

搬运前 tensor api 的定义可参考:https://github.com/Oneflow-Inc/oneflow/blob/dde79e04b01521e65403d1d49fcf1154a6f289fb/python/oneflow/framework/tensor.py

改动的函数列表:

  • tensor.T 绑定到 property
  • cast, diag, diagonal, matmul, var, std, softplus, split 删除手写参数解析的代码,挪到 DIRECT_PASS_FUNC
    • 其中 cast, diag, diagonal, matmul, split 是原来手写了参数解析,这里删掉了,挪到DIRECT_PASS_FUNC
    • var, std 是因为 tensor.var/std 的参数和 functor 中不同,所以原来手写了参数解析,这里和 torch 对齐,挪到DIRECT_PASS_FUNC 中,背景见 Move tensor api to cpython part3 #8342 (comment)
    • 其中 softplus 手写参数解析的原因是,OneFlow 原来的 tensor.softplus 中没有 beta 和 threshold 参数,但是 functor 中有,所以搬运后手写时声明了这两个参数,参考 https://github.com/Oneflow-Inc/oneflow/pull/8342/files#r887654844 。现在改成和 torch 对齐,可以直接挪到 DIRECT_PASS_FUNC
  • eq 手写了 Python 部分的逻辑,即如果另一个 Tensor 是 None 的话,直接返回 False,其他时候调用 LogicalEqual
  • type_as 删除了 CPython 的逻辑,写在 functor 里面
  • 其余都是直接搬运到 DIRECT_PASS_FUNC

marigoold and others added 30 commits June 6, 2022 16:39
@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 









❌ OneFlow resnet50 time: 139.7ms (= 13966.2ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 160.2ms (= 16018.6ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.15 (= 160.2ms / 139.7ms)

OneFlow resnet50 time: 84.8ms (= 8482.6ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 100.9ms (= 10087.8ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.19 (= 100.9ms / 84.8ms)

OneFlow resnet50 time: 57.5ms (= 11499.5ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 79.1ms (= 15826.5ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.38 (= 79.1ms / 57.5ms)

OneFlow resnet50 time: 44.5ms (= 8901.0ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 80.6ms (= 16115.1ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.81 (= 80.6ms / 44.5ms)

OneFlow resnet50 time: 38.8ms (= 7755.4ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 65.8ms (= 13167.6ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.70 (= 65.8ms / 38.8ms)

@github-actions
Copy link
Contributor

Speed stats:
GPU Name: GeForce GTX 1080 









❌ OneFlow resnet50 time: 139.7ms (= 13974.2ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 162.4ms (= 16238.0ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.16 (= 162.4ms / 139.7ms)

OneFlow resnet50 time: 85.2ms (= 8522.8ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 102.7ms (= 10271.5ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.21 (= 102.7ms / 85.2ms)

OneFlow resnet50 time: 57.3ms (= 11452.4ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 77.2ms (= 15437.1ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.35 (= 77.2ms / 57.3ms)

OneFlow resnet50 time: 44.5ms (= 8905.4ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 68.8ms (= 13754.7ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.54 (= 68.8ms / 44.5ms)

OneFlow resnet50 time: 39.5ms (= 7891.0ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 67.5ms (= 13502.7ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.71 (= 67.5ms / 39.5ms)

@github-actions
Copy link
Contributor

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/8395/

@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.


SHAPE_ONLY_METHODS(repeat, functional::Repeat, "repeat_shape");
SHAPE_ONLY_METHODS(tile, functional::Tile, "shape");
SHAPE_ONLY_METHODS(reshape, functional::Reshape, "shape");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看了下torch的api文档,repeat、view、reshape应该是既可以传shape也可以传int list;tile应该是只能传int list

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看了下torch的api文档,repeat、view、reshape应该是既可以传shape也可以传int list;tile应该是只能传int list

tile 也可以的

In [4]: import torch

In [5]: x = torch.randn(1,1)

In [6]: x.tile(1)
Out[6]: tensor([[-1.2598]])

In [7]: x.tile(1, 1)
Out[7]: tensor([[-1.2598]])

@@ -648,6 +509,60 @@ static PyObject* PyTensorObject_transpose(PyObject* self, PyObject* args, PyObje
END_HANDLE_ERRORS
}

#define ARGS_ONLY_METHODS(func_name, bind_func, param_name, convert, data_type) \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个宏是用来完成 foo(x, *dims) 类似函数的参数解析的。
如果调用方式是 tensor.foo(1,2,3) ,参数 args 就是一个长度不为 0 的 tuple;
如果调用方式是 tensor.foo(param=[1,2,3]) ,参数 args 就是一个空 tuple,这时用 PyArg_ParseTupleAndKeywords 完成参数解析

@github-actions
Copy link
Contributor

github-actions bot commented Nov 1, 2022

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

def _ndim(self):
return len(self.shape)
# def _ndim(self):
# return len(self.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些注释应该可以删除了

oneflow/core/functional/functional_api.yaml Show resolved Hide resolved
bind_python: True

- name: "var"
signature: "Tensor (Tensor x, Int32List[1] dim=None, Bool unbiased=None, Bool keepdim=None) => Variance"
signature: "Tensor (Tensor x, Int32List[1] dim=None, Bool unbiased=True, Bool keepdim=False) => Variance"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看 torch 里的 var 是从 keepdim 开始才有默认参的 https://pytorch.org/docs/stable/generated/torch.var.html?highlight=torch+var#torch.var

oneflow/api/python/framework/tensor_functions.cpp Outdated Show resolved Hide resolved
@github-actions
Copy link
Contributor

github-actions bot commented Nov 1, 2022

Speed stats:

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:eq", const_cast<char**>(keywords), &other)) {
return NULL;
}
if (other == Py_None) { Py_RETURN_FALSE; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我有个疑惑,请教下。就是如果 self 也是 Py_None 的前提下,other 也是 Py_None 的话,是不是 Python 中应该认为应该返回 True? 所以这里是不是缺少了对 self 的判断?

还是说传参进来前,已经保证了 self 不会是 Py_None

Comment on lines +517 to +549
#define ARGS_ONLY_METHODS(func_name, bind_func, param_name, param_type, convert, int_data_type) \
static PyObject* PyTensorObject_##func_name(PyObject* self, PyObject* args, PyObject* kwargs) { \
HANDLE_ERRORS \
PyObject* shape_obj = NULL; \
std::vector<int_data_type> shape_vec; \
int args_size = PyTuple_Size(args); \
if (args_size == 0) { \
static const char* keywords[2] = {"" #param_name, NULL}; \
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:" #func_name, \
const_cast<char**>(keywords), &shape_obj)) { \
return NULL; \
} \
} else { \
CHECK_OR_THROW(kwargs == NULL || PyDict_Size(kwargs) <= 0) << #func_name \
"() got multiple values for argument '" #param_name "' or get invalid argument"; \
} \
if (PyTuple_Size(args) == 1) { \
shape_obj = PyTuple_GetItem(args, 0); \
} else if (shape_obj == NULL) { \
shape_obj = args; \
} \
CHECK_OR_THROW(PyLong_Check(shape_obj) || functional::PyLongSequenceCheck(shape_obj)) \
<< Error::TypeError() \
<< #func_name "(): argument '" #param_name "' must be " #param_type ", not " \
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(shape_obj))); \
if (PyLong_Check(shape_obj)) { \
shape_vec.emplace_back(PyLong_AsLongLong(shape_obj)); \
} else { \
shape_vec = functional::PyUnpackLongSequence<int_data_type>(shape_obj); \
} \
return PyTensor_New(ASSERT_PTR(bind_func(PyTensor_Unpack(self), convert(shape_vec)))); \
END_HANDLE_ERRORS \
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种大段的,复杂的宏,可能比较不合适(因为调试时宏不好查看和跳转)。建议简化下。一般方法有:

  1. 简单一点的,把逻辑多提取一层,放到其它函数中,宏只是浅浅地调用下那个函数。这个可以参考 https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/user/ops/adaptive_max_pool_op.cpp#L84-L121
  2. 用模板,把 bind_func 作为模板参数,这样 func_name 就不需要了。(不过模板比较复杂,我也是大概想到这样,可能写的时候还得调整)。

@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@marigoold marigoold closed this Mar 28, 2023
@marigoold marigoold deleted the move_tensor_api_to_cpython_part4 branch April 13, 2023 08:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants