Skip to content

Commit

Permalink
[Relay] Allow Primitive functions to carry virtual device annotations…
Browse files Browse the repository at this point in the history
… in PlanDevices (apache#12095)

* [Relay] Allow Primitive function to carry virtual device annotations in PlanDevices

Previously Primitive=1 functions not analyzed and calls to such were completely
unconstrained. With this change at least any virtual device annotation on the function
are respected and accounted for in calls, even though the body is not analyzed.

This may help with piggy-backing on PlanDevices for doing memory scope analysis, since
it is now possible to express cross-scope functions on Primitive functions. However
I believe there are other issues to deal with in addition to this one.

* - comments

* - also canonicalize targets

When including virtual device annotations in test relay programs the
annotation will typically use a target which was used as an input to
the make_compilation_config helper, but due to various canonicalization
make not be pointer equal to the final structurally equal target which ends
up inside the constructed CompilationConfig. However VirtualDevices use
pointer equality when comparing their target field.

So make sure the notion of CanonicalVirtualDevice also accounts for canonical
targets.

* - update unit test to reflect the Ardreno example

* - trivial cleanup
  • Loading branch information
mbs-octoml authored and Mikael Sevenier committed Jul 26, 2022
1 parent 0c59647 commit 941f609
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 57 deletions.
6 changes: 6 additions & 0 deletions include/tvm/target/compilation_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ class CompilationConfigNode : public Object {
*/
Optional<Target> FindPrimitiveTargetForKind(const std::string& kind_name) const;

/*!
* \brief Returns a \p Target structurally equal to \p target, however prefer a structually equal
* known host or primitive target if the configuration has one.
*/
Target CanonicalTarget(const Target& target) const;

/*!
* \brief Returns a \p VirtualDevice agreeing with \p virtual_device on all its constrained
* fields, however:
Expand Down
11 changes: 7 additions & 4 deletions src/relay/transforms/device_domains.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,13 @@ void DeviceDomains::SetDefault(DeviceDomainPtr domain,
ICHECK(!default_virtual_device->IsFullyUnconstrained());
domain = Lookup(domain);
if (domain->args_and_result_.empty()) {
DeviceDomainPtr defaulted_domain_ptr = UnifyOrNull(
domain, MakeFirstOrderDomain(config_->CanonicalVirtualDevice(
VirtualDevice::Default(domain->virtual_device_, default_virtual_device))));
ICHECK_NOTNULL(defaulted_domain_ptr);
DeviceDomainPtr default_domain = MakeFirstOrderDomain(config_->CanonicalVirtualDevice(
VirtualDevice::Default(domain->virtual_device_, default_virtual_device)));
DeviceDomainPtr defaulted_domain_ptr = UnifyOrNull(domain, default_domain);
ICHECK(defaulted_domain_ptr != nullptr) << "domain:" << std::endl
<< ToString(domain) << std::endl
<< "default domain:" << std::endl
<< ToString(default_domain);
} else {
for (const auto& sub_domain : domain->args_and_result_) {
SetDefault(sub_domain, default_virtual_device);
Expand Down
71 changes: 31 additions & 40 deletions src/relay/transforms/device_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -553,18 +553,9 @@ class DeviceAnalyzer : public MixedModeVisitor {
}

void VisitExpr_(const FunctionNode* function_node) final {
// No need to step into fused primitive functions as they are lowered individually according
// to the devices of all their call sites.
if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
return;
}

auto function = GetRef<Function>(function_node);
auto func_domain = domains_->DomainFor(function); // higher-order

// The function body domain must match the function result domain.
domains_->UnifyExprExact(function_node->body,
func_domain->function_result()); // may be higher-order
ICHECK_EQ(func_domain->function_arity(), function_node->params.size());

VLOG(2) << "initial function domain:" << std::endl
<< domains_->ToString(func_domain) << std::endl
Expand All @@ -573,39 +564,33 @@ class DeviceAnalyzer : public MixedModeVisitor {
<< "for function:" << std::endl
<< PrettyPrint(function);

ICHECK_EQ(func_domain->function_arity(), function_node->params.size());
for (size_t i = 0; i < function_node->params.size(); ++i) {
// The parameter domains must match the function argument domains.
domains_->UnifyExprExact(function_node->params[i],
func_domain->function_param(i)); // may be higher-order
VisitExpr(function_node->params[i]);
// The function body domain must match the function result domain.
domains_->UnifyExprExact(function_node->body,
func_domain->function_result()); // may be higher-order
if (!function_node->virtual_device()->IsFullyUnconstrained()) {
// The function body domain must match any existing virtual device annotation.
domains_->UnifyExprExact(function_node->body,
domains_->ForVirtualDevice(function_node->body->checked_type(),
function_node->virtual_device()));
}

// If the function already has VirtualDevice attributes then we can further constrain the
// function's domain to match them.
if (!function_node->virtual_device()->IsFullyUnconstrained()) {
std::vector<DeviceDomainPtr> args_and_result;
for (auto param : function_node->params) {
args_and_result.emplace_back(
domains_->ForVirtualDevice(param->checked_type(), param->virtual_device()));
}
args_and_result.emplace_back(domains_->ForVirtualDevice(function_node->body->checked_type(),
function_node->virtual_device()));
auto annotation_domain = domains_->MakeHigherOrderDomain(std::move(args_and_result));
if (domains_->UnifyOrNull(func_domain, annotation_domain) == nullptr) { // higher-order
// TODO(mbs): Proper diagnostics.
LOG(FATAL) << "Function VirtualDevices are incompatible with its \"on_device\" annotation. "
"Function:"
<< std::endl
<< PrettyPrint(function) << std::endl
<< "with function virtual devices:" << std::endl
<< domains_->ToString(func_domain) << std::endl
<< "and annotation virtual devices:" << std::endl
<< domains_->ToString(annotation_domain);
for (size_t i = 0; i < function_node->params.size(); ++i) {
const auto& param = function_node->params[i];
// The parameter domain must match the function argument domain.
domains_->UnifyExprExact(param,
func_domain->function_param(i)); // may be higher-order
if (!param->virtual_device()->IsFullyUnconstrained()) {
// The parameter domain must match any existing virtual device annotation.
domains_->UnifyExprExact(
param, domains_->ForVirtualDevice(param->checked_type(), param->virtual_device()));
}
VisitExpr(param);
}

VisitExpr(function_node->body);
// No need to step into the body of Primitive functions.
if (!function_node->HasNonzeroAttr(attr::kPrimitive)) {
VisitExpr(function_node->body);
}

VLOG(2) << "final function domain:" << std::endl
<< domains_->ToString(func_domain) << std::endl
Expand Down Expand Up @@ -839,10 +824,16 @@ class DeviceDefaulter : public ExprVisitor {
// For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*)
// above. But for calls to primitives we may still need to force free domains to be
// defaulted.
VLOG(2) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain);
VLOG(2) << "before defaulting callee:" << std::endl
<< PrettyPrint(call_node->op) << std::endl
<< "of domain:" << std::endl
<< domains_->ToString(func_domain);
domains_->SetResultDefaultThenParams(func_domain,
domains_->config()->default_primitive_virtual_device);
VLOG(2) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain);
VLOG(2) << "after defaulting callee:" << std::endl
<< PrettyPrint(call_node->op) << std::endl
<< "of domain:" << std::endl
<< domains_->ToString(func_domain);
}
return ExprVisitor::VisitExpr_(call_node);
}
Expand Down
46 changes: 36 additions & 10 deletions src/target/compilation_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,43 @@ Optional<Target> CompilationConfigNode::FindPrimitiveTargetForKind(
return *itr;
}

Target CompilationConfigNode::CanonicalTarget(const Target& target) const {
// Fast path -- object identity.
if (target == host_target) {
return target;
}
for (const auto& primitive_target : primitive_targets) {
if (target == primitive_target) {
return target;
}
}
// Slow path -- structural equality. We have so few targets it does not seem worth building an
// index.
if (StructuralEqual()(target, host_target)) {
return host_target;
}
for (const auto& primitive_target : primitive_targets) {
if (StructuralEqual()(target, primitive_target)) {
return primitive_target;
}
}
// No match.
return target;
}

VirtualDevice CompilationConfigNode::CanonicalVirtualDevice(
const VirtualDevice& virtual_device) const {
if (virtual_device->target.defined()) {
return virtual_device_cache_.Unique(virtual_device);
}
DLDeviceType device_type = virtual_device->device_type();
// TODO(mbs): Proper diagnostics.
CHECK(device_type != kInvalidDeviceType)
<< "VirtualDevice annotations must include at least a device_type";
Target target = FindPrimitiveTargetForDeviceOrFail(virtual_device->device_type());
Target target = virtual_device->target;
if (target.defined()) {
target = CanonicalTarget(target);
} else {
// Find the (unique) target matching the device's device type.
// TODO(mbs): Proper diagnostics.
CHECK(device_type != kInvalidDeviceType)
<< "VirtualDevice annotations must include at least a device_type";
target = FindPrimitiveTargetForDeviceOrFail(device_type);
}
return virtual_device_cache_.Unique(VirtualDevice(device_type, virtual_device->virtual_device_id,
target, virtual_device->memory_scope));
}
Expand Down Expand Up @@ -222,9 +249,8 @@ void CompilationConfigNode::Init(const transform::PassContext& pass_ctx,
// Establish the default primitive VirtualDevice, choosing a known Target to match the device
// type. We do not create a default target, it must already exist as a primitive target.
//
default_primitive_virtual_device = virtual_device_cache_.Unique(VirtualDevice(
default_primitive_device_type,
/*virtual_device_id=*/0, FindPrimitiveTargetForDeviceOrFail(default_primitive_device_type)));
default_primitive_virtual_device = CanonicalVirtualDevice(
VirtualDevice::ForDeviceType(default_primitive_device_type, /*virtual_device_id=*/0));

