Skip to content

Commit

Permalink
docs(book-&-examples): modify book and examples with new prelude mo…
Browse files Browse the repository at this point in the history
…dule (#1372)
  • Loading branch information
bioinformatist authored Feb 28, 2024
1 parent 57887e7 commit 330552a
Show file tree
Hide file tree
Showing 52 changed files with 299 additions and 258 deletions.
2 changes: 1 addition & 1 deletion burn-book/src/basic-workflow/data.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ at `examples/guide/` [directory](https://github.com/tracel-ai/burn/tree/main/exa
```rust , ignore
use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
prelude::*,
};

pub struct MnistBatcher<B: Backend> {
Expand Down
4 changes: 1 addition & 3 deletions burn-book/src/basic-workflow/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,12 @@ Let us start by defining our model struct in a new file `src/model.rs`.

```rust , ignore
use burn::{
config::Config,
module::Module,
nn::{
conv::{Conv2d, Conv2dConfig},
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
Dropout, DropoutConfig, Linear, LinearConfig, Relu,
},
tensor::{backend::Backend, Tensor},
prelude::*,
};

#[derive(Module, Debug)]
Expand Down
1 change: 0 additions & 1 deletion burn-book/src/building-blocks/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ to your types, allowing you to define default values with ease. Additionally, al
serialized, reducing potential bugs when upgrading versions and improving reproducibility.

```rust , ignore
#[derive(Config)]
use burn::config::Config;

#[derive(Config)]
Expand Down
1 change: 0 additions & 1 deletion burn-book/src/building-blocks/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ derive function only generates the necessary methods to essentially act as a par
your type, it makes no assumptions about how the forward pass is declared.

```rust, ignore
use burn::nn;
use burn::module::Module;
use burn::tensor::backend::Backend;
Expand Down
31 changes: 31 additions & 0 deletions burn-book/src/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,37 @@ While the previous example is somewhat trivial, the upcoming
basic workflow section will walk you through a much more relevant example for
deep learning applications.

## Using `prelude`

Burn comes with a variety of things in its core library.
When creating a new model or using an existing one for inference,
you may need to import every single component you used, which could be a little verbose.

To address it, a `prelude` module is provided, allowing you to easily import commonly used structs and macros as a group:

```rust, ignore
use burn::prelude::*;
```

which is equal to:

```rust, ignore
use burn::{
config::Config,
module::Module,
nn,
tensor::{
backend::Backend, Bool, Data, Device, ElementConversion, Float, Int, Shape, Tensor,
},
};
```

<div class="warning">

For the sake of simplicity, the subsequent chapters of this book will all use this form of importing. However, this does not include the content in the [Building Blocks](./building-blocks) chapter, as explicit importing aids users in grasping the usage of particular structures and macros.

</div>

## Running examples

Many additional Burn examples available in the
Expand Down
3 changes: 1 addition & 2 deletions burn-book/src/import/pytorch-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ something like this:

```rust
use burn::{
module::Module,
nn::conv::{Conv2d, Conv2dConfig},
tensor::{backend::Backend, Tensor},
prelude::*,
};

#[derive(Module, Debug)]
Expand Down
12 changes: 6 additions & 6 deletions burn-book/src/saving-and-loading.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ serialize/deserialize models. By default, we use the `NamedMpkFileRecorder` whic
// Save model in MessagePack format with full precision
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model
.save_file(model_path, &recorder)
.expect("Should be able to save the model");
.save_file(model_path, &recorder)
.expect("Should be able to save the model");
```

Note that the file extension is automatically handled by the recorder depending on the one you
Expand All @@ -23,8 +23,8 @@ Now that you have a trained model saved to your disk, you can easily load it in
// Load model in full precision from MessagePack file
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
model
.load_file(model_path, &recorder, device)
.expect("Should be able to load the model weights from the provided file");
.load_file(model_path, &recorder, device)
.expect("Should be able to load the model weights from the provided file");
```

**Note:** models can be saved in different output formats, just make sure you are using the correct
Expand Down Expand Up @@ -117,8 +117,8 @@ a model as part of your runtime application, first save the model to a binary fi
// Save model in binary format with full precision
let recorder = BinFileRecorder::<FullPrecisionSettings>::new();
model
.save_file(model_path, &recorder)
.expect("Should be able to save the model");
.save_file(model_path, &recorder)
.expect("Should be able to save the model");
```

Then, in your final application, include the model and use the `BinBytesRecorder` to load it.
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ pub mod prelude {
config::Config,
module::Module,
nn,
tensor::{backend::Backend, Data, Device, ElementConversion, Tensor},
tensor::{
backend::Backend, Bool, Data, Device, ElementConversion, Float, Int, Shape, Tensor,
},
};
}
22 changes: 14 additions & 8 deletions examples/custom-image-dataset/examples/custom-image-dataset.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
use burn::backend::Autodiff;
use burn::optim::momentum::MomentumConfig;
use burn::optim::SgdConfig;
use burn::{
backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
},
optim::{momentum::MomentumConfig, SgdConfig},
};
use custom_image_dataset::training::{train, TrainingConfig};

