Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ To initialize your workspace, first run the `graphrag init` command.
graphrag init
```

When prompted, specify the default chat and embedding models you would like to use in your config.

This will create two files, `.env` and `settings.yaml`, and a directory `input`, in the current directory.

- `input` Location of text files to process with `graphrag`.
Expand Down
4 changes: 0 additions & 4 deletions packages/graphrag/graphrag/api/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ async def build_index(
config: GraphRagConfig,
method: IndexingMethod | str = IndexingMethod.Standard,
is_update_run: bool = False,
memory_profile: bool = False,
callbacks: list[WorkflowCallbacks] | None = None,
additional_context: dict[str, Any] | None = None,
verbose: bool = False,
Expand Down Expand Up @@ -67,9 +66,6 @@ async def build_index(

outputs: list[PipelineRunResult] = []

if memory_profile:
logger.warning("New pipeline does not yet support memory profiling.")

logger.info("Initializing indexing pipeline...")
# todo: this could propagate out to the cli for better clarity, but will be a breaking api change
method = _get_method(method, is_update_run)
Expand Down
6 changes: 0 additions & 6 deletions packages/graphrag/graphrag/cli/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def index_cli(
root_dir: Path,
method: IndexingMethod,
verbose: bool,
memprofile: bool,
cache: bool,
dry_run: bool,
skip_validation: bool,
Expand All @@ -55,7 +54,6 @@ def index_cli(
method=method,
is_update_run=False,
verbose=verbose,
memprofile=memprofile,
cache=cache,
dry_run=dry_run,
skip_validation=skip_validation,
Expand All @@ -66,7 +64,6 @@ def update_cli(
root_dir: Path,
method: IndexingMethod,
verbose: bool,
memprofile: bool,
cache: bool,
skip_validation: bool,
):
Expand All @@ -80,7 +77,6 @@ def update_cli(
method=method,
is_update_run=True,
verbose=verbose,
memprofile=memprofile,
cache=cache,
dry_run=False,
skip_validation=skip_validation,
Expand All @@ -92,7 +88,6 @@ def _run_index(
method,
is_update_run,
verbose,
memprofile,
cache,
dry_run,
skip_validation,
Expand Down Expand Up @@ -129,7 +124,6 @@ def _run_index(
config=config,
method=method,
is_update_run=is_update_run,
memory_profile=memprofile,
callbacks=[ConsoleWorkflowCallbacks(verbose=verbose)],
verbose=verbose,
)
Expand Down
11 changes: 8 additions & 3 deletions packages/graphrag/graphrag/cli/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
logger = logging.getLogger(__name__)


def initialize_project_at(path: Path, force: bool) -> None:
def initialize_project_at(
path: Path, force: bool, model: str, embedding_model: str
) -> None:
"""
Initialize the project at the given path.

Expand Down Expand Up @@ -64,8 +66,11 @@ def initialize_project_at(path: Path, force: bool) -> None:
root / (graphrag_config_defaults.input.storage.base_dir or "input")
).resolve()
input_path.mkdir(parents=True, exist_ok=True)

settings_yaml.write_text(INIT_YAML, encoding="utf-8", errors="strict")
# using replace with custom tokens instead of format here because we have a placeholder for GRAPHRAG_API_KEY that is used later for .env overlay
formatted = INIT_YAML.replace("<DEFAULT_CHAT_MODEL>", model).replace(
"<DEFAULT_EMBEDDING_MODEL>", embedding_model
)
settings_yaml.write_text(formatted, encoding="utf-8", errors="strict")

dotenv = root / ".env"
if not dotenv.exists() or force:
Expand Down
34 changes: 20 additions & 14 deletions packages/graphrag/graphrag/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

import typer

from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.defaults import (
DEFAULT_CHAT_MODEL,
DEFAULT_EMBEDDING_MODEL,
graphrag_config_defaults,
)
from graphrag.config.enums import IndexingMethod, SearchMethod
from graphrag.prompt_tune.defaults import LIMIT, MAX_TOKEN_COUNT, N_SUBSET_MAX, K
from graphrag.prompt_tune.types import DocSelectionType
Expand Down Expand Up @@ -104,6 +108,18 @@ def _initialize_cli(
resolve_path=True,
autocompletion=ROOT_AUTOCOMPLETE,
),
model: str = typer.Option(
DEFAULT_CHAT_MODEL,
"--model",
"-m",
prompt="Specify the default chat model to use",
),
embedding_model: str = typer.Option(
DEFAULT_EMBEDDING_MODEL,
"--embedding",
"-e",
prompt="Specify the default embedding model to use",
),
force: bool = typer.Option(
False,
"--force",
Expand All @@ -114,7 +130,9 @@ def _initialize_cli(
"""Generate a default configuration file."""
from graphrag.cli.initialize import initialize_project_at

initialize_project_at(path=root, force=force)
initialize_project_at(
path=root, force=force, model=model, embedding_model=embedding_model
)


@app.command("index")
Expand Down Expand Up @@ -143,11 +161,6 @@ def _index_cli(
"-v",
help="Run the indexing pipeline with verbose logging",
),
memprofile: bool = typer.Option(
False,
"--memprofile",
help="Run the indexing pipeline with memory profiling",
),
dry_run: bool = typer.Option(
False,
"--dry-run",
Expand All @@ -173,7 +186,6 @@ def _index_cli(
index_cli(
root_dir=root,
verbose=verbose,
memprofile=memprofile,
cache=cache,
dry_run=dry_run,
skip_validation=skip_validation,
Expand Down Expand Up @@ -207,11 +219,6 @@ def _update_cli(
"-v",
help="Run the indexing pipeline with verbose logging.",
),
memprofile: bool = typer.Option(
False,
"--memprofile",
help="Run the indexing pipeline with memory profiling.",
),
cache: bool = typer.Option(
True,
"--cache/--no-cache",
Expand All @@ -233,7 +240,6 @@ def _update_cli(
update_cli(
root_dir=root,
verbose=verbose,
memprofile=memprofile,
cache=cache,
skip_validation=skip_validation,
method=method,
Expand Down
2 changes: 1 addition & 1 deletion packages/graphrag/graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class LanguageModelDefaults:
n: int = 1
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
request_timeout: float = 180.0
request_timeout: float = 600.0
api_base: None = None
api_version: None = None
deployment_name: None = None
Expand Down
9 changes: 0 additions & 9 deletions packages/graphrag/graphrag/config/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,6 @@ def __init__(self, llm_type: str) -> None:
super().__init__(msg)


class LanguageModelConfigMissingError(ValueError):
"""Missing model configuration error."""

def __init__(self, key: str = "") -> None:
"""Init method definition."""
msg = f'A {key} model configuration is required. Please rerun `graphrag init` and set models["{key}"] in settings.yaml.'
super().__init__(msg)


class ConflictingSettingsError(ValueError):
"""Missing model configuration error."""

Expand Down
4 changes: 2 additions & 2 deletions packages/graphrag/graphrag/config/init_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
auth_type: {defs.DEFAULT_CHAT_MODEL_AUTH_TYPE.value} # or azure_managed_identity
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file, or remove if managed identity
model: {defs.DEFAULT_CHAT_MODEL}
model: <DEFAULT_CHAT_MODEL>
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-05-01-preview
model_supports_json: true # recommended if this is available for your model.
Expand All @@ -37,7 +37,7 @@
model_provider: {defs.DEFAULT_MODEL_PROVIDER}
auth_type: {defs.DEFAULT_EMBEDDING_MODEL_AUTH_TYPE.value}
api_key: ${{GRAPHRAG_API_KEY}}
model: {defs.DEFAULT_EMBEDDING_MODEL}
model: <DEFAULT_EMBEDDING_MODEL>
# api_base: https://<instance>.openai.azure.com
# api_version: 2024-05-01-preview
concurrent_requests: {language_model_defaults.concurrent_requests}
Expand Down
20 changes: 0 additions & 20 deletions packages/graphrag/graphrag/config/models/graph_rag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import graphrag.config.defaults as defs
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import VectorStoreType
from graphrag.config.errors import LanguageModelConfigMissingError
from graphrag.config.models.basic_search_config import BasicSearchConfig
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.chunking_config import ChunkingConfig
Expand Down Expand Up @@ -58,24 +57,6 @@ def __str__(self):
default=graphrag_config_defaults.models,
)

def _validate_models(self) -> None:
"""Validate the models configuration.

Ensure both a default chat model and default embedding model
have been defined. Other models may also be defined but
defaults are required for the time being as places of the
code fallback to default model configs instead
of specifying a specific model.

TODO: Don't fallback to default models elsewhere in the code.
Forcing code to specify a model to use and allowing for any
names for model configurations.
"""
if defs.DEFAULT_CHAT_MODEL_ID not in self.models:
raise LanguageModelConfigMissingError(defs.DEFAULT_CHAT_MODEL_ID)
if defs.DEFAULT_EMBEDDING_MODEL_ID not in self.models:
raise LanguageModelConfigMissingError(defs.DEFAULT_EMBEDDING_MODEL_ID)

def _validate_retry_services(self) -> None:
"""Validate the retry services configuration."""
retry_factory = RetryFactory()
Expand Down Expand Up @@ -329,7 +310,6 @@ def get_language_model_config(self, model_id: str) -> LanguageModelConfig:
@model_validator(mode="after")
def _validate_model(self):
"""Validate the model configuration."""
self._validate_models()
self._validate_input_pattern()
self._validate_input_base_dir()
self._validate_reporting_base_dir()
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/min-csv/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"period",
"size"
],
"max_runtime": 1200,
"max_runtime": 2000,
"expected_artifacts": ["community_reports.parquet"]
},
"create_final_text_units": {
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/text/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"period",
"size"
],
"max_runtime": 1200,
"max_runtime": 2000,
"expected_artifacts": ["community_reports.parquet"]
},
"create_final_text_units": {
Expand Down