Skip to content

Commit

Permalink
[MODULE/REFACTOR] Introduce Module for AOT and runtime linking.
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 24, 2017
1 parent 8f240ee commit 2a9a0a1
Show file tree
Hide file tree
Showing 65 changed files with 2,480 additions and 1,149 deletions.
18 changes: 13 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@ include $(config)
# specify tensor path
.PHONY: clean all test doc

all: lib/libtvm.a lib/libtvm.so
all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a

LIB_HALIDE_IR = HalideIR/lib/libHalideIR.a

SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)

RUNTIME_SRC = $(wildcard src/runtime/*.cc src/runtime/*/*.cc)
RUNTIME_DEP = $(patsubst src/%.cc, build/%.o, $(RUNTIME_SRC))

ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)

export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
Expand Down Expand Up @@ -77,15 +82,18 @@ build/%.o: src/%.cc
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@


lib/libtvm.a: $(ALL_DEP)
lib/libtvm.so: $(ALL_DEP)
@mkdir -p $(@D)
ar crv $@ $(filter %.o, $?)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)

lib/libtvm.so: $(ALL_DEP)
lib/libtvm_runtime.so: $(RUNTIME_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)

lib/libtvm.a: $(ALL_DEP)
@mkdir -p $(@D)
ar crv $@ $(filter %.o, $?)

$(LIB_HALIDE_IR): LIBHALIDEIR

LIBHALIDEIR:
Expand Down
74 changes: 8 additions & 66 deletions include/tvm/api_registry.h
Original file line number Diff line number Diff line change
@@ -1,85 +1,27 @@
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2017 by Contributors
* \file api_registry.h
* \brief This file defines the TVM API registry.
*
* The API registry stores type-erased functions.
* Each registered function is automatically exposed
* to front-end language(e.g. python).
* Front-end can also pass callbacks as PackedFunc, or register
* then into the same global registry in C++.
* The goal is to mix the front-end language and the TVM back-end.
*
* \code
* // register the function as MyAPIFuncName
* TVM_REGISTER_API(MyAPIFuncName)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
* \brief This files include necessary headers to
* be used to register an global API function.
*/
#ifndef TVM_API_REGISTRY_H_
#define TVM_API_REGISTRY_H_

#include <dmlc/base.h>
#include <string>
#include "./base.h"
#include "./runtime/packed_func.h"
#include "./packed_func_ext.h"

