diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index e383646949..f3a8b4a146 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -1155,14 +1155,14 @@ def _default_compiledirname() -> str: return safe -def _filter_base_compiledir(path: Path) -> Path: +def _filter_base_compiledir(path: str | Path) -> Path: # Expand '~' in path - return path.expanduser() + return Path(path).expanduser() -def _filter_compiledir(path: Path) -> Path: +def _filter_compiledir(path: str | Path) -> Path: # Expand '~' in path - path = path.expanduser() + path = Path(path).expanduser() # Turn path into the 'real' path. This ensures that: # 1. There is no relative path, which would fail e.g. when trying to # import modules from the compile dir. diff --git a/tests/test_config.py b/tests/test_config.py index 59e261294d..65705c6988 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,6 +3,7 @@ import configparser as stdlib_configparser import io import pickle +from pathlib import Path import pytest @@ -19,6 +20,18 @@ def _create_test_config(): ) +def test_base_compiledir_str(tmp_path: Path): + base_compiledir = tmp_path + assert ( + configdefaults._filter_base_compiledir(str(base_compiledir)) == base_compiledir + ) + + +def test_compiledir_str(tmp_path: Path): + compiledir = tmp_path + assert configdefaults._filter_compiledir(str(compiledir)) == compiledir + + def test_invalid_default(): # Ensure an invalid default value found in the PyTensor code only causes # a crash if it is not overridden by the user.