Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move python code to cpp: eye #7036

Merged
merged 19 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -529,16 +529,11 @@
bind_python: False

- name: "eye"
signature:
[
"Tensor (Scalar n, Scalar m=None, *, DataType dtype=None, Device device=None) => Eye",
]
bind_python: True

- name: "consistent_eye"
signature:
[
"Tensor (Scalar n, Scalar m=None, *, DataType dtype=None, Placement placement, SbpList sbp) => ConsistentEye",
signature: [
"Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, Device device=None, Bool requires_grad=False) => Eye",
"Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, String device, Bool requires_grad=False) => Eye",
"Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, Bool requires_grad=False, Placement placement, SbpList sbp) => Eye",
"Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, Bool requires_grad=False, Placement placement, Sbp sbp) => Eye",
]
bind_python: True

Expand Down
137 changes: 137 additions & 0 deletions oneflow/core/functional/impl/eye_functor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include "oneflow/core/common/just.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/common/scalar.h"
#include "oneflow/core/common/throw.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/function_library.h"
#include "oneflow/core/functional/functional_api.yaml.h"
#include "oneflow/core/functional/impl/common.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/api/common/device.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

api 目录下的头文件只能被 api 目录下的文件 include 的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是某些情况下编译不过吗?然后我本地以及CI没有遇到这个坑所以过了?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

主要是代码结构上的要求,oneflow api 可以依赖 oneflow 本体,oneflow 本体不应该依赖 oneflow api

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

哦哦,好的,我这个PR引入的问题,也是周泽楷在修吗,还是还没人修,没人修的话,我自己提个fix PR。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯对的 zekai 在修,在把 ParseAndNew 移动到 Device 类里


namespace oneflow {
namespace one {
namespace functional {

namespace impl {

class EyeDevcieFunctor {
doombeaker marked this conversation as resolved.
Show resolved Hide resolved
public:
EyeDevcieFunctor() { op_ = CHECK_JUST(one::OpBuilder("eye").Output("out").Build()); }
Maybe<Tensor> operator()(const Scalar& rows, const Optional<Scalar>& cols,
const Symbol<DType>& dtype, const Optional<Symbol<Device>>& device,
const bool& requires_grad) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("rows", JUST(rows.As<int64_t>())));
JUST(attrs.SetAttr<int64_t>("cols", JUST(cols.value_or(rows).As<int64_t>())));
JUST(attrs.SetAttr<DataType>("dtype", dtype->data_type()));
OpExprInterpContext ctx(attrs);
ctx.device = device;
auto res = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx));
JUST(res->set_requires_grad(requires_grad));
return res;
doombeaker marked this conversation as resolved.
Show resolved Hide resolved
}

private:
std::shared_ptr<OpExpr> op_;
};

class EyeDeviceStrFunctor {
public:
Maybe<Tensor> operator()(const Scalar& rows, const Optional<Scalar>& cols,
const Symbol<DType>& dtype, const std::string& device,
const bool& requires_grad) const {
const Symbol<Device>& dev = JUST(DeviceExportUtil::ParseAndNew(device));
return JUST(functional::Eye(rows, cols, dtype, dev, requires_grad));
}
};

class ConsistentEyeSbpListFunctor {
doombeaker marked this conversation as resolved.
Show resolved Hide resolved
public:
ConsistentEyeSbpListFunctor() { op_ = CHECK_JUST(one::OpBuilder("eye").Output("out").Build()); }
Maybe<Tensor> operator()(const Scalar& rows, const Optional<Scalar>& cols,
const Symbol<DType>& dtype, const bool& requires_grad,
const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const {
MutableAttrMap attrs;
CHECK_EQ_OR_RETURN(sbp_tuple.size(), placement->hierarchy()->NumAxes())
<< "len(sbp) == len(placement.hierarchy) required, but "
<< "len(sbp)==" << sbp_tuple.size() << ", "
<< "len(placement.hierarchy)==" << placement->hierarchy()->NumAxes();

FOR_RANGE(int32_t, i, 0, sbp_tuple.size()) {
CHECK_OR_RETURN(sbp_tuple.at(i)->has_broadcast_parallel())
<< "sbp of eye should be broadcast only";
}

JUST(attrs.SetAttr<int64_t>("rows", JUST(rows.As<int64_t>())));
JUST(attrs.SetAttr<int64_t>("cols", JUST(cols.value_or(rows).As<int64_t>())));
JUST(attrs.SetAttr<DataType>("dtype", dtype->data_type()));
if (LazyMode::is_enabled()) {
std::vector<std::string> nd_sbp(sbp_tuple.size());
doombeaker marked this conversation as resolved.
Show resolved Hide resolved
{
for (int i = 0; i < sbp_tuple.size(); ++i) {
nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i));
}
}
JUST(attrs.SetAttr<std::vector<std::string>>("nd_sbp", nd_sbp));
}
const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));
auto res = JUST(
OpInterpUtil::Dispatch<Tensor>(*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp)));
JUST(res->set_requires_grad(requires_grad));
return res;
doombeaker marked this conversation as resolved.
Show resolved Hide resolved
}

private:
std::shared_ptr<OpExpr> op_;
};

class ConsistentEyeSbpFunctor {
public:
Maybe<Tensor> operator()(const Scalar& rows, const Optional<Scalar>& cols,
const Symbol<DType>& dtype, const bool& requires_grad,
const Symbol<ParallelDesc>& placement,
const Symbol<cfg::SbpParallel>& sbp) const {
std::vector<Symbol<cfg::SbpParallel>> sbp_tuple{sbp};
return JUST(functional::Eye(rows, cols, dtype, requires_grad, placement, sbp_tuple));
}
};

} // namespace impl

