Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SCAN/Refactor] Refactor scan interface, enable fix point analysis. #47

Merged
merged 1 commit into from
Feb 20, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,12 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
/*!
* \brief Construct new tensors by scan over scan_axis.
*
* \param scan_axis The iteration representing the scan.
* \param init The intialize tensor of first K steps.
* \param update The update tensor indicated the updated result after each timestamp.
* \param state_placeholder The placeholder for the states.
* \param name The optional name of the tensor.
*/
Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name = "scan");
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ enum AttachType : int {
kNone = 0,
kRoot = 1,
kInline = 2,
kScope = 3
kInlinedAlready = 3,
kScope = 4,
kScanUpdate = 5
};

/*! \brief IterVar type */
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ class OperationNode : public FunctionBaseNode {
virtual Type output_dtype(size_t i) const = 0;
/*! \return shape of i-th output */
virtual Array<Expr> output_shape(size_t i) const = 0;

static constexpr const char* _type_key = "Operation";
};

// Implementations of inline functions
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/addon/nvcc_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
import tempfile
import subprocess

def compile_source(code, target="cubin"):
def compile_source(code, target="cubin", options=None):
"""Compile cuda code with NVCC from env.

Parameters
----------
code : str
The cuda code.

target: str
target : str
The target format

options : str
The additional options

Return
------
cubin : bytearray
Expand All @@ -32,6 +35,8 @@ def compile_source(code, target="cubin"):
cmd = ["nvcc"]
cmd += ["--%s" % target, "-O3"]
cmd += ["-o", path_target]
if options:
cmd += options
cmd += [path_code]
args = ' '.join(cmd)

Expand Down
11 changes: 4 additions & 7 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,11 @@ def compute(shape, fcompute, name="compute"):
return op_node.output(0)


def scan(axis, init, update, state_placeholder, name="scan"):
def scan(init, update, state_placeholder, name="scan"):
"""Construct new tensors by scanning over axis.

Parameters
----------
axis: IterVar
The scanning axis.

init: Tensor or list of Tensor
The initial condition of first init.shape[0] timestamps

Expand All @@ -170,12 +167,11 @@ def scan(axis, init, update, state_placeholder, name="scan"):
# The following code is equivalent to numpy.cumsum
m = tvm.Var("m")
n = tvm.Var("n")
t = tvm.IterVar((1, m), name="t")
X = tvm.placeholder((m, n), name="X")
s_state = tvm.placeholder((m, n))
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
res = tvm.scan(t, s_init, s_update, s_state)
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
res = tvm.scan(s_init, s_update, s_state)
"""
if isinstance(init, _tensor.Tensor):
init = [init]
Expand All @@ -185,6 +181,7 @@ def scan(axis, init, update, state_placeholder, name="scan"):
state_placeholder = [state_placeholder]
if len(init) != len(update) or len(init) != len(state_placeholder):
raise ValueError("init, update, state_placeholder must have same length")
axis = IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name)
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
res = [op.output(i) for i in range(len(update))]
return (res[0] if len(res) == 1 else res)
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def build(sch,
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
# lowering
# normalize schedule first
sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
Expand Down
3 changes: 3 additions & 0 deletions src/api/api_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise)
REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(CreateAttachPath);
REGISTER_SCHEDULE_PASS2(ScanFixPointAnalysis);
REGISTER_SCHEDULE_PASS2(ScheduleOps);

} // namespace schedule
Expand Down
10 changes: 9 additions & 1 deletion src/arithmetic/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,15 @@ IntSet Union(const Array<IntSet>& set) {
if (set.size() == 1) return set[0];
Interval x = set[0].cover_interval().as<IntervalSet>()->i;
for (size_t i = 1; i < set.size(); ++i) {
x.include(set[i].cover_interval().as<IntervalSet>()->i);
IntSet s = set[i].cover_interval();
const Interval& y = s.as<IntervalSet>()->i;
if (can_prove(x.max + 1 >= y.min)) {
x.max = y.max;
} else if (can_prove(y.max + 1 >= x.min)) {
x.min = y.min;
} else {
x.include(y);
}
}
return IntervalSet::make(x);
}
Expand Down
28 changes: 15 additions & 13 deletions src/lang/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ Operation PlaceholderOpNode::make(std::string name,
return Operation(n);
}



Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
return PlaceholderOpNode::make(name, shape, dtype).output(0);
}
Expand Down Expand Up @@ -162,24 +160,25 @@ Operation ScanOpNode::make(std::string name,
<< " scan_axis.dom.min + scan_axis.dom.extent";
CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
<< "The dimension of init need to match state_placeholder";
CHECK_EQ(update[i].ndim() + 1, state_placeholder[i].ndim())
CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim())
<< "The update.ndim need to be state_placeholder.ndim - 1";
for (size_t k = 0; k < update[i].ndim(); ++k) {
CHECK(prove_equal(
update[i]->shape[k], state_placeholder[i]->shape[k + 1]));
// setup spatial axis
std::ostringstream spatial_name;
spatial_name << name << ".out" << i << ".i" << k + 1;
n->spatial_axis_.push_back(
IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
spatial_name.str()));
update[i]->shape[k], state_placeholder[i]->shape[k]));
if (k != 0) {
// setup spatial axis
std::ostringstream spatial_name;
spatial_name << name << ".out" << i << ".i" << k;
n->spatial_axis_.push_back(
IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
spatial_name.str()));
}
}
for (size_t k = 1; k < init[i].ndim(); ++k) {
CHECK(prove_equal(
init[i]->shape[k], state_placeholder[i]->shape[k]));
}
}

n->name = name;
n->scan_axis = axis;
n->init = init;
Expand All @@ -188,11 +187,14 @@ Operation ScanOpNode::make(std::string name,
return Operation(n);
}

Array<Tensor> scan(IterVar scan_axis,
Array<Tensor> init,
Array<Tensor> scan(Array<Tensor> init,
Array<Tensor> update,
Array<Tensor> state_placeholder,
std::string name) {
IterVar scan_axis(
Range::make_with_min_extent(
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
name + ".idx");
Operation op = ScanOpNode::make(
name, scan_axis, init, update, state_placeholder);
Array<Tensor> res;
Expand Down
4 changes: 3 additions & 1 deletion src/pass/inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ Stmt Inline(Stmt stmt,
Expr body) {
CHECK_EQ(f->num_outputs(), 1)
<< "can only inline output single value operation";
return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
Stmt ret = IRInline(f, args, body).Mutate(stmt);
if (ret.same_as(stmt)) return ret;
return ConvertSSA(ret);
}
} // namespace ir
} // namespace tvm
Loading