Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add load dataset #253

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ For other use cases, first [Install pgai](#installation) in Timescale Cloud, a p
* [Cohere](./docs/cohere.md) - configure pgai for Cohere, then use the model to tokenize, embed, chat complete, classify, and rerank.
* [Voyage AI](./docs/voyageai.md) - configure pgai for Voyage AI, then use the model to embed.
- Leverage LLMs for data processing tasks such as classification, summarization, and data enrichment ([see the OpenAI example](/docs/openai.md)).
- Load datasets from Hugging Face into your database with [ai.load_dataset](/docs/load_dataset_from_huggingface.md).



Expand Down Expand Up @@ -178,7 +179,7 @@ You can use pgai to integrate AI from the following providers:
- [Llama 3 (via Ollama)](/docs/ollama.md)
- [Voyage AI](/docs/voyageai.md)

Learn how to [moderate](/docs/moderate.md) content directly in the database using triggers and background jobs.
Learn how to [moderate](/docs/moderate.md) content directly in the database using triggers and background jobs. To get started, [load datasets directly from Hugging Face](/docs/load_dataset_from_huggingface.md) into your database.

### Automatically create and sync LLM embeddings for your data

Expand Down
107 changes: 107 additions & 0 deletions docs/load_dataset_from_huggingface.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Load dataset from Hugging Face

The `ai.load_dataset` function allows you to load datasets from Hugging Face's datasets library directly into your PostgreSQL database.

## Example Usage

```sql
select ai.load_dataset('squad');

select * from squad limit 10;
```

## Parameters
| Name | Type | Default | Required | Description |
|---------------|---------|-------------|----------|----------------------------------------------------------------------------------------------------|
| name | text | - | ✔ | The name of the dataset on Hugging Face (e.g., 'squad', 'glue', etc.) |
| config_name | text | - | ✖ | The specific configuration of the dataset to load. See [Hugging Face documentation](https://huggingface.co/docs/datasets/v2.20.0/en/load_hub#configurations) for more information. |
| split | text | - | ✖ | The split of the dataset to load (e.g., 'train', 'test', 'validation'). Defaults to all splits. |
| schema_name | text | 'public' | ✖ | The PostgreSQL schema where the table will be created |
| table_name | text | - | ✖ | The name of the table to create. If null, will use the dataset name |
| if_table_exists| text | 'error' | ✖ | Behavior when table exists: 'error' (raise error), 'append' (add rows), 'drop' (drop table and recreate) |
| field_types | jsonb | - | ✖ | Custom PostgreSQL data types for columns as a JSONB dictionary from name to type. |
| batch_size | int | 5000 | ✖ | Number of rows to insert in each batch |
| max_batches | int | null | ✖ | Maximum number of batches to load. Null means load all |
| kwargs | jsonb | - | ✖ | Additional arguments passed to the Hugging Face dataset loading function |

## Returns

Returns the number of rows loaded into the database (bigint).

## Using Multiple Transactions

The `ai.load_dataset` function loads all data in a single transaction. However, to load large dataset, it is sometimes useful to use multiple transactions.
For this purpose, we provide the `ai.load_dataset_multi_txn` procedure. That procedure is similar to `ai.load_dataset`, but it allows you to specify the number of batches between commits
using the `commit_every_n_batches` parameter.

```sql
CALL ai.load_dataset_multi_txn('squad', commit_every_n_batches => 10);
```

## Examples

1. Basic usage - Load the entire 'squad' dataset:

```sql
SELECT ai.load_dataset('squad');
```

The data is loaded into a table named `squad`.

2. Load a small subset of the 'squad' dataset:

```sql
SELECT ai.load_dataset('squad', batch_size => 100, max_batches => 1);
```

3. Load the entire 'squad' dataset using multiple transactions:

```sql
CALL ai.load_dataset_multi_txn('squad', commit_every_n_batches => 100);
```

4. Load specific configuration and split:

```sql
SELECT ai.load_dataset(
name => 'glue',
config_name => 'mrpc',
split => 'train'
);
```

5. Load with custom table name and field types:

```sql
SELECT ai.load_dataset(
name => 'glue',
config_name => 'mrpc',
table_name => 'mrpc',
field_types => '{"sentence1": "text", "sentence2": "text"}'::jsonb
);
```

6. Pre-create the table and load data into it:

```sql

CREATE TABLE squad (
id TEXT,
title TEXT,
context TEXT,
question TEXT,
answers JSONB
);

SELECT ai.load_dataset(
name => 'squad',
table_name => 'squad',
if_table_exists => 'append'
);
```

## Notes

- The function requires an active internet connection to download datasets from Hugging Face.
- Large datasets may take significant time to load depending on size and connection speed.
- The function automatically maps Hugging Face dataset types to appropriate PostgreSQL data types unless overridden by `field_types`.
253 changes: 253 additions & 0 deletions projects/extension/ai/load_dataset.py
cevian marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import json
import datasets
from typing import Optional, Dict, Any

from .utils import get_guc_value

GUC_DATASET_CACHE_DIR = "ai.dataset_cache_dir"


def byte_size(s):
return len(s.encode("utf-8"))


def get_default_column_type(dtype: str) -> str:
# Default type mapping from dtypes to PostgreSQL types
type_mapping = {
"string": "TEXT",
"dict": "JSONB",
"list": "JSONB",
"int64": "INT8",
"int32": "INT4",
"int16": "INT2",
"int8": "INT2",
"float64": "FLOAT8",
"float32": "FLOAT4",
"float16": "FLOAT4",
"bool": "BOOLEAN",
}

if dtype.startswith("timestamp"):
return "TIMESTAMPTZ"
else:
return type_mapping.get(dtype.lower(), "TEXT")


def get_column_info(
dataset: datasets.Dataset, field_types: Optional[Dict[str, str]]
) -> tuple[Dict[str, str], Dict[str, Any], str]:
# Extract types from features
column_dtypes = {name: feature.dtype for name, feature in dataset.features.items()}
# Prepare column types, using field_types if provided, otherwise use inferred types
column_pgtypes = {}
for name, py_type in column_dtypes.items():
# Use custom type if provided, otherwise map from python type
column_pgtypes[name] = (
field_types.get(name)
if field_types and name in field_types
else get_default_column_type(str(py_type))
jgpruitt marked this conversation as resolved.
Show resolved Hide resolved
)
column_names = ", ".join(f'"{name}"' for name in column_dtypes.keys())
return column_pgtypes, column_dtypes, column_names


def create_table(
plpy: Any,
name: str,
config_name: Optional[str],
schema: str,
table_name: Optional[str],
column_types: Dict[str, str],
if_table_exists: str,
) -> str:
# Generate default table name if not provided
if table_name is None:
# Handle potential nested dataset names (e.g., "huggingface/dataset")
base_name = name.split("/")[-1]
# Add config name to table name if present
if config_name:
base_name = f"{base_name}_{config_name}"
# Replace any non-alphanumeric characters with underscore
table_name = "".join(c if c.isalnum() else "_" for c in base_name.lower())

# Check table name length - PostgreSQL has a 63 character limit for identifiers
if byte_size(table_name) > 63:
# Find the last underscore before the 63 character limit
last_underscore = table_name[:63].rstrip("_").rfind("_")
if last_underscore > 0:
table_name = table_name[:last_underscore]
else:
# If no underscore found, just truncate
table_name = table_name[:63]
else:
# table_name is provided by the user
# Check table name length - PostgreSQL has a 63 character limit for identifiers
if byte_size(table_name) > 63:
plpy.error(
f"Table name '{table_name}' exceeds PostgreSQL's 63 character limit"
)

# Construct fully qualified table name
plan = plpy.prepare(
"""
SELECT pg_catalog.format('%I.%I', $1, $2) as qualified_table_name
""",
["text", "text"],
)
result = plan.execute([schema, table_name], 1)
qualified_table = result[0]["qualified_table_name"]

# Check if table exists
result = plpy.execute(
f"""
SELECT pg_catalog.to_regclass('{qualified_table}')::text as friendly_table_name
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this strip the schema if the search_path allows for it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it does. Is that a problem?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. Just curious

"""
)
friendly_table_name = result[0]["friendly_table_name"]
table_exists = friendly_table_name is not None

if table_exists:
if if_table_exists == "drop":
plpy.notice(f"dropping and recreating table {friendly_table_name}")
plpy.execute(f"DROP TABLE IF EXISTS {qualified_table}")
elif if_table_exists == "error":
plpy.error(
f"Table {friendly_table_name} already exists. Set if_table_exists to 'drop' to replace it or 'append' to add to it."
)
elif if_table_exists == "append":
plpy.notice(f"adding data to the existing {friendly_table_name} table")
return qualified_table
else:
plpy.error(f"Unsupported if_table_exists value: {if_table_exists}")
else:
plpy.notice(f"creating table {friendly_table_name}")

column_type_def = ", ".join(
f'"{name}" {col_type}' for name, col_type in column_types.items()
)

# Create table
plpy.execute(f"CREATE TABLE {qualified_table} ({column_type_def})")
return qualified_table


def load_dataset(
plpy: Any,
# Dataset loading parameters
name: str,
config_name: Optional[str] = None,
split: Optional[str] = None,
# Database target parameters
schema: str = "public",
table_name: Optional[str] = None,
if_table_exists: str = "error",
# Advanced options
field_types: Optional[Dict[str, str]] = None,
batch_size: int = 5000,
max_batches: Optional[int] = None,
commit_every_n_batches: Optional[int] = None,
# Additional dataset loading options
**kwargs: Dict[str, Any],
) -> int:
"""
Load a dataset into PostgreSQL database using plpy with batch UNNEST operations.

Args:
# Dataset loading parameters
name: Name of the dataset
config_name: Configuration name to load. Some datasets have multiple configurations
(versions or subsets) available. See: https://huggingface.co/docs/datasets/v2.20.0/en/load_hub#configurations
split: Dataset split to load (defaults to all splits)
cache_dir: Directory to cache downloaded datasets (default: None)

# Database target parameters
schema: Target schema name (default: "public")
table_name: Target table name (default: derived from dataset name)
drop_if_exists: If True, drop existing table; if False, error if table exists (default: False)

# Advanced options
field_types: Optional dictionary of field names to PostgreSQL types
batch_size: Number of rows to insert in each batch (default: 5000)

# Additional dataset loading options
**kwargs: Additional keyword arguments passed to datasets.load_dataset()

Returns:
Number of rows loaded
"""

cache_dir = get_guc_value(plpy, GUC_DATASET_CACHE_DIR, None)

# Load dataset using Hugging Face datasets library
ds = datasets.load_dataset(
name, config_name, split=split, cache_dir=cache_dir, streaming=True, **kwargs
)
if isinstance(ds, datasets.IterableDatasetDict):
datasetdict = ds
elif isinstance(ds, datasets.IterableDataset):
datasetdict = {split: ds}
else:
plpy.error(
f"Unsupported dataset type: {type(ds)}. Only datasets.IterableDatasetDict and datasets.IterableDataset are supported."
)

first_dataset = next(iter(datasetdict.values()))
column_pgtypes, column_dtypes, column_names = get_column_info(
first_dataset, field_types
)
qualified_table = create_table(
plpy, name, config_name, schema, table_name, column_pgtypes, if_table_exists
)

# Prepare the UNNEST parameters and INSERT statement once
unnest_params = []
type_params = []
for i, (col_name, col_type) in enumerate(column_pgtypes.items(), 1):
unnest_params.append(f"${i}::{col_type}[]")
type_params.append(f"{col_type}[]")

insert_sql = f"""
INSERT INTO {qualified_table} ({column_names})
SELECT * FROM unnest({', '.join(unnest_params)})
"""
insert_plan = plpy.prepare(insert_sql, type_params)

num_rows = 0
batch_count = 0
batches_since_commit = 0
for split, dataset in datasetdict.items():
# Process data in batches using dataset iteration
batched_dataset = dataset.batch(batch_size=batch_size)
for batch in batched_dataset:
if max_batches and batch_count >= max_batches:
break

batch_arrays = [[] for _ in column_dtypes]
for i, (col_name, py_type) in enumerate(column_dtypes.items()):
type_str = str(py_type).lower()
array_values = batch[col_name]

if type_str in ("dict", "list"):
batch_arrays[i] = [json.dumps(value) for value in array_values]
elif type_str in ("int64", "int32", "int16", "int8"):
batch_arrays[i] = [int(value) for value in array_values]
elif type_str in ("float64", "float32", "float16"):
batch_arrays[i] = [float(value) for value in array_values]
else:
batch_arrays[i] = array_values

insert_plan.execute(batch_arrays)
num_rows += len(batch_arrays[0])
batch_count += 1
batches_since_commit += 1
plpy.debug(
f"inserted {num_rows} rows using {batch_count} batches into {qualified_table} so far..."
)
if (
commit_every_n_batches
and batches_since_commit >= commit_every_n_batches
):
plpy.commit()
batches_since_commit = 0

return num_rows
Loading
Loading