Skip to content

koaning/scikit-mdn

Repository files navigation

scikit-mdn

A mixture density network, by PyTorch, for scikit-learn

This project started as part of a live-stream that is part of the probabl outreach effort on YouTube. If you want to watch the relevant livestreams they can be found here and here.

Usage

To get this tool working locally you will first need to install it:

python -m pip install scikit-mdn

Then you can use it in your code. Here is a small demo example.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from skmdn import MixtureDensityEstimator

# Generate dataset
n_samples = 1000
X_full, _ = make_moons(n_samples=n_samples, noise=0.1)
X = X_full[:, 0].reshape(-1, 1)  # Use only the first column as input
Y = X_full[:, 1].reshape(-1, 1)  # Predict the second column

# Add some noise to Y to make the problem more suitable for MDN
Y += 0.1 * np.random.randn(n_samples, 1)

# Fit the model
mdn = MixtureDensityEstimator()
mdn.fit(X, Y)

# Predict some quantiles on the train set 
means, quantiles = mdn.predict(X, quantiles=[0.01, 0.1, 0.9, 0.99], resolution=100000)
plt.scatter(X, Y)
plt.scatter(X, quantiles[:, 0], color='orange')
plt.scatter(X, quantiles[:, 1], color='green')
plt.scatter(X, quantiles[:, 2], color='green')
plt.scatter(X, quantiles[:, 3], color='orange')
plt.scatter(X, means, color='red')

This is what the chart looks like:

Example chart

Regularisation

There is a weight_decay parameter that will allow you to apply regularisation on the weights. On the moons example the effect of this is pretty clear.

API Documentation

You can find the API documentation on GitHub pages, found here:

https://koaning.github.io/scikit-mdn/

More depth

If you appreciate a glimpse of the internals, you may want to play around with the mdn.ipynb notebook that contains a Jupyter widget.

Example chart

Extra resources

About

A mixture density network, by PyTorch, for scikit-learn

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published