Skip to content

Commit

Permalink
Rewrite the Rust Module API and change some imports causing crashes.
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Feb 4, 2021
1 parent 618ef9e commit 0eb6aa5
Show file tree
Hide file tree
Showing 34 changed files with 297 additions and 170 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
from .scope_builder import ScopeBuilder

# Load Memory Passes
from .transform import memory_plan
from .transform import memory_alloc, memory_plan

# Required to traverse large programs
setrecursionlimit(10000)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from . import call_graph
from .call_graph import CallGraph

# Feature
# # Feature
from . import feature
from . import sparse_dense

Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
This file contains the set of passes for Relay, which exposes an interface for
configuring the passes and scripting them in Python.
"""
from tvm.ir import IRModule
from tvm.relay import transform, build_module
from tvm.runtime.ndarray import cpu
from ...ir import IRModule
from ...relay import transform, build_module
from ...runtime.ndarray import cpu

from . import _ffi_api
from .feature import Feature
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/analysis/annotated_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Regions used in Relay."""

from tvm.runtime import Object
from ...runtime import Object
from . import _ffi_api


Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/analysis/call_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Call graph used in Relay."""

from tvm.ir import IRModule
from tvm.runtime import Object
from ...ir import IRModule
from ...runtime import Object
from ..expr import GlobalVar
from . import _ffi_api

Expand Down
15 changes: 8 additions & 7 deletions python/tvm/relay/analysis/sparse_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
"""
from collections import namedtuple
import numpy as np
import scipy.sparse as sp
import tvm

from ... import nd, runtime
from . import _ffi_api


Expand Down Expand Up @@ -73,6 +73,7 @@ def process_params(expr, params, block_size, sparsity_threshold):
ret : Namedtuple[weight_name: Array[String], weight_shape: Array[Array[IntImm]]]
return names of qualified dense weight and the shape in BSR format
"""
import scipy.sparse as sp
memo = SparseAnalysisResult(weight_name=[], weight_shape=[])
weight_names = _search_dense_op_weight(expr)
for name in weight_names:
Expand All @@ -89,11 +90,11 @@ def process_params(expr, params, block_size, sparsity_threshold):
+ list(sparse_weight.indices.shape)
+ list(sparse_weight.indptr.shape)
)
params[name + ".data"] = tvm.nd.array(sparse_weight.data)
params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)
params[name + ".data"] = nd.array(sparse_weight.data)
params[name + ".indices"] = nd.array(sparse_weight.indices)
params[name + ".indptr"] = nd.array(sparse_weight.indptr)
ret = SparseAnalysisResult(
weight_name=tvm.runtime.convert(memo.weight_name),
weight_shape=tvm.runtime.convert(memo.weight_shape),
weight_name=runtime.convert(memo.weight_name),
weight_shape=runtime.convert(memo.weight_shape),
)
return ret
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/graph_runtime_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm.runtime import ndarray


class GraphRuntimeFactoryModule(object):
class GraphRuntimeFactoryModule:
"""Graph runtime factory module.
This is a module of graph runtime factory
Expand Down
14 changes: 13 additions & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from tvm.ir.transform import PassContext
from tvm.tir import expr as tvm_expr
from .. import nd as _nd, autotvm
from .. import nd as _nd, autotvm, register_func
from ..target import Target
from ..contrib import graph_runtime as _graph_rt
from . import _build_module
Expand Down Expand Up @@ -193,6 +193,18 @@ def get_params(self):
ret[key] = value.data
return ret

@register_func("tvm.relay.build")
def _rust_build_module(mod, target=None, target_host=None, params=None, mod_name="default"):
print(mod)
print("\n")
rt_mod = build(mod, target, target_host, params, mod_name).module
print(rt_mod)
print(rt_mod["default"])
return rt_mod

@register_func("tvm.relay.module_export_library")
def _module_export(module, file_name): # fcompile, addons, kwargs?
return module.export_library(file_name)

