Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VM] Minor refactor for C++ memory alloc #7413

Merged
merged 9 commits into from
Feb 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 0 additions & 73 deletions src/relay/op/device_copy.cc

This file was deleted.

83 changes: 61 additions & 22 deletions src/relay/op/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,22 @@
* \brief Operators for manifest shape-aware memory allocation in Relay.
*/

#include "memory.h"

#include <tvm/node/node.h>
#include <tvm/relay/attrs/memory.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/runtime/data_type.h>
#include <tvm/topi/elemwise.h>

#include <vector>

#include "../../transforms/infer_layout_utils.h"
#include "../op_common.h"
#include "../type_relations.h"
#include "tvm/relay/attrs/device_copy.h"

namespace tvm {
namespace relay {
Expand All @@ -42,15 +48,16 @@ TVM_REGISTER_NODE_TYPE(AllocTensorAttrs);
// The passing value in attrs and args doesn't seem super great.
// We should consider a better solution, i.e the type relation
// being able to see the arguments as well?
TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage")
.set_body_typed([](Expr size, Expr alignment, TVMContext ctx, DataType dtype_hint) {
auto attrs = make_object<AllocStorageAttrs>();
attrs->dtype = dtype_hint;
attrs->device_id = ctx.device_id;
attrs->device_type = ctx.device_type;
static const Op& op = Op::Get("memory.alloc_storage");
return Call(op, {size, alignment}, Attrs(attrs), {});
});
Expr AllocStorage(Expr size, Expr alignment, TVMContext ctx, DataType dtype_hint) {
auto attrs = make_object<AllocStorageAttrs>();
attrs->dtype = dtype_hint;
attrs->device_id = ctx.device_id;
attrs->device_type = ctx.device_type;
static const Op& op = Op::Get("memory.alloc_storage");
return Call(op, {size, alignment}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage").set_body_typed(AllocStorage);

bool AllocStorageRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Expand Down Expand Up @@ -90,19 +97,20 @@ RELAY_REGISTER_OP("memory.alloc_storage")
return {topi::identity(inputs[0])};
});

TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor")
.set_body_typed([](Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype,
Array<IndexExpr> assert_shape) {
auto attrs = make_object<AllocTensorAttrs>();
attrs->dtype = dtype;
if (assert_shape.defined()) {
attrs->assert_shape = assert_shape;
} else {
attrs->const_shape = Downcast<Constant>(shape);
}
static const Op& op = Op::Get("memory.alloc_tensor");
return Call(op, {storage, offset, shape}, Attrs(attrs), {});
});
Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype,
Array<IndexExpr> assert_shape) {
auto attrs = make_object<AllocTensorAttrs>();
attrs->dtype = dtype;
if (assert_shape.defined()) {
attrs->assert_shape = assert_shape;
} else {
attrs->const_shape = Downcast<Constant>(shape);
}
static const Op& op = Op::Get("memory.alloc_tensor");
return Call(op, {storage, offset, shape}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor").set_body_typed(AllocTensor);

std::vector<int64_t> FromConstShape(Constant konst) {
runtime::NDArray shape = konst->data;
Expand Down Expand Up @@ -299,5 +307,36 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.ToTupleType")
return ToTupleType(t, std::vector<Expr>(array.begin(), array.end()));
});

// relay.device_copy
TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs);

Expr DeviceCopy(Expr data, int src_dev_type, int dst_dev_type) {
auto attrs = make_object<DeviceCopyAttrs>();
attrs->src_dev_type = src_dev_type;
attrs->dst_dev_type = dst_dev_type;
static const Op& op = Op::Get("device_copy");
return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.device_copy").set_body_typed(DeviceCopy);

RELAY_REGISTER_OP("device_copy")
.describe(R"code(
Copy data from one tensor to another. The source and destination might be
on different devices.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input data.")
.set_support_level(10)
.add_type_rel("Identity", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attr<FTVMCompute>("FTVMCompute",
[](const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_dtype) -> Array<te::Tensor> {
return {topi::identity(inputs[0])};
});

} // namespace relay
} // namespace tvm
46 changes: 46 additions & 0 deletions src/relay/op/memory/memory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/op/memory/memory.h
* \brief Operators for memory related operations in Relay.
*/

#ifndef TVM_RELAY_OP_MEMORY_MEMORY_H_
#define TVM_RELAY_OP_MEMORY_MEMORY_H_

#include <vector>

#include "tvm/relay/expr.h"

namespace tvm {
namespace relay {

Expr AllocStorage(Expr size, Expr alignment, TVMContext ctx, DataType dtype_hint);
Expr DeviceCopy(Expr data, int src_dev_type, int dst_dev_type);
Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype,
Array<IndexExpr> assert_shape);
Expr ToTupleType(const Type& ty, const std::vector<Expr>& exprs);
std::vector<Expr> FromTupleType(const Type& type, const Expr& expr);
std::vector<TensorType> FlattenTupleType(const Type& type);

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_OP_MEMORY_MEMORY_H_
6 changes: 5 additions & 1 deletion src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,11 @@ Array<te::Tensor> ProdCompute(const Attrs& attrs, const Array<te::Tensor>& input
return ReduceCompute(attrs, inputs, out_type, topi::prod);
}

RELAY_REGISTER_REDUCE_OP("prod")
TVM_REGISTER_GLOBAL("relay.op._make.prod").set_body_typed(Prod);

RELAY_REGISTER_OP("prod")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.describe(R"code(Computes the products of array elements over given axes.

Example::
Expand Down
49 changes: 29 additions & 20 deletions src/relay/op/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
* \brief Dialect operators for Relay VM.
*/

#include "vm.h"

#include <tvm/relay/attrs/memory.h>
#include <tvm/relay/attrs/vm.h>
#include <tvm/relay/expr.h>
Expand All @@ -30,6 +32,8 @@
#include <tvm/runtime/data_type.h>
#include <tvm/topi/elemwise.h>

#include <utility>

#include "../../transforms/infer_layout_utils.h"
#include "../op_common.h"
#include "../type_relations.h"
Expand All @@ -52,20 +56,23 @@ RELAY_REGISTER_OP("vm.shape_of")
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

TVM_REGISTER_GLOBAL("relay.op.vm.shape_of").set_body_typed([](Expr expr) {
Expr ShapeOf(Expr expr) {
auto attrs = make_object<ShapeOfAttrs>();
attrs->dtype = DataType::Int(64);
static const Op& op = Op::Get("vm.shape_of");
return Call(op, {expr}, Attrs(attrs), {});
});
}

TVM_REGISTER_GLOBAL("relay.op.vm.shape_of").set_body_typed(ShapeOf);

Expr ShapeFunc(Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input) {
static const Op& op = Op::Get("vm.shape_func");
auto attrs = make_object<ShapeFuncAttrs>();
attrs->is_input = is_input;
return Call(op, {func, inputs, outputs}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.vm.shape_func")
.set_body_typed([](Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input) {
static const Op& op = Op::Get("vm.shape_func");
auto attrs = make_object<ShapeFuncAttrs>();
attrs->is_input = is_input;
return Call(op, {func, inputs, outputs}, Attrs(attrs), {});
});
TVM_REGISTER_GLOBAL("relay.op.vm.shape_func").set_body_typed(ShapeFunc);

bool ShapeFuncRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Expand Down Expand Up @@ -162,10 +169,11 @@ bool InvokeTVMOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
return true;
}

TVM_REGISTER_GLOBAL("relay.op.vm.invoke_tvm_op")
.set_body_typed([](Expr func, Expr inputs, Expr outputs) {
return Call(Op::Get("vm.invoke_tvm_op"), {func, inputs, outputs}, Attrs());
});
Expr InvokeTVMOp(Expr func, Expr inputs, Expr outputs) {
return Call(Op::Get("vm.invoke_tvm_op"), {func, inputs, outputs}, Attrs());
}

TVM_REGISTER_GLOBAL("relay.op.vm.invoke_tvm_op").set_body_typed(InvokeTVMOp);

RELAY_REGISTER_OP("vm.invoke_tvm_op")
.describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE)
Expand Down Expand Up @@ -212,13 +220,14 @@ RELAY_REGISTER_OP("vm.reshape_tensor")
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

TVM_REGISTER_GLOBAL("relay.op.vm.reshape_tensor")
.set_body_typed([](Expr data, Expr shape, Array<PrimExpr> newshape) {
static const Op& op = Op::Get("vm.reshape_tensor");
auto attrs = make_object<ReshapeTensorAttrs>();
attrs->newshape = std::move(newshape);
return Call(op, {data, shape}, Attrs(attrs), {});
});
Expr ReshapeTensor(Expr data, Expr shape, Array<PrimExpr> newshape) {
static const Op& op = Op::Get("vm.reshape_tensor");
auto attrs = make_object<ReshapeTensorAttrs>();
attrs->newshape = std::move(newshape);
return Call(op, {data, shape}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.vm.reshape_tensor").set_body_typed(ReshapeTensor);

} // namespace relay
} // namespace tvm
40 changes: 40 additions & 0 deletions src/relay/op/vm/vm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/op/vm/vm.h
* \brief Dialect operators for Relay VM.
*/
#ifndef TVM_RELAY_OP_VM_VM_H_
#define TVM_RELAY_OP_VM_VM_H_

#include "tvm/relay/expr.h"

namespace tvm {
namespace relay {

Expr InvokeTVMOp(Expr func, Expr inputs, Expr outputs);
Expr ShapeFunc(Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input);
Expr ShapeOf(Expr expr);
Expr ReshapeTensor(Expr data, Expr shape, Array<PrimExpr> newshape);

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_OP_VM_VM_H_
Loading