Skip to content
/ gpx Public

Official implementation of GPX: Gaussian Process Regression with Interpretable Sample-wise Feature Weights (published on TNNLS)

License

Notifications You must be signed in to change notification settings

yuyay/gpx

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GPX

GPX example on California housing dataset

GPX is a Gaussian process regression model that can output the feature contributions to the prediction for each sample, which is implemented based on the following paper:
Yuya Yoshikawa, and Tomoharu Iwata. "Gaussian Process Regression With Interpretable Sample-Wise Feature Weights." IEEE Transactions on Neural Networks and Learning Systems (2021).

GPX has the following characteristics:

  • High accuracy: GPX can achieve comparable predictive accuracy to standard Gaussian process regression models.
  • Explainability: GPX can output feature contributions with uncertainty for each sample. We showed that the feature contributions are more appropriate qualitatively and quantitatively than the existing explanation methods, such as LIME and SHAP, etc.

Installation

The pytorch-gpx package is on PyPI. Simply run:

pip install pytorch-gpx

Or clone the repository and run:

pip install .

Usage

The pytorch-gpx package provides scikit-learn-like API for training, prediction, and evaluation of GPX models.

from sklearn.metrics import mean_squared_error
from gpx import GPXRegressor

'''Training
X_tr: input data (numpy array), with shape of (n_samples, n_X_features)
y_tr: target variables (numpy array), with shape of (n_samples,)
Z_tr: simplified input data (numpy array), with shape of (n_samples, n_Z_features). The same as X_tr is OK.
'''
model = GPXRegressor().fit(X_tr, y_tr, Z_tr)

'''Prediction
y_mean: the posterior mean of target variables
y_conv: the posterior variance of target variables
w_mean: the posterior mean of weights
w_conv: the posterior variance of weights
'''
y_mean, y_cov, w_mean, w_cov = model.predict(X_te, Z_te, return_weights=True)

'''Evaluation'''
mse = mean_squared_error(y_te, y_mean)
print("Test MSE = {}".format(mse))

For more usage examples, please see the below.

Citation

If you use this repo, please cite the following paper.

@article{yoshikawa2021gpx,
  title={Gaussian Process Regression With Interpretable Sample-Wise Feature Weights},
  author={Yoshikawa, Yuya and Iwata, Tomoharu},
  journal={IEEE Transactions on Neural Networks and Learning Systems},
  year={2021},
  publisher={IEEE}
}

License

Please see LICENSE.txt.

Acknowledgment

This work was supported by the Japan Society for the Promotion of Science (JSPS) KAKENHI under Grant 18K18112.