Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat support host memory #9928

Merged
merged 24 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7674b43
feat_support_host_memory_in_lazy_mode
clackhan Mar 2, 2023
28b91fa
refine
clackhan Mar 2, 2023
aae225d
Merge branch 'master' into feat_support_host_memory_in_lazy_mode
clackhan Mar 2, 2023
3c2d286
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Mar 6, 2023
c48f40d
compatible_eager_and_lazy
clackhan Mar 8, 2023
24b9bcb
del useless code
clackhan Mar 8, 2023
5cdc9fa
Merge branch 'feat_support_host_memory_in_lazy_mode' of https://githu…
clackhan Mar 8, 2023
89609a9
optimize code
clackhan Mar 9, 2023
cff151f
refine
clackhan Mar 9, 2023
aeae1e4
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Mar 9, 2023
54b097e
refine
clackhan Mar 9, 2023
1bf742f
refine
clackhan Mar 9, 2023
7e254ac
Merge branch 'master' into feat_support_host_memory_in_lazy_mode
clackhan Mar 9, 2023
dd69263
fix static check error
clackhan Mar 9, 2023
e2887b4
Merge branch 'feat_support_host_memory_in_lazy_mode' of https://githu…
clackhan Mar 9, 2023
4bb813e
Merge branch 'master' into feat_support_host_memory_in_lazy_mode
clackhan Mar 15, 2023
84373c4
deal comments
clackhan Mar 15, 2023
163cb22
reslove comments
clackhan Mar 20, 2023
94cd8a8
Merge branch 'master' into feat_support_host_memory_in_lazy_mode
clackhan Mar 20, 2023
585a920
reslove comments
clackhan Mar 24, 2023
977682d
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
clackhan Mar 24, 2023
eab40e3
Merge branch 'master' into feat_support_host_memory_in_lazy_mode
clackhan Mar 29, 2023
32d650d
Merge branch 'master' into feat_support_host_memory_in_lazy_mode
mergify[bot] Mar 29, 2023
330ab9e
Merge branch 'master' into feat_support_host_memory_in_lazy_mode
mergify[bot] Mar 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions oneflow/core/framework/global_tensor_infer_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,26 @@ Maybe<Operator> MakeOp(const UserOpExpr& user_op_expr, const AttrMap& attrs,
return JUST(ConstructOp(op_conf, device_type));
}

