Skip to content

Commit

Permalink
[Unity] Basic StructInfo Analysis and Expr construction (#13916)
Browse files Browse the repository at this point in the history
[Unity] Basic StructInfo Analysis and Expr construction.

This PR adds struct info analysis and expr support.
These are logics to construct the IR node and perform
struct info related analysis.

Testcases are added to cover the IR node construction
and related struct info analysis checks.

Co-authored-by: Tianqi Chen <tianqi.tchen@gmail.com>
Co-authored-by: Altan Haan <altanh@cs.washington.edu>
Co-authored-by: Andrew Liu <andrewlliu@gmail.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Jiawei Liu <jaway.liu@gmail.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com>
Co-authored-by: masahi <masahi129@gmail.com>
Co-authored-by: Prakalp Srivastava <prakalp@octoml.ai>
Co-authored-by: Ruihang Lai <ruihangl@cs.cmu.edu>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Steven S. <Lyubomirsky slyubomirsky@octoml.ai>
Co-authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com>
Co-authored-by: Yixin Dong <ubospica@gmail.com>
Co-authored-by: Yong Wu <yongcale@gmail.com>
Co-authored-by: Ziheng Jiang <ziheng@apache.org>
  • Loading branch information
17 people committed Feb 24, 2023
1 parent 76cc9f7 commit fa561c8
Show file tree
Hide file tree
Showing 29 changed files with 5,198 additions and 14 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/support/*.cc
src/script/*.cc
src/relax/ir/*.cc
src/relax/analysis/*.cc
src/relax/backend/vm/*.cc
)

Expand Down
3 changes: 2 additions & 1 deletion include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ class PrimType : public Type {
/*!
* \brief Constructor
* \param dtype The corresponding dtype.
* \param span The span
*/
TVM_DLL explicit PrimType(runtime::DataType dtype);
TVM_DLL explicit PrimType(runtime::DataType dtype, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode);
};
Expand Down
252 changes: 252 additions & 0 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relax/analysis.h
* \brief The set of Relax specific analysis on IR.
*/
#ifndef TVM_RELAX_ANALYSIS_H_
#define TVM_RELAX_ANALYSIS_H_

#include <tvm/arith/analyzer.h>
#include <tvm/ir/diagnostic.h>
#include <tvm/ir/module.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/tir/function.h>

#include <functional>
#include <utility>

namespace tvm {
namespace relax {
//-----------------------------------
// Shape expression analysis
//----------------------------------
/*!
* \brief Can prove the two symbolic shape arrays equals to each other.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \param ana The analyzer used for integer analysis.
* \return The prove result.
*
* \note This function does best effort prove, which means
* if result is false, there is still possibility that
* two shapes equals to each other during runtime.
*/
TVM_DLL bool CanProveShapeEqual(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs,
arith::Analyzer* ana);

/*!
* \brief Can prove the two symbolic shape expressions equals to each other.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \param ana The analyzer used for integer analysis.
*
* \note This function does best effort prove, which means
* if result is false, there is still possibility that
* two shapes equals to each other during runtime.
*/
TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana);

//-----------------------------------
// Foundational StructInfo analysis
//-----------------------------------
/*!
* \brief Get the corresponding static type from a given struct info.
* \param info The struct info.
* \return the corresponding static type.
*/
TVM_DLL Type GetStaticType(const StructInfo& info);

/*!
* \brief Get the corresponding struct info from static type.
* \param type The input type
* \return the corresponding struct info.
*/
TVM_DLL StructInfo StructInfoFromType(const Type& type);

/*!
* \brief Erase the info to a corresponding more coarse grained
* struct info that is still well-defined(with all the vars in scope).
*
* When we are returning a StructInfo to another scope,
* it is important to remember that StructInfo may carry
* dependencies on var that is not defined the other scope.
*
* In such cases, it is important to call EraseToWellDefined to get
* another StructInfo that **only** contains the vars that are defined
* in the target scope.
*
* For example, consider the following function
*
* \code
*
* @R.function
* def f(x: R.Tensor[(n, m)]):
* k = tir.Var("k", "int64")
* v0 = opaque_fn(x)
* v1 = match_cast(v0, R.Tensor[(n, k)])
* v2 : R.Tensor[(n + 1, k + 2)] = pad(v1)
* return v2
*
* \endcode
*
* In the above code, the return value y have shape `(n + 1, k + 2)`,
* However, at the level of function signature, only n, m are defined,
* k is undefined here.
*
* When we call EraseToWellDefined(R.Tensor[(n + 1, k + 2)], fshape_var_map={n: n, m: m}),
* we will obtain R.Tensor(ndim=2), which is an erased info that does not depend
* on k(which is undefined from parameter signature).
*
* However, if we call EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: n, m: m}),
* Then the return value will be R.Tensor[(n + 1, m)], because both n and m are defined.
*
* We can also make these var map to return a different expression.
* For example, EraseToWellDefined(R.Tensor[(n + 1, m)], fshape_var_map={n: 2, m: m})
* will give us R.Tensor[(3, m)], where n get replaced by 2.
*
* Use this function in the following scenarios:
* - Decide the struct_info of expr with sub-scopes, such as If, SeqExpr
* - Decide the deduced return struct_info of a function that can be fully decided by params.
*
* \param info The struct info.
* \param f_shape_var_map callback function to specify
* whether a symbolic shape var is defined and the value it maps to,
* return nullopt if var is undefined.
* \param f_var_map callback function to specify
* whether a var is defined in the target scope and the value it maps to,
* return nullopt if var is undefined.
* \param ana Optional context analyzer to prove symbolic expression equality.
*
* \return the corresponding erased struct info.
*/
TVM_DLL StructInfo
EraseToWellDefined(const StructInfo& info,
std::function<Optional<PrimExpr>(const tir::Var& var)> f_shape_var_map = nullptr,
std::function<Optional<Expr>(const Var& var)> f_var_map = nullptr,
arith::Analyzer* ana = nullptr);

/*!
* \brief EraseToWellDefined variant with map.
* \param info The struct info.
* \param shape_var_map map to specify
* whether a symbolic shape var is defined and the value it maps to,
* return nullopt if var is undefined.
* \param var_map map to specify
* whether a var is defined in the target scope and the value it maps to,
* return nullopt if var is undefined.
* \param ana Optional context analyzer to prove symbolic expression equality.
*
* \return the corresponding erased struct info.
*/
TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, Map<tir::Var, PrimExpr> shape_var_map,
Map<Var, Expr> var_map, arith::Analyzer* ana = nullptr);

