diff --git a/README.md b/README.md index 6ecac31e8..8772f4fa4 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ For other use cases, first [Install pgai](#installation) in Timescale Cloud, a p * [Cohere](./docs/cohere.md) - configure pgai for Cohere, then use the model to tokenize, embed, chat complete, classify, and rerank. * [Voyage AI](./docs/voyageai.md) - configure pgai for Voyage AI, then use the model to embed. - Leverage LLMs for data processing tasks such as classification, summarization, and data enrichment ([see the OpenAI example](/docs/openai.md)). + - Load datasets from Hugging Face into your database with [ai.load_dataset](/docs/load_dataset_from_huggingface.md). @@ -178,7 +179,7 @@ You can use pgai to integrate AI from the following providers: - [Llama 3 (via Ollama)](/docs/ollama.md) - [Voyage AI](/docs/voyageai.md) -Learn how to [moderate](/docs/moderate.md) content directly in the database using triggers and background jobs. +Learn how to [moderate](/docs/moderate.md) content directly in the database using triggers and background jobs. To get started, [load datasets directly from Hugging Face](/docs/load_dataset_from_huggingface.md) into your database. ### Automatically create and sync LLM embeddings for your data diff --git a/docs/load_dataset_from_huggingface.md b/docs/load_dataset_from_huggingface.md new file mode 100644 index 000000000..bcbaada38 --- /dev/null +++ b/docs/load_dataset_from_huggingface.md @@ -0,0 +1,107 @@ +# Load dataset from Hugging Face + +The `ai.load_dataset` function allows you to load datasets from Hugging Face's datasets library directly into your PostgreSQL database. + +## Example Usage + +```sql +select ai.load_dataset('squad'); + +select * from squad limit 10; +``` + +## Parameters +| Name | Type | Default | Required | Description | +|---------------|---------|-------------|----------|----------------------------------------------------------------------------------------------------| +| name | text | - | ✔ | The name of the dataset on Hugging Face (e.g., 'squad', 'glue', etc.) | +| config_name | text | - | ✖ | The specific configuration of the dataset to load. See [Hugging Face documentation](https://huggingface.co/docs/datasets/v2.20.0/en/load_hub#configurations) for more information. | +| split | text | - | ✖ | The split of the dataset to load (e.g., 'train', 'test', 'validation'). Defaults to all splits. | +| schema_name | text | 'public' | ✖ | The PostgreSQL schema where the table will be created | +| table_name | text | - | ✖ | The name of the table to create. If null, will use the dataset name | +| if_table_exists| text | 'error' | ✖ | Behavior when table exists: 'error' (raise error), 'append' (add rows), 'drop' (drop table and recreate) | +| field_types | jsonb | - | ✖ | Custom PostgreSQL data types for columns as a JSONB dictionary from name to type. | +| batch_size | int | 5000 | ✖ | Number of rows to insert in each batch | +| max_batches | int | null | ✖ | Maximum number of batches to load. Null means load all | +| kwargs | jsonb | - | ✖ | Additional arguments passed to the Hugging Face dataset loading function | + +## Returns + +Returns the number of rows loaded into the database (bigint). + +## Using Multiple Transactions + +The `ai.load_dataset` function loads all data in a single transaction. However, to load large dataset, it is sometimes useful to use multiple transactions. +For this purpose, we provide the `ai.load_dataset_multi_txn` procedure. That procedure is similar to `ai.load_dataset`, but it allows you to specify the number of batches between commits +using the `commit_every_n_batches` parameter. + +```sql +CALL ai.load_dataset_multi_txn('squad', commit_every_n_batches => 10); +``` + +## Examples + +1. Basic usage - Load the entire 'squad' dataset: + +```sql +SELECT ai.load_dataset('squad'); +``` + +The data is loaded into a table named `squad`. + +2. Load a small subset of the 'squad' dataset: + +```sql +SELECT ai.load_dataset('squad', batch_size => 100, max_batches => 1); +``` + +3. Load the entire 'squad' dataset using multiple transactions: + +```sql +CALL ai.load_dataset_multi_txn('squad', commit_every_n_batches => 100); +``` + +4. Load specific configuration and split: + +```sql +SELECT ai.load_dataset( + name => 'glue', + config_name => 'mrpc', + split => 'train' +); +``` + +5. Load with custom table name and field types: + +```sql +SELECT ai.load_dataset( + name => 'glue', + config_name => 'mrpc', + table_name => 'mrpc', + field_types => '{"sentence1": "text", "sentence2": "text"}'::jsonb +); +``` + +6. Pre-create the table and load data into it: + +```sql + +CREATE TABLE squad ( + id TEXT, + title TEXT, + context TEXT, + question TEXT, + answers JSONB +); + +SELECT ai.load_dataset( + name => 'squad', + table_name => 'squad', + if_table_exists => 'append' +); +``` + +## Notes + +- The function requires an active internet connection to download datasets from Hugging Face. +- Large datasets may take significant time to load depending on size and connection speed. +- The function automatically maps Hugging Face dataset types to appropriate PostgreSQL data types unless overridden by `field_types`. diff --git a/projects/extension/ai/load_dataset.py b/projects/extension/ai/load_dataset.py new file mode 100644 index 000000000..1fb46558b --- /dev/null +++ b/projects/extension/ai/load_dataset.py @@ -0,0 +1,253 @@ +import json +import datasets +from typing import Optional, Dict, Any + +from .utils import get_guc_value + +GUC_DATASET_CACHE_DIR = "ai.dataset_cache_dir" + + +def byte_size(s): + return len(s.encode("utf-8")) + + +def get_default_column_type(dtype: str) -> str: + # Default type mapping from dtypes to PostgreSQL types + type_mapping = { + "string": "TEXT", + "dict": "JSONB", + "list": "JSONB", + "int64": "INT8", + "int32": "INT4", + "int16": "INT2", + "int8": "INT2", + "float64": "FLOAT8", + "float32": "FLOAT4", + "float16": "FLOAT4", + "bool": "BOOLEAN", + } + + if dtype.startswith("timestamp"): + return "TIMESTAMPTZ" + else: + return type_mapping.get(dtype.lower(), "TEXT") + + +def get_column_info( + dataset: datasets.Dataset, field_types: Optional[Dict[str, str]] +) -> tuple[Dict[str, str], Dict[str, Any], str]: + # Extract types from features + column_dtypes = {name: feature.dtype for name, feature in dataset.features.items()} + # Prepare column types, using field_types if provided, otherwise use inferred types + column_pgtypes = {} + for name, py_type in column_dtypes.items(): + # Use custom type if provided, otherwise map from python type + column_pgtypes[name] = ( + field_types.get(name) + if field_types and name in field_types + else get_default_column_type(str(py_type)) + ) + column_names = ", ".join(f'"{name}"' for name in column_dtypes.keys()) + return column_pgtypes, column_dtypes, column_names + + +def create_table( + plpy: Any, + name: str, + config_name: Optional[str], + schema: str, + table_name: Optional[str], + column_types: Dict[str, str], + if_table_exists: str, +) -> str: + # Generate default table name if not provided + if table_name is None: + # Handle potential nested dataset names (e.g., "huggingface/dataset") + base_name = name.split("/")[-1] + # Add config name to table name if present + if config_name: + base_name = f"{base_name}_{config_name}" + # Replace any non-alphanumeric characters with underscore + table_name = "".join(c if c.isalnum() else "_" for c in base_name.lower()) + + # Check table name length - PostgreSQL has a 63 character limit for identifiers + if byte_size(table_name) > 63: + # Find the last underscore before the 63 character limit + last_underscore = table_name[:63].rstrip("_").rfind("_") + if last_underscore > 0: + table_name = table_name[:last_underscore] + else: + # If no underscore found, just truncate + table_name = table_name[:63] + else: + # table_name is provided by the user + # Check table name length - PostgreSQL has a 63 character limit for identifiers + if byte_size(table_name) > 63: + plpy.error( + f"Table name '{table_name}' exceeds PostgreSQL's 63 character limit" + ) + + # Construct fully qualified table name + plan = plpy.prepare( + """ + SELECT pg_catalog.format('%I.%I', $1, $2) as qualified_table_name + """, + ["text", "text"], + ) + result = plan.execute([schema, table_name], 1) + qualified_table = result[0]["qualified_table_name"] + + # Check if table exists + result = plpy.execute( + f""" + SELECT pg_catalog.to_regclass('{qualified_table}')::text as friendly_table_name + """ + ) + friendly_table_name = result[0]["friendly_table_name"] + table_exists = friendly_table_name is not None + + if table_exists: + if if_table_exists == "drop": + plpy.notice(f"dropping and recreating table {friendly_table_name}") + plpy.execute(f"DROP TABLE IF EXISTS {qualified_table}") + elif if_table_exists == "error": + plpy.error( + f"Table {friendly_table_name} already exists. Set if_table_exists to 'drop' to replace it or 'append' to add to it." + ) + elif if_table_exists == "append": + plpy.notice(f"adding data to the existing {friendly_table_name} table") + return qualified_table + else: + plpy.error(f"Unsupported if_table_exists value: {if_table_exists}") + else: + plpy.notice(f"creating table {friendly_table_name}") + + column_type_def = ", ".join( + f'"{name}" {col_type}' for name, col_type in column_types.items() + ) + + # Create table + plpy.execute(f"CREATE TABLE {qualified_table} ({column_type_def})") + return qualified_table + + +def load_dataset( + plpy: Any, + # Dataset loading parameters + name: str, + config_name: Optional[str] = None, + split: Optional[str] = None, + # Database target parameters + schema: str = "public", + table_name: Optional[str] = None, + if_table_exists: str = "error", + # Advanced options + field_types: Optional[Dict[str, str]] = None, + batch_size: int = 5000, + max_batches: Optional[int] = None, + commit_every_n_batches: Optional[int] = None, + # Additional dataset loading options + **kwargs: Dict[str, Any], +) -> int: + """ + Load a dataset into PostgreSQL database using plpy with batch UNNEST operations. + + Args: + # Dataset loading parameters + name: Name of the dataset + config_name: Configuration name to load. Some datasets have multiple configurations + (versions or subsets) available. See: https://huggingface.co/docs/datasets/v2.20.0/en/load_hub#configurations + split: Dataset split to load (defaults to all splits) + cache_dir: Directory to cache downloaded datasets (default: None) + + # Database target parameters + schema: Target schema name (default: "public") + table_name: Target table name (default: derived from dataset name) + drop_if_exists: If True, drop existing table; if False, error if table exists (default: False) + + # Advanced options + field_types: Optional dictionary of field names to PostgreSQL types + batch_size: Number of rows to insert in each batch (default: 5000) + + # Additional dataset loading options + **kwargs: Additional keyword arguments passed to datasets.load_dataset() + + Returns: + Number of rows loaded + """ + + cache_dir = get_guc_value(plpy, GUC_DATASET_CACHE_DIR, None) + + # Load dataset using Hugging Face datasets library + ds = datasets.load_dataset( + name, config_name, split=split, cache_dir=cache_dir, streaming=True, **kwargs + ) + if isinstance(ds, datasets.IterableDatasetDict): + datasetdict = ds + elif isinstance(ds, datasets.IterableDataset): + datasetdict = {split: ds} + else: + plpy.error( + f"Unsupported dataset type: {type(ds)}. Only datasets.IterableDatasetDict and datasets.IterableDataset are supported." + ) + + first_dataset = next(iter(datasetdict.values())) + column_pgtypes, column_dtypes, column_names = get_column_info( + first_dataset, field_types + ) + qualified_table = create_table( + plpy, name, config_name, schema, table_name, column_pgtypes, if_table_exists + ) + + # Prepare the UNNEST parameters and INSERT statement once + unnest_params = [] + type_params = [] + for i, (col_name, col_type) in enumerate(column_pgtypes.items(), 1): + unnest_params.append(f"${i}::{col_type}[]") + type_params.append(f"{col_type}[]") + + insert_sql = f""" + INSERT INTO {qualified_table} ({column_names}) + SELECT * FROM unnest({', '.join(unnest_params)}) + """ + insert_plan = plpy.prepare(insert_sql, type_params) + + num_rows = 0 + batch_count = 0 + batches_since_commit = 0 + for split, dataset in datasetdict.items(): + # Process data in batches using dataset iteration + batched_dataset = dataset.batch(batch_size=batch_size) + for batch in batched_dataset: + if max_batches and batch_count >= max_batches: + break + + batch_arrays = [[] for _ in column_dtypes] + for i, (col_name, py_type) in enumerate(column_dtypes.items()): + type_str = str(py_type).lower() + array_values = batch[col_name] + + if type_str in ("dict", "list"): + batch_arrays[i] = [json.dumps(value) for value in array_values] + elif type_str in ("int64", "int32", "int16", "int8"): + batch_arrays[i] = [int(value) for value in array_values] + elif type_str in ("float64", "float32", "float16"): + batch_arrays[i] = [float(value) for value in array_values] + else: + batch_arrays[i] = array_values + + insert_plan.execute(batch_arrays) + num_rows += len(batch_arrays[0]) + batch_count += 1 + batches_since_commit += 1 + plpy.debug( + f"inserted {num_rows} rows using {batch_count} batches into {qualified_table} so far..." + ) + if ( + commit_every_n_batches + and batches_since_commit >= commit_every_n_batches + ): + plpy.commit() + batches_since_commit = 0 + + return num_rows diff --git a/projects/extension/ai/secrets.py b/projects/extension/ai/secrets.py index 15674d2a5..e889a443e 100644 --- a/projects/extension/ai/secrets.py +++ b/projects/extension/ai/secrets.py @@ -6,6 +6,8 @@ import httpx from backoff._typing import Details +from .utils import get_guc_value + GUC_SECRETS_MANAGER_URL = "ai.external_functions_executor_url" GUC_SECRET_ENV_ENABLED = "ai.secret_env_enabled" @@ -45,17 +47,6 @@ def get_secret( return secret -def get_guc_value(plpy, setting: str, default: str) -> str: - plan = plpy.prepare("select pg_catalog.current_setting($1, true) as val", ["text"]) - result = plan.execute([setting], 1) - val: str | None = None - if len(result) != 0: - val = result[0]["val"] - if val is None: - val = default - return val - - def check_secret_permissions(plpy, secret_name: str) -> bool: # check if the user has access to all secrets plan = plpy.prepare( diff --git a/projects/extension/ai/utils.py b/projects/extension/ai/utils.py new file mode 100644 index 000000000..dc0acfa7b --- /dev/null +++ b/projects/extension/ai/utils.py @@ -0,0 +1,9 @@ +def get_guc_value(plpy, setting: str, default: str) -> str: + plan = plpy.prepare("select pg_catalog.current_setting($1, true) as val", ["text"]) + result = plan.execute([setting], 1) + val: str | None = None + if len(result) != 0: + val = result[0]["val"] + if val is None: + val = default + return val diff --git a/projects/extension/ai/vectorizer.py b/projects/extension/ai/vectorizer.py index d7e92fc0c..5806cc806 100644 --- a/projects/extension/ai/vectorizer.py +++ b/projects/extension/ai/vectorizer.py @@ -5,6 +5,8 @@ import httpx from backoff._typing import Details +from .utils import get_guc_value + GUC_VECTORIZER_URL = "ai.external_functions_executor_url" DEFAULT_VECTORIZER_URL = "http://localhost:8000" @@ -12,17 +14,6 @@ DEFAULT_VECTORIZER_PATH = "/api/v1/events" -def get_guc_value(plpy, setting: str, default: str) -> str: - plan = plpy.prepare("select pg_catalog.current_setting($1, true) as val", ["text"]) - result = plan.execute([setting], 1) - val: str | None = None - if len(result) != 0: - val = result[0]["val"] - if val is None: - val = default - return val - - def execute_vectorizer(plpy, vectorizer_id: int) -> None: plan = plpy.prepare( """ diff --git a/projects/extension/build.py b/projects/extension/build.py index efa8375ca..95ebf5a14 100755 --- a/projects/extension/build.py +++ b/projects/extension/build.py @@ -377,7 +377,7 @@ def clean_sql() -> None: def postgres_bin_dir() -> Path: bin_dir = os.getenv("PG_BIN") - if bin_dir: + if Path(bin_dir).is_dir(): return Path(bin_dir).resolve() else: bin_dir = Path(f"/usr/lib/postgresql/{pg_major()}/bin") diff --git a/projects/extension/justfile b/projects/extension/justfile index fde4664b0..f80046934 100644 --- a/projects/extension/justfile +++ b/projects/extension/justfile @@ -1,5 +1,5 @@ PG_MAJOR := env("PG_MAJOR", "17") -PG_BIN := "/usr/lib/postgresql/" + PG_MAJOR + "/bin" +PG_BIN := env("PG_BIN", "/usr/lib/postgresql/" + PG_MAJOR + "/bin") # Show list of recipes default: diff --git a/projects/extension/requirements.txt b/projects/extension/requirements.txt index 9e0eddded..e1297133e 100644 --- a/projects/extension/requirements.txt +++ b/projects/extension/requirements.txt @@ -4,4 +4,5 @@ ollama==0.2.1 anthropic==0.29.0 cohere==5.5.8 backoff==2.2.1 -voyageai==0.3.1 \ No newline at end of file +voyageai==0.3.1 +datasets==3.1.0 diff --git a/projects/extension/setup.cfg b/projects/extension/setup.cfg index f742fd699..8ebb91ee2 100644 --- a/projects/extension/setup.cfg +++ b/projects/extension/setup.cfg @@ -13,4 +13,5 @@ install_requires = anthropic==0.29.0 cohere==5.5.8 backoff==2.2.1 - voyageai==0.3.1 \ No newline at end of file + voyageai==0.3.1 + datasets==3.1.0 diff --git a/projects/extension/sql/idempotent/016-load_dataset.sql b/projects/extension/sql/idempotent/016-load_dataset.sql new file mode 100644 index 000000000..58b753a3b --- /dev/null +++ b/projects/extension/sql/idempotent/016-load_dataset.sql @@ -0,0 +1,90 @@ +create or replace procedure ai.load_dataset_multi_txn +( name text +, config_name text default null +, split text default null +, schema_name name default 'public' +, table_name name default null +, if_table_exists text default 'error' +, field_types jsonb default null +, batch_size int default 5000 +, max_batches int default null +, commit_every_n_batches int default 1 +, kwargs jsonb default '{}' +) +as $python$ + #ADD-PYTHON-LIB-DIR + import ai.load_dataset + import json + + # Convert kwargs from json string to dict + kwargs_dict = {} + if kwargs: + kwargs_dict = json.loads(kwargs) + + # Convert field_types from json string to dict + field_types_dict = None + if field_types: + field_types_dict = json.loads(field_types) + + + num_rows = ai.load_dataset.load_dataset( + plpy, + name=name, + config_name=config_name, + split=split, + schema=schema_name, + table_name=table_name, + if_table_exists=if_table_exists, + field_types=field_types_dict, + batch_size=batch_size, + max_batches=max_batches, + commit_every_n_batches=commit_every_n_batches, + **kwargs_dict + ) +$python$ +language plpython3u security invoker; + +create or replace function ai.load_dataset +( name text +, config_name text default null +, split text default null +, schema_name name default 'public' +, table_name name default null +, if_table_exists text default 'error' +, field_types jsonb default null +, batch_size int default 5000 +, max_batches int default null +, kwargs jsonb default '{}' +) returns bigint +as $python$ + #ADD-PYTHON-LIB-DIR + import ai.load_dataset + import json + + # Convert kwargs from json string to dict + kwargs_dict = {} + if kwargs: + kwargs_dict = json.loads(kwargs) + + # Convert field_types from json string to dict + field_types_dict = None + if field_types: + field_types_dict = json.loads(field_types) + + return ai.load_dataset.load_dataset( + plpy, + name=name, + config_name=config_name, + split=split, + schema=schema_name, + table_name=table_name, + if_table_exists=if_table_exists, + field_types=field_types_dict, + batch_size=batch_size, + max_batches=max_batches, + commit_every_n_batches=None, + **kwargs_dict + ) +$python$ +language plpython3u volatile security invoker +set search_path to pg_catalog, pg_temp; \ No newline at end of file diff --git a/projects/extension/tests/contents/output16.expected b/projects/extension/tests/contents/output16.expected index aff71f6f8..4ea8e82c8 100644 --- a/projects/extension/tests/contents/output16.expected +++ b/projects/extension/tests/contents/output16.expected @@ -34,6 +34,8 @@ CREATE EXTENSION function ai.indexing_diskann(integer,text,integer,integer,double precision,integer,integer,boolean) function ai.indexing_hnsw(integer,text,integer,integer,boolean) function ai.indexing_none() + function ai.load_dataset_multi_txn(text,text,text,name,name,text,jsonb,integer,integer,integer,jsonb) + function ai.load_dataset(text,text,text,name,name,text,jsonb,integer,integer,jsonb) function ai.ollama_chat_complete(text,jsonb,text,text,jsonb) function ai.ollama_embed(text,text,text,text,jsonb) function ai.ollama_generate(text,text,text,bytea[],text,jsonb,text,text,integer[]) @@ -90,7 +92,7 @@ CREATE EXTENSION table ai.vectorizer_errors view ai.secret_permissions view ai.vectorizer_status -(86 rows) +(88 rows) Table "ai._secret_permissions" Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description diff --git a/projects/extension/tests/contents/output17.expected b/projects/extension/tests/contents/output17.expected index 9cae4ffe4..8fdcd4d50 100644 --- a/projects/extension/tests/contents/output17.expected +++ b/projects/extension/tests/contents/output17.expected @@ -34,6 +34,8 @@ CREATE EXTENSION function ai.indexing_diskann(integer,text,integer,integer,double precision,integer,integer,boolean) function ai.indexing_hnsw(integer,text,integer,integer,boolean) function ai.indexing_none() + function ai.load_dataset_multi_txn(text,text,text,name,name,text,jsonb,integer,integer,integer,jsonb) + function ai.load_dataset(text,text,text,name,name,text,jsonb,integer,integer,jsonb) function ai.ollama_chat_complete(text,jsonb,text,text,jsonb) function ai.ollama_embed(text,text,text,text,jsonb) function ai.ollama_generate(text,text,text,bytea[],text,jsonb,text,text,integer[]) @@ -104,7 +106,7 @@ CREATE EXTENSION type ai.vectorizer_status[] view ai.secret_permissions view ai.vectorizer_status -(100 rows) +(102 rows) Table "ai._secret_permissions" Column | Type | Collation | Nullable | Default | Storage | Compression | Stats target | Description diff --git a/projects/extension/tests/privileges/function.expected b/projects/extension/tests/privileges/function.expected index 3534bf657..97e05366f 100644 --- a/projects/extension/tests/privileges/function.expected +++ b/projects/extension/tests/privileges/function.expected @@ -220,6 +220,14 @@ f | bob | execute | no | ai | indexing_none() f | fred | execute | no | ai | indexing_none() f | jill | execute | YES | ai | indexing_none() + f | alice | execute | YES | ai | load_dataset(name text, config_name text, split text, schema_name name, table_name name, if_table_exists text, field_types jsonb, batch_size integer, max_batches integer, kwargs jsonb) + f | bob | execute | no | ai | load_dataset(name text, config_name text, split text, schema_name name, table_name name, if_table_exists text, field_types jsonb, batch_size integer, max_batches integer, kwargs jsonb) + f | fred | execute | no | ai | load_dataset(name text, config_name text, split text, schema_name name, table_name name, if_table_exists text, field_types jsonb, batch_size integer, max_batches integer, kwargs jsonb) + f | jill | execute | YES | ai | load_dataset(name text, config_name text, split text, schema_name name, table_name name, if_table_exists text, field_types jsonb, batch_size integer, max_batches integer, kwargs jsonb) + p | alice | execute | YES | ai | load_dataset_multi_txn(IN name text, IN config_name text, IN split text, IN schema_name name, IN table_name name, IN if_table_exists text, IN field_types jsonb, IN batch_size integer, IN max_batches integer, IN commit_every_n_batches integer, IN kwargs jsonb) + p | bob | execute | no | ai | load_dataset_multi_txn(IN name text, IN config_name text, IN split text, IN schema_name name, IN table_name name, IN if_table_exists text, IN field_types jsonb, IN batch_size integer, IN max_batches integer, IN commit_every_n_batches integer, IN kwargs jsonb) + p | fred | execute | no | ai | load_dataset_multi_txn(IN name text, IN config_name text, IN split text, IN schema_name name, IN table_name name, IN if_table_exists text, IN field_types jsonb, IN batch_size integer, IN max_batches integer, IN commit_every_n_batches integer, IN kwargs jsonb) + p | jill | execute | YES | ai | load_dataset_multi_txn(IN name text, IN config_name text, IN split text, IN schema_name name, IN table_name name, IN if_table_exists text, IN field_types jsonb, IN batch_size integer, IN max_batches integer, IN commit_every_n_batches integer, IN kwargs jsonb) f | alice | execute | YES | ai | ollama_chat_complete(model text, messages jsonb, host text, keep_alive text, chat_options jsonb) f | bob | execute | no | ai | ollama_chat_complete(model text, messages jsonb, host text, keep_alive text, chat_options jsonb) f | fred | execute | no | ai | ollama_chat_complete(model text, messages jsonb, host text, keep_alive text, chat_options jsonb) @@ -312,5 +320,5 @@ f | bob | execute | no | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text) f | fred | execute | no | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text) f | jill | execute | YES | ai | voyageai_embed(model text, input_texts text[], input_type text, api_key text, api_key_name text) -(312 rows) +(320 rows) diff --git a/projects/extension/tests/test_load_dataset.py b/projects/extension/tests/test_load_dataset.py new file mode 100644 index 000000000..c7be46fed --- /dev/null +++ b/projects/extension/tests/test_load_dataset.py @@ -0,0 +1,195 @@ +import datetime +import os + +import psycopg +import pytest + + +# skip tests in this module if disabled +enable_load_dataset_tests = os.getenv("ENABLE_LOAD_DATASET_TESTS") +if not enable_load_dataset_tests or enable_load_dataset_tests == "0": + pytest.skip(allow_module_level=True) + + +@pytest.fixture() +def cur() -> psycopg.Cursor: + with psycopg.connect("postgres://test@127.0.0.1:5432/test") as con: + with con.cursor() as cur: + yield cur + + +def test_load_dataset(cur): + # load everything + cur.execute( + """ + select ai.load_dataset('rotten_tomatoes') + """, + ) + actual = cur.fetchone()[0] + assert actual == 10662 + + cur.execute("select count(*) from public.rotten_tomatoes") + assert cur.fetchone()[0] == actual + + cur.execute( + "select column_name, data_type from information_schema.columns where table_name = 'rotten_tomatoes' order by ordinal_position" + ) + assert cur.fetchall() == [("text", "text"), ("label", "bigint")] + + # test append and explicit split + cur.execute( + """ + select ai.load_dataset('rotten_tomatoes', split=>'test', if_table_exists=>'append', batch_size=>2, max_batches=>1) + """, + ) + actual = cur.fetchone()[0] + assert actual == 2 + + cur.execute("select count(*) from public.rotten_tomatoes") + assert cur.fetchone()[0] == 10662 + 2 + + # test drop + cur.execute( + """ + select ai.load_dataset('rotten_tomatoes', split=>'test', if_table_exists=>'drop', batch_size=>2, max_batches=>1) + """, + ) + actual = cur.fetchone()[0] + assert actual == 2 + + cur.execute("select count(*) from public.rotten_tomatoes") + assert cur.fetchone()[0] == 2 + + # test error + with pytest.raises(Exception): + cur.execute( + """ + select ai.load_dataset('rotten_tomatoes', split=>'test', if_table_exists=>'error') + """, + ) + + +def test_load_dataset_with_field_types(cur): + cur.execute( + """ + select ai.load_dataset('rotten_tomatoes', schema_name=>'public', table_name=>'rotten_tomatoes2', field_types=>'{"label": "int"}'::jsonb, batch_size=>100, max_batches=>1) + """, + ) + actual = cur.fetchone()[0] + assert actual == 100 + + cur.execute("select count(*) from public.rotten_tomatoes2") + assert cur.fetchone()[0] == actual + + cur.execute( + "select column_name, data_type from information_schema.columns where table_name = 'rotten_tomatoes2' order by ordinal_position" + ) + assert cur.fetchall() == [("text", "text"), ("label", "integer")] + + +def test_load_dataset_with_field_with_max_batches_and_timestamp(cur): + cur.execute( + """ + select ai.load_dataset('Weijie1996/load_timeseries', batch_size=>2, max_batches=>1) + """, + ) + actual = cur.fetchone()[0] + assert actual == 2 + + cur.execute("select count(*) from public.load_timeseries") + assert cur.fetchone()[0] == actual + + cur.execute( + "select column_name, data_type from information_schema.columns where table_name = 'load_timeseries' order by ordinal_position" + ) + assert cur.fetchall() == [ + ("id", "text"), + ("datetime", "timestamp with time zone"), + ("target", "double precision"), + ("category", "text"), + ] + + cur.execute("select datetime from public.load_timeseries limit 1") + assert cur.fetchone()[0] == datetime.datetime( + 2015, 5, 21, 15, 45, tzinfo=datetime.timezone.utc + ) + + +def test_load_dataset_with_commit_every_n_batches(cur): + cur.execute( + "select xact_commit from pg_stat_database where datname = current_database()" + ) + original_txn_count = cur.fetchone()[0] + + # autocommit=True allows us to commit inside the procedure + with psycopg.connect( + "postgres://test@127.0.0.1:5432/test", autocommit=True + ) as con2: + with con2.cursor() as cur2: + cur2.execute( + """ + call ai.load_dataset_multi_txn('Chendi/NYC_TAXI_FARE_CLEANED', table_name=>'nyc_taxi_fare_cleaned_multi_txn', batch_size=>2, max_batches=>10, commit_every_n_batches=>1) + """, + prepare=False, + ) + con2.commit() + + cur.execute("select count(*) from public.nyc_taxi_fare_cleaned_multi_txn") + assert cur.fetchone()[0] == 20 + + with psycopg.connect("postgres://test@127.0.0.1:5432/test") as con3: + with con3.cursor() as cur3: + cur3.execute( + "select xact_commit from pg_stat_database where datname = current_database()" + ) + new_txn_count = cur3.fetchone()[0] + + assert new_txn_count > original_txn_count + 10 + + +def test_load_dataset_other_datasets(cur): + # test nyc taxi fare cleaned - timestamp mislabeled as text, force timestamp + cur.execute(""" + select ai.load_dataset('Chendi/NYC_TAXI_FARE_CLEANED', batch_size=>2, max_batches=>1, field_types=>'{"pickup_datetime": "timestamp with time zone"}'::jsonb) + """) + actual = cur.fetchone()[0] + assert actual == 2 + + cur.execute( + "select column_name, data_type from information_schema.columns where table_name = 'nyc_taxi_fare_cleaned' order by ordinal_position" + ) + assert cur.fetchall() == [ + ("fare_amount", "double precision"), + ("pickup_datetime", "timestamp with time zone"), + ("pickup_longitude", "double precision"), + ("pickup_latitude", "double precision"), + ("dropoff_longitude", "double precision"), + ("dropoff_latitude", "double precision"), + ("passenger_count", "bigint"), + ] + + cur.execute("select pickup_datetime from public.nyc_taxi_fare_cleaned limit 1") + assert cur.fetchone()[0] == datetime.datetime( + 2009, 6, 15, 17, 26, 21, tzinfo=datetime.timezone.utc + ) + + # dataset with sequence column -- become a jsonb column + cur.execute(""" + select ai.load_dataset('tppllm/nyc-taxi-description', batch_size=>2, max_batches=>1) + """) + actual = cur.fetchone()[0] + assert actual == 2 + + cur.execute( + "select column_name, data_type from information_schema.columns where table_name = 'nyc_taxi_description' order by ordinal_position" + ) + assert cur.fetchall() == [ + ("dim_process", "bigint"), + ("seq_idx", "bigint"), + ("seq_len", "bigint"), + ("time_since_start", "jsonb"), + ("time_since_last_event", "jsonb"), + ("type_event", "jsonb"), + ("type_text", "jsonb"), + ("description", "text"), + ]