ICHECK(default_primitive_virtual_device.defined());
ICHECK(default_primitive_virtual_device->target.defined());
Expand Down
3 changes: 2 additions & 1 deletion src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,8 @@ Optional<Target> TargetNode::GetHost() const {
String TargetNode::ToDebugString() const {
std::ostringstream os;
os << "Target(";
os << "kind='" << kind->name << "'";
os << "id=" << std::hex << reinterpret_cast<size_t>(this);
os << ", kind='" << kind->name << "'";
if (!tag.empty()) {
os << ", tag='" << tag << "'";
}
Expand Down
29 changes: 29 additions & 0 deletions tests/cpp/target/compilation_config_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,29 @@ TEST(CompilationConfig, FindPrimitiveTargetForKind_NotFound) {
ASSERT_FALSE(config->FindPrimitiveTargetForKind("cutlass").defined());
}

TEST(CompilationConfig, CanonicalTarget) {
Target host_target = TestDefaultCpuTarget();
Target cuda_target = TestCudaTarget();
Target cpu_target = TestCpuTarget();
CompilationConfig config = TestCompilationConfig();

{
Target other_cuda_target = Target::WithHost(TestCudaTarget(), TestDefaultCpuTarget());
ASSERT_NE(other_cuda_target, cuda_target);
ASSERT_EQ(config->CanonicalTarget(other_cuda_target),
config->FindPrimitiveTargetForKind("cuda"));
}
{
Target other_host_target = TestDefaultCpuTarget();
ASSERT_NE(other_host_target, cuda_target);
ASSERT_EQ(config->CanonicalTarget(other_host_target), config->host_target);
}
{
Target other_target("cuda -max_num_threads=7");
ASSERT_EQ(config->CanonicalTarget(other_target), other_target);
}
}

TEST(CompilationConfig, CanonicalVirtualDevice) {
Target host_target = TestDefaultCpuTarget();
Target cuda_target = TestCudaTarget();
Expand All @@ -306,6 +329,12 @@ TEST(CompilationConfig, CanonicalVirtualDevice) {
EXPECT_TRUE(StructuralEqual()(actual->target, Target::WithHost(cuda_target, host_target)));
EXPECT_EQ(config->CanonicalVirtualDevice(in), actual);
}
{
Target other_cuda_target = Target::WithHost(TestCudaTarget(), TestDefaultCpuTarget());
VirtualDevice in = VirtualDevice(kDLCUDA, -1, other_cuda_target);
VirtualDevice actual = config->CanonicalVirtualDevice(in);
ASSERT_EQ(actual->target, config->FindPrimitiveTargetForKind("cuda"));
}
}

TEST(CompilationConfig, CanonicalVirtualDevice_NoDevice) {
Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/target/virtual_device_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ TEST(VirtualDeviceCache, Memoized) {
VirtualDeviceCache cache;
Target target_a = Target("cuda");
Target target_b = Target("llvm");
Target target_c = Target("cuda");
VirtualDevice virtual_device_a = cache.Make(kDLCUDA, 3, target_a, "local");
VirtualDevice virtual_device_b = cache.Make(kDLCPU, 1, target_b, "global");

Expand All @@ -115,6 +116,9 @@ TEST(VirtualDeviceCache, Memoized) {
EXPECT_NE(cache.Make(kDLCUDA, 2, target_a, "local"), virtual_device_a);
EXPECT_NE(cache.Make(kDLCPU, 3, target_b, "local"), virtual_device_a);
EXPECT_NE(cache.Make(kDLCUDA, 3, target_a, "global"), virtual_device_a);
EXPECT_EQ(cache.Make(kDLCUDA, 3, Target("cuda"), "local"), virtual_device_a);
EXPECT_NE(cache.Make(kDLCUDA, 3, Target("cuda -max_threads_per_block=4096"), "local"),
virtual_device_a);
}

} // namespace
Expand Down
55 changes: 53 additions & 2 deletions tests/python/relay/test_pass_plan_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
CPU_SCOPE_A = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET, memory_scope="scopeA")
CPU_SCOPE_B = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET, memory_scope="scopeB")

GPU_SCOPE_GLOBAL = tvm.target.VirtualDevice(GPU_DEVICE, GPU_TARGET, memory_scope="global")
GPU_SCOPE_TEXTURE = tvm.target.VirtualDevice(GPU_DEVICE, GPU_TARGET, memory_scope="global.texture")

CTXT = tvm.transform.PassContext(config={"relay.fallback_device_type": DEFAULT.device_type_int})

core = tvm.IRModule()
Expand All @@ -57,7 +60,7 @@

def rewrite_and_assert(in_mod, expected_mod):
"""Manually run the pass and assert it's structurally equals to the expected."""
config = tvm.target.make_compilation_config(CTXT, TARGETS, HOST_TARGET)
config = tvm.target.make_compilation_config(CTXT, TARGETS)
actual_mod = relay.transform.InferType()(in_mod)
actual_mod = relay.transform.PlanDevices(config)(actual_mod)
actual_mod = relay.transform.InferType()(actual_mod)
Expand Down Expand Up @@ -1774,11 +1777,59 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
metatable,
)

