Skip to content

Commit

Permalink
Reorganize source code. (apache#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang authored and YuchenJin committed Jan 26, 2022
1 parent 039aa74 commit 284358d
Show file tree
Hide file tree
Showing 18 changed files with 438 additions and 242 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,20 @@
*/

/*!
* \file tvm/relax/builder.h
* \file tvm/relax/vm/exec_builder.h
* \brief
*/
#ifndef TVM_RELAX_BUILDER_H_
#define TVM_RELAX_BUILDER_H_
#ifndef TVM_RELAX_EXEC_BUILDER_H_
#define TVM_RELAX_EXEC_BUILDER_H_

#include <tvm/ir/expr.h>
#include <tvm/node/reflection.h>
#include <tvm/node/repr_printer.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>

#include "./vm/bytecode.h"
#include "./vm/executable.h"
#include "./bytecode.h"
#include "./executable.h"

namespace tvm {
namespace relax {
Expand Down Expand Up @@ -102,4 +102,4 @@ class ExecBuilder : public ObjectRef {
} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_BUILDER_H_
#endif // TVM_RELAX_EXEC_BUILDER_H_
1 change: 1 addition & 0 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from . import op
from . import parser
from . import analysis
from . import transform


# Expr
Expand Down
20 changes: 0 additions & 20 deletions python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,3 @@ def post_order_visit(expr, fvisit):
The visitor function to be applied.
"""
return _ffi_api.post_order_visit(expr, fvisit)

def fma_rewrite(expr):
"""Perform fused multiply add rewriting in dataflow blocks.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
"""
return _ffi_api.fma_rewrite(expr)

def explicit_memory_rewrite(expr):
"""Perform explicit memory allocation for call_dps in dataflow blocks.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
"""
return _ffi_api.explicit_memory_rewrite(expr)
20 changes: 20 additions & 0 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import, redefined-builtin
"""Relax IR analysis. """

from .transform import *
18 changes: 18 additions & 0 deletions python/tvm/relax/transform/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
import tvm._ffi

tvm._ffi._init_api("relax.transform", __name__)
39 changes: 39 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return
# pylint: disable=unidiomatic-typecheck
from . import _ffi_api

def fma_rewrite(expr):
"""Perform fused multiply add rewriting in dataflow blocks.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
"""
return _ffi_api.fma_rewrite(expr)

def explicit_memory_rewrite(expr):
"""Perform explicit memory allocation for call_dps in dataflow blocks.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
"""
return _ffi_api.explicit_memory_rewrite(expr)
File renamed without changes.
115 changes: 1 addition & 114 deletions src/relax/expr_functor.cc → src/relax/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* \file src/relay/expr_functor.cc
* \file src/relax/expr_functor.cc
* \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprMutator uses memoization and self return in order to amortize
Expand All @@ -29,10 +29,6 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relax/type.h>
#include <stack>
#include <tvm/tir/op.h>

#include "../relay/transforms/pattern_utils.h"

namespace tvm {
namespace relax {
Expand Down Expand Up @@ -415,114 +411,5 @@ Expr DataflowMutator::LookupVar(Var var) {
return irbuilder_->LookupVar(var);
}
}


// ==================
// EwiseFMARewriter
// Example:
// x0 = mul(a, b)
// z0 = add(x0, c)
// -->
// z0 = ewise_fma(a, b, c)

// Example 2:
// Question: do we want to support this?
// x0 = mul(a, add(k, b))
// z0 = add(x0, c)
// -->
// lv0 = add(k, b)
// z0 = ewise_fma(a, lv0, c)

class EwiseFMARewriter : public DataflowMutator {
Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override {
static const Op& add_op = Op::Get("relax.add");
static const Op& multiply_op = Op::Get("relax.multiply");
static const Op& ewise_fma_op = Op::Get("relax.ewise_fma");

// TODO: shape & dtype check
const CallNode* op1 = binding->value.as<CallNode>();
if (op1 && (op1->op == add_op)) {
Expr value = LookupVar(Downcast<Var>(op1->args[0]));
const CallNode* op2 = value.as<CallNode>();
if (op2 && op2->op == multiply_op) {
Call fma_call = Call(ewise_fma_op, {op2->args[0], op2->args[1], op1->args[1]}, {}, {});
return ir_builder->Emit(binding->var, fma_call);
}
}
return ir_builder->Emit(binding);
}
};

Expr FMARewrite(const Expr& e) {
return EwiseFMARewriter().Mutate(e);
}

TVM_REGISTER_GLOBAL("relax.analysis.fma_rewrite")
.set_body_typed([](Expr expr) {
return FMARewrite(expr);
});

// ==================
// ExplicitMemMutator
// Example:
// y: Tensor[n, m] = rx.call_dps((n, m), op.identity, (x))
// -->
// lv0 = rx.call("relax.builtin.alloc_tensor", [n, m])
// rx.call_packed(op.identity, x, lv0)

class ExplicitMemMutator : public DataflowMutator {
Expr ComputeStorageSize(const Expr& shape, const Type& type) const {
DynTensorType tensor_type = Downcast<DynTensorType>(type);
DataType dtype = DataType(tensor_type->dtype);
// Question: what if the dtype of tensor_type is unknown?
// Symbolic/static shape case
if (auto* shape_expr = shape.as<ShapeExprNode>()) {
PrimExpr num = PrimExpr(dtype.bits()) * PrimExpr(dtype.lanes());
PrimExpr add = num + 7;
PrimExpr ret = 1;
for (PrimExpr dim : shape_expr->values) {
ret = ret * dim;
}
ret = ret * (add / PrimExpr(8));
return ShapeExpr({ret});
}
// Fully dynamic shape case
// will need to dedup with ComputeStorageInRelay when we upstream
Expr prod = relay::Prod(shape, Array<Integer>(nullptr), false, false);
Expr num = relay::MakeConstantScalar(DataType::Int(64), dtype.bits() * dtype.lanes());
Expr add = relay::Add(num, relay::MakeConstantScalar(DataType::Int(64), 7));
Expr div = relay::MakeConstantScalar(DataType::Int(64), 8);
Expr ret = relay::Multiply(prod, relay::Divide(add, div));
return ret;
}

Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override {
static const Op& call_dps_op = Op::Get("relax.call_dps");
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");

const CallNode* op = binding->value.as<CallNode>();
if(op && op->op == call_dps_op) {
// switch current DataflowBlock to an impure BindingBlock
ir_builder->is_dataflow_ = false;
ShapeExpr output_shape = Downcast<ShapeExpr>(op->args[0]);
Type arg_type = Downcast<Tuple>(op->args[2])->fields[0]->checked_type();
Expr output_size = ComputeStorageSize(output_shape, arg_type);
Var tensor = ir_builder->Emit(Call(alloc_tensor_op, {op->args[0]}));
return ir_builder->Emit(binding->var, Call(op->args[1], {op->args[2], tensor}));
}
return ir_builder->Emit(binding);
}
};

Expr ExplicitMemRewrite(const Expr& e) {
return ExplicitMemMutator().Mutate(e);
}

TVM_REGISTER_GLOBAL("relax.analysis.explicit_memory_rewrite")
.set_body_typed([](Expr expr) {
return ExplicitMemRewrite(expr);
});


} // namespace relax
} // namespace tvm
File renamed without changes.
File renamed without changes.
74 changes: 74 additions & 0 deletions src/relax/transform/fma_rewrite.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relax/transform/fma_rewrite.cc
* \brief
*/
#include <tvm/relax/expr_functor.h>

namespace tvm {
namespace relax {

// ==================
// EwiseFMARewriter
// Example:
// x0 = mul(a, b)
// z0 = add(x0, c)
// -->
// z0 = ewise_fma(a, b, c)

// Example 2:
// Question: do we want to support this?
// x0 = mul(a, add(k, b))
// z0 = add(x0, c)
// -->
// lv0 = add(k, b)
// z0 = ewise_fma(a, lv0, c)

class EwiseFMARewriter : public DataflowMutator {
Var VisitVarBinding(const VarBinding& binding, IRBuilder& ir_builder) override {
static const Op& add_op = Op::Get("relax.add");
static const Op& multiply_op = Op::Get("relax.multiply");
static const Op& ewise_fma_op = Op::Get("relax.ewise_fma");

// TODO: shape & dtype check
const CallNode* op1 = binding->value.as<CallNode>();
if (op1 && (op1->op == add_op)) {
Expr value = LookupVar(Downcast<Var>(op1->args[0]));
const CallNode* op2 = value.as<CallNode>();
if (op2 && op2->op == multiply_op) {
Call fma_call = Call(ewise_fma_op, {op2->args[0], op2->args[1], op1->args[1]}, {}, {});
return ir_builder->Emit(binding->var, fma_call);
}
}
return ir_builder->Emit(binding);
}
};

Expr FMARewrite(const Expr& e) {
return EwiseFMARewriter().Mutate(e);
}

TVM_REGISTER_GLOBAL("relax.transform.fma_rewrite")
.set_body_typed([](Expr expr) {
return FMARewrite(expr);
});

} // namespace relax
} // namespace tvm
Loading

0 comments on commit 284358d

Please sign in to comment.