Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX time series classification #105

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mtsokol
Copy link

@mtsokol mtsokol commented Nov 20, 2024

This PR adds a tutorial with JAX time series classification task on a FordA dataset from the UCR/UEA archive. It follows one of the Keras tutorials. It's still missing the narrative.

@jakevdp jakevdp self-assigned this Nov 20, 2024
@mtsokol mtsokol force-pushed the nlp-time-series-classif branch 2 times, most recently from 0e0f205 to a7d3a6f Compare November 21, 2024 14:51
@mtsokol
Copy link
Author

mtsokol commented Nov 21, 2024

It's ready for a review!

@jakevdp jakevdp changed the title NLP: JAX time series classification JAX time series classification Nov 21, 2024
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good – it's a bit sparse on the description of what the code is doing, but we can probably address that in a later pass.

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 21, 2024

Please resolve the conflicts, and then we can merge this. Thanks!

@mtsokol
Copy link
Author

mtsokol commented Nov 22, 2024

Please resolve the conflicts, and then we can merge this. Thanks!

Rebased!

I also added a couple of descriptions before some of the code cells, so there's more narrative.

Copy link
Contributor

@pavithraes pavithraes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mtsokol - Thank you for opening this PR. I've added a few more suggestions for narration based on:

Please verify the accuracy of the statements. I've tried to maintain formatting, but these are still GitHub-suggestions, so you may need to update line breaks and such as needed.

# Time series classification with JAX

In this tutorial, we're going to perform time series classification with a Convolutional Neural Network.
We're going to use FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We're going to use FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/).
We will use the FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/), which contains measurements of engine noise captured by a motor sensor.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

The problem we're facing is to assess if an engine is malfunctioning based on recorded noises it generates.
Each sample is comprised of noise measurements across time, together with a "yes/no" label, so it's a binary classification problem.

Although convolution models are mainly associated with image processing, they are useful also for time series data as they're able to extract temporal structures.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Although convolution models are mainly associated with image processing, they are useful also for time series data as they're able to extract temporal structures.
Although convolution models are mainly associated with image processing, they are also useful for time series data because they can extract temporal structures.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines 25 to 40
```{code-cell} ipython3
# Required packages
# !pip install -U jax flax optax
# !pip install -U grain tqdm requests matplotlib
```

## Tools overview

Here's a list of key packages that belong to JAX AI stack:

- [JAX](https://github.com/jax-ml/jax) will be used for array computations.
- [Flax](https://github.com/google/flax) for constructing neural networks.
- [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization.
- [Grain](https://github.com/google/grain/) will be be used to define data sources.
- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think moving the installation cell after introducing the libraries has a nicer flow. :)
I've also modified the title and made the list have a consistent narrative.

Suggested change
```{code-cell} ipython3
# Required packages
# !pip install -U jax flax optax
# !pip install -U grain tqdm requests matplotlib
```
## Tools overview
Here's a list of key packages that belong to JAX AI stack:
- [JAX](https://github.com/jax-ml/jax) will be used for array computations.
- [Flax](https://github.com/google/flax) for constructing neural networks.
- [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization.
- [Grain](https://github.com/google/grain/) will be be used to define data sources.
- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress.
## Tools overview and setup
Here's a list of key packages that belong to the JAX AI stack required for this tutorial:
- [JAX](https://github.com/jax-ml/jax) for array computations.
- [Flax](https://github.com/google/flax) for constructing neural networks.
- [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization.
- [Grain](https://github.com/google/grain/) to define data sources.
- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress.
We'll start by installing and importing these packages.
```{code-cell} ipython3
# Required packages
# !pip install -U jax flax optax
# !pip install -U grain tqdm requests matplotlib
```

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines 20 to 21
The problem we're facing is to assess if an engine is malfunctioning based on recorded noises it generates.
Each sample is comprised of noise measurements across time, together with a "yes/no" label, so it's a binary classification problem.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The problem we're facing is to assess if an engine is malfunctioning based on recorded noises it generates.
Each sample is comprised of noise measurements across time, together with a "yes/no" label, so it's a binary classification problem.
We need to assess if an engine is malfunctioning based on the recorded noises it generates.
Each sample comprises of noise measurements across time, together with a "yes/no" label, so this is a binary classification problem.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines 55 to 56
We load dataset files into NumPy arrays, add singleton dimention to take into
the account convolution features, and change `-1` label to `0` value:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We load dataset files into NumPy arrays, add singleton dimention to take into
the account convolution features, and change `-1` label to `0` value:
We load dataset files into NumPy arrays, add singleton dimension to take convolution features into account, and change `-1` label to `0` (so that the expected values are `0` and `1`):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))
```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We'll define a loss and logits computation function using
Optax's [`losses.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.softmax_cross_entropy_with_integer_labels).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

).mean()
return loss, logits
```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We'll now define the training and evaluation step functions.
The loss and logits from both functions will be used for calculating accuracy metrics.
For training, we'll use `nnx.value_and_grad` compute the gradients, and then update the model’s parameters using our optimizer.
Notice the use of [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit). This sets up the functions for just-in-time (JIT) compilation with [XLA](https://openxla.org/xla) for performant execution across different hardware accelerators like GPUs and TPUs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

"test_accuracy": [],
}
```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We can now train the CNN model.
We'll evaluate the model’s performance on the test set after each epoch,
and print the metrics: total loss and accuracy.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

train_one_epoch(epoch)
evaluate_model(epoch)
```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Finally, let's visualize the loss and accuracy with Matplotlib.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines 345 to 346
For model early stopping and selecting best model there's [Orbax](https://github.com/google/orbax)
library which provides checkpointing and persistence utilities.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For model early stopping and selecting best model there's [Orbax](https://github.com/google/orbax)
library which provides checkpointing and persistence utilities.
For model early stopping and selecting best model, you can check out [Orbax](https://github.com/google/orbax),
a library which provides checkpointing and persistence utilities.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants