Skip to content

Commit

Permalink
Feature visualization improvements 32 (#3947)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored Jul 9, 2021
1 parent dabad57 commit 248504c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
t1 = time_synchronized()
pred = model(img,
augment=augment,
visualize=increment_path(save_dir / 'features', mkdir=True) if visualize else False)[0]
visualize=increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False)[0]

# Apply NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
Expand Down
11 changes: 6 additions & 5 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import yaml
from PIL import Image, ImageDraw, ImageFont

from utils.general import increment_path, xywh2xyxy, xyxy2xywh
from utils.general import xywh2xyxy, xyxy2xywh
from utils.metrics import fitness

# Settings
Expand Down Expand Up @@ -447,7 +447,7 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
fig.savefig(Path(save_dir) / 'results.png', dpi=200)


def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detect/exp')):
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
"""
x: Features to be visualized
module_type: Module type
Expand All @@ -460,13 +460,14 @@ def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detec
if height > 1 and width > 1:
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename

plt.figure(tight_layout=True)
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
n = min(n, channels) # number of plots
ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)[1].ravel() # 8 rows x n/8 cols
fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
ax = ax.ravel()
plt.subplots_adjust(wspace=0.05, hspace=0.05)
for i in range(n):
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
ax[i].axis('off')

print(f'Saving {save_dir / f}... ({n}/{channels})')
plt.savefig(save_dir / f, dpi=300)
plt.savefig(save_dir / f, dpi=300, bbox_inches='tight')

0 comments on commit 248504c

Please sign in to comment.