-
Notifications
You must be signed in to change notification settings - Fork 0
/
gda.py
66 lines (48 loc) · 1.77 KB
/
gda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import numpy as np
from scipy.stats import multivariate_normal
class GDA:
"""Class for the Gaussian Discriminant Analysis model.
Attributes:
phi: Proportion of examples from class 0. Float.
mu_0: Mean for class 0. NumPy array.
mu_1: Mean for class 1. NumPy array.
sigma: Covariance matrix. NumPy array.
Example of usage:
> clf = GDA()
> clf.fit(X_train, y_train)
> clf.predict(X_test)
"""
def __init__(self):
self.phi = None
self.mu_0 = None
self.mu_1 = None
self.sigma = None
def fit(self, X, y):
"""Compute gaussian parameters.
Args:
X: Training examples of shape (m, n). NumPy array.
y: Training labels of shape (m,). NumPy array.
"""
self.phi = np.mean(y == 0)
self.mu_0 = np.mean(X[y == 0], axis=0)
self.mu_1 = np.mean(X[y == 1], axis=0)
t0 = X[y == 0] - self.mu_0
t1 = X[y == 1] - self.mu_1
self.sigma = ((t0.T @ t0) + (t1.T @ t1)) / X.shape[0]
def predict(self, X):
"""Make a prediction given new inputs.
Args:
X: Inputs of shape (m, n). NumPy array.
Returns:
h_x: Predictions of shape (m,). NumPy array.
"""
# Probability of each class being true. (Prior).
p0, p1 = self.phi, 1 - self.phi
# Models the distribution of each class. (Likelihood).
px_0 = multivariate_normal.pdf(X, mean=self.mu_0, cov=self.sigma)
px_1 = multivariate_normal.pdf(X, mean=self.mu_1, cov=self.sigma)
# Use Bayes rule to derive the distribution on y given x. (Posterior).
p0_x = p0 * px_0
p1_x = p1 * px_1
h_x = np.argmax([p0_x, p1_x], axis=0)
return h_x