using namespace impl;

ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<EyeDevcieFunctor, EyeDeviceStrFunctor, ConsistentEyeSbpListFunctor,
ConsistentEyeSbpFunctor>("Eye");
};

} // namespace functional
} // namespace one
} // namespace oneflow
50 changes: 0 additions & 50 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -596,54 +596,6 @@ class TransposeFunctor {
std::shared_ptr<OpExpr> op_;
};

class EyeFunctor {
public:
EyeFunctor() { op_ = CHECK_JUST(one::OpBuilder("eye").Output("out").Build()); }
Maybe<Tensor> operator()(const Scalar& rows, const Optional<Scalar>& cols,
const Optional<Symbol<DType>>& dtype,
const Optional<Symbol<Device>>& device) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("rows", JUST(rows.As<int64_t>())));
JUST(attrs.SetAttr<int64_t>("cols", JUST(cols.value_or(rows).As<int64_t>())));
JUST(attrs.SetAttr<DataType>("dtype", dtype ? JUST(dtype)->data_type() : DataType::kFloat));
OpExprInterpContext ctx(attrs);
ctx.device = device;
return OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx);
}

private:
std::shared_ptr<OpExpr> op_;
};

class ConsistentEyeFunctor {
public:
ConsistentEyeFunctor() { op_ = CHECK_JUST(one::OpBuilder("eye").Output("out").Build()); }
Maybe<Tensor> operator()(const Scalar& rows, const Optional<Scalar>& cols,
const Optional<Symbol<DType>>& dtype,
const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple) const {
JUST(CheckDeviceIdsIsValid(placement));
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("rows", JUST(rows.As<int64_t>())));
JUST(attrs.SetAttr<int64_t>("cols", JUST(cols.value_or(rows).As<int64_t>())));
JUST(attrs.SetAttr<DataType>("dtype", dtype ? JUST(dtype)->data_type() : DataType::kFloat));
if (LazyMode::is_enabled()) {
std::vector<std::string> nd_sbp(sbp_tuple.size());
{
for (int i = 0; i < sbp_tuple.size(); ++i) {
nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i));
}
}
JUST(attrs.SetAttr<std::vector<std::string>>("nd_sbp", nd_sbp));
}
const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));
return OpInterpUtil::Dispatch<Tensor>(*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp));
}

private:
std::shared_ptr<OpExpr> op_;
};

class Transpose2dimFunctor {
public:
Transpose2dimFunctor() {
Expand Down Expand Up @@ -1744,8 +1696,6 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<ReduceMaxGlobalStageGradFunctor>("ReduceMaxGlobalStageGrad");
m.add_functor<TransposeFunctor>("Transpose");
m.add_functor<TransposeFunctor>("Permute");
m.add_functor<EyeFunctor>("Eye");
m.add_functor<ConsistentEyeFunctor>("ConsistentEye");
m.add_functor<Transpose2dimFunctor>("Transpose2dim");
m.add_functor<ArangeFunctor, Arange2Functor>("Arange");
m.add_functor<ConsistentArangeFunctor, ConsistentArange2Functor>("ConsistentArange");
Expand Down
2 changes: 1 addition & 1 deletion python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def is_deprecated(func_or_class):
from oneflow._C import read_onerec
from oneflow._C import decode_onerec
from oneflow._C import dot
from oneflow._C import eye


from . import sbp
Expand Down Expand Up @@ -337,7 +338,6 @@ def atexit_hook(hook):
from oneflow.nn.modules.slice import slice_update_op as slice_update
from oneflow.nn.modules.slice import logical_slice_assign_op as logical_slice_assign
from oneflow.nn.modules.sort import sort_op as sort
from oneflow.nn.modules.eye import eye_op as eye
from oneflow.nn.modules.tensor_buffer import gen_tensor_buffer
from oneflow.nn.modules.tensor_buffer import (
tensor_buffer_to_tensor_op as tensor_buffer_to_tensor,
Expand Down
37 changes: 37 additions & 0 deletions python/oneflow/framework/docstr/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,3 +1337,40 @@
oneflow.Size([3, 4, 2, 5])
""",
)

add_docstr(
oneflow.eye,
"""oneflow.eye(n, m, *, device=None, requires_grad=False, placement=None, sbp) -> Tensor

This operator creates a 2-D Tensor with ones on the diagonal and zeros elsewhere.

Args:
n (int): the number of rows.
m (int, optional): the number of colums with default being n. Defaults to None.

Keyword args:
device(Union[flow.device, str], optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor.
doombeaker marked this conversation as resolved.
Show resolved Hide resolved
requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: `False`.
placement(oneflow._oneflow_internal.placement, optional): The placement attribute allows you to specify which physical device the tensor is stored on.
sbp(Union[oneflow._oneflow_internal.sbp.sbp, List[oneflow._oneflow_internal.sbp.sbp]], optional): When creating a consistent tensor, specify the SBP of the tensor.

Returns:
oneflow.Tensor: The result tensor with ones on the diagonal and zeros elsewhere.

For example:

.. code-block:: python

>>> import oneflow as flow
>>> out = flow.eye(3, 3)
>>> out
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=oneflow.float32)
>>> out = flow.eye(3, 3, device="cuda")
>>> out
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], device='cuda:0', dtype=oneflow.float32)
""",
)
82 changes: 0 additions & 82 deletions python/oneflow/nn/modules/eye.py

This file was deleted.