Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/major-20250909010205372690.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "major",
"description": "Remove document filtering option."
}
1 change: 0 additions & 1 deletion docs/config/yaml.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ Our pipeline can ingest .csv, .txt, or .json data from an input location. See th
- `file_type` **text|csv|json** - The type of input data to load. Default is `text`
- `encoding` **str** - The encoding of the input file. Default is `utf-8`
- `file_pattern` **str** - A regex to match input files. Default is `.*\.csv$`, `.*\.txt$`, or `.*\.json$` depending on the specified `file_type`, but you can customize it if needed.
- `file_filter` **dict** - Key/value pairs to filter. Default is None.
- `text_column` **str** - (CSV/JSON only) The text column name. If unset we expect a column named `text`.
- `title_column` **str** - (CSV/JSON only) The title column name, filename will be used if unset.
- `metadata` **list[str]** - (CSV/JSON only) The additional document attributes fields to keep.
Expand Down
1 change: 0 additions & 1 deletion graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ class InputDefaults:
file_type: ClassVar[InputFileType] = InputFileType.text
encoding: str = "utf-8"
file_pattern: str = ""
file_filter: None = None
text_column: str = "text"
title_column: None = None
metadata: None = None
Expand Down
4 changes: 0 additions & 4 deletions graphrag/config/models/input_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ class InputConfig(BaseModel):
description="The input file pattern to use.",
default=graphrag_config_defaults.input.file_pattern,
)
file_filter: dict[str, str] | None = Field(
description="The optional file filter for the input files.",
default=graphrag_config_defaults.input.file_filter,
)
text_column: str = Field(
description="The input text column to use.",
default=graphrag_config_defaults.input.text_column,
Expand Down
11 changes: 1 addition & 10 deletions graphrag/index/input/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,10 @@ async def load_csv(
"""Load csv inputs from a directory."""
logger.info("Loading csv files from %s", config.storage.base_dir)

async def load_file(path: str, group: dict | None) -> pd.DataFrame:
if group is None:
group = {}
async def load_file(path: str) -> pd.DataFrame:
buffer = BytesIO(await storage.get(path, as_bytes=True))
data = pd.read_csv(buffer, encoding=config.encoding)
additional_keys = group.keys()
if len(additional_keys) > 0:
data[[*additional_keys]] = data.apply(
lambda _row: pd.Series([group[key] for key in additional_keys]), axis=1
)

data = process_data_columns(data, config, path)

creation_date = await storage.get_creation_date(path)
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)

Expand Down
12 changes: 1 addition & 11 deletions graphrag/index/input/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,13 @@ async def load_json(
"""Load json inputs from a directory."""
logger.info("Loading json files from %s", config.storage.base_dir)

async def load_file(path: str, group: dict | None) -> pd.DataFrame:
if group is None:
group = {}
async def load_file(path: str) -> pd.DataFrame:
text = await storage.get(path, encoding=config.encoding)
as_json = json.loads(text)
# json file could just be a single object, or an array of objects
rows = as_json if isinstance(as_json, list) else [as_json]
data = pd.DataFrame(rows)

additional_keys = group.keys()
if len(additional_keys) > 0:
data[[*additional_keys]] = data.apply(
lambda _row: pd.Series([group[key] for key in additional_keys]), axis=1
)

data = process_data_columns(data, config, path)

creation_date = await storage.get_creation_date(path)
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)

Expand Down
6 changes: 2 additions & 4 deletions graphrag/index/input/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@ async def load_text(
) -> pd.DataFrame:
"""Load text inputs from a directory."""

async def load_file(path: str, group: dict | None = None) -> pd.DataFrame:
if group is None:
group = {}
async def load_file(path: str) -> pd.DataFrame:
text = await storage.get(path, encoding=config.encoding)
new_item = {**group, "text": text}
new_item = {"text": text}
new_item["id"] = gen_sha512_hash(new_item, new_item.keys())
new_item["title"] = str(Path(path).name)
new_item["creation_date"] = await storage.get_creation_date(path)
Expand Down
11 changes: 3 additions & 8 deletions graphrag/index/input/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,17 @@ async def load_files(
storage: PipelineStorage,
) -> pd.DataFrame:
"""Load files from storage and apply a loader function."""
files = list(
storage.find(
re.compile(config.file_pattern),
file_filter=config.file_filter,
)
)
files = list(storage.find(re.compile(config.file_pattern)))

if len(files) == 0:
msg = f"No {config.file_type} files found in {config.storage.base_dir}"
raise ValueError(msg)

files_loaded = []

for file, group in files:
for file in files:
try:
files_loaded.append(await loader(file, group))
files_loaded.append(await loader(file))
except Exception as e: # noqa: BLE001 (catching Exception is fine here)
logger.warning("Warning! Error loading file %s. Skipping...", file)
logger.warning("Error: %s", e)
Expand Down
29 changes: 7 additions & 22 deletions graphrag/storage/blob_pipeline_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,13 @@ def find(
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
file_filter: dict[str, Any] | None = None,
max_count=-1,
) -> Iterator[tuple[str, dict[str, Any]]]:
"""Find blobs in a container using a file pattern, as well as a custom filter function.
) -> Iterator[str]:
"""Find blobs in a container using a file pattern.

Params:
base_dir: The name of the base container.
file_pattern: The file pattern to use.
file_filter: A dictionary of key-value pairs to filter the blobs.
max_count: The maximum number of blobs to return. If -1, all blobs are returned.

Returns
Expand All @@ -131,14 +129,6 @@ def _blobname(blob_name: str) -> str:
blob_name = blob_name[1:]
return blob_name

def item_filter(item: dict[str, Any]) -> bool:
if file_filter is None:
return True

return all(
re.search(value, item[key]) for key, value in file_filter.items()
)

try:
container_client = self._blob_service_client.get_container_client(
self._container_name
Expand All @@ -151,14 +141,10 @@ def item_filter(item: dict[str, Any]) -> bool:
for blob in all_blobs:
match = file_pattern.search(blob.name)
if match and blob.name.startswith(base_dir):
group = match.groupdict()
if item_filter(group):
yield (_blobname(blob.name), group)
num_loaded += 1
if max_count > 0 and num_loaded >= max_count:
break
else:
num_filtered += 1
yield _blobname(blob.name)
num_loaded += 1
if max_count > 0 and num_loaded >= max_count:
break
else:
num_filtered += 1
logger.debug(
Expand All @@ -169,10 +155,9 @@ def item_filter(item: dict[str, Any]) -> bool:
)
except Exception: # noqa: BLE001
logger.warning(
"Error finding blobs: base_dir=%s, file_pattern=%s, file_filter=%s",
"Error finding blobs: base_dir=%s, file_pattern=%s",
base_dir,
file_pattern,
file_filter,
)

async def get(
Expand Down
31 changes: 7 additions & 24 deletions graphrag/storage/cosmosdb_pipeline_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,13 @@ def find(
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
file_filter: dict[str, Any] | None = None,
max_count=-1,
) -> Iterator[tuple[str, dict[str, Any]]]:
"""Find documents in a Cosmos DB container using a file pattern regex and custom file filter (optional).
) -> Iterator[str]:
"""Find documents in a Cosmos DB container using a file pattern regex.

Params:
base_dir: The name of the base directory (not used in Cosmos DB context).
file_pattern: The file pattern to use.
file_filter: A dictionary of key-value pairs to filter the documents.
max_count: The maximum number of documents to return. If -1, all documents are returned.

Returns
Expand All @@ -145,23 +143,12 @@ def find(
if not self._database_client or not self._container_client:
return

def item_filter(item: dict[str, Any]) -> bool:
if file_filter is None:
return True
return all(
re.search(value, item.get(key, ""))
for key, value in file_filter.items()
)

try:
query = "SELECT * FROM c WHERE RegexMatch(c.id, @pattern)"
parameters: list[dict[str, Any]] = [
{"name": "@pattern", "value": file_pattern.pattern}
]
if file_filter:
for key, value in file_filter.items():
query += f" AND c.{key} = @{key}"
parameters.append({"name": f"@{key}", "value": value})

items = list(
self._container_client.query_items(
query=query,
Expand All @@ -177,14 +164,10 @@ def item_filter(item: dict[str, Any]) -> bool:
for item in items:
match = file_pattern.search(item["id"])
if match:
group = match.groupdict()
if item_filter(group):
yield (item["id"], group)
num_loaded += 1
if max_count > 0 and num_loaded >= max_count:
break
else:
num_filtered += 1
yield item["id"]
num_loaded += 1
if max_count > 0 and num_loaded >= max_count:
break
else:
num_filtered += 1

Expand Down
31 changes: 9 additions & 22 deletions graphrag/storage/file_pipeline_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,9 @@ def find(
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
file_filter: dict[str, Any] | None = None,
max_count=-1,
) -> Iterator[tuple[str, dict[str, Any]]]:
"""Find files in the storage using a file pattern, as well as a custom filter function."""

def item_filter(item: dict[str, Any]) -> bool:
if file_filter is None:
return True
return all(
re.search(value, item[key]) for key, value in file_filter.items()
)

) -> Iterator[str]:
"""Find files in the storage using a file pattern."""
search_path = Path(self._root_dir) / (base_dir or "")
logger.info(
"search %s for files matching %s", search_path, file_pattern.pattern
Expand All @@ -64,17 +55,13 @@ def item_filter(item: dict[str, Any]) -> bool:
for file in all_files:
match = file_pattern.search(f"{file}")
if match:
group = match.groupdict()
if item_filter(group):
filename = f"{file}".replace(self._root_dir, "")
if filename.startswith(os.sep):
filename = filename[1:]
yield (filename, group)
num_loaded += 1
if max_count > 0 and num_loaded >= max_count:
break
else:
num_filtered += 1
filename = f"{file}".replace(self._root_dir, "")
if filename.startswith(os.sep):
filename = filename[1:]
yield filename
num_loaded += 1
if max_count > 0 and num_loaded >= max_count:
break
else:
num_filtered += 1
logger.debug(
Expand Down
5 changes: 2 additions & 3 deletions graphrag/storage/pipeline_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@ def find(
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
file_filter: dict[str, Any] | None = None,
max_count=-1,
) -> Iterator[tuple[str, dict[str, Any]]]:
"""Find files in the storage using a file pattern, as well as a custom filter function."""
) -> Iterator[str]:
"""Find files in the storage using a file pattern."""

@abstractmethod
async def get(
Expand Down
7 changes: 0 additions & 7 deletions tests/integration/storage/test_blob_pipeline_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ async def test_find():
items = list(
storage.find(base_dir="input", file_pattern=re.compile(r".*\.txt$"))
)
items = [item[0] for item in items]
assert items == []

await storage.set(
Expand All @@ -30,12 +29,10 @@ async def test_find():
items = list(
storage.find(base_dir="input", file_pattern=re.compile(r".*\.txt$"))
)
items = [item[0] for item in items]
assert items == ["input/christmas.txt"]

await storage.set("test.txt", "Hello, World!", encoding="utf-8")
items = list(storage.find(file_pattern=re.compile(r".*\.txt$")))
items = [item[0] for item in items]
assert items == ["input/christmas.txt", "test.txt"]

output = await storage.get("test.txt")
Expand All @@ -57,7 +54,6 @@ async def test_dotprefix():
try:
await storage.set("input/christmas.txt", "Merry Christmas!", encoding="utf-8")
items = list(storage.find(file_pattern=re.compile(r".*\.txt$")))
items = [item[0] for item in items]
assert items == ["input/christmas.txt"]
finally:
storage._delete_container() # noqa: SLF001
Expand Down Expand Up @@ -91,20 +87,17 @@ async def test_child():
storage = parent.child("input")
await storage.set("christmas.txt", "Merry Christmas!", encoding="utf-8")
items = list(storage.find(re.compile(r".*\.txt$")))
items = [item[0] for item in items]
assert items == ["christmas.txt"]

await storage.set("test.txt", "Hello, World!", encoding="utf-8")
items = list(storage.find(re.compile(r".*\.txt$")))
items = [item[0] for item in items]
print("FOUND", items)
assert items == ["christmas.txt", "test.txt"]

output = await storage.get("test.txt")
assert output == "Hello, World!"

items = list(parent.find(re.compile(r".*\.txt$")))
items = [item[0] for item in items]
print("FOUND ITEMS", items)
assert items == ["input/christmas.txt", "input/test.txt"]
finally:
Expand Down
3 changes: 0 additions & 3 deletions tests/integration/storage/test_cosmosdb_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ async def test_find():
try:
try:
items = list(storage.find(file_pattern=re.compile(r".*\.json$")))
items = [item[0] for item in items]
assert items == []

json_content = {
Expand All @@ -40,15 +39,13 @@ async def test_find():
"christmas.json", json.dumps(json_content), encoding="utf-8"
)
items = list(storage.find(file_pattern=re.compile(r".*\.json$")))
items = [item[0] for item in items]
assert items == ["christmas.json"]

json_content = {
"content": "Hello, World!",
}
await storage.set("test.json", json.dumps(json_content), encoding="utf-8")
items = list(storage.find(file_pattern=re.compile(r".*\.json$")))
items = [item[0] for item in items]
assert items == ["christmas.json", "test.json"]

output = await storage.get("test.json")
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/storage/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,8 @@ def find(
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
file_filter: dict[str, Any] | None = None,
max_count=-1,
) -> Iterator[tuple[str, dict[str, Any]]]:
) -> Iterator[str]:
return iter([])

async def get(
Expand Down
Loading
Loading