Skip to content

Commit

Permalink
Cleanup + Rust Module Build test (#12)
Browse files Browse the repository at this point in the history
* remove unneeded changes

* remove println

* remove unnecessary changes

* add simple module build test

* fix some tests

* clone

* revert changes to vtahw
  • Loading branch information
hypercubestart authored Feb 4, 2021
1 parent 44d6706 commit 44c738b
Show file tree
Hide file tree
Showing 18 changed files with 48 additions and 41 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/vta-hw
3 changes: 2 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@
from .scope_builder import ScopeBuilder

# Load Memory Passes
from .transform import memory_alloc, memory_plan
from .transform import memory_alloc
from .transform import memory_plan

# Required to traverse large programs
setrecursionlimit(10000)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from . import call_graph
from .call_graph import CallGraph

# # Feature
# Feature
from . import feature
from . import sparse_dense

Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
This file contains the set of passes for Relay, which exposes an interface for
configuring the passes and scripting them in Python.
"""
from ...ir import IRModule
from ...relay import transform, build_module
from ...runtime.ndarray import cpu
from tvm.ir import IRModule
from tvm.relay import transform, build_module
from tvm.runtime.ndarray import cpu

from . import _ffi_api
from .feature import Feature
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/analysis/annotated_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Regions used in Relay."""

from ...runtime import Object
from tvm.runtime import Object
from . import _ffi_api


Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/analysis/call_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-import
"""Call graph used in Relay."""

from ...ir import IRModule
from ...runtime import Object
from tvm.ir import IRModule
from tvm.runtime import Object
from ..expr import GlobalVar
from . import _ffi_api

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/graph_runtime_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm.runtime import ndarray


class GraphRuntimeFactoryModule:
class GraphRuntimeFactoryModule(object):
"""Graph runtime factory module.
This is a module of graph runtime factory
Expand Down
4 changes: 0 additions & 4 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,7 @@ def get_params(self):

@register_func("tvm.relay.build")
def _rust_build_module(mod, target=None, target_host=None, params=None, mod_name="default"):
print(mod)
print("\n")
rt_mod = build(mod, target, target_host, params, mod_name).module
print(rt_mod)
print(rt_mod["default"])
return rt_mod

@register_func("tvm.relay.module_export_library")
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/frontend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
Contains the model importers currently defined
for Relay.
"""

from __future__ import absolute_import

from .mxnet import from_mxnet
from .mxnet_qnn_op_utils import quantize_conv_bias_mkldnn_from_var
from .keras import from_keras
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
# transformation passes
from .transform import *
from .recast import recast
# from . import memory_alloc
from . import memory_alloc
7 changes: 4 additions & 3 deletions python/tvm/relay/transform/memory_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
"""
import numpy as np

from ... import DataType, register_func, nd, container, cpu
from ...ir.transform import PassContext, module_pass
from . import InferType
from tvm.ir.transform import PassContext, module_pass
from tvm.relay.transform import InferType
from tvm import nd, container
from ..function import Function
from ..expr_functor import ExprVisitor, ExprMutator
from ..scope_builder import ScopeBuilder
from .. import op
from ... import DataType, register_func
from .. import ty, expr
from ..backend import compile_engine
from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type
Expand Down
5 changes: 2 additions & 3 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
import functools
import warnings

from ...ir import transform as tvm_transform
import tvm.ir
from tvm import te
from tvm.runtime import ndarray as _nd

# from tvm import relay
from tvm import relay
from . import _ffi_api


Expand Down Expand Up @@ -83,7 +82,7 @@ def build_config(opt_level=2, required_pass=None, disabled_pass=None, trace=None


@tvm._ffi.register_object("relay.FunctionPass")
class FunctionPass():
class FunctionPass(tvm.ir.transform.Pass):
"""A pass that works on each tvm.relay.Function in a module. A function
pass class should be created through `function_pass`.
"""
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

# pylint: disable=redefined-builtin, wildcard-import
"""CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs

from .conv1d import *
from .conv1d_transpose_ncw import *
from .conv2d import *
Expand Down
10 changes: 0 additions & 10 deletions rust/tvm-rt/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@ use crate::errors::Error;
use crate::{errors, function::Function};
use crate::{String as TString};

const ENTRY_FUNC: &str = "__tvm_main__";

/// Wrapper around TVM module handle which contains an entry function.
/// The entry function can be applied to an imported module through [`entry_func`].
///
/// [`entry_func`]:struct.Module.html#method.entry_func
#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "Module"]
Expand All @@ -64,10 +58,6 @@ crate::external! {
}

impl Module {
pub fn entry(&mut self) -> Option<Function> {
panic!()
}

/// Gets a function by name from a registered module.
pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function, Error> {
let name = CString::new(name)?;
Expand Down
30 changes: 24 additions & 6 deletions rust/tvm/src/compiler/graph_rt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,29 @@ where P1: AsRef<Path>, P2: AsRef<Path> {
let mut input_module_text = std::string::String::new();
input_file.read_to_string(&mut input_module_text)?;
let input_module = IRModule::parse("name", input_module_text)?;
println!("Before");
// let rt_module = compile_module(config, input_module)?;
// println!("Pointer {:p}\n", rt_module.handle() );
// let output_file =
// std::fs::File::open(output_module.as_ref())?;
// panic!()
Ok(())
}

#[cfg(test)]
mod tests {
use crate::ir::IRModule;
use crate::ir::relay::*;
use crate::DataType;
use anyhow::Result;
use crate::ir::span::Span;
use crate::ir::ty::GlobalTypeVar;
use super::compile_module;
use tvm_rt::IsObjectRef;

#[test]
fn test_module_build() -> Result<()> {
let mut module = IRModule::empty()?;
let x = Var::static_tensor("x".into(), vec![1, 1], DataType::float32());
let params = vec![x.clone()];
let func = Function::simple(params, x);
let module = module.add(GlobalVar::new("main".into(), Span::null()), func)?;

let rtmodule = compile_module(Default::default(), module)?;
Ok(())
}
}
1 change: 0 additions & 1 deletion rust/tvm/src/compiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ pub mod graph_rt;

pub(self) static TVM_LOADED: Lazy<Function> = Lazy::new(|| {
let ver = python::load().unwrap();
println!("version: {}", ver);
python::import("tvm.relay").unwrap();
Function::get("tvm.relay.build").unwrap()
});
2 changes: 0 additions & 2 deletions rust/tvm/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ use pyo3::prelude::*;
pub fn load() -> Result<String, ()> {
let gil = Python::acquire_gil();
let py = gil.python();
// let main_mod = initialize();
//let main_mod = main_mod.as_ref(py);
load_python_tvm_(py).map_err(|e| {
// We can't display Python exceptions via std::fmt::Display,
// so print the error here manually.
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm/tests/basics/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ fn main() {
fadd.import_module(Module::load(&concat!(env!("OUT_DIR"), "/test_add.ptx")).unwrap());
}

fadd.entry()
fadd.get_function("__tvm_main__", false).ok().clone()
.expect("module must have entry point")
.invoke(vec![(&arr).into(), (&arr).into(), (&ret).into()])
.unwrap();
Expand Down

0 comments on commit 44c738b

Please sign in to comment.