namespace tvm {

/*! \brief Utility to register API. */
class APIRegistry {
public:
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
APIRegistry& set_body(PackedFunc f); // NOLINT(*)
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
APIRegistry& set_body(PackedFunc::FType f) { // NOLINT(*)
return set_body(PackedFunc(f));
}
/*!
* \brief Register a function with given name
* \param name The name of the function.
*/
static APIRegistry& __REGISTER__(const std::string& name); // NOLINT(*)

private:
/*! \brief name of the function */
std::string name_;
};
#include "./runtime/registry.h"

/*!
* \brief Get API function by name.
* \brief Register an API function globally.
* It simply redirects to TVM_REGISTER_GLOBAL
*
* \param name The name of the function.
* \return the corresponding API function.
* \note It is really PackedFunc::GetGlobal under the hood.
*/
inline PackedFunc GetAPIFunc(const std::string& name) {
return PackedFunc::GetGlobal(name);
}

#define _TVM_REGISTER_VAR_DEF_ \
static DMLC_ATTRIBUTE_UNUSED ::tvm::APIRegistry& __make_TVMRegistry_

/*!
* \brief Register API function globally.
* \code
* TVM_REGISTER_API(MyPrint)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#define TVM_REGISTER_API(OpName) \
DMLC_STR_CONCAT(_TVM_REGISTER_VAR_DEF_, __COUNTER__) = \
::tvm::APIRegistry::__REGISTER__(#OpName)
} // namespace tvm
#define TVM_REGISTER_API(OpName) TVM_REGISTER_GLOBAL(OpName)

#endif // TVM_API_REGISTRY_H_
42 changes: 11 additions & 31 deletions include/tvm/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
#include "./base.h"
#include "./expr.h"
#include "./lowered_func.h"
#include "./api_registry.h"
#include "./runtime/packed_func.h"


namespace tvm {
/*! \brief namespace for lowlevel IR pass and codegen */
namespace codegen {
Expand All @@ -22,41 +22,21 @@ using runtime::TVMArgs;
using runtime::TVMRetValue;

/*!
* \brief Build a stack VM function.
* \param func The LoweredFunc to be build
* \param device_funcs The additional device functions
* \return A packed function representing the func.
*/
PackedFunc BuildStackVM(
LoweredFunc func,
const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs);

/*!
* \brief Build a LLVM VM function, this is still beta
* \param func The LoweredFunc to be build
* \return A packed function representing the func.
*/
PackedFunc BuildLLVM(LoweredFunc func);

/*!
* \brief Build a CUDA function with NVRTC
* \brief Build a module from array of lowered function.
* \param funcs The functions to be built.
* \param target The target to be built.
* \return The builded module.
*
* \param fsplits The LoweredFuncs to be build (after SplitHostDevice)
* The first element is the host function, followed by device functions.
* \param host_mode The host side compilation mode:
* - "stackvm": use stack vm to interpret host side code.
* \note Calls global API function "_codegen_build_" + target
*/
PackedFunc BuildNVRTC(Array<LoweredFunc> fsplits, std::string host_mode);
runtime::Module Build(const Array<LoweredFunc>& funcs,
const std::string& target);

/*!
* \brief Build a OpenCL function.
*
* \param fsplits The LoweredFuncs to be build (after SplitHostDevice)
* The first element is the host function, followed by device functions.
* \param host_mode The host side compilation mode:
* - "stackvm": use stack vm to interpret host side code.
* \param target The target to be queried.
* \return Whether target is enabled.
*/
PackedFunc BuildOpenCL(Array<LoweredFunc> fsplits, std::string host_mode);
bool TargetEnabled(const std::string& target);

} // namespace codegen
} // namespace tvm
Expand Down
19 changes: 5 additions & 14 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,23 +120,14 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*!
* \brief See pesudo code
*
* int tvm_call_global(name, TVMValue* args) {
* PackedFunc f = PackedFunc::GetGlobal(name);
* f (args, type_code_of(args), len(args));
* int tvm_call_packed(name, TVMValue* args) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* (*f)(args, type_code_of(args), len(args));
* return 0;
* }
*/
constexpr const char* tvm_call_global = "tvm_call_global";
/*!
* \brief See pesudo code
*
* int tvm_call_device(name, TVMValue* args) {
* PackedFunc df = CodeGenEnv->GetDevice(name);
* f (args, type_code_of(args), len(args));
* return 0;
* }
*/
constexpr const char* tvm_call_device = "tvm_call_device";
constexpr const char* tvm_call_packed = "tvm_call_packed";
/*!
* \brief See pesudo code
*
Expand Down
7 changes: 5 additions & 2 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,15 @@ Stmt LiftAllocate(Stmt stmt);
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_packed_args Number of arguments that are processed in packed form.
* \param num_unpacked_args Number of arguments that
* are processed in plain form instead of packed form.
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
*
* let num_packed_args = len(api_args) - num_unpacked_args;
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
Expand All @@ -167,7 +170,7 @@ Stmt LiftAllocate(Stmt stmt);
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args);
int num_unpacked_args);

/*!
* \brief Count number of undefined vars in f.
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/lowered_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class LoweredFuncNode : public FunctionBaseNode {
* constant Expr of given type is used.
*/
Map<Var, Expr> handle_data_type;
/*! \brief Whether this function is packed function */
bool is_packed_func{true};
/*! \brief The body statment of the function */
Stmt body;
/*! \return name of the operation */
Expand All @@ -88,6 +90,7 @@ class LoweredFuncNode : public FunctionBaseNode {
v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis);
v->Visit("handle_data_type", &handle_data_type);
v->Visit("is_packed_func", &is_packed_func);
v->Visit("body", &body);
}

Expand Down
Loading

0 comments on commit 2a9a0a1

Please sign in to comment.