1818
1919import os
2020import typing
21- from typing import Iterator
21+ from typing import Iterator , Sequence
2222
2323from etils import epath
2424from orbax .checkpoint ._src .multihost import multihost
2525
2626
27+ _LOCAL_PATH_BASE_NAME = '_local_path_base'
28+ _LOCAL_PART_PREFIX = 'local'
29+
30+
2731# The following is a hack to pass the type checker.
2832if typing .TYPE_CHECKING :
2933 _BasePath = epath .Path
3034else :
3135 _BasePath = object
3236
3337
38+ def create_local_path_base (testclass ) -> epath .Path :
39+ return epath .Path (
40+ testclass .create_tempdir (name = _LOCAL_PATH_BASE_NAME ).full_path
41+ )
42+
43+
44+ def _get_local_part_index (parts : Sequence [str ]) -> int :
45+ for i , part in enumerate (parts ):
46+ if part .startswith (_LOCAL_PART_PREFIX ):
47+ return i
48+ raise ValueError (
49+ f'Did not find a local part ({ _LOCAL_PART_PREFIX } ) in parts: { parts } '
50+ )
51+
52+
3453class LocalPath (_BasePath ):
3554 """A Path implementation for testing process-local paths.
3655
56+ IMPORTANT: Use `create_local_path_base` to create the base path for test
57+ cases.
58+
3759 In the future, this class may more completely provide all functions and
3860 properties of a pathlib Path, but for now, it only provides the minimum
3961 needed to support relevant tests.
@@ -54,9 +76,12 @@ class LocalPath(_BasePath):
5476
5577 def __init__ (self , * parts : epath .PathLike ):
5678 self ._path = epath .Path ('/' .join (os .fspath (p ) for p in parts ))
79+ # Assumes this class will always be constructed on the controller first
80+ # (otherwise this check will return the wrong value).
81+ self ._is_pathways_backend = multihost .is_pathways_backend ()
5782
5883 def __repr__ (self ) -> str :
59- return f'{ self . __class__ . __name__ } ({ self .path } )'
84+ return f'LocalPath ({ self .path } )'
6085
6186 def __str__ (self ) -> str :
6287 return str (self .path )
@@ -67,7 +92,40 @@ def base_path(self) -> epath.Path:
6792
6893 @property
6994 def path (self ) -> epath .Path :
70- return self .base_path / str (f'local_{ multihost .process_index ()} ' )
95+ parts = list (self .base_path .parts )
96+
97+ # Fail if the path is not properly configured. The local part should be
98+ # immediately following the base name.
99+ try :
100+ base_idx = parts .index (_LOCAL_PATH_BASE_NAME )
101+ except ValueError as e :
102+ raise ValueError (
103+ f'Base path for LocalPath must contain { _LOCAL_PATH_BASE_NAME } . Got:'
104+ f' { self .base_path } '
105+ ) from e
106+
107+ if multihost .is_pathways_controller ():
108+ local_part = f'{ _LOCAL_PART_PREFIX } _controller'
109+ else :
110+ local_part = f'{ _LOCAL_PART_PREFIX } _{ multihost .process_index ()} '
111+
112+ try :
113+ # If the local part is already present, potentially replace it with the
114+ # correct local part (e.g. controller vs worker).
115+ local_part_idx = _get_local_part_index (parts )
116+ assert local_part_idx == base_idx + 1
117+ parts [local_part_idx ] = local_part
118+ return epath .Path (* parts )
119+ except ValueError :
120+ pass
121+
122+ # Otherwise, insert following the base part.
123+ parts .insert (base_idx + 1 , local_part )
124+ return epath .Path (* parts )
125+
126+ @property
127+ def parts (self ) -> tuple [str , ...]:
128+ return self .path .parts
71129
72130 def exists (self ) -> bool :
73131 """Returns True if self exists."""
@@ -119,6 +177,14 @@ def unlink(self, missing_ok: bool = False) -> None:
119177 """Remove this file or symbolic link."""
120178 self .path .unlink (missing_ok = missing_ok )
121179
180+ def touch (self , mode : int = 0o666 , exist_ok : bool = False ) -> None :
181+ """Creates the file at this path."""
182+ self .path .touch (exist_ok = exist_ok )
183+
184+ def rename (self , new_path : epath .PathLike ) -> None :
185+ """Renames this file or directory to the given path."""
186+ self .path .rename (new_path )
187+
122188 def write_bytes (self , data : bytes ) -> int :
123189 """Writes content as bytes."""
124190 return self .path .write_bytes (data )
@@ -135,16 +201,16 @@ def write_text(
135201 def as_posix (self ) -> str :
136202 return self .path .as_posix ()
137203
138- def __truediv__ (self , key : epath .PathLike ) -> epath . Path :
139- return self .path / key
204+ def __truediv__ (self , key : epath .PathLike ) -> LocalPath :
205+ return LocalPath ( self .path / key )
140206
141207 @property
142208 def name (self ) -> str :
143209 return self .path .name
144210
145211 @property
146- def parent (self ) -> epath . Path :
147- return self .path .parent
212+ def parent (self ) -> LocalPath :
213+ return LocalPath ( self .path .parent )
148214
149215 def __fspath__ (self ) -> str :
150216 return os .fspath (self .path )
0 commit comments