Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust][Diagnostics] Add initial boilerplate for Rust diagnostic interface. #6656

Merged
merged 32 commits into from
Nov 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e6f992e
Add initial boilerplate for Rust diagnostic interface.
jroesch Oct 9, 2020
4fe35b0
Codespan example almost working
jroesch Oct 10, 2020
dfacf9e
WIP
jroesch Oct 13, 2020
e827660
Hacking on Rust inside of TVM
jroesch Oct 13, 2020
a1a4f3e
Borrow code from Egg
jroesch Oct 13, 2020
52177cc
Update CMake and delete old API
jroesch Oct 15, 2020
c5b4061
Fix Linux build
jroesch Oct 15, 2020
29754ae
Clean up exporting to show off new diagnostics
jroesch Oct 16, 2020
39c90da
Improve Rust bindings
jroesch Oct 16, 2020
c1b994c
Fix calling
jroesch Oct 16, 2020
7ea0c34
Fix
jroesch Oct 16, 2020
46c46ad
Rust Diagnostics work
jroesch Oct 16, 2020
8f219a6
Remove type checker
jroesch Oct 16, 2020
6f28414
Format and cleanup
jroesch Oct 16, 2020
af518e1
Fix the extension code
jroesch Oct 16, 2020
beb8f1c
More cleanup
jroesch Oct 16, 2020
657c708
Fix some CR
jroesch Oct 20, 2020
06cdc47
Add docs and address feedback
jroesch Oct 20, 2020
db6b355
WIP more improvments
jroesch Oct 21, 2020
9aa1a09
Update cmake/modules/RustExt.cmake
jroesch Oct 24, 2020
7e038e0
Update rust/tvm/src/ir/diagnostics/mod.rs
jroesch Oct 24, 2020
9a0e727
Clean up PR
jroesch Oct 24, 2020
d086193
Format all
jroesch Oct 24, 2020
62f3e39
Remove dead comment
jroesch Oct 24, 2020
0b5645e
Code review comments and apache headers
jroesch Oct 26, 2020
2f778db
Purge test file
jroesch Oct 26, 2020
e8fd9a5
Update cmake/modules/LLVM.cmake
jroesch Oct 30, 2020
e92adcc
Format Rust
jroesch Oct 30, 2020
5731a60
Add TK's suggestion
jroesch Oct 30, 2020
0065966
More CR and cleanup
jroesch Oct 31, 2020
5f2ad03
Fix tyck line
jroesch Oct 31, 2020
9700d81
Format
jroesch Oct 31, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF)
tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "Build with Arm Compute Library graph runtime" OFF)
tvm_option(USE_TENSORRT_CODEGEN "Build with TensorRT Codegen support" OFF)
tvm_option(USE_TENSORRT_RUNTIME "Build with TensorRT runtime" OFF)
tvm_option(USE_RUST_EXT "Build with Rust based compiler extensions, STATIC, DYNAMIC, or OFF" OFF)

# include directories
include_directories(${CMAKE_INCLUDE_PATH})
Expand Down Expand Up @@ -352,6 +353,7 @@ include(cmake/modules/contrib/ArmComputeLib.cmake)
include(cmake/modules/contrib/TensorRT.cmake)
include(cmake/modules/Git.cmake)
include(cmake/modules/LibInfo.cmake)
include(cmake/modules/RustExt.cmake)

include(CheckCXXCompilerFlag)
if(NOT MSVC)
Expand Down
9 changes: 8 additions & 1 deletion cmake/modules/LLVM.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@
# under the License.

# LLVM rules
add_definitions(-DDMLC_USE_FOPEN64=0)
# Due to LLVM debug symbols you can sometimes face linking issues on
# certain compiler, platform combinations if you don't set NDEBUG.
#
# See https://github.com/imageworks/OpenShadingLanguage/issues/1069
# for more discussion.
add_definitions(-DDMLC_USE_FOPEN64=0 -DNDEBUG=1)
jroesch marked this conversation as resolved.
Show resolved Hide resolved
# TODO(@jroesch, @tkonolige): if we actually use targets we can do this.
# target_compile_definitions(tvm PRIVATE NDEBUG=1)

# Test if ${USE_LLVM} is not an explicit boolean false
# It may be a boolean or a string
Expand Down
43 changes: 43 additions & 0 deletions cmake/modules/RustExt.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

if(USE_RUST_EXT)
set(RUST_SRC_DIR "${CMAKE_SOURCE_DIR}/rust")
set(CARGO_OUT_DIR "${CMAKE_SOURCE_DIR}/rust/target")

if(USE_RUST_EXT STREQUAL "STATIC")
set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.a")
elseif(USE_RUST_EXT STREQUAL "DYNAMIC")
set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.so")
else()
message(FATAL_ERROR "invalid setting for USE_RUST_EXT, STATIC, DYNAMIC or OFF")
endif()

add_custom_command(
OUTPUT "${COMPILER_EXT_PATH}"
COMMAND cargo build --release
MAIN_DEPENDENCY "${RUST_SRC_DIR}"
WORKING_DIRECTORY "${RUST_SRC_DIR}/compiler-ext")

