Skip to content

Commit

Permalink
Merge pull request #46 from naseemap47/resume
Browse files Browse the repository at this point in the history
Resume
  • Loading branch information
naseemap47 authored Sep 24, 2023
2 parents 0e8a725 + 73c0e39 commit d3a0557
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ You can train your **YOLO-NAS** model with **Single Command Line**
```
python3 train.py --data /dir/dataset/data.yaml --batch 6 --epoch 100 --model yolo_nas_m --size 640
```
### If your training ends in 65th epoch (total 100 epochs), now you can start from 65th epoch and complete your 100 epochs training.
**Example:**
```
python3 train.py --data /dir/dataset/data.yaml --batch 6 --epoch 100 --model yolo_nas_m --size 640 \
--weight runs/train2/ckpt_latest.pth --resume
```

## 📺 Inference
You can Inference your **YOLO-NAS** model with **Single Command Line**
Expand Down
32 changes: 22 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,25 @@

s_time = time.time()


if args['name'] is None:
name = 'train'
else:
name = args['name']
n = 0
while True:
if not os.path.exists(os.path.join('runs', f'{name}{n}')):
name = f'{name}{n}'
os.makedirs(os.path.join('runs', name))
print(f"[INFO] Checkpoints saved in \033[1m{os.path.join('runs', name)}\033[0m")
break
else:
n += 1


if args['resume']:
name = os.path.split(args['weight'])[0].split('/')[-1]
else:
n = 0
while True:
if not os.path.exists(os.path.join('runs', f'{name}{n}')):
name = f'{name}{n}'
os.makedirs(os.path.join('runs', name))
break
else:
n += 1

print(f"[INFO] Checkpoints saved in \033[1m{os.path.join('runs', name)}\033[0m")
# Training on GPU or CPU
if args['cpu']:
print('[INFO] Training on \033[1mCPU\033[0m')
Expand Down Expand Up @@ -184,6 +189,13 @@
"metric_to_watch": 'mAP@0.50'
}

# to Resume Training
if args['resume']:
train_params['resume'] = True

# Print Training Params
print('[INFO] Training Params:\n', train_params)

trainer.train(
model=model,
training_params=train_params,
Expand Down

0 comments on commit d3a0557

Please sign in to comment.