diff --git a/CHANGELOG.md b/CHANGELOG.md index ca44a339..b41897d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file. ## [Unreleased] +- fixes a bug where specifying both relative and absolute paths would cause sqlfmt to crash ([#426](https://github.com/tconbeer/sqlfmt/issues/426) - thank you for the issue and fix, [@smcgivern](https://github.com/smcgivern)!) + ## [0.18.1] - 2023-05-10 - fixes a bug when lexing `union distinct` tokens ([#417](https://github.com/tconbeer/sqlfmt/issues/417) - thank you, [@paschmaria](https://github.com/paschmaria)!) diff --git a/src/sqlfmt/config.py b/src/sqlfmt/config.py index bdeffcd1..e73e1a53 100644 --- a/src/sqlfmt/config.py +++ b/src/sqlfmt/config.py @@ -34,7 +34,7 @@ def _get_common_parents(files: List[Path]) -> List[Path]: assert files, "Must provide a list of paths" common_parents: Set[Path] = set() for p in files: - parents = set(p.parents) + parents = set(p.absolute().parents) if p.is_dir(): parents.add(p) if not common_parents: diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index bb0ae2b9..8caabb41 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from typing import Any, List @@ -63,6 +64,29 @@ def test_find_config_file_not_in_tree( assert config_path is None +def test_find_config_file_relative_and_absolute( + tmp_path: Path, files_relpath: List[Path] +) -> None: + # Only check the cases where we are providing more than one path + if len(files_relpath) == 1: + return + + current_dir = os.getcwd() + copy_config_file_to_dst("valid_sqlfmt_config.toml", tmp_path) + + try: + os.chdir(tmp_path) + + files = [tmp_path / files_relpath[0], files_relpath[1]] + search_paths = _get_common_parents(files) + assert tmp_path in search_paths + config_path = _find_config_file(search_paths) + assert config_path + assert config_path == tmp_path / "pyproject.toml" + finally: + os.chdir(current_dir) + + def test_load_config_from_path(tmp_path: Path) -> None: copy_config_file_to_dst("valid_sqlfmt_config.toml", tmp_path) config = _load_config_from_path(tmp_path / "pyproject.toml")