Skip to content

Commit

Permalink
refactor: implement __getnewargs__ method for DPPath subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
caic99 committed Nov 28, 2024
1 parent 12daca1 commit a1506af
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions deepmd/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ def __new__(cls, path: str, mode: str = "r"):
raise FileNotFoundError(f"{path} not found")
return super().__new__(cls)

def __getnewargs__(self):
return (self.path, self.mode)

@abstractmethod
def load_numpy(self) -> np.ndarray:
"""Load NumPy array.
Expand Down Expand Up @@ -116,6 +113,10 @@ def is_file(self) -> bool:
def is_dir(self) -> bool:
"""Check if self is directory."""

@abstractmethod
def __getnewargs__(self):
"""Return the arguments to be passed to __new__ when unpickling an instance."""

@abstractmethod
def __truediv__(self, key: str) -> "DPPath":
"""Used for / operator."""
Expand Down Expand Up @@ -174,6 +175,9 @@ def __init__(self, path: str, mode: str = "r") -> None:
else:
self.path = Path(path)

def __getnewargs__(self):
return (self.path, self.mode)

def load_numpy(self) -> np.ndarray:
"""Load NumPy array.
Expand Down Expand Up @@ -307,6 +311,9 @@ def __init__(self, path: str, mode: str = "r") -> None:
# h5 path: default is the root path
self._name = s[1] if len(s) > 1 else "/"

def __getnewargs__(self):
return (self.root_path, self.mode)

@classmethod
@lru_cache(None)
def _load_h5py(cls, path: str, mode: str = "r") -> h5py.File:
Expand Down

0 comments on commit a1506af

Please sign in to comment.