Skip to content

Commit

Permalink
[LANG] Include buffer semnatics, introduce pylint (#11)
Browse files Browse the repository at this point in the history
* [LANG] Include buffer semnatics, introduce pylint

* Refactor inline add support for buffer indexing

* fix doc
  • Loading branch information
tqchen authored Jan 13, 2017
1 parent 69a80cc commit 0992873
Show file tree
Hide file tree
Showing 26 changed files with 357 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ LIBHALIDEIR:
+ cd HalideIR; make lib/libHalideIR.a ; cd $(ROOTDIR)

lint:
python2 dmlc-core/scripts/lint.py tvm cpp include src
python2 dmlc-core/scripts/lint.py tvm all include src python

doc:
doxygen docs/Doxyfile
Expand Down
2 changes: 1 addition & 1 deletion dmlc-core
Submodule dmlc-core updated 1 files
+1 −1 include/dmlc/json.h
98 changes: 98 additions & 0 deletions include/tvm/buffer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@

/*!
* Copyright (c) 2016 by Contributors
* \file buffer.h
* \brief Symbolic n-dimensional array, to represent a memory buffer.
*/
#ifndef TVM_BUFFER_H_
#define TVM_BUFFER_H_

#include <tvm/container.h>
#include <string>

#include "./base.h"
#include "./expr.h"

namespace tvm {

// Internal node container Buffer
class BufferNode;
/*!
* \brief Buffer is a symbolic n-darray structure.
* It is a composition of primitive symbolic types,
* used to specify input/output strcuture of the program.
*/
class Buffer : public NodeRef {
public:
Buffer() {}
explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief construct a new buffer based on shape and strides.
*/
explicit Buffer(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "buffer");
/*!
* \brief Generate a load expression loading the index location of buffer.
* \param index The index to the buffer.
* \return The load expression.
*/
Expr MakeLoad(Array<Expr> index) const;
/*!
* \brief Generate a store statement.
* \param index The index to the buffer.
* \param value The value to be stored.
* \return The load expression.
*/
Stmt MakeStore(Array<Expr> index, Expr value) const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const BufferNode* operator->() const;
};

/*! \brief Node to represent a buffer */
class BufferNode : public Node {
public:
/*! \brief optional name of the buffer */
std::string name;
/*! \brief The pointer to the head of the data */
Var ptr;
/*! \brief The shape of the buffer */
Array<Expr> shape;
/*!
* \brief The strides of each dimension
* This can be an empty array, indicating array is contiguous
*/
Array<Expr> strides;
/*! \brief data type in the content of the tensor */
Type dtype;
// Maybe need more information(alignment) later
/*! \brief constructor */
BufferNode() {}

void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("ptr", &ptr);
v->Visit("shape", &shape);
v->Visit("strides", &strides);
v->Visit("dtype", &dtype);
}

static Buffer make(std::string name,
Var ptr,
Array<Expr> shape,
Array<Expr> strides,
Type dtype);

static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode);
};

inline const BufferNode* Buffer::operator->() const {
return static_cast<const BufferNode*>(node_.get());
}

} // namespace tvm
#endif // TVM_BUFFER_H_
19 changes: 16 additions & 3 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <unordered_map>
#include <vector>
#include "./expr.h"
#include "./buffer.h"
#include "./schedule.h"

namespace tvm {
Expand Down Expand Up @@ -56,10 +57,22 @@ Stmt ConvertSSA(Stmt stmt);
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Stmt Inline(FunctionRef f,
Stmt Inline(Stmt stmt,
FunctionRef f,
Array<Var> args,
Expr body,
Stmt stmt);
Expr body);


/*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
*
* \param stmt The stmt to be trasnformed.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
*/
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer);

} // namespace ir
} // namespace tvm
Expand Down
1 change: 1 addition & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=redefined-builtin, wildcard-import
"""C++ backend related python scripts"""
from __future__ import absolute_import as _abs
from ._ctypes._api import register_node
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# coding: utf-8
# pylint: disable=invalid-name
# pylint: disable=invalid-name, no-member
""" ctypes library of nnvm and helper functions """
from __future__ import absolute_import

import sys
import os
import ctypes
import numpy as np
from . import libinfo
Expand Down
15 changes: 9 additions & 6 deletions python/tvm/_ctypes/_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-arguments, too-many-lines
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
"""Symbolic configuration API."""
from __future__ import absolute_import as _abs

Expand All @@ -14,6 +15,7 @@
from .. import _function_internal

class ArgVariant(ctypes.Union):
"""ArgVariant in C API"""
_fields_ = [("v_long", ctypes.c_long),
("v_double", ctypes.c_double),
("v_str", ctypes.c_char_p),
Expand All @@ -30,8 +32,8 @@ class ArgVariant(ctypes.Union):

def _return_node(x):
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
if not isinstance(handle, NodeHandle):
handle = NodeHandle(handle)
ret_val = ArgVariant()
ret_typeid = ctypes.c_int()
ret_success = ctypes.c_int()
Expand All @@ -47,7 +49,7 @@ def _return_node(x):
kLong: lambda x: x.v_long,
kDouble: lambda x: x.v_double,
kStr: lambda x: py_str(x.v_str),
kNodeHandle: lambda x: _return_node(x)
kNodeHandle: _return_node
}

