Skip to content

Commit

Permalink
Immediate checkpoint serialization (ultralytics#9437)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored Mar 31, 2024
1 parent 5e8cd02 commit e13e6d1
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions ultralytics/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,32 +477,37 @@ def _do_train(self, world_size=1):

def save_model(self):
"""Save model training checkpoints with additional metadata."""
import io
import pandas as pd # scope for faster startup

metrics = {**self.metrics, **{"fitness": self.fitness}}
results = {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
ckpt = {
"epoch": self.epoch,
"best_fitness": self.best_fitness,
"model": deepcopy(de_parallel(self.model)).half(),
"ema": deepcopy(self.ema.ema).half(),
"updates": self.ema.updates,
"optimizer": self.optimizer.state_dict(),
"train_args": vars(self.args), # save as dict
"train_metrics": metrics,
"train_results": results,
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
}

# Save last and best
torch.save(ckpt, self.last)
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
buffer = io.BytesIO()
torch.save(
{
"epoch": self.epoch,
"best_fitness": self.best_fitness,
"model": deepcopy(de_parallel(self.model)).half(),
"ema": deepcopy(self.ema.ema).half(),
"updates": self.ema.updates,
"optimizer": self.optimizer.state_dict(),
"train_args": vars(self.args), # save as dict
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
"train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
},
buffer,
)
serialized_ckpt = buffer.getvalue() # get the serialized content to save

# Save checkpoints
self.last.write_bytes(serialized_ckpt) # save last.pt
if self.best_fitness == self.fitness:
self.best.write_bytes(self.last.read_bytes()) # copy last.pt to best.pt
self.best.write_bytes(serialized_ckpt) # save best.pt
if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(self.last.read_bytes()) # copy last.pt to i.e. epoch3.pt
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'

@staticmethod
def get_dataset(data):
Expand Down

0 comments on commit e13e6d1

Please sign in to comment.