Maybe<void> CheckInputParallelDescIdentical(const GlobalTensorMetaInferArgs& infer_args) {
Maybe<void> CheckInputParallelDescIdentical(const GlobalTensorMetaInferArgs& infer_args,
const UserOpExpr& user_op_expr) {
if (infer_args.input_global_tensor_metas().empty()) { return Maybe<void>::Ok(); }
const auto& first_parallel_desc =
infer_args.input_global_tensor_metas().begin()->tensor_meta()->parallel_desc();
Symbol<ParallelDesc> default_parallel_desc;
for (int i = 0; i < infer_args.input_global_tensor_metas().size(); ++i) {
if (user_op_expr.IsHostMemoryInput(i)) { continue; }
default_parallel_desc =
JUST(VectorAt(infer_args.input_global_tensor_metas(), i)).tensor_meta()->parallel_desc();
break;
}

for (int i = 0; i < infer_args.input_global_tensor_metas().size(); ++i) {
if (user_op_expr.IsHostMemoryInput(i)) { continue; }
CHECK_OR_RETURN(
first_parallel_desc
default_parallel_desc
== JUST(VectorAt(infer_args.input_global_tensor_metas(), i)).tensor_meta()->parallel_desc())
<< Error::RuntimeError()
<< "Expected all tensors to be on the same placement, but found "
"at least two placements, "
<< *JUST(PlacementToString(first_parallel_desc)) << " (positional 0) and "
<< *JUST(PlacementToString(default_parallel_desc)) << " (positional 0) and "
<< *JUST(PlacementToString(JUST(VectorAt(infer_args.input_global_tensor_metas(), i))
.tensor_meta()
->parallel_desc()))
Expand Down Expand Up @@ -256,7 +264,7 @@ class UserOpExprDeviceAndStreamInferContext final : public user_op::DeviceAndStr
CHECK_GT_OR_RETURN(infer_args.input_global_tensor_metas().size(), 0); // NOLINT
Symbol<ParallelDesc> parallel_desc =
infer_args.input_global_tensor_metas()[0].tensor_meta()->parallel_desc();
JUST(CheckInputParallelDescIdentical(infer_args));
JUST(CheckInputParallelDescIdentical(infer_args, user_op_expr));
JUST(CheckIsDeviceSupportedByOp(*parallel_desc, user_op_expr.op_type_name()));
std::vector<OpArgMutGlobalTensorMeta> output_mut_metas(user_op_expr.output_size());
{
Expand Down
6 changes: 4 additions & 2 deletions oneflow/core/framework/local_tensor_infer_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ Maybe<void> CheckIsDeviceSupportedByOp(const Device& device, const std::string&
}

Maybe<void> CheckInputDeviceIdentical(const LocalTensorMetaInferArgs& infer_args,
Symbol<Device> default_device) {
Symbol<Device> default_device,
const UserOpExpr& user_op_expr) {
for (int i = 0; i < infer_args.input_local_tensor_metas().size(); ++i) {
if (user_op_expr.IsHostMemoryInput(i)) { continue; }
CHECK_OR_RETURN(default_device
== JUST(VectorAt(infer_args.input_local_tensor_metas(), i))->device())
<< Error::RuntimeError()
Expand Down Expand Up @@ -158,7 +160,7 @@ Maybe<void> LocalTensorMetaInferArgs::InitInputLocalTensorMetas(const TensorTupl
/* static */ Maybe<const LocalTensorInferResult> LocalTensorInferCache::Infer(
const UserOpExpr& user_op_expr, const LocalTensorMetaInferArgs& infer_args) {
const auto& default_device = infer_args.default_device();
JUST(CheckInputDeviceIdentical(infer_args, default_device));
JUST(CheckInputDeviceIdentical(infer_args, default_device, user_op_expr));
JUST(CheckIsDeviceSupportedByOp(*default_device, user_op_expr.op_type_name()));

auto result = std::make_unique<LocalTensorInferResult>(user_op_expr.output_size());
Expand Down
13 changes: 11 additions & 2 deletions oneflow/core/framework/op_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "oneflow/core/framework/global_tensor_infer_cache.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/user/kernels/stateful_opkernel.h"
#include "oneflow/core/common/container_util.h"

namespace oneflow {
namespace one {
Expand Down Expand Up @@ -531,8 +532,8 @@ UserOpExpr::UserOpExpr(const std::string& op_name, UserOpConf&& proto, const Att
base_attrs_(base_attrs) {}

Maybe<void> UserOpExpr::Init(const std::shared_ptr<const UserOpExpr>& self) {
const auto* registry =
user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_proto_.op_type_name());
const auto& op_type_name = op_proto_.op_type_name();
const auto* registry = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name);
CHECK_NOTNULL_OR_RETURN(registry);
logical_tensor_desc_infer_fn_ = registry->logical_tensor_desc_infer_fn;
CHECK_OR_RETURN(static_cast<bool>(logical_tensor_desc_infer_fn_))
Expand All @@ -548,6 +549,14 @@ Maybe<void> UserOpExpr::Init(const std::shared_ptr<const UserOpExpr>& self) {
}
local_tensor_infer_cache_.reset(new LocalTensorInferCache(self));
global_tensor_infer_cache_.reset(new GlobalTensorInferCache(self));
const auto& indexed_input_pairs = this->indexed_input_pairs();
for (int32_t i = 0; i < indexed_input_pairs.size(); ++i) {
const auto& input_pair = JUST(VectorAt(indexed_input_pairs, i));
if (user_op::UserOpHostMemoryInputRegistry::Get().IsHostMemoryInput4Op(
op_type_name, input_pair.first, input_pair.second)) {
host_memory_input_ids_.emplace_back(i);
}
}
return Maybe<void>::Ok();
}

Expand Down
6 changes: 6 additions & 0 deletions oneflow/core/framework/op_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ class UserOpExpr final : public BuiltinOpExprImpl<UserOpConf> {
return device_and_stream_infer_fn_;
}

bool IsHostMemoryInput(int32_t input_index) const {
return std::find(host_memory_input_ids_.begin(), host_memory_input_ids_.end(), input_index)
!= host_memory_input_ids_.end();
}

Maybe<void> InferPhysicalTensorDesc(
const AttrMap& attrs, const std::string& device_tag,
const std::function<const TensorMeta*(int32_t)>& TensorMeta4InputIndex,
Expand Down Expand Up @@ -186,6 +191,7 @@ class UserOpExpr final : public BuiltinOpExprImpl<UserOpConf> {
mutable HashMap<Symbol<Stream>, std::shared_ptr<StatefulOpKernel>> stream2kernel_;
std::shared_ptr<LocalTensorInferCache> local_tensor_infer_cache_;
std::shared_ptr<GlobalTensorInferCache> global_tensor_infer_cache_;
small_vector<int32_t> host_memory_input_ids_;
};

class GlobalToGlobalOpExpr : public OpExpr {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ namespace one {
namespace {

Maybe<Symbol<ParallelDesc>> GetParallelDesc(const TensorTuple& inputs,
const OpExprInterpContext& ctx) {
if (!inputs.empty()) { return inputs.at(0)->parallel_desc(); }
const OpExprInterpContext& ctx,
const UserOpExpr& user_op_expr) {
if (!inputs.empty()) {
for (int32_t i = 0; i < inputs.size(); ++i) {
if (!user_op_expr.IsHostMemoryInput(i)) { return inputs.at(i)->parallel_desc(); }
}
}
return JUST(ctx.parallel_desc);
}

Expand Down Expand Up @@ -110,7 +115,7 @@ auto* GetBoxingOutput =
Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,
TensorTuple* outputs, const OpExprInterpContext& ctx) {
CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size());
const auto& parallel_desc = JUST(GetParallelDesc(inputs, ctx));
const auto& parallel_desc = JUST(GetParallelDesc(inputs, ctx, user_op_expr));
std::shared_ptr<const GlobalTensorInferResult> result;
NonRecursiveMetaInfoConsistencyCheckScope scope;
if (inputs.empty()) {
Expand Down Expand Up @@ -159,10 +164,16 @@ Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,
const auto& infered_input_meta = result->input_tensor_metas().at(i);
const auto& input_parallel_desc = JUST(input->parallel_desc());
CHECK_OR_RETURN(input_parallel_desc == infered_input_meta->parallel_desc());
if (input_parallel_desc->parallel_num() != 1
&& infered_input_meta->nd_sbp() != JUST(input->nd_sbp())) {
input = JUST(GetBoxingOutput(input, infered_input_meta->nd_sbp(),
infered_input_meta->parallel_desc(), parallel_id.has_value()));
bool is_host_input = user_op_expr.IsHostMemoryInput(i);
Symbol<ParallelDesc> dst_parallel_desc =
is_host_input
? JUST(ReplaceDeviceType(infered_input_meta->parallel_desc(), DeviceType::kCPU))
: infered_input_meta->parallel_desc();
Comment on lines +168 to +171
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当op的输入为HostMemory类型时,boxing_out_ parallel_desc的类型设置为cpu

if ((input_parallel_desc->parallel_num() != 1
&& infered_input_meta->nd_sbp() != JUST(input->nd_sbp()))
|| input_parallel_desc->device_type() != dst_parallel_desc->device_type()) {
input = JUST(GetBoxingOutput(input, infered_input_meta->nd_sbp(), dst_parallel_desc,
parallel_id.has_value()));
boxing_outputs.emplace_back(input);
}
const auto& local_tensor = JUST(input->cur_rank_phy_tensor());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,18 @@ Maybe<Symbol<Device>> RawGetDefaultCpuDevice() { return Device::New("cpu"); }

constexpr auto* GetDefaultCpuDevice = DECORATE(&RawGetDefaultCpuDevice, ThreadLocal);

Maybe<Symbol<Device>> GetDefaultDevice(const TensorTuple& inputs, const OpExprInterpContext& ctx) {
if (inputs.empty()) {
if (ctx.device.has_value()) {
return JUST(ctx.device);
} else {
return GetDefaultCpuDevice();
Maybe<Symbol<Device>> GetDefaultDevice(const TensorTuple& inputs, const OpExprInterpContext& ctx,
const UserOpExpr& user_op_expr) {
if (!inputs.empty()) {
for (int32_t i = 0; i < inputs.size(); ++i) {
if (!user_op_expr.IsHostMemoryInput(i)) { return JUST(inputs.at(i)->device()); }
}
}
return JUST(inputs.at(0)->device());
if (ctx.device.has_value()) {
return JUST(ctx.device);
} else {
return GetDefaultCpuDevice();
}
}

Maybe<EagerLocalTensorImpl*> TensorImpl4Tensor(const std::shared_ptr<Tensor>& tensor) {
Expand All @@ -75,7 +78,7 @@ Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in
TensorTuple* outputs, const OpExprInterpContext& ctx) {
OF_PROFILER_RANGE_GUARD("NaiveInterpret");
CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); // NOLINT
Symbol<Device> default_device = JUST(GetDefaultDevice(inputs, ctx));
Symbol<Device> default_device = JUST(GetDefaultDevice(inputs, ctx, user_op_expr));
const std::shared_ptr<const LocalTensorInferResult> result =
JUST([&]() -> Maybe<const LocalTensorInferResult> {
LocalTensorMetaInferArgs infer_args;
Expand All @@ -84,8 +87,17 @@ Maybe<void> NaiveInterpret(const UserOpExpr& user_op_expr, const TensorTuple& in
}());

vm::EagerBlobObjectList input_eager_blob_objects(inputs.size());
// expand lifetime of host_inputs to the end of this function
TensorTuple host_inputs;
for (int i = 0; i < inputs.size(); i++) {
input_eager_blob_objects.at(i) = JUST(inputs.at(i)->eager_blob_object());
if (user_op_expr.IsHostMemoryInput(i)) {
const auto& host_input = JUST(functional::To(
inputs.at(i), Optional<Symbol<Device>>(JUST(GetDefaultCpuDevice())), NullOpt, false));
input_eager_blob_objects.at(i) = JUST(host_input->eager_blob_object());
host_inputs.emplace_back(host_input);
Comment on lines +94 to +97
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

延长host_input的生命周期,防止其被过析构

} else {
input_eager_blob_objects.at(i) = JUST(inputs.at(i)->eager_blob_object());
}
}

const auto& output_tensor_metas = result->output_tensor_metas();
Expand Down
44 changes: 23 additions & 21 deletions oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -788,27 +788,29 @@ Maybe<void> LazyInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTu
for (int i = 0; i < inputs.size(); ++i) {
const auto& input_tensor = inputs.at(i);
CHECK_EQ_OR_RETURN(is_local, input_tensor->is_local()); // NOLINT(maybe-need-error-msg)
if (is_local) {
CHECK_OR_RETURN(device_tag == JUST(GetDeviceTagOfTensor(input_tensor)))
<< Error::RuntimeError() << "Lazy nn.Graph name: " << graph_name
<< " encountered ERROR in module/op_name: " << new_op_name
<< ". Expected all tensors to be on the same device, but found at least two devices, "
<< JUST(JUST(VectorAt(inputs, 0))->device())->ToString() << " (positional 0) and "
<< JUST(JUST(VectorAt(inputs, i))->device())->ToString() << " (positional " << i
<< ")! Please use tensor.to() to synchronize all the input with the same device.";
} else {
// TODO: Print out all the placement
CHECK_OR_RETURN(parallel_desc->Equals(*JUST(GetParallelDescOfTensor(input_tensor))))
<< Error::RuntimeError() << "Lazy nn.Graph name: " << graph_name
<< " encountered ERROR in module/op_name: " << new_op_name
<< ". Expected all tensors to be on the same placement, but found at least two "
"placements, "
<< *JUST(PlacementToString(JUST(JUST(VectorAt(inputs, 0))->parallel_desc())))
<< " (positional 0) and "
<< *JUST(PlacementToString(JUST(JUST(VectorAt(inputs, i))->parallel_desc())))
<< " (positional " << i
<< ")! Please use tensor.to_global() to synchronize all the input with the same "
"placement.";
if (!op_expr.IsHostMemoryInput(i)) {
if (is_local) {
CHECK_OR_RETURN(device_tag == JUST(GetDeviceTagOfTensor(input_tensor)))
<< Error::RuntimeError() << "Lazy nn.Graph name: " << graph_name
<< " encountered ERROR in module/op_name: " << new_op_name
<< ". Expected all tensors to be on the same device, but found at least two devices, "
<< JUST(JUST(VectorAt(inputs, 0))->device())->ToString() << " (positional 0) and "
<< JUST(JUST(VectorAt(inputs, i))->device())->ToString() << " (positional " << i
<< ")! Please use tensor.to() to synchronize all the input with the same device.";
} else {
// TODO: Print out all the placement
CHECK_OR_RETURN(parallel_desc->Equals(*JUST(GetParallelDescOfTensor(input_tensor))))
<< Error::RuntimeError() << "Lazy nn.Graph name: " << graph_name
<< " encountered ERROR in module/op_name: " << new_op_name
<< ". Expected all tensors to be on the same placement, but found at least two "
"placements, "
<< *JUST(PlacementToString(JUST(JUST(VectorAt(inputs, 0))->parallel_desc())))
<< " (positional 0) and "
<< *JUST(PlacementToString(JUST(JUST(VectorAt(inputs, i))->parallel_desc())))
<< " (positional " << i
<< ")! Please use tensor.to_global() to synchronize all the input with the same "
"placement.";
}
}
const std::string& ibn = op_expr.indexed_ibns().at(i);
std::string lbn = TensorNameScope::Global()->Lookup(input_tensor);
Expand Down
34 changes: 33 additions & 1 deletion oneflow/core/framework/user_op_registry_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,38 @@ Maybe<bool> UserOpRegistryMgr::IsOpKernelRegistered(const std::string& op_type_n
return false;
}

} // namespace user_op
UserOpHostMemoryInputRegistry& UserOpHostMemoryInputRegistry::Get() {
static UserOpHostMemoryInputRegistry mgr;
return mgr;
}

Maybe<void> UserOpHostMemoryInputRegistry::SetHostMemoryInput4Op(const std::string& op_type_name,
const std::string& arg_name,
int32_t index) {
auto it = op_type_name2host_memory_input_args_.find(op_type_name);
if (it == op_type_name2host_memory_input_args_.end()) {
auto pair = op_type_name2host_memory_input_args_.emplace(
op_type_name, small_vector<std::pair<std::string, int32_t>>());
CHECK_OR_RETURN(pair.second);
it = pair.first;
}
it->second.emplace_back(std::make_pair(arg_name, index));
return Maybe<void>::Ok();
}

bool UserOpHostMemoryInputRegistry::IsHostMemoryInput4Op(const std::string& op_type_name,
const std::string& arg_name,
int32_t index) const {
auto it = op_type_name2host_memory_input_args_.find(op_type_name);
if (it == op_type_name2host_memory_input_args_.end()) { return false; }
return std::find(it->second.begin(), it->second.end(), std::make_pair(arg_name, index))
!= it->second.end();
}

bool UserOpHostMemoryInputRegistry::HasHostMemoryInput(const std::string& op_type_name) const {
return op_type_name2host_memory_input_args_.find(op_type_name)
!= op_type_name2host_memory_input_args_.end();
}

} // namespace user_op
} // namespace oneflow
26 changes: 26 additions & 0 deletions oneflow/core/framework/user_op_registry_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "oneflow/core/framework/user_op_registry.h"
#include "oneflow/core/framework/user_op_kernel_registry.h"
#include "oneflow/core/common/registry_error.h"
#include "oneflow/core/common/op_args_reserved_size.h"

namespace oneflow {

Expand Down Expand Up @@ -63,10 +64,35 @@ struct UserOpRegisterTrigger final {
}
};

class UserOpHostMemoryInputRegistry final {
public:
UserOpHostMemoryInputRegistry(UserOpHostMemoryInputRegistry const&) = delete;
UserOpHostMemoryInputRegistry& operator=(UserOpHostMemoryInputRegistry const&) = delete;
~UserOpHostMemoryInputRegistry() = default;

static UserOpHostMemoryInputRegistry& Get();

Maybe<void> SetHostMemoryInput4Op(const std::string& op_type_name, const std::string& arg_name,
int32_t index);
bool IsHostMemoryInput4Op(const std::string& op_type_name, const std::string& arg_name,
int32_t index) const;

bool HasHostMemoryInput(const std::string& op_type_name) const;

private:
UserOpHostMemoryInputRegistry() {}
HashMap<std::string, small_vector<std::pair<std::string, int32_t>>>
op_type_name2host_memory_input_args_;
};

} // namespace user_op

} // namespace oneflow

#define REGISTER_OP_HOST_MEMORY_INPUT(op_type_name, arg_name, index) \
COMMAND(CHECK_JUST(user_op::UserOpHostMemoryInputRegistry::Get().SetHostMemoryInput4Op( \
op_type_name, arg_name, index)));

#define REGISTER_USER_OP(name) \
static ::oneflow::user_op::UserOpRegisterTrigger<::oneflow::user_op::OpRegistry> OF_PP_CAT( \
g_register_trigger, __COUNTER__) = \
Expand Down
5 changes: 5 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
]
bind_python: true

# this api just for test host memory input
- name: "host_scalar_add_by_tensor"
signature: "Tensor (Tensor x, Tensor scalar) => HostScalarAddByTensor"
bind_python: true

- name: "amin"
signature: "Tensor (Tensor input, Int32List[1] dim=None, Bool keepdim=False) => Amin"
bind_python: True
Expand Down
18 changes: 18 additions & 0 deletions oneflow/core/functional/impl/binary_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,23 @@ class ScalarAddByTensorFunctor : public InplaceableBinaryFunctor {
}
};

// this functor just for test host memory input
class HostScalarAddByTensorFunctor {
public:
HostScalarAddByTensorFunctor() {
op_ = CHECK_JUST(
one::OpBuilder("host_scalar_add_by_tensor").Input("x").Input("scalar").Output("y").Build());
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& scalar) const {
return OpInterpUtil::Dispatch<Tensor>(*op_, {x, scalar});
}

private:
std::shared_ptr<OpExpr> op_;
};

class ScalarSubByTensorFunctor : public BinaryFunctor {
public:
ScalarSubByTensorFunctor() {
Expand Down Expand Up @@ -796,6 +813,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::BroadcastLessFunctor>("BroadcastLess");
m.add_functor<impl::BroadcastLessEqualFunctor>("BroadcastLessEqual");
m.add_functor<impl::ScalarAddByTensorFunctor>("ScalarAddByTensor");
m.add_functor<impl::HostScalarAddByTensorFunctor>("HostScalarAddByTensor");
m.add_functor<impl::ScalarSubByTensorFunctor>("ScalarSubByTensor");
m.add_functor<impl::ScalarMulByTensorFunctor>("ScalarMulByTensor");
m.add_functor<impl::ScalarDivByTensorFunctor>("ScalarDivByTensor");
Expand Down
Loading