Skip to content

Commit

Permalink
support cli overrides for loading graphrag config
Browse files Browse the repository at this point in the history
  • Loading branch information
dworthen committed Jan 17, 2025
1 parent 416074d commit d355c02
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 98 deletions.
20 changes: 10 additions & 10 deletions graphrag/cli/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def index_cli(
output_dir: Path | None,
):
"""Run the pipeline with the given config."""
config = load_config(root_dir, config_filepath)
cli_overrides = {}
if output_dir:
cli_overrides["storage.base_dir"] = str(output_dir)
cli_overrides["reporting.base_dir"] = str(output_dir)
config = load_config(root_dir, config_filepath, cli_overrides)

_run_index(
config=config,
Expand All @@ -86,7 +90,6 @@ def index_cli(
logger=logger,
dry_run=dry_run,
skip_validation=skip_validation,
output_dir=output_dir,
)


Expand All @@ -101,7 +104,11 @@ def update_cli(
output_dir: Path | None,
):
"""Run the pipeline with the given config."""
config = load_config(root_dir, config_filepath)
cli_overrides = {}
if output_dir:
cli_overrides["storage.base_dir"] = str(output_dir)
cli_overrides["reporting.base_dir"] = str(output_dir)
config = load_config(root_dir, config_filepath, cli_overrides)

# Check if update storage exist, if not configure it with default values
if not config.update_index_storage:
Expand All @@ -122,7 +129,6 @@ def update_cli(
logger=logger,
dry_run=False,
skip_validation=skip_validation,
output_dir=output_dir,
)


Expand All @@ -135,17 +141,11 @@ def _run_index(
logger,
dry_run,
skip_validation,
output_dir,
):
progress_logger = LoggerFactory().create_logger(logger)
info, error, success = _logger(progress_logger)
run_id = resume or time.strftime("%Y%m%d-%H%M%S")

config.storage.base_dir = str(output_dir) if output_dir else config.storage.base_dir
config.reporting.base_dir = (
str(output_dir) if output_dir else config.reporting.base_dir
)

if not cache:
config.cache.type = CacheType.none

Expand Down
24 changes: 16 additions & 8 deletions graphrag/cli/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ def run_global_search(
Loads index files required for global search and calls the Query API.
"""
root = root_dir.resolve()
config = load_config(root, config_filepath)
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
cli_overrides = {}
if data_dir:
cli_overrides["storage.base_dir"] = str(data_dir)
config = load_config(root, config_filepath, cli_overrides)

dataframe_dict = _resolve_output_files(
config=config,
Expand Down Expand Up @@ -117,8 +119,10 @@ def run_local_search(
Loads index files required for local search and calls the Query API.
"""
root = root_dir.resolve()
config = load_config(root, config_filepath)
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
cli_overrides = {}
if data_dir:
cli_overrides["storage.base_dir"] = str(data_dir)
config = load_config(root, config_filepath, cli_overrides)

dataframe_dict = _resolve_output_files(
config=config,
Expand Down Expand Up @@ -207,8 +211,10 @@ def run_drift_search(
Loads index files required for local search and calls the Query API.
"""
root = root_dir.resolve()
config = load_config(root, config_filepath)
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
cli_overrides = {}
if data_dir:
cli_overrides["storage.base_dir"] = str(data_dir)
config = load_config(root, config_filepath, cli_overrides)

dataframe_dict = _resolve_output_files(
config=config,
Expand Down Expand Up @@ -291,8 +297,10 @@ def run_basic_search(
Loads index files required for basic search and calls the Query API.
"""
root = root_dir.resolve()
config = load_config(root, config_filepath)
config.storage.base_dir = str(data_dir) if data_dir else config.storage.base_dir
cli_overrides = {}
if data_dir:
cli_overrides["storage.base_dir"] = str(data_dir)
config = load_config(root, config_filepath, cli_overrides)

dataframe_dict = _resolve_output_files(
config=config,
Expand Down
120 changes: 40 additions & 80 deletions graphrag/config/load_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import json
import os
from abc import ABC, abstractmethod
from pathlib import Path
from string import Template
from typing import Any
Expand All @@ -19,83 +18,6 @@
_default_config_files = ["settings.yaml", "settings.yml", "settings.json"]


class _ConfigTextParser(ABC):
"""Abstract base class for parsing configuration text."""

@abstractmethod
def parse(self, text: str) -> dict[str, Any]:
"""Parse configuration text."""
raise NotImplementedError


class _ConfigYamlParser(_ConfigTextParser):
"""Parse yaml configuration."""

def parse(self, text: str) -> dict[str, Any]:
"""Parse yaml configuration text.
Parameters
----------
text : str
The yaml configuration text.
Returns
-------
dict[str, Any]
The parsed configuration.
"""
return yaml.safe_load(text)


class _ConfigJsonParser(_ConfigTextParser):
"""Parse json configuration."""

def parse(self, text: str) -> dict[str, Any]:
"""Parse json configuration text.
Parameters
----------
text : str
The json configuration text.
Returns
-------
dict[str, Any]
The parsed configuration.
"""
return json.loads(text)


def _get_config_parser(file_extension: str) -> _ConfigTextParser:
"""Get the configuration parser based on the file extension.
Parameters
----------
file_extension : str
The file extension.
Returns
-------
ConfigTextParser
The configuration parser.
Raises
------
ValueError
If the file extension is not supported.
"""
match file_extension:
case ".yaml" | ".yml":
return _ConfigYamlParser()
case ".json":
return _ConfigJsonParser()
case _:
msg = (
f"Unable to parse config. Unsupported file extension: {file_extension}"
)
raise ValueError(msg)


def _search_for_config_in_root_dir(root: str | Path) -> Path | None:
"""Resolve the config path from the given root directory.
Expand Down Expand Up @@ -190,9 +112,41 @@ def _get_config_path(root_dir: Path, config_filepath: Path | None) -> Path:
return config_path


def _apply_overrides(data: dict[str, Any], overrides: dict[str, Any]) -> None:
"""Apply the overrides to the raw configuration."""
for key, value in overrides.items():
keys = key.split(".")
target = data
current_path = keys[0]
for k in keys[:-1]:
current_path += f".{k}"
target_obj = target.get(k, {})
if not isinstance(target_obj, dict):
msg = f"Cannot override non-dict value: data[{current_path}] is not a dict."
raise TypeError(msg)
target[k] = target_obj
target = target[k]
target[keys[-1]] = value


def _parse(file_extension: str, contents: str) -> dict[str, Any]:
"""Parse configuration."""
match file_extension:
case ".yaml" | ".yml":
return yaml.safe_load(contents)
case ".json":
return json.loads(contents)
case _:
msg = (
f"Unable to parse config. Unsupported file extension: {file_extension}"
)
raise ValueError(msg)


def load_config(
root_dir: Path,
config_filepath: Path | None = None,
cli_overrides: dict[str, Any] | None = None,
) -> GraphRagConfig:
"""Load configuration from a file.
Expand All @@ -203,6 +157,9 @@ def load_config(
config_filepath : str | None
The path to the config file.
If None, searches for config file in root.
cli_overrides : dict[str, Any] | None
A flat dictionary of cli overrides.
Example: {'storage.base_dir': 'override_value'}
Returns
-------
Expand All @@ -215,6 +172,8 @@ def load_config(
If the config file is not found.
ValueError
If the config file extension is not supported.
TypeError
If applying cli overrides to the config fails.
KeyError
If config file references a non-existent environment variable.
ValidationError
Expand All @@ -224,8 +183,9 @@ def load_config(
config_path = _get_config_path(root, config_filepath)
_load_dotenv(config_path)
config_extension = config_path.suffix
config_parser = _get_config_parser(config_extension)
config_text = config_path.read_text(encoding="utf-8")
config_text = _parse_env_variables(config_text)
config_data = config_parser.parse(config_text)
config_data = _parse(config_extension, config_text)
if cli_overrides:
_apply_overrides(config_data, cli_overrides)
return create_graphrag_config(config_data, root_dir=str(root))
14 changes: 14 additions & 0 deletions tests/unit/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,20 @@ def test_load_minimal_config() -> None:
assert_graphrag_configs(actual, expected)


@mock.patch.dict(os.environ, {"CUSTOM_API_KEY": FAKE_API_KEY}, clear=True)
def test_load_config_with_cli_overrides() -> None:
cwd = Path(__file__).parent
root_dir = (cwd / "fixtures" / "minimal_config").resolve()
output_dir = "some_output_dir"
expected_storage_base_dir = root_dir / output_dir
expected = get_default_graphrag_config(str(root_dir))
expected.storage.base_dir = str(expected_storage_base_dir)
actual = load_config(
root_dir=root_dir, cli_overrides={"storage.base_dir": output_dir}
)
assert_graphrag_configs(actual, expected)


def test_load_config_missing_env_vars() -> None:
cwd = Path(__file__).parent
root_dir = (cwd / "fixtures" / "minimal_config_missing_env_var").resolve()
Expand Down

0 comments on commit d355c02

Please sign in to comment.