-
Notifications
You must be signed in to change notification settings - Fork 685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move tensor api to cpython part3 #8342
Changes from 62 commits
6031351
92d17ac
2610b22
8ed6387
603684e
0b81125
879e462
1ecd73e
b26c8ce
63a09ec
05fff5c
4772aa0
7955c61
5c1dd97
76fa597
e6e22e8
cac0e24
c16bc8d
ba679d3
bb61005
c48e537
400d43c
ebac1b0
0e03a60
4b0f227
6ff3648
29c4b34
0b9b0b6
52e58a9
2b8d78e
005cd6d
67e9d37
f139060
ac74978
6174148
dce09af
8bf1782
67f7cb5
50a1a84
80a9f44
5815c46
6ec91f5
fa0051d
9e9920f
d30502b
28600e4
1723d74
8556fd9
fa11e6d
9221c05
a39b38a
60c4b5e
0367588
32715a2
3622f68
0f2fa6f
6ed2ae7
a428c7b
c77742f
b07e070
5bdd6b2
f120f00
82d51ac
1f1e1a7
0944e95
063d3e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -241,7 +241,6 @@ DIRECT_PASS_FUNC(PyTensorObject_div, functional::div) | |
DIRECT_PASS_FUNC(PyTensorObject_div_, functional::div_) | ||
DIRECT_PASS_FUNC(PyTensorObject_mul, functional::mul) | ||
DIRECT_PASS_FUNC(PyTensorObject_mul_, functional::mul_) | ||
DIRECT_PASS_FUNC(PyTensorObject_sub, functional::sub) | ||
DIRECT_PASS_FUNC(PyTensorObject_fmod, functional::fmod) | ||
DIRECT_PASS_FUNC(PyTensorObject_logical_and, functional::logical_and) | ||
DIRECT_PASS_FUNC(PyTensorObject_logical_or, functional::logical_or) | ||
|
@@ -253,8 +252,36 @@ DIRECT_PASS_FUNC(PyTensorObject_bmm, functional::batch_matmul) | |
DIRECT_PASS_FUNC(PyTensorObject_argmax, functional::argmax) | ||
DIRECT_PASS_FUNC(PyTensorObject_argmin, functional::argmin) | ||
DIRECT_PASS_FUNC(PyTensorObject_amin, functional::amin) | ||
DIRECT_PASS_FUNC(PyTensorObject_amax, functional::amax) | ||
DIRECT_PASS_FUNC(PyTensorObject_addcmul, functional::addcmul) | ||
DIRECT_PASS_FUNC(PyTensorObject_addcmul_, functional::addcmul_) | ||
DIRECT_PASS_FUNC(PyTensorObject_clip, functional::clip) | ||
DIRECT_PASS_FUNC(PyTensorObject_clip_, functional::clip_) | ||
DIRECT_PASS_FUNC(PyTensorObject_clamp, functional::clamp) | ||
DIRECT_PASS_FUNC(PyTensorObject_clamp_, functional::clamp_) | ||
DIRECT_PASS_FUNC(PyTensorObject_flatten, functional::flatten) | ||
DIRECT_PASS_FUNC(PyTensorObject_in_top_k, functional::in_top_k) | ||
DIRECT_PASS_FUNC(PyTensorObject_index_select, functional::index_select) | ||
DIRECT_PASS_FUNC(PyTensorObject_maximum, functional::maximum) | ||
DIRECT_PASS_FUNC(PyTensorObject_minimum, functional::minimum) | ||
DIRECT_PASS_FUNC(PyTensorObject_tril, functional::tril) | ||
DIRECT_PASS_FUNC(PyTensorObject_triu, functional::triu) | ||
DIRECT_PASS_FUNC(PyTensorObject_softmax, functional::softmax) | ||
DIRECT_PASS_FUNC(PyTensorObject_log_softmax, functional::log_softmax) | ||
DIRECT_PASS_FUNC(PyTensorObject_roll, functional::roll) | ||
DIRECT_PASS_FUNC(PyTensorObject_unbind, functional::unbind) | ||
DIRECT_PASS_FUNC(PyTensorObject_squeeze, functional::squeeze) | ||
DIRECT_PASS_FUNC(PyTensorObject_swapaxes, functional::swapaxes) | ||
DIRECT_PASS_FUNC(PyTensorObject_swapdims, functional::swapdims) | ||
DIRECT_PASS_FUNC(PyTensorObject_unfold, functional::unfold_tensor) | ||
DIRECT_PASS_FUNC(PyTensorObject_unsqueeze, functional::unsqueeze) | ||
DIRECT_PASS_FUNC(PyTensorObject_max, functional::max) | ||
DIRECT_PASS_FUNC(PyTensorObject_min, functional::min) | ||
DIRECT_PASS_FUNC(PyTensorObject_median, functional::median) | ||
DIRECT_PASS_FUNC(PyTensorObject_pow, functional::pow) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的参数名也和Pytorch对不齐,tensor.py中是b,pytorch中是exponent,但是functional_api.yaml里面是对的。所以直接改掉了 |
||
DIRECT_PASS_FUNC(PyTensorObject_chunk, functional::chunk) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里tensor.py参数默认值有误,为 |
||
DIRECT_PASS_FUNC(PyTensorObject_narrow, functional::narrow) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里tensor.py参数名有误,为 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 看来很多接口还是没太对齐 |
||
DIRECT_PASS_FUNC(PyTensorObject_masked_fill, functional::masked_fill) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里tensor.py参数名有误,为 |
||
|
||
// functions that parsing at Python C api layer | ||
static PyObject* PyTensorObject_byte(PyObject* self, PyObject* unused) { | ||
|
@@ -370,19 +397,6 @@ static PyObject* PyTensorObject_matmul(PyObject* self, PyObject* args, PyObject* | |
END_HANDLE_ERRORS | ||
} | ||
|
||
static PyObject* PyTensorObject_sub_(PyObject* self, PyObject* args, PyObject* kwargs) { | ||
HANDLE_ERRORS | ||
PyObject* other = NULL; | ||
static const char* keywords[2] = {"other", NULL}; | ||
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:sub_", const_cast<char**>(keywords), &other)) { | ||
return NULL; | ||
} | ||
PyObject* result = PyTensorObject_nb_inplace_sub(self, other); | ||
if (PyErr_Occurred()) { throw py::error_already_set(); } | ||
return result; | ||
END_HANDLE_ERRORS | ||
} | ||
|
||
static PyObject* PyTensorObject_reshape(PyObject* self, PyObject* args, PyObject* kwargs) { | ||
HANDLE_ERRORS | ||
PyObject* shape = args; | ||
|
@@ -411,6 +425,139 @@ static PyObject* PyTensorObject_reshape_as(PyObject* self, PyObject* args, PyObj | |
END_HANDLE_ERRORS | ||
} | ||
|
||
static PyObject* PyTensorObject_cpu(PyObject* self, PyObject* unused) { | ||
HANDLE_ERRORS | ||
Optional<std::string> device = "cpu"; | ||
return PyTensor_New(ASSERT_PTR(functional::To(PyTensor_Unpack(self), device, NullOpt, false))); | ||
END_HANDLE_ERRORS | ||
} | ||
|
||
static PyObject* PyTensorObject_cuda(PyObject* self, PyObject* args, PyObject* kwargs) { | ||
HANDLE_ERRORS | ||
PyObject* device_obj = Py_None; | ||
static const char* keywords[2] = {"device", NULL}; | ||
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O:cuda", const_cast<char**>(keywords), | ||
&device_obj)) { | ||
return NULL; | ||
} | ||
auto tensor = PyTensor_Unpack(self); | ||
if (functional::PyDeviceCheck(device_obj)) { | ||
Optional<Symbol<Device>> device = functional::PyUnpackDevice(device_obj); | ||
return PyTensor_New(ASSERT_PTR(functional::To(tensor, device, NullOpt, false))); | ||
} | ||
Optional<std::string> device_str; | ||
if (device_obj == Py_None) { | ||
device_str = "cuda"; | ||
} else if (PyLong_Check(device_obj)) { | ||
device_str = "cuda:" + std::to_string(PyLong_AsLongLong(device_obj)); | ||
} | ||
return PyTensor_New(ASSERT_PTR(functional::To(tensor, device_str, tensor->dtype(), false))); | ||
END_HANDLE_ERRORS | ||
} | ||
|
||
static PyObject* PyTensorObject_var(PyObject* self, PyObject* args, PyObject* kwargs) { | ||
HANDLE_ERRORS | ||
PyObject* dim_obj = Py_None; | ||
PyObject* unbiased_obj = Py_True; | ||
PyObject* keepdim_obj = Py_False; | ||
static const char* keywords[4] = {"dim", "unbiased", "keepdim", NULL}; | ||
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO!O!:var", const_cast<char**>(keywords), | ||
&dim_obj, &PyBool_Type, &unbiased_obj, &PyBool_Type, | ||
&keepdim_obj)) { | ||
return NULL; | ||
} | ||
bool unbiased = unbiased_obj == Py_True; | ||
bool keepdim = keepdim_obj == Py_True; | ||
CHECK_OR_THROW(dim_obj == Py_None || PyLong_Check(dim_obj) | ||
|| functional::PyLongSequenceCheck(dim_obj)) | ||
<< Error::TypeError() << "var(): argument 'dim' must be int32 list, not " | ||
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(dim_obj))); | ||
auto tensor = PyTensor_Unpack(self); | ||
if (dim_obj == Py_None) { | ||
return PyTensor_New(ASSERT_PTR(functional::Variance(tensor, NullOpt, unbiased, keepdim))); | ||
} | ||
std::vector<int32_t> dim; | ||
if (PyLong_Check(dim_obj)) { | ||
dim.emplace_back(static_cast<int32_t>(PyLong_AsLong(dim_obj))); | ||
return PyTensor_New(ASSERT_PTR(functional::Variance(tensor, dim, unbiased, keepdim))); | ||
} | ||
dim = functional::PyUnpackLongSequence<int32_t>(dim_obj); | ||
return PyTensor_New(ASSERT_PTR(functional::Variance(tensor, dim, unbiased, keepdim))); | ||
END_HANDLE_ERRORS | ||
} | ||
|
||
static PyObject* PyTensorObject_std(PyObject* self, PyObject* args, PyObject* kwargs) { | ||
HANDLE_ERRORS | ||
PyObject* dim_obj = Py_None; | ||
PyObject* unbiased_obj = Py_True; | ||
PyObject* keepdim_obj = Py_False; | ||
static const char* keywords[4] = {"dim", "unbiased", "keepdim", NULL}; | ||
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO!O!:std", const_cast<char**>(keywords), | ||
&dim_obj, &PyBool_Type, &unbiased_obj, &PyBool_Type, | ||
&keepdim_obj)) { | ||
return NULL; | ||
} | ||
bool unbiased = unbiased_obj == Py_True; | ||
bool keepdim = keepdim_obj == Py_True; | ||
CHECK_OR_THROW(dim_obj == Py_None || PyLong_Check(dim_obj) | ||
|| functional::PyLongSequenceCheck(dim_obj)) | ||
<< Error::TypeError() << "std(): argument 'dim' must be int32 list, not " | ||
<< functional::PyStringAsString(PyObject_Str((PyObject*)Py_TYPE(dim_obj))); | ||
auto tensor = PyTensor_Unpack(self); | ||
if (dim_obj == Py_None) { | ||
return PyTensor_New( | ||
ASSERT_PTR(functional::StandardDeviation(tensor, NullOpt, unbiased, keepdim))); | ||
} | ||
std::vector<int32_t> dim; | ||
if (PyLong_Check(dim_obj)) { | ||
dim.emplace_back(static_cast<int32_t>(PyLong_AsLong(dim_obj))); | ||
return PyTensor_New(ASSERT_PTR(functional::StandardDeviation(tensor, dim, unbiased, keepdim))); | ||
} | ||
dim = functional::PyUnpackLongSequence<int32_t>(dim_obj); | ||
return PyTensor_New(ASSERT_PTR(functional::StandardDeviation(tensor, dim, unbiased, keepdim))); | ||
END_HANDLE_ERRORS | ||
} | ||
|
||
marigoold marked this conversation as resolved.
Show resolved
Hide resolved
|
||
static PyObject* PyTensorObject_softplus(PyObject* self, PyObject* unused) { | ||
HANDLE_ERRORS | ||
PyObjectPtr concat_args(PyTuple_Pack(1, self)); | ||
PyObject* result = functional::softplus(NULL, concat_args.get(), NULL); | ||
if (PyErr_Occurred()) { throw py::error_already_set(); } | ||
return result; | ||
END_HANDLE_ERRORS | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个函数在tensor.py处不接受其他参数,但是functional_api.yaml多了beta和threshold,所以没放在DIRECT_PASS_FUNC里面 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里我倾向还是加上beta和threshold |
||
|
||
static PyObject* PyTensorObject_relu(PyObject* self, PyObject* unused) { | ||
HANDLE_ERRORS | ||
return PyTensor_New(ASSERT_PTR(functional::Relu(PyTensor_Unpack(self), false))); | ||
END_HANDLE_ERRORS | ||
} | ||
|
||
static PyObject* PyTensorObject_relu_(PyObject* self, PyObject* unused) { | ||
HANDLE_ERRORS | ||
return PyTensor_New(ASSERT_PTR(functional::Relu(PyTensor_Unpack(self), true))); | ||
END_HANDLE_ERRORS | ||
} | ||
|
||
#define REDUCE_FUNC(func_name, bind_func, whole_func) \ | ||
static PyObject* func_name(PyObject* self, PyObject* args, PyObject* kwargs) { \ | ||
HANDLE_ERRORS \ | ||
if ((args == NULL || PyTuple_Size(args) == 0) \ | ||
&& (kwargs == NULL || PyDict_Size(kwargs) == 0)) { \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里判断这么多,是发现测试的时候即使调用了tensor.sum()也会导致kwargs不是NULL,所以多加上了一些判断语句。这里的args == NULL可能不是必须的,因为args恒是一个Tuple? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 加上args == NULL判断更好 |
||
return PyTensor_New(ASSERT_PTR(whole_func(PyTensor_Unpack(self)))); \ | ||
} \ | ||
PyObjectPtr concat_args(concat_self(self, args)); \ | ||
PyObject* result = bind_func(NULL, concat_args.get(), kwargs); \ | ||
if (PyErr_Occurred()) { throw py::error_already_set(); } \ | ||
return result; \ | ||
END_HANDLE_ERRORS \ | ||
} | ||
|
||
REDUCE_FUNC(PyTensorObject_any, functional::reduce_any, functional::ReduceAnyWhole) | ||
REDUCE_FUNC(PyTensorObject_all, functional::reduce_all, functional::ReduceAllWhole) | ||
REDUCE_FUNC(PyTensorObject_sum, functional::reduce_sum, functional::ReduceSumWhole) | ||
REDUCE_FUNC(PyTensorObject_mean, functional::reduce_mean, functional::ReduceMeanWhole) | ||
|
||
#define DATATYPE_FUNC(func_name, dtype) \ | ||
static PyObject* func_name(PyObject* self, PyObject* unused) { \ | ||
HANDLE_ERRORS \ | ||
|
@@ -421,6 +568,7 @@ static PyObject* PyTensorObject_reshape_as(PyObject* self, PyObject* args, PyObj | |
|
||
DATATYPE_FUNC(PyTensorObject_int, DType::Int32()); | ||
DATATYPE_FUNC(PyTensorObject_long, DType::Int64()); | ||
DATATYPE_FUNC(PyTensorObject_half, DType::Float16()); | ||
DATATYPE_FUNC(PyTensorObject_float, DType::Float()); | ||
DATATYPE_FUNC(PyTensorObject_double, DType::Double()); | ||
|
||
|
@@ -499,12 +647,23 @@ PyMethodDef PyTensorObject_extra_methods[] = { | |
{"diagonal", (PyCFunction)PyTensorObject_diagonal, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"addcmul", (PyCFunction)PyTensorObject_addcmul, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"addcmul_", (PyCFunction)PyTensorObject_addcmul_, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"sub_", (PyCFunction)PyTensorObject_sub_, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"matmul", (PyCFunction)PyTensorObject_matmul, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"int", PyTensorObject_int, METH_NOARGS, NULL}, | ||
{"long", PyTensorObject_long, METH_NOARGS, NULL}, | ||
{"half", PyTensorObject_half, METH_NOARGS, NULL}, | ||
{"float", PyTensorObject_float, METH_NOARGS, NULL}, | ||
{"double", PyTensorObject_double, METH_NOARGS, NULL}, | ||
{"cpu", PyTensorObject_cpu, METH_NOARGS, NULL}, | ||
{"cuda", (PyCFunction)PyTensorObject_cuda, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"var", (PyCFunction)PyTensorObject_var, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"std", (PyCFunction)PyTensorObject_std, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"softplus", PyTensorObject_softplus, METH_NOARGS, NULL}, | ||
{"relu", PyTensorObject_relu, METH_NOARGS, NULL}, | ||
{"relu_", PyTensorObject_relu_, METH_NOARGS, NULL}, | ||
{"all", (PyCFunction)PyTensorObject_all, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"any", (PyCFunction)PyTensorObject_any, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"sum", (PyCFunction)PyTensorObject_sum, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"mean", (PyCFunction)PyTensorObject_mean, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
|
||
// macro DIRECT_PASS_FUNC | ||
{"floor_divide", (PyCFunction)PyTensorObject_floor_divide, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
|
@@ -515,7 +674,6 @@ PyMethodDef PyTensorObject_extra_methods[] = { | |
{"div_", (PyCFunction)PyTensorObject_div_, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"mul", (PyCFunction)PyTensorObject_mul, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"mul_", (PyCFunction)PyTensorObject_mul_, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"sub", (PyCFunction)PyTensorObject_sub, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"fmod", (PyCFunction)PyTensorObject_fmod, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"logical_and", (PyCFunction)PyTensorObject_logical_and, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"logical_or", (PyCFunction)PyTensorObject_logical_or, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
|
@@ -524,6 +682,34 @@ PyMethodDef PyTensorObject_extra_methods[] = { | |
{"ne", (PyCFunction)PyTensorObject_ne, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"lt", (PyCFunction)PyTensorObject_lt, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"le", (PyCFunction)PyTensorObject_le, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"clip", (PyCFunction)PyTensorObject_clip, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"clip_", (PyCFunction)PyTensorObject_clip_, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"clamp", (PyCFunction)PyTensorObject_clamp, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"clamp_", (PyCFunction)PyTensorObject_clamp_, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"flatten", (PyCFunction)PyTensorObject_flatten, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"in_top_k", (PyCFunction)PyTensorObject_in_top_k, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"index_select", (PyCFunction)PyTensorObject_index_select, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"maximum", (PyCFunction)PyTensorObject_maximum, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"minimum", (PyCFunction)PyTensorObject_minimum, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"tril", (PyCFunction)PyTensorObject_tril, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"triu", (PyCFunction)PyTensorObject_triu, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"softmax", (PyCFunction)PyTensorObject_softmax, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"log_softmax", (PyCFunction)PyTensorObject_log_softmax, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"roll", (PyCFunction)PyTensorObject_roll, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"unbind", (PyCFunction)PyTensorObject_unbind, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"squeeze", (PyCFunction)PyTensorObject_squeeze, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"swapaxes", (PyCFunction)PyTensorObject_swapaxes, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"amax", (PyCFunction)PyTensorObject_amax, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"swapdims", (PyCFunction)PyTensorObject_swapdims, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"unfold", (PyCFunction)PyTensorObject_unfold, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"unsqueeze", (PyCFunction)PyTensorObject_unsqueeze, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"max", (PyCFunction)PyTensorObject_max, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"min", (PyCFunction)PyTensorObject_min, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"median", (PyCFunction)PyTensorObject_median, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"pow", (PyCFunction)PyTensorObject_pow, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"chunk", (PyCFunction)PyTensorObject_chunk, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"narrow", (PyCFunction)PyTensorObject_narrow, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"masked_fill", (PyCFunction)PyTensorObject_masked_fill, METH_VARARGS | METH_KEYWORDS, NULL}, | ||
|
||
// macro UNARY_METHOD | ||
{"abs", PyTensorObject_abs, METH_NOARGS, NULL}, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -458,6 +458,7 @@ class ReduceSumWholeFunctor { | |
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const { | ||
MutableAttrMap attrs; | ||
const int32_t naxis = x->ndim(); | ||
if (naxis == 0) { return x; } // for 0-dim Tensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. test_cuda_0dim会在backward的时候报错,这里加上了特判 |
||
std::vector<int32_t> axis(naxis); | ||
std::iota(axis.begin(), axis.end(), 0); | ||
JUST(attrs.SetAttr<std::vector<int32_t>>("axis", axis)); | ||
|
@@ -1952,7 +1953,12 @@ class StandardDeviationFunctor { | |
Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& input, | ||
const Optional<std::vector<int32_t>>& dim, | ||
const Optional<bool>& unbiased, const Optional<bool>& keepdim) const { | ||
std::vector<int32_t> axis = *JUST(CheckAxis(*JUST(dim), input->ndim())); | ||
std::vector<int32_t> axis; | ||
if (!dim) { | ||
for (int i = 0; i < input->ndim(); i++) { axis.emplace_back(i); } | ||
} else { | ||
axis = *JUST(CheckAxis(*JUST(dim), input->ndim())); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 原来调用tensor.std()会挂在check上,这里修正了这个bug |
||
bool unbias = true; | ||
bool keepdims = false; | ||
if (unbiased.has_value()) { unbias = JUST(unbiased); } | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里tensor.py参数名有误,为
self, y
,functional_api.yaml中和PyTorch一致,为input, other
https://github.com/pytorch/pytorch/blob/4858c56334aa2b09b1ba10d0a3547ef01edda363/aten/src/ATen/native/native_functions.yaml#L8119