Skip to content

Optim-wip: Fix & improve nchannels_to_rgb, hue_to_rgb #920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 62 additions & 40 deletions captum/optim/_utils/image/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,65 +110,87 @@ def _dot_cossim(
return dot * torch.clamp(torch.cosine_similarity(x, y, eps=eps), 0.1) ** cossim_pow


def nchannels_to_rgb(x: torch.Tensor, warp: bool = True) -> torch.Tensor:
"""
Convert an NCHW image with n channels into a 3 channel RGB image.
# Handle older versions of PyTorch
# Defined outside of function in order to support JIT
_torch_norm = torch.linalg.norm if torch.__version__ >= "1.9.0" else torch.norm


def hue_to_rgb(
angle: float, device: torch.device = torch.device("cpu"), warp: bool = True
) -> torch.Tensor:
"""
Create an RGB unit vector based on a hue of the input angle.
Args:
x (torch.Tensor): Image tensor to transform into RGB image.
warp (bool, optional): Whether or not to make colors more distinguishable.
angle (float): The hue angle to create an RGB color for.
device (torch.device, optional): The device to create the angle color tensor
on.
Default: torch.device("cpu")
warp (bool, optional): Whether or not to make colors more distinguishable.
Default: True
Returns:
*tensor* RGB image
color_vec (torch.Tensor): A color vector.
"""

def hue_to_rgb(angle: float) -> torch.Tensor:
"""
Create an RGB unit vector based on a hue of the input angle.
"""

angle = angle - 360 * (angle // 360)
colors = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.7071, 0.7071, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.7071, 0.7071],
[0.0, 0.0, 1.0],
[0.7071, 0.0, 0.7071],
]
)
angle = angle - 360 * (angle // 360)
colors = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.7071, 0.7071, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.7071, 0.7071],
[0.0, 0.0, 1.0],
[0.7071, 0.0, 0.7071],
],
device=device,
)

idx = math.floor(angle / 60)
d = (angle - idx * 60) / 60
idx = math.floor(angle / 60)
d = (angle - idx * 60) / 60

if warp:
if warp:
# Idea from: https://github.com/tensorflow/lucid/pull/193
d = (
math.sin(d * math.pi / 2)
if idx % 2 == 0
else 1 - math.sin((1 - d) * math.pi / 2)
)

def adj(x: float) -> float:
return math.sin(x * math.pi / 2)
vec = (1 - d) * colors[idx] + d * colors[(idx + 1) % 6]
return vec / _torch_norm(vec)

d = adj(d) if idx % 2 == 0 else 1 - adj(1 - d)

vec = (1 - d) * colors[idx] + d * colors[(idx + 1) % 6]
return vec / torch.norm(vec)
def nchannels_to_rgb(
x: torch.Tensor, warp: bool = True, eps: float = 1e-4
) -> torch.Tensor:
"""
Convert an NCHW image with n channels into a 3 channel RGB image.
Args:
x (torch.Tensor): NCHW image tensor to transform into RGB image.
warp (bool, optional): Whether or not to make colors more distinguishable.
Default: True
eps (float, optional): An optional epsilon value.
Default: 1e-4
Returns:
tensor (torch.Tensor): An NCHW RGB image tensor.
"""

assert x.dim() == 4

if (x < 0).any():
x = posneg(x.permute(0, 2, 3, 1), -1).permute(0, 3, 1, 2)

rgb = torch.zeros(1, 3, x.size(2), x.size(3), device=x.device)
nc = x.size(1)
for i in range(nc):
rgb = rgb + x[:, i][:, None, :, :]
rgb = rgb * hue_to_rgb(360 * i / nc).to(device=x.device)[None, :, None, None]

rgb = rgb + torch.ones(x.size(2), x.size(3))[None, None, :, :] * (
torch.sum(x, 1)[:, None] - torch.max(x, 1)[0][:, None]
)
return (rgb / (1e-4 + torch.norm(rgb, dim=1, keepdim=True))) * torch.norm(
x, dim=1, keepdim=True
num_channels = x.size(1)
for i in range(num_channels):
rgb_angle = hue_to_rgb(360 * i / num_channels, device=x.device, warp=warp)
rgb = rgb + (x[:, i][:, None, :, :] * rgb_angle[None, :, None, None])

rgb = rgb + (
torch.ones(1, 1, x.size(2), x.size(3), device=x.device)
* (torch.sum(x, 1) - torch.max(x, 1)[0])[:, None]
)
rgb = rgb / (eps + _torch_norm(rgb, dim=1, keepdim=True))
return rgb * _torch_norm(x, dim=1, keepdim=True)


def weights_to_heatmap_2d(
Expand Down
183 changes: 179 additions & 4 deletions tests/optim/utils/image/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,190 @@ def test_get_neuron_pos_none_y(self) -> None:
self.assertEqual(y, 5)


class TestHueToRGB(BaseTest):
def test_hue_to_rgb_n_groups_4_warp_true(self) -> None:
n_groups = 4
channels = list(range(n_groups))
test_outputs = []
for ch in channels:
output = common.hue_to_rgb(360 * ch / n_groups)
test_outputs.append(output)
test_outputs = torch.stack(test_outputs)
expected_outputs = torch.tensor(
[
[1.0000, 0.0000, 0.0000],
[0.5334, 0.8459, 0.0000],
[0.0000, 0.7071, 0.7071],
[0.5334, 0.0000, 0.8459],
]
)
assertTensorAlmostEqual(self, test_outputs, expected_outputs)

def test_hue_to_rgb_n_groups_4_warp_false(self) -> None:
n_groups = 4
channels = list(range(n_groups))
test_outputs = []
for ch in channels:
output = common.hue_to_rgb(360 * ch / n_groups, warp=False)
test_outputs.append(output)
test_outputs = torch.stack(test_outputs)
expected_outputs = torch.tensor(
[
[1.0000, 0.0000, 0.0000],
[0.3827, 0.9239, 0.0000],
[0.0000, 0.7071, 0.7071],
[0.3827, 0.0000, 0.9239],
]
)
assertTensorAlmostEqual(self, test_outputs, expected_outputs)

def test_hue_to_rgb_n_groups_3_warp_true(self) -> None:
n_groups = 3
channels = list(range(n_groups))
test_outputs = []
for ch in channels:
output = common.hue_to_rgb(360 * ch / n_groups)
test_outputs.append(output)
test_outputs = torch.stack(test_outputs)
expected_outputs = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
)
assertTensorAlmostEqual(self, test_outputs, expected_outputs, delta=0.0)

def test_hue_to_rgb_n_groups_2_warp_true(self) -> None:
n_groups = 2
channels = list(range(n_groups))
test_outputs = []
for ch in channels:
output = common.hue_to_rgb(360 * ch / n_groups)
test_outputs.append(output)
test_outputs = torch.stack(test_outputs)
expected_outputs = torch.tensor(
[[1.0000, 0.0000, 0.0000], [0.0000, 0.7071, 0.7071]]
)
assertTensorAlmostEqual(self, test_outputs, expected_outputs)

def test_hue_to_rgb_n_groups_2_warp_false(self) -> None:
n_groups = 2
channels = list(range(n_groups))
test_outputs = []
for ch in channels:
output = common.hue_to_rgb(360 * ch / n_groups, warp=False)
test_outputs.append(output)
test_outputs = torch.stack(test_outputs)
expected_outputs = torch.tensor(
[[1.0000, 0.0000, 0.0000], [0.0000, 0.7071, 0.7071]]
)
assertTensorAlmostEqual(self, test_outputs, expected_outputs)


class TestNChannelsToRGB(BaseTest):
def test_nchannels_to_rgb_collapse(self) -> None:
test_input = torch.randn(1, 6, 224, 224)
test_output = common.nchannels_to_rgb(test_input)
self.assertEqual(list(test_output.size()), [1, 3, 224, 224])
test_input = torch.arange(0, 1 * 4 * 4 * 4).view(1, 4, 4, 4).float()
test_output = common.nchannels_to_rgb(test_input, warp=True)
expected_output = torch.tensor(
[
[
[
[30.3782, 31.5489, 32.7147, 33.8773],
[35.0379, 36.1975, 37.3568, 38.5163],
[39.6765, 40.8378, 42.0003, 43.1642],
[44.3296, 45.4967, 46.6655, 47.8360],
],
[
[31.1266, 32.0951, 33.0678, 34.0451],
[35.0270, 36.0137, 37.0051, 38.0011],
[39.0015, 40.0063, 41.0152, 42.0282],
[43.0449, 44.0654, 45.0894, 46.1167],
],
[
[41.1375, 41.8876, 42.6646, 43.4656],
[44.2882, 45.1304, 45.9901, 46.8658],
[47.7561, 48.6597, 49.5754, 50.5023],
[51.4394, 52.3859, 53.3411, 54.3044],
],
]
]
)
assertTensorAlmostEqual(self, test_output, expected_output, delta=0.005)

def test_nchannels_to_rgb_collapse_warp_false(self) -> None:
test_input = torch.arange(0, 1 * 4 * 4 * 4).view(1, 4, 4, 4).float()
test_output = common.nchannels_to_rgb(test_input, warp=False)
expected_output = torch.tensor(
[
[
[
[27.0349, 28.1947, 29.3453, 30.4887],
[31.6266, 32.7605, 33.8914, 35.0201],
[36.1474, 37.2737, 38.3995, 39.5252],
[40.6511, 41.7772, 42.9039, 44.0312],
],
[
[31.8525, 32.8600, 33.8708, 34.8851],
[35.9034, 36.9257, 37.9522, 38.9828],
[40.0175, 41.0561, 42.0987, 43.1451],
[44.1951, 45.2486, 46.3054, 47.3655],
],
[
[42.8781, 43.6494, 44.4480, 45.2710],
[46.1162, 46.9813, 47.8644, 48.7640],
[49.6786, 50.6069, 51.5477, 52.5000],
[53.4629, 54.4355, 55.4172, 56.4071],
],
]
]
)
assertTensorAlmostEqual(self, test_output, expected_output, delta=0.005)

def test_nchannels_to_rgb_increase(self) -> None:
test_input = torch.randn(1, 2, 224, 224)
test_input = torch.arange(0, 1 * 2 * 4 * 4).view(1, 2, 4, 4).float()
test_output = common.nchannels_to_rgb(test_input, warp=True)
expected_output = torch.tensor(
[
[
[
[0.0000, 1.8388, 3.4157, 4.8079],
[6.0713, 7.2442, 8.3524, 9.4137],
[10.4405, 11.4414, 12.4226, 13.3886],
[14.3428, 15.2878, 16.2253, 17.1568],
],
[
[11.3136, 11.9711, 12.5764, 13.1697],
[13.7684, 14.3791, 15.0039, 15.6425],
[16.2941, 16.9572, 17.6306, 18.3131],
[19.0037, 19.7013, 20.4051, 21.1145],
],
[
[11.3136, 11.9711, 12.5764, 13.1697],
[13.7684, 14.3791, 15.0039, 15.6425],
[16.2941, 16.9572, 17.6306, 18.3131],
[19.0037, 19.7013, 20.4051, 21.1145],
],
]
]
)
assertTensorAlmostEqual(self, test_output, expected_output, delta=0.005)

def test_nchannels_to_rgb_cuda(self) -> None:
if not torch.cuda.is_available():
raise unittest.SkipTest(
"Skipping nchannels_to_rgb CUDA test due to not supporting CUDA."
)
test_input = torch.randn(1, 6, 224, 224).cuda()
test_output = common.nchannels_to_rgb(test_input)
self.assertTrue(test_output.is_cuda)
self.assertEqual(list(test_output.size()), [1, 3, 224, 224])

def test_nchannels_to_rgb_jit_module(self) -> None:
if torch.__version__ <= "1.8.0":
raise unittest.SkipTest(
"Skipping nchannels_to_rgb JIT module test due to insufficient Torch"
+ " version."
)
test_input = torch.randn(1, 6, 224, 224)
jit_nchannels_to_rgb = torch.jit.script(common.nchannels_to_rgb)
test_output = jit_nchannels_to_rgb(test_input)
self.assertEqual(list(test_output.size()), [1, 3, 224, 224])


Expand Down