Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Commit

Permalink
visualization changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Shrinu Kushagra committed Nov 29, 2021
1 parent f701eea commit 457cec7
Showing 1 changed file with 131 additions and 91 deletions.
222 changes: 131 additions & 91 deletions examples/rcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,69 +43,55 @@
test_size = 20


def fetch_mnist_dataset(train_size: int, test_size: int, seed: int = 5):
"""Returns training and test images sampled randomly from the set of MNIST images.
def fetch_mnist_dataset(test_size: int, seed: int = 5):
"""Returns test images sampled randomly from the set of MNIST images.
Args:
train_size: Desired number of training images.
test_size: Desired number of test images.
Returns:
train_set: A list of length train_size containing images from the MNIST train dataset.
train_labels: Corresponding labels for the train images.
test_set: A list of length test_size containing images from the MNIST test dataset.
test_labels: Corresponding labels for the test images.
"""

mnist_train_size = 60000
num_per_class = test_size // 10

print("Fetching the MNIST dataset")
dataset = fetch_openml("mnist_784", as_frame=False, cache=True)
print("Successfully downloaded the MNIST dataset")

print("Fetched the data")
mnist_images = dataset["data"]
mnist_labels = dataset["target"].astype("int")

def _sample_data(images, labels, num_per_class):
t_set = []
t_labels = []
for i in range(10):
idxs = np.random.choice(np.argwhere(labels == i)[:, 0], num_per_class)
for idx in idxs:
img = images[idx].reshape(28, 28)
img_arr = jax.image.resize(
image=img, shape=(112, 112), method="bicubic"
)
img = jnp.pad(
img_arr,
pad_width=tuple([(p, p) for p in (44, 44)]),
mode="constant",
constant_values=0,
)

t_set.append(img)
t_labels.append(i)
return t_set, t_labels
full_mnist_test_images = mnist_images[mnist_train_size:]
full_mnist_test_labels = mnist_labels[mnist_train_size:]

np.random.seed(seed)
full_train_set, full_train_labels = (
mnist_images[:mnist_train_size],
mnist_labels[:mnist_train_size],
)
full_test_set, full_test_labels = (
mnist_images[mnist_train_size:],
mnist_labels[mnist_train_size:],
)

train_set, train_labels = _sample_data(
full_train_set, full_train_labels, train_size // 10
)
test_set, test_labels = _sample_data(
full_test_set, full_test_labels, test_size // 10
)
test_set = []
test_labels = []
for i in range(10):
idxs = np.random.choice(
np.argwhere(full_mnist_test_labels == i)[:, 0], num_per_class
)
for idx in idxs:
img = full_mnist_test_images[idx].reshape(28, 28)
img_arr = jax.image.resize(image=img, shape=(112, 112), method="bicubic")
img = jnp.pad(
img_arr,
pad_width=tuple([(p, p) for p in (44, 44)]),
mode="constant",
constant_values=0,
)

return train_set, np.array(train_labels), test_set, np.array(test_labels)
test_set.append(img)
test_labels.append(i)

return test_set, np.array(test_labels)


# %%
train_set, train_labels, test_set, test_labels = fetch_mnist_dataset(
train_size, test_size
test_set, test_labels = fetch_mnist_dataset(test_size)
train_labels = (
np.array([[i] * (train_size // 10) for i in range(10)]).reshape(1, -1).squeeze()
)

# %% [markdown]
Expand All @@ -120,13 +106,13 @@ def _sample_data(images, labels, num_per_class):
data["filters"],
)

M = (2 * hps + 1) * (2 * vps + 1) + 1
M = (2 * hps + 1) * (2 * vps + 1)

# %% [markdown]
# # 3. Visualize loaded model

# %%
img = np.zeros((200, 200))
img = np.ones((200, 200))

frc, edge = frcs[4], edges[4]
plt.figure(figsize=(10, 10))
Expand All @@ -135,19 +121,31 @@ def _sample_data(images, labels, num_per_class):
f1, r1, c1 = frc[i1]
f2, r2, c2 = frc[i2]

img[r1, c1] = 255
img[r2, c2] = 255
plt.text((c1 + c2) // 2, (r1 + r2) // 2, str(w), color="blue")
plt.plot([c1, c2], [r1, r2], color="blue", linewidth=0.5)
img[r1, c1] = 0
img[r2, c2] = 0
plt.text((c1 + c2) // 2, (r1 + r2) // 2, str(w), color="green")
plt.plot([c1, c2], [r1, r2], color="green", linewidth=0.5)

plt.imshow(img, cmap="gray")


# %% [markdown]
# # 3. Make pgmax graph
# ## 3.1 Visualize the filters

# %% [markdown]
# The filters are used to detect the oriented edges on a given image. They are pre-computed using Gabor filters.

# %%
plt.figure(figsize=(10, 10))
for i in range(filters.shape[0]):
plt.subplot(4, 4, i + 1)
plt.imshow(filters[i], cmap="gray")

# %% [markdown]
# # 4. Make pgmax graph

# %% [markdown]
# ## 3.1 Make variables
# ## 4.1 Make variables

# %%
start = time.time()
Expand All @@ -165,18 +163,18 @@ def _sample_data(images, labels, num_per_class):


# %% [markdown]
# ## 3.2 Make factors
# ## 4.2 Make factors

# %% [markdown]
# ### 3.2.1 Pre-compute the valid configs for different perturb radii.
# ### 4.2.1 Pre-compute the valid configs for different perturb radii.

# %%
def valid_configs(r: int) -> np.ndarray:
"""Returns the valid configurations for the potential matrix given the perturb radius.
Args:
r: Peturb radius
Returns:
phi: A configuration matrix (shape n X 2) where n is the number of valid configurations.
config_matrix: A configuration matrix (shape n X 2) where n is the number of valid configurations.
"""

rows = []
Expand Down Expand Up @@ -207,7 +205,7 @@ def valid_configs(r: int) -> np.ndarray:
phis.append(phi_r)

# %% [markdown]
# ### 3.2.2 Make the factor graph
# ### 4.2.2 Make the factor graph

# %%
start = end
Expand All @@ -228,10 +226,10 @@ def valid_configs(r: int) -> np.ndarray:


# %% [markdown]
# # 4. Run inference
# # 5. Run inference

# %% [markdown]
# ## 4.1 Helper functions to initialize the evidence for a given image
# ## 5.1 Helper functions to initialize the evidence for a given image

# %%
def get_bu_msg(img: np.ndarray) -> np.ndarray:
Expand Down Expand Up @@ -272,10 +270,46 @@ def get_bu_msg(img: np.ndarray) -> np.ndarray:
if factor != 1:
pos_chan[pos_chan > 0] *= factor
np.maximum(pc, pos_chan, pc)

bu_msg = np.array(pooled_channels)
bu_msg[bu_msg == 0] = -1
return bu_msg


# %% [markdown]
# ## 5.1.1 Visualizing bu_msg for a sample image

# %% [markdown]
# bu_msg has shape (16, H, W) where each 1 <= f <= 16 denotes the present or absense of a oriented edge

# %%
r_test_img = test_set[4]
r_bu_msg = get_bu_msg(r_test_img)
img = np.ones((200, 200))

plt.figure(figsize=(10, 10))

plt.subplot(1, 2, 1)
plt.imshow(r_test_img, cmap="gray")
for i in range(r_bu_msg.shape[0]):
img[r_bu_msg[i] > 0] = 0

plt.subplot(1, 2, 2)
plt.imshow(img, cmap="gray")

# %% [markdown]
# Showing the individual filter activations in r_bu_msg

# %%
plt.figure(figsize=(10, 10))

for i in range(r_bu_msg.shape[0]):
plt.subplot(5, 4, i + 1)
rbm = r_bu_msg[i]
rbm[rbm == 1] = -2
plt.imshow(rbm, cmap="gray")


# %%
def initialize_evidences(test_img: np.ndarray) -> Dict:
"""Computes the initial evidences to the PGMax factor graph given a test image.
Expand All @@ -286,30 +320,43 @@ def initialize_evidences(test_img: np.ndarray) -> Dict:
"""

bu_msg = get_bu_msg(test_img)
# jnp_bu_msg = jnp.asarray(bu_msg)

evidence_updates = {}
for idx in range(frcs.shape[0]):
frc = frcs[idx]

unary_msg = -1 + np.zeros((frc.shape[0], M))
# evidence_updates[idx] = get_evidence(jnp_bu_msg, frc)

for v in range(frc.shape[0]):
f, r, c = frc[v, :]
evidence = bu_msg[f, r - hps : r + hps + 1, c - hps : c + hps + 1]
indices = np.transpose(np.nonzero(evidence > 0))

for index in indices:
r1, c1 = index
delta_r, delta_c = r1 - hps, c1 - vps

index = delta_c + vps + (2 * hps + 1) * (delta_r + hps)
unary_msg[v, index] = 1
evidence = bu_msg[f, r - hps : r + hps + 1, c - vps : c + vps + 1]
unary_msg[v] = evidence.ravel()

evidence_updates[idx] = unary_msg

return evidence_updates


# from functools import partial


# @partial(jax.vmap, in_axes=(None, 0), out_axes=0)
# def get_evidence(bu_msg, frc):
# """
# bu_msg: Array of shape (n_features, M, N)
# frc: Array of shape (n_frcs, 3)
# """

