Skip to content

Commit

Permalink
Avoid keeping redundant copies of model weights in memory during load (
Browse files Browse the repository at this point in the history
…ggerganov#42)

* don't keep copies of model weights in host memory

* adding type annotation

Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
  • Loading branch information
drdaxxy and jongwook authored Sep 23, 2022
1 parent a4fe05a commit f296bcd
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,19 @@
}


def _download(url: str, root: str) -> bytes:
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)

expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
download_target = os.path.join(root, os.path.basename(url))

if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")

if os.path.isfile(download_target):
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes
return model_bytes if in_memory else download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")

Expand All @@ -58,15 +57,15 @@ def _download(url: str, root: str) -> bytes:
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")

return model_bytes
return model_bytes if in_memory else download_target


def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())


def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None) -> Whisper:
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
"""
Load a Whisper ASR model
Expand All @@ -79,28 +78,33 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
Returns
-------
model : Whisper
The Whisper ASR model instance
"""

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache", "whisper")

if name in _MODELS:
model_bytes = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/whisper"))
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
elif os.path.isfile(name):
model_bytes = open(name, "rb").read()
checkpoint_file = open(name, "rb").read() if in_memory else name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

with io.BytesIO(model_bytes) as fp:
checkpoint = torch.load(fp, map_location="cpu")
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file

dims = ModelDimensions(**checkpoint["dims"])
state_dict = checkpoint["model_state_dict"]
model = Whisper(dims)
model.load_state_dict(state_dict)

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.load_state_dict(checkpoint["model_state_dict"])

return model.to(device)

0 comments on commit f296bcd

Please sign in to comment.