Skip to content

Commit 09c65ca

Browse files
authored
Improve regression example (#2405)
1 parent 9ace3a0 commit 09c65ca

File tree

11 files changed

+257
-189
lines changed

11 files changed

+257
-189
lines changed

Cargo.lock

+25
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

README.md

+9-10
Original file line numberDiff line numberDiff line change
@@ -325,17 +325,16 @@ WGPU (WebGPU): Cross-Platform GPU Backend 🌐
325325
Based on the most popular and well-supported Rust graphics library, [WGPU](https://wgpu.rs), this
326326
backend automatically targets Vulkan, OpenGL, Metal, Direct X11/12, and WebGPU, by using the WebGPU
327327
shading language [WGSL](https://www.w3.org/TR/WGSL/https://www.w3.org/TR/WGSL/), or optionally
328-
[SPIR-V](https://www.khronos.org/spir/) when targeting Vulkan. It can also be compiled to Web Assembly
329-
to run in the browser while leveraging the GPU, see
328+
[SPIR-V](https://www.khronos.org/spir/) when targeting Vulkan. It can also be compiled to Web
329+
Assembly to run in the browser while leveraging the GPU, see
330330
[this demo](https://antimora.github.io/image-classification/). For more information on the benefits
331331
of this backend, see [this blog](https://burn.dev/blog/cross-platform-gpu-backend).
332332

333333
The WGPU backend is our first "in-house backend", which means we have complete control over its
334334
implementation details. It is fully optimized with the
335335
[performance characteristics mentioned earlier](#performance), as it serves as our research
336-
playground for a variety of optimizations.
337-
We've since added CUDA, ROCm and SPIR-V support using the same compiler infrastructure, so a kernel
338-
written for burn once, can run anywhere.
336+
playground for a variety of optimizations. We've since added CUDA, ROCm and SPIR-V support using the
337+
same compiler infrastructure, so a kernel written for burn once, can run anywhere.
339338

340339
See the [WGPU Backend README](./crates/burn-wgpu/README.md) and
341340
[CUDA Backend README](./crates/burn-cuda/README.md) for more details.
@@ -486,9 +485,9 @@ The Burn Book 🔥
486485

487486
To begin working effectively with Burn, it is crucial to understand its key components and
488487
philosophy. This is why we highly recommend new users to read the first sections of
489-
[The Burn Book 🔥](https://burn.dev/burn-book/). It provides detailed examples and explanations covering
490-
every facet of the framework, including building blocks like tensors, modules, and optimizers, all
491-
the way to advanced usage, like coding your own GPU kernels.
488+
[The Burn Book 🔥](https://burn.dev/burn-book/). It provides detailed examples and explanations
489+
covering every facet of the framework, including building blocks like tensors, modules, and
490+
optimizers, all the way to advanced usage, like coding your own GPU kernels.
492491

493492
> The project is constantly evolving, and we try as much as possible to keep the book up to date
494493
> with new additions. However, we might miss some details sometimes, so if you see something weird,
@@ -545,8 +544,8 @@ Additional examples:
545544

546545
- [Custom CSV Dataset](./examples/custom-csv-dataset) : Implements a dataset to parse CSV data for a
547546
regression task.
548-
- [Regression](./examples/simple-regression) : Trains a simple MLP on the CSV dataset for the
549-
regression task.
547+
- [Regression](./examples/simple-regression) : Trains a simple MLP on the California Housing dataset
548+
to predict the median house value for a district.
550549
- [Custom Image Dataset](./examples/custom-image-dataset) : Trains a simple CNN on custom image
551550
dataset following a simple folder structure.
552551
- [Custom Renderer](./examples/custom-renderer) : Implements a custom renderer to display the

burn-book/src/examples.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ The following additional examples are currently available if you want to check t
7474
| Example | Description |
7575
| :-------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
7676
| [Custom CSV Dataset](https://github.com/tracel-ai/burn/tree/main/examples/custom-csv-dataset) | Implements a dataset to parse CSV data for a regression task. |
77-
| [Regression](https://github.com/tracel-ai/burn/tree/main/examples/simple-regression) | Trains a simple MLP on the CSV dataset for the regression task. |
77+
| [Regression](https://github.com/tracel-ai/burn/tree/main/examples/simple-regression) | Trains a simple MLP on the California Housing dataset to predict the median house value for a district. |
7878
| [Custom Image Dataset](https://github.com/tracel-ai/burn/tree/main/examples/custom-image-dataset) | Trains a simple CNN on custom image dataset following a simple folder structure. |
7979
| [Custom Renderer](https://github.com/tracel-ai/burn/tree/main/examples/custom-renderer) | Implements a custom renderer to display the [`Learner`](./building-blocks/learner.md) progress. |
8080
| [Image Classification Web](https://github.com/tracel-ai/burn/tree/main/examples/image-classification-web) | Image classification web browser demo using Burn, WGPU and WebAssembly. |
@@ -83,7 +83,7 @@ The following additional examples are currently available if you want to check t
8383
| [Named Tensor](https://github.com/tracel-ai/burn/tree/main/examples/named-tensor) | Performs operations with the experimental `NamedTensor` feature. |
8484
| [ONNX Import Inference](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference) | Imports an ONNX model pre-trained on MNIST to perform inference on a sample image with Burn. |
8585
| [PyTorch Import Inference](https://github.com/tracel-ai/burn/tree/main/examples/pytorch-import) | Imports a PyTorch model pre-trained on MNIST to perform inference on a sample image with Burn. |
86-
| [Text Classification](https://github.com/tracel-ai/burn/tree/main/examples/text-classification) | Trains a text classification transformer model on the AG News or DbPedia datasets. The trained model can then be used to classify a text sample. |
86+
| [Text Classification](https://github.com/tracel-ai/burn/tree/main/examples/text-classification) | Trains a text classification transformer model on the AG News or DbPedia datasets. The trained model can then be used to classify a text sample. |
8787
| [Text Generation](https://github.com/tracel-ai/burn/tree/main/examples/text-generation) | Trains a text generation transformer model on the DbPedia dataset. |
8888

8989
For more information on each example, see their respective `README.md` file. Be sure to check out

examples/simple-regression/Cargo.toml

+4
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,7 @@ burn = {path = "../../crates/burn", features=["train"]}
2222
# Serialization
2323
log = {workspace = true}
2424
serde = {workspace = true, features = ["std", "derive"]}
25+
26+
# Displaying results
27+
textplots = "0.8.6"
28+
rgb = "0.8.27"

examples/simple-regression/README.md

+6-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22

33
The example shows you how to:
44

5-
- Define a custom dataset for regression problems. We implement the [Diabetes Toy Dataset](https://huggingface.co/datasets/Jayabalambika/toy-diabetes)
6-
from HuggingFace hub. The dataset is also available as part of toy regression datasets in sklearn[datasets](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_diabetes.html).
7-
- Create a data pipeline from a raw dataset to a batched fast DataLoader with min-max feature scaling.
5+
- Define a custom dataset for regression problems. We implement the
6+
[California Housing Dataset](https://huggingface.co/datasets/gvlassis/california_housing) from
7+
HuggingFace hub. The dataset is also available as part of toy regression datasets in
8+
sklearn[datasets](https://scikit-learn.org/stable/datasets/real_world.html#california-housing-dataset).
9+
- Create a data pipeline from a raw dataset to a batched fast DataLoader with min-max feature
10+
scaling.
811
- Define a Simple NN model for regression using Burn Modules.
912

1013
> **Note**

examples/simple-regression/examples/regression.rs

+19-23
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,63 @@
1+
use burn::{backend::Autodiff, tensor::backend::Backend};
2+
use simple_regression::{inference, training};
3+
4+
static ARTIFACT_DIR: &str = "/tmp/burn-example-regression";
5+
16
#[cfg(any(
27
feature = "ndarray",
38
feature = "ndarray-blas-netlib",
49
feature = "ndarray-blas-openblas",
510
feature = "ndarray-blas-accelerate",
611
))]
712
mod ndarray {
8-
use burn::backend::{
9-
ndarray::{NdArray, NdArrayDevice},
10-
Autodiff,
11-
};
12-
use simple_regression::training;
13+
use burn::backend::ndarray::{NdArray, NdArrayDevice};
1314

1415
pub fn run() {
1516
let device = NdArrayDevice::Cpu;
16-
training::run::<Autodiff<NdArray>>(device);
17+
super::run::<NdArray>(device.clone());
1718
}
1819
}
1920

2021
#[cfg(feature = "tch-gpu")]
2122
mod tch_gpu {
22-
use burn::backend::{
23-
libtorch::{LibTorch, LibTorchDevice},
24-
Autodiff,
25-
};
26-
use simple_regression::training;
23+
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
2724

2825
pub fn run() {
2926
#[cfg(not(target_os = "macos"))]
3027
let device = LibTorchDevice::Cuda(0);
3128
#[cfg(target_os = "macos")]
3229
let device = LibTorchDevice::Mps;
3330

34-
training::run::<Autodiff<LibTorch>>(device);
31+
super::run::<LibTorch>(device);
3532
}
3633
}
3734

3835
#[cfg(feature = "wgpu")]
3936
mod wgpu {
40-
use burn::backend::{
41-
wgpu::{Wgpu, WgpuDevice},
42-
Autodiff,
43-
};
44-
use simple_regression::training;
37+
use burn::backend::wgpu::{Wgpu, WgpuDevice};
4538

4639
pub fn run() {
4740
let device = WgpuDevice::default();
48-
training::run::<Autodiff<Wgpu>>(device);
41+
super::run::<Wgpu>(device);
4942
}
5043
}
5144

5245
#[cfg(feature = "tch-cpu")]
5346
mod tch_cpu {
54-
use burn::backend::{
55-
libtorch::{LibTorch, LibTorchDevice},
56-
Autodiff,
57-
};
47+
use burn::backend::libtorch::{LibTorch, LibTorchDevice};
5848
use simple_regression::training;
5949
pub fn run() {
6050
let device = LibTorchDevice::Cpu;
61-
training::run::<Autodiff<LibTorch>>(device);
51+
super::run::<LibTorch>(device);
6252
}
6353
}
6454

55+
/// Train a regression model and predict results on a number of samples.
56+
pub fn run<B: Backend>(device: B::Device) {
57+
training::run::<Autodiff<B>>(ARTIFACT_DIR, device.clone());
58+
inference::infer::<B>(ARTIFACT_DIR, device)
59+
}
60+
6561
fn main() {
6662
#[cfg(any(
6763
feature = "ndarray",

0 commit comments

Comments
 (0)