Skip to content

Commit

Permalink
Bug fix & Add python serialization API (apache#10)
Browse files Browse the repository at this point in the history
* Delete C++ UT hack since Python is ready

* Add ndarray.non_empty

* Update Serialization python API
  • Loading branch information
jcf94 authored and merrymercy committed Jun 20, 2020
1 parent 6b21dc6 commit e52135f
Show file tree
Hide file tree
Showing 14 changed files with 408 additions and 172 deletions.
23 changes: 23 additions & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,29 @@ TVM_DLL int TVMFuncListGlobalNames(int* out_size, const char*** out_array);
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_bits,
int dtype_lanes, int device_type, int device_id, TVMArrayHandle* out);

/*!
* \brief Allocate a nd-array's memory of non-empty values,
* including space of shape, of given spec.
*
* \param shape The shape of the array, the data content will be copied to out
* \param ndim The number of dimension of the array.
* \param dtype_code The type code of the dtype
* \param dtype_bits The number of bits of dtype
* \param dtype_lanes The number of lanes in the dtype.
* \param device_type The device type of context
* \param device_id The device id of context.
* \param out The output handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayAllocNonEmpty(const tvm_index_t* shape,
int ndim,
int dtype_code,
int dtype_bits,
int dtype_lanes,
int device_type,
int device_id,
TVMArrayHandle* out);

/*!
* \brief Free the TVM Array.
* \param handle The array handle to be freed.
Expand Down
12 changes: 11 additions & 1 deletion include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,17 @@ class NDArray : public ObjectRef {
* \param ctx The context of the Array.
* \return The created Array
*/
TVM_DLL static NDArray Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx);
TVM_DLL static NDArray Empty(std::vector<int64_t> shape,
DLDataType dtype, DLContext ctx);
/*!
* \brief Create an NDArray with non-empty values.
* \param shape The shape of the new array.
* \param dtype The data type of the new array.
* \param ctx The context of the Array.
* \return The created Array
*/
TVM_DLL static NDArray NonEmpty(std::vector<int64_t> shape,
DLDataType dtype, DLContext ctx);
/*!
* \brief Create a NDArray backed by a dlpack tensor.
*
Expand Down
1 change: 1 addition & 0 deletions python/tvm/ansor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .task import auto_schedule
from .measure import MeasureInput, LocalBuilder, LocalRunner
from .cost_model import RandomModel
from .serialization import LogToFile, LogReader, best_measure_pair_in_file
12 changes: 12 additions & 0 deletions python/tvm/ansor/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,15 @@ def print_python_code_from_state(self, state):
str : Str
"""
return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state)

def infer_bound_from_state(self, state):
"""
Parameters
----------
state : State
Returns
-------
state : State
"""
return _ffi_api.ComputeDAGInferBoundFromState(self, state)
8 changes: 6 additions & 2 deletions python/tvm/ansor/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
logger = logging.getLogger('ansor')


@tvm._ffi.register_object("ansor.MeasureCallback")
class MeasureCallback(Object):
pass

@tvm._ffi.register_object("ansor.MeasureInput")
class MeasureInput(Object):
"""
Expand Down Expand Up @@ -332,7 +336,7 @@ def timed_func():

if error_no == 0:
try:
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in
args = [ndarray.non_empty(get_const_tuple(x.shape), x.dtype, ctx) for x in
build_res.args]
ctx.sync()

Expand Down Expand Up @@ -390,7 +394,7 @@ def timed_func(inp, build_res):

if error_no == 0:
try:
args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in
args = [ndarray.non_empty(get_const_tuple(x.shape), x.dtype, ctx) for x in
build_res.args]
ctx.sync()

Expand Down
98 changes: 98 additions & 0 deletions python/tvm/ansor/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import
""" ... """
import numpy as np

import tvm._ffi
from tvm.runtime import Object

from .measure import MeasureCallback, MeasureErrorNo

from . import _ffi_api


@tvm._ffi.register_object("ansor.LogToFile")
class LogToFile(MeasureCallback):
"""
Parameters
----------
filename : Str
"""

def __init__(self, filename="ansor_tuning.json"):
self.__init_handle_by_constructor__(_ffi_api.LogToFile, filename)


@tvm._ffi.register_object("ansor.LogReader")
class LogReader(Object):
def __init__(self, filename="ansor_tuning.json"):
self.__init_handle_by_constructor__(_ffi_api.LogReader, filename)

def read_lines(self, max_size=-1, skip_size=0):
inputs, results = _ffi_api.LogReaderReadLines(
self, max_size, skip_size)
return inputs, results

def __iter__(self):
while True:
ret = _ffi_api.LogReaderReadNext(self)
if ret is None or not len(ret):
break
yield ret[0], ret[1] # (input, result)


def best_measure_pair_in_file(filename, workload_key=None, target=None):
""" Return best results form log file
Parameters
----------
filename : Str
workload_key : Str
target : Str
Returns
-------
inp : MeasureInput
res : MeasureResult
"""
log_reader = LogReader(filename)
best_cost = 1e30
best_inp = None
best_res = None

for inp, res in log_reader:
if res.error_no != MeasureErrorNo.NO_ERROR:
continue
if workload_key and inp.task.workload_key != workload_key:
continue
if target and inp.task.target.target_name != target.target_name:
continue

costs = []
for value in res.costs:
costs.append(value.value)
cost = np.mean(costs)
if cost < best_cost:
best_cost = cost
best_inp = inp
best_res = res

return best_inp, best_res
33 changes: 33 additions & 0 deletions python/tvm/runtime/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,39 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
return _make_array(handle, False, False)


def non_empty(shape, dtype="float32", ctx=context(1, 0)):
"""Create an non-empty array given shape and device
Parameters
----------
shape : tuple of int
The shape of the array
dtype : type or str
The data type of the array.
ctx : TVMContext
The context of the array
Returns
-------
arr : tvm.nd.NDArray
The array tvm supported.
"""
shape = c_array(tvm_shape_index_t, shape)
ndim = ctypes.c_int(len(shape))
handle = TVMArrayHandle()
dtype = DataType(dtype)
check_call(_LIB.TVMArrayAllocNonEmpty(
shape, ndim,
ctypes.c_int(dtype.type_code),
ctypes.c_int(dtype.bits),
ctypes.c_int(dtype.lanes),
ctx.device_type,
ctx.device_id,
ctypes.byref(handle)))
return _make_array(handle, False, False)

def from_dlpack(dltensor):
"""Produce an array from a DLPack tensor without memory copy.
Retreives the underlying DLPack tensor's pointer to create an array from the
Expand Down
5 changes: 5 additions & 0 deletions src/ansor/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1271,5 +1271,10 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState")
return dag.PrintStepsAsPython(state->transform_steps);
});

TVM_REGISTER_GLOBAL("ansor.ComputeDAGInferBoundFromState")
.set_body_typed([](const ComputeDAG& dag, const State& state) {
return dag.ReplayAndInferBound(state->transform_steps);
});

} // namespace ansor
} // namespace tvm
71 changes: 40 additions & 31 deletions src/ansor/search_policy/meta_tile_rewrite_policy.h
Original file line number Diff line number Diff line change
@@ -1,91 +1,100 @@
/*!
* Copyright (c) 2020 by Contributors
* \file ansor/meta_tile_rewrite_policy.h
* \brief A search policy that search with meta tiling structure and random rewrite
* \brief A search policy that search with meta tiling structure and random
* rewrite
*/
#ifndef TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_
#define TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_

