Skip to content

Commit

Permalink
[COMPILER] Initial compiler infra (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent 7be605f commit f773edc
Show file tree
Hide file tree
Showing 39 changed files with 2,148 additions and 119 deletions.
11 changes: 7 additions & 4 deletions nnvm/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ include $(config)

export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -Iinclude -fPIC
CFLAGS += -Itvm/include -Itvm/dlpack/include
CFLAGS += -Itvm/include -Itvm/dlpack/include -Itvm/HalideIR/src

ifdef DMLC_CORE_PATH
CFLAGS += -I$(DMLC_CORE_PATH)/include
Expand All @@ -38,7 +38,7 @@ PLUGIN_OBJ =
include $(NNVM_PLUGINS)

# specify tensor path
.PHONY: clean all test lint doc cython cython3 cyclean
.PHONY: clean all test lint pylint doc cython cython3 cyclean

UNAME_S := $(shell uname -s)

Expand All @@ -55,7 +55,7 @@ endif
all: lib/libnnvm.a lib/libnnvm_top.$(SHARED_LIBRARY_SUFFIX) lib/libnnvm_top_runtime.$(SHARED_LIBRARY_SUFFIX)

SRC = $(wildcard src/*.cc src/c_api/*.cc src/core/*.cc src/pass/*.cc)
SRC_TOP = $(wildcard src/top/*.cc, src/top/*/*.cc src/runtime/*.cc)
SRC_TOP = $(wildcard src/top/*/*.cc src/runtime/*.cc src/compiler/*.cc src/compiler/*/*.cc)
ALL_OBJ = $(patsubst %.cc, build/%.o, $(SRC))
TOP_OBJ = $(patsubst %.cc, build/%.o, $(SRC_TOP))
ALL_DEP = $(ALL_OBJ)
Expand Down Expand Up @@ -90,9 +90,12 @@ cython3:
cyclean:
rm -rf python/nnvm/*/*.so python/nnvm/*/*.dylib python/nnvm/*/*.cpp

lint:
lint: pylint
python dmlc-core/scripts/lint.py nnvm cpp include src

pylint:
pylint python/nnvm --rcfile=$(ROOTDIR)/tests/lint/pylintrc

doc:
doxygen docs/Doxyfile

Expand Down
28 changes: 28 additions & 0 deletions nnvm/include/nnvm/compiler/contrib_op_param.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*!
* Copyright (c) 2017 by Contributors
* \file contrib_op_param.h
* \brief Additional parameters for compiler optimized operators.
*/
#ifndef NNVM_COMPILER_CONTRIB_OP_PARAM_H_
#define NNVM_COMPILER_CONTRIB_OP_PARAM_H_

#include <dmlc/parameter.h>
#include <string>

namespace nnvm {
namespace compiler {

/*! \brief Parameters of layout transform operator */
struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
std::string src_layout;
std::string dst_layout;

DMLC_DECLARE_PARAMETER(LayoutTransformParam) {
DMLC_DECLARE_FIELD(src_layout);
DMLC_DECLARE_FIELD(dst_layout);
}
};
} // namespace compiler
} // namespace nnvm

#endif // NNVM_COMPILER_CONTRIB_OP_PARAM_H_
89 changes: 89 additions & 0 deletions nnvm/include/nnvm/compiler/op_attr_types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*!
* Copyright (c) 2017 by Contributors
* \file op_attr_types.h
* \brief The Expr and related elements in DataFlow construction.
*/
#ifndef NNVM_COMPILER_OP_ATTR_TYPES_H_
#define NNVM_COMPILER_OP_ATTR_TYPES_H_

#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/registry.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/graph.h>
#include <vector>
#include <string>

