|  | 
|  | 1 | +#!/usr/bin/env python3 | 
|  | 2 | +import unittest | 
|  | 3 | + | 
|  | 4 | +import torch | 
|  | 5 | + | 
|  | 6 | +import captum.optim._utils.atlas as atlas | 
|  | 7 | +from tests.helpers.basic import BaseTest, assertTensorAlmostEqual | 
|  | 8 | + | 
|  | 9 | + | 
|  | 10 | +class TestNormalizeGrid(BaseTest): | 
|  | 11 | +    def test_normalize_grid(self) -> None: | 
|  | 12 | +        x = torch.arange(0, 2 * 3 * 3).view(3 * 3, 2).float() | 
|  | 13 | + | 
|  | 14 | +        x_out = atlas.normalize_grid(x) | 
|  | 15 | + | 
|  | 16 | +        x_expected = torch.tensor( | 
|  | 17 | +            [ | 
|  | 18 | +                [0.0000, 0.0000], | 
|  | 19 | +                [0.1250, 0.1250], | 
|  | 20 | +                [0.2500, 0.2500], | 
|  | 21 | +                [0.3750, 0.3750], | 
|  | 22 | +                [0.5000, 0.5000], | 
|  | 23 | +                [0.6250, 0.6250], | 
|  | 24 | +                [0.7500, 0.7500], | 
|  | 25 | +                [0.8750, 0.8750], | 
|  | 26 | +                [1.0000, 1.0000], | 
|  | 27 | +            ] | 
|  | 28 | +        ) | 
|  | 29 | + | 
|  | 30 | +        assertTensorAlmostEqual(self, x_out, x_expected) | 
|  | 31 | + | 
|  | 32 | + | 
|  | 33 | +class TestGridIndices(BaseTest): | 
|  | 34 | +    def test_grid_indices(self) -> None: | 
|  | 35 | +        x = torch.arange(0, 2 * 3 * 3).view(3 * 3, 2).float() | 
|  | 36 | +        x = atlas.normalize_grid(x) | 
|  | 37 | +        x_indices = atlas.grid_indices(x, size=(2, 2)) | 
|  | 38 | + | 
|  | 39 | +        expected_indices = [ | 
|  | 40 | +            [torch.tensor([0, 1, 2, 3, 4]), torch.tensor([4])], | 
|  | 41 | +            [torch.tensor([4]), torch.tensor([4, 5, 6, 7, 8])], | 
|  | 42 | +        ] | 
|  | 43 | + | 
|  | 44 | +        for list1, list2 in zip(x_indices, expected_indices): | 
|  | 45 | +            for t1, t2 in zip(list1, list2): | 
|  | 46 | +                assertTensorAlmostEqual(self, t1, t2) | 
|  | 47 | + | 
|  | 48 | + | 
|  | 49 | +class TestExtractGridVectors(BaseTest): | 
|  | 50 | +    def test_extract_grid_vectors(self) -> None: | 
|  | 51 | +        x_raw = torch.arange(0, 4 * 3 * 3).view(3 * 3, 4).float() | 
|  | 52 | +        x = torch.arange(0, 2 * 3 * 3).view(3 * 3, 2).float() | 
|  | 53 | +        x = atlas.normalize_grid(x) | 
|  | 54 | +        x_indices = atlas.grid_indices(x, size=(2, 2)) | 
|  | 55 | + | 
|  | 56 | +        x_vecs, vec_coords = atlas.extract_grid_vectors( | 
|  | 57 | +            x_indices, x_raw, size=(2, 2), min_density=2 | 
|  | 58 | +        ) | 
|  | 59 | + | 
|  | 60 | +        expected_vecs = torch.tensor([[8.0, 9.0, 10.0, 11.0], [24.0, 25.0, 26.0, 27.0]]) | 
|  | 61 | +        expected_coords = [(0, 0), (1, 1)] | 
|  | 62 | + | 
|  | 63 | +        assertTensorAlmostEqual(self, x_vecs, expected_vecs) | 
|  | 64 | +        self.assertEqual(vec_coords, expected_coords) | 
|  | 65 | + | 
|  | 66 | + | 
|  | 67 | +class TestCreateAtlasVectors(BaseTest): | 
|  | 68 | +    def test_create_atlas_vectors(self) -> None: | 
|  | 69 | +        x_raw = torch.arange(0, 4 * 3 * 3).view(3 * 3, 4).float() | 
|  | 70 | +        x = torch.arange(0, 2 * 3 * 3).view(3 * 3, 2).float() | 
|  | 71 | +        x_vecs, vec_coords = atlas.create_atlas_vectors( | 
|  | 72 | +            x, x_raw, size=(2, 2), min_density=2, normalize=True | 
|  | 73 | +        ) | 
|  | 74 | + | 
|  | 75 | +        expected_vecs = torch.tensor([[8.0, 9.0, 10.0, 11.0], [24.0, 25.0, 26.0, 27.0]]) | 
|  | 76 | +        expected_coords = [(0, 0), (1, 1)] | 
|  | 77 | + | 
|  | 78 | +        assertTensorAlmostEqual(self, x_vecs, expected_vecs) | 
|  | 79 | +        self.assertEqual(vec_coords, expected_coords) | 
|  | 80 | + | 
|  | 81 | + | 
|  | 82 | +class TestCreateAtlas(BaseTest): | 
|  | 83 | +    def test_create_atlas(self) -> None: | 
|  | 84 | +        img_list = [torch.ones(1, 3, 4, 4)] * 2 | 
|  | 85 | +        expected_coords = [(0, 0), (1, 1)] | 
|  | 86 | +        canvas = atlas.create_atlas(img_list, expected_coords, grid_size=(2, 2)) | 
|  | 87 | +        assertTensorAlmostEqual(self, canvas, torch.ones_like(canvas)) | 
|  | 88 | + | 
|  | 89 | + | 
|  | 90 | +if __name__ == "__main__": | 
|  | 91 | +    unittest.main() | 
0 commit comments