From 792ee57123db193c6ba4de182a13df25da6f3afe Mon Sep 17 00:00:00 2001 From: ivanz-thinkpad Date: Fri, 31 Jan 2025 23:05:22 +0300 Subject: [PATCH 1/5] optimize split_matrix function by removing duplicate code to the extract_submatrix function, add tests --- .../strassen_matrix_multiplication.py | 20 +++--- divide_and_conquer/tests/__init__.py | 0 .../test_strassen_matrix_multiplication.py | 61 +++++++++++++++++++ 3 files changed, 72 insertions(+), 9 deletions(-) create mode 100644 divide_and_conquer/tests/__init__.py create mode 100644 divide_and_conquer/tests/test_strassen_matrix_multiplication.py diff --git a/divide_and_conquer/strassen_matrix_multiplication.py b/divide_and_conquer/strassen_matrix_multiplication.py index f529a255d2ef..78c2e56fac07 100644 --- a/divide_and_conquer/strassen_matrix_multiplication.py +++ b/divide_and_conquer/strassen_matrix_multiplication.py @@ -49,18 +49,20 @@ def split_matrix(a: list) -> tuple[list, list, list, list]: if len(a) % 2 != 0 or len(a[0]) % 2 != 0: raise Exception("Odd matrices are not supported!") - matrix_length = len(a) - mid = matrix_length // 2 + def extract_submatrix(rows, cols): + return [[a[i][j] for j in cols] for i in rows] - top_right = [[a[i][j] for j in range(mid, matrix_length)] for i in range(mid)] - bot_right = [ - [a[i][j] for j in range(mid, matrix_length)] for i in range(mid, matrix_length) - ] + mid = len(a) // 2 - top_left = [[a[i][j] for j in range(mid)] for i in range(mid)] - bot_left = [[a[i][j] for j in range(mid)] for i in range(mid, matrix_length)] + rows_top, rows_bot = range(mid), range(mid, len(a)) + cols_left, cols_right = range(mid), range(mid, len(a)) - return top_left, top_right, bot_left, bot_right + return ( + extract_submatrix(rows_top, cols_left), # Top-left + extract_submatrix(rows_top, cols_right), # Top-right + extract_submatrix(rows_bot, cols_left), # Bottom-left + extract_submatrix(rows_bot, cols_right), # Bottom-right + ) def matrix_dimensions(matrix: list) -> tuple[int, int]: diff --git a/divide_and_conquer/tests/__init__.py b/divide_and_conquer/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/divide_and_conquer/tests/test_strassen_matrix_multiplication.py b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py new file mode 100644 index 000000000000..1a6073f8d29d --- /dev/null +++ b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py @@ -0,0 +1,61 @@ +import unittest +from strassen_matrix_multiplication import split_matrix + + +class TestSplitMatrix(unittest.TestCase): + + def test_4x4_matrix(self): + matrix = [ + [4, 3, 2, 4], + [2, 3, 1, 1], + [6, 5, 4, 3], + [8, 4, 1, 6] + ] + expected = ( + [[4, 3], [2, 3]], + [[2, 4], [1, 1]], + [[6, 5], [8, 4]], + [[4, 3], [1, 6]] + ) + self.assertEqual(split_matrix(matrix), expected) + + def test_8x8_matrix(self): + matrix = [ + [4, 3, 2, 4, 4, 3, 2, 4], + [2, 3, 1, 1, 2, 3, 1, 1], + [6, 5, 4, 3, 6, 5, 4, 3], + [8, 4, 1, 6, 8, 4, 1, 6], + [4, 3, 2, 4, 4, 3, 2, 4], + [2, 3, 1, 1, 2, 3, 1, 1], + [6, 5, 4, 3, 6, 5, 4, 3], + [8, 4, 1, 6, 8, 4, 1, 6] + ] + expected = ( + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]] + ) + self.assertEqual(split_matrix(matrix), expected) + + def test_invalid_odd_matrix(self): + matrix = [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9] + ] + with self.assertRaises(Exception): + split_matrix(matrix) + + def test_invalid_non_square_matrix(self): + matrix = [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12] + ] + with self.assertRaises(Exception): + split_matrix(matrix) + + +if __name__ == "__main__": + unittest.main() From 30930a6e283daf20884558f21a0f1d26a8bee843 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 Jan 2025 20:09:04 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../test_strassen_matrix_multiplication.py | 26 +++++-------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/divide_and_conquer/tests/test_strassen_matrix_multiplication.py b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py index 1a6073f8d29d..240096caab6a 100644 --- a/divide_and_conquer/tests/test_strassen_matrix_multiplication.py +++ b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py @@ -3,19 +3,13 @@ class TestSplitMatrix(unittest.TestCase): - def test_4x4_matrix(self): - matrix = [ - [4, 3, 2, 4], - [2, 3, 1, 1], - [6, 5, 4, 3], - [8, 4, 1, 6] - ] + matrix = [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]] expected = ( [[4, 3], [2, 3]], [[2, 4], [1, 1]], [[6, 5], [8, 4]], - [[4, 3], [1, 6]] + [[4, 3], [1, 6]], ) self.assertEqual(split_matrix(matrix), expected) @@ -28,31 +22,23 @@ def test_8x8_matrix(self): [4, 3, 2, 4, 4, 3, 2, 4], [2, 3, 1, 1, 2, 3, 1, 1], [6, 5, 4, 3, 6, 5, 4, 3], - [8, 4, 1, 6, 8, 4, 1, 6] + [8, 4, 1, 6, 8, 4, 1, 6], ] expected = ( [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], - [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]] + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], ) self.assertEqual(split_matrix(matrix), expected) def test_invalid_odd_matrix(self): - matrix = [ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9] - ] + matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] with self.assertRaises(Exception): split_matrix(matrix) def test_invalid_non_square_matrix(self): - matrix = [ - [1, 2, 3, 4], - [5, 6, 7, 8], - [9, 10, 11, 12] - ] + matrix = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]] with self.assertRaises(Exception): split_matrix(matrix) From 598359efa5ccffbf65bacb7751b6a31dc2f49508 Mon Sep 17 00:00:00 2001 From: ivanz-thinkpad Date: Fri, 31 Jan 2025 23:26:54 +0300 Subject: [PATCH 3/5] fix test file issues --- .../test_strassen_matrix_multiplication.py | 121 +++++++++--------- 1 file changed, 60 insertions(+), 61 deletions(-) diff --git a/divide_and_conquer/tests/test_strassen_matrix_multiplication.py b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py index 1a6073f8d29d..4e2bf6515fc9 100644 --- a/divide_and_conquer/tests/test_strassen_matrix_multiplication.py +++ b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py @@ -1,61 +1,60 @@ -import unittest -from strassen_matrix_multiplication import split_matrix - - -class TestSplitMatrix(unittest.TestCase): - - def test_4x4_matrix(self): - matrix = [ - [4, 3, 2, 4], - [2, 3, 1, 1], - [6, 5, 4, 3], - [8, 4, 1, 6] - ] - expected = ( - [[4, 3], [2, 3]], - [[2, 4], [1, 1]], - [[6, 5], [8, 4]], - [[4, 3], [1, 6]] - ) - self.assertEqual(split_matrix(matrix), expected) - - def test_8x8_matrix(self): - matrix = [ - [4, 3, 2, 4, 4, 3, 2, 4], - [2, 3, 1, 1, 2, 3, 1, 1], - [6, 5, 4, 3, 6, 5, 4, 3], - [8, 4, 1, 6, 8, 4, 1, 6], - [4, 3, 2, 4, 4, 3, 2, 4], - [2, 3, 1, 1, 2, 3, 1, 1], - [6, 5, 4, 3, 6, 5, 4, 3], - [8, 4, 1, 6, 8, 4, 1, 6] - ] - expected = ( - [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], - [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], - [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], - [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]] - ) - self.assertEqual(split_matrix(matrix), expected) - - def test_invalid_odd_matrix(self): - matrix = [ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9] - ] - with self.assertRaises(Exception): - split_matrix(matrix) - - def test_invalid_non_square_matrix(self): - matrix = [ - [1, 2, 3, 4], - [5, 6, 7, 8], - [9, 10, 11, 12] - ] - with self.assertRaises(Exception): - split_matrix(matrix) - - -if __name__ == "__main__": - unittest.main() +import pytest +from divide_and_conquer.strassen_matrix_multiplication import split_matrix + + +def test_4x4_matrix(): + matrix = [ + [4, 3, 2, 4], + [2, 3, 1, 1], + [6, 5, 4, 3], + [8, 4, 1, 6] + ] + expected = ( + [[4, 3], [2, 3]], + [[2, 4], [1, 1]], + [[6, 5], [8, 4]], + [[4, 3], [1, 6]] + ) + assert split_matrix(matrix) == expected + + +def test_8x8_matrix(): + matrix = [ + [4, 3, 2, 4, 4, 3, 2, 4], + [2, 3, 1, 1, 2, 3, 1, 1], + [6, 5, 4, 3, 6, 5, 4, 3], + [8, 4, 1, 6, 8, 4, 1, 6], + [4, 3, 2, 4, 4, 3, 2, 4], + [2, 3, 1, 1, 2, 3, 1, 1], + [6, 5, 4, 3, 6, 5, 4, 3], + [8, 4, 1, 6, 8, 4, 1, 6] + ] + expected = ( + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]] + ) + assert split_matrix(matrix) == expected + + +def test_invalid_odd_matrix(): + matrix = [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9] + ] + with pytest.raises(Exception, match="Odd matrices are not supported!"): + split_matrix(matrix) + + +def test_invalid_non_square_matrix(): + matrix = [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + [17, 18, 19, 20] + ] + with pytest.raises(Exception, match="Odd matrices are not supported!"): + split_matrix(matrix) From 034d970bc2ad322dcbf8232c189b039d177d2dd9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 Jan 2025 20:30:58 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../test_strassen_matrix_multiplication.py | 26 +++++-------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/divide_and_conquer/tests/test_strassen_matrix_multiplication.py b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py index 4e2bf6515fc9..a5e646a75c8b 100644 --- a/divide_and_conquer/tests/test_strassen_matrix_multiplication.py +++ b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py @@ -3,18 +3,8 @@ def test_4x4_matrix(): - matrix = [ - [4, 3, 2, 4], - [2, 3, 1, 1], - [6, 5, 4, 3], - [8, 4, 1, 6] - ] - expected = ( - [[4, 3], [2, 3]], - [[2, 4], [1, 1]], - [[6, 5], [8, 4]], - [[4, 3], [1, 6]] - ) + matrix = [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]] + expected = ([[4, 3], [2, 3]], [[2, 4], [1, 1]], [[6, 5], [8, 4]], [[4, 3], [1, 6]]) assert split_matrix(matrix) == expected @@ -27,23 +17,19 @@ def test_8x8_matrix(): [4, 3, 2, 4, 4, 3, 2, 4], [2, 3, 1, 1, 2, 3, 1, 1], [6, 5, 4, 3, 6, 5, 4, 3], - [8, 4, 1, 6, 8, 4, 1, 6] + [8, 4, 1, 6, 8, 4, 1, 6], ] expected = ( [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], - [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]] + [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], ) assert split_matrix(matrix) == expected def test_invalid_odd_matrix(): - matrix = [ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9] - ] + matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] with pytest.raises(Exception, match="Odd matrices are not supported!"): split_matrix(matrix) @@ -54,7 +40,7 @@ def test_invalid_non_square_matrix(): [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16], - [17, 18, 19, 20] + [17, 18, 19, 20], ] with pytest.raises(Exception, match="Odd matrices are not supported!"): split_matrix(matrix) From c1d7da9dfab80f26dd0c8ab3805a44c638e10037 Mon Sep 17 00:00:00 2001 From: ivanz-thinkpad Date: Fri, 31 Jan 2025 23:36:43 +0300 Subject: [PATCH 5/5] fix codestyle --- divide_and_conquer/tests/test_strassen_matrix_multiplication.py | 1 + 1 file changed, 1 insertion(+) diff --git a/divide_and_conquer/tests/test_strassen_matrix_multiplication.py b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py index 4e2bf6515fc9..bddb3f940eee 100644 --- a/divide_and_conquer/tests/test_strassen_matrix_multiplication.py +++ b/divide_and_conquer/tests/test_strassen_matrix_multiplication.py @@ -1,4 +1,5 @@ import pytest + from divide_and_conquer.strassen_matrix_multiplication import split_matrix