From 575cdc5504fecf83712090a0e7299f6b3dc879c3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 19 Aug 2024 08:08:01 -0500 Subject: [PATCH] [Utility][Relax] Implemented InjectDebugCallback transform In general, intermediate values produced while evaluating Relax functions are not visible to an end user. While this provides stronger guarantees to the compiler, it can make debugging difficult. For example, if an end-to-end model is producing incorrect results, it can be difficult to determine which step of the model first introduced an error. This commit implements `relax.transform.InjectDebugCallback`, which adds a `debug_callback` parameter to each externally-exposed function of an `IRModule`. This callback is called with the name and value of each variable binding within the function bodies, allowing error-checking to be added. For example, a binding of `B = R.add(A,A)` would be followed by `debug_callback("B", B)`. --- python/tvm/relax/transform/__init__.py | 1 + python/tvm/relax/transform/transform.py | 13 + src/relax/transform/inject_debug_callback.cc | 140 +++++++++++ .../test_transform_inject_debug_callback.py | 237 ++++++++++++++++++ 4 files changed, 391 insertions(+) create mode 100644 src/relax/transform/inject_debug_callback.cc create mode 100644 tests/python/relax/test_transform_inject_debug_callback.py diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 1ce864651cd9..fb030be13313 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -47,6 +47,7 @@ FuseTIR, FusionPattern, Gradient, + InjectDebugCallback, InlinePrivateFunctions, KillAfterLastUse, LambdaLift, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 2546284625e9..927f9285f1d7 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -385,6 +385,19 @@ def after(args, fset_param: R.Callable([R.Prim('int64'), R.Object])): return _ffi_api.LazySetOutput() +def InjectDebugCallback() -> tvm.ir.transform.Pass: + """A pass that adds a callback that is called after each variable + binding. + + Returns + ------- + ret: tvm.ir.transform.Pass + The pass. + + """ + return _ffi_api.InjectDebugCallback() + + def ConvertToDataflow(min_size: int = 2) -> tvm.ir.transform.Pass: """A pass that converts consecutive dataflow operations inside binding blocks into dataflow blocks. diff --git a/src/relax/transform/inject_debug_callback.cc b/src/relax/transform/inject_debug_callback.cc new file mode 100644 index 000000000000..23cf2490f02d --- /dev/null +++ b/src/relax/transform/inject_debug_callback.cc @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/inject_debug_callback.cc + * \brief Add a callback that is called after each binding + */ + +#include +#include +#include + +#include +#include + +namespace tvm { +namespace relax { + +namespace { + +class Mutator : public ExprMutator { + public: + Expr VisitExpr_(const FunctionNode* func) override { + if (!func->GetAttr(tvm::attr::kGlobalSymbol).defined()) { + return GetRef(func); + } + + auto callback_signature = FuncStructInfo::OpaqueFunc(TupleStructInfo(Array{})); + Var debug_callback("debug_callback", callback_signature); + + Array new_params; + new_params.push_back(debug_callback); + for (Var param : func->params) { + new_params.push_back(param); + } + + auto cached = info_; + info_ = PerFunctionInfo{debug_callback}; + auto new_body = VisitWithNewScope(func->body, new_params); + + ICHECK(info_->callback_invocations.empty()); + bool new_purity = + Downcast(func->struct_info_)->purity && !info_->uses_debug_callback; + info_ = cached; + + FuncStructInfo new_sinfo(new_params.Map(GetStructInfo), func->ret_struct_info, new_purity); + + auto new_attrs = func->attrs; + if (auto num_input = func->attrs.GetAttr(attr::kNumInput)) { + new_attrs = + WithAttr(new_attrs, String(attr::kNumInput), runtime::Int(num_input.value()->value + 1)); + } + + return Function(new_params, new_body, func->ret_struct_info, new_purity, new_attrs); + } + + void VisitBinding(const Binding& binding) override { + ExprMutator::VisitBinding(binding); + if (info_ && !binding->var.as()) { + info_->uses_debug_callback = true; + Expr invoke_callback = + Call(info_->debug_callback, {relax::StringImm(binding->var->name_hint()), binding->var}); + if (builder_->CurrentBlockIsDataFlow()) { + info_->callback_invocations.push_back(invoke_callback); + } else { + builder_->Emit(invoke_callback, "_"); + } + } + } + + Expr VisitExpr_(const SeqExprNode* seq_expr) override { + bool made_change = false; + Array new_blocks; + + for (const auto& block : seq_expr->blocks) { + auto new_block = VisitBindingBlock(block); + new_blocks.push_back(new_block); + made_change = made_change || !new_block.same_as(block); + + if (info_ && info_->callback_invocations.size()) { + builder_->BeginBindingBlock(); + for (Expr invoke_callback : info_->callback_invocations) { + builder_->Emit(invoke_callback, "_"); + } + new_blocks.push_back(builder_->EndBlock()); + info_->callback_invocations.clear(); + made_change = true; + } + } + + Expr new_body = VisitExpr(seq_expr->body); + made_change = made_change || !new_body.same_as(seq_expr->body); + + if (made_change) { + return SeqExpr(new_blocks, new_body); + } else { + return GetRef(seq_expr); + } + } + + private: + struct PerFunctionInfo { + Var debug_callback; + std::vector callback_invocations; + bool uses_debug_callback = false; + }; + std::optional info_; +}; + +} // namespace + +namespace transform { +Pass InjectDebugCallback() { + auto pass_func = [=](Function func, IRModule, PassContext) -> Function { + return Downcast(Mutator()(func)); + }; + return CreateFunctionPass(pass_func, 0, "InjectDebugCallback", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.InjectDebugCallback").set_body_typed(InjectDebugCallback); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_inject_debug_callback.py b/tests/python/relax/test_transform_inject_debug_callback.py new file mode 100644 index 000000000000..e291c57d7051 --- /dev/null +++ b/tests/python/relax/test_transform_inject_debug_callback.py @@ -0,0 +1,237 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import pytest + +import tvm +import tvm.testing +from tvm.script import ir as I, tir as T, relax as R + + +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.relax.transform.InjectDebugCallback() + + +class TestSimple(BaseCompare): + """The debug callback is called after each variable binding""" + + @I.ir_module + class Before: + @R.function + def main(): + A = R.const([1.0, 2.0], "float64") + B = R.const([3.0, 4.0], "float64") + C = A + B + return C + + @I.ir_module + class Expected: + @R.function(pure=False) + def main(debug_callback: R.Callable(ret=R.Tuple([]))): + A = R.const([1.0, 2.0], "float64") + debug_callback(R.str("A"), A) + B = R.const([3.0, 4.0], "float64") + debug_callback(R.str("B"), B) + C = A + B + debug_callback(R.str("C"), C) + return C + + +class TestCallbackDelayedUntilAfterDataflow(BaseCompare): + """The debug callback is not inserted within a dataflow block. + + Dataflow blocks may not contain impure calls, and the callback is + impure. + + """ + + @I.ir_module + class Before: + @R.function + def main(): + with R.dataflow(): + A = R.const([1.0, 2.0], "float64") + B = R.const([3.0, 4.0], "float64") + C = A + B + R.output(A, B, C) + + return (A, B, C) + + @I.ir_module + class Expected: + @R.function(pure=False) + def main(debug_callback: R.Callable(ret=R.Tuple([]))): + with R.dataflow(): + A = R.const([1.0, 2.0], "float64") + B = R.const([3.0, 4.0], "float64") + C = A + B + R.output(A, B, C) + debug_callback(R.str("A"), A) + debug_callback(R.str("B"), B) + debug_callback(R.str("C"), C) + return (A, B, C) + + +class TestDelayedCallbacksDoNotIncludeDataflowVar(BaseCompare): + """The delayed callbacks only include non-dataflow variables + + The impure callback must occur after the dataflow block, but + dataflow variables may only be accessed within the dataflow block. + As a result, the callback is skipped for all dataflow vars. + + """ + + @I.ir_module + class Before: + @R.function + def main(): + with R.dataflow(): + A = R.const([1.0, 2.0], "float64") + B = R.const([3.0, 4.0], "float64") + C = A + B + R.output(C) + + return C + + @I.ir_module + class Expected: + @R.function(pure=False) + def main(debug_callback: R.Callable(ret=R.Tuple([]))): + with R.dataflow(): + A = R.const([1.0, 2.0], "float64") + B = R.const([3.0, 4.0], "float64") + C = A + B + R.output(C) + debug_callback(R.str("C"), C) + return C + + +class TestCallbackParameterPreservedNumInputAttribute(BaseCompare): + """The callback function counts as a runtime input + + The `attr::kNumInput` ("num_input") attribute indicates which + parameters are provided at runtime, and which are known at + compile-time, such as model weights. When the debug callback is + inserted, any existing `attr::kNumInput` attributes must be + updated. + + """ + + @I.ir_module + class Before: + @R.function + def main( + activations: R.Tensor([16, 1024], dtype="float16"), + weights: R.Tensor([1024, 1024], dtype="float16"), + bias: R.Tensor([1024], dtype="float16"), + ): + R.func_attr({"num_input": 1}) + after_matmul = R.matmul(activations, weights) + after_bias = R.add(after_matmul, bias) + + return after_bias + + @I.ir_module + class Expected: + @R.function(pure=False) + def main( + debug_callback: R.Callable(ret=R.Tuple([])), + activations: R.Tensor([16, 1024], dtype="float16"), + weights: R.Tensor([1024, 1024], dtype="float16"), + bias: R.Tensor([1024], dtype="float16"), + ): + R.func_attr({"num_input": 2}) + after_matmul = R.matmul(activations, weights) + debug_callback(R.str("after_matmul"), after_matmul) + after_bias = R.add(after_matmul, bias) + debug_callback(R.str("after_bias"), after_bias) + + return after_bias + + +@tvm.testing.parametrize_targets("llvm", "cuda") +def test_inject_debug_check_for_nan(target, dev): + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([2], "float32")): + B = A + R.prim_value(T.float32(1.0)) + C = R.sqrt(B) + D = A + C + return D + + target = tvm.target.Target(target) + if "gpu" in target.keys: + Module = tvm.ir.transform.Sequential( + [ + tvm.relax.transform.LegalizeOps(), + tvm.tir.transform.BindTarget(target), + tvm.tir.transform.DefaultGPUSchedule(), + ] + )(Module) + + built = tvm.relax.build(Module, target) + vm = tvm.relax.VirtualMachine(built, dev) + + # Suppose a function can be called with most outputs, producing + # valid outputs. + np_input = np.array([1.0, 2.0], dtype="float32") + expected = np.sqrt(np_input + 1.0) + np_input + tvm_input = tvm.nd.array(np_input, dev) + tvm_output = vm["main"](tvm_input) + tvm.testing.assert_allclose(expected, tvm_output.numpy()) + + # However, for some inputs, the function produces incorrect values + np_input = np.array([-5.0, 5.0], dtype="float32") + vm["main"](tvm.nd.array(np_input, dev)) + + # We'd like to have some assertion in order to determine where the + # error occurs. However, we only have visibility to the final + # output of the end-to-end function. + + def assert_not_nan(var_name, var_value): + if isinstance(var_value, tvm.runtime.NDArray): + contains_nan = np.isnan(var_value.numpy()).any() + assert not contains_nan, f"Variable {var_name} contained NaN" + + # A callback can be inserted with `InjectDebugCallback`. After + # applying this pass, all externally-exposed functions take a + # callback function as their first parameter. + + Module = tvm.relax.transform.InjectDebugCallback()(Module) + + built = tvm.relax.build(Module, target) + vm = tvm.relax.VirtualMachine(built, dev) + + # The valid inputs can be inspected, and still produce the same + # output. + np_input = np.array([1.0, 2.0], dtype="float32") + expected = np.sqrt(np_input + 1.0) + np_input + tvm_input = tvm.nd.array(np_input, dev) + tvm_output = vm["main"](assert_not_nan, tvm_input) + tvm.testing.assert_allclose(expected, tvm_output.numpy()) + + # However, the invalid inputs can be caught in the debug function + # and inspected. + np_input = np.array([-5.0, 5.0], dtype="float32") + with pytest.raises(AssertionError, match="Variable C contained NaN"): + vm["main"](assert_not_nan, tvm.nd.array(np_input, dev)) + + +if __name__ == "__main__": + tvm.testing.main()