Skip to content

Commit

Permalink
[Relay] PlanDevices supports 'free' on_device annotations (apache#9693)
Browse files Browse the repository at this point in the history
* [Relay] PlanDevices supports 'free' on_device annotations

This is in support of apache#9613, which allows PlanDevices to be run
after lowering so as to flow memory constraints in and
out of PrimFuncs. That requires a way to insert device_copies
when the memory scopes chosen during separate lowering of fused
primitive functions clashes, but otherwise avoid device_copies when
scopes can be chosen so as to avoid them.

We support that by generalizing the "on_device" annotation to
allow the device constraint to be independently controlled for
its 'body' and 'result'.

# Standard user annotation: body is constrained to S
on_device(body, S)

# Used by PlanDevices to 'fix' expression to S
# (was is_fixed=True)
on_device(body, S, constrain_result=True)

# Used by PlanDevices to indicate a device_copy can be
# inserted if necessary.
on_device(body, S, constrain_body=False)

# Supported, but currently has no use.
on_device(body, S, constrain_result=True, constrain_body=False)

A few extra odd's 'n ends collected along the way:
 - Some CallLowered cleanup which I found useful.
 - The usual extra debugging output needed as I debugged.
   In return I removed some particularly verbose logging I'd
   added while tracking down unexpected object copies.
 - Cleanup warnings from clang-12 as I touch files.

* [checkpoint] unused var
  • Loading branch information
mbs-octoml authored and baoxinqi committed Dec 27, 2021
1 parent 50d45b0 commit 6e3d992
Show file tree
Hide file tree
Showing 28 changed files with 831 additions and 418 deletions.
4 changes: 3 additions & 1 deletion include/tvm/relay/attrs/call.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ struct CallLoweredAttrs : public tvm::AttrsNode<CallLoweredAttrs> {
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") {
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the lowered function call.");
TVM_ATTR_FIELD(metadata)
.describe("Metadata attached to the lowered function call.")
.set_default(Map<String, ObjectRef>());
}
};

Expand Down
62 changes: 33 additions & 29 deletions include/tvm/relay/attrs/on_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* \file tvm/relay/attrs/on_device.h
* \brief Attribute for the on device annotation.
* \brief Attribute for the "on_device" annotation (ie operator).
*/
#ifndef TVM_RELAY_ATTRS_ON_DEVICE_H_
#define TVM_RELAY_ATTRS_ON_DEVICE_H_
Expand All @@ -33,9 +33,9 @@ namespace tvm {
namespace relay {

/*!
* \brief Attributes for the "on_device" special operator.
* \brief Attributes for the "on_device" annotation (ie operator).
*
* The Relay call (aka 'annotation'):
* The Relay call:
* \code
* on_device(sub_expr, se_scope=S)
* \endcode
Expand All @@ -54,44 +54,48 @@ namespace relay {
* multiply(device_copy(add(%x, %y), src_se_scope=GPU, dst_se_scope=CPU), %z)
* \endcode
*
* The Relay call
* \code
* on_device(sub_expr, se_scope=S, is_fixed=True)
* \endcode
* is similar to the above, however the annotation itself must appear in an expression on the
* same \p SEScope \p S. The compiler will check the \p SEScopes are consistent, and will not
* insert any "device_copy" call. This form of annotation shouldn't be necessary in user programs.
* However it is needed by the \p PlanDevices pass to fully specify the results of device planning
* so that the pass is idempotent.
*
* E.g.: The following program is equivalent to the above:
* \code
* let %a = on_device(add(%x, %y), se_scope=GPU, is_fixed=True)
* multiply(device_copy(%a, src_se_scope=GPU, dst_se_scope=CPU), %z)
* \endcode
* The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored
* on the GPU.
* The \p constraint_body (default true) and \p constraint_result (default false) fields can be
* used by passes for finer-grained control over how the \p SEScope constraint should be applied.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
/*!
* \brief (Virtual) \p SEScope on which the result of the argument expression should be stored.
* \brief The \p SEScope to constraint to apply to the body, result, or both body and result
* of the "on_device" call.
*/
SEScope se_scope = SEScope::FullyUnconstrained();

/*!
* \brief If fales (the default), the result of the "on_device" call is not constrained to be
* \p se_scope.
*/
bool constrain_result = false;

/*!
* \brief If true (the default), the body of the "on_device" call is constrained to be \p
* se_scope.
*/
bool constrain_body = true;

/*!
* \brief Returns true if both the body and result are constrained.
*/
bool is_fixed() const { return constrain_result && constrain_body; }

/*!
* \brief If true, the result \p SEScope must also be \p se_scope, and device planning should
* not insert any "device_copy" calls to respect this annotation.
*
* This is used by the device planning pass itself when annotating the planned program.
* \brief Returns true only the body is constrained (the 'normal' case).
*/
bool is_fixed = false;
bool is_normal() const { return !constrain_result && constrain_body; }

TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(se_scope)
.describe("The (virtual) device and scope holding the expression result.")
.describe("The (virtual) device to constrain to.")
.set_default(SEScope::FullyUnconstrained());
TVM_ATTR_FIELD(is_fixed)
.describe("If true, do not insert a \"device_copy\" call to respect this annotation.")
TVM_ATTR_FIELD(constrain_result)
.describe("Whether the constraint applies to the overall expression")
.set_default(false);
TVM_ATTR_FIELD(constrain_body)
.describe("Whether the constraint applies to the body sub-expression.")
.set_default(true);
}
};

