|
| 1 | +from typing import List, Tuple |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | + |
| 6 | +def grid_indices( |
| 7 | + tensor: torch.Tensor, |
| 8 | + size: Tuple[int, int] = (8, 8), |
| 9 | + x_extent: Tuple[float, float] = (0.0, 1.0), |
| 10 | + y_extent: Tuple[float, float] = (0.0, 1.0), |
| 11 | +) -> List[List[torch.Tensor]]: |
| 12 | + """ |
| 13 | + Create grid cells of a specified size for an irregular grid. |
| 14 | + """ |
| 15 | + |
| 16 | + assert tensor.dim() == 2 and tensor.size(1) == 2 |
| 17 | + x_coords = ((tensor[:, 0] - x_extent[0]) / (x_extent[1] - x_extent[0])) * size[1] |
| 18 | + y_coords = ((tensor[:, 1] - y_extent[0]) / (y_extent[1] - y_extent[0])) * size[0] |
| 19 | + |
| 20 | + x_list = [] |
| 21 | + for x in range(size[1]): |
| 22 | + y_list = [] |
| 23 | + for y in range(size[0]): |
| 24 | + in_bounds_x = torch.logical_and(x <= x_coords, x_coords <= x + 1) |
| 25 | + in_bounds_y = torch.logical_and(y <= y_coords, y_coords <= y + 1) |
| 26 | + in_bounds_indices = torch.where( |
| 27 | + torch.logical_and(in_bounds_x, in_bounds_y) |
| 28 | + )[0] |
| 29 | + y_list.append(in_bounds_indices) |
| 30 | + x_list.append(y_list) |
| 31 | + return x_list |
| 32 | + |
| 33 | + |
| 34 | +def normalize_grid( |
| 35 | + x: torch.Tensor, |
| 36 | + min_percentile: float = 0.01, |
| 37 | + max_percentile: float = 0.99, |
| 38 | + relative_margin: float = 0.1, |
| 39 | +) -> torch.Tensor: |
| 40 | + """ |
| 41 | + Remove outliers and rescale grid to [0,1]. |
| 42 | + """ |
| 43 | + |
| 44 | + assert x.dim() == 2 and x.size(1) == 2 |
| 45 | + mins = torch.quantile(x, min_percentile, dim=0) |
| 46 | + maxs = torch.quantile(x, max_percentile, dim=0) |
| 47 | + |
| 48 | + # add margins |
| 49 | + mins = mins - relative_margin * (maxs - mins) |
| 50 | + maxs = maxs + relative_margin * (maxs - mins) |
| 51 | + |
| 52 | + clipped = torch.max(torch.min(x, maxs), mins) |
| 53 | + clipped = clipped - clipped.min(0)[0] |
| 54 | + return clipped / clipped.max(0)[0] |
| 55 | + |
| 56 | + |
| 57 | +def extract_grid_vectors( |
| 58 | + grid: List[List[torch.Tensor]], |
| 59 | + activations: torch.Tensor, |
| 60 | + size: Tuple[int, int] = (8, 8), |
| 61 | + min_density: int = 8, |
| 62 | +) -> Tuple[torch.Tensor, List[Tuple[int, int]]]: |
| 63 | + """ |
| 64 | + Create direction vectors. |
| 65 | + """ |
| 66 | + |
| 67 | + cell_coords = [] |
| 68 | + average_activations = [] |
| 69 | + for x in range(size[1]): |
| 70 | + for y in range(size[0]): |
| 71 | + indices = grid[x][y] |
| 72 | + if len(indices) >= min_density: |
| 73 | + average_activations.append(torch.mean(activations[indices], 0)) |
| 74 | + cell_coords.append((x, y)) |
| 75 | + return torch.stack(average_activations), cell_coords |
| 76 | + |
| 77 | + |
| 78 | +def create_atlas_vectors( |
| 79 | + tensor: torch.Tensor, |
| 80 | + activations: torch.Tensor, |
| 81 | + size: Tuple[int, int] = (8, 8), |
| 82 | + min_density: int = 8, |
| 83 | + normalize: bool = True, |
| 84 | +) -> Tuple[torch.Tensor, List[Tuple[int, int]]]: |
| 85 | + """ |
| 86 | + Create direction vectors by splitting an irregular grid into cells. |
| 87 | + """ |
| 88 | + |
| 89 | + assert tensor.dim() == 2 and tensor.size(1) == 2 |
| 90 | + if normalize: |
| 91 | + tensor = normalize_grid(tensor) |
| 92 | + indices = grid_indices(tensor, size) |
| 93 | + grid_vecs, vec_coords = extract_grid_vectors( |
| 94 | + indices, activations, size, min_density |
| 95 | + ) |
| 96 | + return grid_vecs, vec_coords |
| 97 | + |
| 98 | + |
| 99 | +def create_atlas( |
| 100 | + cells: List[torch.Tensor], |
| 101 | + coords: List[List[torch.Tensor]], |
| 102 | + grid_size: Tuple[int, int] = (8, 8), |
| 103 | +) -> torch.Tensor: |
| 104 | + cell_h, cell_w = cells[0].shape[2:] |
| 105 | + canvas = torch.ones(1, 3, cell_h * grid_size[0], cell_w * grid_size[1]) |
| 106 | + for i, img in enumerate(cells): |
| 107 | + y = int(coords[i][0]) |
| 108 | + x = int(coords[i][1]) |
| 109 | + canvas[ |
| 110 | + ..., |
| 111 | + (grid_size[0] - x - 1) * cell_h : (grid_size[0] - x) * cell_h, |
| 112 | + y * cell_w : (y + 1) * cell_w, |
| 113 | + ] = img |
| 114 | + return canvas |
0 commit comments