def build(mod, target=None, target_host=None, params=None, mod_name="default"):
# fmt: off
Expand Down
3 changes: 0 additions & 3 deletions python/tvm/relay/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
Contains the model importers currently defined
for Relay.
"""

from __future__ import absolute_import

from .mxnet import from_mxnet
from .mxnet_qnn_op_utils import quantize_conv_bias_mkldnn_from_var
from .keras import from_keras
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,10 +914,12 @@ def _impl(inputs, attr, params, mod):


def _sparse_tensor_dense_matmul():
# Sparse utility from scipy
from scipy.sparse import csr_matrix

def _impl(inputs, attr, params, mod):
# Loading this by default causes TVM to not be loadable from other languages.
# Sparse utility from scipy
from scipy.sparse import csr_matrix

assert len(inputs) == 4, "There should be 4 input tensors"

indices_tensor = _infer_value(inputs[0], params, mod).asnumpy()
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def register_shape_func(op_name, data_dependent, shape_func=None, level=10):
"""
if not isinstance(data_dependent, list):
data_dependent = [data_dependent]
get(op_name).set_attr("TShapeDataDependent", data_dependent, level)
get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
return tvm.ir.register_op_attr(op_name, "FShapeFunc", shape_func, level)


Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
import functools
import warnings

from ...ir import transform as tvm_transform
import tvm.ir
from tvm import te
from tvm.runtime import ndarray as _nd

from tvm import relay
# from tvm import relay
from . import _ffi_api


Expand Down Expand Up @@ -82,7 +83,7 @@ def build_config(opt_level=2, required_pass=None, disabled_pass=None, trace=None


@tvm._ffi.register_object("relay.FunctionPass")
class FunctionPass(tvm.ir.transform.Pass):
class FunctionPass():
"""A pass that works on each tvm.relay.Function in a module. A function
pass class should be created through `function_pass`.
"""
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

# pylint: disable=redefined-builtin, wildcard-import
"""CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs

from .conv1d import *
from .conv1d_transpose_ncw import *
from .conv2d import *
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/topi/cuda/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

"""Sparse operators"""
import numpy as np
import scipy.sparse as sp

import tvm
from tvm import relay, te
Expand Down Expand Up @@ -361,6 +360,7 @@ def schedule_sparse_dense_padded(outs):

def pad_sparse_matrix(matrix, blocksize):
"""Pad rows of sparse matrix matrix so that they are a multiple of blocksize."""
import scipy.sparse as sp
assert isinstance(matrix, sp.bsr_matrix)
new_entries = np.zeros(matrix.shape[0], dtype=matrix.indptr.dtype)
bsr = matrix.blocksize[0]
Expand Down Expand Up @@ -397,6 +397,7 @@ def _alter_sparse_dense_layout(_attrs, inputs, _tinfos, _out_type):
sparse_dense implementation for one that operates on a padded matrix. We
also padd the matrix.
"""
import scipy.sparse as sp
# TODO(ANSHUMAN87): Handle for sparse_lhs case too
if (
isinstance(inputs[1], relay.Constant)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/topi/testing/conv1d_transpose_ncw_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
# pylint: disable=unused-variable
"""Transposed 1D convolution in python"""
import numpy as np
import scipy
import tvm.topi.testing
from tvm.topi.nn.utils import get_pad_tuple1d


def conv1d_transpose_ncw_python(a_np, w_np, stride, padding, output_padding):
"""Transposed 1D convolution operator in NCW layout.
Expand Down Expand Up @@ -51,6 +49,7 @@ def conv1d_transpose_ncw_python(a_np, w_np, stride, padding, output_padding):
3-D with shape [batch, out_channel, out_width]
"""
import scipy
batch, in_c, in_w = a_np.shape
_, out_c, filter_w = w_np.shape
opad = output_padding[0]
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/testing/conv2d_hwcn_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Convolution in python"""
import numpy as np
import scipy.signal
from tvm.topi.nn.utils import get_pad_tuple


Expand Down Expand Up @@ -45,6 +44,7 @@ def conv2d_hwcn_python(a_np, w_np, stride, padding):
b_np : np.ndarray
4-D with shape [out_height, out_width, out_channel, batch]
"""
import scipy.signal
in_height, in_width, in_channel, batch = a_np.shape
kernel_h, kernel_w, _, num_filter = w_np.shape
if isinstance(stride, int):
Expand Down
12 changes: 12 additions & 0 deletions rust/tvm-rt/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,18 @@ where
let oref: ObjectRef = map_get_item(self.object.clone(), key.upcast())?;
oref.downcast()
}

pub fn empty() -> Self {
Self::from_iter(vec![].into_iter())
}

//(@jroesch): I don't think this is a correct implementation.
pub fn null() -> Self {
Map {
object: ObjectRef::null(),
_data: PhantomData,
}
}
}