# return jax.lax.dynamic_slice(
# bu_msg[frc[0]],
# jnp.array([frc[1] - hps, frc[2] - vps]),
# jnp.array([2 * hps + 1, 2 * vps + 1])
# ).ravel()

# %% [markdown]
# ## 4.2 Run map product inference on all test images
# ## 5.2 Run map product inference on all test images

# %%
run_bp_fn, _, get_beliefs_fn = graph.BP(fg.bp_state, 30)
Expand All @@ -321,6 +368,7 @@ def initialize_evidences(test_img: np.ndarray) -> Dict:

start = time.time()
evidence_updates = initialize_evidences(img)

end = time.time()
print(f"Initializing evidences took {end-start:.3f} seconds for image {test_idx}.")

Expand All @@ -345,7 +393,7 @@ def initialize_evidences(test_img: np.ndarray) -> Dict:


# %% [markdown]
# # 5. Compute metrics (accuracy)
# # 6. Compute metrics (accuracy)

# %%
test_preds = train_labels[scores.argmax(axis=1)]
Expand All @@ -355,38 +403,30 @@ def initialize_evidences(test_img: np.ndarray) -> Dict:


# %% [markdown]
# # 6. Visualize predictions
# # 7. Visualize predictions - backtrace for the top model

# %%
test_idx = 0
plt.imshow(test_set[test_idx], cmap="gray")