add_custom_target(rust_ext ALL DEPENDS "${COMPILER_EXT_PATH}")

# TODO(@jroesch, @tkonolige): move this to CMake target
# target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE)
list(APPEND TVM_LINKER_LIBS ${COMPILER_EXT_PATH})

add_definitions(-DRUST_COMPILER_EXT=1)
endif()
2 changes: 0 additions & 2 deletions include/tvm/parser/source_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ class SourceMap : public ObjectRef {

TVM_DLL SourceMap() : SourceMap({}) {}

TVM_DLL static SourceMap Global();

void Add(const Source& source);

SourceMapNode* operator->() {
Expand Down
1 change: 1 addition & 0 deletions python/tvm/ir/diagnostics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_renderer():
return _ffi_api.GetRenderer()


@tvm.register_func("diagnostics.override_renderer")
def override_renderer(render_func):
"""
Sets a custom renderer for diagnostics.
Expand Down
1 change: 1 addition & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ members = [
"tvm-graph-rt/tests/test_tvm_dso",
"tvm-graph-rt/tests/test_wasm32",
"tvm-graph-rt/tests/test_nn",
"compiler-ext",
]
32 changes: 32 additions & 0 deletions rust/compiler-ext/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

[package]
name = "compiler-ext"
version = "0.1.0"
authors = ["TVM Contributors"]
edition = "2018"

[lib]
crate-type = ["staticlib", "cdylib"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
tvm = { path = "../tvm", default-features = false, features = ["static-linking"] }
log = "*"
jroesch marked this conversation as resolved.
Show resolved Hide resolved
env_logger = "*"
35 changes: 35 additions & 0 deletions rust/compiler-ext/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

use env_logger;
use tvm::export;

fn diagnostics() -> Result<(), tvm::Error> {
tvm::ir::diagnostics::codespan::init()
}

export!(diagnostics);

#[no_mangle]
extern "C" fn compiler_ext_initialize() -> i32 {
let _ = env_logger::try_init();
tvm_export("rust_ext").expect("failed to initialize the Rust compiler extensions.");
log::debug!("Loaded the Rust compiler extension.");
return 0;
}
15 changes: 11 additions & 4 deletions rust/tvm-rt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,26 @@ categories = ["api-bindings", "science"]
authors = ["TVM Contributors"]
edition = "2018"

[features]
default = ["dynamic-linking"]
dynamic-linking = ["tvm-sys/bindings"]
static-linking = []
blas = ["ndarray/blas"]

[dependencies]
thiserror = "^1.0"
ndarray = "0.12"
num-traits = "0.2"
tvm-sys = { version = "0.1", path = "../tvm-sys/", features = ["bindings"] }
tvm-macros = { version = "0.1", path = "../tvm-macros" }
paste = "0.1"
mashup = "0.1"
once_cell = "^1.3.1"
memoffset = "0.5.6"

[dependencies.tvm-sys]
version = "0.1"
default-features = false
path = "../tvm-sys/"

[dev-dependencies]
anyhow = "^1.0"

[features]
blas = ["ndarray/blas"]
37 changes: 37 additions & 0 deletions rust/tvm-rt/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/

use std::convert::{TryFrom, TryInto};
use std::iter::{IntoIterator, Iterator};
use std::marker::PhantomData;

use crate::errors::Error;
Expand Down Expand Up @@ -81,6 +82,42 @@ impl<T: IsObjectRef> Array<T> {
}
}

pub struct IntoIter<T: IsObjectRef> {
array: Array<T>,
pos: isize,
size: isize,
}

impl<T: IsObjectRef> Iterator for IntoIter<T> {
type Item = T;

fn next(&mut self) -> Option<Self::Item> {
if self.pos < self.size {
let item =
self.array.get(self.pos)
.expect("Can not index as in-bounds position after bounds checking.\nNote: this error can only be do to an uncaught issue with API bindings.");
self.pos += 1;
Some(item)
} else {
None
}
}
}

impl<T: IsObjectRef> IntoIterator for Array<T> {
type Item = T;
type IntoIter = IntoIter<T>;

fn into_iter(self) -> Self::IntoIter {
let size = self.len() as isize;
IntoIter {
array: self,
pos: 0,
size: size,
}
}
}

impl<T: IsObjectRef> From<Array<T>> for ArgValue<'static> {
fn from(array: Array<T>) -> ArgValue<'static> {
array.object.into()
Expand Down
17 changes: 17 additions & 0 deletions rust/tvm-rt/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ pub enum Error {
Infallible(#[from] std::convert::Infallible),
#[error("a panic occurred while executing a Rust packed function")]
Panic,
#[error(
"one or more error diagnostics were emitted, please check diagnostic render for output."
)]
DiagnosticError(String),
#[error("{0}")]
Raw(String),
}

impl Error {
pub fn from_raw_tvm(raw: &str) -> Error {
let err_header = raw.find(":").unwrap_or(0);
let (err_ty, err_content) = raw.split_at(err_header);
match err_ty {
"DiagnosticError" => Error::DiagnosticError((&err_content[1..]).into()),
_ => Error::Raw(raw.into()),
}
}
}

impl Error {
Expand Down
35 changes: 19 additions & 16 deletions rust/tvm-rt/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,27 @@ impl Function {
let mut ret_val = ffi::TVMValue { v_int64: 0 };
let mut ret_type_code = 0i32;

check_call!(ffi::TVMFuncCall(
self.handle,
values.as_mut_ptr() as *mut ffi::TVMValue,
type_codes.as_mut_ptr() as *mut c_int,
num_args as c_int,
&mut ret_val as *mut _,
&mut ret_type_code as *mut _
));
let ret_code = unsafe {
ffi::TVMFuncCall(
self.handle,
values.as_mut_ptr() as *mut ffi::TVMValue,
type_codes.as_mut_ptr() as *mut c_int,
num_args as c_int,
&mut ret_val as *mut _,
&mut ret_type_code as *mut _,
)
};

if ret_code != 0 {
let raw_error = crate::get_last_error();
let error = match Error::from_raw_tvm(raw_error) {
Error::Raw(string) => Error::CallFailed(string),
jroesch marked this conversation as resolved.
Show resolved Hide resolved
e => e,
};
return Err(error);
}

let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32);
match rv {
RetValue::ObjectHandle(object) => {
let optr = crate::object::ObjectPtr::from_raw(object as _).unwrap();
// println!("after wrapped call: {}", optr.count());
crate::object::ObjectPtr::leak(optr);
}
_ => {}
};

Ok(rv)
}
Expand Down
14 changes: 7 additions & 7 deletions rust/tvm-rt/src/object/object_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ impl Object {
/// By using associated constants and generics we can provide a
/// type indexed abstraction over allocating objects with the
/// correct index and deleter.
pub fn base_object<T: IsObject>() -> Object {
pub fn base<T: IsObject>() -> Object {
let index = Object::get_type_index::<T>();
Object::new(index, delete::<T>)
}
Expand Down Expand Up @@ -351,15 +351,15 @@ mod tests {

#[test]
fn test_new_object() -> anyhow::Result<()> {
let object = Object::base_object::<Object>();
let object = Object::base::<Object>();
let ptr = ObjectPtr::new(object);
assert_eq!(ptr.count(), 1);
Ok(())
}

#[test]
fn test_leak() -> anyhow::Result<()> {
let ptr = ObjectPtr::new(Object::base_object::<Object>());
let ptr = ObjectPtr::new(Object::base::<Object>());
assert_eq!(ptr.count(), 1);
let object = ObjectPtr::leak(ptr);
assert_eq!(object.count(), 1);
Expand All @@ -368,7 +368,7 @@ mod tests {

#[test]
fn test_clone() -> anyhow::Result<()> {
let ptr = ObjectPtr::new(Object::base_object::<Object>());
let ptr = ObjectPtr::new(Object::base::<Object>());
assert_eq!(ptr.count(), 1);
let ptr2 = ptr.clone();
assert_eq!(ptr2.count(), 2);
Expand All @@ -379,7 +379,7 @@ mod tests {

#[test]
fn roundtrip_retvalue() -> Result<()> {
let ptr = ObjectPtr::new(Object::base_object::<Object>());
let ptr = ObjectPtr::new(Object::base::<Object>());
assert_eq!(ptr.count(), 1);
let ret_value: RetValue = ptr.clone().into();
let ptr2: ObjectPtr<Object> = ret_value.try_into()?;
Expand All @@ -401,7 +401,7 @@ mod tests {

#[test]
fn roundtrip_argvalue() -> Result<()> {
let ptr = ObjectPtr::new(Object::base_object::<Object>());
let ptr = ObjectPtr::new(Object::base::<Object>());
assert_eq!(ptr.count(), 1);
let ptr_clone = ptr.clone();
assert_eq!(ptr.count(), 2);
Expand Down Expand Up @@ -435,7 +435,7 @@ mod tests {
fn test_ref_count_boundary3() {
use super::*;
use crate::function::{register, Function};
let ptr = ObjectPtr::new(Object::base_object::<Object>());
let ptr = ObjectPtr::new(Object::base::<Object>());
assert_eq!(ptr.count(), 1);
let stay = ptr.clone();
assert_eq!(ptr.count(), 2);
Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-rt/src/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl From<std::string::String> for String {
fn from(s: std::string::String) -> Self {
let size = s.len() as u64;
let data = Box::into_raw(s.into_boxed_str()).cast();
let base = Object::base_object::<StringObj>();
let base = Object::base::<StringObj>();
StringObj { base, data, size }.into()
}
}
Expand All @@ -47,7 +47,7 @@ impl From<&'static str> for String {
fn from(s: &'static str) -> Self {
let size = s.len() as u64;
let data = s.as_bytes().as_ptr();
let base = Object::base_object::<StringObj>();
let base = Object::base::<StringObj>();
StringObj { base, data, size }.into()
}
}
Expand Down
Loading