Skip to content

Commit

Permalink
[Relay] Prepare for merging context_analysis.cc and device_annotation…
Browse files Browse the repository at this point in the history
….cc (apache#9077)

* [Relay] Prepare for merging context_analysis.cc and device_annotation.cc

- Improve construction and deconstruction of "on_device" and "device_copy" calls since they will be center stage.
- Move "device_copy" support out of memory.h into own module to mirror "on_device".
- Clearing out some DLOG -> VLOG changes I found helped me debug.
- Clearing out some whitespace-only changes I accumulated.

* [checkpoint] Address Christopher's comments.

Some stray py formatting changes snuck in since I just run black . at the root.
  • Loading branch information
mbs-octoml authored and ylc committed Jan 13, 2022
1 parent 970e0fb commit 97d642f
Show file tree
Hide file tree
Showing 26 changed files with 709 additions and 99 deletions.
28 changes: 25 additions & 3 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,37 @@ namespace tvm {
namespace relay {

/*!
* \brief Options for the device annotation operators.
* \brief Attributes for the "on_device" operator.
*
* The relay call
* \code
* on_device(expr, device_type=2)
* \endcode
* denotes that the result of \p expr should be stored on the device with \p DLDeviceType 2
* (i.e. \p kDLCuda). Semantically the operator is the identity function.
*
* See also FunctionOnDeviceAttrs in include/relay/attrs/function.h for the function-level
* companion.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
int device_type;
// TODO(mbs): Replace device types with TargetDevice.
/*! \brief Device type on which argument expression should be evaluated. */
int device_type = kInvalidDeviceType;
/*!
* \brief If true, the result device must also be \p device_type and device planning should
* not insert any "device_copy" calls to respect this annotation.
*
* This is used by the device planning pass itself when annotating the planned program.
*/
bool is_fixed = false;

TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(device_type)
.describe("The virutal device/context type that an expression is annotated with.")
.describe("The type of the virtual device which should hold the expression result.")
.set_default(0);
TVM_ATTR_FIELD(is_fixed)
.describe("If true, do not insert a \"device_copy\" call to respect this annotation.")
.set_default(false);
}
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/attrs/device_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace relay {
* \brief Options for the device copy operators.
*/
struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
// TODO(mbs): Should be TargetDevice.
int dst_dev_type;
int src_dev_type;

Expand Down
66 changes: 66 additions & 0 deletions include/tvm/relay/attrs/function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/attrs/function.h
* \brief Attributes for Relay Functions which don't make sense on PrimFuncs.
*/
#ifndef TVM_RELAY_ATTRS_FUNCTION_H_
#define TVM_RELAY_ATTRS_FUNCTION_H_

namespace tvm {
namespace relay {
/*!
* \brief Attributes for Relay function definitions which capture the devices for the
* function parameters and result.
*
* See also OnDeviceAttrs in include/tvm/relay/attrs/annotation.h for the companion "on_device"
* call attributes.
*/
struct FunctionOnDeviceAttrs : public tvm::AttrsNode<FunctionOnDeviceAttrs> {
/*! \brief Device type on which each of the function's arguments already resides. */
Array<Integer> param_device_types;
// TODO(mbs): Replace device types with TargetDevice.
/*! \brief Device type on which function body should be evaluated. */
int result_device_type = kInvalidDeviceType;

TVM_DECLARE_ATTRS(FunctionOnDeviceAttrs, "relay.attrs.FunctionOnDeviceAttrs") {
TVM_ATTR_FIELD(param_device_types)
.describe("The type of the virtual device which holds each function parameters.");
TVM_ATTR_FIELD(result_device_type)
.describe("The type of the virtual device which will hold the function's result.")
.set_default(0);
}
};

namespace attr {

/*!
* \brief Device annotations for function parameters and results.
*
* Type: FunctionOnDeviceAttrs
*/
constexpr static const char* kFunctionAttrsKey = "on_device";

} // namespace attr

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_ATTRS_FUNCTION_H_
3 changes: 2 additions & 1 deletion include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <unordered_map>
#include <utility>
#include <vector>

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -227,7 +228,7 @@ class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
*
* MixedModeVisitor provides the same recursive API as ExprVisitor, and uses
* recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions
* of the graph and processes them iteratatively to prevent stack overflows
* of the graph and processes them iteratively to prevent stack overflows
*/
class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
public:
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,8 @@ TVM_DLL Pass RelayToTIRTargetHook();
* \brief A pass for manifesting explicit memory allocations and rewriting
* specific dialects.
*
* \param target_host The target used by the host for compliation.
* \param targets The device type and target pairs for compliation.
* \param target_host The target used by the host for compilation.
* \param targets The device type and target pairs for compilation.
*
* \return The pass.
*/
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/container/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> {
};

/*!
* \brief Array, container representing a contigious sequence of ObjectRefs.
* \brief Array, container representing a contiguous sequence of ObjectRefs.
*
* Array implements in-place copy-on-write semantics.
*
Expand Down
26 changes: 16 additions & 10 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,19 @@
#include <vector>

namespace tvm {
namespace runtime {

typedef DLDevice Device;
// alias DLDevice
using Device = DLDevice;

// A 'null' device type, does not correspond to any DLDeviceType enum.
// TODO(mbs): This is to help us as we transition away from representing the 'homogenous' case
// as a singleton target map indexed by the invalid DLDeviceType '0'.
constexpr DLDeviceType kNullDeviceType = static_cast<DLDeviceType>(0);

// An 'invalid' device type, does not correspond to any DLDeviceType enum.
constexpr DLDeviceType kInvalidDeviceType = static_cast<DLDeviceType>(-1);

namespace runtime {

/*!
* \brief Managed NDArray.
Expand Down Expand Up @@ -481,23 +491,19 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
}

} // namespace runtime

// alias Device
using tvm::runtime::Device;

} // namespace tvm

namespace std {
template <>
struct hash<tvm::runtime::Device> {
std::size_t operator()(const tvm::runtime::Device& dev) const {
struct hash<tvm::Device> {
std::size_t operator()(const tvm::Device& dev) const {
return ((dev.device_id << 8) | dev.device_type);
}
};

template <>
struct equal_to<tvm::runtime::Device> {
bool operator()(const tvm::runtime::Device& lhs, const tvm::runtime::Device& rhs) const {
struct equal_to<tvm::Device> {
bool operator()(const tvm::Device& lhs, const tvm::Device& rhs) const {
return (lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id);
}
};
Expand Down
56 changes: 43 additions & 13 deletions python/tvm/relay/op/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,62 @@
from .. import op as reg


def on_device(data, device):
"""Annotate an expression with a certain device type.
def _device_to_int(device):
if isinstance(device, _Device):
return device.device_type
if isinstance(device, str):
return _nd.device(device).device_type
raise ValueError("expecting a Device or device name, but received a %s" % (type(device)))


def on_device(data, device, is_fixed=False):
"""Annotates an expression with the device type on which its result should be stored.
Parameters
----------
data : tvm.relay.Expr
The expression to be annotated.
device : Union[:py:class:`Device`, str]
The device type to annotate.
The device to annotate with. Only the device's type is significant.
is_fixed : bool
If false (the default), a device_copy
If true, the annotation does not imply a device_copy may be inserted to
reconcile the device of the data argument with the device for the context of the
annotated expression.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
if isinstance(device, _Device):
device = device.device_type
elif isinstance(device, str):
device = _nd.device(device).device_type
else:
raise ValueError(
"device is expected to be the type of Device or "
"str, but received %s" % (type(device))
)
return _make.on_device(data, device)
return _make.on_device(data, _device_to_int(device), is_fixed)


def function_on_device(function, param_devices, result_device):
"""Annotates a Relay function with the device types on which its parameters and result should
be stored.
Parameters
----------
function : tvm.relay.Function
The function to be annotated.
param_devices : Array[Union[:py:class:`Device`, str]]
The devices for each parameter. Only the device types are significant.
result_device: Union[:py:class:`Device`, str]
The device for the function result. Only the device type is significant.
Returns
-------
result : tvm.rleay.Function
The annotated function.
"""
return _make.function_on_device(
function, [_device_to_int(d) for d in param_devices], _device_to_int(result_device)
)


def stop_fusion(data):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def MergeCompilerRegions():

def RewriteAnnotatedOps(fallback_device):
"""Rewrite the annotated program where annotation operators, e.g.
`on_deivce`, mark which device an expression should be scheduled to.
`on_device`, mark which device an expression should be scheduled to.
This pass helps heterogeneous execution where different operators may need
to be allocated on various devices.
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def hexagon(cpu_ver="v66", **kwargs):

# LLVM target string
def create_llvm_target(cpu_ver, config):
""" Create LLVM target string. """
"""Create LLVM target string."""

target = " -mtriple=hexagon"
mcpu = " -mcpu=hexagon" + cpu_ver
Expand All @@ -547,7 +547,7 @@ def create_target_features(config):

# Simulator options string
def create_sim_options(cpu_ver, config):
""" Create simulator option string. """
"""Create simulator option string."""

def validate_hvx_length(codegen_hvx, sim_options):
if sim_options and "--hvx_length" in sim_options:
Expand Down Expand Up @@ -606,7 +606,7 @@ def validate_hvx_length(codegen_hvx, sim_options):

# LLVM options string
def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument
""" Create LLVM options string. """
"""Create LLVM options string."""

llvm_options = config["llvm_options"]

Expand All @@ -620,7 +620,7 @@ def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument

# TVM target attributes string
def create_tvm_options(cpu_ver, config): # pylint: disable=unused-argument
""" Create TVM target features string. """
"""Create TVM target features string."""

features = {
"link_params": "link-params",
Expand Down
7 changes: 5 additions & 2 deletions src/node/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
/*!
* \file src/node/structural_equal.cc
*/
#include <tvm/ir/module.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
#include <tvm/node/reflection.h>
Expand Down Expand Up @@ -119,8 +120,10 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
// Check the result.
bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
if (assert_mode_ && !result) {
LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by\n"
<< "lhs = " << lhs << "\nrhs = " << rhs;
LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl
<< PrettyPrint(lhs) << std::endl
<< "and rhs:" << std::endl
<< PrettyPrint(rhs);
}
return result;
}
Expand Down
8 changes: 3 additions & 5 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,10 @@ class LowerTensorExprMutator : public ExprMutator {
}

// Non-External Relay Function
DLOG(INFO) << "lowering to target '" << target->str() << "' for primitive:\n"
<< PrettyPrint(func);
VLOG(1) << "lowering to target '" << target->str() << "' for primitive:\n" << PrettyPrint(func);
CCacheKey key = CCacheKey(func, target);
CachedFunc lowered_func = compiler_->Lower(key, module_name_);
DLOG(INFO) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'";
VLOG(1) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'";

// Collect all the lowered functions produced for this primitive function.
Map<GlobalVar, tir::PrimFunc> prim_fns;
Expand All @@ -452,8 +451,7 @@ class LowerTensorExprMutator : public ExprMutator {
CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
all_prim_fn_vars.push_back(prim_fn.first);
DLOG(INFO) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first)
<< "'";
VLOG(1) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) << "'";
}

// TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/vm/inline_primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,13 @@ struct PrimitiveInliner : ExprMutator {
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);

DLOG(INFO) << "Before inlining primitives: " << global << std::endl << AsText(func, false);
VLOG(1) << "Before inlining primitives: " << global << std::endl << PrettyPrint(func);

func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Add(global, func, true);

DLOG(INFO) << "After inlining primitives: " << global << std::endl << AsText(func, false);
VLOG(1) << "After inlining primitives: " << global << std::endl << PrettyPrint(func);
}
}
return module_;
Expand Down
Loading

0 comments on commit 97d642f

Please sign in to comment.