Skip to content

Commit

Permalink
[NeMo-UX] Adding file-lock to Connector (#9400)
Browse files Browse the repository at this point in the history
* Adding file-lock to Connector

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

* Fixing bug in path in mistral-7b

* Fixing bug with overwrite

* Apply isort and black reformatting

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>

---------

Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com>
Co-authored-by: marcromeyn <marcromeyn@users.noreply.github.com>
  • Loading branch information
marcromeyn and marcromeyn authored Jun 10, 2024
1 parent 445b9b1 commit 8c58e13
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
4 changes: 3 additions & 1 deletion nemo/collections/llm/gpt/model/mistral_7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Mistral7BConfig(GPTConfig):

class Mistral7BModel(GPTModel):
def __init__(self, config: Optional[Mistral7BConfig] = None, tokenizer=None):
_tokenizer = tokenizer or HFMistral7BImporter().tokenizer
_tokenizer = tokenizer or HFMistral7BImporter("mistralai/Mistral-7B-v0.1").tokenizer

super().__init__(config or Mistral7BConfig(), _tokenizer)

Expand All @@ -56,6 +56,8 @@ def apply(self, output_path: Path) -> Path:
self.convert_state(source, target)
self.nemo_save(output_path, trainer)

print(f"Converted Mistral 7B model to Nemo, model saved to {output_path}")

teardown(trainer, target)
del trainer, target

Expand Down
29 changes: 24 additions & 5 deletions nemo/lightning/io/connector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
import os
import shutil
from pathlib import Path, PosixPath, WindowsPath
from typing import Generic, Optional, Tuple, TypeVar

import pytorch_lightning as pl
from filelock import FileLock, Timeout

# Dynamically inherit from the correct Path subclass based on the operating system.
if os.name == 'nt':
Expand Down Expand Up @@ -47,6 +49,7 @@ class Connector(BasePath, Generic[SourceT, TargetT]):
"""

default_path = None
LOCK_TIMEOUT = 1200

def init(self) -> TargetT:
raise NotImplementedError()
Expand All @@ -63,13 +66,29 @@ def __new__(cls, *args, **kwargs):

def __call__(self, output_path: Optional[Path] = None, overwrite: bool = False) -> Path:
_output_path = output_path or self.local_path()
lock_path = _output_path.with_suffix(_output_path.suffix + '.lock')
lock = FileLock(lock_path)

if overwrite and _output_path.exists():
shutil.rmtree(_output_path)
# Check if the lock file exists and set overwrite to False if it does
if lock_path.exists():
overwrite = False

if not _output_path.exists():
to_return = self.apply(_output_path)
_output_path = to_return or _output_path
try:
with lock.acquire(timeout=self.LOCK_TIMEOUT):
if overwrite and _output_path.exists():
shutil.rmtree(_output_path)

if not _output_path.exists():
to_return = self.apply(_output_path)
_output_path = to_return or _output_path

except Timeout:
logging.error(f"Timeout occurred while trying to acquire the lock for {_output_path}")
raise

except Exception as e:
logging.error(f"An error occurred: {e}")
raise

return _output_path

Expand Down

0 comments on commit 8c58e13

Please sign in to comment.