namespace nnvm {
namespace compiler {

using ::tvm::Array;
using ::tvm::Tensor;
using ::tvm::Schedule;

/*! \brief operator pattern used in graph fusion */
enum OpPatternKind : int {
// Elementwise operation
kElemWise = 0,
// Broadcast operation
kBroadcast = 1,
// Complex operation, can fuse bcast in input/outputs
// but cannot chain another complex op
kComplex = 2,
// Extern operation, cannot fuse anything.
kExtern = 3
};

/*! \brief the operator pattern */
using TOpPattern = int;

/*!
* \brief Computation description interface
* \param attrs The attribute of the node.
* \param inputs The input tensors(placeholders)
* \return The output description of the tensor.
*/
using FTVMCompute = std::function<
Array<Tensor>
(const NodeAttrs& attrs, const Array<Tensor>& inputs)>;

/*!
* \brief Build the computation schedule for
* op whose root is at current op.
* \param attrs The attribute of the node.
* \param outs The output tensors.
* \param target The build target.
* \return schedule The computation schedule.
*/
using FTVMSchedule = std::function<
Schedule(const NodeAttrs& attrs,
const Array<Tensor>& outs,
const std::string& target)>;

/*! \brief Layout Information about an entry */
using TLayoutInfo = std::string;

/*!
* \brief The producer consumer function of node layout
* \param attrs The attribute of the node.
* \param ilayouts The input layouts that the node request.
* \param olayouts The output layouts that the node produce.
* \return bool The success flag.
*/
using FTVMLayoutRequest = std::function<bool (const NodeAttrs& attrs,
std::vector<TLayoutInfo> *ilayouts,
std::vector<TLayoutInfo> *olayouts)>;

/*!
* \brief Transform from normal operator to vectorized operator
* \param node The source node.
* \return Transformed vectorized op.
*/
using FTVMVectorizedOp = std::function<nnvm::NodePtr (const nnvm::Node* node)>;

} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_OP_ATTR_TYPES_H_
57 changes: 57 additions & 0 deletions nnvm/include/nnvm/compiler/packed_func_ext.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*!
* Copyright (c) 2017 by Contributors
* \file packed_func_ext.h
* \brief Extension to enable packed functionn for nnvm types
*/
#ifndef NNVM_COMPILER_PACKED_FUNC_EXT_H_
#define NNVM_COMPILER_PACKED_FUNC_EXT_H_

#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <nnvm/graph.h>
#include <nnvm/symbolic.h>
#include <string>
#include <unordered_map>

namespace nnvm {
namespace compiler {

using tvm::runtime::PackedFunc;

using AttrDict = std::unordered_map<std::string, std::string>;

/*!
* \brief Get PackedFunction from global registry and
* report error if it does not exist
* \param name The name of the function.
* \return The created PackedFunc.
*/
inline const PackedFunc& GetPackedFunc(const std::string& name) {
const PackedFunc* pf = tvm::runtime::Registry::Get(name);
CHECK(pf != nullptr) << "Cannot find function " << name << " in registry";
return *pf;
}
} // namespace compiler
} // namespace nnvm

// Enable the graph and symbol object exchange.
namespace tvm {
namespace runtime {

template<>
struct extension_class_info<nnvm::Symbol> {
static const int code = 16;
};

template<>
struct extension_class_info<nnvm::Graph> {
static const int code = 17;
};

template<>
struct extension_class_info<nnvm::compiler::AttrDict> {
static const int code = 18;
};
} // namespace runtime
} // namespace tvm
#endif // NNVM_COMPILER_PACKED_FUNC_EXT_H_
12 changes: 12 additions & 0 deletions nnvm/include/nnvm/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ template<typename AttrType>
using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs)>;

/*!
* \brief Get attribute dictionary from node.
*
* \param attrs The attributes of the node.
* \return The attribute dict.
* \note Register under "FUpdateAttrDict"
*/
using FGetAttrDict = std::function<
std::unordered_map<std::string, std::string>
(const NodeAttrs& attrs)>;