class SliceBase(object):
Expand Down Expand Up @@ -251,6 +253,7 @@ def register_node(type_key=None):
"""
if isinstance(type_key, str):
def register(cls):
"""internal register function"""
NODE_TYPE[type_key] = cls
return cls
return register
Expand All @@ -273,9 +276,9 @@ def _init_function_module(root_namespace):
module_obj = sys.modules["%s.function" % root_namespace]
module_internal = sys.modules["%s._function_internal" % root_namespace]
namespace_match = {
"_make_" : sys.modules["%s.make" % root_namespace],
"_pass_" : sys.modules["%s.ir_pass" % root_namespace],
"_schedule_" : sys.modules["%s.schedule" % root_namespace]
"_make_": sys.modules["%s.make" % root_namespace],
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
"_schedule_": sys.modules["%s.schedule" % root_namespace]
}

for name in op_names:
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/collections.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=protected-access, no-member
"""Collection structure in the high level DSL."""
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
Expand All @@ -6,6 +7,7 @@

@register_node
class Array(NodeBase):
"""Array container of TVM"""
def __getitem__(self, i):
if i >= len(self):
raise IndexError("array index out ot range")
Expand All @@ -19,13 +21,15 @@ def __repr__(self):

@register_node
class Map(NodeBase):
"""Map container of TVM"""
def __getitem__(self, k):
return _function_internal._MapGetItem(self, k)

def __contains__(self, k):
return _function_internal._MapCount(self, k) != 0

def items(self):
"""Get the items from the map"""
akvs = _function_internal._MapItems(self)
return [(akvs[i], akvs[i+1]) for i in range(0, len(akvs), 2)]

Expand All @@ -38,9 +42,17 @@ def __repr__(self):

@register_node
class Range(NodeBase):
"""Represent range in TVM"""
pass


@register_node
class IterVar(NodeBase, _expr.ExprOp):
"""Represent iteration variable."""
pass


@register_node
class Buffer(NodeBase):
"""Represent a Buffer in TVM."""
pass
3 changes: 2 additions & 1 deletion python/tvm/expr.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=protected-access, no-member, missing-docstring
from __future__ import absolute_import as _abs
from ._ctypes._api import NodeBase, register_node
from . import make as _make
Expand Down Expand Up @@ -174,7 +175,7 @@ class Call(Expr):
Halide = 3
Intrinsic = 4
PureIntrinsic = 5
pass


@register_node
class Let(Expr):
Expand Down
57 changes: 49 additions & 8 deletions python/tvm/function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# pylint: disable=protected-access, no-member, invalid-name
# pylint: disable=redefined-builtin, undefined-variable
"""Functions defined in TVM."""
from __future__ import absolute_import as _abs
from numbers import Number as _Number, Integral as _Integral
from numbers import Integral as _Integral
from ._ctypes._api import _init_function_module, convert
from . import _function_internal
from . import make as _make
Expand All @@ -8,6 +11,7 @@

int32 = "int32"
float32 = "float32"
handle = "handle"

def const(value, dtype=None):
"""construct a constant"""
Expand Down Expand Up @@ -65,7 +69,7 @@ def Var(name="tindex", dtype=int32):
return _function_internal._Var(name, dtype)


def placeholder(shape, dtype = None, name="placeholder"):
def placeholder(shape, dtype=None, name="placeholder"):
"""Construct an empty tensor object.
Parameters
Expand All @@ -84,6 +88,7 @@ def placeholder(shape, dtype = None, name="placeholder"):
tensor: tensor.Tensor
The created tensor
"""
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
dtype = float32 if dtype is None else dtype
return _function_internal._Placeholder(
shape, dtype, name)
Expand Down Expand Up @@ -111,8 +116,7 @@ def compute(shape, fcompute, name="compute"):
tensor: tensor.Tensor
The created tensor
"""
if isinstance(shape, _expr.Expr):
shape = (shape, )
shape = (shape,) if isinstance(shape, _expr.Expr) else shape

ndim = len(shape)
arg_names = fcompute.__code__.co_varnames
Expand All @@ -125,7 +129,44 @@ def compute(shape, fcompute, name="compute"):
op_node = _function_internal._ComputeOp(
name, dim_var, body)
return _function_internal._Tensor(
shape, name, body.dtype, op_node, 0)
shape, body.dtype, op_node, 0)


def Buffer(shape, dtype=None,
name="buffer", ptr=None,
strides=None):
"""Create a new buffer
Parameters
----------
shape : tuple of Expr
The shape of the buffer.
dtype : str, optional
The data type of the buffer.
name : str, optional
The name of the buffer.
ptr : Var, optional
The data pointer in the buffer.
strides: array of Expr
The stride of the buffer.
Returns
-------
buffer : Buffer
The created buffer
"""
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
dtype = float32 if dtype is None else dtype
strides = () if strides is None else strides
if ptr is None:
ptr = Var(name, "handle")

return _function_internal._Buffer(
name, ptr, shape, strides, dtype)


def IterVar(dom, name='iter', thread_tag=''):
Expand Down Expand Up @@ -170,7 +211,7 @@ def sum(expr, rdom):
The reduction domainx
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Add", expr, rdom)
x = _make.Reduce("Add", expr, rdom)
return x


Expand All @@ -186,7 +227,7 @@ def min(expr, rdom):
The reduction domainx
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Min", expr, rdom)
x = _make.Reduce("Min", expr, rdom)
return x


Expand All @@ -202,7 +243,7 @@ def max(expr, rdom):
The reduction domainx
"""
rdom = rdom if isinstance(rdom, list) else [rdom]
x = _make.Reduce("Max", expr, rdom)
x = _make.Reduce("Max", expr, rdom)
return x


Expand Down
Loading

0 comments on commit 0992873

Please sign in to comment.