Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Update GlobalVar name in AttachGlobalSymbol #17202

Merged
merged 4 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions include/tvm/ir/analysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ir/analysis.h
*
* Analysis routines that must function across multiple IR types for
* correctness. For example, identifying unused functions, when both TIR
*
*/
#ifndef TVM_IR_ANALYSIS_H_
#define TVM_IR_ANALYSIS_H_

#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <tvm/node/functor.h>
#include <tvm/runtime/container/array.h>

namespace tvm {
namespace ir {

class CalleeCollector {
public:
/* \brief Functor to be registered for IR types
*
* Should be implemented for each `BaseFunc` subclass.
* Implementation should call `CalleeCollector::Mark` for each
* `GlobalVar` in the function.
*/
using FType = NodeFunctor<void(const ObjectRef&, CalleeCollector*)>;
TVM_DLL static FType& vtable() {
static FType inst;
return inst;
}

virtual ~CalleeCollector() {}

/* \brief Collect the GlobalVar in a function */
virtual void Mark(GlobalVar gvar) = 0;
};

Map<GlobalVar, Array<GlobalVar>> CollectCallMap(const IRModule& mod);

} // namespace ir
} // namespace tvm

#endif // TVM_IR_ANALYSIS_H_
57 changes: 57 additions & 0 deletions include/tvm/ir/replace_global_var.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ir/replace_global_var.h
*
* \brief A utility to replace GlobalVar instances across all TVM IR
* types in an IRMdoule.
*/
#ifndef TVM_IR_REPLACE_GLOBAL_VAR_H_
#define TVM_IR_REPLACE_GLOBAL_VAR_H_

#include <tvm/ir/module.h>

namespace tvm {
namespace transform {

/*!
* \brief Replace GlobalVar instances across any IR type.
*
* \param mod The module to update
*
* \param replacements The map, where each entry maps from an old
* `GlobalVar` to the new `GlobalVar` that should replace it.
*
* \return The updated IRModule
*/
TVM_DLL IRModule ReplaceGlobalVar(IRModule mod, Map<GlobalVar, GlobalVar> replacements);

struct GlobalVarReplacer {
using FType = NodeFunctor<BaseFunc(const ObjectRef&, Map<GlobalVar, GlobalVar>)>;
TVM_DLL static FType& vtable() {
static FType inst;
return inst;
}
};

} // namespace transform
} // namespace tvm

#endif // TVM_IR_REPLACE_GLOBAL_VAR_H_
3 changes: 3 additions & 0 deletions python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=unused-import
"""Common data structures across all IR variants."""

from . import diagnostics, instrument, transform
from .adt import Constructor, TypeData
from .affine_type import TensorAffineType, TupleAffineType
Expand Down Expand Up @@ -61,3 +62,5 @@
TypeVar,
)
from .type_relation import TypeCall, TypeRelation

from . import analysis
22 changes: 22 additions & 0 deletions python/tvm/ir/_ffi_analysis_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.ir.analysis"""

import tvm._ffi


tvm._ffi._init_api("ir.analysis", __name__)
44 changes: 44 additions & 0 deletions python/tvm/ir/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=unused-import

"""Common analysis across all IR variants."""

from typing import Dict, List

import tvm
from . import _ffi_analysis_api as _ffi


def collect_call_map(
module: "tvm.ir.IRModule",
) -> Dict["tvm.ir.GlobalVar", List["tvm.ir.GlobalVar"]]:
"""Collect the call map of a module

Parameters
----------
module: tvm.ir.IRModule
The module to inspect

Returns
-------
call_map: Dict[tvm.ir.GlobalVar, List[tvm.ir.GlobalVar]]
A map from functions to the subroutines they call.

"""
return _ffi.CollectCallMap(module)
49 changes: 49 additions & 0 deletions src/ir/analysis.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/ir/analysis.cc
* \brief Analysis functions that must span multiple IR types
*/
#include <tvm/ir/analysis.h>

#include "../support/ordered_set.h"

namespace tvm {
namespace ir {

Map<GlobalVar, Array<GlobalVar>> CollectCallMap(const IRModule& mod) {
struct CalleeCollectorImpl : CalleeCollector {
void Mark(GlobalVar gvar) override { gvars.push_back(gvar); }
support::OrderedSet<GlobalVar> gvars;
};

Map<GlobalVar, Array<GlobalVar>> call_map;
for (const auto& [gvar, base_func] : mod->functions) {
CalleeCollectorImpl collector;
CalleeCollector::vtable()(base_func, &collector);
call_map.Set(gvar, Array<GlobalVar>{collector.gvars.begin(), collector.gvars.end()});
}
return call_map;
}

TVM_REGISTER_GLOBAL("ir.analysis.CollectCallMap").set_body_typed(CollectCallMap);

} // namespace ir
} // namespace tvm
63 changes: 63 additions & 0 deletions src/ir/replace_global_var.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/ir/replace_global_var.cc
* \brief IRModule transform to replace GlobalVar instances across any IR type.
*/

#include <tvm/ir/replace_global_var.h>

#include <vector>

namespace tvm {
namespace transform {

IRModule ReplaceGlobalVar(IRModule mod, Map<GlobalVar, GlobalVar> replacements) {
std::vector<GlobalVar> to_remove;
IRModule updates;

const auto& vtable = GlobalVarReplacer::vtable();

for (const auto& [old_gvar, old_func] : mod->functions) {
auto new_gvar = replacements.Get(old_gvar).value_or(old_gvar);
auto new_func = vtable(old_func, replacements);

if (!new_gvar.same_as(old_gvar)) {
to_remove.push_back(old_gvar);
}
if (!old_gvar.same_as(new_gvar) || !old_func.same_as(new_func)) {
updates->Add(new_gvar, new_func);
}
}

if (to_remove.size() || updates->functions.size()) {
auto write_ptr = mod.CopyOnWrite();
for (const auto& old_gvar : to_remove) {
write_ptr->Remove(old_gvar);
}
write_ptr->Update(updates);
}
return mod;
}

TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVar").set_body_typed(ReplaceGlobalVar);

} // namespace transform
} // namespace tvm
56 changes: 56 additions & 0 deletions src/relax/analysis/collect_call_map.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
*
* \file src/relax/analysis/collect_call_map.cc
*
* \brief Collect cross-IR call graph
*/

#include <tvm/ir/analysis.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/tir/expr_functor.h>

namespace tvm {
namespace relax {

namespace {
using ir::CalleeCollector;

struct Visitor : ExprVisitor {
explicit Visitor(CalleeCollector* collector) : collector(collector) {}
CalleeCollector* collector;
void VisitExpr_(const GlobalVarNode* node) override { collector->Mark(GetRef<GlobalVar>(node)); }
};

} // namespace

TVM_STATIC_IR_FUNCTOR(CalleeCollector, vtable)
.set_dispatch<relax::FunctionNode>([](const ObjectRef& func, CalleeCollector* collector) {
Visitor visitor{collector};
visitor(Downcast<Function>(func));
});

TVM_STATIC_IR_FUNCTOR(CalleeCollector, vtable)
.set_dispatch<relax::ExternFuncNode>([](const ObjectRef& func, CalleeCollector* collector) {});

} // namespace relax
} // namespace tvm
Loading
Loading