/*!
* \brief Shape inference function.
* Update the shapes given the input shape information.
Expand Down
2 changes: 1 addition & 1 deletion nnvm/include/nnvm/top/README
Original file line number Diff line number Diff line change
@@ -1 +1 @@
NNVM Core Operator Specs
NNVM Core Operator and Compiler
11 changes: 7 additions & 4 deletions nnvm/python/nnvm/_base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# coding: utf-8
# pylint: disable=invalid-name
# pylint: disable=invalid-name, unused-import
""" ctypes library of nnvm and helper functions """
from __future__ import absolute_import

import sys
import os
import ctypes
import numpy as np
from . import libinfo

__all__ = ['NNNetError']
try:
import tvm
except ImportError:
pass

#----------------------------
# library loading
#----------------------------
Expand Down Expand Up @@ -181,7 +184,7 @@ def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True)
param_keys.add(key)
type_info = py_str(arg_types[i])
ret = '%s : %s' % (key, type_info)
if len(arg_descs[i]) != 0:
if arg_descs[i]:
ret += '\n ' + py_str(arg_descs[i])
param_str.append(ret)
doc_str = ('Parameters\n' +
Expand Down
5 changes: 3 additions & 2 deletions nnvm/python/nnvm/_ctypes/symbol.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines,
# pylint: disable=len-as-condition, consider-iterating-dictionary
"""Symbolic configuration API."""
from __future__ import absolute_import as _abs

import copy
import ctypes
import sys
from .._base import _LIB
from .._base import c_array, c_str, nn_uint, py_str, string_types
from .._base import c_array, c_str, nn_uint, py_str
from .._base import SymbolHandle, OpHandle
from .._base import check_call, ctypes2docstring
from ..name import NameManager
Expand Down
18 changes: 18 additions & 0 deletions nnvm/python/nnvm/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Namespace for NNVM-TVM compiler toolchain"""
from __future__ import absolute_import

import tvm

from . import build_module
from . build_module import build

from .. import symbol as _symbol
from .. import graph as _graph

from .registry import OpPattern
from .registry import register_compute, register_schedule, register_pattern

from .. import top as _top

tvm.register_extension(_symbol.Symbol, _symbol.Symbol)
tvm.register_extension(_graph.Graph, _graph.Graph)
79 changes: 79 additions & 0 deletions nnvm/python/nnvm/compiler/build_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# pylint: disable=invalid-name
"""Namespace for building operators."""
from __future__ import absolute_import as _abs

import tvm
from . import graph_attr
from .. import graph as _graph

@tvm.register_func("nnvm.compiler.lower")
def _lower(sch, inputs, func_name):
f = tvm.lower(sch, inputs, name=func_name)
return f if isinstance(
f, (tvm.container.Array, tuple, list)) else [f]


@tvm.register_func("nnvm.compiler.build_target")
def _build(funcs, target):
return tvm.build(funcs, target=target)


_move_module = tvm.get_global_func("nnvm.compiler._move_module")


def optimize(graph):
"""Perform graph optimization
Parameters
----------
graph : Graph
The graph to be used in lowering.
Returns
-------
graph : Graph
The optimized execution graph.
"""
return graph


def build(graph, target, shape, dtype="float32"):
"""Build graph into runtime library.
This is the final step of graph compilation.
Parameters
----------
graph : Graph
The graph to be used in lowering
target : str
The build target
shape : dict of str to tuple
The input shape to the graph
dtype : str or dict of str to str
The input types to the graph
Returns
-------
graph : Graph
The final execution graph.
libmod : tvm.Module
The modue that comes with the execution graph
"""
if not isinstance(target, str):
raise TypeError("require target to be str")
if not isinstance(shape, dict):
raise TypeError("require shape to be dict")

graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
graph = graph_attr.set_shape(graph, shape)
graph = graph_attr.set_dtype(graph, dtype)
graph._set_json_attr("target", target, "str")
graph = graph.apply("InferShape").apply("InferType")
graph = graph.apply("GraphFusePartition").apply("GraphFuse")
libmod = _move_module(graph)
return graph, libmod
Loading

0 comments on commit f773edc

Please sign in to comment.