# %% [markdown]
# ## 6.1 Backtrace of some models on this test image

imgs = np.ones((20, 200, 200))
plt.figure(figsize=(15, 15))

# %%
map_states = map_states_dict[test_idx]
imgs = np.ones((len(frcs), 200, 200))
pred_idxs = np.argmax(scores, axis=1)
n_plots = [0, 5, 10]
for ii, pred_idx in enumerate(n_plots):

for i in range(frcs.shape[0]):
map_state = map_states[i]
frc = frcs[i]
map_states = map_states_dict[pred_idx]
map_state = map_states[pred_idxs[pred_idx]]
frc = frcs[pred_idx]

for v in range(frc.shape[0]):
idx = map_state[v]
f, r, c = frc[v]

delta_r, delta_c = -hps + idx // (2 * vps + 1), -vps + idx % (2 * vps + 1)
rd, cd = r + delta_r, c + delta_c
imgs[i, rd, cd] = 0
plt.figure(figsize=(15, 15))

for k, index in enumerate(range(0, len(train_set), 5)):
plt.subplot(1, 4, 1 + k)
plt.title(f" Model {int(train_labels[index])}")
plt.imshow(imgs[index, :, :], cmap="gray")
imgs[ii, rd, cd] = 0

plt.subplot(len(n_plots), 2, 1 + 2 * ii)
plt.imshow(test_set[pred_idx], cmap="gray")

# %%
plt.subplot(len(n_plots), 2, 2 + 2 * ii)
plt.imshow(imgs[ii, :, :], cmap="gray")

0 comments on commit 457cec7

Please sign in to comment.