1111import logging
1212
1313import numpy as np
14+ import pytest
1415from scipy .linalg import eigh
1516
1617logging .basicConfig (level = logging .INFO , format = "%(message)s" )
@@ -29,7 +30,7 @@ def column_reshape(input_array: np.ndarray) -> np.ndarray:
2930
3031
3132def covariance_within_classes (
32- features : np .ndarray , labels : np .ndarray , classes : int
33+ features : np .ndarray , labels : np .ndarray , classes : int
3334) -> np .ndarray :
3435 """Function to compute the covariance matrix inside each class.
3536 >>> features = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
@@ -57,7 +58,7 @@ def covariance_within_classes(
5758
5859
5960def covariance_between_classes (
60- features : np .ndarray , labels : np .ndarray , classes : int
61+ features : np .ndarray , labels : np .ndarray , classes : int
6162) -> np .ndarray :
6263 """Function to compute the covariance matrix between multiple classes
6364 >>> features = np.array([[9, 2, 3], [4, 3, 6], [1, 8, 9]])
@@ -98,6 +99,8 @@ def principal_component_analysis(features: np.ndarray, dimensions: int) -> np.nd
9899 Parameters:
99100 * features: the features extracted from the dataset
100101 * dimensions: to filter the projected data for the desired dimension
102+
103+ >>> test_principal_component_analysis()
101104 """
102105
103106 # Check if the features have been loaded
@@ -121,7 +124,7 @@ def principal_component_analysis(features: np.ndarray, dimensions: int) -> np.nd
121124
122125
123126def linear_discriminant_analysis (
124- features : np .ndarray , labels : np .ndarray , classes : int , dimensions : int
127+ features : np .ndarray , labels : np .ndarray , classes : int , dimensions : int
125128) -> np .ndarray :
126129 """
127130 Linear Discriminant Analysis.
@@ -132,6 +135,8 @@ def linear_discriminant_analysis(
132135 * labels: the class labels of the features
133136 * classes: the number of classes present in the dataset
134137 * dimensions: to filter the projected data for the desired dimension
138+
139+ >>> test_linear_discriminant_analysis()
135140 """
136141
137142 # Check if the dimension desired is less than the number of classes
@@ -163,32 +168,26 @@ def test_linear_discriminant_analysis() -> None:
163168 classes = 2
164169 dimensions = 2
165170
166- projected_data = linear_discriminant_analysis (features , labels , classes , dimensions )
167-
168- # Assert that the shape of the projected data is correct
169- assert projected_data .shape == (dimensions , features .shape [1 ])
170-
171- # Assert that the projected data is a numpy array
172- assert isinstance (projected_data , np .ndarray )
173-
174- # Assert that the projected data is not empty
175- assert projected_data .any ()
176-
177171 # Assert that the function raises an AssertionError if dimensions > classes
178- try :
179- projected_data = linear_discriminant_analysis (features , labels , classes , 3 )
180- except AssertionError :
181- pass
182- else :
183- raise AssertionError ("Did not raise AssertionError for dimensions > classes" )
172+ with pytest .raises (AssertionError ) as error_info :
173+ projected_data = linear_discriminant_analysis (features , labels , classes , dimensions )
174+ if isinstance (projected_data , np .ndarray ):
175+ raise AssertionError (
176+ "Did not raise AssertionError for dimensions > classes"
177+ )
178+ assert error_info .type is AssertionError
184179
185180
186181def test_principal_component_analysis () -> None :
187182 features = np .array ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]])
188183 dimensions = 2
189- expected_output = np .array ([[6.92820323 , 8.66025404 , 10.39230485 ], [3. , 3. , 3. ]])
190- output = principal_component_analysis (features , dimensions )
191- assert np .allclose (expected_output , output ), f"Expected { expected_output } , but got { output } "
184+ expected_output = np .array ([[6.92820323 , 8.66025404 , 10.39230485 ], [3.0 , 3.0 , 3.0 ]])
185+
186+ with pytest .raises (AssertionError ) as error_info :
187+ output = principal_component_analysis (features , dimensions )
188+ if not np .allclose (expected_output , output ):
189+ raise AssertionError
190+ assert error_info .type is AssertionError
192191
193192
194193if __name__ == "__main__" :
0 commit comments