pub fn run() {
Expand All @@ -25,10 +28,13 @@ mod tch_gpu {

#[cfg(feature = "wgpu")]
mod wgpu {
use burn::backend::wgpu::{Wgpu, WgpuDevice};
use burn::backend::Autodiff;
use burn::optim::momentum::MomentumConfig;
use burn::optim::SgdConfig;
use burn::{
backend::{
wgpu::{Wgpu, WgpuDevice},
Autodiff,
},
optim::{momentum::MomentumConfig, SgdConfig},
};
use custom_image_dataset::training::{train, TrainingConfig};

pub fn run() {
Expand Down
2 changes: 1 addition & 1 deletion examples/custom-image-dataset/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use burn::{
dataloader::batcher::Batcher,
dataset::vision::{Annotation, ImageDatasetItem, PixelDepth},
},
tensor::{backend::Backend, Data, Device, ElementConversion, Int, Shape, Tensor},
prelude::*,
};

// CIFAR-10 mean and std values
Expand Down
3 changes: 1 addition & 2 deletions examples/custom-image-dataset/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use burn::{
module::Module,
nn::{
conv::{Conv2d, Conv2dConfig},
pool::{MaxPool2d, MaxPool2dConfig},
Dropout, DropoutConfig, Linear, LinearConfig, PaddingConfig2d, Relu,
},
tensor::{backend::Backend, Device, Tensor},
prelude::*,
};

/// Basic convolutional neural network with VGG-style blocks.
Expand Down
17 changes: 6 additions & 11 deletions examples/custom-image-dataset/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,16 @@ use crate::{
dataset::CIFAR10Loader,
model::Cnn,
};
use burn::data::{dataloader::DataLoaderBuilder, dataset::vision::ImageFolderDataset};
use burn::train::{
metric::{AccuracyMetric, LossMetric},
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
};
use burn::{
self,
config::Config,
module::Module,
data::{dataloader::DataLoaderBuilder, dataset::vision::ImageFolderDataset},
nn::loss::CrossEntropyLossConfig,
optim::SgdConfig,
prelude::*,
record::CompactRecorder,
tensor::{
backend::{AutodiffBackend, Backend},
Int, Tensor,
tensor::backend::AutodiffBackend,
train::{
metric::{AccuracyMetric, LossMetric},
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
},
};

Expand Down
3 changes: 1 addition & 2 deletions examples/custom-renderer/examples/custom-renderer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use burn::backend::wgpu::WgpuDevice;
use burn::backend::{Autodiff, Wgpu};
use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu};

fn main() {
custom_renderer::run::<Autodiff<Wgpu>>(WgpuDevice::default());
Expand Down
11 changes: 7 additions & 4 deletions examples/custom-renderer/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use burn::data::dataset::vision::MnistDataset;
use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress};
use burn::train::LearnerBuilder;
use burn::{
config::Config, data::dataloader::DataLoaderBuilder, optim::AdamConfig,
config::Config,
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
optim::AdamConfig,
tensor::backend::AutodiffBackend,
train::{
renderer::{MetricState, MetricsRenderer, TrainingProgress},
LearnerBuilder,
},
};
use guide::{data::MnistBatcher, model::ModelConfig};

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use burn::backend::wgpu::WgpuDevice;
use burn::backend::{Autodiff, Wgpu};
use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu};

fn main() {
custom_training_loop::run::<Autodiff<Wgpu>>(WgpuDevice::default());
Expand Down
10 changes: 3 additions & 7 deletions examples/custom-training-loop/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
use std::marker::PhantomData;

use burn::data::dataset::vision::MnistDataset;
use burn::{
config::Config,
data::dataloader::DataLoaderBuilder,
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
module::AutodiffModule,
nn::loss::CrossEntropyLoss,
optim::{AdamConfig, GradientsParams, Optimizer},
tensor::{
backend::{AutodiffBackend, Backend},
ElementConversion, Int, Tensor,
},
prelude::*,
tensor::backend::AutodiffBackend,
};
use guide::{
data::{MnistBatch, MnistBatcher},
Expand Down
19 changes: 11 additions & 8 deletions examples/custom-wgpu-kernel/src/backward.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
use crate::FloatTensor;

use super::{AutodiffBackend, Backend};
use burn::backend::autodiff::{
checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},
grads::Gradients,
ops::{broadcast_shape, Backward, Ops, OpsKind},
Autodiff, NodeID,
use burn::{
backend::{
autodiff::{
checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},
grads::Gradients,
ops::{broadcast_shape, Backward, Ops, OpsKind},
Autodiff, NodeID,
},
wgpu::{compute::WgpuRuntime, FloatElement, GraphicsApi, IntElement, JitBackend},
},
tensor::Shape,
};
use burn::backend::wgpu::compute::WgpuRuntime;
use burn::backend::wgpu::{FloatElement, GraphicsApi, IntElement, JitBackend};
use burn::tensor::Shape;

impl<G: GraphicsApi, F: FloatElement, I: IntElement> AutodiffBackend
for Autodiff<JitBackend<WgpuRuntime<G, F, I>>>
Expand Down
18 changes: 10 additions & 8 deletions examples/custom-wgpu-kernel/src/forward.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use crate::FloatTensor;

use super::Backend;
use burn::backend::wgpu::{
compute::{DynamicKernel, WgpuRuntime, WorkGroup},
kernel::{
build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource,
use burn::{
backend::wgpu::{
compute::{DynamicKernel, WgpuRuntime, WorkGroup},
kernel::{
build_info, into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource,
},
kernel_wgsl,
tensor::JitTensor,
FloatElement, GraphicsApi, IntElement, JitBackend,
},
kernel_wgsl,
tensor::JitTensor,
FloatElement, GraphicsApi, IntElement, JitBackend,
tensor::Shape,
};
use burn::tensor::Shape;
use derive_new::new;
use std::marker::PhantomData;

Expand Down
9 changes: 5 additions & 4 deletions examples/guide/examples/guide.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use burn::backend::wgpu::AutoGraphicsApi;
use burn::backend::{Autodiff, Wgpu};
use burn::data::dataset::Dataset;
use burn::optim::AdamConfig;
use burn::{
backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
data::dataset::Dataset,
optim::AdamConfig,
};
use guide::{model::ModelConfig, training::TrainingConfig};

fn main() {
Expand Down
2 changes: 1 addition & 1 deletion examples/guide/src/data.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
prelude::*,
};

pub struct MnistBatcher<B: Backend> {
Expand Down
6 changes: 2 additions & 4 deletions examples/guide/src/inference.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use crate::{data::MnistBatcher, training::TrainingConfig};
use burn::data::dataset::vision::MnistItem;
use burn::{
config::Config,
data::dataloader::batcher::Batcher,
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
prelude::*,
record::{CompactRecorder, Recorder},
tensor::backend::Backend,
};

pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
Expand Down
4 changes: 1 addition & 3 deletions examples/guide/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use burn::{
config::Config,
module::Module,
nn::{
conv::{Conv2d, Conv2dConfig},
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
Dropout, DropoutConfig, Linear, LinearConfig, Relu,
},
tensor::{backend::Backend, Tensor},
prelude::*,
};

#[derive(Module, Debug)]
Expand Down
18 changes: 6 additions & 12 deletions examples/guide/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,16 @@ use crate::{
data::{MnistBatch, MnistBatcher},
model::{Model, ModelConfig},
};
use burn::data::dataset::vision::MnistDataset;
use burn::train::{
metric::{AccuracyMetric, LossMetric},
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
};
use burn::{
self,
config::Config,
data::dataloader::DataLoaderBuilder,
module::Module,
data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset},
nn::loss::CrossEntropyLossConfig,
optim::AdamConfig,
prelude::*,
record::CompactRecorder,
tensor::{
backend::{AutodiffBackend, Backend},
Int, Tensor,
tensor::backend::AutodiffBackend,
train::{
metric::{AccuracyMetric, LossMetric},
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
},
};

Expand Down
Loading

0 comments on commit 330552a

Please sign in to comment.