diff --git a/models/common.py b/models/common.py index 9764d4c3a6c0..689aa0f3ed7c 100644 --- a/models/common.py +++ b/models/common.py @@ -223,18 +223,18 @@ def forward(self, x): return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) -class autoShape(nn.Module): +class AutoShape(nn.Module): # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS conf = 0.25 # NMS confidence threshold iou = 0.45 # NMS IoU threshold classes = None # (optional list) filter by class def __init__(self, model): - super(autoShape, self).__init__() + super(AutoShape, self).__init__() self.model = model.eval() def autoshape(self): - print('autoShape already enabled, skipping... ') # model already converted to model.autoshape() + print('AutoShape already enabled, skipping... ') # model already converted to model.autoshape() return self @torch.no_grad() diff --git a/models/yolo.py b/models/yolo.py index 314fd806f5e7..06b80032d3d3 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -215,9 +215,9 @@ def nms(self, mode=True): # add or remove NMS module self.model = self.model[:-1] # remove return self - def autoshape(self): # add autoShape module - logger.info('Adding autoShape... ') - m = autoShape(self) # wrap model + def autoshape(self): # add AutoShape module + logger.info('Adding AutoShape... ') + m = AutoShape(self) # wrap model copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes return m