Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/web examples #2691

Merged
merged 4 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 31 additions & 72 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b179e368c5404e871176155262cfd5221ae0ed60" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b179e368c5404e871176155262cfd5221ae0ed60" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-import/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ onnx = []
pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"]

[dependencies]
burn = { path = "../burn", version = "0.16.0", features = ["ndarray"] }
burn = { path = "../burn", version = "0.16.0", default-features = false, features = ["std"]}
burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", default-features = false }
onnx-ir = { path = "../onnx-ir", version = "0.16.0" }
candle-core = { workspace = true }
derive-new = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-import/src/burn/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub struct BurnGraph<PS: PrecisionSettings> {
}

// The backend used for recording.
type Backend = burn::backend::ndarray::NdArray;
type Backend = burn_ndarray::NdArray;

impl<PS: PrecisionSettings> BurnGraph<PS> {
/// Register a new operation node into the graph.
Expand Down
3 changes: 1 addition & 2 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ use super::{
unsqueeze::UnsqueezeNode,
};
use crate::burn::{BurnImports, Scope, Type};
use burn::backend::NdArray;
use burn::record::PrecisionSettings;
use proc_macro2::TokenStream;
use serde::Serialize;

/// Backend used for serialization.
pub type SerializationBackend = NdArray<f32>;
pub type SerializationBackend = burn_ndarray::NdArray<f32>;

/// Codegen trait that should be implemented by all [node](Node) entries.
pub trait NodeCodegen<PS: PrecisionSettings>: std::fmt::Debug {
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-jit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ export_tests = [
]
fusion = ["burn-fusion"]
fusion-experimental = ["fusion"]
std = ["cubecl/std", "burn-tensor/std"]

std = ["cubecl/std"]
template = []

[dependencies]
burn-common = { path = "../burn-common", version = "0.16.0" }
burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true }
burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [
burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [
"cubecl",
"repr",
] }
Expand Down
3 changes: 1 addition & 2 deletions crates/burn-tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ cubecl = ["dep:cubecl"]
cubecl-cuda = ["cubecl", "cubecl/cuda"]
cubecl-hip = ["cubecl", "cubecl/hip"]
cubecl-wgpu = ["cubecl", "cubecl/wgpu"]
default = ["std", "repr"]
default = ["std", "repr", "burn-common/rayon"]
doc = ["default"]
experimental-named-tensor = []
export_tests = ["burn-tensor-testgen", "cubecl"]
Expand All @@ -26,7 +26,6 @@ std = [
"half/std",
"num-traits/std",
"burn-common/std",
"burn-common/rayon",
"colored",
]

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ cubecl = { workspace = true, features = ["wgpu"] }

burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true }
burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [
burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [
"cubecl-wgpu",
] }

Expand Down
3 changes: 1 addition & 2 deletions examples/image-classification-web/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ half_precision = []
burn = { path = "../../crates/burn", version = "0.16.0", default-features = false, features = [
"ndarray", "wgpu",
] }
cubecl-runtime = { version = "0.3.0", features = ["channel-mpsc"] } # missing feature flags
burn-candle = { path = "../../crates/burn-candle", version = "0.16.0", default-features = false }

log = { workspace = true }
Expand All @@ -35,4 +34,4 @@ js-sys = "0.3"

[build-dependencies]
# Used to generate code from ONNX model
burn-import = { path = "../../crates/burn-import" }
burn-import = { path = "../../crates/burn-import", default-features = false, features = ["onnx"]}
3 changes: 1 addition & 2 deletions examples/mnist-inference-web/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@ crate-type = ["cdylib"]
default = ["ndarray"]

ndarray = ["burn/ndarray"]
wgpu = ["burn/wgpu", "cubecl-runtime"]
wgpu = ["burn/wgpu"]

[dependencies]
burn = { path = "../../crates/burn", default-features = false }
cubecl-runtime = { version = "0.3.0", optional = true, features = ["channel-mpsc"] } # missing feature flag
serde = { workspace = true }
console_error_panic_hook = { workspace = true }

Expand Down
4 changes: 2 additions & 2 deletions examples/mnist-inference-web/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use burn::{
};

#[cfg(feature = "wgpu")]
use burn::backend::wgpu::{init_async, AutoGraphicsApi, Wgpu, WgpuDevice};
use burn::backend::wgpu::{init_setup_async, AutoGraphicsApi, Wgpu, WgpuDevice};

#[cfg(feature = "wgpu")]
pub type Backend = Wgpu<f32, i32>;
Expand All @@ -18,7 +18,7 @@ static STATE_ENCODED: &[u8] = include_bytes!("../model.bin");
/// Builds and loads trained parameters into the model.
pub async fn build_and_load_model() -> Model<Backend> {
#[cfg(feature = "wgpu")]
init_async::<AutoGraphicsApi>(&WgpuDevice::default(), Default::default()).await;
init_setup_async::<AutoGraphicsApi>(&WgpuDevice::default(), Default::default()).await;

let model: Model<Backend> = Model::new(&Default::default());
let record = BinBytesRecorder::<FullPrecisionSettings>::default()
Expand Down
Loading