config = tvm.target.make_compilation_config(CTXT, TARGETS, HOST_TARGET)
config = tvm.target.make_compilation_config(CTXT, TARGETS)
actual_mod = relay.transform.InferType()(input())
actual_mod = relay.transform.PlanDevices(config)(actual_mod)
relay.transform.InferType()(actual_mod)


def test_primitive():
"""Annotations on Primitive functions should be accepted, even though the body
of the Primitive function is not considered during PlanDevices."""
metatable = {
"VirtualDevice": [
GPU_SCOPE_GLOBAL,
GPU_SCOPE_TEXTURE,
]
}

mod = tvm.parser.parse(
"""
#[version = "0.0.5"]
def @main(%data1: Tensor[(1, 32, 40, 40), float32],
%data2: Tensor[(1, 32, 40, 40), float32]) {
%0 = fn (%a, Primitive=1) {
layout_transform(%a, src_layout="NCHW", dst_layout="NCHW4c")
};
%1 = %0(%data1);
%3 = %0(%data2);
%5 = fn (%a {virtual_device=meta[VirtualDevice][0]},
%b {virtual_device=meta[VirtualDevice][0]},
virtual_device=meta[VirtualDevice][1],
Primitive=1) {
add(%a, %b)
};
%6 = %5(%1, %3);
%10 = fn (%a,
virtual_device=meta[VirtualDevice][0],
Primitive=1) {
layout_transform(%a, src_layout="NCHW4c", dst_layout="NCHW")
};
%10(%6)
}
""",
"from_string",
None,
metatable,
)
print(mod)

config = tvm.target.make_compilation_config(CTXT, GPU_TARGET)
mod = relay.transform.InferType()(mod)
# PlanDevices should succeed.
mod = relay.transform.PlanDevices(config)(mod)
print(mod)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 941f609

Please sign in to comment.