Expand Down
24 changes: 15 additions & 9 deletions python/tvm/relay/op/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,35 @@ def _make_se_scope(device):
raise ValueError("expecting a Device or device name, but received a %s" % (type(device)))


def on_device(data, device, is_fixed=False):
"""Annotates an expression with the device type on which its result should be stored.
def on_device(body, device, constrain_result=False, constrain_body=True):
"""Annotates a body expression with device constraints. The constraint influences
how the body is compiled, where the body is evaluated, and where the result of
evaluation is stored.
Note that the defaults for the constrain_body and constrain_result parameters should
almost never need to be overridden by the user. These parameters are exposed here
to help unit tests exercise the PlanDevices pass machinery.
Parameters
----------
data : tvm.relay.Expr
body : tvm.relay.Expr
The expression to be annotated.
device : Union[:py:class:`Device`, str]
The device to annotate with.
is_fixed : bool
If false (the default), a device_copy
If true, the annotation does not imply a device_copy may be inserted to
reconcile the device of the data argument with the device for the context of the
annotated expression.
constrain_result : bool
If false (the default), the result of the on_device is not constrained to be on device.
constrain_body : bool
If true (the default), the body of the on_device is constrained to be on device.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.OnDevice(data, _make_se_scope(device), is_fixed)
return _make.OnDevice(body, _make_se_scope(device), constrain_result, constrain_body)


def function_on_device(function, param_devices, result_device):
Expand Down
11 changes: 5 additions & 6 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,6 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
Doc doc;
doc << "@" << op->name_hint;
#if TVM_LOG_DEBUG
if (op->checked_type_.defined()) {
doc << " /* type=" << PrintType(op->checked_type_, /*meta=*/false) << " */";
}
doc << " /* id=" << reinterpret_cast<uint64_t>(op) << " */";
#endif
return doc;
}

Expand All @@ -521,6 +515,11 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) {
for (const Expr& arg : op->args) {
args.push_back(Print(arg));
}
#if TVM_LOG_DEBUG
for (const Type& type_arg : op->type_args) {
args.push_back(Print(type_arg));
}
#endif
for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
args.push_back(d);
}
Expand Down
57 changes: 40 additions & 17 deletions src/printer/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <tvm/tir/function.h>

#include <algorithm>
#include <string>

