diff --git a/ci/jenkins/generated/cpu_jenkinsfile.groovy b/ci/jenkins/generated/cpu_jenkinsfile.groovy index 627bb85862f3..e93400f6d637 100644 --- a/ci/jenkins/generated/cpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/cpu_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2025-02-15T19:40:24.687837 +// Generated at 2025-05-09T10:31:17.078676 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -516,11 +516,6 @@ def run_build(node_type) { script: "./${jenkins_scripts_root}/s3.py --action upload --bucket ${s3_bucket} --prefix ${s3_prefix}/cpu --items build/libtvm.so build/libtvm_runtime.so build/config.cmake build/libtvm_allvisible.so build/cpptest build/build.ninja build/CMakeFiles/rules.ninja", label: 'Upload artifacts to S3', ) - - ci_setup(ci_cpu) - // sh "${docker_run} ${ci_cpu} ./tests/scripts/task_golang.sh" - // TODO(@jroesch): need to resolve CI issue will turn back on in follow up patch - sh (script: "${docker_run} ${ci_cpu} ./tests/scripts/task_rust.sh", label: 'Rust build and test') }) } } diff --git a/ci/jenkins/templates/cpu_jenkinsfile.groovy.j2 b/ci/jenkins/templates/cpu_jenkinsfile.groovy.j2 index c84b0c48a29f..50e47f9bbfaf 100644 --- a/ci/jenkins/templates/cpu_jenkinsfile.groovy.j2 +++ b/ci/jenkins/templates/cpu_jenkinsfile.groovy.j2 @@ -32,10 +32,6 @@ cmake_build(ci_cpu, 'build') make_cpp_tests(ci_cpu, 'build') {{ m.upload_artifacts(tag='cpu', filenames=tvm_lib + tvm_allvisible + cpptest) }} - ci_setup(ci_cpu) - // sh "${docker_run} ${ci_cpu} ./tests/scripts/task_golang.sh" - // TODO(@jroesch): need to resolve CI issue will turn back on in follow up patch - sh (script: "${docker_run} ${ci_cpu} ./tests/scripts/task_rust.sh", label: 'Rust build and test') {% endcall %} {% set test_method_names = [] %} diff --git a/docker/README.md b/docker/README.md index ecc6e7948957..7d3fd22dc911 100644 --- a/docker/README.md +++ b/docker/README.md @@ -130,9 +130,3 @@ tasks. ```bash ./docker/ci_build.sh ci_gpu make -C docs html ``` - -- build golang test suite. - - ```bash - ./docker/build.sh ci_cpu tests/scripts/task_golang.sh - ``` diff --git a/golang/Makefile b/golang/Makefile deleted file mode 100644 index 76ac371b628d..000000000000 --- a/golang/Makefile +++ /dev/null @@ -1,81 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -.PHONY: clean all - -TVM_BASE = $(CURDIR)/../ -TARGET = gotvm -LIBS = -lm -ldl -NATIVE_SRC = tvm_runtime_pack.cc - -GOPATH=$(CURDIR)/gopath -GOPATHDIR=${GOPATH}/src/${TARGET}/ -CGO_CPPFLAGS="-I. -I${TVM_BASE}/ -I${TVM_BASE}/3rdparty/dmlc-core/include -I${TVM_BASE}/include -I${TVM_BASE}/3rdparty/dlpack/include/" -CGO_CXXFLAGS="-std=c++17 -DDMLC_USE_LOGGING_LIBRARY= -DTVM_USE_LIBBACKTRACE=0" -CGO_CFLAGS="-I${TVM_BASE}" -CGO_LDFLAGS="-ldl -lm" - -all: - @mkdir gopath 2>/dev/null || true - @mkdir gopath/src 2>/dev/null || true - @mkdir gopath/src/$(TARGET) 2>/dev/null || true - @cp src/$(TARGET).cc gopath/src/$(TARGET) - @cp src/$(TARGET).h gopath/src/$(TARGET) - @cp src/$(NATIVE_SRC) gopath/src/$(TARGET) - @cp src/*.go gopath/src/$(TARGET) - @export GOPATH=$(GOPATH); \ - export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \ - export CGO_CXXFLAGS=$(CGO_CXXFLAGS); \ - export CGO_CFLAGS=$(CGO_CFLAGS); \ - export CGO_LDFLAGS=$(CGO_LDFLAGS); \ - (cd $(GOPATHDIR) && go clean -cache \ - && golint && go build -o $(TARGET).a \ - && go install) - @find . -name gotvm.a - @#mkdir gopath/doc 2>/dev/null || true - @#godoc -html -goroot gopath/ gotvm | grep -v "for documentation on the gotvm command" > gopath/doc/gotvm.html - @#echo "Run 'godoc -http=:6060 -goroot=./gopath' for documentation" - -samples: all - cp gopath/pkg/linux_amd64/gotvm.a sample/ -rfa - make -C sample - -tests: all - @(cd sample; python3 deploy.py) - @export GOPATH=$(GOPATH); \ - export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \ - export CGO_CXXFLAGS=$(CGO_CXXFLAGS); \ - export CGO_CFLAGS=$(CGO_CFLAGS); \ - export CGO_LDFLAGS=$(CGO_LDFLAGS); \ - (cd $(GOPATHDIR) \ - && cp ../../../sample/deploy.so . \ - && go test -v) - -clean: - @if [ -d $(GOPATHDIR) ] ; then \ - export GOPATH=$(GOPATH); \ - export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \ - export CGO_CFLAGS=$(CGO_CFLAGS); \ - export CGO_LDFLAGS=$(CGO_LDFLAGS); \ - (cd $(GOPATHDIR) && go clean -cache); fi - @rm -rf gopath - @make -C sample clean - -lint: - @(cd src; golint) - @python3 ${TVM_BASE}/dmlc-core/scripts/lint.py gotvm cpp src/*.cc - @python3 ${TVM_BASE}/dmlc-core/scripts/lint.py gotvm cpp src/*.h diff --git a/golang/README.md b/golang/README.md deleted file mode 100644 index ee3ea8cc2e98..000000000000 --- a/golang/README.md +++ /dev/null @@ -1,126 +0,0 @@ - - - - - - - - - - - - - - - - - -# gotvm - Golang Frontend for TVM Runtime - -This folder contain golang interface for TVM runtime. It brings TVM runtime to Golang. - -- It enable c runtime api of tvm exposed to golang. -- It enables module loading (lib, graph and params) and inference operations. - -## Installation - -### Requirements - -- go compiler (https://golang.org/) version 0.10 or above. - -### Modules - -- src - Module that generates golang package corresponding to the c runtime api exposed from tvm source tree. - This process build golang package _gotvm.a_ - -- samples - Sample golang reference application to inference through gotvm package. - -### Build - -Once the Requirements are installed - -To build _gotvm_ package - -```bash -make -``` - -To build and run internal tests - -```bash -make tests -``` - -To build sample apps. - -```bash -make samples -``` - -## Run - -To Demonstrates sample TVM module compilation using python and deploy via golang. -```bash -./simple -``` - -To deploy a realtime module with lib, graph and param. -```bash -python3 gen_mobilenet_lib.py - -./complex -``` - -To demonstrate go function closure conversion to packed function handle. - -```bash -./pack_func_convert -``` - -To demonstrate a packed function handle given as an argument. - -```bash -./pack_func_handle_arg -``` - -To register go function with runtime as a global function. - -```bash -./pack_func_register -``` - -To demonstrate function closure passed as argument to a function call. - -```bash -./pack_func_closure_arg -``` - -To demonstrate function closure returned from a packed function. - -```bash -./pack_func_closure_return -``` - -## Documentation -gotvm.go is documented with sufficient information about gotvm package. -A html version documentation can be accessed by running below command after building runtime. - -```bash -godoc -http=:6060 -goroot=./gopath -``` -After above command try http://127.0.0.1:6060 from any browser. - -Also please refer to the sample applications under sample folder. - -## Docker -Docker setup may need below additions for dependencies and environment preparation. - -Please refer ```docker/install/ubuntu_install_golang.sh``` for the packages dependencies. - -go compiler 1.10 on ubuntu doesn't install on standard path, hence an explicit export may be needed as shown below. - -```bash -export PATH="/usr/lib/go-1.10/bin:$PATH" -``` diff --git a/golang/sample/Makefile b/golang/sample/Makefile deleted file mode 100644 index fd738b6f979f..000000000000 --- a/golang/sample/Makefile +++ /dev/null @@ -1,34 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -.PHONY: clean all - -SOURCES=$(wildcard *.go) -EXECUTABLE=$(patsubst %.go, %, $(SOURCES)) - -all: $(EXECUTABLE) - @golint - @python3 deploy.py - -%: %.o - @go tool link -linkmode external -extld "g++" -extldflags "-ldl" -o $@ $< - -%.o: %.go - @go tool compile -pack -o $@ $< - -clean: - @rm -f $(EXECUTABLE) *.so *.o *.a *.json *.params diff --git a/golang/sample/complex.go b/golang/sample/complex.go deleted file mode 100644 index c048207b8b5e..000000000000 --- a/golang/sample/complex.go +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Sample golang application deployment over tvm. - * \file complex.go - */ - -package main - -import ( - "fmt" - "io/ioutil" - "math/rand" - "./gotvm" - "runtime" -) - -// NNVM compiled model paths. -const ( - modLib = "./mobilenet.so" - modJSON = "./mobilenet.json" - modParams = "./mobilenet.params" -) - -// main -func main() { - defer runtime.GC() - // Welcome - fmt.Printf("TVM Version : v%v\n", gotvm.TVMVersion) - fmt.Printf("DLPACK Version: v%v\n\n", gotvm.DLPackVersion) - - // Query global functions available - funcNames, err := gotvm.FuncListGlobalNames() - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Global Functions:%v\n", funcNames) - - // Import tvm module (so) - modp, err := gotvm.LoadModuleFromFile(modLib) - if err != nil { - fmt.Print(err) - fmt.Printf("Please copy tvm compiled modules here and update the sample.go accordingly.\n") - fmt.Printf("You may need to update modLib, modJSON, modParams, tshapeIn, tshapeOut\n") - return - } - fmt.Printf("Module Imported:%p\n", modp) - bytes, err := ioutil.ReadFile(modJSON) - if err != nil { - fmt.Print(err) - return - } - jsonStr := string(bytes) - - // Load module on tvm runtime - call tvm.graph_executor.create - funp, err := gotvm.GetGlobalFunction("tvm.graph_executor.create") - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Calling tvm.graph_executor.create\n") - // Call function - graphrt, err := funp.Invoke(jsonStr, modp, (int64)(gotvm.KDLCPU), (int64)(0)) - if err != nil { - fmt.Print(err) - return - } - graphmod := graphrt.AsModule() - fmt.Printf("Graph executor Created\n") - - // Array allocation attributes - tshapeIn := []int64{1, 224, 224, 3} - tshapeOut := []int64{1, 1001} - - // Allocate input Array - inX, err := gotvm.Empty(tshapeIn, "float32", gotvm.CPU(0)) - if err != nil { - fmt.Print(err) - return - } - - // Allocate output Array - out, err := gotvm.Empty(tshapeOut) - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Input and Output Arrays allocated\n") - - // Get module function from graph executor : load_params - // Read params - bytes, err = ioutil.ReadFile(modParams) - if err != nil { - fmt.Print(err) - } - - // Load Params - funp, err = graphmod.GetFunction("load_params") - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Func load_params:%p\n", funp) - - // Call function - _, err = funp.Invoke(bytes) - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Module params loaded\n") - - // Set some data in input Array - inSlice := make([]float32, (224 * 224 * 3)) - rand.Seed(10) - rand.Shuffle(len(inSlice), func(i, j int) {inSlice[i], - inSlice[j] = rand.Float32(), - rand.Float32() }) - inX.CopyFrom(inSlice) - - // Set Input - funp, err = graphmod.GetFunction("set_input") - if err != nil { - fmt.Print(err) - return - } - - // Call function - _, err = funp.Invoke("input", inX) - if err != nil { - fmt.Print(err) - return - } - - fmt.Printf("Module input is set\n") - - // Run - funp, err = graphmod.GetFunction("run") - if err != nil { - fmt.Print(err) - return - } - - // Call function - _, err = funp.Invoke() - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Module Executed \n") - - // Call runtime function get_output - funp, err = graphmod.GetFunction("get_output") - if err != nil { - fmt.Print(err) - return - } - - // Call function - _, err = funp.Invoke(int64(0), out) - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Got Module Output \n") - - // Print results - outIntf, _ := out.AsSlice() - outSlice := outIntf.([]float32) - fmt.Printf("Result:%v\n", outSlice[:10]) -} diff --git a/golang/sample/pack_func_closure_arg.go b/golang/sample/pack_func_closure_arg.go deleted file mode 100644 index ff2d1e2754c4..000000000000 --- a/golang/sample/pack_func_closure_arg.go +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Sample golang application to demonstrate go-closure given to a packed function argument. - * \file pack_func_closure_arg.go - */ - -package main - -import ( - "fmt" - "./gotvm" -) - - -// sampleFunctionArg receives a Packed Function handle and calls it. -func sampleFunctionArg(args ...*gotvm.Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - // Call Packed Function - retVal, err = pfunc.Invoke(args[1].AsInt64(), args[2].AsInt64()) - return -} - -// main -func main() { - // Not passing a function name implicitely - // picks the name from reflection as "main.sampleDunctionArg" - gotvm.RegisterFunction(sampleFunctionArg); - fmt.Printf("Registered: sampleFunctionArg\n") - - // Get registered global function. - funp, err := gotvm.GetGlobalFunction("main.sampleFunctionArg") - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("GetGlobalFunction: main.sampleFunctionArg - Success\n") - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*gotvm.Value) (retVal interface{}, err error) { - for _, v := range args { - fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) - } - val1 := args[0].AsInt64() - val2 := args[1].AsInt64() - retVal = int64(val1+val2) - return - } - - // Call function - result, err := funp.Invoke(funccall, 30, 50) - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Invoked sampleFunctionArg with function closure arg : Result:%v\n", result.AsInt64()) -} diff --git a/golang/sample/pack_func_closure_return.go b/golang/sample/pack_func_closure_return.go deleted file mode 100644 index e010b9395361..000000000000 --- a/golang/sample/pack_func_closure_return.go +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Sample golang application to demonstrate go-closure returned from a callback function. - * \file pack_func_closure_return.go - */ - -package main - -import ( - "fmt" - "./gotvm" -) - -// sampleFunctionCb returns a function closure which is embed as packed function in TVMValue. -func sampleFunctionCb(args ...*gotvm.Value) (retVal interface{}, err error) { - funccall := func (cargs ...*gotvm.Value) (fret interface{}, ferr error) { - for _, v := range cargs { - fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) - } - val1 := cargs[0].AsInt64() - val2 := cargs[1].AsInt64() - fret = int64(val1+val2) - return - } - retVal = funccall - return -} - -// main -func main() { - // Not passing a function name implicitely - // picks the name from reflection as "main.sampleDunctionCb" - gotvm.RegisterFunction(sampleFunctionCb); - fmt.Printf("Registered: sampleFunctionCb\n") - - // Get registered global function - funp, err := gotvm.GetGlobalFunction("main.sampleFunctionCb") - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("GetGlobalFunction: main.sampleFunctionCb - Success\n") - - // Call function - result, err := funp.Invoke() - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Invoked main.sampleFunctionCb via Function handle\n") - - pfunc := result.AsFunction() - fmt.Printf("Function Handle received via Packed Function call:%T - %v \n", pfunc, pfunc) - - pfuncRet, err := pfunc.Invoke(30, 40) - fmt.Printf("Invoked closure inside sampleFunctionCb result:%v\n", pfuncRet.AsInt64()) -} diff --git a/golang/sample/pack_func_convert.go b/golang/sample/pack_func_convert.go deleted file mode 100644 index b6d1fbf24d46..000000000000 --- a/golang/sample/pack_func_convert.go +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Sample golang application to demonstrate function conversion to packed function. - * \file pack_func_convert.go - */ - -package main - -import ( - "fmt" - "./gotvm" -) - -// sampleCb is a simple golang callback function like C = A + B. -func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) { - for _, v := range args { - fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) - } - val1 := args[0].AsInt64() - val2 := args[1].AsInt64() - retVal = int64(val1+val2) - return -} - -// main -func main() { - // Welcome - - // Simple convert to a packed function - fhandle, err := gotvm.ConvertFunction(sampleCb) - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Converted function\n") - - retVal, err := fhandle.Invoke(10, 20) - fmt.Printf("Invoke Completed\n") - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Result:%v\n", retVal.AsInt64()) -} diff --git a/golang/sample/pack_func_handle_arg.go b/golang/sample/pack_func_handle_arg.go deleted file mode 100644 index d5a3f074946e..000000000000 --- a/golang/sample/pack_func_handle_arg.go +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Sample golang application to demonstrate converted packed - * function handle passed to another packed function. - * \file pack_func_handle_arg.go - */ - -package main - -import ( - "fmt" - "./gotvm" -) - -// sampleCb is a simple golang callback function like C = A + B. -func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) { - for _, v := range args { - fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) - } - val1 := args[0].AsInt64() - val2 := args[1].AsInt64() - retVal = int64(val1+val2) - return -} - -// sampleFunctionArg receives a Packed Function handle and calls it. -func sampleFunctionArg(args ...*gotvm.Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - - // Call Packed Function - retVal, err = pfunc.Invoke(args[1], args[2]) - return -} - -// main -func main() { - // Simple convert to a packed function - fhandle, err := gotvm.ConvertFunction(sampleCb) - if err != nil { - fmt.Print(err) - return - } - - gotvm.RegisterFunction(sampleFunctionArg); - fmt.Printf("Registered: sampleFunctionArg\n") - - funp, err := gotvm.GetGlobalFunction("main.sampleFunctionArg") - if err != nil { - fmt.Print(err) - return - } - - retVal, err := funp.Invoke(fhandle, 10, 20) - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("Result:%v\n", retVal.AsInt64()) -} diff --git a/golang/sample/pack_func_register.go b/golang/sample/pack_func_register.go deleted file mode 100644 index ac4ea438dbef..000000000000 --- a/golang/sample/pack_func_register.go +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Sample golang application to demonstrate function register into TVM global functions. - * \file pack_func_register.go - */ - -package main - -import ( - "fmt" - "./gotvm" - "strings" -) - -// sampleCb is a simple golang callback function like C = A + B. -func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) { - for _, v := range args { - fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) - } - val1 := args[0].AsInt64() - val2 := args[1].AsInt64() - retVal = int64(val1+val2) - return -} - -// main -func main() { - // Register sampleCb with TVM packed function system and call and check Global Function List. - gotvm.RegisterFunction(sampleCb, "sampleCb"); - // Query global functions available - funcNames, err := gotvm.FuncListGlobalNames() - if err != nil { - fmt.Print(err) - return - } - - found := 0 - for ii := range (funcNames) { - if strings.Compare(funcNames[ii], "sampleCb") == 0 { - found = 1 - } - } - if found == 0 { - fmt.Printf("Function registerd but, not listed\n") - return - } - - - // Get "sampleCb" and verify the call. - funp, err := gotvm.GetGlobalFunction("sampleCb") - if err != nil { - fmt.Print(err) - return - } - - // Call function - result, err := funp.Invoke((int64)(10), (int64)(20)) - if err != nil { - fmt.Print(err) - return - } - fmt.Printf("sampleCb result: %v\n", result.AsInt64()) -} diff --git a/golang/sample/simple.go b/golang/sample/simple.go deleted file mode 100644 index 7bb503db4598..000000000000 --- a/golang/sample/simple.go +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Sample golang application deployment over tvm. - * \file simple.go - */ - -package main - -import ( - "fmt" - "runtime" - "./gotvm" - "math/rand" -) - -// NNVM compiled model paths. -const ( - modLib = "./deploy.so" -) - -// main -func main() { - // Welcome - defer runtime.GC() - fmt.Printf("TVM Version : v%v\n", gotvm.TVMVersion) - fmt.Printf("DLPACK Version: v%v\n\n", gotvm.DLPackVersion) - - // Import tvm module (so) - modp, _ := gotvm.LoadModuleFromFile(modLib) - fmt.Printf("Module Imported\n") - - - // Allocate Array for inputs and outputs. - // Allocation by explicit type and device. - tshapeIn := []int64{4} - inX, _ := gotvm.Empty(tshapeIn, "float32", gotvm.CPU(0)) - - // Default allocation on CPU - inY, _ := gotvm.Empty(tshapeIn, "float32") - - // Default allocation to type "float32" and on CPU - out, _ := gotvm.Empty(tshapeIn) - fmt.Printf("Input and Output Arrays allocated\n") - - // Fill Input Data : inX , inY - inXSlice := make([]float32, 4) - inYSlice := make([]float32, 4) - for i := range inXSlice { - inXSlice[i] = rand.Float32() - inYSlice[i] = rand.Float32() - } - - - // Copy the data on target memory through runtime CopyFrom api. - inX.CopyFrom(inXSlice) - inY.CopyFrom(inYSlice) - fmt.Printf("X: %v\n", inXSlice) - fmt.Printf("Y: %v\n", inYSlice) - - // Get function "myadd" - funp, _ := modp.GetFunction("myadd") - - // Call function - funp.Invoke(inX, inY, out) - fmt.Printf("Module function myadd executed\n") - - // Get the output tensor as an interface holding a slice through runtime CopyTo api. - outSlice, _ := out.AsSlice() - - // Print results - fmt.Printf("Result:%v\n", outSlice.([]float32)) -} diff --git a/golang/src/array_test.go b/golang/src/array_test.go deleted file mode 100644 index a2636a8b0f20..000000000000 --- a/golang/src/array_test.go +++ /dev/null @@ -1,614 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file array_test.go - */ - - -package gotvm - -import ( - "testing" - "unsafe" - "math/rand" -) - -// Create an array and check size. -func TestArrayCreateSize(t *testing.T) { - _, err := Empty([]int64{4}) - if err != nil { - t.Error(err.Error()) - return - } - - _, err = Empty([]int64{4, 5, 6}) - if err != nil { - t.Error(err.Error()) - return - } - - _, err = Empty([]int64{}) - if err == nil { - t.Error("Expected err for empty Array created, but didn't got !!") - return - } -} - -// Check array creation via various different arguments. -func TestArrayCreateArgs(t *testing.T) { - _, err := Empty([]int64{4, 2}, "float32", CPU(0)) - if err != nil { - t.Error(err.Error()) - return - } - - _, err = Empty([]int64{4, 2}, "float32") - if err != nil { - t.Error(err.Error()) - return - } - - _, err = Empty([]int64{4, 2}, CPU(0)) - if err != nil { - t.Error(err.Error()) - return - } - - _, err = Empty([]int64{4, 2}, CPU(0), "float32") - if err != nil { - t.Error(err.Error()) - return - } -} - -// Create an array and check the NDim. -func TestArrayNDim(t *testing.T) { - arr, err := Empty([]int64{4, 5, 6}) - if err != nil { - t.Error(err.Error()) - return - } - - if 3 != arr.GetNdim() { - t.Errorf("GetNdim failed Expected: 3 Got :%v\n", arr.GetNdim()) - return - } -} - -// Create an array and check Shape. -func TestArrayShape(t *testing.T) { - arr, err := Empty([]int64{4, 5, 6}) - if err != nil { - t.Error(err.Error()) - return - } - - shape := arr.GetShape() - if len(shape) != 3 { - t.Errorf("Shape slice expected: 3 Got :%v\n", len(shape)) - return - } - - if shape[0] != 4 || shape[1] != 5 || shape[2] != 6 { - t.Errorf("Shape values expected {4, 5, 6} Got : %v\n", shape); - return - } -} - -// Create an array and check created Device. -func TestArrayDevice(t *testing.T) { - // TODO: Could some test cases for other targets - arr, err := Empty([]int64{4}, CPU(0)) - if err != nil { - t.Error(err.Error()) - return - } - - dev := arr.GetDevice() - if dev.DeviceType != KDLCPU { - t.Errorf("Dev DeviceType expected: %v Got :%v\n", KDLCPU, dev.DeviceType) - return - } - if dev.DeviceID != 0 { - t.Errorf("Dev DeviceID expected: %v Got :%v\n", KDLCPU, dev.DeviceID) - return - } - - arr, err = Empty([]int64{4}, CPU(2)) - if err != nil { - t.Error(err.Error()) - return - } - - dev = arr.GetDevice() - if dev.DeviceType != KDLCPU { - t.Errorf("Dev DeviceType expected: %v Got :%v\n", KDLCPU, dev.DeviceType) - return - } - if dev.DeviceID != 2 { - t.Errorf("Dev DeviceID expected: %v Got :%v\n", KDLCPU, dev.DeviceID) - return - } -} - -// Create array of different dtypes and check dtypes. -func TestArrayDType(t *testing.T) { - for _, dtype := range []string{"int8", "int16", "int32", "int64", - "uint8", "uint16", "uint32", "uint64", - "float32", "float64"} { - arr, err := Empty([]int64{4}, dtype) - if err != nil { - t.Error(err.Error()) - return - } - - if dtype != arr.GetDType() { - t.Errorf("Dtype expected: %v Got :%v\n", dtype, arr.GetDType()) - return - } - } -} - -// Copy Int8 data to created Array and verify. -func TestArrayCopySliceInt8(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "int8") - - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen) - rand.Read(bdata) - data := (*[1<<31]int8)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []int8: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - - dataRet := ret.([]int8) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy Int16 data to created Array and verify. -func TestArrayCopySliceInt16(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "int16") - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen*2) - rand.Read(bdata) - data := (*[1<<31]int16)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - switch ret.(type) { - case []int16: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - - dataRet := ret.([]int16) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy Int32 data to created Array and verify. -func TestArrayCopySliceInt32(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "int32") - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen*4) - rand.Read(bdata) - data := (*[1<<31]int32)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []int32: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]int32) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy Int64 data to created Array and verify. -func TestArrayCopySliceInt64(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "int64") - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen*8) - rand.Read(bdata) - data := (*[1<<31]int64)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []int64: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]int64) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy UInt8 data to created Array and verify. -func TestArrayCopySliceUInt8(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "uint8") - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen) - rand.Read(bdata) - data := (*[1<<31]uint8)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []uint8: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]uint8) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy UInt16 data to created Array and verify. -func TestArrayCopySliceUInt16(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "uint16") - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen*2) - rand.Read(bdata) - data := (*[1<<31]uint16)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []uint16: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]uint16) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy UInt32 data to created Array and verify. -func TestArrayCopySliceUInt32(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "uint32") - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen*4) - rand.Read(bdata) - data := (*[1<<31]uint32)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []uint32: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]uint32) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy UInt64 data to created Array and verify. -func TestArrayCopySliceUInt64(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "uint64") - if err != nil { - t.Error(err.Error()) - return - } - - bdata := make([]byte, dlen*8) - rand.Read(bdata) - data := (*[1<<31]uint64)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []uint64: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]uint64) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} - -// Copy Float32 data to created Array and verify. -func TestArrayCopySliceFloat32(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "float32") - if err != nil { - t.Error(err.Error()) - return - } - - data := make([]float32, dlen) - - for i := range data { - data[i] = rand.Float32() - } - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []float32: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]float32) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v \nGot :%v \n", data, dataRet) - return - } - } -} - -// Copy Float64 data to created Array and verify. -func TestArrayCopySliceFloat64(t *testing.T) { - dlen := int64(32) - arr, err := Empty([]int64{4, dlen/4}, "float64") - if err != nil { - t.Error(err.Error()) - return - } - - data := make([]float64, dlen) - - for i := range data { - data[i] = rand.Float64() - } - - err = arr.CopyFrom(data) - if err != nil { - t.Error(err.Error()) - return - } - - ret, err := arr.AsSlice() - if err != nil { - t.Error(err.Error()) - return - } - - switch ret.(type) { - case []float64: - default: - t.Errorf("Expected : %T but got :%T\n", data, ret) - return - } - dataRet := ret.([]float64) - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v\n", data, dataRet) - return - } - } -} diff --git a/golang/src/bytearray.go b/golang/src/bytearray.go deleted file mode 100644 index 4dcecef4a9b7..000000000000 --- a/golang/src/bytearray.go +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for TVMByteArray interface. - * \file bytearray.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "unsafe" -) - -// ByteArray type wraps the TVMByteArray of C runtime API. -// -// This can be used to hold raw data like params of a model. -type ByteArray uintptr - -// nativeCPtr returns the type freed unitptr for ByteArray. -func (tbytearray ByteArray) nativeCPtr() (retVal uintptr) { - retVal = (uintptr)(tbytearray) - return -} - -// SetData is used to intialize ByteArray from a golang string object. -// -// This method initialize both data and data size of the underlaying object. -// This function handles freeing old data object if any before allocating new. -// -// `val` is the golang string object from which the ByteArray is initialized. -func (tbytearray ByteArray) setData(val string) { - bufPtr := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data - if bufPtr == (*C.char)(C.NULL) { - C.free(unsafe.Pointer(bufPtr)) - } - - ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data = C.CString(val) - ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).size = C.ulong(len(val)) -} - -// getData returns the golang byte slice corresponding to the ByteArray. -func (tbytearray ByteArray) getData() (retVal []byte) { - val := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data - blen := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).size - retVal = C.GoBytes(unsafe.Pointer(val), C.int(blen)) - return -} - -// newByteArray initilizes the native TVMByteArray object with given byte slice -// -//`val` is the golang byte array used to initialize. -// -// returns newly created ByteArray. -func newByteArray(val []byte) (retVal ByteArray) { - handle := ByteArray(C.malloc(C.sizeof_TVMByteArray)) - ((*C.TVMByteArray)(unsafe.Pointer(handle))).data = (*C.char)(C.NULL) - ((*C.TVMByteArray)(unsafe.Pointer(handle))).size = 0 - handle.setData(string(val)) - retVal = handle - return -} - -// deleteTVMByteArray releases the allocated native object of ByteArray. -// -// This delete handles freeing of underlaying native data object too. -func (tbytearray ByteArray) deleteTVMByteArray() { - bufPtr := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data - C.free(unsafe.Pointer(bufPtr)) - C.free(unsafe.Pointer(tbytearray.nativeCPtr())) -} diff --git a/golang/src/bytearray_test.go b/golang/src/bytearray_test.go deleted file mode 100644 index c4047c50a605..000000000000 --- a/golang/src/bytearray_test.go +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file bytearray_test.go - */ - - -package gotvm - -import ( - "testing" - "math/rand" -) - -// Check ByteArray creation from byte slice and verify the data. -func TestByteArrayGet(t *testing.T) { - data := make([]byte, 1024) - rand.Read(data) - - barr := newByteArray(data) - dataRet := barr.getData() - if len(data) != len(dataRet) { - t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) - return - } - for i := range data { - if data[i] != dataRet[i] { - t.Errorf("Data expected: %v Got :%v at : %v\n", data[i], dataRet[i], i) - return - } - } -} diff --git a/golang/src/device.go b/golang/src/device.go deleted file mode 100644 index 2918cf6a0f0f..000000000000 --- a/golang/src/device.go +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for Device interface - * \file device.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -// KDLCPU is golang enum correspond to TVM device type kDLCPU. -var KDLCPU = int32(C.kDLCPU) -// kDLCUDA is golang enum correspond to TVM device type kDLCUDA. -var kDLCUDA = int32(C.kDLCUDA) -// kDLCUDAHost is golang enum correspond to TVM device type kDLCUDAHost. -var kDLCUDAHost = int32(C.kDLCUDAHost) -// KDLOpenCL is golang enum correspond to TVM device type kDLOpenCL. -var KDLOpenCL = int32(C.kDLOpenCL) -// KDLMetal is golang enum correspond to TVM device type kDLMetal. -var KDLMetal = int32(C.kDLMetal) -// KDLVPI is golang enum correspond to TVM device type kDLVPI. -var KDLVPI = int32(C.kDLVPI) -// KDLROCM is golang enum correspond to TVM device type kDLROCM. -var KDLROCM = int32(C.kDLROCM) -// KDLVulkan is golang enum correspond to TVM device type kDLVulkan. -var KDLVulkan = int32(C.kDLVulkan) -// KExtDev is golang enum correspond to TVM device type kDLExtDev. -var KExtDev = int32(C.kDLExtDev) - -// Device dtype corresponding to Device aka DLDevice -type Device struct { - DeviceType int32 - DeviceID int32 -} - -// CPU returns the Device object for CPU target on given index -func CPU(index int32) Device { - return Device{KDLCPU, index} -} - -// CUDA returns the Device object for CUDA target on given index -func CUDA(index int32) Device { - return Device{kDLCUDA, index} -} - -// CUDAHost returns the Device object for CUDAHost target on given index -func CUDAHost(index int32) Device { - return Device{kDLCUDAHost, index} -} - -// OpenCL returns the Device object for OpenCL target on given index -func OpenCL(index int32) Device { - return Device{KDLOpenCL, index} -} - -// Metal returns the Device object for Metal target on given index -func Metal(index int32) Device { - return Device{KDLMetal, index} -} - -// VPI returns the Device object for VPI target on given index -func VPI(index int32) Device { - return Device{KDLVPI, index} -} - -// ROCM returns the Device object for ROCM target on given index -func ROCM(index int32) Device { - return Device{KDLROCM, index} -} - -// Vulkan returns the Device object for Vulkan target on given index -func Vulkan(index int32) Device { - return Device{KDLVulkan, index} -} diff --git a/golang/src/error.go b/golang/src/error.go deleted file mode 100644 index edd8116a3612..000000000000 --- a/golang/src/error.go +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for error related API interface. - * \file error.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "unsafe" -) - -// getTVMLastError returns the detailed error string for any api called in TVM runtime. -// -// This is useful when any api returns non zero value. -// -// Returns golang string for the corresponding native error message. -func getTVMLastError() (retVal string) { - errStr := C.TVMGetLastError() - retVal = C.GoString(errStr) - return -} - -func setTVMLastError(errStr string) { - cstr := C.CString(errStr) - C.TVMAPISetLastError(cstr) - C.free(unsafe.Pointer(cstr)) -} diff --git a/golang/src/error_test.go b/golang/src/error_test.go deleted file mode 100644 index 3fe912db110e..000000000000 --- a/golang/src/error_test.go +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file error_test.go - */ - - -package gotvm - -import ( - "testing" - "strings" -) - -// Check err receiving from TVM global function. -func TestErrorTest(t *testing.T) { - _, err := LoadModuleFromFile("dummy.so") - if err == nil { - t.Error("Expected an error, but not received\n") - return - } - - errStr := err.Error() - if !(strings.Contains(errStr, string("cannot open shared object"))) { - t.Error("Ah! TVM didn't report an error\n") - } -} diff --git a/golang/src/function.go b/golang/src/function.go deleted file mode 100644 index 7b1c5d27d429..000000000000 --- a/golang/src/function.go +++ /dev/null @@ -1,383 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for TVMFunction interface. - * \file function.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "unsafe" - "encoding/binary" - "errors" - "runtime" - "reflect" - "fmt" -) - -// Function type in golang hold pointer for the TVMFunction handle. -type Function uintptr - -// nativeCPtr returns type freed uintptr for the Function. -func (tvmfunction Function) nativeCPtr() (retVal uintptr) { - retVal = (uintptr)(tvmfunction) - return -} - -// Invoke calls the TVM packed function referred by the handle with given arguments. -func (tvmfunction *Function) Invoke(args ...interface{}) (retVal *Value, err error) { - funccall := func (fargs ...interface{}) (*Value, error) { - return callNativeFunction(tvmfunction, fargs) - } - // Check is any args are contain any ValueArray - // Possible is it's a args forward from one packed function to another. - valueArrayFound := false - for ii := range args { - switch args[ii].(type) { - case []*Value: - valueArrayFound = true - } - } - - if !valueArrayFound { - return funccall(args...) - } - if len(args) != 1 { - err = fmt.Errorf("Not supported if packed function args are a mix of []Value and other types") - return - } - - valArray := args[0].([]*Value) - if len(valArray) > 0 { - newArgs := make([]interface{}, len(valArray)) - for ii := range valArray { - newVal := newTVMValue() - newVal.moveFrom(valArray[ii]) - newArgs[ii] = newVal - } - - return funccall(newArgs...) - } - return funccall() -} - -// FuncListGlobalNames is used to query global callable packed function names from TVM. -// -// returns slice of string holding function names and error if any. -func FuncListGlobalNames() (retVal []string, err error) { - var str string - ret := (int32)(C._TVMFuncListGlobalNames(unsafe.Pointer((&str)))) - if ret != 0 { - err = errors.New(getTVMLastError()) - return - } - - str = goStringFromNative(*(*string)(unsafe.Pointer(&str))) - bin := binary.LittleEndian - size := bin.Uint64([]byte(str[:8])) - str = str[8:] - retVal = make([]string, size) - for i := range retVal { - len := bin.Uint64([]byte(str[:8])) - str = str[8:] - retVal[i] = str[:len] - str = str[len:] - } - return -} - -// GetGlobalFunction is to get handle to the given global function name. -// -// `funcname` is the name of global packed function. -// -// returns a function closure with signature -// func (args ...interface{}) (interface{}, error) and error if any. -// -// The closure function can be used to call Function with arguments directly. -// -// Variadic arguments can be any type which can be embed into Value. -func GetGlobalFunction(funcname string) (retVal *Function, err error) { - var funp uintptr - - cfuncname := C.CString(funcname) - ret := (int32)(C.TVMFuncGetGlobal(cfuncname, - (*C.TVMFunctionHandle)(unsafe.Pointer(&funp)))) - C.free(unsafe.Pointer(cfuncname)) - - if ret != 0 { - err = errors.New(getTVMLastError()) - return - } - - handle := new(Function) - *handle = Function(funp) - finalizer := func(fhandle *Function) { - nativeTVMFuncFree(fhandle) - fhandle = nil - } - runtime.SetFinalizer(handle, finalizer) - retVal = handle - return -} - -// callNativeFunction is routine which calls gotvm native wrapper with given arguments. -// -// `handle` is the handle for Function. -// -// `args` are the variadic arguments to the Function. -// -// returns the interface for the return value from TVM if any and error if any. -func callNativeFunction(handle *Function, args []interface{}) (retVal *Value, err error) { - argsIn := make([]*Value, len(args)) - var typeCodes []int32 - if len(args) != 0 { - typeCodes = make([]int32, len(args)) - } else { - typeCodes = make([]int32, 1) - } - - for ii := range args { - argsIn[ii] = newTVMValue() - if typeCodes[ii], err = argsIn[ii].setValue(args[ii]); err != nil { - return - } - } - - retVal = newTVMValue() - argsOut := []*Value{retVal} - retTypeCode := KNull - err = nativeTVMFuncCall(handle, argsIn, typeCodes, argsOut, &retTypeCode) - if err != nil { - retVal = nil - return - } - retVal.isLocal = false - retVal.dtype = retTypeCode - return -} - -// nativeTVMFuncFree free the function handle allocated in TVM runtime. -// -// `funp` is the Function handle to be freed. -func nativeTVMFuncFree(funp *Function) (retVal int32) { - retVal = (int32) (C.TVMFuncFree(C.TVMFunctionHandle(funp.nativeCPtr()))) - return -} - -// nativeToGoSlice converts native TVMValue array to Golang slice of TVMValue -// -// -func nativeToGoSlice(nargValues (*C.void), argValues []*Value, typeCodes []int32) { - for ii := range argValues { - C._TVMValueNativeGet(unsafe.Pointer(argValues[ii].nativeCPtr()), - unsafe.Pointer(nargValues), - C.int(int32(ii))) - argValues[ii].dtype = typeCodes[ii] - } -} - -// nativeFromGoSlice converts golang slice of TVMValue to native TVMValue array. -// -// -func nativeFromGoSlice(argValues []*Value) (nptr (*C.void)) { - nargValues := ((uintptr)(C.malloc(C.ulong(C.sizeof_TVMValue * len(argValues))))) - for ii := range argValues { - C._TVMValueNativeSet(unsafe.Pointer(nargValues), - unsafe.Pointer(argValues[ii].nativeCPtr()), - C.int(int32(ii))) - } - nptr = (*C.void)(unsafe.Pointer(nargValues)) - return -} - -// nativeTVMFuncCall executes the function with given arguments -// -// `funp` Function handle to the packed function. -// -// `argValues` is the slice of Value which are arguments to the packed function. -// -// `typeCodes` is the alice of argument type codes corresponding to argValues. -// -// `retValues` is return argument which is slice of return values from the packed function. -// -// `retTypeCode` is int32 holding type codes for retValue -// -// Returns err indicating native error if any. -func nativeTVMFuncCall(funp *Function, argValues []*Value, typeCodes []int32, - retValues []*Value, retTypeCode *int32) (err error) { - nargValues := nativeFromGoSlice(argValues) - nretValues := nativeFromGoSlice(retValues) - result := (int32)(C.TVMFuncCall(C.TVMFunctionHandle(*funp), - (*C.TVMValue)(unsafe.Pointer(nargValues)), - (*C.int)(unsafe.Pointer(&(typeCodes[0]))), - C.int(len(argValues)), - (*C.TVMValue)(unsafe.Pointer(nretValues)), - (*C.int)(unsafe.Pointer(retTypeCode)))) - nativeToGoSlice(nargValues, argValues, typeCodes) - nativeToGoSlice(nretValues, retValues, (*[1<<31] int32)(unsafe.Pointer(retTypeCode))[:1:1]) - C.free(unsafe.Pointer(nargValues)) - C.free(unsafe.Pointer(nretValues)) - - if result != 0 { - err = errors.New(getTVMLastError()) - } - return -} - -// goCallBack is a structure holding the go callback function pointer. -// This wrapping is necessary as cgo doesn't support -// passing golang functions type conversion to native. -type goCallBack struct { - cb func (args ...*Value) (interface{}, error) -} - -//export goTVMCallback -func goTVMCallback(args C.native_voidp, typeCodes C.native_voidp, numArgs int32, - retArg C.native_voidp, resourceHandle C.native_voidp) (ret int32){ - fcb := (*goCallBack)(resourceHandle) - // Make Value Sice from native TVMValue pointer. - argValues := make([]*Value, numArgs) - - for ii := range argValues { - argValues[ii] = newTVMValue() - argValues[ii].isLocal = false - } - - // Prepare arguments for golang callback function - nativeToGoSlice((*C.void)(unsafe.Pointer(args)), argValues, - (*[1<<31] int32)(unsafe.Pointer(typeCodes))[:numArgs:numArgs]) - cbargs := argValues - - // Execute the callback - retVal, err := fcb.cb(cbargs...) - if err != nil { - errStr := err.Error() - setTVMLastError(errStr) - return -1 - } - - // It's possible a packed function directly return - // the return value of another packed function. - // - // Inside a packed func : - // ```return pfunc.Invoke(args)``` - // - // In this case pfunc returns nil which is - // returned as an interface holding nil *Value. - // Which becomes a valid retVal holding nil *Value. - isRetNull := false - switch retVal.(type) { - case *Value: - pRet := retVal.(*Value) - if pRet == nil { - isRetNull = true - } - } - - // Handle return value from callback function - if retVal != nil && !isRetNull { - var retTypeCode int32 - retValues := []*Value{newTVMValue()} - - retTypeCode, err = retValues[0].setValue(retVal) - if err != nil { - errStr := err.Error() - setTVMLastError(errStr) - return -1 - } - nretValues := nativeFromGoSlice(retValues) - - // Handle KStr, KBytes: Local finalizers shouldn't try freeing them. - retValues[0].isLocal = false - - apiRet := (int32) (C.TVMCFuncSetReturn(C.TVMRetValueHandle(retArg), - (*C.TVMValue)(unsafe.Pointer(nretValues)), - (*C.int)(unsafe.Pointer(&retTypeCode)), 1)) - C.free(unsafe.Pointer(nretValues)) - if apiRet != 0 { - errStr := string("TVMCFuncSetReturn failed ") - setTVMLastError(errStr) - } - } - return -} - -// ConvertFunction converts given golang function to TVM packed function. -// -// `args[0]` function pointer for a type ```func (args ...interface{}) (interface{})``` -// -// Returns Function handle and err if any. -func ConvertFunction(args ...interface{}) (retVal *Function, err error) { - function := args[0].(func (args ...*Value) (interface{}, error)) - fcb := &goCallBack{cb:function} - var funp uintptr - - result := (int32) (C._ConvertFunction(unsafe.Pointer(fcb), - unsafe.Pointer(&funp))) - if result != 0 { - err = errors.New(getTVMLastError()) - } - - handle := new(Function) - *handle = Function(funp) - finalizer := func(fhandle *Function) { - nativeTVMFuncFree(fhandle) - fhandle = nil - } - runtime.SetFinalizer(handle, finalizer) - retVal = handle - return -} - -// RegisterFunction registers the golang func in TVM runtime global space. -// -// `args[0]` function pointer for a type ```func (args ...interface{}) (interface{})``` -// -// `args[1]` Optional argument of function name with which it will be registered. -// If not passed we use function name from reflection. -// -// Returns err indicating native error if any. -func RegisterFunction(args ...interface{}) (err error) { - fhandle, err := ConvertFunction(args...) - if err != nil { - return - } - - funcname := runtime.FuncForPC(reflect.ValueOf(args[0]).Pointer()).Name() - if len(args) > 1 { - funcname = args[1].(string) - } - - cfuncname := C.CString(funcname) - result := (int32) (C.TVMFuncRegisterGlobal(cfuncname, - C.TVMFunctionHandle(*fhandle), - 0)); // Override = False - C.free(unsafe.Pointer(cfuncname)) - if result != 0 { - err = errors.New(getTVMLastError()) - } - // Clear the finalizer as we don't need to control it anymore. - runtime.SetFinalizer(fhandle, nil) - return -} diff --git a/golang/src/function_test.go b/golang/src/function_test.go deleted file mode 100644 index 0830d16419a2..000000000000 --- a/golang/src/function_test.go +++ /dev/null @@ -1,349 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file function_test.go - */ - -package gotvm - -import ( - "testing" - "reflect" - "math/rand" - "strings" - "fmt" -) - -// Check global function list API -func TestFunctionGlobals(t *testing.T) { - funcNames, err := FuncListGlobalNames() - if err != nil { - t.Error(err.Error()) - return - } - if len(funcNames) < 1 { - t.Errorf("Global Function names received:%v\n", funcNames) - } -} - -// Check GetFunction API -func TestFunctionGlobalGet(t *testing.T) { - funp, err := GetGlobalFunction("tvm.graph_executor.create") - if err != nil { - t.Error(err.Error()) - return - } - if reflect.TypeOf(funp).Kind() != reflect.Ptr { - t.Error("Function type mis matched\n") - return - } -} - -func TestFunctionModuleGet(t *testing.T) { - modp, err := LoadModuleFromFile("./deploy.so") - if err != nil { - t.Error(err.Error()) - return - } - funp, err := modp.GetFunction("myadd") - if err != nil { - t.Error(err.Error()) - return - } - if reflect.TypeOf(funp).Kind() != reflect.Ptr { - t.Error("Function type mis matched\n") - return - } - - dlen := int64(1024) - shape := []int64{dlen} - inX, _ := Empty(shape) - inY, _ := Empty(shape) - out, _ := Empty(shape) - dataX := make([]float32, (dlen)) - dataY := make([]float32, (dlen)) - outExpected := make([]float32, (dlen)) - - for i := range dataX { - dataX[i] = rand.Float32() - dataY[i] = rand.Float32() - outExpected[i] = dataX[i] + dataY[i] - } - - inX.CopyFrom(dataX) - inY.CopyFrom(dataY) - - funp.Invoke(inX, inY, out) - outi, _ := out.AsSlice() - outSlice := outi.([]float32) - if len(outSlice) != len(outExpected) { - t.Errorf("Data expected Len: %v Got :%v\n", len(outExpected), len(outSlice)) - return - } - for i := range outSlice { - if outExpected[i] != outSlice[i] { - t.Errorf("Data expected: %v Got :%v at index %v\n", outExpected[i], outSlice[i], i) - return - } - } -} - -// Check FunctionConvert API -func TestFunctionConvert(t *testing.T) { - sampleCb := func (args ...*Value) (retVal interface{}, err error) { - val1 := args[0].AsInt64() - val2 := args[1].AsInt64() - retVal = int64(val1+val2) - return - } - - fhandle, err := ConvertFunction(sampleCb) - if err != nil { - t.Error(err.Error()) - return - } - - retVal, err := fhandle.Invoke(10, 20) - if err != nil { - t.Error(err.Error()) - return - } - - if retVal.AsInt64() != int64(30) { - t.Errorf("Expected result :30 got:%v\n", retVal.AsInt64()) - return - } -} - -func TestFunctionError(t *testing.T) { - sampleCb := func (args ...*Value) (retVal interface{}, err error) { - err = fmt.Errorf("Sample Error XYZABC"); - return - } - - fhandle, err := ConvertFunction(sampleCb) - if err != nil { - t.Error(err.Error()) - return - } - - _, err = fhandle.Invoke() - if err == nil { - t.Error("Expected error but didn't received\n") - return - } - - if !strings.Contains(err.Error(), string("Sample Error XYZABC")) { - t.Errorf("Expected Error should contain :\"Sample Error XYZABC\" got :%v\n", err.Error()) - } -} - -// Check FunctionRegister -func TestFunctionRegister(t *testing.T) { - sampleCb := func (args ...*Value) (retVal interface{}, err error) { - val1 := args[0].AsInt64() - val2 := args[1].AsInt64() - retVal = int64(val1+val2) - return - } - - RegisterFunction(sampleCb, "TestFunctionRegister.sampleCb"); - // Query global functions available - funcNames, err := FuncListGlobalNames() - if err != nil { - t.Error(err.Error()) - return - } - - found := 0 - for ii := range (funcNames) { - if strings.Compare(funcNames[ii], "TestFunctionRegister.sampleCb") == 0 { - found = 1 - } - } - if found == 0 { - t.Error("Registered function not found in global function list.") - return - } - - // Get "sampleCb" and verify the call. - funp, err := GetGlobalFunction("TestFunctionRegister.sampleCb") - if err != nil { - t.Error(err.Error()) - return - } - - // Call function - result, err := funp.Invoke((int64)(10), (int64)(20)) - if err != nil { - t.Error(err.Error()) - return - } - if result.AsInt64() != int64(30) { - t.Errorf("Expected result :30 got:%v\n", result.AsInt64()) - return - } -} - -// Check packed function receiving go-closure as argument. -func TestFunctionClosureArg(t *testing.T) { - // sampleFunctionArg receives a Packed Function handle and calls it. - sampleFunctionArg := func (args ...*Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - - // Call Packed Function by Value - ret, err := pfunc.Invoke(args[1], args[2]) - if err != nil { - return - } - - // Call Packed Function with extracted values - ret1, err := pfunc.Invoke(args[1].AsInt64(), args[2].AsInt64()) - if err != nil { - return - } - if ret1.AsInt64() != ret.AsInt64() { - err = fmt.Errorf("Invoke with int64 didn't match with Value") - return - } - retVal = ret - return - } - - RegisterFunction(sampleFunctionArg, "TestFunctionClosureArg.sampleFunctionArg"); - funp, err := GetGlobalFunction("TestFunctionClosureArg.sampleFunctionArg") - if err != nil { - t.Error(err.Error()) - return - } - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*Value) (retVal interface{}, err error) { - val1 := args[0].AsInt64() - val2 := args[1].AsInt64() - retVal = int64(val1+val2) - return - } - - // Call function - result, err := funp.Invoke(funccall, 30, 50) - if err != nil { - t.Error(err.Error()) - return - } - - if result.AsInt64() != int64(80) { - t.Errorf("Expected result :80 got:%v\n", result.AsInt64()) - return - } -} - -// Check packed function returning a go-closure. -func TestFunctionClosureReturn(t *testing.T) { - // sampleFunctionCb returns a function closure which is embed as packed function in TVMValue. - sampleFunctionCb := func (args ...*Value) (retVal interface{}, err error) { - funccall := func (cargs ...*Value) (fret interface{}, ferr error) { - val1 := cargs[0].AsInt64() - val2 := cargs[1].AsInt64() - fret = int64(val1+val2) - return - } - retVal = funccall - return - } - - RegisterFunction(sampleFunctionCb, "TestFunctionClosureReturn.sampleFunctionCb"); - funp, err := GetGlobalFunction("TestFunctionClosureReturn.sampleFunctionCb") - if err != nil { - t.Error(err.Error()) - return - } - - // Call function - result, err := funp.Invoke() - if err != nil { - t.Error(err.Error()) - return - } - - pfunc := result.AsFunction() - pfuncRet, err := pfunc.Invoke(30, 40) - if err != nil { - t.Error(err.Error()) - return - } - if pfuncRet.AsInt64() != int64(70) { - t.Errorf("Expected result :70 got:%v\n", pfuncRet.AsInt64()) - return - } -} - -// Check packed function with no arguments and no return values. -func TestFunctionNoArgsReturns(t *testing.T) { - sampleFunction := func (args ...*Value) (retVal interface{}, err error) { - return - } - - fhandle, err := ConvertFunction(sampleFunction) - if err != nil { - t.Error(err.Error()) - return - } - - _, err = fhandle.Invoke() - if err != nil { - t.Error(err.Error()) - return - } -} - -// Check packed function returning a go-closure with no arg and returns. -func TestFunctionNoArgsReturns2(t *testing.T) { - // sampleFunctionCb returns a function closure which is embed as packed function in TVMValue. - sampleFunctionCb := func (args ...*Value) (retVal interface{}, err error) { - funccall := func (cargs ...*Value) (fret interface{}, ferr error) { - return - } - retVal = funccall - return - } - - funp, err := ConvertFunction(sampleFunctionCb) - if err != nil { - t.Error(err.Error()) - return - } - - // Call function - result, err := funp.Invoke() - if err != nil { - t.Error(err.Error()) - return - } - - pfunc := result.AsFunction() - _, err = pfunc.Invoke() - if err != nil { - t.Error(err.Error()) - return - } -} diff --git a/golang/src/gotvm.cc b/golang/src/gotvm.cc deleted file mode 100644 index d8919dafbfcb..000000000000 --- a/golang/src/gotvm.cc +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm native interface definition - * \file gotvm.cxx - */ - -// Standard includes -#include -#include -#include -#include -#include -#include - -// golang string compatible definition -typedef struct { - char* p; - int n; -} _gostring_; -#include - -#ifdef __cplusplus -extern "C" { -#endif - -// TVM runtime C interface -#include -#include - -/*! - * \brief Convert native char array to _gostring_ structure. - * _gostring_ structure represents the same memory footprint as golang string object. - * - * \param p is char pointer to a char array. - * \param l is the size of the char array. this method exclusively need length as - * its possible to have a bytearray in a string. - * - * \return _gostring_ object corresponding to native char array. - * Caller is responsible to free the memory block allocated here. - */ -static _gostring_ _native_to_gostring(const char* p, size_t l) { - _gostring_ ret; - ret.p = reinterpret_cast(malloc(l)); - if (NULL == ret.p) { - ret.n = 0; - return ret; - } - memcpy(ret.p, p, l); - ret.n = l; - return ret; -} - -/*! - * \brief embeds a 64bit uint value inside a string to serialize the data. - * - * \param s is string object. - * \param off is the offset in the string object. - * \param v is the uint64_t value which need to embed into given string. - */ -static void putuint64(std::string* s, size_t off, uint64_t v) { - for (int i = 0; i < 8; i++) { - (*s)[off + i] = (v >> (i * 8)) & 0xff; - } -} - -// TVM runtime C interface wrappers - -/*! - * \brief Native interface to query TVM_VERSION in golang string format. - * - * \return char pointer to TVM-VERSION - */ -const char* _TVM_VERSION(void) { - const char* version = TVM_VERSION; - return version; -} - -/*! - * \brief Native interface for getting TVMGlobal function list. - * - * \param names return by argument to return the function names. - * We wrap all strings into single string joined by (len+string) - * which is unpacked and processed in golang. - * - * \return c_runtime_api return status. - */ -int _TVMFuncListGlobalNames(_gostring_* names) { - int names_size; - char** names_array; - int result; - - result = TVMFuncListGlobalNames(&names_size, (char const***)&names_array); - if (result) { - return result; - } - - size_t tot = 8; - for (int ii = 0; ii < names_size; ++ii) { - tot += 8 + strlen(names_array[ii]); - } - - std::string str; - str.resize(tot); - putuint64(&str, 0, names_size); - size_t off = 8; - for (int64_t ii = 0; ii < names_size; ++ii) { - putuint64(&str, off, strlen(names_array[ii])); - off += 8; - str.replace(off, strlen(names_array[ii]), names_array[ii]); - off += strlen(names_array[ii]); - } - *names = _native_to_gostring(str.data(), str.size()); - if (str.size() != names->n) { - TVMAPISetLastError("malloc failed during _native_to_gostring"); - result = 1; - } - return result; -} - -// Helpers for TVMValue - -/*! - * \brief Native helper to copy TVMValue from golang slice to native array. - * this helper is need as underlying memory for golang slice is not continuous. - * - * \param to_ptr is the native pointer of TVMValue array. - * \param from_ptr pointer to TVMValue in golang slice. - * \param array index in native array. - */ -void _TVMValueNativeSet(void* to_ptr, void* from_ptr, int ind) { - TVMValue* from_p = reinterpret_cast(from_ptr); - TVMValue* to_p = reinterpret_cast(to_ptr); - memcpy(to_p + ind, from_p, sizeof(TVMValue)); -} - -/*! - * \brief Native helper to copy TVMValue from golang slice to native array. - * this helper is need as underlying memory for golang slice is not continuous. - * - * \param to_ptr pointer to TVMValue in golang slice. - * \param from_ptr is the native pointer of TVMValue array. - * \param array index in native array. - */ -void _TVMValueNativeGet(void* to_ptr, void* from_ptr, int ind) { - TVMValue* from_p = reinterpret_cast(from_ptr); - TVMValue* to_p = reinterpret_cast(to_ptr); - memcpy(to_p, from_p + ind, sizeof(TVMValue)); -} - -extern int goTVMCallback(void*, void*, int, void*, void*); - -/*! - * \brief _TVMCallback is the TVM runtime callback function for ffi::Functiontion system. - * - * \param args is an array of TVMValue - * \param type_codes is an array of int - * \param num_args is int representing number of in arguments - * \param ret is the return value handle to set the packed function return. - * \param resource_handle is the golang private data pointer. - * - * \returns the error status as TVM_DLL - */ -int _TVMCallback(TVMValue* args, int* type_codes, int num_args, TVMRetValueHandle ret, - void* resource_handle) { - return goTVMCallback(args, type_codes, num_args, ret, resource_handle); -} - -/*! - * _TVMPackedCFuncFinalizer is finalizer for packed function system. - * - */ -void _TVMPackedCFuncFinalizer(void* resource_handle) { return; } - -/*! - * /brief _ConvertFunction creates a packed function for with given resource handle. - * - * /param fptr is the pointer to golang resource handle. - * /param *fhandle is the return argument holding packed function. - * - * /return is an int indicating the return status. - */ -int _ConvertFunction(void* fptr, TVMFunctionHandle* fhandle) { - int ret = TVMFuncCreateFromCFunc(_TVMCallback, fptr, _TVMPackedCFuncFinalizer, fhandle); - return ret; -} - -#ifdef __cplusplus -} -#endif diff --git a/golang/src/gotvm.go b/golang/src/gotvm.go deleted file mode 100644 index 072d9cce4619..000000000000 --- a/golang/src/gotvm.go +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file gotvm.go - */ - - -// Package gotvm is TVM runtime interface definition for golang. -// -// Application need to import this package to access the c_runtime_api exposed by TVM. -package gotvm - -//#include "gotvm.h" -import "C" - -// DLPackVersion is the dlpack version of tvm runtime. -var DLPackVersion = int(C.DLPACK_VERSION) -// TVMVersion is the TVM runtime version. -var TVMVersion = getTVMVersion() - -func getTVMVersion() (retStr string) { - retStr = C.GoString(C._TVM_VERSION()) - return -} diff --git a/golang/src/gotvm.h b/golang/src/gotvm.h deleted file mode 100644 index a053e39bd79a..000000000000 --- a/golang/src/gotvm.h +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm native interface declaration. - * \file gotvm.h - * - * These declarations are in cgo interface definition while calling API - * across golang and native C boundaries. - */ - -#ifndef GOTVM_GOTVM_H_ -#define GOTVM_GOTVM_H_ - -#ifdef __cplusplus -extern "C" { -#endif - -#include -#include -#include -#include -#include - -// Some type definitions for golang "C" -typedef void* native_voidp; - -// Version -extern char* _TVM_VERSION(void); - -// Wrappers : For incompatible cgo API. -// To handle array of strings wrapped into __gostring__ -extern int _TVMFuncListGlobalNames(void*); -// To handle TVMValue slice to/from native sequential TVMValue array. -extern void _TVMValueNativeSet(void* to, void* from, int index); -extern void _TVMValueNativeGet(void* to, void* from, int index); - -// Callbacks -extern int _ConvertFunction(void* fptr, void* funp); - -#ifdef __cplusplus -} -#endif -#endif // GOTVM_GOTVM_H_ diff --git a/golang/src/gotvm_test.go b/golang/src/gotvm_test.go deleted file mode 100644 index 271b1899897b..000000000000 --- a/golang/src/gotvm_test.go +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file gotvm_test.go - */ - - -package gotvm - -import ( - "testing" - "reflect" -) - -// Check TVMVersion API -func TestTVMVersion(t *testing.T) { - if len(TVMVersion) == 0 { - t.Error("TVMVersion not set\n") - } - if reflect.TypeOf(TVMVersion).Kind() != reflect.String { - t.Error("TVMVersion type mismatch\n") - } -} - -// Check DLPackVersion API -func TestDLPackVersion(t *testing.T) { - if reflect.TypeOf(DLPackVersion).Kind() != reflect.Int { - t.Error("TVMVersion type mismatch\n") - } -} diff --git a/golang/src/module.go b/golang/src/module.go deleted file mode 100644 index 8ac09e369cae..000000000000 --- a/golang/src/module.go +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for TVMModule interface. - * \file module.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "errors" - "runtime" - "unsafe" -) - -// Module type in golang hold pointer for the TVMModule handle. -// -// Module initialization happen through TVMModLoadFromFile api in TVM runtime. -type Module uintptr - -// nativeCPtr returns type freed uintptr for the Module. -func (tvmmodule *Module) nativeCPtr() (retVal uintptr) { - retVal = (uintptr)(*tvmmodule) - return -} - -// LoadModuleFromFile loads the given module in TVM runtime. -// -// `modpath` is the path to tvm module. -// -// `args` is an optional arguments of ["dll", "dylib", "dso", "so"] with default value "so" -// -// returns pointer to Module and err or if any. -func LoadModuleFromFile(modpath string, args ...interface{}) (retVal *Module, err error) { - modtype := "so" - if len(args) > 0 { - modtype = args[0].(string) - } - var modp uintptr - - cmodpath := C.CString(modpath) - cmodtype := C.CString(modtype) - - ret := (int32)(C.TVMModLoadFromFile(cmodpath, - cmodtype, - (*C.TVMModuleHandle)(unsafe.Pointer(&modp)))) - - C.free(unsafe.Pointer(cmodpath)) - C.free(unsafe.Pointer(cmodtype)) - - if ret != 0 { - err = errors.New(getTVMLastError()) - return - } - - handle := new(Module) - *handle = Module(modp) - finalizer := func(mhandle *Module) { - nativeTVMModFree(mhandle) - mhandle = nil - } - runtime.SetFinalizer(handle, finalizer) - retVal = handle - return -} - -// nativeTVMModFree free the module handle allocated in TVM runtime. -// -// `modp` is the Module handle to be freed. -func nativeTVMModFree(modp *Module) (retVal int32) { - retVal = (int32) (C.TVMModFree(C.TVMModuleHandle(modp.nativeCPtr()))) - return -} - -// GetFunction returns the function pointer from the module for given function name. -// -// `tvmmodule` is handle for Module -// -// `funcname` function name in module. -// -// `args` variadic args of `queryImport` -// -// returns function closure with signature -// func (args ...interface{}) (interface{}, error) and error if any. -// -// The closure function can be used to call Function with arguments directly. -// -// Variadic arguments can be any type which can be embed into Value. -func (tvmmodule *Module) GetFunction ( - funcname string, args ...interface{}) ( - retVal *Function, err error){ - queryImports := int32(1) - if len(args) > 0 { - queryImports = int32(args[1].(int)) - } - - var funp uintptr - cfuncname := C.CString(funcname) - ret := (int32)(C.TVMModGetFunction((C.TVMModuleHandle)(*tvmmodule), - cfuncname, - C.int(queryImports), - (*C.TVMFunctionHandle)(unsafe.Pointer(&funp)))) - C.free(unsafe.Pointer(cfuncname)) - - if ret != 0 { - err = errors.New(getTVMLastError()) - return - } - - handle := new(Function) - *handle = Function(funp) - finalizer := func(fhandle *Function) { - nativeTVMFuncFree(fhandle) - fhandle = nil - } - runtime.SetFinalizer(handle, finalizer) - retVal = handle - return -} diff --git a/golang/src/module_test.go b/golang/src/module_test.go deleted file mode 100644 index 7e18a86c5b3a..000000000000 --- a/golang/src/module_test.go +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file module_test.go - */ - - -package gotvm - -import ( - "testing" - "reflect" -) - -// Check module loading - dll -func TestModuleTestLoad1(t *testing.T) { - // dll - mod, err := LoadModuleFromFile("./deploy.so", "dll") - if err != nil { - t.Error(err.Error()) - return - } - if reflect.TypeOf(mod).Kind() != reflect.Ptr { - t.Error("Module type mis matched\n") - return - } -} - -// Check module loading - dylib -func TestModuleTestLoad2(t *testing.T) { - // dylib - mod, err := LoadModuleFromFile("./deploy.so", "dylib") - if err != nil { - t.Error(err.Error()) - return - } - if reflect.TypeOf(mod).Kind() != reflect.Ptr { - t.Error("Module type mis matched\n") - return - } -} - -func TestModuleTestLoad3(t *testing.T) { - // dso - mod, err := LoadModuleFromFile("./deploy.so", "dso") - if err != nil { - t.Error(err.Error()) - return - } - if reflect.TypeOf(mod).Kind() != reflect.Ptr { - t.Error("Module type mis matched\n") - return - } -} - -// Check module loading - so -func TestModuleTestLoad4(t *testing.T) { - // so - mod, err := LoadModuleFromFile("./deploy.so", "so") - if err != nil { - t.Error(err.Error()) - return - } - if reflect.TypeOf(mod).Kind() != reflect.Ptr { - t.Error("Module type mis matched\n") - return - } -} - -// Check module loading - default (so) -func TestModuleTestLoad5(t *testing.T) { - // default type as so - mod, err := LoadModuleFromFile("./deploy.so") - if err != nil { - t.Error(err.Error()) - return - } - if reflect.TypeOf(mod).Kind() != reflect.Ptr { - t.Error("Module type mis matched\n") - return - } -} - -// Check module loading err -func TestModuleTestLoadErr(t *testing.T) { - // Unknown file should return error - _, err := LoadModuleFromFile("xyzabc.so") - if err == nil { - t.Error("Expected an error, but not received\n") - return - } -} diff --git a/golang/src/ndarray.go b/golang/src/ndarray.go deleted file mode 100644 index b1e71aef56bd..000000000000 --- a/golang/src/ndarray.go +++ /dev/null @@ -1,347 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for TVMArray aka DLTensor - * \file ndarray.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "unsafe" - "fmt" - "errors" - "runtime" - "reflect" -) - -// Array type in golang hold pointer for the TVMArray object from dlpack. -// -// Array initialization happen through Empty api -type Array uintptr - -// nativeCPtr returns type freed uintptr for the Array. -func (parray Array) nativeCPtr() (retVal uintptr) { - retVal = (uintptr)(parray) - return -} - -func (parray Array) nativeCopyFrom(data unsafe.Pointer, datalen int) (err error) { - ret := C.TVMArrayCopyFromBytes((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr())), - data, - C.ulong(datalen)) - if ret != 0 { - err = errors.New(getTVMLastError()) - } - return -} - -// CopyFrom copies given golang data slice into Array. -// -// `val` is interface homding a slice of Array data type. -// -// returns err is any. -// TOD: Use reflections for better handling -func (parray Array) CopyFrom(val interface{}) (err error) { - var data unsafe.Pointer - var datalen int - dtype := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype - - switch val.(type) { - case []int8: - sliceVal := val.([]int8) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []int16: - sliceVal := val.([]int16) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []int32: - sliceVal := val.([]int32) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []int64: - sliceVal := val.([]int64) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []uint8: - sliceVal := val.([]uint8) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []uint16: - sliceVal := val.([]uint16) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []uint32: - sliceVal := val.([]uint32) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []uint64: - sliceVal := val.([]uint64) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []float32: - sliceVal := val.([]float32) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - case []float64: - sliceVal := val.([]float64) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - return parray.nativeCopyFrom(data, datalen) - default: - err = fmt.Errorf("Given type not supported : %v", reflect.TypeOf(val)) - return - } - return -} - -func (parray Array) nativeCopyTo (data unsafe.Pointer, datalen int) (err error){ - ret := C.TVMArrayCopyToBytes((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr())), - unsafe.Pointer(data), - C.ulong(datalen)) - - if ret != 0 { - err = errors.New(getTVMLastError()) - } - return -} - -// AsSlice returns the unitptr of for the data inside Array. -// -// returns the slice of array inside Array and err of any. -// TOD: Use reflections for better handling -func (parray Array) AsSlice() (retVal interface{}, err error) { - shape := parray.GetShape() - size := int64(1) - var data unsafe.Pointer - var datalen int - - for ii := range shape { - size *= shape[ii] - } - dtype := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype - - switch parray.GetDType() { - case "int8": - sliceVal := make([]int8, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "int16": - sliceVal := make([]int16, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "int32": - sliceVal := make([]int32, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "int64": - sliceVal := make([]int64, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "uint8": - sliceVal := make([]uint8, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "uint16": - sliceVal := make([]uint16, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "uint32": - sliceVal := make([]uint32, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "uint64": - sliceVal := make([]uint64, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "float32": - sliceVal := make([]float32, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - case "float64": - sliceVal := make([]float64, size) - data = unsafe.Pointer(&sliceVal[0]) - datalen = len(sliceVal) * int(dtype.bits / 8) - err = parray.nativeCopyTo(data, datalen) - retVal = sliceVal - default: - err = fmt.Errorf("Given type not supported : %v", parray.GetDType()) - return - } - return -} - -// GetNdim returns the number of dimentions in Array -func (parray Array) GetNdim() (retVal int32) { - retVal = int32(((*C.DLTensor)(unsafe.Pointer(parray))).ndim) - return -} - -// GetShape returns the number of dimentions in Array -func (parray Array) GetShape() (retVal []int64) { - shapePtr := (*C.int64_t)(((*C.DLTensor)(unsafe.Pointer(parray))).shape) - ndim := parray.GetNdim() - - shapeSlice := (*[1<<31] int64)(unsafe.Pointer(shapePtr))[:ndim:ndim] - retVal = make([]int64, ndim) - copy(retVal, shapeSlice) - return -} - -// GetDType returns the number of dimentions in Array -func (parray Array) GetDType() (retVal string) { - ret := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype - retVal, _ = dtypeFromTVMType(*(*pTVMType)(unsafe.Pointer(&ret))) - return -} - -// GetDevice returns the number of dimentions in Array -func (parray Array) GetDevice() (retVal Device) { - ret := ((*C.DLTensor)(unsafe.Pointer(parray))).device - retVal = *(*Device)(unsafe.Pointer(&ret)) - return -} - -// nativeTVMArrayAlloc is used to allocate TVMArray from given attributes. -// -// `shape` is int64 slice holding shape of the Array to be created. -// -// `ndim` is the rank of the Array to be created. -// -// `dtypeCode`, `dtypeBits` and `dtypeLanes` describe the data type in Array. -// -// `deviceType` indicates the device on whose memory the Array to allocated. -// -// `deviceID` indicates device index if multiple devices of same type present. -// -// return argument holding native pointer to newly created Array and error is any. -func nativeTVMArrayAlloc(shape []int64, ndim int32, - dtypeCode int32, dtypeBits int32, dtypeLanes int32, - deviceType int32, deviceID int32) (retVal uintptr, err error) { - ret := (int32)(C.TVMArrayAlloc((*C.long)(&(shape[0])), - C.int(ndim), - C.int(dtypeCode), - C.int(dtypeBits), - C.int(dtypeLanes), - C.int(deviceType), - C.int(deviceID), - (*C.TVMArrayHandle)(unsafe.Pointer(&retVal)))) - if ret != 0 { - err = errors.New(getTVMLastError()) - return - } - return -} - -// Empty is used to allocate TVM empty array of given epecification. -// -// `shape` is int64 slice holding shape of the Array -// -// `args` is variadic args for -// -// `args[0]` is string for data type. Default value is 'float32' -// -// `args[1]` is Device. Default value is '{KDLCPU, 0}' -// -// returns pointer to Array on successful execution and error if any. -func Empty(shape []int64, args ...interface{}) (parray *Array, err error) { - typeName := "float32" - dev := Device{KDLCPU, 0} - - if len(shape) < 1 { - err = fmt.Errorf("Invalid shape for Array creation: %v", len(shape)) - return - } - - for i, val := range args { - switch val.(type) { - case string: - typeName = args[i].(string) - case Device: - dev = args[i].(Device) - default: - err = fmt.Errorf("Invalid Optional Argument Type: %T", val) - return - } - } - - tvmType, err := dtypeToTVMType(typeName) - if err != nil { - return - } - ndim := int32(len(shape)) - newArray, err := nativeTVMArrayAlloc(shape, ndim, int32(tvmType.code), - int32(tvmType.bits), int32(tvmType.lanes), - dev.DeviceType, dev.DeviceID) - if err != nil { - return - } - handle := new(Array) - *handle = Array(newArray) - - finalizer := func (ahandle *Array) { - nativeTVMArrayFree(*ahandle) - ahandle = nil - } - runtime.SetFinalizer(handle, finalizer) - parray = handle - return -} - -// nativeTVMArrayFree is used to release the Array. -// -// `parray` is the Array handle. -// -// `ret` indicates the status of this api execution. -func nativeTVMArrayFree(parray Array) (retVal int32) { - retVal = (int32)(C.TVMArrayFree((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr())))) - return -} diff --git a/golang/src/tvm_runtime_pack.cc b/golang/src/tvm_runtime_pack.cc deleted file mode 100644 index 475abf5a3e36..000000000000 --- a/golang/src/tvm_runtime_pack.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief This is an all in one TVM runtime file. - * \file tvm_runtime_pack.cc - */ -#include "src/runtime/c_runtime_api.cc" -#include "src/runtime/container.cc" -#include "src/runtime/cpu_device_api.cc" -#include "src/runtime/file_utils.cc" -#include "src/runtime/library_module.cc" -#include "src/runtime/logging.cc" -#include "src/runtime/module.cc" -#include "src/runtime/ndarray.cc" -#include "src/runtime/object.cc" -#include "src/runtime/registry.cc" -#include "src/runtime/thread_pool.cc" -#include "src/runtime/threading_backend.cc" -#include "src/runtime/workspace_pool.cc" - -// NOTE: all the files after this are optional modules -// that you can include remove, depending on how much feature you use. - -// Likely we only need to enable one of the following -// If you use Module::Load, use dso_module -// For system packed library, use system_lib_module -#include "src/runtime/dso_library.cc" -#include "src/runtime/system_library.cc" - -// Graph executor -#include "src/runtime/memory/memory_manager.cc" - -// Uncomment the following lines to enable RPC -// #include "../../src/runtime/rpc/rpc_session.cc" -// #include "../../src/runtime/rpc/rpc_event_impl.cc" -// #include "../../src/runtime/rpc/rpc_server_env.cc" - -// These macros enables the device API when uncommented. -#define TVM_CUDA_RUNTIME 1 -#define TVM_METAL_RUNTIME 1 -#define TVM_OPENCL_RUNTIME 1 - -// Uncomment the following lines to enable Metal -// #include "../../src/runtime/metal/metal_device_api.mm" -// #include "../../src/runtime/metal/metal_module.mm" - -// Uncomment the following lines to enable CUDA -// #include "../../src/runtime/cuda/cuda_device_api.cc" -// #include "../../src/runtime/cuda/cuda_module.cc" - -// Uncomment the following lines to enable OpenCL -// #include "../../src/runtime/opencl/opencl_device_api.cc" -// #include "../../src/runtime/opencl/opencl_module.cc" -// #include "../src/runtime/source_utils.cc" diff --git a/golang/src/type.go b/golang/src/type.go deleted file mode 100644 index 6202e0baa875..000000000000 --- a/golang/src/type.go +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package for TVMType interface - * \file type.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "fmt" -) - -// pTVMType corresponding to data types. -type pTVMType struct { - code uint8 - bits uint8 - lanes uint16 -} - -// data type to pTVMType mapping -var dtypeMap = map[string] pTVMType { - "int8": pTVMType{0, 8, 1}, - "int16": pTVMType{0, 16, 1}, - "int32": pTVMType{0, 32, 1}, - "int64": pTVMType{0, 64, 1}, - "uint8": pTVMType{1, 8, 1}, - "uint16": pTVMType{1, 16, 1}, - "uint32": pTVMType{1, 32, 1}, - "uint64": pTVMType{1, 64, 1}, - "float32": pTVMType{2, 32, 1}, - "float64": pTVMType{2, 64, 1}, -} - -// dtypeFromTVMType return the pTVMType corresponding to given dtype -// -// `dtype` string for the given data type. -func dtypeFromTVMType(tvmtype pTVMType) (retVal string, err error) { - for k, v := range dtypeMap { - if v.code == tvmtype.code && v.bits == tvmtype.bits && v.lanes == tvmtype.lanes { - retVal = k - return - } - } - - err = fmt.Errorf("Cannot map TVMType:%v to dtype", tvmtype) - return -} - -// dtypeToTVMType return the pTVMType corresponding to given dtype -// -// `dtype` string for the given data type. -func dtypeToTVMType(args ...interface{}) (tvmtype pTVMType, err error) { - dtype := args[0].(string) - lanes := 1 - - if len(args) == 2 { - lanes = args[1].(int) - } - - for k, v := range dtypeMap { - if k == dtype { - tvmtype = v - tvmtype.lanes = uint16(lanes) - return - } - } - err = fmt.Errorf("Cannot map dtype:%v to TVMType", dtype) - return -} diff --git a/golang/src/utils.go b/golang/src/utils.go deleted file mode 100644 index 2da4138a1e66..000000000000 --- a/golang/src/utils.go +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for common utilities - * \file utils.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "unsafe" -) - -// Native string map for go string -type nativeGoString struct { p uintptr; n int32 } - -func goStringFromNative (s string) (retStr string) { - p := *(*nativeGoString)(unsafe.Pointer(&s)) - retStr = string((*[0x7fffffff]byte)(unsafe.Pointer(p.p))[:p.n]) - C.free(unsafe.Pointer(p.p)) - return -} diff --git a/golang/src/value.go b/golang/src/value.go deleted file mode 100644 index 450cf4866ab0..000000000000 --- a/golang/src/value.go +++ /dev/null @@ -1,378 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package source for TVMValue interface - * \file value.go - */ - -package gotvm - -//#include "gotvm.h" -import "C" - -import ( - "fmt" - "runtime" - "unsafe" -) - -// KHandle is golang type code for TVM enum kTVMOpaqueHandle. -var KHandle = int32(C.kTVMOpaqueHandle) -// KNull is golang type code for TVM kTVMNullptr. -var KNull = int32(C.kTVMNullptr) -// KTVMType is golang type code for TVM kTVMDataType. -var KTVMType = int32(C.kTVMDataType) -// KDLDevice is golang type code for TVM kDLDevice. -var KDLDevice = int32(C.kDLDevice) -// KArrayHandle is golang type code for TVM kTVMDLTensorHandle. -var KArrayHandle = int32(C.kTVMDLTensorHandle) -// KObjectHandle is golang type code for TVM kTVMObjectHandle. -var KObjectHandle = int32(C.kTVMObjectHandle) -// KModuleHandle is gonag type code for TVM kTVMModuleHandle. -var KModuleHandle = int32(C.kTVMModuleHandle) -// KFuncHandle is gonalg type code for TVM kTVMPackedFuncHandle. -var KFuncHandle = int32(C.kTVMPackedFuncHandle) -// KStr is golang type code for TVM kTVMStr. -var KStr = int32(C.kTVMStr) -// KBytes is golang type code for TVM kTVMBytes. -var KBytes = int32(C.kTVMBytes) -// KNDArrayContainer is golang typecode for kTVMNDArrayHandle. -var KNDArrayContainer = int32(C.kTVMNDArrayHandle) -// KExtBegin is golang enum corresponding to TVM kTVMExtBegin. -var KExtBegin = int32(C.kTVMExtBegin) -// KNNVMFirst is golang enum corresponding to TVM kNNVMFirst. -var KNNVMFirst = int32(C.kTVMNNVMFirst) -// KNNVMLast is golang enum corresponding to TVM kNNVMLast. -var KNNVMLast = int32(C.kTVMNNVMLast) -// KExtReserveEnd is golang enum corresponding to TVM kExtReserveEnd. -var KExtReserveEnd = int32(C.kTVMExtReserveEnd) -// KExtEnd is golang enum corresponding to TVM kExtEnd. -var KExtEnd = int32(C.kTVMExtEnd) -// KDLInt is golang type code for TVM kDLInt. -var KDLInt = int32(C.kDLInt) -// KDLUInt is golang type code for TVM kDLUInt. -var KDLUInt = int32(C.kDLUInt) -// KDLFloat is golang type code for TVM kDLFloat. -var KDLFloat = int32(C.kDLFloat) - -// Value Typemap for union exposed by TVM runtime API. -// -// gotvm maps it to a uintptr and then dynamically allocates memory by newTVMValue method. -type Value struct { - nptr uintptr - dtype int32 - isLocal bool -} - -// AsInt64 returns the int64 value inside the Value. -func (tvmval *Value) AsInt64() (retVal int64) { - retVal = tvmval.getVInt64() - return -} - -// AsFloat64 returns the Float64 value inside the Value. -func (tvmval *Value) AsFloat64() (retVal float64) { - retVal = tvmval.getVFloat64() - return -} - -// AsModule returns the Module inside the Value. -func (tvmval *Value) AsModule() (retVal *Module) { - mhandle := tvmval.getVMHandle() - retVal = &mhandle - return -} - -// AsFunction returns the Function inside the Value. -func (tvmval *Value) AsFunction() (retVal *Function) { - fhandle := tvmval.getVFHandle() - retVal = &fhandle - - return -} - -// AsBytes returns the byte slice value inside the Value. -func (tvmval *Value) AsBytes() (retVal []byte) { - retVal = tvmval.getVBHandle().getData() - return -} - -// AsStr returns the golang string in the Value. -func (tvmval *Value) AsStr() (retVal string) { - str := tvmval.getVStr() - retVal = str - return -} - -// nativeCPtr return the unitptr corresponding to Value type. -func (tvmval *Value) nativeCPtr() (ret uintptr) { - ret = (uintptr)(tvmval.nptr) - return -} - -// moveFrom copies the tvmval from other Value object. -func (tvmval *Value) moveFrom(fromval *Value) () { - C.memcpy(unsafe.Pointer(tvmval.nativeCPtr()), - unsafe.Pointer(fromval.nativeCPtr()), - C.sizeof_TVMValue) - - // Move the dtype too. - tvmval.dtype = fromval.dtype - fromval.dtype = KNull - return -} - -// setVInt64 initializes the Value object with given int64 value. -// -// `val` is the int64 value to initialize the Value -func (tvmval *Value) setVInt64(val int64) { - valp := (*C.int64_t)(unsafe.Pointer(tvmval.nativeCPtr())) - *valp = C.int64_t(val) - tvmval.dtype = KDLInt - return -} - - -// getVInt64 returns the int64 value inside the Value. -func (tvmval *Value) getVInt64() (retVal int64) { - valp := (*C.int64_t)(unsafe.Pointer(tvmval.nativeCPtr())) - retVal = int64(*valp) - return -} - -// setVFloat64 initializes the Value object with given float64 value. -// -// `val` is the float64 value to initialize the Value. -func (tvmval *Value) setVFloat64(val float64) { - valp := (*C.double)(unsafe.Pointer(tvmval.nativeCPtr())) - *valp = C.double(val) - tvmval.dtype = KDLFloat - return -} - -// getVFloat64 returns the float64 value inside Value. -func (tvmval *Value) getVFloat64() (retVal float64) { - valp := (*C.double)(unsafe.Pointer(tvmval.nativeCPtr())) - retVal = float64(*valp) - return -} - -// setVHandle initializes the handle inside the Value. -// -// Can be used to store any uintptr type object like -// module handle, function handle and any object's nativeCPtr. -// -// `val` is the uintptr type of given handle. -func (tvmval *Value) setVHandle(val uintptr) { - valp := (**C.void)(unsafe.Pointer(tvmval.nativeCPtr())) - *valp = (*C.void)(unsafe.Pointer(val)) -} - -// getVHandle returns the uintptr handle -func (tvmval *Value) getVHandle() (retVal uintptr) { - valp := (**C.void)(unsafe.Pointer(tvmval.nativeCPtr())) - retVal = uintptr(unsafe.Pointer(*valp)) - return -} - -// setVStr intializes the Value with given golang string object. -// -// `val` is the golang string object used to initialize the Value. -func (tvmval *Value) setVStr(val string) { - valp := (**C.char)(unsafe.Pointer(tvmval.nativeCPtr())) - *valp = C.CString(val) - tvmval.dtype = KStr - return -} - - -// getVStr returns the golang string for the native string inside Value. -func (tvmval *Value) getVStr() (retVal string) { - valp := (**C.char)(unsafe.Pointer(tvmval.nativeCPtr())) - retVal = C.GoString(*valp) - return -} - -// unSetVStr release the memory allocated in setVStr -func (tvmval *Value) unSetVStr() { - valp := (**C.char)(unsafe.Pointer(tvmval.nativeCPtr())) - C.free(unsafe.Pointer(*valp)) - tvmval.dtype = KNull -} - -// setVAHandle is used to set Array handle in Value. -// -// Application can call the setVHandle with nativeCPtr instead too. -// This is a wrapper to accept Array directly. -func (tvmval *Value) setVAHandle(ptvmarray Array) { - tvmval.setVHandle(ptvmarray.nativeCPtr()) - tvmval.dtype = KArrayHandle - return -} - -// getVAHandle is used to get Array handle in Value. -func (tvmval *Value) getVAHandle() (retVal Array) { - retVal = (Array)(tvmval.getVHandle()) - return -} - -// setVMHandle is used to set Module handle in Value. -// -// Application can call the setVHandle with nativeCPtr instead too. -// This is a wrapper to accept Module directly. -func (tvmval *Value) setVMHandle(tvmmodule Module) { - tvmval.setVHandle(tvmmodule.nativeCPtr()) - tvmval.dtype = KModuleHandle - return -} - -// getVMHandle is used to get Module handle in Value. -func (tvmval *Value) getVMHandle() (retVal Module) { - retVal = (Module)(tvmval.getVHandle()) - return -} - -// setVFHandle is used to set Function handle in Value. -// -// Application can call the setVHandle with nativeCPtr instead. -// This is a wrapper to accept Function directly. -func (tvmval *Value) setVFHandle(tvmfunction Function) { - tvmval.setVHandle(tvmfunction.nativeCPtr()) - tvmval.dtype = KFuncHandle - return -} - -// getVFHandle is used to get Function handle in Value. -func (tvmval *Value) getVFHandle() (retVal Function) { - retVal = (Function)(tvmval.getVHandle()) - return -} - -// setVBHandle is used to set ByteArray handle in Value. -// -// Application can call the setVHandle with nativeCPtr instead. -// This is a wrapper to accept ByteArray directly. -func (tvmval *Value) setVBHandle(tbytearray ByteArray) { - tvmval.setVHandle(tbytearray.nativeCPtr()) - tvmval.dtype = KBytes - return -} - -// getVBHandle is used to get ByteArray handle in Value. -func (tvmval *Value) getVBHandle() (retVal ByteArray) { - retVal = (ByteArray)(tvmval.getVHandle()) - return -} - -// setValue is used to set the given value in Value. -// -// `val` is value of types accepted by Value container or native union. -func (tvmval *Value) setValue(val interface{}) (retVal int32, err error) { - retVal = KNull - switch val.(type) { - case string: - tvmval.setVStr(val.(string)) - case uint8: - tvmval.setVInt64(int64(val.(uint8))) - case uint16: - tvmval.setVInt64(int64(val.(uint16))) - case uint32: - tvmval.setVInt64(int64(val.(uint32))) - case uint64: - tvmval.setVInt64(int64(val.(uint64))) - case int: - tvmval.setVInt64(int64(val.(int))) - case int8: - tvmval.setVInt64(int64(val.(int8))) - case int16: - tvmval.setVInt64(int64(val.(int16))) - case int32: - tvmval.setVInt64(int64(val.(int32))) - case int64: - tvmval.setVInt64(val.(int64)) - case float32: - tvmval.setVFloat64(float64(val.(float32))) - case float64: - tvmval.setVFloat64(val.(float64)) - case *Module: - tvmval.setVMHandle(*(val.(*Module))) - case *Function: - tvmval.setVFHandle(*(val.(*Function))) - case *ByteArray: - tvmval.setVBHandle(*(val.(*ByteArray))) - case []byte: - barray := newByteArray(val.([]byte)) - tvmval.setVBHandle(barray) - case *Array: - tvmval.setVAHandle(*(val.(*Array))) - case func (args ...*Value) (interface{}, error): - fhandle, apierr := ConvertFunction(val) - if apierr != nil { - err = fmt.Errorf("Given value Type not defined for Value: %v : %T", val, val); - return - } - tvmval.setVFHandle(*fhandle) - - // Clear the finalizer as we don't need to control it anymore. - runtime.SetFinalizer(fhandle, nil) - case *Value: - tvmval.moveFrom(val.(*Value)) - case Value: - fromval := val.(Value) - tvmval.moveFrom(&fromval) - default: - err = fmt.Errorf("Given value Type not defined for Value: %v : %T", val, val); - } - retVal = tvmval.dtype - return -} - -// newTVMValue initialize the TVMValue native object. -// -// This is intended to use as intermediate type between native and golang types. -// Allocated from FuncCall or Callback to handle conversions. -func newTVMValue() (retVal *Value) { - handle := new(Value) - - handle.nptr = (uintptr(C.malloc(C.sizeof_TVMValue))) - handle.dtype = KNull - handle.isLocal = true - finalizer := func(vhandle *Value) { - vhandle.deleteTVMValue() - vhandle = nil - } - runtime.SetFinalizer(handle, finalizer) - retVal = handle - return -} - -// deleteTVMValue free the native Value object which is allocated in newTVMValue. -func (tvmval Value) deleteTVMValue() { - if tvmval.isLocal == true { - if tvmval.dtype == KStr { - tvmval.unSetVStr() - } - if tvmval.dtype == KBytes { - tvmval.getVBHandle().deleteTVMByteArray() - } - } - - C.free(unsafe.Pointer(tvmval.nativeCPtr())) -} diff --git a/golang/src/value_test.go b/golang/src/value_test.go deleted file mode 100644 index ba502254cd20..000000000000 --- a/golang/src/value_test.go +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief gotvm package - * \file value_test.go - */ - -package gotvm - -import ( - "testing" - "math/rand" - "strings" -) - -// Check Int64 Value looping via packed function calling another packed function. -func TestValueLoopInt64(t *testing.T) { - // Receive a function Handle and argument and echo the Value on the handle. - sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - newArgs := args[1:] - - // Call Packed Function by Value - return pfunc.Invoke(newArgs) - } - - fhandle, err := ConvertFunction(sampleFunctionLoop) - if err != nil { - t.Error(err.Error()) - return - } - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*Value) (retVal interface{}, err error) { - retVal = args[0] - return - } - - result := rand.Int63() - retVal, err := fhandle.Invoke(funccall, result) - if err != nil { - t.Error(err.Error()) - return - } - if retVal.AsInt64() != result { - t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64()) - return - } -} - -// Check Int32 Value looping via packed function calling another packed function. -func TestValueLoopInt32(t *testing.T) { - // Receive a function Handle and argument and echo the Value on the handle. - sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - newArgs := args[1:] - - // Call Packed Function by Value - return pfunc.Invoke(newArgs) - } - - fhandle, err := ConvertFunction(sampleFunctionLoop) - if err != nil { - t.Error(err.Error()) - return - } - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*Value) (retVal interface{}, err error) { - retVal = args[0] - return - } - - result := rand.Int31() - retVal, err := fhandle.Invoke(funccall, result) - if err != nil { - t.Error(err.Error()) - return - } - - if retVal.AsInt64() != int64(result) { - t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64()) - return - } -} - -// Check Float32 Value looping via packed function calling another packed function. -func TestValueLoopFloat32(t *testing.T) { - // Receive a function Handle and argument and echo the Value on the handle. - sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - newArgs := args[1:] - // Call Packed Function by Value - return pfunc.Invoke(newArgs) - } - - fhandle, err := ConvertFunction(sampleFunctionLoop) - if err != nil { - t.Error(err.Error()) - return - } - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*Value) (retVal interface{}, err error) { - retVal = args[0] - return - } - - result := rand.Float32() - retVal, err := fhandle.Invoke(funccall, result) - if err != nil { - t.Error(err.Error()) - return - } - - if retVal.AsFloat64() != float64(result) { - t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64()) - return - } -} - -// Check Float64 Value looping via packed function calling another packed function. -func TestValueLoopFloat64(t *testing.T) { - // Receive a function Handle and argument and echo the Value on the handle. - sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - newArgs := args[1:] - // Call Packed Function by Value - return pfunc.Invoke(newArgs) - } - - fhandle, err := ConvertFunction(sampleFunctionLoop) - if err != nil { - t.Error(err.Error()) - return - } - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*Value) (retVal interface{}, err error) { - retVal = args[0] - return - } - - result := rand.Float64() - retVal, err := fhandle.Invoke(funccall, result) - if err != nil { - t.Error(err.Error()) - return - } - - if retVal.AsFloat64() != result { - t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64()) - return - } -} - -func TestValueLoopString(t *testing.T) { - // Receive a function Handle and argument and echo the Value on the handle. - sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - argStr := args[1].AsStr() - // Call Packed Function by Value - return pfunc.Invoke(argStr) - } - - fhandle, err := ConvertFunction(sampleFunctionLoop) - if err != nil { - t.Error(err.Error()) - return - } - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*Value) (retVal interface{}, err error) { - retVal = args[0].AsStr() - return - } - - retVal, err := fhandle.Invoke(funccall, "TestString") - if err != nil { - t.Error(err.Error()) - return - } - - vStr := retVal.AsStr() - if strings.Compare(vStr, string("TestString")) != 0 { - t.Errorf("Expected : %v got:%v\n", string("TestString"), vStr) - return - } -} - -// Check []byte Value looping via packed function calling another packed function. -func TestValueLoopByteSlice(t *testing.T) { - // Receive a function Handle and argument and echo the Value on the handle. - sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { - // Reveive Packed Function Handle - pfunc := args[0].AsFunction() - argBytes := args[1].AsBytes() - // Call Packed Function by Value - return pfunc.Invoke(argBytes) - } - - fhandle, err := ConvertFunction(sampleFunctionLoop) - if err != nil { - t.Error(err.Error()) - return - } - - // funccall is a simple golang callback function like C = A + B. - funccall := func (args ...*Value) (retVal interface{}, err error) { - retVal = args[0].AsBytes() - return - } - - result := make([]byte, 1024) - rand.Read(result) - retVal, err := fhandle.Invoke(funccall, result) - if err != nil { - t.Error(err.Error()) - return - } - - received := retVal.AsBytes() - if len(result) != len(received) { - t.Errorf("Data expected Len: %v Got :%v\n", len(result), len(received)) - return - } - for i := range result { - if result[i] != received[i] { - t.Errorf("Data expected: %v Got :%v at index %v\n", result[i], received[i], i) - return - } - } -} diff --git a/rust/.gitignore b/rust/.gitignore deleted file mode 100644 index 0cc660650780..000000000000 --- a/rust/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -target/ -*.rs.bk -Cargo.lock -c_runtime_api.rs diff --git a/rust/.rustfmt.toml b/rust/.rustfmt.toml deleted file mode 100644 index 95936dc4dec8..000000000000 --- a/rust/.rustfmt.toml +++ /dev/null @@ -1,31 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -max_width = 100 -hard_tabs = false -tab_spaces = 4 -newline_style = "Auto" -use_small_heuristics = "Default" -reorder_imports = true -reorder_modules = true -remove_nested_parens = true -fn_params_layout = "Tall" -edition = "2018" -merge_derives = true -use_try_shorthand = false -use_field_init_shorthand = false -force_explicit_abi = true diff --git a/rust/Cargo.toml b/rust/Cargo.toml deleted file mode 100644 index 26b4398b427d..000000000000 --- a/rust/Cargo.toml +++ /dev/null @@ -1,23 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[workspace] -members = [ - "tvm-sys", - "tvm-macros", - "tvm-rt" -] diff --git a/rust/tvm-macros/Cargo.toml b/rust/tvm-macros/Cargo.toml deleted file mode 100644 index 4300cb3f1dcb..000000000000 --- a/rust/tvm-macros/Cargo.toml +++ /dev/null @@ -1,37 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "tvm-macros" -version = "0.1.1-alpha" -license = "Apache-2.0" -description = "Procedural macros of the TVM crate." -repository = "https://github.com/apache/tvm" -readme = "README.md" -keywords = ["tvm"] -authors = ["TVM Contributors"] -edition = "2018" - -[lib] -proc-macro = true - -[dependencies] -goblin = "^0.2" -proc-macro2 = "^1.0" -quote = "^1.0" -syn = { version = "1.0.48", features = ["full", "parsing", "extra-traits"] } -proc-macro-error = "^1.0" diff --git a/rust/tvm-macros/README.md b/rust/tvm-macros/README.md deleted file mode 100644 index 8a7c4b301524..000000000000 --- a/rust/tvm-macros/README.md +++ /dev/null @@ -1,20 +0,0 @@ - - - - - - - - - - - - - - - - - -# tvm-macros - -The procedural macro implementations for TVM crates, see `tvm` crate for more documentation. diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs deleted file mode 100644 index 146f9d4d6bc6..000000000000 --- a/rust/tvm-macros/src/external.rs +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -use proc_macro2::Span; -use proc_macro_error::abort; -use quote::quote; -use syn::parse::{Parse, ParseStream, Result}; - -use syn::{ - token::Semi, Attribute, FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, - Signature, Type, Visibility, -}; - -struct ExternalItem { - attrs: Vec, - visibility: Visibility, - sig: Signature, -} - -impl Parse for ExternalItem { - fn parse(input: ParseStream) -> Result { - let item = ExternalItem { - attrs: input.call(Attribute::parse_outer)?, - visibility: input.parse()?, - sig: input.parse()?, - }; - let _semi: Semi = input.parse()?; - Ok(item) - } -} - -struct External { - visibility: Visibility, - tvm_name: String, - ident: Ident, - generics: Generics, - inputs: Vec, - ret_type: ReturnType, -} - -impl Parse for External { - fn parse(input: ParseStream) -> Result { - let method: ExternalItem = input.parse()?; - let visibility = method.visibility; - assert_eq!(method.attrs.len(), 1); - let sig = method.sig; - let tvm_name = method.attrs[0].parse_meta()?; - let tvm_name = match tvm_name { - Meta::List(meta_list) => { - let name = meta_list.path.get_ident().expect("name"); - assert_eq!(name.to_string(), "name".to_string()); - match meta_list.nested.first() { - Some(NestedMeta::Lit(Lit::Str(lit))) => lit.value(), - _ => panic!(), - } - } - _ => panic!(), - }; - - let ident = sig.ident; - let generics = sig.generics; - let inputs = sig - .inputs - .iter() - .cloned() - .map(|param| param.clone()) - .collect(); - let ret_type = sig.output; - - Ok(External { - visibility, - tvm_name, - ident, - generics, - inputs, - ret_type, - }) - } -} - -struct ExternalInput { - externs: Vec, -} - -impl Parse for ExternalInput { - fn parse(input: ParseStream) -> Result { - let mut externs: Vec = Vec::new(); - - loop { - if input.is_empty() { - break; - } - externs.push(input.parse()?); - } - - Ok(ExternalInput { externs }) - } -} - -pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let ext_input = syn::parse_macro_input!(input as ExternalInput); - - let tvm_rt_crate = crate::util::get_tvm_rt_crate(); - - let result_type = quote! { #tvm_rt_crate::function::Result }; - - let mut items = Vec::new(); - - for external in &ext_input.externs { - let visibility = &external.visibility; - let name = &external.ident; - let global_name = format!("global_{}", external.ident); - let global_name = Ident::new(&global_name, Span::call_site()); - let ext_name = &external.tvm_name; - - let ty_params: Vec = external - .generics - .params - .iter() - .map(|ty_param| match ty_param { - syn::GenericParam::Type(param) => param.clone(), - _ => abort! { ty_param, - "Only supports type parameters." - }, - }) - .collect(); - - let args = &external.inputs; - - let (args, tys): (Vec, Vec) = args - .iter() - .map(|arg| match arg { - FnArg::Typed(pat_type) => match &*pat_type.pat { - Pat::Ident(pat_ident) => { - let ident: Ident = pat_ident.ident.clone(); - let ty: Type = *pat_type.ty.clone(); - (ident, ty) - } - _ => abort! { pat_type, - "Only supports type parameters." - }, - }, - pat => abort! { - pat, "invalid pattern type for function"; - - note = "{:?} is not allowed here", pat; - }, - }) - .unzip(); - - let ret_type = match &external.ret_type { - ReturnType::Type(_, rtype) => *rtype.clone(), - ReturnType::Default => syn::parse_str::("()").unwrap(), - }; - - let global = quote! { - #[allow(non_upper_case_globals)] - static #global_name: ::once_cell::sync::Lazy<#tvm_rt_crate::Function> = - ::once_cell::sync::Lazy::new(|| { - #tvm_rt_crate::Function::get(#ext_name) - .expect(concat!("unable to load external function", stringify!(#ext_name), "from TVM registry.")) - }); - }; - - items.push(global); - - let wrapper = quote! { - #visibility fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> { - let func_ref: #tvm_rt_crate::Function = #global_name.clone(); - let func_ref: Box #result_type<#ret_type>> = func_ref.into(); - let res: #ret_type = func_ref(#(#args),*)?; - Ok(res) - } - }; - - items.push(wrapper); - } - - proc_macro::TokenStream::from(quote! { - #(#items - )* - }) -} diff --git a/rust/tvm-macros/src/import_module.rs b/rust/tvm-macros/src/import_module.rs deleted file mode 100644 index bebf73b2528f..000000000000 --- a/rust/tvm-macros/src/import_module.rs +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -use quote::quote; -use std::{fs::File, io::Read}; -use syn::parse::{Parse, ParseStream, Result}; -use syn::LitStr; - -use std::path::PathBuf; - -struct ImportModule { - importing_file: LitStr, -} - -impl Parse for ImportModule { - fn parse(input: ParseStream) -> Result { - let importing_file: LitStr = input.parse()?; - Ok(ImportModule { importing_file }) - } -} - -pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let import_module_args = syn::parse_macro_input!(input as ImportModule); - - let manifest = - std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo."); - - let mut path = PathBuf::new(); - path.push(manifest); - path = path.join(import_module_args.importing_file.value()); - - let mut fd = File::open(&path) - .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display())); - let mut buffer = Vec::new(); - fd.read_to_end(&mut buffer).unwrap(); - - let fn_names = match goblin::Object::parse(&buffer).unwrap() { - goblin::Object::Elf(elf) => elf - .syms - .iter() - .filter_map(|s| { - if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" { - return None; - } - match elf.strtab.get(s.st_name) { - Some(Ok(name)) if name != "" => { - Some(syn::Ident::new(name, proc_macro2::Span::call_site())) - } - _ => None, - } - }) - .collect::>(), - goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => { - obj.symbols() - .filter_map(|s| match s { - Ok((name, ref nlist)) - if nlist.is_global() - && nlist.n_sect != 0 - && !name.ends_with("tvm_module_ctx") => - { - Some(syn::Ident::new( - if name.starts_with('_') { - // Mach objects prepend a _ to globals. - &name[1..] - } else { - &name - }, - proc_macro2::Span::call_site(), - )) - } - _ => None, - }) - .collect::>() - } - _ => panic!("Unsupported object format."), - }; - - let extern_fns = quote! { - mod ext { - extern "C" { - #( - pub(super) fn #fn_names( - args: *const tvm_graph_rt::ffi::TVMValue, - type_codes: *const std::os::raw::c_int, - num_args: std::os::raw::c_int - ) -> std::os::raw::c_int; - )* - } - } - }; - - let fns = quote! { - use tvm_graph_rt::{ffi::TVMValue, ArgValue, RetValue, FuncCallError}; - #extern_fns - - #( - pub fn #fn_names(args: &[ArgValue]) -> Result { - let (values, type_codes): (Vec, Vec) = args - .into_iter() - .map(|arg| { - let (val, code) = arg.to_tvm_value(); - (val, code as i32) - }) - .unzip(); - let exit_code = unsafe { - ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32) - }; - if exit_code == 0 { - Ok(RetValue::default()) - } else { - Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string())) - } - } - )* - }; - - proc_macro::TokenStream::from(fns) -} diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs deleted file mode 100644 index e563a57f149e..000000000000 --- a/rust/tvm-macros/src/lib.rs +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use proc_macro::TokenStream; -use proc_macro_error::proc_macro_error; - -mod external; -mod import_module; -mod object; -mod util; - -#[proc_macro] -pub fn import_module(input: TokenStream) -> TokenStream { - import_module::macro_impl(input) -} - -#[proc_macro_error] -#[proc_macro_derive(Object, attributes(base, ref_name, type_key, no_derive))] -pub fn macro_impl(input: TokenStream) -> TokenStream { - // let input = proc_macro2::TokenStream::from(input); - TokenStream::from(object::macro_impl(input)) -} - -#[proc_macro_error] -#[proc_macro] -pub fn external(input: TokenStream) -> TokenStream { - external::macro_impl(input) -} diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs deleted file mode 100644 index 4134da5fe6d9..000000000000 --- a/rust/tvm-macros/src/object.rs +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use proc_macro::TokenStream; -use proc_macro2::Span; -use quote::quote; -use syn::DeriveInput; -use syn::Ident; - -use crate::util::*; - -pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { - let tvm_rt_crate = get_tvm_rt_crate(); - let result = quote! { #tvm_rt_crate::function::Result }; - let error = quote! { #tvm_rt_crate::errors::Error }; - let derive_input = syn::parse_macro_input!(input as DeriveInput); - let payload_id = derive_input.ident.clone(); - - let type_key = get_attr(&derive_input, "type_key") - .map(attr_to_str) - .expect("Failed to get type_key"); - - let derive = get_attr(&derive_input, "no_derive") - .map(|_| false) - .unwrap_or(true); - - let ref_id = get_attr(&derive_input, "ref_name") - .map(|a| Ident::new(attr_to_str(a).value().as_str(), Span::call_site())) - .unwrap_or_else(|| { - let id = payload_id.to_string(); - let suffixes = ["Node", "Obj"]; - if let Some(suf) = suffixes - .iter() - .find(|&suf| id.len() > suf.len() && id.ends_with(suf)) - { - Ident::new(&id[..id.len() - suf.len()], payload_id.span()) - } else { - panic!( - "Either 'ref_name' must be given, or the struct name must end one of {:?}", - suffixes - ) - } - }); - - let base_tokens = match &derive_input.data { - syn::Data::Struct(s) => s.fields.iter().next().and_then(|f| { - let (base_id, base_ty) = (f.ident.clone()?, f.ty.clone()); - if base_id == "base" { - // The transitive case of subtyping - Some(quote! { - impl AsRef for #payload_id - where #base_ty: AsRef - { - fn as_ref(&self) -> &O { - self.#base_id.as_ref() - } - } - }) - } else { - None - } - }), - _ => panic!("derive only works for structs"), - }; - - let ref_derives = if derive { - quote! { #[derive(Debug, Clone)]} - } else { - quote! { #[derive(Clone)] } - }; - - let mut expanded = quote! { - unsafe impl #tvm_rt_crate::object::IsObject for #payload_id { - const TYPE_KEY: &'static str = #type_key; - } - - // a silly AsRef impl is necessary for subtyping to work - impl AsRef<#payload_id> for #payload_id { - fn as_ref(&self) -> &Self { - self - } - } - - #ref_derives - pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>); - - impl #tvm_rt_crate::object::IsObjectRef for #ref_id { - type Object = #payload_id; - - fn as_ptr(&self) -> Option<&#tvm_rt_crate::object::ObjectPtr> { - self.0.as_ref() - } - - fn into_ptr(self) -> Option<#tvm_rt_crate::object::ObjectPtr> { - self.0 - } - - fn from_ptr(object_ptr: Option<#tvm_rt_crate::object::ObjectPtr>) -> Self { - #ref_id(object_ptr) - } - } - - impl std::ops::Deref for #ref_id { - type Target = #payload_id; - - fn deref(&self) -> &Self::Target { - self.0.as_ref().unwrap() - } - } - - impl std::convert::From<#payload_id> for #ref_id { - fn from(payload: #payload_id) -> Self { - let ptr = #tvm_rt_crate::object::ObjectPtr::new(payload); - #tvm_rt_crate::object::IsObjectRef::from_ptr(Some(ptr)) - } - } - - impl std::convert::From<#tvm_rt_crate::object::ObjectPtr<#payload_id>> for #ref_id { - fn from(ptr: #tvm_rt_crate::object::ObjectPtr<#payload_id>) -> Self { - #tvm_rt_crate::object::IsObjectRef::from_ptr(Some(ptr)) - } - } - - impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id { - type Error = #error; - - fn try_from(ret_val: #tvm_rt_crate::RetValue) -> #result<#ref_id> { - use std::convert::TryInto; - let ptr: #tvm_rt_crate::object::ObjectPtr<#payload_id> = ret_val.try_into()?; - Ok(ptr.into()) - } - } - - impl<'a> From<&'a #ref_id> for #tvm_rt_crate::ArgValue<'a> { - fn from(object_ref: &'a #ref_id) -> #tvm_rt_crate::ArgValue<'a> { - use std::ffi::c_void; - let object_ptr = &object_ref.0; - match object_ptr { - None => { - #tvm_rt_crate::ArgValue:: - ObjectHandle(std::ptr::null::() as *mut c_void) - } - Some(value) => value.into() - } - } - } - - impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id { - type Error = #error; - - fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> #result<#ref_id> { - use std::convert::TryInto; - let optr = arg_value.try_into()?; - Ok(#ref_id(Some(optr))) - } - } - - - impl From<#ref_id> for #tvm_rt_crate::RetValue { - fn from(object_ref: #ref_id) -> #tvm_rt_crate::RetValue { - use std::ffi::c_void; - let object_ptr = &object_ref.0; - match object_ptr { - None => { - #tvm_rt_crate::RetValue::ObjectHandle(std::ptr::null::() as *mut c_void) - } - Some(value) => value.clone().into() - } - } - } - }; - - expanded.extend(base_tokens); - - if derive { - let derives = quote! { - impl std::hash::Hash for #ref_id { - fn hash(&self, state: &mut H) { - self.0.hash(state) - } - } - - impl std::cmp::PartialEq for #ref_id { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - } - } - - impl std::cmp::Eq for #ref_id {} - }; - - expanded.extend(derives); - } - - TokenStream::from(expanded) -} diff --git a/rust/tvm-macros/src/util.rs b/rust/tvm-macros/src/util.rs deleted file mode 100644 index b02e3f69b671..000000000000 --- a/rust/tvm-macros/src/util.rs +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use proc_macro2::TokenStream; -use quote::quote; -use std::env; - -pub fn get_tvm_rt_crate() -> TokenStream { - if env::var("CARGO_PKG_NAME").unwrap() == "tvm-rt" { - quote!(crate) - } else { - quote!(tvm_rt) - } -} - -pub(crate) fn get_attr<'a>( - derive_input: &'a syn::DeriveInput, - name: &str, -) -> Option<&'a syn::Attribute> { - derive_input.attrs.iter().find(|a| a.path.is_ident(name)) -} - -pub(crate) fn attr_to_str(attr: &syn::Attribute) -> syn::LitStr { - match attr.parse_meta() { - Ok(syn::Meta::NameValue(syn::MetaNameValue { - lit: syn::Lit::Str(s), - .. - })) => s, - Ok(m) => panic!("Expected a string literal, got {:?}", m), - Err(e) => panic!("{}", e), - } -} diff --git a/rust/tvm-rt/.gitignore b/rust/tvm-rt/.gitignore deleted file mode 100644 index 2430329c78b6..000000000000 --- a/rust/tvm-rt/.gitignore +++ /dev/null @@ -1,7 +0,0 @@ -target -**/*.rs.bk -Cargo.lock -/tests/basics/add_* -/examples/resnet/deploy_* -/examples/resnet/*.png -/examples/resnet/synset.* diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml deleted file mode 100644 index cb8c560c3efa..000000000000 --- a/rust/tvm-rt/Cargo.toml +++ /dev/null @@ -1,95 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "tvm-rt" -version = "0.1.0-alpha" -license = "Apache-2.0" -description = "Rust bindings for the TVM runtime API." -repository = "https://github.com/apache/tvm" -homepage = "https://github.com/apache/tvm" -readme = "README.md" -keywords = ["rust", "tvm"] -categories = ["api-bindings", "science"] -authors = ["TVM Contributors"] -edition = "2018" - -[features] -default = ["dynamic-linking"] -dynamic-linking = ["tvm-sys/dynamic-linking"] -static-linking = ["tvm-sys/static-linking"] -standalone = ["tvm-sys/runtime-only"] -runtime-only = ["tvm-sys/runtime-only"] -blas = ["ndarray/blas"] -# Enabling any of the following features is like setting the value to "ON" in config.cmake. -use-cuda = ["tvm-sys/use-cuda"] -use-opencl = ["tvm-sys/use-opencl"] -use-vulkan = ["tvm-sys/use-vulkan"] -use-metal = ["tvm-sys/use-metal"] -use-rocm = ["tvm-sys/use-rocm"] -use-hexagon-device = ["tvm-sys/use-hexagon-device"] -use-rpc = ["tvm-sys/use-rpc"] -use-threads = ["tvm-sys/use-threads"] -use-llvm = ["tvm-sys/use-llvm"] -use-stackvm-runtime = ["tvm-sys/use-stackvm-runtime"] -use-openmp = ["tvm-sys/use-openmp"] -use-rtti = ["tvm-sys/use-rtti"] -use-mscv-mt = ["tvm-sys/use-mscv-mt"] -use-install-dev = ["tvm-sys/use-install-dev"] -hide-private-symbols = ["tvm-sys/hide-private-symbols"] -use-fallback-stl-map = ["tvm-sys/use-fallback-stl-map"] -use-index-default-i64 = ["tvm-sys/use-index-default-i64"] -use-tf-tvmdsoop = ["tvm-sys/use-tf-tvmdsoop"] -use-byodt-posit = ["tvm-sys/use-byodt-posit"] -use-mkl = ["tvm-sys/use-mkl"] -use-mkldnn = ["tvm-sys/use-mkldnn"] -use-dnnl-codegen = ["tvm-sys/use-dnnl-codegen"] -use-cudnn = ["tvm-sys/use-cudnn"] -use-cublas = ["tvm-sys/use-cublas"] -use-thrust = ["tvm-sys/use-thrust"] -use-miopen = ["tvm-sys/use-miopen"] -use-rocblas = ["tvm-sys/use-rocblas"] -use-sort = ["tvm-sys/use-sort"] -use-nnpack = ["tvm-sys/use-nnpack"] -use-random = ["tvm-sys/use-random"] -use-cpp-rpc = ["tvm-sys/use-cpp-rpc"] -use-tflite = ["tvm-sys/use-tflite"] -use-coreml = ["tvm-sys/use-coreml"] -use-target-onnx = ["tvm-sys/use-target-onnx"] -use-arm-compute-lib = ["tvm-sys/use-arm-compute-lib"] -use-arm-compute-lib-graph-runtime = ["tvm-sys/use-arm-compute-lib-graph-runtime"] -use-tensorrt-codegen = ["tvm-sys/use-tensorrt-codegen"] -use-tensorrt-runtime = ["tvm-sys/use-tensorrt-runtime"] -build-static-runtime = ["tvm-sys/build-static-runtime"] - -[dependencies] -thiserror = "^1.0" -ndarray = "0.12" -num-traits = "0.2" -tvm-macros = { version = "0.1.1-alpha", path = "../tvm-macros" } -paste = "0.1" -mashup = "0.1" -once_cell = "^1.3.1" -memoffset = "0.5.6" - -[dependencies.tvm-sys] -version = "0.1.1-alpha" -default-features = false -path = "../tvm-sys/" - -[dev-dependencies] -anyhow = "^1.0" diff --git a/rust/tvm-rt/README.md b/rust/tvm-rt/README.md deleted file mode 100644 index 58b1f8a30a39..000000000000 --- a/rust/tvm-rt/README.md +++ /dev/null @@ -1,60 +0,0 @@ - - - - - - - - - - - - - - - - - -# TVM Runtime Support - -This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/tvm) runtime, -see [here](https://github.com/apache/tvm/blob/main/rust/tvm/README.md) for more details. - -## What Does This Crate Offer? - -TVM is an end-to-end deep learning compiler which takes high level machine learning -models or tensor computations and lowers them into executable code for a variety -of heterogenous devices (e.g., CPU, GPU). - -This crate provides access to the APIs for manipulating runtime data structures, -as well as TVM's cross-language Object system which functions similarly to systems -such as COM, enabling cross-language interoperability. - -## Installations - -Please follow TVM [installation](https://tvm.apache.org/docs/install/index.html) instructions, -`export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. - -### Example of registering a cross-language closure. - -One can use `register!` macro to expose a Rust closure with arguments which implement `TryFrom` -and return types which implement `Into`. Once registered with TVM these functions can be -accessed via Python or C++, or any other language which implements the TVM packed function convention -see the offcial documentation for more information. - -```rust -use tvm_rt::{ArgValue, RetValue}; -use tvm_rt::function::{Function, Result, register}; - -fn sum(x: i64, y: i64, z: i64) -> i64 { - x + y + z -} - -fn main() { - register(sum, "mysum".to_owned()).unwrap(); - let func = Function::get("mysum").unwrap(); - let boxed_fn: Box Result> = func.into(); - let ret = boxed_fn(10, 20, 30).unwrap(); - assert_eq!(ret, 60); -} -``` diff --git a/rust/tvm-rt/src/device.rs b/rust/tvm-rt/src/device.rs deleted file mode 100644 index b1cb58cd54cf..000000000000 --- a/rust/tvm-rt/src/device.rs +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::os::raw::c_void; -use std::ptr; - -use crate::errors::Error; - -use tvm_sys::ffi; - -pub use tvm_sys::device::*; - -trait DeviceExt { - /// Checks whether the device exists or not. - fn exist(&self) -> bool; - fn sync(&self) -> Result<(), Error>; - fn max_threads_per_block(&self) -> isize; - fn warp_size(&self) -> isize; - fn max_shared_memory_per_block(&self) -> isize; - fn compute_version(&self) -> isize; - fn device_name(&self) -> isize; - fn max_clock_rate(&self) -> isize; - fn multi_processor_count(&self) -> isize; - fn max_thread_dimensions(&self) -> isize; -} - -macro_rules! impl_device_attrs { - ($(($attr_name:ident, $attr_kind:expr));+) => { - $( - fn $attr_name(&self) -> isize { - get_device_attr(self.device_type as i32, self.device_id as i32, 0) - .expect("should not fail") as isize - } - - )+ - }; -} - -crate::external! { - #[name("runtime.GetDeviceAttr")] - fn get_device_attr(device_type: i32, device_id: i32, device_kind: i32) -> i32; -} - -impl DeviceExt for Device { - fn exist(&self) -> bool { - let exists = get_device_attr(self.device_type as i32, self.device_id as i32, 0) - .expect("should not fail"); - - exists != 0 - } - - /// Synchronize the device stream. - fn sync(&self) -> Result<(), Error> { - check_call!(ffi::TVMSynchronize( - self.device_type as i32, - self.device_id as i32, - ptr::null_mut() as *mut c_void - )); - Ok(()) - } - - impl_device_attrs!((max_threads_per_block, 1); - (warp_size, 2); - (max_shared_memory_per_block, 3); - (compute_version, 4); - (device_name, 5); - (max_clock_rate, 6); - (multi_processor_count, 7); - (max_thread_dimensions, 8)); -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn sync() { - let dev = Device::cpu(0); - assert!(dev.sync().is_ok()) - } -} diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs deleted file mode 100644 index 31ce385ef662..000000000000 --- a/rust/tvm-rt/src/errors.rs +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::DataType; -use thiserror::Error; - -#[derive(Debug, Error)] -#[error("Function was not set in `function::Builder`")] -pub struct FunctionNotFoundError; - -#[derive(Debug, Error)] -#[error("Expected type `{expected}` but found `{actual}`")] -pub struct TypeMismatchError { - pub expected: String, - pub actual: String, -} - -#[derive(Debug, Error)] -pub enum NDArrayError { - #[error("Cannot convert from an empty array.")] - EmptyArray, - #[error("Invalid datatype when attempting to convert ndarray.")] - InvalidDatatype(#[from] tvm_sys::datatype::ParseDataTypeError), - #[error("a shape error occurred in the Rust ndarray library")] - ShapeError(#[from] ndarray::ShapeError), - #[error("Expected type `{expected}` but found `{actual}`")] - DataTypeMismatch { - expected: DataType, - actual: DataType, - }, -} - -#[derive(Debug, Error)] -pub enum Error { - #[error("{0}")] - Downcast(#[from] tvm_sys::errors::ValueDowncastError), - #[error("raw pointer passed across boundary was null")] - Null, - #[error("failed to load module due to invalid path {0}")] - ModuleLoadPath(String), - #[error("failed to convert String into CString due to embedded nul character")] - ToCString(#[from] std::ffi::NulError), - #[error("failed to convert CString into String")] - FromCString(#[from] std::ffi::IntoStringError), - #[error("Handle `{0}` is null.")] - NullHandle(String), - #[error("{0}")] - NDArray(#[from] NDArrayError), - #[error("{0}")] - CallFailed(String), - #[error("this case will never occur")] - Infallible(#[from] std::convert::Infallible), - #[error("a panic occurred while executing a Rust packed function")] - Panic, - #[error( - "one or more error diagnostics were emitted, please check diagnostic render for output." - )] - DiagnosticError(String), - #[error("{0}")] - Raw(String), -} - -impl Error { - pub fn from_raw_tvm(raw: &str) -> Error { - let err_header = raw.find(":").unwrap_or(0); - let (err_ty, err_content) = raw.split_at(err_header); - match err_ty { - "DiagnosticError" => Error::DiagnosticError((&err_content[1..]).into()), - _ => Error::Raw(raw.into()), - } - } -} - -impl Error { - pub fn downcast(actual_type: String, expected_type: &'static str) -> Error { - Self::Downcast(tvm_sys::errors::ValueDowncastError { - actual_type, - expected_type, - }) - } -} diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs deleted file mode 100644 index 62474e6650d4..000000000000 --- a/rust/tvm-rt/src/function.rs +++ /dev/null @@ -1,354 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! This module provides an idiomatic Rust API for creating and working with TVM functions. -//! -//! For calling an already registered TVM function use [`function::Builder`] -//! To register a TVM packed function from Rust side either -//! use [`function::register`] or the macro [`register_global_func`]. -//! -//! See the tests and examples repository for more examples. - -use std::convert::{TryFrom, TryInto}; -use std::sync::Arc; -use std::{ - ffi::CString, - os::raw::{c_char, c_int}, - ptr, str, -}; - -use crate::errors::Error; - -pub use super::to_function::{RawArgs, ToFunction, Typed}; -use crate::object::AsArgValue; -pub use tvm_sys::{ffi, ArgValue, RetValue}; - -pub type Result = std::result::Result; - -#[derive(Debug, Hash)] -struct FunctionPtr { - handle: ffi::TVMFunctionHandle, -} - -// NB(@jroesch): I think this is ok, need to double check, -// if not we should mutex the pointer or move to Rc. -unsafe impl Send for FunctionPtr {} -unsafe impl Sync for FunctionPtr {} - -impl FunctionPtr { - fn from_raw(handle: ffi::TVMFunctionHandle) -> Self { - FunctionPtr { handle } - } -} - -impl Drop for FunctionPtr { - fn drop(&mut self) { - check_call!(ffi::TVMFuncFree(self.handle)); - } -} - -/// An owned thread-safe version of `tvm::PackedFunc` for consumption in Rust. -#[derive(Debug, Hash)] -pub struct Function { - inner: Arc, -} - -impl Function { - pub(crate) fn from_raw(handle: ffi::TVMFunctionHandle) -> Self { - Function { - inner: Arc::new(FunctionPtr::from_raw(handle)), - } - } - - pub unsafe fn null() -> Self { - Function::from_raw(std::ptr::null_mut()) - } - - /// For a given function, it returns a function by name. - pub fn get>(name: S) -> Option { - let name = CString::new(name.as_ref()).unwrap(); - let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle; - - check_call!(ffi::TVMFuncGetGlobal( - name.as_ptr() as *const c_char, - &mut handle as *mut _ - )); - - if handle.is_null() { - None - } else { - Some(Function::from_raw(handle)) - } - } - - pub fn get_boxed(name: S) -> Option> - where - S: AsRef, - F: ?Sized, - Self: Into>, - { - Self::get(name).map(|f| f.into()) - } - - /// Returns the underlying TVM function handle. - pub fn handle(&self) -> ffi::TVMFunctionHandle { - self.inner.handle - } - - /// Calls the function that created from `Builder`. - pub fn invoke<'a>(&self, arg_buf: Vec>) -> Result { - let num_args = arg_buf.len(); - let (mut values, mut type_codes): (Vec, Vec) = - arg_buf.into_iter().map(|arg| arg.to_tvm_value()).unzip(); - - let mut ret_val = ffi::TVMValue { v_int64: 0 }; - let mut ret_type_code = 0i32; - - let ret_code = unsafe { - ffi::TVMFuncCall( - self.handle(), - values.as_mut_ptr() as *mut ffi::TVMValue, - type_codes.as_mut_ptr() as *mut c_int, - num_args as c_int, - &mut ret_val as *mut _, - &mut ret_type_code as *mut _, - ) - }; - - if ret_code != 0 { - let raw_error = crate::get_last_error(); - let error = match Error::from_raw_tvm(raw_error) { - Error::Raw(string) => Error::CallFailed(string), - e => e, - }; - return Err(error); - } - - let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32); - - Ok(rv) - } -} - -macro_rules! impl_to_fn { - () => { impl_to_fn!(@impl); }; - ($t:ident, $($ts:ident,)*) => { impl_to_fn!(@impl $t, $($ts,)*); impl_to_fn!($($ts,)*); }; - (@impl $($t:ident,)*) => { - impl From for Box Result> - where - Error: From, - Out: TryFrom, - $($t: for<'a> AsArgValue<'a>),* - { - fn from(func: Function) -> Self { - #[allow(non_snake_case)] - Box::new(move |$($t : $t),*| { - let args = vec![ $((&$t).as_arg_value()),* ]; - Ok(func.invoke(args)?.try_into()?) - }) - } - } - }; -} - -impl_to_fn!(T1, T2, T3, T4, T5, T6,); - -impl Clone for Function { - fn clone(&self) -> Function { - Function { - inner: self.inner.clone(), - } - } -} - -impl From for RetValue { - fn from(func: Function) -> RetValue { - RetValue::FuncHandle(func.handle()) - } -} - -impl TryFrom for Function { - type Error = Error; - - fn try_from(ret_value: RetValue) -> Result { - match ret_value { - RetValue::FuncHandle(handle) => Ok(Function::from_raw(handle)), - _ => Err(Error::downcast( - format!("{:?}", ret_value), - "FunctionHandle", - )), - } - } -} - -impl<'a> From<&'a Function> for ArgValue<'a> { - fn from(func: &'a Function) -> ArgValue<'a> { - if func.handle().is_null() { - ArgValue::Null - } else { - ArgValue::FuncHandle(func.handle()) - } - } -} - -impl<'a> TryFrom> for Function { - type Error = Error; - - fn try_from(arg_value: ArgValue<'a>) -> Result { - match arg_value { - ArgValue::FuncHandle(handle) => Ok(Function::from_raw(handle)), - _ => Err(Error::downcast( - format!("{:?}", arg_value), - "FunctionHandle", - )), - } - } -} - -impl<'a> TryFrom<&ArgValue<'a>> for Function { - type Error = Error; - - fn try_from(arg_value: &ArgValue<'a>) -> Result { - match arg_value { - ArgValue::FuncHandle(handle) => Ok(Function::from_raw(*handle)), - _ => Err(Error::downcast( - format!("{:?}", arg_value), - "FunctionHandle", - )), - } - } -} - -/// Registers a Rust function with an arbitrary type signature in -/// the TVM registry. -/// -/// -/// A function is convertible if and only if its arguments and return types are convertible -/// to and from TVM values respectively. -/// -/// Use [`register_override`] if control of overriding existing global TVM function -/// is required, this function will panic if a function is already registered. -/// -/// ## Example -/// -/// ``` -/// # use tvm_rt::{ArgValue, RetValue}; -/// # use tvm_rt::function::{Function, Result, register}; -/// -/// fn sum(x: i64, y: i64, z: i64) -> i64 { -/// x + y + z -/// } -/// -/// register(sum, "mysum".to_owned()).unwrap(); -/// let func = Function::get("mysum").unwrap(); -/// let boxed_fn: Box Result> = func.into(); -/// let ret = boxed_fn(10, 20, 30).unwrap(); -/// assert_eq!(ret, 60); -/// ``` -pub fn register>(f: F, name: S) -> Result<()> -where - F: ToFunction, - F: Typed, -{ - register_override(f, name, false) -} - -/// Register a function with explicit control over whether to override an existing registration or not. -/// -/// See `register` for more details on how to use the registration API. -pub fn register_override>(f: F, name: S, override_: bool) -> Result<()> -where - F: ToFunction, - F: Typed, -{ - let func = f.to_function(); - let name = name.into(); - // Not sure about this code - let handle = func.handle(); - let name = CString::new(name)?; - check_call!(ffi::TVMFuncRegisterGlobal( - name.into_raw(), - handle, - override_ as c_int - )); - - Ok(()) -} - -pub fn register_untyped>( - f: for<'a> fn(Vec>) -> Result, - name: S, - override_: bool, -) -> Result<()> { - //TODO(@jroesch): can we unify the untpyed and typed registration functions. - let func = ToFunction::::to_function(f); - let name = name.into(); - // Not sure about this code - let handle = func.handle(); - let name = CString::new(name)?; - check_call!(ffi::TVMFuncRegisterGlobal( - name.into_raw(), - handle, - override_ as c_int - )); - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::function::Function; - - static CANARY: &str = "runtime.ModuleLoadFromFile"; - - #[test] - fn get_fn() { - assert!(Function::get(CANARY).is_some()); - assert!(Function::get("does not exists!").is_none()); - } - - #[test] - fn register_and_call_closure0() { - use crate::function; - use function::Result; - - fn constfn() -> i64 { - return 10; - } - - function::register_override(constfn, "constfn".to_owned(), true).unwrap(); - - let func = Function::get_boxed:: Result, _>("constfn").unwrap(); - let ret = func().unwrap(); - assert_eq!(ret, 10); - } - - #[test] - fn register_and_call_closure1() { - use crate::function::{self}; - - fn ident(x: i64) -> i64 { - return x; - } - - function::register_override(ident, "ident".to_owned(), true).unwrap(); - let func = Function::get_boxed:: Result, _>("ident").unwrap(); - assert_eq!(func(60).unwrap(), 60); - } -} diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs deleted file mode 100644 index 921117abaee4..000000000000 --- a/rust/tvm-rt/src/lib.rs +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! [TVM](https://github.com/apache/tvm) is a compiler stack for deep learning systems. -//! -//! This crate provides an idiomatic Rust API for TVM runtime. -//! -//! The TVM runtime API contains the data structures used by higher-level TVM executors. -//! Specifically it exposes the basic types such as NDArray, as well as the more general object system. -//! The TVM object system enables cross-language interoperability including that of closures for all -//! supported languages including C++, and Python. - -// Macro to check the return call to TVM runtime shared library. - -#[macro_export] -macro_rules! tvm_call { - ($e:expr) => {{ - if unsafe { $e } != 0 { - Err($crate::get_last_error().into()) - } else { - Ok(()) - } - }}; -} - -#[macro_export] -macro_rules! check_call { - ($e:expr) => {{ - if unsafe { $e } != 0 { - panic!("{}", $crate::get_last_error()); - } - }}; -} - -// Define all sumodules. -pub mod device; -pub mod errors; -pub mod function; -pub mod module; -pub mod ndarray; -pub mod object; -pub mod string; -mod to_function; - -pub use object::*; -pub use string::*; - -use std::{ - ffi::{CStr, CString}, - str, -}; - -pub use crate::{ - device::{Device, DeviceType}, - errors::*, - function::Function, - module::Module, - ndarray::NDArray, -}; - -pub use function::{ArgValue, RetValue}; -pub use tvm_sys::byte_array::ByteArray; -pub use tvm_sys::datatype::DataType; -use tvm_sys::ffi; - -pub use tvm_macros::external; - -/// Gets the last error message. -pub fn get_last_error() -> &'static str { - unsafe { - match CStr::from_ptr(ffi::TVMGetLastError()).to_str() { - Ok(s) => s, - Err(_) => "Invalid UTF-8 message", - } - } -} - -pub(crate) fn set_last_error(err: &E) { - let c_string = CString::new(err.to_string()).unwrap(); - unsafe { - ffi::TVMAPISetLastError(c_string.as_ptr()); - } -} - -/// Outputs the current TVM version. -pub fn version() -> &'static str { - match str::from_utf8(ffi::TVM_VERSION) { - Ok(s) => s, - Err(_) => "Invalid UTF-8 string", - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ByteArray, DataType, Device}; - use std::{convert::TryInto, str::FromStr}; - - #[test] - fn print_version() { - println!("TVM version: {}", version()); - } - - #[test] - fn set_error() { - let err = errors::NDArrayError::EmptyArray; - set_last_error(&err); - assert_eq!( - get_last_error().trim(), - errors::NDArrayError::EmptyArray.to_string() - ); - } - - // todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership. - // #[test] - // fn bytearray() { - // let w = vec![1u8, 2, 3, 4, 5]; - // let v = ByteArray::from(w.as_slice()); - // let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); - // assert_eq!( - // tvm.data(), - // w.iter().copied().collect::>().as_slice() - // ); - // } - - #[test] - fn ty() { - let t = DataType::from_str("int32").unwrap(); - let tvm: DataType = RetValue::from(t).try_into().unwrap(); - assert_eq!(tvm, t); - } - - #[test] - fn device() { - let c = Device::from_str("cuda").unwrap(); - let tvm: Device = RetValue::from(c).try_into().unwrap(); - assert_eq!(tvm, c); - } -} diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs deleted file mode 100644 index 754ebf44262e..000000000000 --- a/rust/tvm-rt/src/module.rs +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! Provides the [`Module`] type and methods for working with runtime TVM modules. - -use std::{ - ffi::CString, - os::raw::{c_char, c_int}, - path::Path, - ptr, -}; - -use crate::object::Object; -use tvm_macros::Object; -use tvm_sys::ffi; - -use crate::errors::Error; -use crate::String as TString; -use crate::{errors, function::Function}; - -/// Wrapper around TVM module handle which contains an entry function. -/// The entry function can be applied to an imported module through [`entry_func`]. -/// -/// [`entry_func`]:struct.Module.html#method.entry_func -#[repr(C)] -#[derive(Object, Debug)] -#[ref_name = "Module"] -#[type_key = "runtime.Module"] -pub struct ModuleNode { - base: Object, -} - -crate::external! { - #[name("runtime.RuntimeEnabled")] - fn runtime_enabled(target: CString) -> bool; - - #[name("runtime.ModuleLoadFromFile")] - fn load_from_file(file_name: CString, format: CString) -> Module; - - #[name("runtime.ModuleSaveToFile")] - fn save_to_file(module: Module, name: TString, fmt: TString); - - // TODO(@jroesch): we need to refactor this - #[name("tvm.relax.module_export_library")] - fn export_library(module: Module, file_name: TString); -} - -impl Module { - pub fn default_fn(&mut self) -> Result { - self.get_function("default", true) - } - - /// Gets a function by name from a registered module. - pub fn get_function(&self, name: &str, query_import: bool) -> Result { - let name = CString::new(name)?; - let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; - - check_call!(ffi::TVMModGetFunction( - self.handle(), - name.as_ptr() as *const c_char, - query_import as c_int, - &mut fhandle as *mut _ - )); - - if fhandle.is_null() { - return Err(errors::Error::NullHandle(name.into_string()?.to_string())); - } - - Ok(Function::from_raw(fhandle)) - } - - /// Imports a dependent module such as `.ptx` for cuda gpu. - pub fn import_module(&self, dependent_module: Module) { - check_call!(ffi::TVMModImport(self.handle(), dependent_module.handle())) - } - - /// Loads a module shared library from path. - pub fn load>(path: &P) -> Result { - let ext = CString::new( - path.as_ref() - .extension() - .unwrap_or_else(|| std::ffi::OsStr::new("")) - .to_str() - .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))?, - )?; - - let cpath = CString::new( - path.as_ref() - .to_str() - .ok_or_else(|| Error::ModuleLoadPath(path.as_ref().display().to_string()))?, - )?; - - let module = load_from_file(cpath, ext)?; - Ok(module) - } - - pub fn save_to_file(&self, name: String, fmt: String) -> Result<(), Error> { - save_to_file(self.clone(), name.into(), fmt.into()) - } - - pub fn export_library(&self, name: String) -> Result<(), Error> { - export_library(self.clone(), name.into()) - } - - /// Checks if a target device is enabled for a module. - pub fn enabled(&self, target: &str) -> bool { - let target = CString::new(target).unwrap(); - runtime_enabled(target).unwrap() - } - - /// Returns the underlying module handle. - pub unsafe fn handle(&self) -> ffi::TVMModuleHandle { - self.0.clone().unwrap().into_raw() as *mut _ - } -} diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs deleted file mode 100644 index dd3882a098e2..000000000000 --- a/rust/tvm-rt/src/ndarray.rs +++ /dev/null @@ -1,515 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! This module implements the [`NDArray`] type for working with *TVM tensors* or -//! coverting from a Rust's ndarray to TVM `NDArray`. -//! -//! One can create an empty NDArray given the shape, device and dtype using [`empty`]. -//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`]. -//! To copy an NDArray to different device use [`copy_to_device`]. -//! -//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows: -//! -//! # Example -//! -//! ``` -//! # use tvm_rt::{NDArray, DataType, Device}; -//! # use ndarray::{Array, ArrayD}; -//! # use std::str::FromStr; -//! use std::convert::TryFrom; -//! -//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) -//! .unwrap() -//! .into_dyn(); // Rust's ndarray -//! let nd = NDArray::from_rust_ndarray(&a, Device::cpu(0), DataType::from_str("float32").unwrap()).unwrap(); -//! assert_eq!(nd.shape(), &[2, 2]); -//! let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); -//! assert!(rnd.all_close(&a, 1e-8f32)); -//! ``` -//! -//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/ -//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer -//! [`copy_to_device`]:struct.NDArray.html#method.copy_to_device - -use std::ffi::c_void; -use std::{borrow::Cow, convert::TryInto}; -use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; - -use mem::size_of; -use tvm_macros::Object; -use tvm_sys::ffi::DLTensor; -use tvm_sys::{ffi, ByteArray, DataType, Device}; - -use ndarray::{Array, ArrayD}; -use num_traits::Num; - -use crate::errors::NDArrayError; - -use crate::object::{Object, ObjectPtr, ObjectRef}; - -/// See the [`module-level documentation`](../ndarray/index.html) for more details. -#[repr(C)] -#[derive(Object, Debug)] -#[ref_name = "NDArray"] -#[type_key = "runtime.NDArray"] -pub struct NDArrayContainer { - base: Object, - // Container Base - dl_tensor: DLTensor, - manager_ctx: *mut c_void, - shape: ObjectRef, -} - -impl NDArrayContainer { - pub(crate) fn from_raw(handle: ffi::TVMArrayHandle) -> Option> { - let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; - let base_ptr = unsafe { (handle as *mut i8).offset(-base_offset) }; - let object_ptr = ObjectPtr::from_raw(base_ptr.cast()); - object_ptr.map(|ptr| { - ptr.downcast::() - .expect("we know this is an NDArray container") - }) - } - - pub fn leak<'a>(object_ptr: ObjectPtr) -> &'a mut NDArrayContainer - where - NDArrayContainer: 'a, - { - let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; - unsafe { - &mut *std::mem::ManuallyDrop::new(object_ptr) - .ptr - .as_ptr() - .cast::() - .offset(base_offset) - .cast::() - } - } - - pub fn as_mut_ptr<'a>(object_ptr: &ObjectPtr) -> *mut NDArrayContainer - where - NDArrayContainer: 'a, - { - let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize; - unsafe { - object_ptr - .ptr - .as_ptr() - .cast::() - .offset(base_offset) - .cast::() - } - } -} - -fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> { - if std::mem::size_of::() == 64 { - debug_assert!(slice.iter().all(|&x| x >= 0)); - let shape: &[usize] = unsafe { std::mem::transmute(slice) }; - Cow::Borrowed(shape) - } else { - let shape: Vec = slice - .iter() - .map(|&x| usize::try_from(x).unwrap_or_else(|_| panic!("Cannot fit into usize: {}", x))) - .collect(); - Cow::Owned(shape) - } -} - -impl NDArray { - pub(crate) fn _from_raw(handle: ffi::TVMArrayHandle) -> Self { - let ptr = NDArrayContainer::from_raw(handle); - NDArray(ptr) - } - - // I think these should be marked as unsafe functions? projecting a reference is bad news. - pub fn as_dltensor(&self) -> &DLTensor { - &self.dl_tensor - } - - pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { - unsafe { std::mem::transmute(self.as_dltensor()) } - } - - pub fn is_view(&self) -> bool { - false - } - - /// Returns the shape of the NDArray. - pub fn shape(&self) -> &[i64] { - let arr = self.as_dltensor(); - if arr.shape.is_null() || arr.data.is_null() { - &[] - } else { - unsafe { slice::from_raw_parts(arr.shape, self.ndim()) } - } - } - - /// Returns the shape of the NDArray as a &[usize] - /// - /// On 64-bit platforms, this is zero-cost and uses the shape from the DLTensor. - /// On other platforms, this copies into a buffer. - pub fn shape_usize(&self) -> Cow<[usize]> { - cow_usize(self.shape()) - } - - /// Returns the strides of the underlying NDArray. - pub fn strides(&self) -> Option<&[i64]> { - let arr = self.as_dltensor(); - if arr.strides.is_null() { - None - } else { - Some(unsafe { slice::from_raw_parts(arr.strides, self.ndim()) }) - } - } - - /// Returns the strides of the NDArray as a &[usize] - /// - /// On 64-bit platforms, this is zero-cost and uses the strides from the DLTensor. - /// On other platforms, this copies into a buffer. - pub fn strides_usize(&self) -> Option> { - self.strides().map(cow_usize) - } - - /// Returns true if the tensor is empty - pub fn is_empty(&self) -> bool { - self.as_dltensor().data.is_null() - } - - /// Returns the total number of entries of the NDArray. - pub fn len(&self) -> usize { - let len: i64 = self.shape().iter().product(); - usize::try_from(len).unwrap_or_else(|_| panic!("bad len: {}", len)) - } - - /// Returns the total bytes taken up by the data. - /// This is equal to `nd.len() * nd.dtype().itemsize` - pub fn size(&self) -> usize { - self.len() * self.dtype().itemsize - } - - /// Returns the device which the NDArray was defined. - pub fn device(&self) -> Device { - self.as_dltensor().device.into() - } - - /// Returns the type of the entries of the NDArray. - pub fn dtype(&self) -> DataType { - self.as_dltensor().dtype.into() - } - - /// Returns the number of dimensions of the NDArray. - pub fn ndim(&self) -> usize { - self.as_dltensor() - .ndim - .try_into() - .expect("number of dimensions must always be positive") - } - - /// Shows whether the underlying ndarray is contiguous in memory or not. - pub fn is_contiguous(&self) -> bool { - match self.strides() { - None => true, - Some(strides) => { - // NDArrayError::MissingShape in case shape is not determined - self.shape() - .iter() - .zip(strides) - .rfold( - (true, 1), - |(is_contig, expected_stride), (shape, stride)| { - ( - is_contig && *stride == expected_stride, - expected_stride * shape, - ) - }, - ) - .0 - } - } - } - - pub fn byte_offset(&self) -> isize { - self.as_dltensor().byte_offset as isize - } - - /// Flattens the NDArray to a `Vec` of the same type in cpu. - /// - /// ## Example - /// - /// ``` - /// # use tvm_rt::{Device, DataType, NDArray}; - /// # use std::str::FromStr; - /// let mut shape = [4]; - /// let mut data = vec![1i32, 2, 3, 4]; - /// let dev = Device::cpu(0); - /// let mut ndarray = NDArray::empty(&mut shape, dev, DataType::from_str("int32").unwrap()); - /// ndarray.copy_from_buffer(&mut data); - /// assert_eq!(ndarray.shape(), shape); - /// assert_eq!(ndarray.to_vec::().unwrap(), data); - /// ``` - pub fn to_vec(&self) -> Result, NDArrayError> { - let n = self.size() / size_of::(); - let mut vec: Vec = Vec::with_capacity(n); - - let ptr = vec.as_mut_ptr(); - let slice = unsafe { slice::from_raw_parts_mut(ptr, n) }; - self.copy_to_buffer(slice); - - unsafe { vec.set_len(n) }; - Ok(vec) - } - - /// Converts the NDArray to [`ByteArray`]. - pub fn to_bytearray(&self) -> Result { - let v = self.to_vec::()?; - Ok(ByteArray::from(v)) - } - - /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. - /// - /// ## Example - /// - /// ``` - /// # use tvm_rt::{Device, DataType, NDArray}; - /// # use std::str::FromStr; - /// let shape = &mut [2]; - /// let mut data = vec![1f32, 2.0]; - /// let dev = Device::cpu(0); - /// let mut ndarray = NDArray::empty(shape, dev, DataType::from_str("int32").unwrap()); - /// ndarray.copy_from_buffer(&mut data); - /// ``` - /// - /// *Note*: if something goes wrong during the copy, it will panic - /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. - pub fn copy_from_buffer(&mut self, data: &[T]) { - check_call!(ffi::TVMArrayCopyFromBytes( - self.as_raw_dltensor(), - data.as_ptr() as *mut _, - (data.len() * mem::size_of::()) as _, - )); - } - - pub fn copy_to_buffer(&self, data: &mut [T]) { - assert_eq!(self.size(), data.len() * size_of::()); - check_call!(ffi::TVMArrayCopyToBytes( - self.as_raw_dltensor(), - data.as_ptr() as *mut _, - self.size() as _, - )); - } - - pub fn fill_from_iter(&mut self, iter: I) - where - T: Num32, - I: ExactSizeIterator, - { - assert!(self.is_contiguous()); - assert_eq!(self.size(), size_of::() * iter.len()); - let mut ptr: *mut T = self.as_dltensor().data.cast(); - iter.for_each(|x| unsafe { - ptr.write(x); - ptr = ptr.add(1); - }) - } - - /// Copies the NDArray to another target NDArray. - pub fn copy_to_ndarray(&self, target: NDArray) -> Result { - if self.dtype() != target.dtype() { - return Err(NDArrayError::DataTypeMismatch { - expected: self.dtype(), - actual: target.dtype(), - }); - } - - check_call!(ffi::TVMArrayCopyFromTo( - self.as_raw_dltensor(), - target.as_raw_dltensor(), - ptr::null_mut() as ffi::TVMStreamHandle - )); - - Ok(target) - } - - /// Copies the NDArray to a target device. - pub fn copy_to_device(&self, target: &Device) -> Result { - let tmp = NDArray::empty(self.shape(), *target, self.dtype()); - let copy = self.copy_to_ndarray(tmp)?; - Ok(copy) - } - - /// Converts a Rust's ndarray to TVM NDArray. - pub fn from_rust_ndarray( - input_nd: &ArrayD, - dev: Device, - dtype: DataType, - ) -> Result { - let shape: Vec = input_nd.shape().iter().map(|&x| x as i64).collect(); - let mut nd = NDArray::empty(&shape, dev, dtype); - nd.fill_from_iter(input_nd.iter().copied()); - Ok(nd) - } - - /// Allocates and creates an empty NDArray given the shape, device and dtype. - pub fn empty(shape: &[i64], dev: Device, dtype: DataType) -> NDArray { - let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; - let dtype: tvm_sys::ffi::DLDataType = dtype.into(); - check_call!(ffi::TVMArrayAlloc( - shape.as_ptr(), - shape.len() as c_int, - i32::from(dtype.code) as c_int, - i32::from(dtype.bits) as c_int, - i32::from(dtype.lanes) as c_int, - dev.device_type as c_int, - dev.device_id as c_int, - &mut handle as *mut _, - )); - let ptr = NDArrayContainer::from_raw(handle) - .map(|o| o.downcast().expect("this should never fail")); - NDArray(ptr) - } - - pub fn zeroed(self) -> NDArray { - unsafe { - let dltensor = self.as_raw_dltensor(); - let bytes_ptr: *mut u8 = std::mem::transmute((*dltensor).data); - println!("size {}", self.size()); - std::ptr::write_bytes(bytes_ptr, 0, self.size()); - self - } - } -} - -macro_rules! impl_from_ndarray_rustndarray { - ($type:ty, $type_name:tt) => { - impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { - type Error = NDArrayError; - - fn try_from(nd: &NDArray) -> Result, Self::Error> { - assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); - Ok(Array::from_shape_vec( - &*nd.shape_usize(), - nd.to_vec::<$type>()?, - )?) - } - } - - impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { - type Error = NDArrayError; - - fn try_from(nd: &mut NDArray) -> Result, Self::Error> { - assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); - Ok(Array::from_shape_vec( - &*nd.shape_usize(), - nd.to_vec::<$type>()?, - )?) - } - } - }; -} - -impl_from_ndarray_rustndarray!(i32, "int"); -impl_from_ndarray_rustndarray!(u32, "uint"); -impl_from_ndarray_rustndarray!(f32, "float"); - -mod sealed { - /// Private trait to prevent other traits from being implemeneted in downstream crates. - pub trait Sealed {} -} - -/// A trait for the supported 32-bits numerical types in frontend. -pub trait Num32: Num + sealed::Sealed { - const BITS: u8 = 32; -} - -macro_rules! impl_num32 { - ($($type:ty),+) => { - $( - impl sealed::Sealed for $type {} - impl Num32 for $type {} - )+ - }; -} - -impl_num32!(i32, u32, f32); - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn basics() { - let shape = &[1, 2, 3]; - let dev = Device::cpu(0); - println!("before empty"); - let ndarray = NDArray::empty(shape, dev, DataType::from_str("int32").unwrap()); - println!("after empty"); - assert_eq!(ndarray.shape(), shape); - assert_eq!(ndarray.len(), shape.iter().product::() as usize); - assert_eq!(ndarray.ndim(), 3); - assert!(ndarray.strides().is_none()); - assert_eq!(ndarray.byte_offset(), 0); - } - - #[test] - fn copy() { - let shape = &[4]; - let data = vec![1i32, 2, 3, 4]; - let dev = Device::cpu(0); - let mut ndarray = NDArray::empty(shape, dev, DataType::int(32, 1)).zeroed(); - assert_eq!(ndarray.to_vec::().unwrap(), vec![0, 0, 0, 0]); - ndarray.copy_from_buffer(&data); - assert_eq!(ndarray.shape(), shape); - assert_eq!(ndarray.to_vec::().unwrap(), data); - assert_eq!(ndarray.ndim(), 1); - assert!(ndarray.is_contiguous()); - assert_eq!(ndarray.byte_offset(), 0); - let shape = vec![4]; - let e = NDArray::empty(&shape, Device::cpu(0), DataType::from_str("int32").unwrap()); - let nd = ndarray.copy_to_ndarray(e); - assert!(nd.is_ok()); - assert_eq!(nd.unwrap().to_vec::().unwrap(), data); - } - - /// This occasionally panics on macOS: https://github.com/rust-lang/rust/issues/71397 - #[test] - #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] - fn copy_wrong_dtype() { - let shape = vec![4]; - let mut data = vec![1f32, 2., 3., 4.]; - let dev = Device::cpu(0); - let mut nd_float = NDArray::empty(&shape, dev, DataType::from_str("float32").unwrap()); - nd_float.copy_from_buffer(&mut data); - let empty_int = NDArray::empty(&shape, dev, DataType::from_str("int32").unwrap()); - nd_float.copy_to_ndarray(empty_int).unwrap(); - } - - #[test] - fn rust_ndarray() { - let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) - .unwrap() - .into_dyn(); - let nd = - NDArray::from_rust_ndarray(&a, Device::cpu(0), DataType::from_str("float32").unwrap()) - .unwrap(); - assert_eq!(nd.shape(), &[2, 2]); - let rnd: ArrayD = ArrayD::try_from(&nd).unwrap(); - assert!(rnd.all_close(&a, 1e-8f32)); - } -} diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs deleted file mode 100644 index f5832fcb3ab8..000000000000 --- a/rust/tvm-rt/src/object/mod.rs +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::convert::TryFrom; -use std::ffi::CString; - -use crate::errors::Error; -use crate::external; - -use tvm_sys::{ArgValue, RetValue}; - -mod object_ptr; - -pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef}; - -pub trait AsArgValue<'a> { - fn as_arg_value(&'a self) -> ArgValue<'a>; -} - -impl<'a, T: 'static> AsArgValue<'a> for T -where - &'a T: Into>, -{ - fn as_arg_value(&'a self) -> ArgValue<'a> { - self.into() - } -} - -// TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we -// can't because of coherence rules. Instead, we generate them in the macro, and -// add what we can (including Into instead of From) as subtraits. -// We also add named conversions for clarity -pub trait IsObjectRef: - Sized - + Clone - + Into - + for<'a> AsArgValue<'a> - + TryFrom - + for<'a> TryFrom, Error = Error> - + std::fmt::Debug -{ - type Object: IsObject; - fn as_ptr(&self) -> Option<&ObjectPtr>; - fn into_ptr(self) -> Option>; - fn from_ptr(object_ptr: Option>) -> Self; - - fn null() -> Self { - Self::from_ptr(None) - } - - fn into_arg_value<'a>(&'a self) -> ArgValue<'a> { - self.as_arg_value() - } - - fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result { - Self::try_from(arg_value) - } - - fn into_ret_value<'a>(self) -> RetValue { - self.into() - } - - fn from_ret_value<'a>(ret_value: RetValue) -> Result { - Self::try_from(ret_value) - } - - fn upcast(self) -> U - where - U: IsObjectRef, - Self::Object: AsRef, - { - let ptr = self.into_ptr().map(ObjectPtr::upcast); - U::from_ptr(ptr) - } - - fn downcast(self) -> Result - where - U: IsObjectRef, - U::Object: AsRef, - { - let ptr = self.into_ptr().map(ObjectPtr::downcast); - let ptr = ptr.transpose()?; - Ok(U::from_ptr(ptr)) - } -} - -external! { - #[name("ir.DebugPrint")] - pub fn debug_print(object: ObjectRef) -> CString; - #[name("node.StructuralHash")] - fn structural_hash(object: ObjectRef, map_free_vars: bool) -> i64; - #[name("node.StructuralEqual")] - fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> bool; -} diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs deleted file mode 100644 index 09d6068f1a88..000000000000 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ /dev/null @@ -1,555 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::convert::TryFrom; -use std::ffi::CString; -use std::fmt; -use std::os::raw::c_char; -use std::ptr::NonNull; -use std::sync::atomic::AtomicI32; - -use tvm_macros::Object; -use tvm_sys::ffi::{ - self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeIndex2Key, TVMObjectTypeKey2Index, -}; -use tvm_sys::{ArgValue, RetValue}; - -use crate::errors::Error; - -type Deleter = unsafe extern "C" fn(object: *mut Object) -> (); - -/// A TVM intrusive smart pointer header, in TVM all FFI compatible types -/// start with an Object as their first field. The base object tracks -/// a type_index which is an index into the runtime type information -/// table, an atomic reference count, and a customized deleter which -/// will be invoked when the reference count is zero. -/// -#[derive(Debug, Object)] -#[ref_name = "ObjectRef"] -#[type_key = "runtime.Object"] -#[repr(C)] -pub struct Object { - /// The index into TVM's runtime type information table. - pub(self) type_index: u32, - // TODO(@jroesch): pretty sure Rust and C++ atomics are the same, but not sure. - // NB: in general we should not touch this in Rust. - /// The reference count of the smart pointer. - pub(self) ref_count: AtomicI32, - /// The deleter function which is used to deallocate the underlying data - /// when the reference count is zero. This field must always be set for - /// all objects. - /// - /// The common use case is ensuring that the allocator which allocated the - /// data is also the one that deletes it. - pub(self) fdeleter: Deleter, -} - -/// The default deleter for objects allocated in Rust, we use a bit of -/// trait magic here to get a monomorphized deleter for each object -/// "subtype". -/// -/// This function just converts the pointer to the correct type -/// and reconstructs a Box which then is dropped to deallocate -/// the underlying allocation. -unsafe extern "C" fn delete(object: *mut Object) { - let typed_object: *mut T = object as *mut T; - let boxed: Box = Box::from_raw(typed_object); - drop(boxed); -} - -fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool { - let mut is_derived = 0; - crate::check_call!(ffi::TVMObjectDerivedFrom( - child_type_index, - parent_type_index, - &mut is_derived - )); - - if is_derived == 0 { - false - } else { - true - } -} - -impl Object { - fn new(type_index: u32, deleter: Deleter) -> Object { - Object { - type_index, - // NB(@jroesch): I believe it is sound to use Rust atomics - // in conjunction with C++ atomics given the memory model - // is nearly identical. - // - // Of course these are famous last words which I may later - // regret. - ref_count: AtomicI32::new(0), - fdeleter: deleter, - } - } - - fn get_type_key(&self) -> String { - let mut cstring: *mut c_char = std::ptr::null_mut(); - unsafe { - if TVMObjectTypeIndex2Key(self.type_index, &mut cstring as *mut _) != 0 { - panic!("{}", crate::get_last_error()); - } - return CString::from_raw(cstring) - .into_string() - .expect("type keys should be valid utf-8"); - } - } - - fn get_type_index() -> u32 { - let type_key = T::TYPE_KEY; - let cstring = CString::new(type_key).expect("type key must not contain null characters"); - - // TODO(@jroesch): look into TVMObjectTypeKey2Index. - if type_key == "runtime.Object" { - return 0; - } else { - let mut index = 0; - unsafe { - if TVMObjectTypeKey2Index(cstring.as_ptr(), &mut index) != 0 { - panic!("{}", crate::get_last_error()) - } - } - return index; - } - } - - pub fn count(&self) -> i32 { - // need to do atomic read in C++ - // ABI compatible atomics is funky/hard. - self.ref_count.load(std::sync::atomic::Ordering::Relaxed) - } - - /// Allocates a base object value for an object subtype of type T. - /// By using associated constants and generics we can provide a - /// type indexed abstraction over allocating objects with the - /// correct index and deleter. - pub fn base() -> Object { - let index = Object::get_type_index::(); - Object::new(index, delete::) - } - - /// Increases the object's reference count by one. - pub(self) fn inc_ref(&self) { - let raw_ptr = self as *const Object as *mut Object as *mut std::ffi::c_void; - unsafe { - assert_eq!(TVMObjectRetain(raw_ptr), 0); - } - } - - /// Decreases the object's reference count by one. - pub(self) fn dec_ref(&self) { - let raw_ptr = self as *const Object as *mut Object as *mut std::ffi::c_void; - unsafe { - assert_eq!(TVMObjectFree(raw_ptr), 0); - } - } -} - -/// An unsafe trait which should be implemented for an object -/// subtype. -/// -/// The trait contains the type key needed to compute the type -/// index, a method for accessing the base object given the -/// subtype, and a typed delete method which is specialized -/// to the subtype. -pub unsafe trait IsObject: AsRef + std::fmt::Debug { - const TYPE_KEY: &'static str; -} - -/// A smart pointer for types which implement IsObject. -/// This type directly corresponds to TVM's C++ type ObjectPtr. -/// -/// See object.h for more details. -#[repr(C)] -pub struct ObjectPtr { - pub ptr: NonNull, -} - -impl ObjectPtr { - pub fn from_raw(object_ptr: *mut Object) -> Option> { - let non_null = NonNull::new(object_ptr); - non_null.map(|ptr| { - debug_assert!(unsafe { ptr.as_ref().count() } >= 0); - ObjectPtr { ptr } - }) - } -} - -impl Clone for ObjectPtr { - fn clone(&self) -> Self { - unsafe { self.ptr.as_ref().as_ref().inc_ref() } - ObjectPtr { ptr: self.ptr } - } -} - -impl Drop for ObjectPtr { - fn drop(&mut self) { - unsafe { self.ptr.as_ref().as_ref().dec_ref() } - } -} - -impl ObjectPtr { - pub fn leak<'a>(object_ptr: ObjectPtr) -> &'a mut T - where - T: 'a, - { - unsafe { &mut *std::mem::ManuallyDrop::new(object_ptr).ptr.as_ptr() } - } - - pub fn new(object: T) -> ObjectPtr { - object.as_ref().inc_ref(); - let object_ptr = Box::new(object); - let object_ptr = Box::leak(object_ptr); - let ptr = NonNull::from(object_ptr); - ObjectPtr { ptr } - } - - pub fn count(&self) -> i32 { - // need to do atomic read in C++ - // ABI compatible atomics is funky/hard. - self.as_ref() - .ref_count - .load(std::sync::atomic::Ordering::Relaxed) - } - - /// This method avoid running the destructor on self once it's dropped, so we don't accidentally release the memory - unsafe fn cast(self) -> ObjectPtr { - let ptr = self.ptr.cast(); - std::mem::forget(self); - ObjectPtr { ptr } - } - - pub fn upcast(self) -> ObjectPtr - where - U: IsObject, - T: AsRef, - { - unsafe { self.cast() } - } - - pub fn downcast(self) -> Result, Error> - where - U: IsObject + AsRef, - { - let child_index = Object::get_type_index::(); - let object_index = self.as_ref().type_index; - - let is_derived = if child_index == object_index { - true - } else { - // TODO(@jroesch): write tests - derived_from(object_index, child_index) - }; - - if is_derived { - Ok(unsafe { self.cast() }) - } else { - let type_key = self.as_ref().get_type_key(); - Err(Error::downcast(type_key.into(), U::TYPE_KEY)) - } - } - - pub unsafe fn into_raw(self) -> *mut T { - self.ptr.as_ptr() - } - - pub unsafe fn as_ptr(&self) -> *mut T { - self.ptr.as_ptr() - } -} - -impl std::ops::Deref for ObjectPtr { - type Target = T; - - fn deref(&self) -> &Self::Target { - unsafe { self.ptr.as_ref() } - } -} - -impl fmt::Debug for ObjectPtr { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - use std::ops::Deref; - write!(f, "{:?}", self.deref()) - } -} - -impl<'a, T: IsObject> From> for RetValue { - fn from(object_ptr: ObjectPtr) -> RetValue { - let raw_object_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void; - assert!(!raw_object_ptr.is_null()); - RetValue::ObjectHandle(raw_object_ptr) - } -} - -impl<'a, T: IsObject> TryFrom for ObjectPtr { - type Error = Error; - - fn try_from(ret_value: RetValue) -> Result, Self::Error> { - use crate::ffi::DLTensor; - use crate::ndarray::NDArrayContainer; - - match ret_value { - RetValue::ObjectHandle(handle) | RetValue::ModuleHandle(handle) => { - let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; - debug_assert!(optr.count() >= 1); - optr.downcast() - } - RetValue::NDArrayHandle(handle) => { - let optr: ObjectPtr = - NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; - debug_assert!(optr.count() >= 1); - optr.upcast::().downcast() - } - _ => Err(Error::downcast(format!("{:?}", ret_value), T::TYPE_KEY)), - } - } -} - -impl<'a, T: IsObject> From<&'a ObjectPtr> for ArgValue<'a> { - fn from(object_ptr: &'a ObjectPtr) -> ArgValue<'a> { - debug_assert!(object_ptr.count() >= 1); - let object_ptr = object_ptr.clone().upcast::(); - match T::TYPE_KEY { - "runtime.NDArray" => { - use crate::ndarray::NDArrayContainer; - let dcast_ptr = object_ptr.downcast().unwrap(); - let raw_ptr = NDArrayContainer::as_mut_ptr(&dcast_ptr) as *mut std::ffi::c_void; - assert!(!raw_ptr.is_null()); - ArgValue::NDArrayHandle(raw_ptr) - } - "runtime.Module" => { - let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; - assert!(!raw_ptr.is_null()); - ArgValue::ModuleHandle(raw_ptr) - } - _ => { - let raw_ptr = unsafe { object_ptr.as_ptr() } as *mut std::ffi::c_void; - assert!(!raw_ptr.is_null()); - ArgValue::ObjectHandle(raw_ptr) - } - } - } -} - -impl<'a, T: IsObject> TryFrom> for ObjectPtr { - type Error = Error; - - fn try_from(arg_value: ArgValue<'a>) -> Result, Self::Error> { - use crate::ffi::DLTensor; - use crate::ndarray::NDArrayContainer; - - match arg_value { - ArgValue::ObjectHandle(handle) | ArgValue::ModuleHandle(handle) => { - let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?; - optr.inc_ref(); - // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must - // bump the reference count by one. - assert!(optr.count() >= 1); - optr.downcast() - } - ArgValue::NDArrayHandle(handle) => { - let optr = - NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?; - // We are building an owned, ref-counted view into the underlying ArgValue, in order to be safe we must - // bump the reference count by one. - assert!(optr.count() >= 1); - // TODO(@jroesch): figure out if there is a more optimal way to do this - let object = optr.upcast::(); - object.inc_ref(); - object.downcast() - } - _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), - } - } -} - -impl std::hash::Hash for ObjectPtr { - fn hash(&self, state: &mut H) { - state.write_i64( - super::structural_hash(ObjectRef(Some(self.clone().upcast())), false).unwrap(), - ) - } -} - -impl PartialEq for ObjectPtr { - fn eq(&self, other: &Self) -> bool { - let lhs = ObjectRef(Some(self.clone().upcast())); - let rhs = ObjectRef(Some(other.clone().upcast())); - super::structural_equal(lhs, rhs, false, false).unwrap() - } -} - -impl Eq for ObjectPtr {} - -#[cfg(test)] -mod tests { - use super::{Object, ObjectPtr}; - use anyhow::{ensure, Result}; - use std::convert::TryInto; - use tvm_sys::{ArgValue, RetValue}; - - #[test] - fn test_new_object() -> anyhow::Result<()> { - let object = Object::base::(); - let ptr = ObjectPtr::new(object); - assert_eq!(ptr.count(), 1); - Ok(()) - } - - #[test] - fn test_leak() -> anyhow::Result<()> { - let ptr = ObjectPtr::new(Object::base::()); - assert_eq!(ptr.count(), 1); - let object = ObjectPtr::leak(ptr); - assert_eq!(object.count(), 1); - Ok(()) - } - - #[test] - fn test_clone() -> anyhow::Result<()> { - let ptr = ObjectPtr::new(Object::base::()); - assert_eq!(ptr.count(), 1); - let ptr2 = ptr.clone(); - assert_eq!(ptr2.count(), 2); - drop(ptr); - assert_eq!(ptr2.count(), 1); - Ok(()) - } - - #[test] - fn roundtrip_retvalue() -> Result<()> { - let ptr = ObjectPtr::new(Object::base::()); - assert_eq!(ptr.count(), 1); - let ret_value: RetValue = ptr.clone().into(); - let ptr2: ObjectPtr = ret_value.try_into()?; - assert_eq!(ptr.count(), ptr2.count()); - assert_eq!(ptr.count(), 2); - ensure!( - ptr.type_index == ptr2.type_index, - "type indices do not match" - ); - ensure!( - ptr.fdeleter == ptr2.fdeleter, - "objects have different deleters" - ); - // After dropping the second pointer we should only see only refcount. - drop(ptr2); - assert_eq!(ptr.count(), 1); - Ok(()) - } - - #[test] - fn roundtrip_argvalue() -> Result<()> { - let ptr = ObjectPtr::new(Object::base::()); - assert_eq!(ptr.count(), 1); - let ptr_clone = ptr.clone(); - assert_eq!(ptr.count(), 2); - let arg_value: ArgValue = (&ptr_clone).into(); - assert_eq!(ptr.count(), 2); - let ptr2: ObjectPtr = arg_value.try_into()?; - assert_eq!(ptr2.count(), 3); - assert_eq!(ptr.count(), ptr2.count()); - drop(ptr_clone); - assert_eq!(ptr.count(), 2); - ensure!( - ptr.type_index == ptr2.type_index, - "type indices do not match" - ); - ensure!( - ptr.fdeleter == ptr2.fdeleter, - "objects have different deleters" - ); - // After dropping the second pointer we should only see only refcount. - drop(ptr2); - assert_eq!(ptr.count(), 1); - Ok(()) - } - - fn test_fn_raw<'a>( - mut args: crate::to_function::ArgList<'a>, - ) -> crate::function::Result { - let v: ArgValue = args.remove(0); - let v2: ArgValue = args.remove(0); - // assert_eq!(o.count(), 2); - let o: ObjectPtr = v.try_into().unwrap(); - assert_eq!(o.count(), 2); - let o2: ObjectPtr = v2.try_into().unwrap(); - assert_eq!(o2.count(), 3); - drop(o2); - assert_eq!(o.count(), 2); - Ok(o.into()) - } - - #[test] - fn test_ref_count_raw_fn() { - use super::*; - use crate::function::{register_untyped, Function}; - let ptr = ObjectPtr::new(Object::base::()); - // Call the function without the wrapping for TVM. - assert_eq!(ptr.count(), 1); - let same = test_fn_raw(vec![(&ptr).into(), (&ptr).into()]).unwrap(); - let output: ObjectPtr = same.try_into().unwrap(); - assert_eq!(output.count(), 2); - drop(output); - assert_eq!(ptr.count(), 1); - - register_untyped(test_fn_raw, "test_fn_raw", true).unwrap(); - let raw_func = Function::get("test_fn_raw").unwrap(); - let output = raw_func.invoke(vec![(&ptr).into(), (&ptr).into()]).unwrap(); - let output: ObjectPtr = output.try_into().unwrap(); - assert_eq!(output.count(), 2); - drop(output); - assert_eq!(ptr.count(), 1); - } - - fn test_fn_typed(o: ObjectPtr, o2: ObjectPtr) -> ObjectPtr { - assert_eq!(o.count(), 3); - assert_eq!(o2.count(), 3); - drop(o2); - assert_eq!(o.count(), 2); - return o; - } - - #[test] - fn test_ref_count_typed() { - use super::*; - use crate::function::{register, Function}; - let ptr = ObjectPtr::new(Object::base::()); - // Call the function without the wrapping for TVM. - assert_eq!(ptr.count(), 1); - let output = test_fn_typed(ptr.clone(), ptr.clone()); - assert_eq!(output.count(), 2); - drop(output); - assert_eq!(ptr.count(), 1); - - register(test_fn_typed, "test_fn_typed").unwrap(); - let typed_func = Function::get("test_fn_typed").unwrap(); - let output = typed_func - .invoke(vec![(&ptr).into(), (&ptr).into()]) - .unwrap(); - let output: ObjectPtr = output.try_into().unwrap(); - assert_eq!(output.count(), 2); - drop(output); - assert_eq!(ptr.count(), 1); - } -} diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs deleted file mode 100644 index e61afaf7399b..000000000000 --- a/rust/tvm-rt/src/string.rs +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::cmp::{Ordering, PartialEq}; -use std::hash::{Hash, Hasher}; - -use super::Object; - -use tvm_macros::Object; - -#[repr(C)] -#[derive(Object, Debug)] -#[ref_name = "String"] -#[type_key = "runtime.String"] -#[no_derive] -pub struct StringObj { - base: Object, - data: *const u8, - size: u64, -} - -impl From for String { - fn from(s: std::string::String) -> Self { - let size = s.len() as u64; - let data = Box::into_raw(s.into_boxed_str()).cast(); - let base = Object::base::(); - StringObj { base, data, size }.into() - } -} - -impl From<&'static str> for String { - fn from(s: &'static str) -> Self { - let size = s.len() as u64; - let data = s.as_bytes().as_ptr(); - let base = Object::base::(); - StringObj { base, data, size }.into() - } -} - -impl AsRef<[u8]> for String { - fn as_ref(&self) -> &[u8] { - self.as_bytes() - } -} - -impl std::fmt::Display for String { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.to_string_lossy().fmt(f) - } -} - -impl String { - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - pub fn len(&self) -> usize { - self.size as usize - } - - pub fn as_bytes(&self) -> &[u8] { - unsafe { std::slice::from_raw_parts(self.data, self.len()) } - } - - pub fn as_str(&self) -> Result<&str, std::str::Utf8Error> { - std::str::from_utf8(self.as_bytes()) - } - - pub fn to_string_lossy(&self) -> std::borrow::Cow { - std::string::String::from_utf8_lossy(self.as_bytes()) - } -} - -impl> PartialEq for String { - fn eq(&self, other: &T) -> bool { - self.as_bytes() == other.as_ref() - } -} - -impl> PartialOrd for String { - fn partial_cmp(&self, other: &T) -> Option { - self.as_bytes().partial_cmp(other.as_ref()) - } -} - -impl Eq for String {} - -impl Ord for String { - fn cmp(&self, other: &Self) -> Ordering { - self.as_bytes().cmp(other.as_bytes()) - } -} - -impl Hash for String { - fn hash(&self, state: &mut H) { - self.as_bytes().hash(state); - } -} - -impl std::fmt::Debug for String { - fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_fmt(format_args!("{:?}", self.to_string_lossy())) - } -} - -#[cfg(test)] -mod tests { - use super::String; - use crate::object::debug_print; - use crate::IsObjectRef; - use anyhow::{ensure, Result}; - - #[test] - fn test_string_debug() -> Result<()> { - let s = String::from("foo"); - let object_ref = s.upcast(); - println!("about to call"); - let string = debug_print(object_ref)?; - println!("after call"); - ensure!( - string.into_string().expect("is cstring").contains("foo"), - "string content is invalid" - ); - Ok(()) - } -} diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs deleted file mode 100644 index 67fbfc996af0..000000000000 --- a/rust/tvm-rt/src/to_function.rs +++ /dev/null @@ -1,337 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! This module provides an idiomatic Rust API for creating and working with TVM functions. -//! -//! For calling an already registered TVM function use [`function::Builder`] -//! To register a TVM packed function from Rust side either -//! use [`function::register`] or the macro [`register_global_func`]. -//! -//! See the tests and examples repository for more examples. - -use std::convert::{TryFrom, TryInto}; -use std::{ - os::raw::{c_int, c_void}, - ptr, slice, -}; - -use super::{function::Result, Function}; -use crate::errors::Error; - -pub use tvm_sys::{ffi, ArgValue, RetValue}; - -/// A trait representing whether the function arguments -/// and return type can be assigned to a TVM packed function. -/// -/// By splitting the conversion to function into two traits -/// we are able to improve error reporting, by splitting the -/// conversion of inputs and outputs to this trait. -/// -/// And the implementation of it to `ToFunction`. - -pub type ArgList<'a> = Vec>; - -pub enum Args<'a, I> { - Typed(I), - Raw(ArgList<'a>), -} - -pub trait Typed { - fn args<'arg>(i: Vec>) -> Result>; - fn ret(o: O) -> Result; -} - -pub trait ToFunction: Sized { - type Handle; - - fn into_raw(self) -> *mut Self::Handle; - - fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result - where - Self: Typed; - - fn drop(handle: *mut Self::Handle); - - fn to_function(self) -> Function - where - Self: Typed, - { - let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; - let resource_handle = self.into_raw(); - - check_call!(ffi::TVMFuncCreateFromCFunc( - Some(Self::tvm_callback), - resource_handle as *mut _, - Some(Self::tvm_finalizer), - &mut fhandle as *mut ffi::TVMFunctionHandle, - )); - - Function::from_raw(fhandle) - } - - /// The callback function which is wrapped converted by TVM - /// into a packed function stored in fhandle. - unsafe extern "C" fn tvm_callback( - args: *mut ffi::TVMValue, - type_codes: *mut c_int, - num_args: c_int, - ret: ffi::TVMRetValueHandle, - resource_handle: *mut c_void, - ) -> c_int - where - Self: Typed, - { - #![allow(unused_assignments, unused_unsafe)] - let result = std::panic::catch_unwind(|| { - // turning off the incorrect linter complaints - let len = num_args as usize; - let args_list = slice::from_raw_parts_mut(args, len); - let type_codes_list = slice::from_raw_parts_mut(type_codes, len); - let mut local_args: Vec = Vec::new(); - let mut value = ffi::TVMValue { v_int64: 0 }; - let mut tcode = 0; - let resource_handle = resource_handle as *mut Self::Handle; - for i in 0..len { - value = args_list[i]; - tcode = type_codes_list[i]; - // TODO(@jroesch): I believe it is sound to disable this specialized move rule. - // - // This is used in C++ to deal with moving an RValue or reference to a return value - // directly so you can skip copying. - // - // I believe this is not needed as the move directly occurs into the Rust function. - - // if tcode == ffi::TVMArgTypeCode_kTVMObjectHandle as c_int - // || tcode == ffi::TVMArgTypeCode_kTVMObjectRValueRefArg as c_int - // || tcode == ffi::TVMArgTypeCode_kTVMPackedFuncHandle as c_int - // || tcode == ffi::TVMArgTypeCode_kTVMModuleHandle as c_int - // || tcode == ffi::TVMArgTypeCode_kTVMNDArrayHandle as c_int - // { - // check_call!(ffi::TVMCbArgToReturn( - // &mut value as *mut _, - // &mut tcode as *mut _ - // )); - // } - let arg_value = ArgValue::from_tvm_value(value, tcode as u32); - local_args.push(arg_value); - } - - let rv = match Self::call(resource_handle, local_args) { - Ok(v) => v, - Err(msg) => { - return Err(msg); - } - }; - - // TODO(@jroesch): clean up the handling of the is dec_ref - match rv.clone().try_into() as Result> { - Err(_) => {} - Ok(v) => drop(v), - }; - - let (mut ret_val, ret_tcode) = rv.to_tvm_value(); - let mut ret_type_code = ret_tcode as c_int; - - check_call!(ffi::TVMCFuncSetReturn( - ret, - &mut ret_val as *mut _, - &mut ret_type_code as *mut _, - 1 as c_int - )); - - Ok(()) - }); - - // Here we handle either a panic or true error to isolate - // the unwinding as it will cause issues if we allow Rust - // to unwind over C++ boundary without care. - match result { - Err(_) => { - // TODO(@jroesch): figure out how to improve error here. - crate::set_last_error(&Error::Panic); - return -1; - } - Ok(inner_res) => match inner_res { - Err(err) => { - crate::set_last_error(&err); - return -1; - } - Ok(()) => return 0, - }, - } - } - - /// The finalizer which is invoked when the packed function's - /// reference count is zero. - unsafe extern "C" fn tvm_finalizer(fhandle: *mut c_void) { - let handle = std::mem::transmute(fhandle); - Self::drop(handle) - } -} - -pub struct RawArgs; - -impl Typed for for<'a> fn(Vec>) -> Result { - fn args<'arg>(args: Vec>) -> Result> { - Ok(Args::Raw(args)) - } - - fn ret(o: RetValue) -> Result { - Ok(o) - } -} - -impl ToFunction for for<'arg> fn(Vec>) -> Result { - type Handle = for<'arg> fn(Vec>) -> Result; - - fn into_raw(self) -> *mut Self::Handle { - let ptr: Box = Box::new(self); - Box::into_raw(ptr) - } - - fn call<'arg>(handle: *mut Self::Handle, args: Vec>) -> Result { - unsafe { - let func = *handle; - func(args) - } - } - - fn drop(_: *mut Self::Handle) {} -} - -/// A helper trait which correctly captures the complex conversion and lifetime semantics needed -/// to coerce an ordinary Rust value into `ArgValue`. -pub trait TryFromArgValue: TryFrom { - fn from_arg_value(f: F) -> std::result::Result; -} - -impl<'a, T> TryFromArgValue> for T -where - Self: TryFrom>, - Error: From<>>::Error>, -{ - fn from_arg_value(f: ArgValue<'a>) -> std::result::Result { - Ok(TryFrom::try_from(f)?) - } -} - -macro_rules! impl_typed_and_to_function { - ($len:literal; $($t:ident),*) => { - impl Typed<($($t,)*), Out> for Fun - where - Fun: Fn($($t),*) -> Out, - Out: TryInto, - Error: From, - $( for<'a> $t: TryFromArgValue>, )* - { - #[allow(non_snake_case, unused_variables, unused_mut)] - fn args<'arg>(args: Vec>) -> Result> { - if args.len() != $len { - return Err(Error::CallFailed(format!("{} expected {} arguments, got {}.\n", - std::any::type_name::(), - $len, args.len()))) - } - let mut args = args.into_iter(); - $(let $t = TryFromArgValue::from_arg_value(args.next().unwrap())?;)* - Ok(Args::Typed(($($t,)*))) - } - - fn ret(out: Out) -> Result { - out.try_into().map_err(|e| e.into()) - } - } - - - impl ToFunction<($($t,)*), Out> for Fun - where - Fun: Fn($($t,)*) -> Out + 'static - { - type Handle = Box Out + 'static>; - - fn into_raw(self) -> *mut Self::Handle { - let ptr: Box = Box::new(Box::new(self)); - Box::into_raw(ptr) - } - - #[allow(non_snake_case)] - fn call<'a>(handle: *mut Self::Handle, args: Vec>) -> Result - where - Fun: Typed<($($t,)*), Out> - { - let ($($t,)*) = match Fun::args(args)? { - Args::Raw(_) => panic!("impossible case"), - Args::Typed(typed) => typed, - }; - - let fn_ptr = unsafe { &*handle }; - let out = fn_ptr($($t),*); - Fun::ret(out) - } - - fn drop(ptr: *mut Self::Handle) { - let bx = unsafe { Box::from_raw(ptr) }; - std::mem::drop(bx) - } - } - } -} - -impl_typed_and_to_function!(0;); -impl_typed_and_to_function!(1; A); -impl_typed_and_to_function!(2; A, B); -impl_typed_and_to_function!(3; A, B, C); -impl_typed_and_to_function!(4; A, B, C, D); -impl_typed_and_to_function!(5; A, B, C, D, E); -impl_typed_and_to_function!(6; A, B, C, D, E, F); -impl_typed_and_to_function!(7; A, B, C, D, E, F, G); -impl_typed_and_to_function!(8; A, B, C, D, E, F, G, H); - -#[cfg(test)] -mod tests { - use super::*; - - fn call<'a, F, I, O>(f: F, args: Vec>) -> Result - where - F: ToFunction, - F: Typed, - { - F::call(f.into_raw(), args) - } - - #[test] - fn test_to_function0() { - fn zero() -> i32 { - 10 - } - let _ = zero.to_function(); - let good = call(zero, vec![]).unwrap(); - assert_eq!(i32::try_from(good).unwrap(), 10); - let bad = call(zero, vec![1.into()]).unwrap_err(); - assert!(matches!(bad, Error::CallFailed(..))); - } - - #[test] - fn test_to_function2() { - fn two_arg(i: i32, j: i32) -> i32 { - i + j - } - let good = call(two_arg, vec![3.into(), 4.into()]).unwrap(); - assert_eq!(i32::try_from(good).unwrap(), 7); - } -} diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml deleted file mode 100644 index 03e1d4e13d55..000000000000 --- a/rust/tvm-sys/Cargo.toml +++ /dev/null @@ -1,81 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "tvm-sys" -version = "0.1.1-alpha" -authors = ["TVM Contributors"] -license = "Apache-2.0" -edition = "2018" -description = "Low level bindings to TVM's cross language API." - -[features] -default = ["dynamic-linking"] -static-linking = [] -dynamic-linking = [] -runtime-only = [] -# Enabling any of the following features is like setting the value to "ON" in config.cmake. -use-cuda = [] -use-opencl = [] -use-vulkan = [] -use-metal = [] -use-rocm = [] -use-hexagon-device = [] -use-rpc = [] -use-threads = [] -use-llvm = [] -use-stackvm-runtime = [] -use-openmp = [] -use-rtti = [] -use-mscv-mt = [] -use-install-dev = [] -hide-private-symbols = [] -use-fallback-stl-map = [] -use-index-default-i64 = [] -use-tf-tvmdsoop = [] -use-byodt-posit = [] -use-mkl = [] -use-mkldnn = [] -use-dnnl-codegen = [] -use-cudnn = [] -use-cublas = [] -use-thrust = [] -use-miopen = [] -use-rocblas = [] -use-sort = [] -use-nnpack = [] -use-random = [] -use-cpp-rpc = [] -use-tflite = [] -use-coreml = [] -use-target-onnx = [] -use-arm-compute-lib = [] -use-arm-compute-lib-graph-runtime = [] -use-tensorrt-codegen = [] -use-tensorrt-runtime = [] -build-static-runtime = [] - -[dependencies] -thiserror = "^1.0" -anyhow = "^1.0" -ndarray = "0.12" -enumn = "^0.1" - -[build-dependencies] -bindgen = { version="0.57", default-features = false, features = ["runtime"] } -anyhow = "^1.0" -tvm-build = "0.2.4" diff --git a/rust/tvm-sys/README.md b/rust/tvm-sys/README.md deleted file mode 100644 index 735a9431aa33..000000000000 --- a/rust/tvm-sys/README.md +++ /dev/null @@ -1,28 +0,0 @@ - - - - - - - - - - - - - - - - - -# tvm-sys - -The low level bindings to TVM's C APIs for interacting with the runtime, -the cross-language object system, and packed function API. - -These will generate bindings to TVM, if you set `TVM_HOME` variable before -building it will instruct the bindings to use your source tree, if not the -crate will use `tvm-build` in order to build a sandboxed version of the library. - -This feature is intended to simplify the installation for brand new TVM users -by trying to automate the build process as much as possible. diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs deleted file mode 100644 index 2f30afb4b0ab..000000000000 --- a/rust/tvm-sys/build.rs +++ /dev/null @@ -1,274 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -extern crate bindgen; - -use std::{ - path::{Path, PathBuf}, - str::FromStr, -}; - -use anyhow::{Context, Result}; -use tvm_build::{BuildConfig, CMakeSetting}; - -/// The necessary information for detecting a TVM installation. -struct TVMInstall { - source_path: PathBuf, - build_path: PathBuf, -} - -/// Find the TVM install using the provided path. -fn find_using_tvm_path>(tvm_path: P) -> Result { - Ok(TVMInstall { - source_path: tvm_path.as_ref().into(), - build_path: tvm_path.as_ref().into(), - }) -} - -#[allow(unused)] -fn if_unset, V: AsRef>(k: K, v: V) -> Result<()> { - match std::env::var(k.as_ref()) { - Ok(other) if other != "" => { - println!( - "cargo:warning=Using existing environment variable setting {:?}={:?}", - k.as_ref(), - v.as_ref() - ); - } - _ => std::env::set_var(k, v), - } - - Ok(()) -} - -/// Find a TVM installation using TVM build by either first installing or detecting. -fn find_using_tvm_build() -> Result { - let mut build_config = BuildConfig::default(); - build_config.repository = Some("https://github.com/apache/tvm".to_string()); - build_config.branch = Some(option_env!("TVM_BRANCH").unwrap_or("main").into()); - - if cfg!(feature = "use-cuda") { - build_config.settings.use_cuda = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-opencl") { - build_config.settings.use_opencl = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-vulkan") { - build_config.settings.use_vulkan = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-rocm") { - build_config.settings.use_rocm = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-metal") { - build_config.settings.use_metal = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-hexagon-device") { - build_config.settings.use_hexagon_device = Some(true); - } - if cfg!(feature = "use-rpc") { - build_config.settings.use_rpc = Some(true); - } - if cfg!(feature = "use-threads") { - build_config.settings.use_threads = Some(true); - } - if cfg!(feature = "use-llvm") { - build_config.settings.use_llvm = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-stackvm-runtime") { - build_config.settings.use_stackvm_runtime = Some(true); - } - if cfg!(feature = "use-graph-runtime") { - build_config.settings.use_graph_runtime = Some(true); - } - if cfg!(feature = "use-graph-runtime-debug") { - build_config.settings.use_graph_runtime_debug = Some(true); - } - if cfg!(feature = "use-openmp") { - build_config.settings.use_openmp = Some(true); - } - if cfg!(feature = "use-rtti") { - build_config.settings.use_rtti = Some(true); - } - if cfg!(feature = "use-mscv-mt") { - build_config.settings.use_mscv_mt = Some(true); - } - if cfg!(feature = "use-install-dev") { - build_config.settings.use_install_dev = Some(true); - } - if cfg!(feature = "hide_private-symbols") { - build_config.settings.hide_private_symbols = Some(true); - } - if cfg!(feature = "use-fallback-stl-map") { - build_config.settings.use_fallback_stl_map = Some(true); - } - if cfg!(feature = "use-index_default-i64") { - build_config.settings.use_index_default_i64 = Some(true); - } - if cfg!(feature = "use-tf-tvmdsoop") { - build_config.settings.use_tf_tvmdsoop = Some(true); - } - if cfg!(feature = "use-byodt-posit") { - build_config.settings.use_byodt_posit = Some(true); - } - if cfg!(feature = "use-mkl") { - build_config.settings.use_mkl = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-mkldnn") { - build_config.settings.use_mkldnn = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-dnnl-codegen") { - build_config.settings.use_dnnl_codegen = Some(true); - } - if cfg!(feature = "use-cudnn") { - build_config.settings.use_cudnn = Some(true); - } - if cfg!(feature = "use-cublas") { - build_config.settings.use_cublas = Some(true); - } - if cfg!(feature = "use-thrust") { - build_config.settings.use_thrust = Some(true); - } - if cfg!(feature = "use-miopen") { - build_config.settings.use_miopen = Some(true); - } - if cfg!(feature = "use-rocblas") { - build_config.settings.use_rocblas = Some(true); - } - if cfg!(feature = "use-sort") { - build_config.settings.use_sort = Some(true); - } - if cfg!(feature = "use-nnpack") { - build_config.settings.use_nnpack = Some(true); - } - if cfg!(feature = "use-random") { - build_config.settings.use_random = Some(true); - } - if cfg!(feature = "use-cpp-rpc") { - build_config.settings.use_cpp_rpc = Some(true); - } - if cfg!(feature = "use-tflite") { - build_config.settings.use_tflite = Some(true); - } - if cfg!(feature = "use-coreml") { - build_config.settings.use_coreml = Some(true); - } - if cfg!(feature = "use-target-onnx") { - build_config.settings.use_target_onnx = Some(true); - } - if cfg!(feature = "use-arm-compute-lib") { - build_config.settings.use_arm_compute_lib = Some(true); - } - if cfg!(feature = "use-arm-compute-lib-graph-runtime") { - build_config.settings.use_arm_compute_lib_graph_runtime = CMakeSetting::from_str("on").ok(); - } - if cfg!(feature = "use-tensorrt-codegen") { - build_config.settings.use_tensorrt_codegen = Some(true); - } - if cfg!(feature = "use-tensorrt-runtime") { - build_config.settings.use_tensorrt_runtime = CMakeSetting::from_str("on").ok(); - } - if cfg!(any( - feature = "static-linking", - feature = "build-static-runtime" - )) { - build_config.settings.build_static_runtime = Some(true); - } - - let build_result = tvm_build::build(build_config)?; - let source_path = build_result.revision.source_path(); - let build_path = build_result.revision.build_path(); - Ok(TVMInstall { - source_path, - build_path, - }) -} - -fn main() -> Result<()> { - let TVMInstall { - source_path, - build_path, - } = match option_env!("TVM_HOME") { - Some(tvm_path) if tvm_path != "" => find_using_tvm_path(tvm_path), - _ => find_using_tvm_build(), - }?; - - // If the TVM_HOME environment variable changed, the LLVM_CONFIG_PATH environment variable - // changed or the source headers have changed we need to rebuild the Rust bindings. - println!("cargo:rerun-if-env-changed=TVM_HOME"); - println!("cargo:rerun-if-env-changed=LLVM_CONFIG_PATH"); - println!("cargo:rerun-if-changed={}/include", source_path.display()); - - let library_name = if cfg!(feature = "runtime-only") { - "tvm_runtime" - } else { - "tvm" - }; - - match &std::env::var("CARGO_CFG_TARGET_ARCH") - .expect("CARGO_CFG_TARGET_ARCH must be set by CARGO")[..] - { - "wasm32" => {} - _ => { - if cfg!(feature = "static-linking") { - println!("cargo:rustc-link-lib=static={}", library_name); - // TODO(@jroesch): move this to tvm-build as library_path? - println!( - "cargo:rustc-link-search=native={}/build", - build_path.display() - ); - } - - if cfg!(feature = "dynamic-linking") { - println!("cargo:rustc-link-lib=dylib={}", library_name); - println!( - "cargo:rustc-link-search=native={}/build", - build_path.display() - ); - } - } - }; - - let runtime_api = source_path.join("include/tvm/runtime/c_runtime_api.h"); - let backend_api = source_path.join("include/tvm/runtime/c_backend_api.h"); - let source_path = source_path.display().to_string(); - let dlpack_include = format!("-I{}/3rdparty/dlpack/include/", source_path); - let tvm_include = format!("-I{}/include/", source_path); - - let out_file = PathBuf::from(std::env::var("OUT_DIR")?).join("c_runtime_api.rs"); - - // @see rust-bindgen#550 for `blacklist_type` - bindgen::Builder::default() - .header(runtime_api.display().to_string()) - .header(backend_api.display().to_string()) - .clang_arg(dlpack_include) - .clang_arg(tvm_include) - .blacklist_type("max_align_t") - .layout_tests(false) - .derive_partialeq(true) - .derive_eq(true) - .derive_default(true) - .generate() - .map_err(|()| { - anyhow::anyhow!("bindgen failed to generate the Rust bindings for the C API") - })? - .write_to_file(out_file) - .context("failed to write the generated Rust binding to disk")?; - - Ok(()) -} diff --git a/rust/tvm-sys/src/array.rs b/rust/tvm-sys/src/array.rs deleted file mode 100644 index 92208303e89c..000000000000 --- a/rust/tvm-sys/src/array.rs +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{ - mem, - os::raw::{c_int, c_void}, -}; - -use crate::ffi::{ - DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLDevice, - DLDeviceType_kDLCPU, DLTensor, -}; - -/// `From` conversions to `DLTensor` for `ndarray::Array`. -/// Takes a reference to the `ndarray` since `DLTensor` is not owned. -macro_rules! impl_dltensor_from_ndarray { - ($type:ty, $typecode:expr) => { - impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { - fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { - DLTensor { - data: arr.as_mut_ptr() as *mut c_void, - device: DLDevice { - device_type: DLDeviceType_kDLCPU, - device_id: 0, - }, - ndim: arr.ndim() as c_int, - dtype: DLDataType { - code: $typecode as u8, - bits: 8 * mem::size_of::<$type>() as u8, - lanes: 1, - }, - shape: arr.shape().as_ptr() as *const i64 as *mut i64, - strides: arr.strides().as_ptr() as *const i64 as *mut i64, - byte_offset: 0, - ..Default::default() - } - } - } - }; -} - -impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); -impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); -impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); -impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); -impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); -impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); diff --git a/rust/tvm-sys/src/byte_array.rs b/rust/tvm-sys/src/byte_array.rs deleted file mode 100644 index 2903a81d9c36..000000000000 --- a/rust/tvm-sys/src/byte_array.rs +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -use std::convert::TryFrom; - -use crate::errors::ValueDowncastError; -use crate::ffi::{TVMByteArray, TVMByteArrayFree}; -use crate::{ArgValue, RetValue}; - -/// A newtype wrapping a raw TVM byte-array. -/// -/// ## Example -/// -/// ``` -/// let v = b"hello"; -/// let barr = tvm_sys::ByteArray::from(&v); -/// assert_eq!(barr.len(), v.len()); -/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); -/// ``` -pub enum ByteArray { - Rust(TVMByteArray), - External(TVMByteArray), -} - -impl Drop for ByteArray { - fn drop(&mut self) { - match self { - ByteArray::Rust(bytes) => { - let ptr = bytes.data; - let len = bytes.size as _; - let cap = bytes.size as _; - let data: Vec = unsafe { Vec::from_raw_parts(ptr as _, len, cap) }; - drop(data); - } - ByteArray::External(byte_array) => unsafe { - if TVMByteArrayFree(byte_array as _) != 0 { - panic!("error"); - } - }, - } - } -} - -impl ByteArray { - /// Gets the underlying byte-array - pub fn data(&self) -> &[u8] { - match self { - ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => unsafe { - std::slice::from_raw_parts(byte_array.data as *const u8, byte_array.size as _) - }, - } - } - - /// Gets the length of the underlying byte-array - pub fn len(&self) -> usize { - match self { - ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => byte_array.size as _, - } - } - - /// Converts the underlying byte-array to `Vec` - pub fn to_vec(&self) -> Vec { - self.data().to_vec() - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } -} - -impl>> From for ByteArray { - fn from(arg: T) -> Self { - let mut incoming_bytes: Vec = arg.into(); - let mut bytes = Vec::with_capacity(incoming_bytes.len()); - bytes.append(&mut incoming_bytes); - - let mut bytes = std::mem::ManuallyDrop::new(bytes); - let ptr = bytes.as_mut_ptr(); - assert_eq!(bytes.len(), bytes.capacity()); - ByteArray::Rust(TVMByteArray { - data: ptr as _, - size: bytes.len() as _, - }) - } -} - -impl<'a> From<&'a ByteArray> for ArgValue<'a> { - fn from(val: &'a ByteArray) -> ArgValue<'a> { - match val { - ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { - ArgValue::Bytes(byte_array) - } - } - } -} - -// todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership. -// impl From for RetValue { -// fn from(val: ByteArray) -> RetValue { -// match val { -// ByteArray::Rust(byte_array) | ByteArray::External(byte_array) => { -// // TODO(@jroesch): This requires a little more work, going to land narratives -// RetValue::Bytes(byte_array) -// } -// } -// } -// } - -impl TryFrom for ByteArray { - type Error = ValueDowncastError; - fn try_from(val: RetValue) -> Result { - match val { - RetValue::Bytes(array) => Ok(ByteArray::External(array)), - _ => Err(ValueDowncastError { - expected_type: "ByteArray", - actual_type: format!("{:?}", val), - }), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn convert() { - let v = vec![1u8, 2, 3]; - let barr = ByteArray::from(v.to_vec()); - assert_eq!(barr.len(), v.len()); - assert_eq!(barr.to_vec(), vec![1u8, 2, 3]); - let v = b"hello"; - let barr = ByteArray::from(v.to_vec()); - assert_eq!(barr.len(), v.len()); - assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]); - } -} diff --git a/rust/tvm-sys/src/datatype.rs b/rust/tvm-sys/src/datatype.rs deleted file mode 100644 index 5f7e0c3a3b60..000000000000 --- a/rust/tvm-sys/src/datatype.rs +++ /dev/null @@ -1,214 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::any::TypeId; -use std::convert::TryFrom; -use std::str::FromStr; - -use crate::ffi::DLDataType; -use crate::packed_func::RetValue; - -use thiserror::Error; - -const DL_INT_CODE: u8 = 0; -const DL_UINT_CODE: u8 = 1; -const DL_FLOAT_CODE: u8 = 2; -const DL_HANDLE: u8 = 3; - -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] -#[repr(C)] -pub struct DataType { - code: u8, - bits: u8, - lanes: u16, -} - -impl DataType { - pub const fn new(code: u8, bits: u8, lanes: u16) -> DataType { - DataType { code, bits, lanes } - } - - /// Returns the number of bytes occupied by an element of this `DataType`. - pub fn itemsize(&self) -> usize { - (self.bits as usize * self.lanes as usize) >> 3 - } - - /// Returns whether this `DataType` represents primitive type `T`. - pub fn is_type(&self) -> bool { - if self.lanes != 1 { - return false; - } - let typ = TypeId::of::(); - (typ == TypeId::of::() && self.code == DL_INT_CODE && self.bits == 32) - || (typ == TypeId::of::() && self.code == DL_INT_CODE && self.bits == 64) - || (typ == TypeId::of::() && self.code == DL_UINT_CODE && self.bits == 32) - || (typ == TypeId::of::() && self.code == DL_UINT_CODE && self.bits == 64) - || (typ == TypeId::of::() && self.code == DL_FLOAT_CODE && self.bits == 32) - || (typ == TypeId::of::() && self.code == DL_FLOAT_CODE && self.bits == 64) - } - - pub fn code(&self) -> usize { - self.code as usize - } - - pub fn bits(&self) -> usize { - self.bits as usize - } - - pub fn lanes(&self) -> usize { - self.lanes as usize - } - - pub const fn int(bits: u8, lanes: u16) -> DataType { - DataType::new(DL_INT_CODE, bits, lanes) - } - - pub const fn float(bits: u8, lanes: u16) -> DataType { - DataType::new(DL_FLOAT_CODE, bits, lanes) - } - - pub const fn float32() -> DataType { - Self::float(32, 1) - } - - pub const fn uint(bits: u8, lanes: u16) -> DataType { - DataType::new(DL_UINT_CODE, bits, lanes) - } -} - -impl<'a> From<&'a DataType> for DLDataType { - fn from(dtype: &'a DataType) -> Self { - Self { - code: dtype.code as u8, - bits: dtype.bits as u8, - lanes: dtype.lanes as u16, - } - } -} - -impl From for DataType { - fn from(dtype: DLDataType) -> Self { - Self { - code: dtype.code, - bits: dtype.bits, - lanes: dtype.lanes, - } - } -} - -impl From for DLDataType { - fn from(dtype: DataType) -> Self { - Self { - code: dtype.code, - bits: dtype.bits, - lanes: dtype.lanes, - } - } -} - -#[derive(Debug, Error)] -pub enum ParseDataTypeError { - #[error("invalid number: {0}")] - InvalidNumber(std::num::ParseIntError), - #[error("missing data type specifier (e.g., int32, float64)")] - MissingDataType, - #[error("unknown type: {0}")] - UnknownType(String), -} - -/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}` -/// such as "int32", "float32" or with lane "float32x1". -impl FromStr for DataType { - type Err = ParseDataTypeError; - - fn from_str(type_str: &str) -> Result { - use ParseDataTypeError::*; - - if type_str == "bool" { - return Ok(DataType::new(1, 1, 1)); - } - - let mut type_lanes = type_str.split('x'); - let typ = type_lanes.next().ok_or(MissingDataType)?; - let lanes = type_lanes - .next() - .map(|l| ::from_str_radix(l, 10)) - .unwrap_or(Ok(1)) - .map_err(InvalidNumber)?; - let (type_name, bits) = match typ.find(char::is_numeric) { - Some(idx) => { - let (name, bits_str) = typ.split_at(idx); - ( - name, - u8::from_str_radix(bits_str, 10).map_err(InvalidNumber)?, - ) - } - None => (typ, 32), - }; - - let type_code = match type_name { - "int" => DL_INT_CODE, - "uint" => DL_UINT_CODE, - "float" => DL_FLOAT_CODE, - "handle" => DL_HANDLE, - _ => return Err(UnknownType(type_name.to_string())), - }; - - Ok(DataType::new(type_code, bits, lanes)) - } -} - -impl std::fmt::Display for DataType { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - if self.bits == 1 && self.lanes == 1 { - return write!(f, "bool"); - } - let mut type_str = match self.code { - DL_INT_CODE => "int", - DL_UINT_CODE => "uint", - DL_FLOAT_CODE => "float", - DL_HANDLE => "handle", - _ => "unknown", - } - .to_string(); - - type_str += &self.bits.to_string(); - if self.lanes > 1 { - type_str += &format!("x{}", self.lanes); - } - f.write_str(&type_str) - } -} - -impl From for RetValue { - fn from(dt: DataType) -> RetValue { - RetValue::DataType((&dt).into()) - } -} - -impl TryFrom for DataType { - type Error = anyhow::Error; - fn try_from(ret_value: RetValue) -> anyhow::Result { - match ret_value { - RetValue::DataType(dt) => Ok(dt.into()), - // TODO(@jroesch): improve - _ => Err(anyhow::anyhow!("unable to convert datatype from ...")), - } - } -} diff --git a/rust/tvm-sys/src/device.rs b/rust/tvm-sys/src/device.rs deleted file mode 100644 index 0344983c1622..000000000000 --- a/rust/tvm-sys/src/device.rs +++ /dev/null @@ -1,294 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! Provides [`Device`] and related device queries. -//! -//! Create a new device for device type and device id. -//! -//! # Example -//! -//! ``` -//! # use tvm_sys::{DeviceType, Device}; -//! let cpu = DeviceType::from("cpu"); -//! let dev = Device::new(cpu , 0); -//! let cpu0 = Device::cpu(0); -//! assert_eq!(dev, cpu0); -//! ``` -//! -//! Or from a supported device name. -//! -//! ``` -//! use tvm_sys::Device; -//! let cpu0 = Device::from("cpu"); -//! println!("{}", cpu0); -//! ``` - -use std::convert::TryFrom; -use std::fmt::{self, Display, Formatter}; -use std::str::FromStr; - -use crate::ffi::{self, *}; -use crate::packed_func::{ArgValue, RetValue}; - -use anyhow::Result; -use enumn::N; -use thiserror::Error; - -/// Device type represents the set of devices supported by -/// [TVM](https://github.com/apache/tvm). -/// -/// ## Example -/// -/// ``` -/// use tvm_sys::DeviceType; -/// let cpu = DeviceType::from("cpu"); -/// println!("device is: {}", cpu); -///``` - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, N)] -#[repr(i64)] -pub enum DeviceType { - CPU = 1, - CUDA = 2, - CUDAHost = 3, - OpenCL = 4, - Vulkan = 7, - Metal = 8, - VPI = 9, - ROCM = 10, - ExtDev = 12, -} - -impl Default for DeviceType { - /// default device is cpu. - fn default() -> Self { - DeviceType::CPU - } -} - -impl From for ffi::DLDeviceType { - fn from(device_type: DeviceType) -> Self { - device_type as Self - } -} - -impl From for DeviceType { - fn from(device_type: ffi::DLDeviceType) -> Self { - Self::n(device_type as _).expect("invalid enumeration value for ffi::DLDeviceType") - } -} - -impl Display for DeviceType { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!( - f, - "{}", - match self { - DeviceType::CPU => "cpu", - DeviceType::CUDA => "cuda", - DeviceType::CUDAHost => "cuda_host", - DeviceType::OpenCL => "opencl", - DeviceType::Vulkan => "vulkan", - DeviceType::Metal => "metal", - DeviceType::VPI => "vpi", - DeviceType::ROCM => "rocm", - DeviceType::ExtDev => "ext_device", - // DeviceType(_) => "rpc", - } - ) - } -} - -impl<'a> From<&'a str> for DeviceType { - fn from(type_str: &'a str) -> Self { - match type_str { - "cpu" => DeviceType::CPU, - "llvm" => DeviceType::CPU, - "cuda" => DeviceType::CUDA, - "nvptx" => DeviceType::CUDA, - "cl" => DeviceType::OpenCL, - "opencl" => DeviceType::OpenCL, - "metal" => DeviceType::Metal, - "vpi" => DeviceType::VPI, - "rocm" => DeviceType::ROCM, - _ => panic!("{:?} not supported!", type_str), - } - } -} - -impl<'a> From<&DeviceType> for ArgValue<'a> { - fn from(dev: &DeviceType) -> Self { - Self::Int(*dev as _) - } -} - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub struct Device { - pub device_type: DeviceType, - pub device_id: usize, -} - -impl Device { - pub fn new(device_type: DeviceType, device_id: usize) -> Device { - Device { - device_type, - device_id, - } - } -} - -impl<'a> From<&'a Device> for DLDevice { - fn from(dev: &'a Device) -> Self { - Self { - device_type: dev.device_type.into(), - device_id: dev.device_id as i32, - } - } -} - -impl Default for Device { - fn default() -> Self { - Self { - device_type: DLDeviceType_kDLCPU.into(), - device_id: 0, - } - } -} - -#[derive(Debug, Error)] -#[error("unsupported device: {0}")] -pub struct UnsupportedDeviceError(String); - -macro_rules! impl_tvm_device { - ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => { - /// Creates a Device from a string (e.g., "cpu", "cuda", "ext_dev") - impl FromStr for Device { - type Err = UnsupportedDeviceError; - fn from_str(type_str: &str) -> Result { - Ok(Self { - device_type: match type_str { - $( $( stringify!($dev_name) )|+ => $dev_type.into()),+, - _ => return Err(UnsupportedDeviceError(type_str.to_string())), - }, - device_id: 0, - }) - } - } - - impl Device { - $( - $( - pub fn $dev_name(device_id: usize) -> Self { - Self { - device_type: $dev_type.into(), - device_id: device_id, - } - } - )+ - )+ - } - }; -} - -impl_tvm_device!( - DLDeviceType_kDLCPU: [cpu, llvm], - DLDeviceType_kDLCUDA: [cuda, nvptx], - DLDeviceType_kDLOpenCL: [cl], - DLDeviceType_kDLMetal: [metal], - DLDeviceType_kDLVPI: [vpi], - DLDeviceType_kDLROCM: [rocm], - DLDeviceType_kDLExtDev: [ext_dev] -); - -impl<'a> From<&'a str> for Device { - fn from(target: &str) -> Self { - Device::new(DeviceType::from(target), 0) - } -} - -impl From for Device { - fn from(dev: ffi::DLDevice) -> Self { - Device { - device_type: DeviceType::from(dev.device_type), - device_id: dev.device_id as usize, - } - } -} - -impl From for ffi::DLDevice { - fn from(dev: Device) -> Self { - ffi::DLDevice { - device_type: dev.device_type.into(), - device_id: dev.device_id as i32, - } - } -} - -impl Display for Device { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "{}({})", self.device_type, self.device_id) - } -} - -impl<'a> From<&'a Device> for ArgValue<'a> { - fn from(dev: &'a Device) -> Self { - DLDevice::from(dev).into() - } -} - -impl<'a> From for ArgValue<'a> { - fn from(dev: Device) -> Self { - DLDevice::from(dev).into() - } -} - -impl From for RetValue { - fn from(ret_value: Device) -> RetValue { - RetValue::Device(ret_value.into()) - } -} - -impl TryFrom for Device { - type Error = anyhow::Error; - fn try_from(ret_value: RetValue) -> anyhow::Result { - match ret_value { - RetValue::Device(dt) => Ok(dt.into()), - // TODO(@jroesch): improve - _ => Err(anyhow::anyhow!("unable to convert datatype from ...")), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn device() { - let dev = Device::cpu(0); - println!("device: {}", dev); - let default_dev = Device::new(DeviceType::CPU, 0); - assert_eq!(dev.clone(), default_dev); - assert_ne!(dev, Device::cuda(0)); - - let str_dev = Device::new(DeviceType::CUDA, 0); - assert_eq!(str_dev.clone(), str_dev); - assert_ne!(str_dev, Device::new(DeviceType::CPU, 0)); - } -} diff --git a/rust/tvm-sys/src/errors.rs b/rust/tvm-sys/src/errors.rs deleted file mode 100644 index 54fe261ec37e..000000000000 --- a/rust/tvm-sys/src/errors.rs +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use thiserror::Error; - -#[derive(Error, Debug)] -#[error("invalid header (expected {expected_type:?}, found {actual_type:?})")] -pub struct ValueDowncastError { - pub actual_type: String, - pub expected_type: &'static str, -} - -#[derive(Error, Debug)] -#[error("Function call `{context:?}` returned error: {message:?}")] -pub struct FuncCallError { - context: String, - message: String, -} - -impl FuncCallError { - pub fn get_with_context(context: String) -> Self { - Self { - context, - message: unsafe { std::ffi::CStr::from_ptr(crate::ffi::TVMGetLastError()) } - .to_str() - .expect("failed while attempting to retrieve the TVM error message") - .to_owned(), - } - } -} diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs deleted file mode 100644 index f9ac3b461c69..000000000000 --- a/rust/tvm-sys/src/lib.rs +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -//! This crate contains the minimal interface over TVM's -//! C runtime API. -//! -//! These common bindings are useful to both runtimes -//! written in Rust, as well as higher level API bindings. -//! -//! See the `tvm-rt` or `tvm` crates for full bindings to -//! the TVM API. - -/// The low-level C runtime FFI API for TVM. -pub mod ffi { - #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)] - - use std::os::raw::{c_char, c_int, c_void}; - - include!(concat!(env!("OUT_DIR"), "/c_runtime_api.rs")); - - pub type BackendPackedCFunc = extern "C" fn( - args: *const TVMValue, - type_codes: *const c_int, - num_args: c_int, - out_ret_value: *mut TVMValue, - out_ret_tcode: *mut u32, - resource_handle: *mut c_void, - ) -> c_int; -} - -pub mod array; -pub mod byte_array; -pub mod datatype; -pub mod device; -pub mod errors; -#[macro_use] -pub mod packed_func; -pub mod value; - -pub use byte_array::ByteArray; -pub use datatype::DataType; -pub use device::{Device, DeviceType}; -pub use errors::*; -pub use packed_func::{ArgValue, RetValue}; - -impl std::convert::TryFrom> for RetValue -where - RetValue: std::convert::TryFrom, - E: From<>::Error>, -{ - type Error = E; - - fn try_from(val: Result) -> Result { - val.and_then(|t| RetValue::try_from(t).map_err(|e| e.into())) - } -} diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs deleted file mode 100644 index 3d78ce52d621..000000000000 --- a/rust/tvm-sys/src/packed_func.rs +++ /dev/null @@ -1,400 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{ - convert::TryFrom, - ffi::{CStr, CString}, - os::raw::{c_char, c_void}, -}; - -use crate::{errors::ValueDowncastError, ffi::*}; - -pub use crate::ffi::TVMValue; - -pub trait PackedFunc: - Fn(&[ArgValue]) -> Result + Send + Sync -{ -} - -impl PackedFunc for T where - T: Fn(&[ArgValue]) -> Result + Send + Sync -{ -} - -/// Calls a packed function and returns a `RetValue`. -/// -/// # Example -/// -/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` -#[macro_export] -macro_rules! call_packed { - ($fn:expr, $($args:expr),+) => { - $fn(&[$($args.into(),)+]) - }; - ($fn:expr) => { - $fn(&Vec::new()) - }; -} - -/// Constructs a derivative of a TVMPodValue. -macro_rules! TVMPODValue { - { - $(#[$m:meta])+ - $name:ident $(<$a:lifetime>)? { - $($extra_variant:ident ( $variant_type:ty ) ),+ $(,)? - }, - match $value:ident { - $($tvm_type:ident => { $from_tvm_type:expr })+ - }, - match &self { - $($self_type:ident ( $val:ident ) => { $from_self_type:expr })+ - } - $(,)? - } => { - $(#[$m])+ - #[derive(Clone, Debug)] - pub enum $name $(<$a>)? { - Int(i64), - UInt(i64), - Float(f64), - Bool(bool), - Null, - DataType(DLDataType), - String(*mut c_char), - Device(DLDevice), - Handle(*mut c_void), - ArrayHandle(TVMArrayHandle), - ObjectHandle(*mut c_void), - ModuleHandle(TVMModuleHandle), - FuncHandle(TVMFunctionHandle), - NDArrayHandle(*mut c_void), - $($extra_variant($variant_type)),+ - } - - impl $(<$a>)? $name $(<$a>)? { - pub fn from_tvm_value($value: TVMValue, type_code: u32) -> Self { - use $name::*; - #[allow(non_upper_case_globals)] - unsafe { - match type_code as _ { - DLDataTypeCode_kDLInt => Int($value.v_int64), - DLDataTypeCode_kDLUInt => UInt($value.v_int64), - DLDataTypeCode_kDLFloat => Float($value.v_float64), - TVMArgTypeCode_kTVMArgBool => Bool($value.v_int64 != 0), - TVMArgTypeCode_kTVMNullptr => Null, - TVMArgTypeCode_kTVMDataType => DataType($value.v_type), - TVMArgTypeCode_kDLDevice => Device($value.v_device), - TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), - TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), - TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), - TVMArgTypeCode_kTVMObjectRValueRefArg => ObjectHandle(*($value.v_handle as *mut *mut c_void)), - TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), - TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), - TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), - $( $tvm_type => { $from_tvm_type } ),+ - _ => unimplemented!("{}", type_code), - } - } - } - - pub fn to_tvm_value(&self) -> (TVMValue, TVMArgTypeCode) { - use $name::*; - match self { - Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), - UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), - Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), - Bool(val) => (TVMValue { v_int64: *val as i64 }, TVMArgTypeCode_kTVMArgBool), - Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), - DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), - Device(val) => (TVMValue { v_device: val.clone() }, TVMArgTypeCode_kDLDevice), - String(val) => { - ( - TVMValue { v_handle: *val as *mut c_void }, - TVMArgTypeCode_kTVMStr, - ) - } - Handle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMOpaqueHandle), - ArrayHandle(val) => { - ( - TVMValue { v_handle: *val as *const _ as *mut c_void }, - TVMArgTypeCode_kTVMNDArrayHandle, - ) - }, - ObjectHandle(val) => (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMObjectHandle), - ModuleHandle(val) => - (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMModuleHandle), - FuncHandle(val) => ( - TVMValue { v_handle: *val }, - TVMArgTypeCode_kTVMPackedFuncHandle - ), - NDArrayHandle(val) => - (TVMValue { v_handle: *val }, TVMArgTypeCode_kTVMNDArrayHandle), - $( $self_type($val) => { $from_self_type } ),+ - } - } - } - } -} - -TVMPODValue! { - /// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way - /// to obtain a `ArgValue` is automatically via `call_packed!`. - ArgValue<'a> { - Bytes(&'a TVMByteArray), - Str(&'a CStr), - }, - match value { - TVMArgTypeCode_kTVMBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) } - TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) } - }, - match &self { - Bytes(val) => { - (TVMValue { v_handle: *val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes) - } - Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMArgTypeCode_kTVMStr) } - } -} - -TVMPODValue! { - /// An owned TVMPODValue. Can be converted from a variety of primitive and object types. - /// Can be downcasted using `try_from` if it contains the desired type. - /// - /// # Example - /// - /// ``` - /// use std::convert::{TryFrom, TryInto}; - /// use tvm_sys::RetValue; - /// - /// let a = 42u32; - /// let b: u32 = tvm_sys::RetValue::from(a).try_into().unwrap(); - /// - /// let s = "hello, world!"; - /// let t: RetValue = s.to_string().into(); - /// assert_eq!(String::try_from(t).unwrap(), s); - /// ``` - RetValue { - Bytes(TVMByteArray), - Str(&'static CStr), - }, - match value { - TVMArgTypeCode_kTVMBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) } - TVMArgTypeCode_kTVMStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) } - }, - match &self { - Bytes(val) => - { (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMArgTypeCode_kTVMBytes ) } - Str(val) => - { (TVMValue { v_str: val.as_ptr() }, TVMArgTypeCode_kTVMStr ) } - } -} - -#[macro_export] -macro_rules! try_downcast { - ($val:ident -> $into:ty, $( |$pat:pat| { $converter:expr } ),+ ) => { - match $val { - $( $pat => { Ok($converter) } )+ - _ => Err($crate::errors::ValueDowncastError { - actual_type: format!("{:?}", $val), - expected_type: stringify!($into), - }), - } - }; -} - -/// Creates a conversion to a `ArgValue` for a primitive type and DLDataTypeCode. -macro_rules! impl_pod_value { - ($variant:ident, $inner_ty:ty, [ $( $type:ty ),+ ] ) => { - $( - impl<'a> From<$type> for ArgValue<'a> { - fn from(val: $type) -> Self { - Self::$variant(val as $inner_ty) - } - } - - impl<'a> From<&'a $type> for ArgValue<'a> { - fn from(val: &'a $type) -> Self { - Self::$variant(*val as $inner_ty) - } - } - - impl<'a> TryFrom> for $type { - type Error = $crate::errors::ValueDowncastError; - fn try_from(val: ArgValue<'a>) -> Result { - try_downcast!(val -> $type, |ArgValue::$variant(val)| { val as $type }) - } - } - - impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type { - type Error = $crate::errors::ValueDowncastError; - fn try_from(val: &'a ArgValue<'v>) -> Result { - try_downcast!(val -> $type, |ArgValue::$variant(val)| { *val as $type }) - } - } - - impl From<$type> for RetValue { - fn from(val: $type) -> Self { - Self::$variant(val as $inner_ty) - } - } - - impl TryFrom for $type { - type Error = $crate::errors::ValueDowncastError; - fn try_from(val: RetValue) -> Result { - try_downcast!(val -> $type, |RetValue::$variant(val)| { val as $type }) - } - } - )+ - }; -} - -impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]); -impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]); -impl_pod_value!(Float, f64, [f32, f64]); -impl_pod_value!(Bool, bool, [bool]); -impl_pod_value!(DataType, DLDataType, [DLDataType]); -impl_pod_value!(Device, DLDevice, [DLDevice]); - -impl<'a> From<&'a str> for ArgValue<'a> { - fn from(s: &'a str) -> Self { - Self::String(CString::new(s).unwrap().into_raw()) - } -} - -impl<'a> From for ArgValue<'a> { - fn from(s: String) -> Self { - Self::String(CString::new(s).unwrap().into_raw()) - } -} - -impl<'a> From<&'a CStr> for ArgValue<'a> { - fn from(s: &'a CStr) -> Self { - Self::Str(s) - } -} - -impl<'a> From<&'a CString> for ArgValue<'a> { - fn from(s: &'a CString) -> Self { - Self::String(s.as_ptr() as _) - } -} - -impl<'a> From<&'a TVMByteArray> for ArgValue<'a> { - fn from(s: &'a TVMByteArray) -> Self { - Self::Bytes(s) - } -} - -impl<'a> TryFrom> for &'a str { - type Error = ValueDowncastError; - fn try_from(val: ArgValue<'a>) -> Result { - try_downcast!(val -> &str, |ArgValue::Str(s)| { s.to_str().unwrap() }) - } -} - -impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for &'v str { - type Error = ValueDowncastError; - fn try_from(val: &'a ArgValue<'v>) -> Result { - try_downcast!(val -> &str, |ArgValue::Str(s)| { s.to_str().unwrap() }) - } -} - -/// Converts an unspecialized handle to a ArgValue. -impl<'a, T> From<*const T> for ArgValue<'a> { - fn from(ptr: *const T) -> Self { - Self::Handle(ptr as *mut c_void) - } -} - -/// Converts an unspecialized mutable handle to a ArgValue. -impl<'a, T> From<*mut T> for ArgValue<'a> { - fn from(ptr: *mut T) -> Self { - Self::Handle(ptr as *mut c_void) - } -} - -impl<'a> From<&'a mut DLTensor> for ArgValue<'a> { - fn from(arr: &'a mut DLTensor) -> Self { - Self::ArrayHandle(arr as *mut DLTensor) - } -} - -impl<'a> From<&'a DLTensor> for ArgValue<'a> { - fn from(arr: &'a DLTensor) -> Self { - Self::ArrayHandle(arr as *const _ as *mut DLTensor) - } -} - -impl TryFrom for String { - type Error = ValueDowncastError; - fn try_from(val: RetValue) -> Result { - try_downcast!( - val -> String, - |RetValue::String(s)| { unsafe { CString::from_raw(s).into_string().unwrap() }}, - |RetValue::Str(s)| { s.to_str().unwrap().to_string() } - ) - } -} - -impl From for RetValue { - fn from(s: String) -> Self { - Self::String(std::ffi::CString::new(s).unwrap().into_raw()) - } -} - -impl From for RetValue { - fn from(arr: TVMByteArray) -> Self { - Self::Bytes(arr) - } -} - -impl TryFrom for TVMByteArray { - type Error = ValueDowncastError; - fn try_from(val: RetValue) -> Result { - try_downcast!(val -> TVMByteArray, |RetValue::Bytes(val)| { val }) - } -} - -impl Default for RetValue { - fn default() -> Self { - Self::Int(0) - } -} - -impl TryFrom for std::ffi::CString { - type Error = ValueDowncastError; - fn try_from(val: RetValue) -> Result { - try_downcast!(val -> std::ffi::CString, - |RetValue::Str(val)| { val.into() }) - } -} - -impl From<()> for RetValue { - fn from(_: ()) -> Self { - RetValue::Null - } -} - -impl TryFrom for () { - type Error = ValueDowncastError; - - fn try_from(val: RetValue) -> Result<(), Self::Error> { - try_downcast!(val -> bool, - |RetValue::Null| { () }) - } -} diff --git a/rust/tvm-sys/src/value.rs b/rust/tvm-sys/src/value.rs deleted file mode 100644 index 9c987af4cef6..000000000000 --- a/rust/tvm-sys/src/value.rs +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::str::FromStr; - -use crate::ffi::*; - -use thiserror::Error; - -macro_rules! impl_pod_tvm_value { - ($field:ident, $field_ty:ty, $( $ty:ty ),+) => { - $( - impl From<$ty> for TVMValue { - fn from(val: $ty) -> Self { - TVMValue { $field: val as $field_ty } - } - } - - impl From for $ty { - fn from(val: TVMValue) -> Self { - unsafe { val.$field as $ty } - } - } - )+ - }; - ($field:ident, $ty:ty) => { - impl_pod_tvm_value!($field, $ty, $ty); - } -} - -impl_pod_tvm_value!(v_int64, i64, i8, u8, i16, u16, i32, u32, i64, u64, isize, usize); -impl_pod_tvm_value!(v_float64, f64, f32, f64); -impl_pod_tvm_value!(v_type, DLDataType); -impl_pod_tvm_value!(v_device, DLDevice); - -#[derive(Debug, Error)] -#[error("unsupported device: {0}")] -pub struct UnsupportedDeviceError(String); - -macro_rules! impl_tvm_device { - ( $( $dev_type:ident : [ $( $dev_name:ident ),+ ] ),+ ) => { - /// Creates a DLDevice from a string (e.g., "cpu", "cuda", "ext_dev") - impl FromStr for DLDevice { - type Err = UnsupportedDeviceError; - fn from_str(type_str: &str) -> Result { - Ok(Self { - device_type: match type_str { - $( $( stringify!($dev_name) )|+ => $dev_type ),+, - _ => return Err(UnsupportedDeviceError(type_str.to_string())), - }, - device_id: 0, - }) - } - } - - impl DLDevice { - $( - $( - pub fn $dev_name(device_id: usize) -> Self { - Self { - device_type: $dev_type, - device_id: device_id as i32, - } - } - )+ - )+ - } - }; -} - -impl_tvm_device!( - DLDeviceType_kDLCPU: [cpu, llvm], - DLDeviceType_kDLCUDA: [cuda, nvptx], - DLDeviceType_kDLOpenCL: [cl], - DLDeviceType_kDLMetal: [metal], - DLDeviceType_kDLVPI: [vpi], - DLDeviceType_kDLROCM: [rocm], - DLDeviceType_kDLExtDev: [ext_dev] -); diff --git a/tests/lint/rust_format.sh b/tests/lint/rust_format.sh deleted file mode 100755 index bed7ad976ea6..000000000000 --- a/tests/lint/rust_format.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/env bash -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -TVM_HOME="$(git rev-parse --show-toplevel)" -RUST_DIR="$TVM_HOME/rust" - -if [[ "$1" == "-i" ]]; then - INPLACE_FORMAT=1 - shift 1 -else - INPLACE_FORMAT=0 -fi - -cd $RUST_DIR - -if [[ ${INPLACE_FORMAT} -eq 1 ]]; then - cargo fmt -else - cargo fmt -- --check -fi diff --git a/tests/scripts/task_lint.sh b/tests/scripts/task_lint.sh index 3b270b21f60a..6a6a2171bd1f 100755 --- a/tests/scripts/task_lint.sh +++ b/tests/scripts/task_lint.sh @@ -72,9 +72,6 @@ function shard2 { echo "clang-format check..." tests/lint/git-clang-format.sh - echo "Rust check..." - tests/lint/rust_format.sh - echo "Docker check..." tests/lint/docker-format.sh } diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index b217f692b4a3..7b58658bd7c7 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -161,12 +161,6 @@ npm install npm run typedoc cd .. -# Rust doc -cd rust -# Temp disable rust doc build -# cargo doc --workspace --no-deps -cd .. - # Prepare the doc dir rm -rf _docs mv docs/_build/html _docs @@ -174,7 +168,6 @@ rm -f _docs/.buildinfo mkdir -p _docs/reference/api mv docs/doxygen/html _docs/reference/api/doxygen mv jvm/core/target/site/apidocs _docs/reference/api/javadoc -# mv rust/target/doc _docs/api/rust mv web/dist/docs _docs/reference/api/typedoc git rev-parse HEAD > _docs/commit_hash