Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: cleaned test_cluster.py and added an additional unit test [no issue] #89

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,50 @@

import sys
import unittest
import numpy as np

sys.path.append("../src")

from harmony.matching.cluster import cluster_questions
from harmony.matching.cluster import cluster_questions, perform_kmeans
from harmony.schemas.requests.text import Question

sys.path.append("../src")

class TestCluster(unittest.TestCase):
"""Test class for the cluster.py module."""
def setUp(self):
self.all_questions_real = [Question(question_no="1", question_text="Feeling nervous, anxious, or on edge"),
self.all_questions_real = [Question(question_no="1",
question_text="Feeling nervous, anxious, or on edge"),
Question(question_no="2",
question_text="Not being able to stop or control worrying"),
question_text="Not being able to stop or control "
"worrying"),
Question(question_no="3",
question_text="Little interest or pleasure in doing things"),
Question(question_no="4", question_text="Feeling down, depressed, or hopeless"),
question_text="Little interest or pleasure in doing "
"things"),
Question(question_no="4", question_text="Feeling down, "
"depressed or hopeless"),
Question(question_no="5",
question_text="Trouble falling/staying asleep, sleeping too much"), ]
question_text="Trouble falling/staying asleep, "
"sleeping too much"), ]

def test_cluster(self):
"""Test the entire cluster module."""
clusters_out, score_out = cluster_questions(self.all_questions_real, 2, False)
assert (len(clusters_out) == 5)
assert len(clusters_out) == 5
assert score_out

@unittest.mock.patch("harmony.matching.cluster.KMeans")
def test_perform_kmeans(self, mock_kmeans: unittest.mock.MagicMock):
"""Test the perform_kmeans function in the cluster module."""
mock_kmeans_instance = unittest.mock.Mock()
mock_kmeans.return_value = mock_kmeans_instance
mock_kmeans_instance.fit_predict.return_value = np.array([0, 1, 0, 2, 1])
test_embeddings = np.array([[1,2], [3,4], [1,3], [7,8], [4,5]])

result = perform_kmeans(test_embeddings, num_clusters=3)

mock_kmeans.assert_called_once_with(n_clusters=3)
mock_kmeans_instance.fit_predict.assert_called_once_with(test_embeddings)
np.testing.assert_array_equal(result, np.array([0, 1, 0, 2, 1]))


if __name__ == '__main__':
unittest.main()