namespace tvm {
Expand All @@ -36,49 +37,71 @@ static const char* kSemVer = "0.0.5";
Doc TextPrinter::PrintMod(const IRModule& mod) {
Doc doc;
int counter = 0;

// We'll print in alphabetical order to make a/b diffs easier to work with.

// type definitions
std::vector<GlobalTypeVar> tyvars;
for (const auto& kv : mod->type_definitions) {
tyvars.emplace_back(kv.first);
}
std::sort(tyvars.begin(), tyvars.end(),
[](const GlobalTypeVar& left, const GlobalTypeVar& right) {
return left->name_hint < right->name_hint;
});
for (const auto& tyvar : tyvars) {
if (counter++ != 0) {
doc << Doc::NewLine();
}
doc << relay_text_printer_.Print(kv.second);
doc << relay_text_printer_.Print(mod->type_definitions[tyvar]);
doc << Doc::NewLine();
}

// functions
std::vector<GlobalVar> vars;
for (const auto& kv : mod->functions) {
if (kv.second.as<relay::FunctionNode>()) {
vars.emplace_back(kv.first);
}
std::sort(vars.begin(), vars.end(), [](const GlobalVar& left, const GlobalVar& right) {
return left->name_hint < right->name_hint;
});
for (const auto& var : vars) {
const BaseFunc& base_func = mod->functions[var];
if (base_func.as<relay::FunctionNode>()) {
relay_text_printer_.dg_ =
relay::DependencyGraph::Create(&relay_text_printer_.arena_, kv.second);
relay::DependencyGraph::Create(&relay_text_printer_.arena_, base_func);
}
if (counter++ != 0) {
doc << Doc::NewLine();
}
if (kv.second.as<relay::FunctionNode>()) {
if (base_func.as<relay::FunctionNode>()) {
std::ostringstream os;
os << "def @" << kv.first->name_hint;
#if TVM_LOG_DEBUG
os << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second);
} else if (kv.second.as<tir::PrimFuncNode>()) {
doc << "@" << kv.first->name_hint;
#if TVM_LOG_DEBUG
doc << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
os << "def @" << var->name_hint;
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), base_func);
} else if (base_func.as<tir::PrimFuncNode>()) {
doc << "@" << var->name_hint;
doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(base_func));
}
doc << Doc::NewLine();
}

#if TVM_LOG_DEBUG
// attributes
// TODO(mbs): Make this official, including support from parser.
if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
doc << "attributes {" << Doc::NewLine();
std::vector<String> keys;
for (const auto& kv : mod->attrs->dict) {
doc << " '" << kv.first << "' = " << PrettyPrint(kv.second) << Doc::NewLine();
keys.emplace_back(kv.first);
}
std::sort(keys.begin(), keys.end());
doc << "attributes {" << Doc::NewLine();
for (const auto& key : keys) {
doc << " '" << key << "' = " << PrettyPrint(mod->attrs->dict[key]) << Doc::NewLine();
}
doc << "}" << Doc::NewLine();
}
#endif

return doc;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ class ConvertAddToSubtract : public MixedModeMutator {

// Since we are replacing the Relay function with a call to a TIR function, we must use the
// call_lowered op.
auto call_lowered_attrs = make_object<CallLoweredAttrs>();
call_lowered_attrs->metadata.Set("relay_attrs", call->attrs);
return CallLowered(std::move(new_global_var), call->args,
std::move(Attrs(call_lowered_attrs)), call->type_args, call->span);
CallLoweredAttrs attrs;
attrs.metadata.Set("relay_attrs", call->attrs);
ICHECK(call->type_args.empty()) << "lowered functions cannot be polymorphic";
return CallLowered(std::move(new_global_var), call->args, std::move(attrs), call->span);
}
}

