-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX time series classification #105
base: main
Are you sure you want to change the base?
Conversation
0e0f205
to
a7d3a6f
Compare
It's ready for a review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good – it's a bit sparse on the description of what the code is doing, but we can probably address that in a later pass.
Please resolve the conflicts, and then we can merge this. Thanks! |
a7d3a6f
to
fcc6b33
Compare
Rebased! I also added a couple of descriptions before some of the code cells, so there's more narrative. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mtsokol - Thank you for opening this PR. I've added a few more suggestions for narration based on:
- Keras: Timeseries classification from scratch
- Tensorflow: TFDS for Jax and PyTorch
- Flax: MNIST tutorial
Please verify the accuracy of the statements. I've tried to maintain formatting, but these are still GitHub-suggestions, so you may need to update line breaks and such as needed.
# Time series classification with JAX | ||
|
||
In this tutorial, we're going to perform time series classification with a Convolutional Neural Network. | ||
We're going to use FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're going to use FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/). | |
We will use the FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/), which contains measurements of engine noise captured by a motor sensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
The problem we're facing is to assess if an engine is malfunctioning based on recorded noises it generates. | ||
Each sample is comprised of noise measurements across time, together with a "yes/no" label, so it's a binary classification problem. | ||
|
||
Although convolution models are mainly associated with image processing, they are useful also for time series data as they're able to extract temporal structures. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although convolution models are mainly associated with image processing, they are useful also for time series data as they're able to extract temporal structures. | |
Although convolution models are mainly associated with image processing, they are also useful for time series data because they can extract temporal structures. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
```{code-cell} ipython3 | ||
# Required packages | ||
# !pip install -U jax flax optax | ||
# !pip install -U grain tqdm requests matplotlib | ||
``` | ||
|
||
## Tools overview | ||
|
||
Here's a list of key packages that belong to JAX AI stack: | ||
|
||
- [JAX](https://github.com/jax-ml/jax) will be used for array computations. | ||
- [Flax](https://github.com/google/flax) for constructing neural networks. | ||
- [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization. | ||
- [Grain](https://github.com/google/grain/) will be be used to define data sources. | ||
- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think moving the installation cell after introducing the libraries has a nicer flow. :)
I've also modified the title and made the list have a consistent narrative.
```{code-cell} ipython3 | |
# Required packages | |
# !pip install -U jax flax optax | |
# !pip install -U grain tqdm requests matplotlib | |
``` | |
## Tools overview | |
Here's a list of key packages that belong to JAX AI stack: | |
- [JAX](https://github.com/jax-ml/jax) will be used for array computations. | |
- [Flax](https://github.com/google/flax) for constructing neural networks. | |
- [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization. | |
- [Grain](https://github.com/google/grain/) will be be used to define data sources. | |
- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress. | |
## Tools overview and setup | |
Here's a list of key packages that belong to the JAX AI stack required for this tutorial: | |
- [JAX](https://github.com/jax-ml/jax) for array computations. | |
- [Flax](https://github.com/google/flax) for constructing neural networks. | |
- [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization. | |
- [Grain](https://github.com/google/grain/) to define data sources. | |
- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress. | |
We'll start by installing and importing these packages. | |
```{code-cell} ipython3 | |
# Required packages | |
# !pip install -U jax flax optax | |
# !pip install -U grain tqdm requests matplotlib | |
``` | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
The problem we're facing is to assess if an engine is malfunctioning based on recorded noises it generates. | ||
Each sample is comprised of noise measurements across time, together with a "yes/no" label, so it's a binary classification problem. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem we're facing is to assess if an engine is malfunctioning based on recorded noises it generates. | |
Each sample is comprised of noise measurements across time, together with a "yes/no" label, so it's a binary classification problem. | |
We need to assess if an engine is malfunctioning based on the recorded noises it generates. | |
Each sample comprises of noise measurements across time, together with a "yes/no" label, so this is a binary classification problem. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
We load dataset files into NumPy arrays, add singleton dimention to take into | ||
the account convolution features, and change `-1` label to `0` value: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We load dataset files into NumPy arrays, add singleton dimention to take into | |
the account convolution features, and change `-1` label to `0` value: | |
We load dataset files into NumPy arrays, add singleton dimension to take convolution features into account, and change `-1` label to `0` (so that the expected values are `0` and `1`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum)) | ||
``` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll define a loss and logits computation function using | |
Optax's [`losses.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.softmax_cross_entropy_with_integer_labels). | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
).mean() | ||
return loss, logits | ||
``` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll now define the training and evaluation step functions. | |
The loss and logits from both functions will be used for calculating accuracy metrics. | |
For training, we'll use `nnx.value_and_grad` compute the gradients, and then update the model’s parameters using our optimizer. | |
Notice the use of [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit). This sets up the functions for just-in-time (JIT) compilation with [XLA](https://openxla.org/xla) for performant execution across different hardware accelerators like GPUs and TPUs. | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
"test_accuracy": [], | ||
} | ||
``` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can now train the CNN model. | |
We'll evaluate the model’s performance on the test set after each epoch, | |
and print the metrics: total loss and accuracy. | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
train_one_epoch(epoch) | ||
evaluate_model(epoch) | ||
``` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finally, let's visualize the loss and accuracy with Matplotlib. | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
For model early stopping and selecting best model there's [Orbax](https://github.com/google/orbax) | ||
library which provides checkpointing and persistence utilities. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For model early stopping and selecting best model there's [Orbax](https://github.com/google/orbax) | |
library which provides checkpointing and persistence utilities. | |
For model early stopping and selecting best model, you can check out [Orbax](https://github.com/google/orbax), | |
a library which provides checkpointing and persistence utilities. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
This PR adds a tutorial with JAX time series classification task on a FordA dataset from the UCR/UEA archive. It follows one of the Keras tutorials. It's still missing the narrative.