Skip to content

Commit

Permalink
Train from --data path/to/dataset.zip feature (ultralytics#4185)
Browse files Browse the repository at this point in the history
* Train from `--data path/to/dataset.zip` feature

* Update dataset_stats()

* cleanup

* cleanup2
  • Loading branch information
glenn-jocher authored Jul 28, 2021
1 parent 9d7e211 commit 846fec6
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 73 deletions.
2 changes: 1 addition & 1 deletion data/Argoverse_HD.yaml → data/Argoverse.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# YOLOv5 🚀 by Ultralytics https://ultralytics.com, licensed under GNU GPL v3.0
# Argoverse-HD dataset (ring-front-center camera) http://www.cs.cmu.edu/~mengtial/proj/streaming/
# Example usage: python train.py --data Argoverse_HD.yaml
# Example usage: python train.py --data Argoverse.yaml
# parent
# ├── yolov5
# └── datasets
Expand Down
2 changes: 1 addition & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo

from models.yolo import Model, attempt_load
from utils.general import check_requirements, set_logging
from utils.google_utils import attempt_download
from utils.downloads import attempt_download
from utils.torch_utils import select_device

file = Path(__file__).absolute()
Expand Down
2 changes: 1 addition & 1 deletion models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn

from models.common import Conv, DWConv
from utils.google_utils import attempt_download
from utils.downloads import attempt_download


class CrossConv(nn.Module):
Expand Down
11 changes: 4 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
check_requirements, print_mutation, set_logging, one_cycle, colorstr
from utils.google_utils import attempt_download
from utils.downloads import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_labels, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, de_parallel
Expand Down Expand Up @@ -78,9 +78,9 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
plots = not evolve # create plots
cuda = device.type != 'cpu'
init_seeds(1 + RANK)
with open(data, encoding='ascii', errors='ignore') as f:
data_dict = yaml.safe_load(f)

with torch_distributed_zero_first(RANK):
data_dict = check_dataset(data) # check
train_path, val_path = data_dict['train'], data_dict['val']
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
assert len(names) == nc, f'{len(names)} names found for nc={nc} dataset in {data}' # check
Expand All @@ -106,9 +106,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
else:
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
with torch_distributed_zero_first(RANK):
check_dataset(data_dict) # check
train_path, val_path = data_dict['train'], data_dict['val']

# Freeze
freeze = [] # parameter names to freeze (full or partial)
Expand Down
66 changes: 50 additions & 16 deletions utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,11 +884,11 @@ def verify_image_label(args):
return [None, None, None, None, nm, nf, ne, nc, msg]


def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False, profile=False, hub=False):
""" Return dataset statistics dictionary with images and instances counts per split per class
Usage1: from utils.datasets import *; dataset_stats('coco128.yaml', verbose=True)
Usage2: from utils.datasets import *; dataset_stats('../datasets/coco128.zip', verbose=True)
To run in parent directory: export PYTHONPATH="$PWD/yolov5"
Usage1: from utils.datasets import *; dataset_stats('coco128.yaml', autodownload=True)
Usage2: from utils.datasets import *; dataset_stats('../datasets/coco128_with_yaml.zip')
Arguments
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
autodownload: Attempt to download dataset if not found locally
Expand All @@ -897,46 +897,80 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):

def round_labels(labels):
# Update labels to integer class and 6 decimal place floats
return [[int(c), *[round(x, 6) for x in points]] for c, *points in labels]
return [[int(c), *[round(x, 4) for x in points]] for c, *points in labels]

def unzip(path):
# Unzip data.zip TODO: CONSTRAINT: path/to/abc.zip MUST unzip to 'path/to/abc/'
if str(path).endswith('.zip'): # path is data.zip
assert Path(path).is_file(), f'Error unzipping {path}, file not found'
assert os.system(f'unzip -q {path} -d {path.parent}') == 0, f'Error unzipping {path}'
data_dir = path.with_suffix('') # dataset directory
return True, data_dir, list(data_dir.rglob('*.yaml'))[0] # zipped, data_dir, yaml_path
dir = path.with_suffix('') # dataset directory
return True, str(dir), next(dir.rglob('*.yaml')) # zipped, data_dir, yaml_path
else: # path is data.yaml
return False, None, path

def hub_ops(f, max_dim=1920):
# HUB ops for 1 image 'f'
im = Image.open(f)
r = max_dim / max(im.height, im.width) # ratio
if r < 1.0: # image too large
im = im.resize((int(im.width * r), int(im.height * r)))
im.save(im_dir / Path(f).name, quality=75) # save

zipped, data_dir, yaml_path = unzip(Path(path))
with open(check_file(yaml_path), encoding='ascii', errors='ignore') as f:
data = yaml.safe_load(f) # data dict
if zipped:
data['path'] = data_dir # TODO: should this be dir.resolve()?
check_dataset(data, autodownload) # download dataset if missing
nc = data['nc'] # number of classes
stats = {'nc': nc, 'names': data['names']} # statistics dictionary
hub_dir = Path(data['path'] + ('-hub' if hub else ''))
stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
for split in 'train', 'val', 'test':
if data.get(split) is None:
stats[split] = None # i.e. no test set
continue
x = []
dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset
if split == 'train':
cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path
dataset = LoadImagesAndLabels(data[split]) # load dataset
for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
x.append(np.bincount(label[:, 0].astype(int), minlength=data['nc']))
x = np.array(x) # shape(128x80)
stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
'per_class': (x > 0).sum(0).tolist()},
'labels': [{str(Path(k).name): round_labels(v.tolist())} for k, v in
zip(dataset.img_files, dataset.labels)]}

if hub:
im_dir = hub_dir / 'images'
im_dir.mkdir(parents=True, exist_ok=True)
for _ in tqdm(ThreadPool(NUM_THREADS).imap(hub_ops, dataset.img_files), total=dataset.n, desc='HUB Ops'):
pass

# Profile
stats_path = hub_dir / 'stats.json'
if profile:
for _ in range(1):
file = stats_path.with_suffix('.npy')
t1 = time.time()
np.save(file, stats)
t2 = time.time()
x = np.load(file, allow_pickle=True)
print(f'stats.npy times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')

file = stats_path.with_suffix('.json')
t1 = time.time()
with open(file, 'w') as f:
json.dump(stats, f) # save stats *.json
t2 = time.time()
with open(file, 'r') as f:
x = json.load(f) # load hyps dict
print(f'stats.json times: {time.time() - t2:.3f}s read, {t2 - t1:.3f}s write')

# Save, print and return
with open(cache_path.with_suffix('.json'), 'w') as f:
json.dump(stats, f) # save stats *.json
if hub:
print(f'Saving {stats_path.resolve()}...')
with open(stats_path, 'w') as f:
json.dump(stats, f) # save stats.json
if verbose:
print(json.dumps(stats, indent=2, sort_keys=False))
# print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
return stats
6 changes: 5 additions & 1 deletion utils/google_utils.py → utils/downloads.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Google utils: https://cloud.google.com/storage/docs/reference/libraries
# Download utils

import os
import platform
Expand Down Expand Up @@ -115,6 +115,10 @@ def get_token(cookie="./cookie"):
return line.split()[-1]
return ""


# Google utils: https://cloud.google.com/storage/docs/reference/libraries ----------------------------------------------
#
#
# def upload_blob(bucket_name, source_file_name, destination_blob_name):
# # Uploads a file to a bucket
# # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python
Expand Down
40 changes: 29 additions & 11 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torchvision
import yaml

from utils.google_utils import gsutil_getsize
from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness
from utils.torch_utils import init_torch_seeds

Expand Down Expand Up @@ -224,16 +224,30 @@ def check_file(file):


def check_dataset(data, autodownload=True):
# Download dataset if not found locally
path = Path(data.get('path', '')) # optional 'path' field
if path:
for k in 'train', 'val', 'test':
if data.get(k): # prepend path
data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]
# Download and/or unzip dataset if not found locally
# Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip

# Download (optional)
extract_dir = ''
if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip
download(data, dir='../datasets', unzip=True, delete=False, curl=False, threads=1)
data = next((Path('../datasets') / Path(data).stem).rglob('*.yaml'))
extract_dir, autodownload = data.parent, False

# Read yaml (optional)
if isinstance(data, (str, Path)):
with open(data, encoding='ascii', errors='ignore') as f:
data = yaml.safe_load(f) # dictionary

# Parse yaml
path = extract_dir or Path(data.get('path') or '') # optional 'path' default to '.'
for k in 'train', 'val', 'test':
if data.get(k): # prepend path
data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]

assert 'nc' in data, "Dataset 'nc' key missing."
if 'names' not in data:
data['names'] = [str(i) for i in range(data['nc'])] # assign class names if missing
data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing
train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]
if val:
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
Expand All @@ -256,13 +270,17 @@ def check_dataset(data, autodownload=True):
else:
raise Exception('Dataset not found.')

return data # dictionary


def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):
# Multi-threaded file download and unzip function
# Multi-threaded file download and unzip function, used in data.yaml for autodownload
def download_one(url, dir):
# Download 1 file
f = dir / Path(url).name # filename
if not f.exists():
if Path(url).is_file(): # exists in current path
Path(url).rename(f) # move to dir
elif not f.exists():
print(f'Downloading {url} to {f}...')
if curl:
os.system(f"curl -L '{url}' -o '{f}' --retry 9 -C -") # curl download, retry and resume on fail
Expand All @@ -286,7 +304,7 @@ def download_one(url, dir):
pool.close()
pool.join()
else:
for u in tuple(url) if isinstance(url, str) else url:
for u in [url] if isinstance(url, (str, Path)) else url:
download_one(u, dir)


Expand Down
Loading

0 comments on commit 846fec6

Please sign in to comment.