Skip to content

Commit

Permalink
Simpler accuracy computation
Browse files Browse the repository at this point in the history
  • Loading branch information
dherrera1911 committed Dec 14, 2024
1 parent f63f04e commit fb519f8
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions docs/source/tutorials/digit_processing.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,13 @@ second order (class covariances) are used to classify samples.

```{code-cell} ipython3
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn.metrics import accuracy_score
def get_qda_accuracy(x_train, y_train, x_test, y_test):
"""Fit QDA model to the training data and return the accuracy on the test data."""
qda = QuadraticDiscriminantAnalysis()
qda.fit(x_train, y_train)
y_pred = qda.predict(x_test)
accuracy = accuracy_score(y_test, y_pred)
accuracy = torch.mean(torch.as_tensor(y_pred == y_test.numpy(), dtype=torch.float))
return accuracy
model_list = [pca, lda, sqfa_model]
Expand Down

0 comments on commit fb519f8

Please sign in to comment.