Skip to content

Commit

Permalink
[AutoTVM] fix argument type for curve feature (#3004)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Apr 11, 2019
1 parent 5178506 commit 5a27632
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/autotvm/touch_extractor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,10 +514,10 @@ TVM_REGISTER_API("autotvm.feature.GetItervarFeatureFlatten")
TVM_REGISTER_API("autotvm.feature.GetCurveSampleFeatureFlatten")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Stmt stmt = args[0];
bool take_log = args[1];
int sample_n = args[1];
std::vector<float> ret_feature;

GetCurveSampleFeatureFlatten(stmt, take_log, &ret_feature);
GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature);

TVMByteArray arr;
arr.size = sizeof(float) * ret_feature.size();
Expand Down
19 changes: 19 additions & 0 deletions tests/python/unittest/test_autotvm_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,23 @@ def test_iter_feature_gemm():
assert ans[pair[0]] == pair[1:], "%s: %s vs %s" % (pair[0], ans[pair[0]], pair[1:])


def test_curve_feature_gemm():
N = 128

k = tvm.reduce_axis((0, N), 'k')
A = tvm.placeholder((N, N), name='A')
B = tvm.placeholder((N, N), name='B')
C = tvm.compute(
A.shape,
lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k),
name='C')

s = tvm.create_schedule(C.op)

feas = feature.get_buffer_curve_sample_flatten(s, [A, B, C], sample_n=30)
# sample_n * #buffers * #curves * 2 numbers per curve
assert len(feas) == 30 * 3 * 4 * 2

def test_feature_shape():
"""test the dimensions of flatten feature are the same"""

Expand Down Expand Up @@ -112,4 +129,6 @@ def get_gemm_feature(target):

if __name__ == "__main__":
test_iter_feature_gemm()
test_curve_feature_gemm()
test_feature_shape()

0 comments on commit 5a27632

Please sign in to comment.