pub struct IntoIter<K, V> {
Expand Down
54 changes: 29 additions & 25 deletions rust/tvm-rt/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,25 @@ use std::{
};

use tvm_sys::ffi;
use tvm_macros::Object;
use crate::object::{Object, ObjectPtr, IsObjectRef};

use crate::errors::Error;
use crate::{errors, function::Function};
use crate::{String as TString};

const ENTRY_FUNC: &str = "__tvm_main__";

/// Wrapper around TVM module handle which contains an entry function.
/// The entry function can be applied to an imported module through [`entry_func`].
///
/// [`entry_func`]:struct.Module.html#method.entry_func
#[derive(Debug, Clone)]
pub struct Module {
pub(crate) handle: ffi::TVMModuleHandle,
entry_func: Option<Function>,
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "Module"]
#[type_key = "runtime.Module"]
pub struct ModuleNode {
base: Object,
}

crate::external! {
Expand All @@ -49,21 +54,18 @@ crate::external! {

#[name("runtime.ModuleLoadFromFile")]
fn load_from_file(file_name: CString, format: CString) -> Module;

#[name("runtime.ModuleSaveToFile")]
fn save_to_file(module: Module, name: TString, fmt: TString);

// TODO(@jroesch): we need to refactor this
#[name("tvm.relay.module_export_library")]
fn export_library(module: Module, file_name: TString);
}

impl Module {
pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self {
Self {
handle,
entry_func: None,
}
}

pub fn entry(&mut self) -> Option<Function> {
if self.entry_func.is_none() {
self.entry_func = self.get_function(ENTRY_FUNC, false).ok();
}
self.entry_func.clone()
panic!()
}

/// Gets a function by name from a registered module.
Expand All @@ -72,7 +74,7 @@ impl Module {
let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;

check_call!(ffi::TVMModGetFunction(
self.handle,
self.handle(),
name.as_ptr() as *const c_char,
query_import as c_int,
&mut fhandle as *mut _
Expand All @@ -87,7 +89,7 @@ impl Module {

/// Imports a dependent module such as `.ptx` for gpu.
pub fn import_module(&self, dependent_module: Module) {
check_call!(ffi::TVMModImport(self.handle, dependent_module.handle))
check_call!(ffi::TVMModImport(self.handle(), dependent_module.handle()))
}

/// Loads a module shared library from path.
Expand All @@ -110,6 +112,14 @@ impl Module {
Ok(module)
}

pub fn save_to_file(&self, name: String, fmt: String) -> Result<(), Error> {
save_to_file(self.clone(), name.into(), fmt.into())
}

pub fn export_library(&self, name: String) -> Result<(), Error> {
export_library(self.clone(), name.into())
}

/// Checks if a target device is enabled for a module.
pub fn enabled(&self, target: &str) -> bool {
let target = CString::new(target).unwrap();
Expand All @@ -118,13 +128,7 @@ impl Module {
}

/// Returns the underlying module handle.
pub fn handle(&self) -> ffi::TVMModuleHandle {
self.handle
}
}

impl Drop for Module {
fn drop(&mut self) {
check_call!(ffi::TVMModFree(self.handle));
pub unsafe fn handle(&self) -> ffi::TVMModuleHandle {
self.0.clone().unwrap().into_raw() as *mut _
}
}
Loading

0 comments on commit 0eb6aa5

Please sign in to comment.