/*!
* \brief Fine grained result of base check.
*
* This analysis comes with different levels of checking failures
* that can help to customize the compilation decisions.
*
* For a given pair of lhs_struct_info, rhs_struct_info. We adopt
* the following terminology:
* - LSet = {value | value matches lhs_struct_info}
* - RSet = {value | value matches rhs_struct_info}
*
* See the definition of each level below.
*/
enum class BaseCheckResult {
/*!
* \brief The two value sets have no intersection at all: Interset(LSet, RSet) = empty
*/
kFailL0 = 0,
/*!
* \brief LSet is not superset of RSet by only looking at static information.
*
* \note This level will trigger static type checking error when lhs is param and rhs is arg.
*/
kFailL1 = 1,
/*!
* \brief WLSet is not superset of RSet because of mismatch in value information.
*
* L1-level mismatches in params of FuncStructInfo is categorized as
* If lhs is FuncStructInfo, then L1-level mismatch in its params
* is categorized as L2-level mismatch for lhs.
*
* Design considerations for functions:
* - (a) We want to be able to erase type/value in function signature
* when we unify function struct info and preserve simpler representations.
* - (b) We automatically insert match_cast at function boundary, so
* we can erase (int)->int argument as (object)->int.
* The input shape/type mismatch will be detected by runtime checks at function boundary.
* This behavior is also consistent with the PackedFunc behavior.
*
* \note This level means there is no problem about static known information.
* It is OK for the checker to do best effort and return this value.
*/
kFailL2 = 2,
/*! \brief LSet is superset of RSet. */
kPass = 3
};

/*!
* \brief Run a base check to see if base subsumes derived.
*
* This function returns fine-grained base-check result on reasons of failure.
*
* \param base The base struct info.
* \param derived The derived struct info.
* \param ana Optional context analyzer to prove symbolic expression equality.
* \return Whether the relation holds.
*
* \sa BaseCheckResult
*/
TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived,
arith::Analyzer* ana = nullptr);

/*!
* \brief Check the relation of two struct info to see if one subsumes another one.
*
* \param base The base struct info.
* \param derived The derived struct info.
* \param ana Optional context analyzer to prove symbolic expression equality.
* \return Whether the relation holds.
*/
TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived,
arith::Analyzer* ana = nullptr);

/*!
* \brief Unify the two struct info to their least common ancestor.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \param ana Optional context analyzer to prove symbolic expression equality.
* \return The unified information.
*/
TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
arith::Analyzer* ana = nullptr);
} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_ANALYSIS_H_
43 changes: 41 additions & 2 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <tvm/ir/source_map.h>
#include <tvm/node/node.h>
#include <tvm/relax/type.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/object.h>
Expand All @@ -35,7 +34,47 @@ namespace relax {

using Expr = RelayExpr;
using ExprNode = RelayExprNode;
using relay::Id;
/*!
* \brief The unique identifier of variables.
*
* Id is like name to the variables,
* except that id is unique for each Var.
*
* \note Do not create Id directly, they are created in Var.
*/
class IdNode : public Object {
public:
/*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality.
*/
String name_hint;

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); }

bool SEqualReduce(const IdNode* other, SEqualReducer equal) const {
return equal.FreeVarEqualImpl(this, other);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.FreeVarHashImpl(this); }

static constexpr const char* _type_key = "relax.Id";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IdNode, Object);
};

class Id : public ObjectRef {
public:
/*!
* \brief The constructor
* \param name_hint The name of the variable.
*/
TVM_DLL explicit Id(String name_hint);

TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode);
};

/*!
* \brief Base type of all structure information.
Expand Down
Loading

0 comments on commit fa561c8

Please sign in to comment.