diff --git a/docs/source/tutorials/digit_processing.md b/docs/source/tutorials/digit_processing.md index 2628293..a14c0be 100644 --- a/docs/source/tutorials/digit_processing.md +++ b/docs/source/tutorials/digit_processing.md @@ -173,14 +173,13 @@ second order (class covariances) are used to classify samples. ```{code-cell} ipython3 from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis -from sklearn.metrics import accuracy_score def get_qda_accuracy(x_train, y_train, x_test, y_test): """Fit QDA model to the training data and return the accuracy on the test data.""" qda = QuadraticDiscriminantAnalysis() qda.fit(x_train, y_train) y_pred = qda.predict(x_test) - accuracy = accuracy_score(y_test, y_pred) + accuracy = torch.mean(torch.as_tensor(y_pred == y_test.numpy(), dtype=torch.float)) return accuracy model_list = [pca, lda, sqfa_model]