Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add prod dtype #7932

Merged
merged 66 commits into from
Apr 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
38b7955
add prod dtype
simonJJJ Mar 30, 2022
1cd9dc1
refine
simonJJJ Mar 30, 2022
fc5d648
Merge branch 'master' into add_prod_dtype
mergify[bot] Mar 30, 2022
580a30e
Merge branch 'master' into add_prod_dtype
mergify[bot] Mar 30, 2022
0168951
Merge branch 'master' into add_prod_dtype
mergify[bot] Mar 30, 2022
d1234e9
Merge branch 'master' into add_prod_dtype
mergify[bot] Mar 30, 2022
e1c6046
Merge branch 'master' into add_prod_dtype
mergify[bot] Mar 30, 2022
97687f1
Merge branch 'master' into add_prod_dtype
mergify[bot] Mar 31, 2022
a03e923
Merge branch 'master' into add_prod_dtype
mergify[bot] Mar 31, 2022
06b2ff2
Merge branch 'master' into add_prod_dtype
mergify[bot] Mar 31, 2022
c9a921e
Merge branch 'master' into add_prod_dtype
mergify[bot] Mar 31, 2022
f2cbc6c
Merge branch 'master' into add_prod_dtype
mergify[bot] Mar 31, 2022
76b9da5
Merge branch 'master' into add_prod_dtype
mergify[bot] Mar 31, 2022
74ed332
Merge branch 'master' into add_prod_dtype
mergify[bot] Mar 31, 2022
b73a807
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 2, 2022
87adbc4
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 2, 2022
858a23d
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 2, 2022
6fe5783
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 2, 2022
d6711c7
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 2, 2022
617d3ca
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 2, 2022
a3a99fd
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 3, 2022
8e14312
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 3, 2022
b43b71a
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 5, 2022
d83995c
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 5, 2022
7f359c6
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 6, 2022
12e4b9d
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 6, 2022
a660038
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 6, 2022
760d877
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 7, 2022
bf345ea
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 7, 2022
a6978d0
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 7, 2022
04456a1
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 7, 2022
819259a
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 7, 2022
c8e141e
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 8, 2022
dbdef54
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 8, 2022
b537a98
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 8, 2022
28234fd
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 8, 2022
1885c83
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 9, 2022
44a63da
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 9, 2022
f6fad5c
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 9, 2022
b34a5fe
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 9, 2022
d6628b0
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 10, 2022
a6daa02
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 10, 2022
f815dcc
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 10, 2022
aa45a44
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 11, 2022
68041a9
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 11, 2022
a4460b3
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 12, 2022
1c788ac
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 13, 2022
5b471d3
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 13, 2022
67ea177
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 14, 2022
2f072cd
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 14, 2022
e8baf03
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 14, 2022
7fa423e
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 14, 2022
207ad8b
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 14, 2022
1bb2c94
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 14, 2022
e73356a
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 14, 2022
70364c7
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 14, 2022
857e07f
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 15, 2022
93b44b3
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 15, 2022
6cfa839
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 15, 2022
8fae502
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 15, 2022
5faa4bd
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 15, 2022
3e33d4b
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 16, 2022
11135bd
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 17, 2022
604f3fa
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 17, 2022
2913d6f
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 17, 2022
187842e
Merge branch 'master' into add_prod_dtype
mergify[bot] Apr 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@
bind_python: True

- name: "reduce_prod"
signature: "Tensor (Tensor x, Int32List axis, Bool keepdims=False) => ReduceProd"
signature: "Tensor (Tensor x, Int32List axis, Bool keepdims=False, *, DataType dtype=None) => ReduceProd"
bind_python: True

- name: "reduce_min_device_stage"
Expand Down
14 changes: 8 additions & 6 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,27 +658,29 @@ class ReduceProdFunctor {
one::OpBuilder("reduce_prod").Input("input_tensor").Output("output_tensor").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const std::vector<int32_t>& axis,
const bool& keepdims) const {
const bool& keepdims, const Optional<Symbol<DType>>& dtype) const {
MutableAttrMap attrs;
std::shared_ptr<one::Tensor> tensor = x;
if (dtype.has_value() && (dtype != x->dtype())) { tensor = JUST(Cast(tensor, JUST(dtype))); }
TensorProcessor tensor_processor;
Symbol<DType> lowest_dtype;
if (DType::priority_order[x->dtype()->data_type()]
if (DType::priority_order[tensor->dtype()->data_type()]
== DType::priority_order[DType::Bool()->data_type()]) {
lowest_dtype = DType::Int64();
} else {
lowest_dtype = x->dtype();
lowest_dtype = tensor->dtype();
}
JUST(tensor_processor.AddInputs({x}, lowest_dtype).Apply());
JUST(tensor_processor.AddInputs({tensor}, lowest_dtype).Apply());
TensorTuple input_tuple = JUST(tensor_processor.GetInputs());
if (axis.empty()) {
std::vector<int32_t> reduce_axis(x->shape()->NumAxes());
std::vector<int32_t> reduce_axis(tensor->shape()->NumAxes());
std::iota(reduce_axis.begin(), reduce_axis.end(), 0);
JUST(attrs.SetAttr<std::vector<int32_t>>("axis", reduce_axis));
} else {
JUST(attrs.SetAttr<std::vector<int32_t>>("axis", axis));
}
JUST(attrs.SetAttr<bool>("keepdims", keepdims));
return OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple, attrs);
return JUST(OpInterpUtil::Dispatch<Tensor>(*op_, input_tuple, attrs));
}

private:
Expand Down
4 changes: 2 additions & 2 deletions python/oneflow/nn/modules/reduce_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def mean_op(input, dim=None, keepdim=False):
return flow._C.reduce_mean(input, axis=axis_checked, keepdims=keepdim)


def prod_op(input, dim=None, keepdim=False):
def prod_op(input, dim=None, keepdim=False, *, dtype=None):
axis_checked = _check_axis(dim, input.shape)
if len(axis_checked) == 0:
return input
return flow._C.reduce_prod(input, axis_checked, keepdim)
return flow._C.reduce_prod(input, axis_checked, keepdim, dtype=dtype)


def all_op(input, dim=None, keepdim=False):
Expand Down
10 changes: 10 additions & 0 deletions python/oneflow/test/modules/test_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ def test_reduce_prod_bool_without_dim(test_case):

return y

@autotest(auto_backward=False, check_graph=False)
def test_reduce_prod_with_dtype(test_case):
device = random_device()
ndim = random(1, 5).to(int)
x = random_tensor(ndim=ndim, low=1.0, high=4.0, requires_grad=False).to(device)
dim = random(0, ndim).to(int)
y = torch.prod(x, dim, dtype=torch.int32)

return y


if __name__ == "__main__":
unittest.main()