#include <vector>
#include <set>
#include <string>
#include <utility>
#include <unordered_set>
#include <set>
#include "search_policy.h"
#include <utility>
#include <vector>

#include "../cost_model/cost_model.h"
#include "../utils.h"

#include "search_policy.h"

namespace tvm {
namespace ansor {

/*! Multi stage search policy */
class MetaTileRewritePolicyNode: public SearchPolicyNode {
class MetaTileRewritePolicyNode : public SearchPolicyNode {
public:
CostModel program_cost_model;

/* this->params is used to store the following arguments
* int evolutionary_search_population // The population size for evolutionary search
* int evolutionary_search_mutation_prob // The probability of mutation for evolutionary search
* int evolutionary_search_num_iters; // The number of iterations for evolutionary search
* double local_mutation_use_measured_ratio; // The maximum percentage of measured states in the initial
* // population for evolutionary search
* double eps_greedy; // Always allocate this percentage of measurements to random sampled states
* str cpu_multi_level_tiling_structure // The structure of multi-level tiling for CPU
* str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU
* int evolutionary_search_population
* The population size for evolutionary search
* int evolutionary_search_mutation_prob
* The probability of mutation for evolutionary search
* int evolutionary_search_num_iters
* The number of iterations for evolutionary search
* double local_mutation_use_measured_ratio
* The maximum percentage of measured states in the initial population
* for evolutionary search
* double eps_greedy
* Always allocate this percentage of measurements to random sampled states
* str cpu_multi_level_tiling_structure
* The structure of multi-level tiling for CPU
* str gpu_multi_level_tiling_structure
* The structure of multi-level tiling for GPU
*/
Map<std::string, ObjectRef> params;

static SearchPolicy make(CostModel program_cost_model,
Map<std::string, ObjectRef> params,
int seed);
Map<std::string, ObjectRef> params, int seed);

// Search and make n_trails measurements
// Return the best state
State Search(SearchTask task, int n_trials,
int early_stopping, int num_measure_per_iter,
int verbose, ProgramMeasurer measurer) final;
State Search(SearchTask task, int n_trials, int early_stopping,
int num_measure_per_iter, int verbose,
ProgramMeasurer measurer) final;

// Continue search. This is used by JointTuner
std::pair<Array<MeasureInput>, Array<MeasureResult> > ContinueSearchOneRound(
SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final;
SearchTask task, int num_measure, int verbose,
ProgramMeasurer measurer) final;

static constexpr const char *_type_key = "ansor.MetaTileRewritePolicy";
static constexpr const char* _type_key = "ansor.MetaTileRewritePolicy";
static const std::vector<int> auto_unroll_configs;

TVM_DECLARE_FINAL_OBJECT_INFO(MetaTileRewritePolicyNode, SearchPolicyNode);

SearchTask cur_task_; // The current task
SearchTask cur_task_; // The current task

friend class MetaTileRewritePolicyNodeTest; // Hack friend class for UT
protected:
// Pick states from best states and random states with eps-greedy policy
void PickStatesWithEpsGreedy(std::vector<MeasureInput>* inputs,
const std::vector<State>& best_states,
const std::vector<State>& random_states, int remaining_n_trials);
const std::vector<State>& random_states,
int remaining_n_trials);

private:
// Run one round of the search pipeline
void SearchOneRound(std::vector<State>* best_states,
int num_random_states, std::vector<State>* random_states);
void SearchOneRound(std::vector<State>* best_states, int num_random_states,
std::vector<State>* random_states);

// Synthesize meta tiling structure without tile size
void SynthesizeMetaStructure(std::vector<State>* out_states);

// Sample init population
void SampleInitPopulation(const std::vector<State>& meta_structures,
int out_size, std::vector<State>* out_states);
int out_size, std::vector<State>* out_states);

// Perform evolutionary search
void EvolutionarySearch(const std::vector<State>& init_population,
int num_best_states, std::vector<State>* best_states);
int num_best_states, std::vector<State>* best_states);

SplitFactorizationMemo split_memo_; // Memorize split space for Split
std::mt19937 rand_gen_; // Random generator
int verbose_; // Verbose level (0 means silent)
int num_measure_per_iter_; // The number of states to measure per iteration
int num_measure_per_iter_; // The number of states to measure per iteration

// The set of the already measured states. We store the string format for redundancy check
// The set of the already measured states. We store the string format for
// redundancy check
std::unordered_set<std::string> measured_states_set_;

// The array of already measured states.
Expand Down
Loading

0 comments on commit e52135f

Please sign in to comment.