diff --git a/jupyter_server_fileid/manager.py b/jupyter_server_fileid/manager.py index bff2d77..e6d2d24 100644 --- a/jupyter_server_fileid/manager.py +++ b/jupyter_server_fileid/manager.py @@ -78,6 +78,32 @@ def _validate_db_path(self, proposal): def _uuid() -> str: return str(uuid.uuid4()) + def _normalize_path(self, path: str) -> str: + """Accepts an API path and returns a filesystem path, i.e. one prefixed + by root_dir and uses os.path.sep.""" + # use commonprefix instead of commonpath, since root_dir may not be a + # absolute POSIX path. + if os.path.commonprefix([self.root_dir, path]) != self.root_dir: + path = os.path.join(self.root_dir, path) + + return path + + def _from_normalized_path(self, path: Optional[str]) -> Optional[str]: + """Accepts a filesystem path and returns an API path, i.e. one relative + to root_dir and uses forward slashes as the path separator. Returns + `None` if the given path is None or is not relative to root_dir.""" + if path is None: + return None + + if os.path.commonprefix([self.root_dir, path]) != self.root_dir: + return None + + relpath = os.path.relpath(path, self.root_dir) + # always use forward slashes to delimit children + relpath = relpath.replace(os.path.sep, "/") + + return relpath + @abstractmethod def index(self, path: str) -> Optional[str]: """Returns the file ID for the file corresponding to `path`. @@ -98,9 +124,14 @@ def get_id(self, path: str) -> Optional[str]: @abstractmethod def get_path(self, id: str) -> Optional[str]: - """Retrieves the file path associated with the given file ID. + """ + Accepts a file ID and returns the API path to that file. Returns None if + the file ID does not exist. - Returns None if the file ID does not exist. + Notes + ----- + - See `_from_normalized_path()` for implementation details on how to + convert a filesystem path to an API path. """ pass @@ -191,11 +222,13 @@ def __init__(self, *args, **kwargs): self.con.commit() def _create(self, path: str) -> str: + path = self._normalize_path(path) id = self._uuid() self.con.execute("INSERT INTO Files (id, path) VALUES (?, ?)", (id, path)) return id def index(self, path: str) -> str: + path = self._normalize_path(path) row = self.con.execute("SELECT id FROM Files WHERE path = ?", (path,)).fetchone() existing_id = row and row[0] @@ -208,14 +241,18 @@ def index(self, path: str) -> str: return id def get_id(self, path: str) -> Optional[str]: + path = self._normalize_path(path) row = self.con.execute("SELECT id FROM Files WHERE path = ?", (path,)).fetchone() return row and row[0] def get_path(self, id: str) -> Optional[str]: row = self.con.execute("SELECT path FROM Files WHERE id = ?", (id,)).fetchone() - return row and row[0] + path = row and row[0] + return self._from_normalized_path(path) def move(self, old_path: str, new_path: str) -> None: + old_path = self._normalize_path(old_path) + new_path = self._normalize_path(new_path) row = self.con.execute("SELECT id FROM Files WHERE path = ?", (old_path,)).fetchone() id = row and row[0] @@ -228,11 +265,14 @@ def move(self, old_path: str, new_path: str) -> None: return id def copy(self, from_path: str, to_path: str) -> Optional[str]: + from_path = self._normalize_path(from_path) + to_path = self._normalize_path(to_path) id = self._create(to_path) self.con.commit() return id def delete(self, path: str) -> None: + path = self._normalize_path(path) self.con.execute("DELETE FROM Files WHERE path = ?", (path,)) self.con.commit() @@ -325,6 +365,13 @@ def __init__(self, *args, **kwargs): self.con.execute("CREATE INDEX IF NOT EXISTS ix_Files_is_dir ON Files (is_dir)") self.con.commit() + def _normalize_path(self, path): + path = super()._normalize_path(path) + path = os.path.normcase(path) + path = os.path.normpath(path) + + return path + def _index_all(self): """Recursively indexes all directories under the server root.""" self._index_dir_recursively(self.root_dir, self._stat(self.root_dir)) @@ -491,14 +538,6 @@ def _sync_file(self, path, stat_info): return id - def _normalize_path(self, path): - """Normalizes a given file path.""" - if not os.path.isabs(path): - path = os.path.join(self.root_dir, path) - path = os.path.normcase(path) - path = os.path.normpath(path) - return path - def _parse_raw_stat(self, raw_stat): """Accepts an `os.stat_result` object and returns a `StatStruct` object.""" @@ -665,13 +704,8 @@ def get_path(self, id): if ino != stat_info.ino or not self._check_timestamps(stat_info): return None - # if path is not relative to `self.root_dir`, return None. - if os.path.commonpath([self.root_dir, path]) != self.root_dir: - return None - - # finally, convert the path to a relative one. - path = os.path.relpath(path, self.root_dir) - return path + # finally, convert the path to a relative one and return it + return self._from_normalized_path(path) def _move_recursive(self, old_path, new_path): """Updates path of all indexed files prefixed with `old_path` and diff --git a/jupyter_server_fileid/pytest_plugin.py b/jupyter_server_fileid/pytest_plugin.py index 7a3492a..cc07e29 100644 --- a/jupyter_server_fileid/pytest_plugin.py +++ b/jupyter_server_fileid/pytest_plugin.py @@ -40,14 +40,26 @@ def fid_manager(fid_db_path, jp_root_dir): return fid_manager +@pytest.fixture +def arbitrary_fid_manager(fid_db_path, jp_root_dir): + """Fixture returning a test-configured instance of `ArbitraryFileIdManager`.""" + arbitrary_fid_manager = ArbitraryFileIdManager(db_path=fid_db_path, root_dir=str(jp_root_dir)) + arbitrary_fid_manager.con.execute("PRAGMA journal_mode = OFF") + return arbitrary_fid_manager + + @pytest.fixture(params=["local", "arbitrary"]) -def any_fid_manager(request, fid_db_path, jp_root_dir): +def any_fid_manager_class(request): """Parametrized fixture that runs the test with each of the default File ID manager implementations.""" class_by_param = {"local": LocalFileIdManager, "arbitrary": ArbitraryFileIdManager} + return class_by_param[request.param] - fid_manager = class_by_param[request.param](db_path=fid_db_path, root_dir=str(jp_root_dir)) - fid_manager.con.execute("PRAGMA journal_mode = OFF") # type: ignore[attr-defined] + +@pytest.fixture +def any_fid_manager(any_fid_manager_class, fid_db_path, jp_root_dir): + fid_manager = any_fid_manager_class(db_path=fid_db_path, root_dir=str(jp_root_dir)) + fid_manager.con.execute("PRAGMA journal_mode = OFF") return fid_manager diff --git a/tests/test_manager.py b/tests/test_manager.py index 84b5671..00ef7b9 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -1,3 +1,4 @@ +import ntpath import os from unittest.mock import patch @@ -78,25 +79,51 @@ def get_path_nosync(fid_manager, id): def test_validates_root_dir(fid_db_path): - rel_root_dir = root_dir = os.path.join("some", "rel", "path") + root_dir = "s3://bucket" with pytest.raises(TraitError, match="must be an absolute path"): - LocalFileIdManager(root_dir=rel_root_dir, db_path=fid_db_path) + LocalFileIdManager(root_dir=root_dir, db_path=fid_db_path) # root_dir can be relative for ArbitraryFileIdManager instances (and None) - afm = ArbitraryFileIdManager(root_dir=rel_root_dir, db_path=fid_db_path) - assert afm.root_dir == rel_root_dir + afm = ArbitraryFileIdManager(root_dir=root_dir, db_path=fid_db_path) + assert afm.root_dir == root_dir afm2 = ArbitraryFileIdManager(root_dir=None, db_path=fid_db_path) assert afm2.root_dir is None -def test_validates_db_path(jp_root_dir): +def test_validates_db_path(jp_root_dir, any_fid_manager_class): with pytest.raises(TraitError, match="must be an absolute path"): - LocalFileIdManager(root_dir=str(jp_root_dir), db_path=os.path.join("some", "rel", "path")) - with pytest.raises(TraitError, match="must be an absolute path"): - ArbitraryFileIdManager( + any_fid_manager_class( root_dir=str(jp_root_dir), db_path=os.path.join("some", "rel", "path") ) +def test_different_roots( + any_fid_manager_class, fid_db_path, jp_root_dir, test_path, test_path_child +): + """Assert that default FIM implementations assign the same file the same + file ID agnostic of contents manager root.""" + fid_manager_1 = any_fid_manager_class(db_path=fid_db_path, root_dir=str(jp_root_dir)) + fid_manager_2 = any_fid_manager_class( + db_path=fid_db_path, root_dir=str(jp_root_dir / test_path) + ) + + id_1 = fid_manager_1.index(test_path_child) + id_2 = fid_manager_2.index(os.path.basename(test_path_child)) + + assert id_1 == id_2 + + +def test_different_roots_arbitrary(fid_db_path): + """Assert that ArbitraryFileIdManager assigns the same file the same file ID + agnostic of contents manager root, even if non-local.""" + manager_1 = ArbitraryFileIdManager(db_path=fid_db_path, root_dir="s3://bucket") + manager_2 = ArbitraryFileIdManager(db_path=fid_db_path, root_dir="s3://bucket/folder") + + id_1 = manager_1.index("folder/child") + id_2 = manager_2.index("child") + + assert id_1 == id_2 + + def test_index(any_fid_manager, test_path): id = any_fid_manager.index(test_path) assert id is not None @@ -250,6 +277,33 @@ def test_get_id_oob_move_new_file_at_old_path(fid_manager, old_path, new_path, f assert fid_manager.get_id(other_path) == other_id +def test_get_path_arbitrary_preserves_path(arbitrary_fid_manager): + """Tests whether ArbitraryFileIdManager always preserves the file paths it + receives.""" + path = "AbCd.txt" + id = arbitrary_fid_manager.index(path) + assert path == arbitrary_fid_manager.get_path(id) + + +@patch("os.path.sep", new="\\") +@patch("os.path.relpath", new=ntpath.relpath) +@patch("os.path.normpath", new=ntpath.normpath) +@patch("os.path.join", new=ntpath.join) +def test_get_path_returns_api_path(jp_root_dir, fid_db_path): + """Tests whether get_path() method always returns an API path, i.e. one + relative to the server root and one delimited by forward slashes (even if + os.path.sep = "\\").""" + test_path = "a\\b\\c" + expected_path = "a/b/c" + manager = ArbitraryFileIdManager( + root_dir=ntpath.join("c:", ntpath.normpath(str(jp_root_dir))), db_path=fid_db_path + ) + + id = manager.index(test_path) + path = manager.get_path(id) + assert path == expected_path + + def test_get_path_oob_move(fid_manager, old_path, new_path, fs_helpers): id = fid_manager.index(old_path) fs_helpers.move(old_path, new_path)