Skip to content

Commit

Permalink
Add feature map visualization (#3804)
Browse files Browse the repository at this point in the history
* Add feature map visualization

Add a feature_visualization function to visualize the mid feature map of the model.

* Update yolo.py

* remove boolean from forward and reorder if statement

* remove print from forward

* General cleanup

* Indent

* Update plots.py

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
Zigars and glenn-jocher authored Jun 28, 2021
1 parent 3974d72 commit 20d45aa
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
6 changes: 5 additions & 1 deletion models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from models.experimental import *
from utils.autoanchor import check_anchor_order
from utils.general import make_divisible, check_file, set_logging
from utils.plots import feature_visualization
from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \
select_device, copy_attr

Expand Down Expand Up @@ -135,7 +136,7 @@ def forward_augment(self, x):
y.append(yi)
return torch.cat(y, 1), None # augmented inference, train

def forward_once(self, x, profile=False):
def forward_once(self, x, profile=False, feature_vis=False):
y, dt = [], [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
Expand All @@ -153,6 +154,9 @@ def forward_once(self, x, profile=False):

x = m(x) # run
y.append(x if m.i in self.save else None) # save output

if feature_vis and m.type == 'models.common.SPP':
feature_visualization(x, m.type, m.i)

if profile:
logger.info('%.1fms total' % sum(dt))
Expand Down
30 changes: 28 additions & 2 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import torch
import yaml
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms

from utils.general import xywh2xyxy, xyxy2xywh
from utils.general import increment_path, xywh2xyxy, xyxy2xywh
from utils.metrics import fitness

# Settings
Expand Down Expand Up @@ -299,7 +300,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
matplotlib.use('svg') # faster
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
# [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
# [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
ax[0].set_ylabel('instances')
if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names)))
Expand Down Expand Up @@ -445,3 +446,28 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):

ax[1].legend()
fig.savefig(Path(save_dir) / 'results.png', dpi=200)


def feature_visualization(features, module_type, module_idx, n=64):
"""
features: Features to be visualized
module_type: Module type
module_idx: Module layer index within model
n: Maximum number of feature maps to plot
"""
project, name = 'runs/features', 'exp'
save_dir = increment_path(Path(project) / name) # increment run
save_dir.mkdir(parents=True, exist_ok=True) # make dir

plt.figure(tight_layout=True)
blocks = torch.chunk(features, features.shape[1], dim=1) # block by channel dimension
n = min(n, len(blocks))
for i in range(n):
feature = transforms.ToPILImage()(blocks[i].squeeze())
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
ax.axis('off')
plt.imshow(feature) # cmap='gray'

f = f"layer_{module_idx}_{module_type.split('.')[-1]}_features.png"
print(f'Saving {save_dir / f}...')
plt.savefig(save_dir / f, dpi=300)

0 comments on commit 20d45aa

Please sign in to comment.