Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust] Remove mxnet dependency and re-enable rust example #17293

Merged
merged 3 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rust/tvm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ You can find the API Documentation [here](https://tvm.apache.org/docs/api/rust/t

The goal of this crate is to provide bindings to both the TVM compiler and runtime
APIs. First train your **Deep Learning** model using any major framework such as
[PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.apache.org/) or [TensorFlow](https://www.tensorflow.org/).
[PyTorch](https://pytorch.org/) or [TensorFlow](https://www.tensorflow.org/).
Then use **TVM** to build and deploy optimized model artifacts on a supported devices such as CPU, GPU, OpenCL and specialized accelerators.

The Rust bindings are composed of a few crates:
Expand Down
2 changes: 1 addition & 1 deletion rust/tvm/examples/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ This end-to-end example shows how to:
* build `Resnet 18` with `tvm` from Python
* use the provided Rust frontend API to test for an input image

To run the example with pretrained resnet weights, first `tvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet`
To run the example with pretrained resnet weights, first `tvm` and `torchvision` must be installed for the python build. To install torchvision for cpu, run `pip install torch torchvision`
and to install `tvm` with `llvm` follow the [TVM installation guide](https://tvm.apache.org/docs/install/index.html).

* **Build the example**: `cargo build
Expand Down
6 changes: 0 additions & 6 deletions rust/tvm/examples/resnet/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ use anyhow::{Context, Result};
use std::{io::Write, path::Path, process::Command};

fn main() -> Result<()> {
// Currently disabled, as it depends on the no-longer-supported
// mxnet repo to download resnet.

/*
let out_dir = std::env::var("CARGO_MANIFEST_DIR")?;
let python_script = concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py");
let synset_txt = concat!(env!("CARGO_MANIFEST_DIR"), "/synset.txt");
Expand Down Expand Up @@ -57,7 +53,5 @@ fn main() -> Result<()> {
);
println!("cargo:rustc-link-search=native={}", out_dir);

*/

Ok(())
}
28 changes: 14 additions & 14 deletions rust/tvm/examples/resnet/src/build_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,18 @@
# under the License.

import argparse
import csv
import logging
from os import path as osp
import sys
import shutil
from os import path as osp

import numpy as np

import torch
import torchvision
import tvm
from tvm import te
from tvm import relay, runtime
from tvm.relay import testing
from tvm.contrib import graph_executor, cc
from PIL import Image
from tvm import relay, runtime
from tvm.contrib import cc, graph_executor
from tvm.contrib.download import download_testdata
from mxnet.gluon.model_zoo.vision import get_model

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
Expand Down Expand Up @@ -64,11 +60,16 @@

def build(target_dir):
"""Compiles resnet18 with TVM"""
# Download the pretrained model in MxNet's format.
block = get_model("resnet18_v1", pretrained=True)
# Download the pretrained model from Torchvision.
weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
torch_model = torchvision.models.resnet18(weights=weights).eval()

input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(torch_model, input_data)
input_infos = [("data", input_data.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, input_infos)

shape_dict = {"data": (1, 3, 224, 224)}
mod, params = relay.frontend.from_mxnet(block, shape_dict)
# Add softmax to do classification in last layer.
func = mod["main"]
func = relay.Function(
Expand All @@ -93,7 +94,6 @@ def build(target_dir):

def download_img_labels():
"""Download an image and imagenet1k class labels for test"""
from mxnet.gluon.utils import download

synset_url = "".join(
[
Expand Down
5 changes: 0 additions & 5 deletions rust/tvm/examples/resnet/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ use tvm_rt::graph_rt::GraphRt;
use tvm_rt::*;

fn main() -> anyhow::Result<()> {
// Currently disabled, as it depends on the no-longer-supported
// mxnet repo to download resnet.

/*
let dev = Device::cpu(0);
println!("{}", concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png"));

Expand Down Expand Up @@ -138,7 +134,6 @@ fn main() -> anyhow::Result<()> {
"input image belongs to the class `{}` with probability {}",
label, max_prob
);
*/

Ok(())
}
Loading