Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Commit

Permalink
addressing all comments. adding support for training with 100 templates
Browse files Browse the repository at this point in the history
  • Loading branch information
Shrinu Kushagra committed Nov 30, 2021
1 parent 457cec7 commit e803909
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
Binary file added examples/example_data/rcn_100.npz
Binary file not shown.
File renamed without changes.
24 changes: 14 additions & 10 deletions examples/rcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
# # 1. Load the data
# %%
hps, vps = 12, 12

# Use train_size = 100 if you have a gpu with atleast 8Gbs memory.
# Recommend that jax is installed with cuda enabled for this option.
train_size = 20
test_size = 20

Expand Down Expand Up @@ -98,7 +101,9 @@ def fetch_mnist_dataset(test_size: int, seed: int = 5):
# # 2. Load the model

# %%
data = np.load("example_data/rcn.npz", allow_pickle=True, encoding="latin1")
data = np.load(
f"example_data/rcn_{train_size}.npz", allow_pickle=True, encoding="latin1"
)
frcs, edges, suppression_masks, filters = (
data["frcs"],
data["edges"],
Expand All @@ -113,7 +118,7 @@ def fetch_mnist_dataset(test_size: int, seed: int = 5):

# %%
img = np.ones((200, 200))

pad = 44
frc, edge = frcs[4], edges[4]
plt.figure(figsize=(10, 10))
for e in edge:
Expand All @@ -123,10 +128,10 @@ def fetch_mnist_dataset(test_size: int, seed: int = 5):

img[r1, c1] = 0
img[r2, c2] = 0
plt.text((c1 + c2) // 2, (r1 + r2) // 2, str(w), color="green")
plt.plot([c1, c2], [r1, r2], color="green", linewidth=0.5)
plt.text((c1 + c2) // 2 - pad, (r1 + r2) // 2 - pad, str(w), color="green")
plt.plot([c1 - pad, c2 - pad], [r1 - pad, r2 - pad], color="green", linewidth=0.5)

plt.imshow(img, cmap="gray")
plt.imshow(img[pad : 200 - pad, pad : 200 - pad], cmap="gray")


# %% [markdown]
Expand Down Expand Up @@ -412,6 +417,7 @@ def initialize_evidences(test_img: np.ndarray) -> Dict:
pred_idxs = np.argmax(scores, axis=1)
n_plots = [0, 5, 10]
for ii, pred_idx in enumerate(n_plots):
plt.subplot(len(n_plots), 1, 1 + ii)

map_states = map_states_dict[pred_idx]
map_state = map_states[pred_idxs[pred_idx]]
Expand All @@ -423,10 +429,8 @@ def initialize_evidences(test_img: np.ndarray) -> Dict:

delta_r, delta_c = -hps + idx // (2 * vps + 1), -vps + idx % (2 * vps + 1)
rd, cd = r + delta_r, c + delta_c
imgs[ii, rd, cd] = 0
imgs[ii, rd, cd] = 255
plt.plot(cd, rd, "r.")

plt.subplot(len(n_plots), 2, 1 + 2 * ii)
# plt.imshow(imgs[ii, :, :])
plt.imshow(test_set[pred_idx], cmap="gray")

plt.subplot(len(n_plots), 2, 2 + 2 * ii)
plt.imshow(imgs[ii, :, :], cmap="gray")

0 comments on commit e803909

Please sign in to comment.