Expand All @@ -144,5 +144,4 @@ transform::Pass RelayToTIR() {
} // namespace example_target_hooks
} // namespace contrib
} // namespace relay

} // namespace tvm
42 changes: 21 additions & 21 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -589,8 +589,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
* to the TIR implementation, and attributes to attach to the call to identify it as
* a TIR call.
*/
Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Array<Type> type_args, Span span,
Target target) {
Expr MakeLoweredCall(Function func, Array<Expr> visited_args, Span span, Target target) {
CCacheKey key = CCacheKey(func, target);
CachedFunc cfunc = compiler_->Lower(key, module_name_);
ICHECK(cfunc.defined());
Expand Down Expand Up @@ -623,16 +622,16 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, cfunc->target);
this->process_fn_(func_with_metadata);

auto call_lowered_attrs = make_object<CallLoweredAttrs>();
CallLoweredAttrs call_lowered_attrs;

// Non-External Relay Function
// TODO(mbs): "reshape" cleanup.
if (!opt_compiler && func->HasNonzeroAttr(attr::kReshapeOnly)) {
call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
call_lowered_attrs.metadata.Set(attr::kReshapeOnly, tvm::Integer(1));
}

call_lowered_attrs->metadata.Set("relay_attrs", func->attrs);
call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars);
call_lowered_attrs.metadata.Set("relay_attrs", func->attrs);
call_lowered_attrs.metadata.Set("all_prim_fn_vars", all_prim_fn_vars);

if (IsDynamic(func->ret_type)) {
// Also lower the companion dynamic shape function.
Expand All @@ -645,24 +644,24 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
// Capture the shape function's global var and parameters 'states' in call
// annotations so calling convention can be recovered.
// TODO(mbs): Shape cleanup.
call_lowered_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var);
call_lowered_attrs->metadata.Set("prim_shape_fn_states",
lowered_shape_func->shape_func_param_states);
call_lowered_attrs->metadata.Set(
"prim_shape_fn_num_inputs", Integer(static_cast<int>(lowered_shape_func->inputs.size())));
call_lowered_attrs->metadata.Set(
call_lowered_attrs.metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var);
call_lowered_attrs.metadata.Set("prim_shape_fn_states",
lowered_shape_func->shape_func_param_states);
call_lowered_attrs.metadata.Set("prim_shape_fn_num_inputs",
Integer(static_cast<int>(lowered_shape_func->inputs.size())));
call_lowered_attrs.metadata.Set(
"prim_shape_fn_num_outputs",
Integer(static_cast<int>(lowered_shape_func->outputs.size())));
Array<GlobalVar> all_prim_shape_fn_vars;
for (const auto& kv : lowered_shape_func->funcs->functions) {
CHECK(kv.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
all_prim_shape_fn_vars.push_back(kv.first);
}
call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars);
call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars);
}

return CallLowered(cfunc->prim_fn_var, std::move(visited_args), Attrs(call_lowered_attrs),
type_args, std::move(span));
return CallLowered(cfunc->prim_fn_var, std::move(visited_args), std::move(call_lowered_attrs),
std::move(span));
}

std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) final {
Expand Down Expand Up @@ -758,12 +757,13 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {
});

ICHECK(!IsDynamic(call_node->checked_type()));
auto call_lowered_attrs = make_object<CallLoweredAttrs>();
call_lowered_attrs->metadata.Set("relay_attrs", primitive_func->attrs);
CallLoweredAttrs call_lowered_attrs;
call_lowered_attrs.metadata.Set("relay_attrs", primitive_func->attrs);

process_fn_(func_with_metadata);
return CallLowered(call_node->op, std::move(new_args), Attrs(std::move(call_lowered_attrs)),
call_node->type_args, call_node->span);
ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic";
return CallLowered(prim_func_var, std::move(new_args), std::move(call_lowered_attrs),
call_node->span);
}

// Typical case: call to fused primitive Relay Function.
Expand All @@ -783,8 +783,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {

// Lower the primitive function for that target.
Function function = Downcast<Function>(primitive_func);
return MakeLoweredCall(function, std::move(new_args), call_node->type_args, call_node->span,
target);
ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic";
return MakeLoweredCall(function, std::move(new_args), call_node->span, target);
}

IRModule module_;
Expand Down
Loading

0 comments on commit 6e3d992

Please sign in to comment.