Skip to content

Commit

Permalink
Merge pull request #40 from tracel-ai/chore/resnet/burn-0.14
Browse files Browse the repository at this point in the history
Update to burn 0.14
  • Loading branch information
nathanielsimard authored Sep 13, 2024
2 parents f2a56b4 + 7ef0474 commit 00dfeac
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 16 deletions.
4 changes: 2 additions & 2 deletions resnet-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ license = "MIT OR Apache-2.0"

[workspace.dependencies]
# Note: default-features = false is needed to disable std
burn = { version = "0.13.0", default-features = false }
burn-import = "0.13.0"
burn = { version = "0.14.0", default-features = false }
burn-import = "0.14.0"
dirs = "5.0.1"
serde = { version = "1.0.192", default-features = false, features = [
"derive",
Expand Down
10 changes: 5 additions & 5 deletions resnet-burn/examples/finetune/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ pub struct Normalizer<B: Backend> {
impl<B: Backend> Normalizer<B> {
/// Creates a new normalizer.
pub fn new(device: &Device<B>) -> Self {
let mean = Tensor::from_floats(MEAN, device).reshape([1, 3, 1, 1]);
let std = Tensor::from_floats(STD, device).reshape([1, 3, 1, 1]);
let mean = Tensor::<B, 1>::from_floats(MEAN, device).reshape([1, 3, 1, 1]);
let std = Tensor::<B, 1>::from_floats(STD, device).reshape([1, 3, 1, 1]);
Self { mean, std }
}

Expand Down Expand Up @@ -117,9 +117,9 @@ impl<B: Backend> Batcher<ImageDatasetItem, ClassificationBatch<B>> for Classific

let images = items
.into_iter()
.map(|item| Data::new(image_as_vec_u8(item), Shape::new([HEIGHT, WIDTH, 3])))
.map(|data| Tensor::<B, 3>::from_data(data.convert(), &self.device).permute([2, 0, 1]))
.map(|tensor| tensor / 255) // normalize between [0, 1]
.map(|item| TensorData::new(image_as_vec_u8(item), Shape::new([HEIGHT, WIDTH, 3])))
.map(|data| Tensor::<B, 3>::from_data(data.convert::<B::FloatElem>(), &self.device))
.map(|tensor| tensor.permute([2, 0, 1]) / 255) // normalize between [0, 1]
.collect();

let images = Tensor::stack(images, 0);
Expand Down
3 changes: 1 addition & 2 deletions resnet-burn/examples/finetune/src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, threshold: f32)
// Get predicted class names over the specified threshold
let predicted = output.greater_equal_elem(threshold).nonzero()[1]
.to_data()
.value
.iter()
.iter::<B::IntElem>()
.map(|i| CLASSES[i.elem::<i64>() as usize])
.collect::<Vec<_>>();

Expand Down
8 changes: 3 additions & 5 deletions resnet-burn/examples/inference/examples/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use burn::{
backend::NdArray,
module::Module,
record::{FullPrecisionSettings, NamedMpkFileRecorder},
tensor::{backend::Backend, Data, Device, Element, Shape, Tensor},
tensor::{backend::Backend, Device, Element, Tensor, TensorData},
};

const MODEL_PATH: &str = "resnet18-ImageNet1k";
Expand All @@ -17,10 +17,8 @@ fn to_tensor<B: Backend, T: Element>(
shape: [usize; 3],
device: &Device<B>,
) -> Tensor<B, 3> {
Tensor::<B, 3>::from_data(Data::new(data, Shape::new(shape)).convert(), device)
// permute(2, 0, 1)
.swap_dims(2, 1) // [H, C, W]
.swap_dims(1, 0) // [C, H, W]
Tensor::<B, 3>::from_data(TensorData::new(data, shape).convert::<B::FloatElem>(), device)
.permute([2, 0, 1]) // [C, H, W]
/ 255 // normalize between [0, 1]
}

Expand Down
4 changes: 2 additions & 2 deletions resnet-burn/examples/inference/src/imagenet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ pub struct Normalizer<B: Backend> {
impl<B: Backend> Normalizer<B> {
/// Creates a new normalizer.
pub fn new(device: &Device<B>) -> Self {
let mean = Tensor::from_floats(MEAN, device).reshape([1, 3, 1, 1]);
let std = Tensor::from_floats(STD, device).reshape([1, 3, 1, 1]);
let mean = Tensor::<B, 1>::from_floats(MEAN, device).reshape([1, 3, 1, 1]);
let std = Tensor::<B, 1>::from_floats(STD, device).reshape([1, 3, 1, 1]);
Self { mean, std }
}

Expand Down

0 comments on commit 00dfeac

Please sign in to comment.