From f0ffa2bd736a90b62ed2ccbc31cc0dff71d703a9 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 22 May 2025 06:10:11 -0400 Subject: [PATCH 01/10] Add base Postgres vector writer, CloudSQL vector writer and refactor. --- .../schemaio-expansion-service/build.gradle | 2 + .../apache_beam/internal/gcp/auth_test.py | 1 + .../apache_beam/ml/gcp/recommendations_ai.py | 2 +- .../apache_beam/ml/rag/ingestion/alloydb.py | 783 ++------------- .../ml/rag/ingestion/alloydb_it_test.py | 912 +----------------- .../apache_beam/ml/rag/ingestion/cloudsql.py | 220 +++++ .../ml/rag/ingestion/cloudsql_it_test.py | 223 +++++ .../ml/rag/ingestion/jdbc_common.py | 78 ++ .../apache_beam/ml/rag/ingestion/postgres.py | 209 ++++ .../ml/rag/ingestion/postgres_common.py | 484 ++++++++++ .../ml/rag/ingestion/postgres_it_test.py | 902 +++++++++++++++++ .../ml/rag/ingestion/test_utils.py | 105 ++ sdks/python/setup.py | 1 + 13 files changed, 2325 insertions(+), 1597 deletions(-) create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/jdbc_common.py create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/postgres.py create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/postgres_it_test.py create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/test_utils.py diff --git a/sdks/java/extensions/schemaio-expansion-service/build.gradle b/sdks/java/extensions/schemaio-expansion-service/build.gradle index c2128fb73c3d..15873d58e615 100644 --- a/sdks/java/extensions/schemaio-expansion-service/build.gradle +++ b/sdks/java/extensions/schemaio-expansion-service/build.gradle @@ -62,6 +62,8 @@ dependencies { permitUnusedDeclared 'com.microsoft.sqlserver:mssql-jdbc:12.2.0.jre11' // BEAM-11761 implementation 'com.google.cloud:alloydb-jdbc-connector:1.2.0' permitUnusedDeclared 'com.google.cloud:alloydb-jdbc-connector:1.2.0' + implementation 'com.google.cloud.sql:postgres-socket-factory:1.25.0' + permitUnusedDeclared 'com.google.cloud.sql:postgres-socket-factory:1.25.0' testImplementation library.java.junit testImplementation library.java.mockito_core runtimeOnly ("org.xerial:sqlite-jdbc:3.49.1.0") diff --git a/sdks/python/apache_beam/internal/gcp/auth_test.py b/sdks/python/apache_beam/internal/gcp/auth_test.py index 654d8e815a50..fe16acc3c089 100644 --- a/sdks/python/apache_beam/internal/gcp/auth_test.py +++ b/sdks/python/apache_beam/internal/gcp/auth_test.py @@ -25,6 +25,7 @@ try: import google.auth as gauth + import google_auth_httplib2 # pylint: disable=unused-import except ImportError: gauth = None # type: ignore diff --git a/sdks/python/apache_beam/ml/gcp/recommendations_ai.py b/sdks/python/apache_beam/ml/gcp/recommendations_ai.py index 935ca690adc9..1bce097b6046 100644 --- a/sdks/python/apache_beam/ml/gcp/recommendations_ai.py +++ b/sdks/python/apache_beam/ml/gcp/recommendations_ai.py @@ -26,6 +26,7 @@ from google.api_core.retry import Retry +from cachetools.func import ttl_cache from apache_beam import pvalue from apache_beam.metrics import Metrics from apache_beam.options.pipeline_options import GoogleCloudOptions @@ -34,7 +35,6 @@ from apache_beam.transforms import PTransform from apache_beam.transforms.util import GroupIntoBatches from apache_beam.utils import retry -from cachetools.func import ttl_cache # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports try: diff --git a/sdks/python/apache_beam/ml/rag/ingestion/alloydb.py b/sdks/python/apache_beam/ml/rag/ingestion/alloydb.py index a4654b032a86..229c3e2bd99b 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/alloydb.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/alloydb.py @@ -14,39 +14,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -import logging from dataclasses import dataclass -from dataclasses import field from typing import Any -from typing import Callable from typing import Dict from typing import List -from typing import Literal -from typing import NamedTuple from typing import Optional -from typing import Type -from typing import Union -import apache_beam as beam -from apache_beam.coders import registry -from apache_beam.coders.row_coder import RowCoder -from apache_beam.io.jdbc import WriteToJdbc -from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig -from apache_beam.ml.rag.types import Chunk - -_LOGGER = logging.getLogger(__name__) +from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig +from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig +from apache_beam.ml.rag.ingestion.postgres import ColumnSpecsBuilder +from apache_beam.ml.rag.ingestion.postgres import PostgresVectorWriterConfig +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec +from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution @dataclass class AlloyDBLanguageConnectorConfig: - """Configuration options for AlloyDB Java language connector. + """Configuration options for AlloyDB language connector. Contains all parameters needed to configure a connection using the AlloyDB Java connector via JDBC. For details see https://github.com/GoogleCloudPlatform/alloydb-java-connector/blob/main/docs/jdbc.md Attributes: + username: Database username. + password: Database password. Can be empty string when using IAM. database_name: Name of the database to connect to. instance_name: Fullly qualified instance. Format: 'projects//locations//clusters//instances @@ -60,7 +52,13 @@ class AlloyDBLanguageConnectorConfig: delegated impersonation. admin_service_endpoint: Optional custom API service endpoint. quota_project: Optional project ID for quota and billing. + connection_properties: Optional JDBC connection properties dict. + Example: {'ssl': 'true'} + additional_properties: Additional properties to be added to the JDBC + url. Example: {'someProperty': 'true'} """ + username: str + password: str database_name: str instance_name: str ip_type: str = "PRIVATE" @@ -69,6 +67,8 @@ class AlloyDBLanguageConnectorConfig: delegates: Optional[List[str]] = None admin_service_endpoint: Optional[str] = None quota_project: Optional[str] = None + connection_properties: Optional[Dict[str, str]] = None + additional_properties: Optional[Dict[str, Any]] = None def to_jdbc_url(self) -> str: """Convert options to a properly formatted JDBC URL. @@ -101,674 +101,42 @@ def to_jdbc_url(self) -> str: if self.quota_project: properties["alloydbQuotaProject"] = self.quota_project + if self.additional_properties: + properties.update(self.additional_properties) + property_string = "&".join(f"{k}={v}" for k, v in properties.items()) return url + property_string - -@dataclass -class AlloyDBConnectionConfig: - """Configuration for AlloyDB database connection. - - Provides connection details and options for connecting to an AlloyDB - instance. - - Attributes: - jdbc_url: JDBC URL for the AlloyDB instance. - Example: 'jdbc:postgresql://host:port/database' - username: Database username. - password: Database password. - connection_properties: Optional JDBC connection properties dict. - Example: {'ssl': 'true'} - connection_init_sqls: Optional list of SQL statements to execute when - connection is established. - autosharding: Enable automatic re-sharding of bundles to scale the - number of shards with workers. - max_connections: Optional number of connections in the pool. - Use negative for no limit. - write_batch_size: Optional write batch size for bulk operations. - additional_jdbc_args: Additional arguments that will be passed to - WriteToJdbc. These may include 'driver_jars', 'expansion_service', - 'classpath', etc. See full set of args at - :class:`~apache_beam.io.jdbc.WriteToJdbc` - - Example: - >>> config = AlloyDBConnectionConfig( - ... jdbc_url='jdbc:postgresql://localhost:5432/mydb', - ... username='user', - ... password='pass', - ... connection_properties={'ssl': 'true'}, - ... max_connections=10 - ... ) - """ - jdbc_url: str - username: str - password: str - connection_properties: Optional[Dict[str, str]] = None - connection_init_sqls: Optional[List[str]] = None - autosharding: Optional[bool] = None - max_connections: Optional[int] = None - write_batch_size: Optional[int] = None - additional_jdbc_args: Dict[str, Any] = field(default_factory=dict) - - @classmethod - def with_language_connector( - cls, - connector_options: AlloyDBLanguageConnectorConfig, - username: str, - password: str, - connection_properties: Optional[Dict[str, str]] = None, - connection_init_sqls: Optional[List[str]] = None, - autosharding: Optional[bool] = None, - max_connections: Optional[int] = None, - write_batch_size: Optional[int] = None) -> 'AlloyDBConnectionConfig': - """Create AlloyDBConnectionConfig using the AlloyDB language connector. - - Args: - connector_options: AlloyDB language connector configuration options. - username: Database username. For IAM auth, this should be the IAM - user email. - password: Database password. Can be empty string when using IAM - auth. - connection_properties: Additional JDBC connection properties. - connection_init_sqls: SQL statements to execute on connection. - autosharding: Enable autosharding. - max_connections: Max connections in pool. - write_batch_size: Write batch size. - - Returns: - Configured AlloyDBConnectionConfig instance. - - Example: - >>> options = AlloyDBLanguageConnectorConfig( - ... database_name="mydb", - ... instance_name="projects/my-project/locations/us-central1\ - .... /clusters/my-cluster/instances/my-instance", - ... ip_type="PUBLIC", - ... enable_iam_auth=True - ... ) - """ - return cls( - jdbc_url=connector_options.to_jdbc_url(), - username=username, - password=password, - connection_properties=connection_properties, - connection_init_sqls=connection_init_sqls, - autosharding=autosharding, - max_connections=max_connections, - write_batch_size=write_batch_size, - additional_jdbc_args={ - 'classpath': [ - "org.postgresql:postgresql:42.2.16", - "com.google.cloud:alloydb-jdbc-connector:1.2.0" - ] - }) - - -@dataclass -class ConflictResolution: - """Specification for how to handle conflicts during insert. - - Configures conflict handling behavior when inserting records that may - violate unique constraints. - - Attributes: - on_conflict_fields: Field(s) that determine uniqueness. Can be a single - field name or list of field names for composite constraints. - action: How to handle conflicts - either "UPDATE" or "IGNORE". - UPDATE: Updates existing record with new values. - IGNORE: Skips conflicting records. - update_fields: Optional list of fields to update on conflict. If None, - all non-conflict fields are updated. - - Examples: - Simple primary key: - >>> ConflictResolution("id") - - Composite key with specific update fields: - >>> ConflictResolution( - ... on_conflict_fields=["source", "timestamp"], - ... action="UPDATE", - ... update_fields=["embedding", "content"] - ... ) - - Ignore conflicts: - >>> ConflictResolution( - ... on_conflict_fields="id", - ... action="IGNORE" - ... ) - """ - on_conflict_fields: Union[str, List[str]] - action: Literal["UPDATE", "IGNORE"] = "UPDATE" - update_fields: Optional[List[str]] = None - - def maybe_set_default_update_fields(self, columns: List[str]): - if self.action != "UPDATE": - return - if self.update_fields is not None: - return - - conflict_fields = ([self.on_conflict_fields] if isinstance( - self.on_conflict_fields, str) else self.on_conflict_fields) - self.update_fields = [col for col in columns if col not in conflict_fields] - - def get_conflict_clause(self) -> str: - """Get conflict clause with update fields.""" - conflict_fields = [self.on_conflict_fields] \ - if isinstance(self.on_conflict_fields, str) \ - else self.on_conflict_fields - - if self.action == "IGNORE": - conflict_fields_string = f"({', '.join(conflict_fields)})" \ - if len(conflict_fields) > 0 else "" - return f"ON CONFLICT {conflict_fields_string} DO NOTHING" - - # update_fields should be set by query builder before this is called - assert self.update_fields is not None, \ - "update_fields must be set before generating conflict clause" - updates = [f"{field} = EXCLUDED.{field}" for field in self.update_fields] - return f"ON CONFLICT " \ - f"({', '.join(conflict_fields)}) DO UPDATE SET {', '.join(updates)}" - - -def chunk_embedding_fn(chunk: Chunk) -> str: - """Convert embedding to PostgreSQL array string. - - Formats dense embedding as a PostgreSQL-compatible array string. - Example: [1.0, 2.0] -> '{1.0,2.0}' - - Args: - chunk: Input Chunk object. - - Returns: - str: PostgreSQL array string representation of the embedding. - - Raises: - ValueError: If chunk has no dense embedding. - """ - if chunk.embedding is None or chunk.embedding.dense_embedding is None: - raise ValueError(f'Expected chunk to contain embedding. {chunk}') - return '{' + ','.join(str(x) for x in chunk.embedding.dense_embedding) + '}' - - -def chunk_content_fn(chunk: Chunk) -> str: - """Extract content text from chunk. - - Args: - chunk: Input Chunk object. - - Returns: - str: The chunk's content text. - """ - if chunk.content.text is None: - raise ValueError(f'Expected chunk to contain content. {chunk}') - return chunk.content.text - - -def chunk_metadata_fn(chunk: Chunk) -> str: - """Extract metadata from chunk as JSON string. - - Args: - chunk: Input Chunk object. - - Returns: - str: JSON string representation of the chunk's metadata. - """ - return json.dumps(chunk.metadata) - - -@dataclass -class ColumnSpec: - """Specification for mapping Chunk fields to SQL columns for insertion. - - Defines how to extract and format values from Chunks into database columns, - handling the full pipeline from Python value to SQL insertion. - - The insertion process works as follows: - - value_fn extracts a value from the Chunk and formats it as needed - - The value is stored in a NamedTuple field with the specified python_type - - During SQL insertion, the value is bound to a ? placeholder - - Attributes: - column_name: The column name in the database table. - python_type: Python type for the NamedTuple field that will hold the - value. Must be compatible with must be compatible with - :class:`~apache_beam.coders.row_coder.RowCoder`. - value_fn: Function to extract and format the value from a Chunk. - Takes a Chunk and returns a value of python_type. - sql_typecast: Optional SQL type cast to append to the ? placeholder. - Common examples: - - "::float[]" for vector arrays - - "::jsonb" for JSON data - - Examples: - Basic text column (uses standard JDBC type mapping): - >>> ColumnSpec.text( - ... column_name="content", - ... value_fn=lambda chunk: chunk.content.text - ... ) - # Results in: INSERT INTO table (content) VALUES (?) - - Vector column with explicit array casting: - >>> ColumnSpec.vector( - ... column_name="embedding", - ... value_fn=lambda chunk: '{' + - ... ','.join(map(str, chunk.embedding.dense_embedding)) + '}' - ... ) - # Results in: INSERT INTO table (embedding) VALUES (?::float[]) - # The value_fn formats [1.0, 2.0] as '{1.0,2.0}' for PostgreSQL array - - Timestamp from metadata with explicit casting: - >>> ColumnSpec( - ... column_name="created_at", - ... python_type=str, - ... value_fn=lambda chunk: chunk.metadata.get("timestamp"), - ... sql_typecast="::timestamp" - ... ) - # Results in: INSERT INTO table (created_at) VALUES (?::timestamp) - # Allows inserting string timestamps with proper PostgreSQL casting - - Factory Methods: - text: Creates a text column specification (no type cast). - integer: Creates an integer column specification (no type cast). - float: Creates a float column specification (no type cast). - vector: Creates a vector column specification with float[] casting. - jsonb: Creates a JSONB column specification with jsonb casting. - """ - column_name: str - python_type: Type - value_fn: Callable[[Chunk], Any] - sql_typecast: Optional[str] = None - - @property - def placeholder(self) -> str: - """Get SQL placeholder with optional typecast.""" - return f"?{self.sql_typecast or ''}" - - @classmethod - def text( - cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': - """Create a text column specification.""" - return cls(column_name, str, value_fn) - - @classmethod - def integer( - cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': - """Create an integer column specification.""" - return cls(column_name, int, value_fn) - - @classmethod - def float( - cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': - """Create a float column specification.""" - return cls(column_name, float, value_fn) - - @classmethod - def vector( - cls, - column_name: str, - value_fn: Callable[[Chunk], Any] = chunk_embedding_fn) -> 'ColumnSpec': - """Create a vector column specification.""" - return cls(column_name, str, value_fn, "::float[]") - - @classmethod - def jsonb( - cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': - """Create a JSONB column specification.""" - return cls(column_name, str, value_fn, "::jsonb") - - -MetadataSpec = Union[ColumnSpec, Dict[str, ColumnSpec]] - - -def chunk_id_fn(chunk: Chunk) -> str: - """Extract ID from chunk. - - Args: - chunk: Input Chunk object. - - Returns: - str: The chunk's ID. - """ - return chunk.id - - -class _AlloyDBQueryBuilder: - def __init__( - self, - table_name: str, - *, - column_specs: List[ColumnSpec], - conflict_resolution: Optional[ConflictResolution] = None): - """Builds SQL queries for writing Chunks with Embeddings to AlloyDB. - """ - self.table_name = table_name - - self.column_specs = column_specs - self.conflict_resolution = conflict_resolution - - # Validate no duplicate column names - names = [col.column_name for col in self.column_specs] - duplicates = set(name for name in names if names.count(name) > 1) - if duplicates: - raise ValueError(f"Duplicate column names found: {duplicates}") - - # Create NamedTuple type - fields = [(col.column_name, col.python_type) for col in self.column_specs] - type_name = f"VectorRecord_{table_name}" - self.record_type = NamedTuple(type_name, fields) # type: ignore - - # Register coder - registry.register_coder(self.record_type, RowCoder) - - # Set default update fields to all non-conflict fields if update fields are - # not specified - if self.conflict_resolution: - self.conflict_resolution.maybe_set_default_update_fields( - [col.column_name for col in self.column_specs if col.column_name]) - - def build_insert(self) -> str: - """Build INSERT query with proper type casting.""" - # Get column names and placeholders - fields = [col.column_name for col in self.column_specs] - placeholders = [col.placeholder for col in self.column_specs] - - # Build base query - query = f""" - INSERT INTO {self.table_name} - ({', '.join(fields)}) - VALUES ({', '.join(placeholders)}) - """ - - # Add conflict handling if configured - if self.conflict_resolution: - query += f" {self.conflict_resolution.get_conflict_clause()}" - - _LOGGER.info("Query with placeholders %s", query) - return query - - def create_converter(self) -> Callable[[Chunk], NamedTuple]: - """Creates a function to convert Chunks to records.""" - def convert(chunk: Chunk) -> self.record_type: # type: ignore - return self.record_type( - **{col.column_name: col.value_fn(chunk) - for col in self.column_specs}) # type: ignore - - return convert - - -class ColumnSpecsBuilder: - """Builder for :class:`.ColumnSpec`'s with chainable methods.""" - def __init__(self): - self._specs: List[ColumnSpec] = [] - - @staticmethod - def with_defaults() -> 'ColumnSpecsBuilder': - """Add all default column specifications.""" - return ( - ColumnSpecsBuilder().with_id_spec().with_embedding_spec(). - with_content_spec().with_metadata_spec()) - - def with_id_spec( - self, - column_name: str = "id", - python_type: Type = str, - convert_fn: Optional[Callable[[str], Any]] = None, - sql_typecast: Optional[str] = None) -> 'ColumnSpecsBuilder': - """Add ID :class:`.ColumnSpec` with optional type and conversion. - - Args: - column_name: Name for the ID column (defaults to "id") - python_type: Python type for the column (defaults to str) - convert_fn: Optional function to convert the chunk ID - If None, uses ID as-is - sql_typecast: Optional SQL type cast - - Returns: - Self for method chaining - - Example: - >>> builder.with_id_spec( - ... column_name="doc_id", - ... python_type=int, - ... convert_fn=lambda id: int(id.split('_')[1]) - ... ) - """ - def value_fn(chunk: Chunk) -> Any: - value = chunk.id - return convert_fn(value) if convert_fn else value - - self._specs.append( - ColumnSpec( - column_name=column_name, - python_type=python_type, - value_fn=value_fn, - sql_typecast=sql_typecast)) - return self - - def with_content_spec( - self, - column_name: str = "content", - python_type: Type = str, - convert_fn: Optional[Callable[[str], Any]] = None, - sql_typecast: Optional[str] = None) -> 'ColumnSpecsBuilder': - """Add content :class:`.ColumnSpec` with optional type and conversion. - - Args: - column_name: Name for the content column (defaults to "content") - python_type: Python type for the column (defaults to str) - convert_fn: Optional function to convert the content text - If None, uses content text as-is - sql_typecast: Optional SQL type cast - - Returns: - Self for method chaining - - Example: - >>> builder.with_content_spec( - ... column_name="content_length", - ... python_type=int, - ... convert_fn=len # Store content length instead of content - ... ) - """ - def value_fn(chunk: Chunk) -> Any: - if chunk.content.text is None: - raise ValueError(f'Expected chunk to contain content. {chunk}') - value = chunk.content.text - return convert_fn(value) if convert_fn else value - - self._specs.append( - ColumnSpec( - column_name=column_name, - python_type=python_type, - value_fn=value_fn, - sql_typecast=sql_typecast)) - return self - - def with_metadata_spec( - self, - column_name: str = "metadata", - python_type: Type = str, - convert_fn: Optional[Callable[[Dict[str, Any]], Any]] = None, - sql_typecast: Optional[str] = "::jsonb") -> 'ColumnSpecsBuilder': - """Add metadata :class:`.ColumnSpec` with optional type and conversion. - - Args: - column_name: Name for the metadata column (defaults to "metadata") - python_type: Python type for the column (defaults to str) - convert_fn: Optional function to convert the metadata dictionary - If None and python_type is str, converts to JSON string - sql_typecast: Optional SQL type cast (defaults to "::jsonb") - - Returns: - Self for method chaining - - Example: - >>> builder.with_metadata_spec( - ... column_name="meta_tags", - ... python_type=list, - ... convert_fn=lambda meta: list(meta.keys()), - ... sql_typecast="::text[]" - ... ) - """ - def value_fn(chunk: Chunk) -> Any: - if convert_fn: - return convert_fn(chunk.metadata) - return json.dumps( - chunk.metadata) if python_type == str else chunk.metadata - - self._specs.append( - ColumnSpec( - column_name=column_name, - python_type=python_type, - value_fn=value_fn, - sql_typecast=sql_typecast)) - return self - - def with_embedding_spec( - self, - column_name: str = "embedding", - convert_fn: Optional[Callable[[List[float]], Any]] = None - ) -> 'ColumnSpecsBuilder': - """Add embedding :class:`.ColumnSpec` with optional conversion. - - Args: - column_name: Name for the embedding column (defaults to "embedding") - convert_fn: Optional function to convert the dense embedding values - If None, uses default PostgreSQL array format - - Returns: - Self for method chaining - - Example: - >>> builder.with_embedding_spec( - ... column_name="embedding_vector", - ... convert_fn=lambda values: '{' + ','.join(f"{x:.4f}" - ... for x in values) + '}' - ... ) - """ - def value_fn(chunk: Chunk) -> Any: - if chunk.embedding is None or chunk.embedding.dense_embedding is None: - raise ValueError(f'Expected chunk to contain embedding. {chunk}') - values = chunk.embedding.dense_embedding - if convert_fn: - return convert_fn(values) - return '{' + ','.join(str(x) for x in values) + '}' - - self._specs.append( - ColumnSpec.vector(column_name=column_name, value_fn=value_fn)) - return self - - def add_metadata_field( - self, - field: str, - python_type: Type, - column_name: Optional[str] = None, - convert_fn: Optional[Callable[[Any], Any]] = None, - default: Any = None, - sql_typecast: Optional[str] = None) -> 'ColumnSpecsBuilder': - """""Add a :class:`.ColumnSpec` that extracts and converts a field from - chunk metadata. - - Args: - field: Key to extract from chunk metadata - python_type: Python type for the column (e.g. str, int, float) - column_name: Name for the column (defaults to metadata field name) - convert_fn: Optional function to convert the extracted value to - desired type. If None, value is used as-is - default: Default value if field is missing from metadata - sql_typecast: Optional SQL type cast (e.g. "::timestamp") - - Returns: - Self for chaining - - Examples: - Simple string field: - >>> builder.add_metadata_field("source", str) - - Integer with default: - >>> builder.add_metadata_field( - ... field="count", - ... python_type=int, - ... column_name="item_count", - ... default=0 - ... ) - - Float with conversion and default: - >>> builder.add_metadata_field( - ... field="confidence", - ... python_type=intfloat, - ... convert_fn=lambda x: round(float(x), 2), - ... default=0.0 - ... ) - - Timestamp with conversion and type cast: - >>> builder.add_metadata_field( - ... field="created_at", - ... python_type=intstr, - ... convert_fn=lambda ts: ts.replace('T', ' '), - ... sql_typecast="::timestamp" - ... ) - """ - name = column_name or field - - def value_fn(chunk: Chunk) -> Any: - value = chunk.metadata.get(field, default) - if value is not None and convert_fn is not None: - value = convert_fn(value) - return value - - spec = ColumnSpec( - column_name=name, - python_type=python_type, - value_fn=value_fn, - sql_typecast=sql_typecast) - - self._specs.append(spec) - return self - - def add_custom_column_spec(self, spec: ColumnSpec) -> 'ColumnSpecsBuilder': - """Add a custom :class:`.ColumnSpec` to the builder. - - Use this method when you need complete control over the :class:`.ColumnSpec` - , including custom value extraction and type handling. - - Args: - spec: A :class:`.ColumnSpec` instance defining the column name, type, - value extraction, and optional SQL type casting. - - Returns: - Self for method chaining - - Examples: - Custom text column from chunk metadata: - >>> builder.add_custom_column_spec( - ... ColumnSpec.text( - ... name="source_and_id", - ... value_fn=lambda chunk: \ - ... f"{chunk.metadata.get('source')}_{chunk.id}" - ... ) - ... ) - """ - self._specs.append(spec) - return self - - def build(self) -> List[ColumnSpec]: - """Build the final list of column specifications.""" - return self._specs.copy() + def to_connection_config(self): + return ConnectionConfig( + jdbc_url=self.to_jdbc_url(), + username=self.username, + password=self.password, + connection_properties=self.connection_properties, + additional_jdbc_args=self.additional_jdbc_args()) + + def additional_jdbc_args(self) -> Dict[str, List[Any]]: + return { + 'classpath': [ + "org.postgresql:postgresql:42.2.16", + "com.google.cloud:alloydb-jdbc-connector:1.2.0" + ] + } -class AlloyDBVectorWriterConfig(VectorDatabaseWriteConfig): +class AlloyDBVectorWriterConfig(PostgresVectorWriterConfig): def __init__( self, - connection_config: AlloyDBConnectionConfig, + connection_config: AlloyDBLanguageConnectorConfig, table_name: str, *, # pylint: disable=dangerous-default-value + write_config: WriteConfig = WriteConfig(), column_specs: List[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build( ), conflict_resolution: Optional[ConflictResolution] = ConflictResolution( on_conflict_fields=[], action='IGNORE')): - """Configuration for writing vectors to AlloyDB using managed transforms. + """Configuration for writing vectors to AlloyDB. Supports flexible schema configuration through column specifications and conflict resolution strategies. @@ -776,68 +144,61 @@ def __init__( Args: connection_config: AlloyDB connection configuration. table_name: Target table name. - column_specs: Column specifications. If None, uses default Chunk schema. - Use ColumnSpecsBuilder to construct the specifications. - conflict_resolution: Optional strategy for handling insert conflicts. - ON CONFLICT DO NOTHING by default. + write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control + batch sizes, authosharding, etc. + column_specs: + Use :class:`~.postgres_common.ColumnSpecsBuilder` to configure how + embeddings and metadata are written a database + schema. If None, uses default Chunk schema. + conflict_resolution: Optional + :class:`~.postgres_common.ConflictResolution` + strategy for handling insert conflicts. ON CONFLICT DO NOTHING by + default. Examples: Basic usage with default schema: + >>> config = AlloyDBVectorWriterConfig( ... connection_config=AlloyDBConnectionConfig(...), ... table_name='embeddings' ... ) + Simple case with default schema: + + >>> config = PostgresVectorWriterConfig( + ... connection_config=ConnectionConfig(...), + ... table_name='embeddings' + ... ) + Custom schema with metadata fields: + >>> specs = (ColumnSpecsBuilder() - ... .with_id_spec() + ... .with_id_spec(column_name="my_id_column") ... .with_embedding_spec(column_name="embedding_vec") - ... .add_metadata_field("source") + ... .add_metadata_field(field="source", column_name="src") ... .add_metadata_field( ... "timestamp", ... column_name="created_at", ... sql_typecast="::timestamp" ... ) ... .build()) + + Minimal schema (only ID + embedding written) + + >>> column_specs = (ColumnSpecsBuilder() + ... .with_id_spec() + ... .with_embedding_spec() + ... .build()) + >>> config = AlloyDBVectorWriterConfig( ... connection_config=AlloyDBConnectionConfig(...), ... table_name='embeddings', ... column_specs=specs ... ) """ - self.connection_config = connection_config - # NamedTuple is created and registered here during pipeline construction - self.query_builder = _AlloyDBQueryBuilder( - table_name, + super().__init__( + connection_config=connection_config.to_connection_config(), + write_config=write_config, + table_name=table_name, column_specs=column_specs, conflict_resolution=conflict_resolution) - - def create_write_transform(self) -> beam.PTransform: - return _WriteToAlloyDBVectorDatabase(self) - - -class _WriteToAlloyDBVectorDatabase(beam.PTransform): - """Implementation of BigQuery vector database write. """ - def __init__(self, config: AlloyDBVectorWriterConfig): - self.config = config - - def expand(self, pcoll: beam.PCollection[Chunk]): - return ( - pcoll - | "Convert to Records" >> beam.Map( - self.config.query_builder.create_converter()) - | "Write to AlloyDB" >> WriteToJdbc( - table_name=self.config.query_builder.table_name, - driver_class_name="org.postgresql.Driver", - jdbc_url=self.config.connection_config.jdbc_url, - username=self.config.connection_config.username, - password=self.config.connection_config.password, - statement=self.config.query_builder.build_insert(), - connection_properties=self.config.connection_config. - connection_properties, - connection_init_sqls=self.config.connection_config. - connection_init_sqls, - autosharding=self.config.connection_config.autosharding, - max_connections=self.config.connection_config.max_connections, - write_batch_size=self.config.connection_config.write_batch_size, - **self.config.connection_config.additional_jdbc_args)) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py index 6939d09e2bad..cc7db95a1d07 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py @@ -15,137 +15,26 @@ # limitations under the License. # -import hashlib -import json import logging import os import secrets import time import unittest -from typing import List -from typing import NamedTuple import psycopg2 import pytest import apache_beam as beam -from apache_beam.coders import registry -from apache_beam.coders.row_coder import RowCoder from apache_beam.io.jdbc import ReadFromJdbc -from apache_beam.ml.rag.ingestion.alloydb import AlloyDBConnectionConfig +from apache_beam.ml.rag.ingestion import test_utils from apache_beam.ml.rag.ingestion.alloydb import AlloyDBLanguageConnectorConfig from apache_beam.ml.rag.ingestion.alloydb import AlloyDBVectorWriterConfig -from apache_beam.ml.rag.ingestion.alloydb import ColumnSpec -from apache_beam.ml.rag.ingestion.alloydb import ColumnSpecsBuilder -from apache_beam.ml.rag.ingestion.alloydb import ConflictResolution -from apache_beam.ml.rag.ingestion.alloydb import chunk_embedding_fn -from apache_beam.ml.rag.types import Chunk -from apache_beam.ml.rag.types import Content -from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteTransform from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to -TestRow = NamedTuple( - 'TestRow', - [('id', str), ('embedding', List[float]), ('content', str), - ('metadata', str)]) -registry.register_coder(TestRow, RowCoder) - -CustomSpecsRow = NamedTuple( - 'CustomSpecsRow', - [ - ('custom_id', str), # For id_spec test - ('embedding_vec', List[float]), # For embedding_spec test - ('content_col', str), # For content_spec test - ('metadata', str) - ]) -registry.register_coder(CustomSpecsRow, RowCoder) - -MetadataConflictRow = NamedTuple( - 'MetadataConflictRow', - [ - ('id', str), - ('source', str), # For metadata_spec and composite key - ('timestamp', str), # For metadata_spec and composite key - ('content', str), - ('embedding', List[float]), - ('metadata', str) - ]) -registry.register_coder(MetadataConflictRow, RowCoder) - _LOGGER = logging.getLogger(__name__) -VECTOR_SIZE = 768 - - -def row_to_chunk(row) -> Chunk: - # Parse embedding string back to float list - embedding_list = [float(x) for x in row.embedding.strip('[]').split(',')] - return Chunk( - id=row.id, - content=Content(text=row.content if hasattr(row, 'content') else None), - embedding=Embedding(dense_embedding=embedding_list), - metadata=json.loads(row.metadata) if hasattr(row, 'metadata') else {}) - - -class ChunkTestUtils: - """Helper functions for generating test Chunks.""" - @staticmethod - def from_seed(seed: int, content_prefix: str, seed_multiplier: int) -> Chunk: - """Creates a deterministic Chunk from a seed value.""" - return Chunk( - id=f"id_{seed}", - content=Content(text=f"{content_prefix}{seed}"), - embedding=Embedding( - dense_embedding=[ - float(seed + i * seed_multiplier) / 100 - for i in range(VECTOR_SIZE) - ]), - metadata={"seed": str(seed)}) - - @staticmethod - def get_expected_values( - range_start: int, - range_end: int, - content_prefix: str = "Testval", - seed_multiplier: int = 1) -> List[Chunk]: - """Returns a range of test Chunks.""" - return [ - ChunkTestUtils.from_seed(i, content_prefix, seed_multiplier) - for i in range(range_start, range_end) - ] - - -class HashingFn(beam.CombineFn): - """Hashing function for verification.""" - def create_accumulator(self): - return [] - - def add_input(self, accumulator, input): - # Hash based on content like TestRow's SelectNameFn - accumulator.append(input.content.text if input.content.text else "") - return accumulator - - def merge_accumulators(self, accumulators): - merged = [] - for acc in accumulators: - merged.extend(acc) - return merged - - def extract_output(self, accumulator): - sorted_values = sorted(accumulator) - return hashlib.md5(''.join(sorted_values).encode()).hexdigest() - - -def generate_expected_hash(num_records: int) -> str: - chunks = ChunkTestUtils.get_expected_values(0, num_records) - values = sorted( - chunk.content.text if chunk.content.text else "" for chunk in chunks) - return hashlib.md5(''.join(values).encode()).hexdigest() - - -def key_on_id(chunk): - return (int(chunk.id.split('_')[1]), chunk) @pytest.mark.uses_gcp_java_expansion_service @@ -156,8 +45,8 @@ def key_on_id(chunk): @unittest.skipUnless( os.environ.get('ALLOYDB_PASSWORD'), "ALLOYDB_PASSWORD environment var is not provided") -class AlloyDBVectorWriterConfigTest(unittest.TestCase): - ALLOYDB_TABLE_PREFIX = 'python_rag_alloydb_' +class AlloydbVectorWriterConfigTest(unittest.TestCase): + POSTGRES_TABLE_PREFIX = 'python_rag_postgres_' @classmethod def setUpClass(cls): @@ -189,18 +78,11 @@ def skip_if_dataflow_runner(self): def setUp(self): self.write_test_pipeline = TestPipeline(is_integration_test=True) self.read_test_pipeline = TestPipeline(is_integration_test=True) - self.write_test_pipeline2 = TestPipeline(is_integration_test=True) - self.read_test_pipeline2 = TestPipeline(is_integration_test=True) self._runner = type(self.read_test_pipeline.runner).__name__ - self.default_table_name = f"{self.ALLOYDB_TABLE_PREFIX}" \ + self.default_table_name = "default_embeddings" + f"{self.POSTGRES_TABLE_PREFIX}" \ f"{self.table_suffix}" - self.default_table_name = f"{self.ALLOYDB_TABLE_PREFIX}" \ - f"{self.table_suffix}" - self.custom_table_name = f"{self.ALLOYDB_TABLE_PREFIX}" \ - f"_custom_{self.table_suffix}" - self.metadata_conflicts_table = f"{self.ALLOYDB_TABLE_PREFIX}" \ - f"_meta_conf_{self.table_suffix}" self.jdbc_url = f'jdbc:postgresql://{self.host}:{self.port}/{self.database}' @@ -210,32 +92,11 @@ def setUp(self): f""" CREATE TABLE {self.default_table_name} ( id TEXT PRIMARY KEY, - embedding VECTOR({VECTOR_SIZE}), + embedding VECTOR({test_utils.VECTOR_SIZE}), content TEXT, metadata JSONB ) """) - cursor.execute( - f""" - CREATE TABLE {self.custom_table_name} ( - custom_id TEXT PRIMARY KEY, - embedding_vec VECTOR(2), - content_col TEXT, - metadata JSONB - ) - """) - cursor.execute( - f""" - CREATE TABLE {self.metadata_conflicts_table} ( - id TEXT, - source TEXT, - timestamp TIMESTAMP, - content TEXT, - embedding VECTOR(2), - PRIMARY KEY (id), - UNIQUE (source, timestamp) - ) - """) _LOGGER = logging.getLogger(__name__) _LOGGER.info("Created table %s", self.default_table_name) @@ -243,8 +104,6 @@ def tearDown(self): # Drop test table with self.conn.cursor() as cursor: cursor.execute(f"DROP TABLE IF EXISTS {self.default_table_name}") - cursor.execute(f"DROP TABLE IF EXISTS {self.custom_table_name}") - cursor.execute(f"DROP TABLE IF EXISTS {self.metadata_conflicts_table}") _LOGGER = logging.getLogger(__name__) _LOGGER.info("Dropped table %s", self.default_table_name) @@ -253,111 +112,32 @@ def tearDownClass(cls): if hasattr(cls, 'conn'): cls.conn.close() - def test_default_schema(self): - """Test basic write with default schema.""" - jdbc_url = f'jdbc:postgresql://{self.host}:{self.port}/{self.database}' - connection_config = AlloyDBConnectionConfig( - jdbc_url=jdbc_url, username=self.username, password=self.password) - - config = AlloyDBVectorWriterConfig( - connection_config=connection_config, table_name=self.default_table_name) - - # Create test chunks - num_records = 1500 - sample_size = min(500, num_records // 2) - # Generate test chunks - chunks = ChunkTestUtils.get_expected_values(0, num_records) - - # Run pipeline and verify - self.write_test_pipeline.not_use_test_runner_api = True - - with self.write_test_pipeline as p: - _ = (p | beam.Create(chunks) | config.create_write_transform()) - - self.read_test_pipeline.not_use_test_runner_api = True - # Read pipeline to verify - read_query = f""" - SELECT - CAST(id AS VARCHAR(255)), - CAST(content AS VARCHAR(255)), - CAST(embedding AS text), - CAST(metadata AS text) - FROM {self.default_table_name} - """ - - # Read and verify pipeline - with self.read_test_pipeline as p: - rows = ( - p - | ReadFromJdbc( - table_name=self.default_table_name, - driver_class_name="org.postgresql.Driver", - jdbc_url=jdbc_url, - username=self.username, - password=self.password, - query=read_query)) - - count_result = rows | "Count All" >> beam.combiners.Count.Globally() - assert_that(count_result, equal_to([num_records]), label='count_check') - - chunks = (rows | "To Chunks" >> beam.Map(row_to_chunk)) - chunk_hashes = chunks | "Hash Chunks" >> beam.CombineGlobally(HashingFn()) - assert_that( - chunk_hashes, - equal_to([generate_expected_hash(num_records)]), - label='hash_check') - - # Sample validation - first_n = ( - chunks - | "Key on Index" >> beam.Map(key_on_id) - | f"Get First {sample_size}" >> beam.transforms.combiners.Top.Of( - sample_size, key=lambda x: x[0], reverse=True) - | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) - expected_first_n = ChunkTestUtils.get_expected_values(0, sample_size) - assert_that( - first_n, - equal_to([expected_first_n]), - label=f"first_{sample_size}_check") - - last_n = ( - chunks - | "Key on Index 2" >> beam.Map(key_on_id) - | f"Get Last {sample_size}" >> beam.transforms.combiners.Top.Of( - sample_size, key=lambda x: x[0]) - | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) - expected_last_n = ChunkTestUtils.get_expected_values( - num_records - sample_size, num_records)[::-1] - assert_that( - last_n, - equal_to([expected_last_n]), - label=f"last_{sample_size}_check") - def test_language_connector(self): """Test language connector.""" self.skip_if_dataflow_runner() - connector_options = AlloyDBLanguageConnectorConfig( + connection_config = AlloyDBLanguageConnectorConfig( + username=self.username, + password=self.password, database_name=self.database, instance_name="projects/apache-beam-testing/locations/us-central1/\ clusters/testing-psc/instances/testing-psc-1", ip_type="PSC") - connection_config = AlloyDBConnectionConfig.with_language_connector( - connector_options=connector_options, - username=self.username, - password=self.password) - config = AlloyDBVectorWriterConfig( + writer_config = AlloyDBVectorWriterConfig( connection_config=connection_config, table_name=self.default_table_name) # Create test chunks num_records = 150 sample_size = min(500, num_records // 2) - chunks = ChunkTestUtils.get_expected_values(0, num_records) + chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) self.write_test_pipeline.not_use_test_runner_api = True with self.write_test_pipeline as p: - _ = (p | beam.Create(chunks) | config.create_write_transform()) + _ = ( + p + | beam.Create(chunks) + | VectorDatabaseWriteTransform(writer_config)) self.read_test_pipeline.not_use_test_runner_api = True read_query = f""" @@ -375,33 +155,32 @@ def test_language_connector(self): | ReadFromJdbc( table_name=self.default_table_name, driver_class_name="org.postgresql.Driver", - jdbc_url=connector_options.to_jdbc_url(), + jdbc_url=connection_config.to_connection_config().jdbc_url, username=self.username, password=self.password, query=read_query, - classpath=[ - "org.postgresql:postgresql:42.2.16", - "com.google.cloud:alloydb-jdbc-connector:1.2.0" - ])) + classpath=connection_config.additional_jdbc_args()['classpath'])) count_result = rows | "Count All" >> beam.combiners.Count.Globally() assert_that(count_result, equal_to([num_records]), label='count_check') - chunks = (rows | "To Chunks" >> beam.Map(row_to_chunk)) - chunk_hashes = chunks | "Hash Chunks" >> beam.CombineGlobally(HashingFn()) + chunks = (rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk)) + chunk_hashes = chunks | "Hash Chunks" >> beam.CombineGlobally( + test_utils.HashingFn()) assert_that( chunk_hashes, - equal_to([generate_expected_hash(num_records)]), + equal_to([test_utils.generate_expected_hash(num_records)]), label='hash_check') # Sample validation first_n = ( chunks - | "Key on Index" >> beam.Map(key_on_id) + | "Key on Index" >> beam.Map(test_utils.key_on_id) | f"Get First {sample_size}" >> beam.transforms.combiners.Top.Of( sample_size, key=lambda x: x[0], reverse=True) | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) - expected_first_n = ChunkTestUtils.get_expected_values(0, sample_size) + expected_first_n = test_utils.ChunkTestUtils.get_expected_values( + 0, sample_size) assert_that( first_n, equal_to([expected_first_n]), @@ -409,654 +188,17 @@ def test_language_connector(self): last_n = ( chunks - | "Key on Index 2" >> beam.Map(key_on_id) + | "Key on Index 2" >> beam.Map(test_utils.key_on_id) | f"Get Last {sample_size}" >> beam.transforms.combiners.Top.Of( sample_size, key=lambda x: x[0]) | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) - expected_last_n = ChunkTestUtils.get_expected_values( + expected_last_n = test_utils.ChunkTestUtils.get_expected_values( num_records - sample_size, num_records)[::-1] assert_that( last_n, equal_to([expected_last_n]), label=f"last_{sample_size}_check") - def test_custom_specs(self): - """Test custom specifications for ID, embedding, and content.""" - self.skip_if_dataflow_runner() - num_records = 20 - - specs = ( - ColumnSpecsBuilder().add_custom_column_spec( - ColumnSpec.text( - column_name="custom_id", - value_fn=lambda chunk: - f"timestamp_{chunk.metadata.get('timestamp', '')}") - ).add_custom_column_spec( - ColumnSpec.vector( - column_name="embedding_vec", - value_fn=chunk_embedding_fn)).add_custom_column_spec( - ColumnSpec.text( - column_name="content_col", - value_fn=lambda chunk: - f"{len(chunk.content.text)}:{chunk.content.text}")). - with_metadata_spec().build()) - - connection_config = AlloyDBConnectionConfig( - jdbc_url=self.jdbc_url, username=self.username, password=self.password) - - writer_config = AlloyDBVectorWriterConfig( - connection_config=connection_config, - table_name=self.custom_table_name, - column_specs=specs) - - # Generate test chunks - test_chunks = [ - Chunk( - id=str(i), - content=Content(text=f"content_{i}"), - embedding=Embedding(dense_embedding=[float(i), float(i + 1)]), - metadata={"timestamp": f"2024-02-02T{i:02d}:00:00"}) - for i in range(num_records) - ] - - # Write pipeline - self.write_test_pipeline.not_use_test_runner_api = True - with self.write_test_pipeline as p: - _ = ( - p | beam.Create(test_chunks) | writer_config.create_write_transform()) - - # Read and verify - read_query = f""" - SELECT - CAST(custom_id AS VARCHAR(255)), - CAST(embedding_vec AS text), - CAST(content_col AS VARCHAR(255)), - CAST(metadata AS text) - FROM {self.custom_table_name} - ORDER BY custom_id - """ - - # Convert BeamRow back to Chunk - def custom_row_to_chunk(row): - # Extract timestamp from custom_id - timestamp = row.custom_id.split('timestamp_')[1] - # Extract index from timestamp - i = int(timestamp.split('T')[1][:2]) - - # Parse embedding vector - embedding_list = [ - float(x) for x in row.embedding_vec.strip('[]').split(',') - ] - - # Extract content from length-prefixed format - content = row.content_col.split(':', 1)[1] - - return Chunk( - id=str(i), - content=Content(text=content), - embedding=Embedding(dense_embedding=embedding_list), - metadata=json.loads(row.metadata)) - - self.read_test_pipeline.not_use_test_runner_api = True - with self.read_test_pipeline as p: - rows = ( - p - | ReadFromJdbc( - table_name=self.custom_table_name, - driver_class_name="org.postgresql.Driver", - jdbc_url=self.jdbc_url, - username=self.username, - password=self.password, - query=read_query)) - - # Verify count - count_result = rows | "Count All" >> beam.combiners.Count.Globally() - assert_that(count_result, equal_to([num_records]), label='count_check') - - chunks = rows | "To Chunks" >> beam.Map(custom_row_to_chunk) - assert_that(chunks, equal_to(test_chunks), label='chunks_check') - - def test_defaults_with_args_specs(self): - """Test custom specifications for ID, embedding, and content.""" - self.skip_if_dataflow_runner() - num_records = 20 - - specs = ( - ColumnSpecsBuilder().with_id_spec( - column_name="custom_id", - python_type=int, - convert_fn=lambda x: int(x), - sql_typecast="::text").with_content_spec( - column_name="content_col", - convert_fn=lambda x: f"{len(x)}:{x}", - ).with_embedding_spec( - column_name="embedding_vec").with_metadata_spec().build()) - - connection_config = AlloyDBConnectionConfig( - jdbc_url=self.jdbc_url, username=self.username, password=self.password) - - writer_config = AlloyDBVectorWriterConfig( - connection_config=connection_config, - table_name=self.custom_table_name, - column_specs=specs) - - # Generate test chunks - test_chunks = [ - Chunk( - id=str(i), - content=Content(text=f"content_{i}"), - embedding=Embedding(dense_embedding=[float(i), float(i + 1)]), - metadata={"timestamp": f"2024-02-02T{i:02d}:00:00"}) - for i in range(num_records) - ] - - # Write pipeline - self.write_test_pipeline.not_use_test_runner_api = True - with self.write_test_pipeline as p: - _ = ( - p | beam.Create(test_chunks) | writer_config.create_write_transform()) - - # Read and verify - read_query = f""" - SELECT - CAST(custom_id AS VARCHAR(255)), - CAST(embedding_vec AS text), - CAST(content_col AS VARCHAR(255)), - CAST(metadata AS text) - FROM {self.custom_table_name} - ORDER BY custom_id - """ - - # Convert BeamRow back to Chunk - def custom_row_to_chunk(row): - # Parse embedding vector - embedding_list = [ - float(x) for x in row.embedding_vec.strip('[]').split(',') - ] - - # Extract content from length-prefixed format - content = row.content_col.split(':', 1)[1] - - return Chunk( - id=row.custom_id, - content=Content(text=content), - embedding=Embedding(dense_embedding=embedding_list), - metadata=json.loads(row.metadata)) - - self.read_test_pipeline.not_use_test_runner_api = True - with self.read_test_pipeline as p: - rows = ( - p - | ReadFromJdbc( - table_name=self.custom_table_name, - driver_class_name="org.postgresql.Driver", - jdbc_url=self.jdbc_url, - username=self.username, - password=self.password, - query=read_query)) - - # Verify count - count_result = rows | "Count All" >> beam.combiners.Count.Globally() - assert_that(count_result, equal_to([num_records]), label='count_check') - - chunks = rows | "To Chunks" >> beam.Map(custom_row_to_chunk) - assert_that(chunks, equal_to(test_chunks), label='chunks_check') - - def test_default_id_embedding_specs(self): - """Test with only default id and embedding specs, others set to None.""" - self.skip_if_dataflow_runner() - num_records = 20 - connection_config = AlloyDBConnectionConfig( - jdbc_url=self.jdbc_url, username=self.username, password=self.password) - specs = ( - ColumnSpecsBuilder().with_id_spec() # Use default id spec - .with_embedding_spec() # Use default embedding spec - .build()) - - writer_config = AlloyDBVectorWriterConfig( - connection_config=connection_config, - table_name=self.default_table_name, - column_specs=specs) - - # Generate test chunks - test_chunks = ChunkTestUtils.get_expected_values(0, num_records) - - # Write pipeline - self.write_test_pipeline.not_use_test_runner_api = True - with self.write_test_pipeline as p: - _ = ( - p | beam.Create(test_chunks) | writer_config.create_write_transform()) - - # Read and verify only id and embedding - read_query = f""" - SELECT - CAST(id AS VARCHAR(255)), - CAST(embedding AS text) - FROM {self.default_table_name} - ORDER BY id - """ - - self.read_test_pipeline.not_use_test_runner_api = True - with self.read_test_pipeline as p: - rows = ( - p - | ReadFromJdbc( - table_name=self.default_table_name, - driver_class_name="org.postgresql.Driver", - jdbc_url=self.jdbc_url, - username=self.username, - password=self.password, - query=read_query)) - - chunks = rows | "To Chunks" >> beam.Map(row_to_chunk) - - # Create expected chunks with None values - expected_chunks = ChunkTestUtils.get_expected_values(0, num_records) - for chunk in expected_chunks: - chunk.content.text = None - chunk.metadata = {} - - assert_that(chunks, equal_to(expected_chunks), label='chunks_check') - - def test_metadata_spec_and_conflicts(self): - """Test metadata specification and conflict resolution.""" - self.skip_if_dataflow_runner() - num_records = 20 - - specs = ( - ColumnSpecsBuilder().with_id_spec().with_embedding_spec(). - with_content_spec().add_metadata_field( - field="source", - column_name="source", - python_type=str, - sql_typecast=None # Plain text field - ).add_metadata_field( - field="timestamp", python_type=str, - sql_typecast="::timestamp").build()) - - # Conflict resolution on source+timestamp - conflict_resolution = ConflictResolution( - on_conflict_fields=["source", "timestamp"], - action="UPDATE", - update_fields=["embedding", "content"]) - connection_config = AlloyDBConnectionConfig( - jdbc_url=self.jdbc_url, username=self.username, password=self.password) - writer_config = AlloyDBVectorWriterConfig( - connection_config=connection_config, - table_name=self.metadata_conflicts_table, - column_specs=specs, - conflict_resolution=conflict_resolution) - - # Generate initial test chunks - initial_chunks = [ - Chunk( - id=str(i), - content=Content(text=f"content_{i}"), - embedding=Embedding(dense_embedding=[float(i), float(i + 1)]), - metadata={ - "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00" - }) for i in range(num_records) - ] - - # Write initial chunks - self.write_test_pipeline.not_use_test_runner_api = True - with self.write_test_pipeline as p: - _ = ( - p | "Write Initial" >> beam.Create(initial_chunks) - | writer_config.create_write_transform()) - - # Generate conflicting chunks (same source+timestamp, different content) - conflicting_chunks = [ - Chunk( - id=f"new_{i}", - content=Content(text=f"updated_content_{i}"), - embedding=Embedding( - dense_embedding=[float(i) * 2, float(i + 1) * 2]), - metadata={ - "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00" - }) for i in range(num_records) - ] - - # Write conflicting chunks - self.write_test_pipeline2.not_use_test_runner_api = True - with self.write_test_pipeline2 as p: - _ = ( - p | "Write Conflicts" >> beam.Create(conflicting_chunks) - | writer_config.create_write_transform()) - - # Read and verify - read_query = f""" - SELECT - CAST(id AS VARCHAR(255)), - CAST(embedding AS text), - CAST(content AS VARCHAR(255)), - CAST(source AS VARCHAR(255)), - CAST(timestamp AS VARCHAR(255)) - FROM {self.metadata_conflicts_table} - ORDER BY timestamp, id - """ - - # Expected chunks after conflict resolution - expected_chunks = [ - Chunk( - id=str(i), - content=Content(text=f"updated_content_{i}"), - embedding=Embedding( - dense_embedding=[float(i) * 2, float(i + 1) * 2]), - metadata={ - "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00" - }) for i in range(num_records) - ] - - def metadata_row_to_chunk(row): - return Chunk( - id=row.id, - content=Content(text=row.content), - embedding=Embedding( - dense_embedding=[ - float(x) for x in row.embedding.strip('[]').split(',') - ]), - metadata={ - "source": row.source, - "timestamp": row.timestamp.replace(' ', 'T') - }) - - self.read_test_pipeline.not_use_test_runner_api = True - with self.read_test_pipeline as p: - rows = ( - p - | ReadFromJdbc( - table_name=self.metadata_conflicts_table, - driver_class_name="org.postgresql.Driver", - jdbc_url=self.jdbc_url, - username=self.username, - password=self.password, - query=read_query)) - - chunks = rows | "To Chunks" >> beam.Map(metadata_row_to_chunk) - assert_that(chunks, equal_to(expected_chunks), label='chunks_check') - - def test_conflict_resolution_update(self): - """Test conflict resolution with UPDATE action.""" - self.skip_if_dataflow_runner() - num_records = 20 - - connection_config = AlloyDBConnectionConfig( - jdbc_url=self.jdbc_url, username=self.username, password=self.password) - - conflict_resolution = ConflictResolution( - on_conflict_fields="id", - action="UPDATE", - update_fields=["embedding", "content"]) - - config = AlloyDBVectorWriterConfig( - connection_config=connection_config, - table_name=self.default_table_name, - conflict_resolution=conflict_resolution) - - # Generate initial test chunks - test_chunks = ChunkTestUtils.get_expected_values(0, num_records) - self.write_test_pipeline.not_use_test_runner_api = True - # Insert initial test chunks - with self.write_test_pipeline as p: - _ = ( - p - | "Create initial chunks" >> beam.Create(test_chunks) - | "Write initial chunks" >> config.create_write_transform()) - - read_query = f""" - SELECT - CAST(id AS VARCHAR(255)), - CAST(content AS VARCHAR(255)), - CAST(embedding AS text), - CAST(metadata AS text) - FROM {self.default_table_name} - ORDER BY id desc - """ - self.read_test_pipeline.not_use_test_runner_api = True - with self.read_test_pipeline as p: - rows = ( - p - | ReadFromJdbc( - table_name=self.default_table_name, - driver_class_name="org.postgresql.Driver", - jdbc_url=self.jdbc_url, - username=self.username, - password=self.password, - query=read_query)) - - chunks = ( - rows - | "To Chunks" >> beam.Map(row_to_chunk) - | "Key on Index" >> beam.Map(key_on_id) - | "Get First 500" >> beam.transforms.combiners.Top.Of( - num_records, key=lambda x: x[0], reverse=True) - | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) - assert_that( - chunks, equal_to([test_chunks]), label='original_chunks_check') - - updated_chunks = ChunkTestUtils.get_expected_values( - 0, num_records, content_prefix="Newcontent", seed_multiplier=2) - self.write_test_pipeline2.not_use_test_runner_api = True - with self.write_test_pipeline2 as p: - _ = ( - p - | "Create updated Chunks" >> beam.Create(updated_chunks) - | "Write updated Chunks" >> config.create_write_transform()) - self.read_test_pipeline2.not_use_test_runner_api = True - with self.read_test_pipeline2 as p: - rows = ( - p - | "Read Updated chunks" >> ReadFromJdbc( - table_name=self.default_table_name, - driver_class_name="org.postgresql.Driver", - jdbc_url=self.jdbc_url, - username=self.username, - password=self.password, - query=read_query)) - - chunks = ( - rows - | "To Chunks 2" >> beam.Map(row_to_chunk) - | "Key on Index 2" >> beam.Map(key_on_id) - | "Get First 500 2" >> beam.transforms.combiners.Top.Of( - num_records, key=lambda x: x[0], reverse=True) - | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) - assert_that( - chunks, equal_to([updated_chunks]), label='updated_chunks_check') - - def test_conflict_resolution_default_ignore(self): - """Test conflict resolution with default.""" - self.skip_if_dataflow_runner() - num_records = 20 - - connection_config = AlloyDBConnectionConfig( - jdbc_url=self.jdbc_url, username=self.username, password=self.password) - - config = AlloyDBVectorWriterConfig( - connection_config=connection_config, table_name=self.default_table_name) - - # Generate initial test chunks - test_chunks = ChunkTestUtils.get_expected_values(0, num_records) - self.write_test_pipeline.not_use_test_runner_api = True - # Insert initial test chunks - with self.write_test_pipeline as p: - _ = ( - p - | "Create initial chunks" >> beam.Create(test_chunks) - | "Write initial chunks" >> config.create_write_transform()) - - read_query = f""" - SELECT - CAST(id AS VARCHAR(255)), - CAST(content AS VARCHAR(255)), - CAST(embedding AS text), - CAST(metadata AS text) - FROM {self.default_table_name} - ORDER BY id desc - """ - self.read_test_pipeline.not_use_test_runner_api = True - with self.read_test_pipeline as p: - rows = ( - p - | ReadFromJdbc( - table_name=self.default_table_name, - driver_class_name="org.postgresql.Driver", - jdbc_url=self.jdbc_url, - username=self.username, - password=self.password, - query=read_query)) - - chunks = ( - rows - | "To Chunks" >> beam.Map(row_to_chunk) - | "Key on Index" >> beam.Map(key_on_id) - | "Get First 500" >> beam.transforms.combiners.Top.Of( - num_records, key=lambda x: x[0], reverse=True) - | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) - assert_that( - chunks, equal_to([test_chunks]), label='original_chunks_check') - - updated_chunks = ChunkTestUtils.get_expected_values( - 0, num_records, content_prefix="Newcontent", seed_multiplier=2) - self.write_test_pipeline2.not_use_test_runner_api = True - with self.write_test_pipeline2 as p: - _ = ( - p - | "Create updated Chunks" >> beam.Create(updated_chunks) - | "Write updated Chunks" >> config.create_write_transform()) - self.read_test_pipeline2.not_use_test_runner_api = True - with self.read_test_pipeline2 as p: - rows = ( - p - | "Read Updated chunks" >> ReadFromJdbc( - table_name=self.default_table_name, - driver_class_name="org.postgresql.Driver", - jdbc_url=self.jdbc_url, - username=self.username, - password=self.password, - query=read_query)) - - chunks = ( - rows - | "To Chunks 2" >> beam.Map(row_to_chunk) - | "Key on Index 2" >> beam.Map(key_on_id) - | "Get First 500 2" >> beam.transforms.combiners.Top.Of( - num_records, key=lambda x: x[0], reverse=True) - | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) - assert_that(chunks, equal_to([test_chunks]), label='updated_chunks_check') - - def test_conflict_resolution_default_update_fields(self): - """Test conflict resolution with default update fields (all non-conflict - fields).""" - self.skip_if_dataflow_runner() - num_records = 20 - - connection_config = AlloyDBConnectionConfig( - jdbc_url=self.jdbc_url, username=self.username, password=self.password) - - # Create a conflict resolution with only the conflict field specified - # No update_fields specified - should default to all non-conflict fields - conflict_resolution = ConflictResolution( - on_conflict_fields="id", action="UPDATE") - - config = AlloyDBVectorWriterConfig( - connection_config=connection_config, - table_name=self.default_table_name, - conflict_resolution=conflict_resolution) - - # Generate initial test chunks - test_chunks = ChunkTestUtils.get_expected_values(0, num_records) - self.write_test_pipeline.not_use_test_runner_api = True - - # Insert initial test chunks - with self.write_test_pipeline as p: - _ = ( - p - | "Create initial chunks" >> beam.Create(test_chunks) - | "Write initial chunks" >> config.create_write_transform()) - - # Verify initial data was written correctly - read_query = f""" - SELECT - CAST(id AS VARCHAR(255)), - CAST(content AS VARCHAR(255)), - CAST(embedding AS text), - CAST(metadata AS text) - FROM {self.default_table_name} - ORDER BY id desc - """ - self.read_test_pipeline.not_use_test_runner_api = True - with self.read_test_pipeline as p: - rows = ( - p - | ReadFromJdbc( - table_name=self.default_table_name, - driver_class_name="org.postgresql.Driver", - jdbc_url=self.jdbc_url, - username=self.username, - password=self.password, - query=read_query)) - - chunks = ( - rows - | "To Chunks" >> beam.Map(row_to_chunk) - | "Key on Index" >> beam.Map(key_on_id) - | "Get First 500" >> beam.transforms.combiners.Top.Of( - num_records, key=lambda x: x[0], reverse=True) - | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) - assert_that( - chunks, equal_to([test_chunks]), label='original_chunks_check') - - # Create updated chunks with same IDs but different content, embedding, and - # metadata - updated_chunks = [] - for i in range(num_records): - original_chunk = test_chunks[i] - updated_chunk = Chunk( - id=original_chunk.id, - content=Content(text=f"Updated content {i}"), - embedding=Embedding( - dense_embedding=[float(i * 2), float(i * 2 + 1)] + [0.0] * - (VECTOR_SIZE - 2)), - metadata={ - "updated": "true", "timestamp": "2024-02-25" - }) - updated_chunks.append(updated_chunk) - - # Write updated chunks - should update all non-conflict fields - self.write_test_pipeline2.not_use_test_runner_api = True - with self.write_test_pipeline2 as p: - _ = ( - p - | "Create updated Chunks" >> beam.Create(updated_chunks) - | "Write updated Chunks" >> config.create_write_transform()) - - # Read and verify that all non-conflict fields were updated - self.read_test_pipeline2.not_use_test_runner_api = True - with self.read_test_pipeline2 as p: - rows = ( - p - | "Read Updated chunks" >> ReadFromJdbc( - table_name=self.default_table_name, - driver_class_name="org.postgresql.Driver", - jdbc_url=self.jdbc_url, - username=self.username, - password=self.password, - query=read_query)) - - chunks = ( - rows - | "To Chunks 2" >> beam.Map(row_to_chunk) - | "Key on Index 2" >> beam.Map(key_on_id) - | "Get First 500 2" >> beam.transforms.combiners.Top.Of( - num_records, key=lambda x: x[0], reverse=True) - | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) - - # Verify that all non-conflict fields were updated - assert_that( - chunks, equal_to([updated_chunks]), label='updated_chunks_check') - if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py new file mode 100644 index 000000000000..69ead961a763 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql.py @@ -0,0 +1,220 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import asdict +from dataclasses import dataclass +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig +from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig +from apache_beam.ml.rag.ingestion.postgres import ColumnSpecsBuilder +from apache_beam.ml.rag.ingestion.postgres import PostgresVectorWriterConfig +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec +from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution + + +@dataclass +class LanguageConnectorConfig: + """Configuration options for CloudSQL Java language connector. + + Set parameters to connect connection to a CloudSQL instance using + Java language connector connector. For details see + https://github.com/GoogleCloudPlatform/cloud-sql-jdbc-socket-factory/blob/main/docs/jdbc.md + + Attributes: + username: Database username. + password: Database password. Can be empty string when using IAM. + database_name: Name of the database to connect to. + instance_name: Instance connection name. Format: + '::' + ip_type: Preferred order of IP types used to connect via a comma + list of strings. + enable_iam_auth: Whether to enable IAM authentication. Default is False + target_principal: Optional service account to impersonate for + connection. + delegates: Optional list of service accounts for delegated + impersonation. + admin_service_endpoint: Optional custom API service endpoint. + quota_project: Optional project ID for quota and billing. + connection_properties: Optional JDBC connection properties dict. + Example: {'ssl': 'true'} + additional_properties: Additional properties to be added to the JDBC + url. Example: {'someProperty': 'true'} + """ + username: str + password: str + database_name: str + instance_name: str + ip_types: Optional[List[str]] = None + enable_iam_auth: bool = False + target_principal: Optional[str] = None + delegates: Optional[List[str]] = None + quota_project: Optional[str] = None + connection_properties: Optional[Dict[str, str]] = None + additional_properties: Optional[Dict[str, Any]] = None + + def _base_jdbc_properties(self) -> Dict[str, Any]: + properties = {"cloudSqlInstance": self.instance_name} + + if self.ip_types: + properties["ipTypes"] = ",".join(self.ip_types) + + if self.enable_iam_auth: + properties["enableIamAuth"] = "true" + + if self.target_principal: + properties["cloudSqlTargetPrincipal"] = self.target_principal + + if self.delegates: + properties["cloudSqlDelegates"] = ",".join(self.delegates) + + if self.quota_project: + properties["cloudSqlAdminQuotaProject"] = self.quota_project + + if self.additional_properties: + properties.update(self.additional_properties) + + return properties + + def _build_jdbc_url(self, socketFactory, database_type): + url = f"jdbc:{database_type}:///{self.database_name}?" + + properties = self._base_jdbc_properties() + properties['socketFactory'] = socketFactory + + property_string = "&".join(f"{k}={v}" for k, v in properties.items()) + return url + property_string + + def to_connection_config(self): + return ConnectionConfig( + jdbc_url=self.to_jdbc_url(), + username=self.username, + password=self.password, + connection_properties=self.connection_properties, + additional_jdbc_args=self.additional_jdbc_args()) + + def additional_jdbc_args(self) -> Dict[str, List[Any]]: + return {} + + +@dataclass +class _PostgresConnectorConfig(LanguageConnectorConfig): + def to_jdbc_url(self) -> str: + """Convert options to a properly formatted JDBC URL. + + Returns: + JDBC URL string configured with all options. + """ + return self._build_jdbc_url( + socketFactory="com.google.cloud.sql.postgres.SocketFactory", + database_type="postgresql") + + def additional_jdbc_args(self) -> Dict[str, List[Any]]: + return { + 'classpath': [ + "org.postgresql:postgresql:42.2.16", + "com.google.cloud.sql:postgres-socket-factory:1.25.0" + ] + } + + @classmethod + def from_base_config(cls, config: LanguageConnectorConfig): + return cls(**asdict(config)) + + +class CloudSQLPostgresVectorWriterConfig(PostgresVectorWriterConfig): + def __init__( + self, + connection_config: LanguageConnectorConfig, + table_name: str, + *, + # pylint: disable=dangerous-default-value + write_config: WriteConfig = WriteConfig(), + column_specs: List[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build( + ), + conflict_resolution: Optional[ConflictResolution] = ConflictResolution( + on_conflict_fields=[], action='IGNORE')): + """Configuration for writing vectors to ClouSQL Postgres. + + Supports flexible schema configuration through column specifications and + conflict resolution strategies. + + Args: + connection_config: :class:`LanguageConnectorConfig`. + table_name: Target table name. + write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control + batch sizes, authosharding, etc. + column_specs: + Use :class:`~.postgres_common.ColumnSpecsBuilder` to configure how + embeddings and metadata are written a database + schema. If None, uses default Chunk schema. + conflict_resolution: Optional + :class:`~.postgres_common.ConflictResolution` + strategy for handling insert conflicts. ON CONFLICT DO NOTHING by + default. + + Examples: + Basic usage with default schema: + + >>> config = PostgresVectorWriterConfig( + ... connection_config=PostgresConnectionConfig(...), + ... table_name='embeddings' + ... ) + + Simple case with default schema: + + >>> config = PostgresVectorWriterConfig( + ... connection_config=ConnectionConfig(...), + ... table_name='embeddings' + ... ) + + Custom schema with metadata fields: + + >>> specs = (ColumnSpecsBuilder() + ... .with_id_spec(column_name="my_id_column") + ... .with_embedding_spec(column_name="embedding_vec") + ... .add_metadata_field(field="source", column_name="src") + ... .add_metadata_field( + ... "timestamp", + ... column_name="created_at", + ... sql_typecast="::timestamp" + ... ) + ... .build()) + + Minimal schema (only ID + embedding written) + + >>> column_specs = (ColumnSpecsBuilder() + ... .with_id_spec() + ... .with_embedding_spec() + ... .build()) + + >>> config = CloudSQLPostgresVectorWriterConfig( + ... connection_config=PostgresConnectionConfig(...), + ... table_name='embeddings', + ... column_specs=specs + ... ) + """ + self.connector_config = _PostgresConnectorConfig.from_base_config( + connection_config) + super().__init__( + connection_config=self.connector_config.to_connection_config(), + write_config=write_config, + table_name=table_name, + column_specs=column_specs, + conflict_resolution=conflict_resolution) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py new file mode 100644 index 000000000000..959e4cadb137 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/cloudsql_it_test.py @@ -0,0 +1,223 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import os +import secrets +import time +import unittest + +import pytest +import sqlalchemy +from google.cloud.sql.connector import Connector +from sqlalchemy import text + +import apache_beam as beam +from apache_beam.io.jdbc import ReadFromJdbc +from apache_beam.ml.rag.ingestion import test_utils +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteTransform +from apache_beam.ml.rag.ingestion.cloudsql import CloudSQLPostgresVectorWriterConfig +from apache_beam.ml.rag.ingestion.cloudsql import LanguageConnectorConfig +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + + +@pytest.mark.uses_gcp_java_expansion_service +@unittest.skipUnless( + os.environ.get('EXPANSION_JARS'), + "EXPANSION_JARS environment var is not provided, " + "indicating that jars have not been built") +@unittest.skipUnless( + os.environ.get('ALLOYDB_PASSWORD'), + "ALLOYDB_PASSWORD environment var is not provided") +class CloudSQLPostgresVectorWriterConfigTest(unittest.TestCase): + POSTGRES_TABLE_PREFIX = 'python_rag_postgres_' + + @classmethod + def _create_engine(cls): + """Create SQLAlchemy engine using Cloud SQL connector.""" + def getconn(): + conn = cls.connector.connect( + cls.instance_uri, + "pg8000", + user=cls.username, + password=cls.password, + db=cls.database, + ) + return conn + + engine = sqlalchemy.create_engine( + "postgresql+pg8000://", + creator=getconn, + ) + return engine + + @classmethod + def setUpClass(cls): + cls.database = os.environ.get('POSTGRES_DATABASE', 'postgres') + cls.username = os.environ.get('POSTGRES_USERNAME', 'postgres') + if not os.environ.get('ALLOYDB_PASSWORD'): + raise ValueError('ALLOYDB_PASSWORD env not set') + cls.password = os.environ.get('ALLOYDB_PASSWORD') + cls.instance_uri = os.environ.get( + 'POSTGRES_INSTANCE_URI', + 'apache-beam-testing:us-central1:beam-integration-tests') + + # Create unique table name suffix + cls.table_suffix = '%d%s' % (int(time.time()), secrets.token_hex(3)) + + # Setup database connection + cls.connector = Connector(refresh_strategy="LAZY") + cls.engine = cls._create_engine() + + def skip_if_dataflow_runner(self): + if self._runner and "dataflowrunner" in self._runner.lower(): + self.skipTest( + "Skipping some tests on Dataflow Runner to avoid bloat and timeouts") + + def setUp(self): + self.write_test_pipeline = TestPipeline(is_integration_test=True) + self.read_test_pipeline = TestPipeline(is_integration_test=True) + self._runner = type(self.read_test_pipeline.runner).__name__ + + self.default_table_name = f"{self.POSTGRES_TABLE_PREFIX}" \ + f"{self.table_suffix}" + + # Create test table + with self.engine.connect() as connection: + connection.execute( + text( + f""" + CREATE TABLE {self.default_table_name} ( + id TEXT PRIMARY KEY, + embedding VECTOR({test_utils.VECTOR_SIZE}), + content TEXT, + metadata JSONB + ) + """)) + connection.commit() + _LOGGER = logging.getLogger(__name__) + _LOGGER.info("Created table %s", self.default_table_name) + + def tearDown(self): + # Drop test table + with self.engine.connect() as connection: + connection.execute( + text(f"DROP TABLE IF EXISTS {self.default_table_name}")) + connection.commit() + _LOGGER = logging.getLogger(__name__) + _LOGGER.info("Dropped table %s", self.default_table_name) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, 'connector'): + cls.connector.close() + if hasattr(cls, 'engine'): + cls.engine.dispose() + + def test_language_connector(self): + """Test language connector.""" + self.skip_if_dataflow_runner() + + connection_config = LanguageConnectorConfig( + username=self.username, + password=self.password, + database_name=self.database, + instance_name=self.instance_uri) + writer_config = CloudSQLPostgresVectorWriterConfig( + connection_config=connection_config, table_name=self.default_table_name) + + # Create test chunks + num_records = 150 + sample_size = min(500, num_records // 2) + chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) + + self.write_test_pipeline.not_use_test_runner_api = True + + with self.write_test_pipeline as p: + _ = ( + p + | beam.Create(chunks) + | VectorDatabaseWriteTransform(writer_config)) + + self.read_test_pipeline.not_use_test_runner_api = True + read_query = f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(content AS VARCHAR(255)), + CAST(embedding AS text), + CAST(metadata AS text) + FROM {self.default_table_name} + """ + + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=writer_config.connector_config.to_connection_config( + ).jdbc_url, + username=self.username, + password=self.password, + query=read_query, + classpath=writer_config.connector_config.additional_jdbc_args() + ['classpath'])) + + count_result = rows | "Count All" >> beam.combiners.Count.Globally() + assert_that(count_result, equal_to([num_records]), label='count_check') + + chunks = (rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk)) + chunk_hashes = chunks | "Hash Chunks" >> beam.CombineGlobally( + test_utils.HashingFn()) + assert_that( + chunk_hashes, + equal_to([test_utils.generate_expected_hash(num_records)]), + label='hash_check') + + # Sample validation + first_n = ( + chunks + | "Key on Index" >> beam.Map(test_utils.key_on_id) + | f"Get First {sample_size}" >> beam.transforms.combiners.Top.Of( + sample_size, key=lambda x: x[0], reverse=True) + | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) + expected_first_n = test_utils.ChunkTestUtils.get_expected_values( + 0, sample_size) + assert_that( + first_n, + equal_to([expected_first_n]), + label=f"first_{sample_size}_check") + + last_n = ( + chunks + | "Key on Index 2" >> beam.Map(test_utils.key_on_id) + | f"Get Last {sample_size}" >> beam.transforms.combiners.Top.Of( + sample_size, key=lambda x: x[0]) + | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) + expected_last_n = test_utils.ChunkTestUtils.get_expected_values( + num_records - sample_size, num_records)[::-1] + assert_that( + last_n, + equal_to([expected_last_n]), + label=f"last_{sample_size}_check") + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/ingestion/jdbc_common.py b/sdks/python/apache_beam/ml/rag/ingestion/jdbc_common.py new file mode 100644 index 000000000000..586bb7a4aa65 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/jdbc_common.py @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + + +@dataclass +class ConnectionConfig: + """Configuration for connecting to a JDBC database. + + Provides connection details and options for connecting to a database + instance. + + Attributes: + jdbc_url: JDBC URL for the database instance. + Example: 'jdbc:postgresql://host:port/database' + username: Database username. + password: Database password. + connection_properties: Optional JDBC connection properties dict. + Example: {'ssl': 'true'} + connection_init_sqls: Optional list of SQL statements to execute when + connection is established. + additional_jdbc_args: Additional arguments that will be passed to + WriteToJdbc. These may include 'driver_jars', 'expansion_service', + 'classpath', etc. See full set of args at + :class:`~apache_beam.io.jdbc.WriteToJdbc` + + Example: + >>> config = AlloyDBConnectionConfig( + ... jdbc_url='jdbc:postgresql://localhost:5432/mydb', + ... username='user', + ... password='pass', + ... connection_properties={'ssl': 'true'}, + ... max_connections=10 + ... ) + """ + jdbc_url: str + username: str + password: str + connection_properties: Optional[Dict[str, str]] = None + connection_init_sqls: Optional[List[str]] = None + additional_jdbc_args: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class WriteConfig: + """Configuration writing to JDBC database. + + Modifies the write behavior when writing via JdbcIO. + + Attributes: + autosharding: Enable automatic re-sharding of bundles to scale the + number of shards with workers. + max_connections: Optional number of connections in the pool. + Use negative for no limit. + write_batch_size: Optional write batch size for bulk operations. + """ + autosharding: Optional[bool] = None + max_connections: Optional[int] = None + write_batch_size: Optional[int] = None diff --git a/sdks/python/apache_beam/ml/rag/ingestion/postgres.py b/sdks/python/apache_beam/ml/rag/ingestion/postgres.py new file mode 100644 index 000000000000..045579a73d28 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/postgres.py @@ -0,0 +1,209 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Callable +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import Union + +import apache_beam as beam +from apache_beam.coders import registry +from apache_beam.coders.row_coder import RowCoder +from apache_beam.io.jdbc import WriteToJdbc +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig +from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig +from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder +from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution +from apache_beam.ml.rag.types import Chunk + +_LOGGER = logging.getLogger(__name__) + +MetadataSpec = Union[ColumnSpec, Dict[str, ColumnSpec]] + + +class _PostgresQueryBuilder: + def __init__( + self, + table_name: str, + *, + column_specs: List[ColumnSpec], + conflict_resolution: Optional[ConflictResolution] = None): + """Builds SQL queries for writing Chunks with Embeddings to Postgres. + """ + self.table_name = table_name + + self.column_specs = column_specs + self.conflict_resolution = conflict_resolution + + # Validate no duplicate column names + names = [col.column_name for col in self.column_specs] + duplicates = set(name for name in names if names.count(name) > 1) + if duplicates: + raise ValueError(f"Duplicate column names found: {duplicates}") + + # Create NamedTuple type + fields = [(col.column_name, col.python_type) for col in self.column_specs] + type_name = f"VectorRecord_{table_name}" + self.record_type = NamedTuple(type_name, fields) # type: ignore + + # Register coder + registry.register_coder(self.record_type, RowCoder) + + # Set default update fields to all non-conflict fields if update fields are + # not specified + if self.conflict_resolution: + self.conflict_resolution.maybe_set_default_update_fields( + [col.column_name for col in self.column_specs if col.column_name]) + + def build_insert(self) -> str: + """Build INSERT query with proper type casting.""" + # Get column names and placeholders + fields = [col.column_name for col in self.column_specs] + placeholders = [col.placeholder for col in self.column_specs] + + # Build base query + query = f""" + INSERT INTO {self.table_name} + ({', '.join(fields)}) + VALUES ({', '.join(placeholders)}) + """ + + # Add conflict handling if configured + if self.conflict_resolution: + query += f" {self.conflict_resolution.get_conflict_clause()}" + + _LOGGER.info("Query with placeholders %s", query) + return query + + def create_converter(self) -> Callable[[Chunk], NamedTuple]: + """Creates a function to convert Chunks to records.""" + def convert(chunk: Chunk) -> self.record_type: # type: ignore + return self.record_type( + **{col.column_name: col.value_fn(chunk) + for col in self.column_specs}) # type: ignore + + return convert + + +class PostgresVectorWriterConfig(VectorDatabaseWriteConfig): + def __init__( + self, + connection_config: ConnectionConfig, + table_name: str, + *, + # pylint: disable=dangerous-default-value + write_config: WriteConfig = WriteConfig(), + column_specs: List[ColumnSpec] = ColumnSpecsBuilder.with_defaults().build( + ), + conflict_resolution: Optional[ConflictResolution] = ConflictResolution( + on_conflict_fields=[], action='IGNORE')): + """Configuration for writing vectors to Postgres using jdbc. + + Supports flexible schema configuration through column specifications and + conflict resolution strategies. + + Args: + connection_config: + :class:`~apache_beam.ml.rag.ingestion.jdbc_common.ConnectionConfig`. + table_name: Target table name. + write_config: JdbcIO :class:`~.jdbc_common.WriteConfig` to control + batch sizes, authosharding, etc. + column_specs: + Use :class:`~.postgres_common.ColumnSpecsBuilder` to configure how + embeddings and metadata are written a database + schema. If None, uses default Chunk schema. + conflict_resolution: Optional + :class:`~.postgres_common.ConflictResolution` + strategy for handling insert conflicts. ON CONFLICT DO NOTHING by + default. + + Examples: + Simple case with default schema: + + >>> config = PostgresVectorWriterConfig( + ... connection_config=ConnectionConfig(...), + ... table_name='embeddings' + ... ) + + Custom schema with metadata fields: + + >>> specs = (ColumnSpecsBuilder() + ... .with_id_spec(column_name="my_id_column") + ... .with_embedding_spec(column_name="embedding_vec") + ... .add_metadata_field(field="source", column_name="src") + ... .add_metadata_field( + ... "timestamp", + ... column_name="created_at", + ... sql_typecast="::timestamp" + ... ) + ... .build()) + + Minimal schema (only ID + embedding written) + + >>> column_specs = (ColumnSpecsBuilder() + ... .with_id_spec() + ... .with_embedding_spec() + ... .build()) + + >>> config = PostgresVectorWriterConfig( + ... connection_config=ConnectionConfig(...), + ... table_name='embeddings', + ... column_specs=specs + ... ) + """ + self.connection_config = connection_config + self.write_config = write_config + # NamedTuple is created and registered here during pipeline construction + self.query_builder = _PostgresQueryBuilder( + table_name, + column_specs=column_specs, + conflict_resolution=conflict_resolution) + + def create_write_transform(self) -> beam.PTransform: + return _WriteToPostgresVectorDatabase(self) + + +class _WriteToPostgresVectorDatabase(beam.PTransform): + """Implementation of Postgres vector database write. """ + def __init__(self, config: PostgresVectorWriterConfig): + self.config = config + self.query_builder = config.query_builder + self.connection_config = config.connection_config + self.write_config = config.write_config + + def expand(self, pcoll: beam.PCollection[Chunk]): + return ( + pcoll + | + "Convert to Records" >> beam.Map(self.query_builder.create_converter()) + | "Write to Postgres" >> WriteToJdbc( + table_name=self.query_builder.table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=self.connection_config.jdbc_url, + username=self.connection_config.username, + password=self.connection_config.password, + statement=self.query_builder.build_insert(), + connection_properties=self.connection_config.connection_properties, + connection_init_sqls=self.connection_config.connection_init_sqls, + autosharding=self.write_config.autosharding, + max_connections=self.write_config.max_connections, + write_batch_size=self.write_config.write_batch_size, + **self.connection_config.additional_jdbc_args)) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py new file mode 100644 index 000000000000..eca740a4e9c3 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py @@ -0,0 +1,484 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from dataclasses import dataclass +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Literal +from typing import Optional +from typing import Type +from typing import Union + +from apache_beam.ml.rag.types import Chunk + + +def chunk_embedding_fn(chunk: Chunk) -> str: + """Convert embedding to PostgreSQL array string. + + Formats dense embedding as a PostgreSQL-compatible array string. + Example: [1.0, 2.0] -> '{1.0,2.0}' + + Args: + chunk: Input Chunk object. + + Returns: + str: PostgreSQL array string representation of the embedding. + + Raises: + ValueError: If chunk has no dense embedding. + """ + if chunk.embedding is None or chunk.embedding.dense_embedding is None: + raise ValueError(f'Expected chunk to contain embedding. {chunk}') + return '{' + ','.join(str(x) for x in chunk.embedding.dense_embedding) + '}' + + +@dataclass +class ColumnSpec: + """Specification for mapping Chunk fields to SQL columns for insertion. + + Defines how to extract and format values from Chunks into database columns, + handling the full pipeline from Python value to SQL insertion. + + The insertion process works as follows: + - value_fn extracts a value from the Chunk and formats it as needed + - The value is stored in a NamedTuple field with the specified python_type + - During SQL insertion, the value is bound to a ? placeholder + + Attributes: + column_name: The column name in the database table. + python_type: Python type for the NamedTuple field that will hold the + value. Must be compatible with must be compatible with + :class:`~apache_beam.coders.row_coder.RowCoder`. + value_fn: Function to extract and format the value from a Chunk. + Takes a Chunk and returns a value of python_type. + sql_typecast: Optional SQL type cast to append to the ? placeholder. + Common examples: + - "::float[]" for vector arrays + - "::jsonb" for JSON data + + Examples: + Basic text column (uses standard JDBC type mapping): + >>> ColumnSpec.text( + ... column_name="content", + ... value_fn=lambda chunk: chunk.content.text + ... ) + # Results in: INSERT INTO table (content) VALUES (?) + + Vector column with explicit array casting: + >>> ColumnSpec.vector( + ... column_name="embedding", + ... value_fn=lambda chunk: '{' + + ... ','.join(map(str, chunk.embedding.dense_embedding)) + '}' + ... ) + # Results in: INSERT INTO table (embedding) VALUES (?::float[]) + # The value_fn formats [1.0, 2.0] as '{1.0,2.0}' for PostgreSQL array + + Timestamp from metadata with explicit casting: + >>> ColumnSpec( + ... column_name="created_at", + ... python_type=str, + ... value_fn=lambda chunk: chunk.metadata.get("timestamp"), + ... sql_typecast="::timestamp" + ... ) + # Results in: INSERT INTO table (created_at) VALUES (?::timestamp) + # Allows inserting string timestamps with proper PostgreSQL casting + + Factory Methods: + text: Creates a text column specification (no type cast). + integer: Creates an integer column specification (no type cast). + float: Creates a float column specification (no type cast). + vector: Creates a vector column specification with float[] casting. + jsonb: Creates a JSONB column specification with jsonb casting. + """ + column_name: str + python_type: Type + value_fn: Callable[[Chunk], Any] + sql_typecast: Optional[str] = None + + @property + def placeholder(self) -> str: + """Get SQL placeholder with optional typecast.""" + return f"?{self.sql_typecast or ''}" + + @classmethod + def text( + cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + """Create a text column specification.""" + return cls(column_name, str, value_fn) + + @classmethod + def integer( + cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + """Create an integer column specification.""" + return cls(column_name, int, value_fn) + + @classmethod + def float( + cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + """Create a float column specification.""" + return cls(column_name, float, value_fn) + + @classmethod + def vector( + cls, + column_name: str, + value_fn: Callable[[Chunk], Any] = chunk_embedding_fn) -> 'ColumnSpec': + """Create a vector column specification.""" + return cls(column_name, str, value_fn, "::float[]") + + @classmethod + def jsonb( + cls, column_name: str, value_fn: Callable[[Chunk], Any]) -> 'ColumnSpec': + """Create a JSONB column specification.""" + return cls(column_name, str, value_fn, "::jsonb") + + +class ColumnSpecsBuilder: + """Builder for :class:`.ColumnSpec`'s with chainable methods.""" + def __init__(self): + self._specs: List[ColumnSpec] = [] + + @staticmethod + def with_defaults() -> 'ColumnSpecsBuilder': + """Add all default column specifications.""" + return ( + ColumnSpecsBuilder().with_id_spec().with_embedding_spec(). + with_content_spec().with_metadata_spec()) + + def with_id_spec( + self, + column_name: str = "id", + python_type: Type = str, + convert_fn: Optional[Callable[[str], Any]] = None, + sql_typecast: Optional[str] = None) -> 'ColumnSpecsBuilder': + """Add ID :class:`.ColumnSpec` with optional type and conversion. + + Args: + column_name: Name for the ID column (defaults to "id") + python_type: Python type for the column (defaults to str) + convert_fn: Optional function to convert the chunk ID + If None, uses ID as-is + sql_typecast: Optional SQL type cast + + Returns: + Self for method chaining + + Example: + >>> builder.with_id_spec( + ... column_name="doc_id", + ... python_type=int, + ... convert_fn=lambda id: int(id.split('_')[1]) + ... ) + """ + def value_fn(chunk: Chunk) -> Any: + value = chunk.id + return convert_fn(value) if convert_fn else value + + self._specs.append( + ColumnSpec( + column_name=column_name, + python_type=python_type, + value_fn=value_fn, + sql_typecast=sql_typecast)) + return self + + def with_content_spec( + self, + column_name: str = "content", + python_type: Type = str, + convert_fn: Optional[Callable[[str], Any]] = None, + sql_typecast: Optional[str] = None) -> 'ColumnSpecsBuilder': + """Add content :class:`.ColumnSpec` with optional type and conversion. + + Args: + column_name: Name for the content column (defaults to "content") + python_type: Python type for the column (defaults to str) + convert_fn: Optional function to convert the content text + If None, uses content text as-is + sql_typecast: Optional SQL type cast + + Returns: + Self for method chaining + + Example: + >>> builder.with_content_spec( + ... column_name="content_length", + ... python_type=int, + ... convert_fn=len # Store content length instead of content + ... ) + """ + def value_fn(chunk: Chunk) -> Any: + if chunk.content.text is None: + raise ValueError(f'Expected chunk to contain content. {chunk}') + value = chunk.content.text + return convert_fn(value) if convert_fn else value + + self._specs.append( + ColumnSpec( + column_name=column_name, + python_type=python_type, + value_fn=value_fn, + sql_typecast=sql_typecast)) + return self + + def with_metadata_spec( + self, + column_name: str = "metadata", + python_type: Type = str, + convert_fn: Optional[Callable[[Dict[str, Any]], Any]] = None, + sql_typecast: Optional[str] = "::jsonb") -> 'ColumnSpecsBuilder': + """Add metadata :class:`.ColumnSpec` with optional type and conversion. + + Args: + column_name: Name for the metadata column (defaults to "metadata") + python_type: Python type for the column (defaults to str) + convert_fn: Optional function to convert the metadata dictionary + If None and python_type is str, converts to JSON string + sql_typecast: Optional SQL type cast (defaults to "::jsonb") + + Returns: + Self for method chaining + + Example: + >>> builder.with_metadata_spec( + ... column_name="meta_tags", + ... python_type=list, + ... convert_fn=lambda meta: list(meta.keys()), + ... sql_typecast="::text[]" + ... ) + """ + def value_fn(chunk: Chunk) -> Any: + if convert_fn: + return convert_fn(chunk.metadata) + return json.dumps( + chunk.metadata) if python_type == str else chunk.metadata + + self._specs.append( + ColumnSpec( + column_name=column_name, + python_type=python_type, + value_fn=value_fn, + sql_typecast=sql_typecast)) + return self + + def with_embedding_spec( + self, + column_name: str = "embedding", + convert_fn: Optional[Callable[[List[float]], Any]] = None + ) -> 'ColumnSpecsBuilder': + """Add embedding :class:`.ColumnSpec` with optional conversion. + + Args: + column_name: Name for the embedding column (defaults to "embedding") + convert_fn: Optional function to convert the dense embedding values + If None, uses default PostgreSQL array format + + Returns: + Self for method chaining + + Example: + >>> builder.with_embedding_spec( + ... column_name="embedding_vector", + ... convert_fn=lambda values: '{' + ','.join(f"{x:.4f}" + ... for x in values) + '}' + ... ) + """ + def value_fn(chunk: Chunk) -> Any: + if chunk.embedding is None or chunk.embedding.dense_embedding is None: + raise ValueError(f'Expected chunk to contain embedding. {chunk}') + values = chunk.embedding.dense_embedding + if convert_fn: + return convert_fn(values) + return '{' + ','.join(str(x) for x in values) + '}' + + self._specs.append( + ColumnSpec.vector(column_name=column_name, value_fn=value_fn)) + return self + + def add_metadata_field( + self, + field: str, + python_type: Type, + column_name: Optional[str] = None, + convert_fn: Optional[Callable[[Any], Any]] = None, + default: Any = None, + sql_typecast: Optional[str] = None) -> 'ColumnSpecsBuilder': + """""Add a :class:`.ColumnSpec` that extracts and converts a field from + chunk metadata. + + Args: + field: Key to extract from chunk metadata + python_type: Python type for the column (e.g. str, int, float) + column_name: Name for the column (defaults to metadata field name) + convert_fn: Optional function to convert the extracted value to + desired type. If None, value is used as-is + default: Default value if field is missing from metadata + sql_typecast: Optional SQL type cast (e.g. "::timestamp") + + Returns: + Self for chaining + + Examples: + + Simple string field: + >>> builder.add_metadata_field("source", str) + + Integer with default: + + >>> builder.add_metadata_field( + ... field="count", + ... python_type=int, + ... column_name="item_count", + ... default=0 + ... ) + + Float with conversion and default: + + >>> builder.add_metadata_field( + ... field="confidence", + ... python_type=intfloat, + ... convert_fn=lambda x: round(float(x), 2), + ... default=0.0 + ... ) + + Timestamp with conversion and type cast: + + >>> builder.add_metadata_field( + ... field="created_at", + ... python_type=intstr, + ... convert_fn=lambda ts: ts.replace('T', ' '), + ... sql_typecast="::timestamp" + ... ) + """ + name = column_name or field + + def value_fn(chunk: Chunk) -> Any: + value = chunk.metadata.get(field, default) + if value is not None and convert_fn is not None: + value = convert_fn(value) + return value + + spec = ColumnSpec( + column_name=name, + python_type=python_type, + value_fn=value_fn, + sql_typecast=sql_typecast) + + self._specs.append(spec) + return self + + def add_custom_column_spec(self, spec: ColumnSpec) -> 'ColumnSpecsBuilder': + """Add a custom :class:`.ColumnSpec` to the builder. + + Use this method when you need complete control over the :class:`.ColumnSpec` + , including custom value extraction and type handling. + + Args: + spec: A :class:`.ColumnSpec` instance defining the column name, type, + value extraction, and optional SQL type casting. + + Returns: + Self for method chaining + + Examples: + Custom text column from chunk metadata: + + >>> builder.add_custom_column_spec( + ... ColumnSpec.text( + ... name="source_and_id", + ... value_fn=lambda chunk: \ + ... f"{chunk.metadata.get('source')}_{chunk.id}" + ... ) + ... ) + """ + self._specs.append(spec) + return self + + def build(self) -> List[ColumnSpec]: + """Build the final list of column specifications.""" + return self._specs.copy() + + +@dataclass +class ConflictResolution: + """Specification for how to handle conflicts during insert. + + Configures conflict handling behavior when inserting records that may + violate unique constraints. + + Attributes: + on_conflict_fields: Field(s) that determine uniqueness. Can be a single + field name or list of field names for composite constraints. + action: How to handle conflicts - either "UPDATE" or "IGNORE". + UPDATE: Updates existing record with new values. + IGNORE: Skips conflicting records. + update_fields: Optional list of fields to update on conflict. If None, + all non-conflict fields are updated. + + Examples: + Simple primary key: + + >>> ConflictResolution("id") + + Composite key with specific update fields: + + >>> ConflictResolution( + ... on_conflict_fields=["source", "timestamp"], + ... action="UPDATE", + ... update_fields=["embedding", "content"] + ... ) + + Ignore conflicts: + + >>> ConflictResolution( + ... on_conflict_fields="id", + ... action="IGNORE" + ... ) + """ + on_conflict_fields: Union[str, List[str]] + action: Literal["UPDATE", "IGNORE"] = "UPDATE" + update_fields: Optional[List[str]] = None + + def maybe_set_default_update_fields(self, columns: List[str]): + if self.action != "UPDATE": + return + if self.update_fields is not None: + return + + conflict_fields = ([self.on_conflict_fields] if isinstance( + self.on_conflict_fields, str) else self.on_conflict_fields) + self.update_fields = [col for col in columns if col not in conflict_fields] + + def get_conflict_clause(self) -> str: + """Get conflict clause with update fields.""" + conflict_fields = [self.on_conflict_fields] \ + if isinstance(self.on_conflict_fields, str) \ + else self.on_conflict_fields + + if self.action == "IGNORE": + conflict_fields_string = f"({', '.join(conflict_fields)})" \ + if len(conflict_fields) > 0 else "" + return f"ON CONFLICT {conflict_fields_string} DO NOTHING" + + # update_fields should be set by query builder before this is called + assert self.update_fields is not None, \ + "update_fields must be set before generating conflict clause" + updates = [f"{field} = EXCLUDED.{field}" for field in self.update_fields] + return f"ON CONFLICT " \ + f"({', '.join(conflict_fields)}) DO UPDATE SET {', '.join(updates)}" diff --git a/sdks/python/apache_beam/ml/rag/ingestion/postgres_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/postgres_it_test.py new file mode 100644 index 000000000000..adbe28b5d086 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/postgres_it_test.py @@ -0,0 +1,902 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import logging +import os +import secrets +import time +import unittest +from typing import List +from typing import NamedTuple + +import psycopg2 +import pytest + +import apache_beam as beam +from apache_beam.coders import registry +from apache_beam.coders.row_coder import RowCoder +from apache_beam.io.jdbc import ReadFromJdbc +from apache_beam.ml.rag.ingestion import test_utils +from apache_beam.ml.rag.ingestion.jdbc_common import ConnectionConfig +from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig +from apache_beam.ml.rag.ingestion.postgres import PostgresVectorWriterConfig +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder +from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution +from apache_beam.ml.rag.ingestion.postgres_common import chunk_embedding_fn +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +CustomSpecsRow = NamedTuple( + 'CustomSpecsRow', + [ + ('custom_id', str), # For id_spec test + ('embedding_vec', List[float]), # For embedding_spec test + ('content_col', str), # For content_spec test + ('metadata', str) + ]) +registry.register_coder(CustomSpecsRow, RowCoder) + +MetadataConflictRow = NamedTuple( + 'MetadataConflictRow', + [ + ('id', str), + ('source', str), # For metadata_spec and composite key + ('timestamp', str), # For metadata_spec and composite key + ('content', str), + ('embedding', List[float]), + ('metadata', str) + ]) +registry.register_coder(MetadataConflictRow, RowCoder) + + +@pytest.mark.uses_gcp_java_expansion_service +@unittest.skipUnless( + os.environ.get('EXPANSION_JARS'), + "EXPANSION_JARS environment var is not provided, " + "indicating that jars have not been built") +@unittest.skipUnless( + os.environ.get('ALLOYDB_PASSWORD'), + "ALLOYDB_PASSWORD environment var is not provided") +class PostgresVectorWriterConfigTest(unittest.TestCase): + POSTGRES_TABLE_PREFIX = 'python_rag_postgres_' + + @classmethod + def setUpClass(cls): + cls.host = os.environ.get('ALLOYDB_HOST', '10.119.0.22') + cls.port = os.environ.get('ALLOYDB_PORT', '5432') + cls.database = os.environ.get('ALLOYDB_DATABASE', 'postgres') + cls.username = os.environ.get('ALLOYDB_USERNAME', 'postgres') + if not os.environ.get('ALLOYDB_PASSWORD'): + raise ValueError('ALLOYDB_PASSWORD env not set') + cls.password = os.environ.get('ALLOYDB_PASSWORD') + + # Create unique table name suffix + cls.table_suffix = '%d%s' % (int(time.time()), secrets.token_hex(3)) + + # Setup database connection + cls.conn = psycopg2.connect( + host=cls.host, + port=cls.port, + database=cls.database, + user=cls.username, + password=cls.password) + cls.conn.autocommit = True + + def skip_if_dataflow_runner(self): + if self._runner and "dataflowrunner" in self._runner.lower(): + self.skipTest( + "Skipping some tests on Dataflow Runner to avoid bloat and timeouts") + + def setUp(self): + self.write_test_pipeline = TestPipeline(is_integration_test=True) + self.read_test_pipeline = TestPipeline(is_integration_test=True) + self.write_test_pipeline2 = TestPipeline(is_integration_test=True) + self.read_test_pipeline2 = TestPipeline(is_integration_test=True) + self._runner = type(self.read_test_pipeline.runner).__name__ + + self.default_table_name = f"{self.POSTGRES_TABLE_PREFIX}" \ + f"{self.table_suffix}" + self.default_table_name = f"{self.POSTGRES_TABLE_PREFIX}" \ + f"{self.table_suffix}" + self.custom_table_name = f"{self.POSTGRES_TABLE_PREFIX}" \ + f"_custom_{self.table_suffix}" + self.metadata_conflicts_table = f"{self.POSTGRES_TABLE_PREFIX}" \ + f"_meta_conf_{self.table_suffix}" + + self.jdbc_url = f'jdbc:postgresql://{self.host}:{self.port}/{self.database}' + + # Create test table + with self.conn.cursor() as cursor: + cursor.execute( + f""" + CREATE TABLE {self.default_table_name} ( + id TEXT PRIMARY KEY, + embedding VECTOR({test_utils.VECTOR_SIZE}), + content TEXT, + metadata JSONB + ) + """) + cursor.execute( + f""" + CREATE TABLE {self.custom_table_name} ( + custom_id TEXT PRIMARY KEY, + embedding_vec VECTOR(2), + content_col TEXT, + metadata JSONB + ) + """) + cursor.execute( + f""" + CREATE TABLE {self.metadata_conflicts_table} ( + id TEXT, + source TEXT, + timestamp TIMESTAMP, + content TEXT, + embedding VECTOR(2), + PRIMARY KEY (id), + UNIQUE (source, timestamp) + ) + """) + _LOGGER = logging.getLogger(__name__) + _LOGGER.info("Created table %s", self.default_table_name) + + def tearDown(self): + # Drop test table + with self.conn.cursor() as cursor: + cursor.execute(f"DROP TABLE IF EXISTS {self.default_table_name}") + cursor.execute(f"DROP TABLE IF EXISTS {self.custom_table_name}") + cursor.execute(f"DROP TABLE IF EXISTS {self.metadata_conflicts_table}") + _LOGGER = logging.getLogger(__name__) + _LOGGER.info("Dropped table %s", self.default_table_name) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, 'conn'): + cls.conn.close() + + def test_default_schema(self): + """Test basic write with default schema.""" + jdbc_url = f'jdbc:postgresql://{self.host}:{self.port}/{self.database}' + connection_config = ConnectionConfig( + jdbc_url=jdbc_url, username=self.username, password=self.password) + + config = PostgresVectorWriterConfig( + connection_config=connection_config, + table_name=self.default_table_name, + write_config=WriteConfig(write_batch_size=100)) + + # Create test chunks + num_records = 1500 + sample_size = min(500, num_records // 2) + # Generate test chunks + chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) + + # Run pipeline and verify + self.write_test_pipeline.not_use_test_runner_api = True + + with self.write_test_pipeline as p: + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + self.read_test_pipeline.not_use_test_runner_api = True + # Read pipeline to verify + read_query = f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(content AS VARCHAR(255)), + CAST(embedding AS text), + CAST(metadata AS text) + FROM {self.default_table_name} + """ + + # Read and verify pipeline + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=jdbc_url, + username=self.username, + password=self.password, + query=read_query)) + + count_result = rows | "Count All" >> beam.combiners.Count.Globally() + assert_that(count_result, equal_to([num_records]), label='count_check') + + chunks = (rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk)) + chunk_hashes = chunks | "Hash Chunks" >> beam.CombineGlobally( + test_utils.HashingFn()) + assert_that( + chunk_hashes, + equal_to([test_utils.generate_expected_hash(num_records)]), + label='hash_check') + + # Sample validation + first_n = ( + chunks + | "Key on Index" >> beam.Map(test_utils.key_on_id) + | f"Get First {sample_size}" >> beam.transforms.combiners.Top.Of( + sample_size, key=lambda x: x[0], reverse=True) + | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) + expected_first_n = test_utils.ChunkTestUtils.get_expected_values( + 0, sample_size) + assert_that( + first_n, + equal_to([expected_first_n]), + label=f"first_{sample_size}_check") + + last_n = ( + chunks + | "Key on Index 2" >> beam.Map(test_utils.key_on_id) + | f"Get Last {sample_size}" >> beam.transforms.combiners.Top.Of( + sample_size, key=lambda x: x[0]) + | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) + expected_last_n = test_utils.ChunkTestUtils.get_expected_values( + num_records - sample_size, num_records)[::-1] + assert_that( + last_n, + equal_to([expected_last_n]), + label=f"last_{sample_size}_check") + + def test_custom_specs(self): + """Test custom specifications for ID, embedding, and content.""" + self.skip_if_dataflow_runner() + num_records = 20 + + specs = ( + ColumnSpecsBuilder().add_custom_column_spec( + ColumnSpec.text( + column_name="custom_id", + value_fn=lambda chunk: + f"timestamp_{chunk.metadata.get('timestamp', '')}") + ).add_custom_column_spec( + ColumnSpec.vector( + column_name="embedding_vec", + value_fn=chunk_embedding_fn)).add_custom_column_spec( + ColumnSpec.text( + column_name="content_col", + value_fn=lambda chunk: + f"{len(chunk.content.text)}:{chunk.content.text}")). + with_metadata_spec().build()) + + connection_config = ConnectionConfig( + jdbc_url=self.jdbc_url, username=self.username, password=self.password) + + writer_config = PostgresVectorWriterConfig( + connection_config=connection_config, + table_name=self.custom_table_name, + column_specs=specs) + + # Generate test chunks + test_chunks = [ + Chunk( + id=str(i), + content=Content(text=f"content_{i}"), + embedding=Embedding(dense_embedding=[float(i), float(i + 1)]), + metadata={"timestamp": f"2024-02-02T{i:02d}:00:00"}) + for i in range(num_records) + ] + + # Write pipeline + self.write_test_pipeline.not_use_test_runner_api = True + with self.write_test_pipeline as p: + _ = ( + p | beam.Create(test_chunks) | writer_config.create_write_transform()) + + # Read and verify + read_query = f""" + SELECT + CAST(custom_id AS VARCHAR(255)), + CAST(embedding_vec AS text), + CAST(content_col AS VARCHAR(255)), + CAST(metadata AS text) + FROM {self.custom_table_name} + ORDER BY custom_id + """ + + # Convert BeamRow back to Chunk + def custom_row_to_chunk(row): + # Extract timestamp from custom_id + timestamp = row.custom_id.split('timestamp_')[1] + # Extract index from timestamp + i = int(timestamp.split('T')[1][:2]) + + # Parse embedding vector + embedding_list = [ + float(x) for x in row.embedding_vec.strip('[]').split(',') + ] + + # Extract content from length-prefixed format + content = row.content_col.split(':', 1)[1] + + return Chunk( + id=str(i), + content=Content(text=content), + embedding=Embedding(dense_embedding=embedding_list), + metadata=json.loads(row.metadata)) + + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.custom_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=self.jdbc_url, + username=self.username, + password=self.password, + query=read_query)) + + # Verify count + count_result = rows | "Count All" >> beam.combiners.Count.Globally() + assert_that(count_result, equal_to([num_records]), label='count_check') + + chunks = rows | "To Chunks" >> beam.Map(custom_row_to_chunk) + assert_that(chunks, equal_to(test_chunks), label='chunks_check') + + def test_defaults_with_args_specs(self): + """Test custom specifications for ID, embedding, and content.""" + self.skip_if_dataflow_runner() + num_records = 20 + + specs = ( + ColumnSpecsBuilder().with_id_spec( + column_name="custom_id", + python_type=int, + convert_fn=lambda x: int(x), + sql_typecast="::text").with_content_spec( + column_name="content_col", + convert_fn=lambda x: f"{len(x)}:{x}", + ).with_embedding_spec( + column_name="embedding_vec").with_metadata_spec().build()) + + connection_config = ConnectionConfig( + jdbc_url=self.jdbc_url, username=self.username, password=self.password) + + writer_config = PostgresVectorWriterConfig( + connection_config=connection_config, + table_name=self.custom_table_name, + column_specs=specs) + + # Generate test chunks + test_chunks = [ + Chunk( + id=str(i), + content=Content(text=f"content_{i}"), + embedding=Embedding(dense_embedding=[float(i), float(i + 1)]), + metadata={"timestamp": f"2024-02-02T{i:02d}:00:00"}) + for i in range(num_records) + ] + + # Write pipeline + self.write_test_pipeline.not_use_test_runner_api = True + with self.write_test_pipeline as p: + _ = ( + p | beam.Create(test_chunks) | writer_config.create_write_transform()) + + # Read and verify + read_query = f""" + SELECT + CAST(custom_id AS VARCHAR(255)), + CAST(embedding_vec AS text), + CAST(content_col AS VARCHAR(255)), + CAST(metadata AS text) + FROM {self.custom_table_name} + ORDER BY custom_id + """ + + # Convert BeamRow back to Chunk + def custom_row_to_chunk(row): + # Parse embedding vector + embedding_list = [ + float(x) for x in row.embedding_vec.strip('[]').split(',') + ] + + # Extract content from length-prefixed format + content = row.content_col.split(':', 1)[1] + + return Chunk( + id=row.custom_id, + content=Content(text=content), + embedding=Embedding(dense_embedding=embedding_list), + metadata=json.loads(row.metadata)) + + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.custom_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=self.jdbc_url, + username=self.username, + password=self.password, + query=read_query)) + + # Verify count + count_result = rows | "Count All" >> beam.combiners.Count.Globally() + assert_that(count_result, equal_to([num_records]), label='count_check') + + chunks = rows | "To Chunks" >> beam.Map(custom_row_to_chunk) + assert_that(chunks, equal_to(test_chunks), label='chunks_check') + + def test_default_id_embedding_specs(self): + """Test with only default id and embedding specs, others set to None.""" + self.skip_if_dataflow_runner() + num_records = 20 + connection_config = ConnectionConfig( + jdbc_url=self.jdbc_url, username=self.username, password=self.password) + specs = ( + ColumnSpecsBuilder().with_id_spec() # Use default id spec + .with_embedding_spec() # Use default embedding spec + .build()) + + writer_config = PostgresVectorWriterConfig( + connection_config=connection_config, + table_name=self.default_table_name, + column_specs=specs) + + # Generate test chunks + test_chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) + + # Write pipeline + self.write_test_pipeline.not_use_test_runner_api = True + with self.write_test_pipeline as p: + _ = ( + p | beam.Create(test_chunks) | writer_config.create_write_transform()) + + # Read and verify only id and embedding + read_query = f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(embedding AS text) + FROM {self.default_table_name} + ORDER BY id + """ + + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=self.jdbc_url, + username=self.username, + password=self.password, + query=read_query)) + + chunks = rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk) + + # Create expected chunks with None values + expected_chunks = test_utils.ChunkTestUtils.get_expected_values( + 0, num_records) + for chunk in expected_chunks: + chunk.content.text = None + chunk.metadata = {} + + assert_that(chunks, equal_to(expected_chunks), label='chunks_check') + + def test_metadata_spec_and_conflicts(self): + """Test metadata specification and conflict resolution.""" + self.skip_if_dataflow_runner() + num_records = 20 + + specs = ( + ColumnSpecsBuilder().with_id_spec().with_embedding_spec(). + with_content_spec().add_metadata_field( + field="source", + column_name="source", + python_type=str, + sql_typecast=None # Plain text field + ).add_metadata_field( + field="timestamp", python_type=str, + sql_typecast="::timestamp").build()) + + # Conflict resolution on source+timestamp + conflict_resolution = ConflictResolution( + on_conflict_fields=["source", "timestamp"], + action="UPDATE", + update_fields=["embedding", "content"]) + connection_config = ConnectionConfig( + jdbc_url=self.jdbc_url, username=self.username, password=self.password) + writer_config = PostgresVectorWriterConfig( + connection_config=connection_config, + table_name=self.metadata_conflicts_table, + column_specs=specs, + conflict_resolution=conflict_resolution) + + # Generate initial test chunks + initial_chunks = [ + Chunk( + id=str(i), + content=Content(text=f"content_{i}"), + embedding=Embedding(dense_embedding=[float(i), float(i + 1)]), + metadata={ + "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00" + }) for i in range(num_records) + ] + + # Write initial chunks + self.write_test_pipeline.not_use_test_runner_api = True + with self.write_test_pipeline as p: + _ = ( + p | "Write Initial" >> beam.Create(initial_chunks) + | writer_config.create_write_transform()) + + # Generate conflicting chunks (same source+timestamp, different content) + conflicting_chunks = [ + Chunk( + id=f"new_{i}", + content=Content(text=f"updated_content_{i}"), + embedding=Embedding( + dense_embedding=[float(i) * 2, float(i + 1) * 2]), + metadata={ + "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00" + }) for i in range(num_records) + ] + + # Write conflicting chunks + self.write_test_pipeline2.not_use_test_runner_api = True + with self.write_test_pipeline2 as p: + _ = ( + p | "Write Conflicts" >> beam.Create(conflicting_chunks) + | writer_config.create_write_transform()) + + # Read and verify + read_query = f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(embedding AS text), + CAST(content AS VARCHAR(255)), + CAST(source AS VARCHAR(255)), + CAST(timestamp AS VARCHAR(255)) + FROM {self.metadata_conflicts_table} + ORDER BY timestamp, id + """ + + # Expected chunks after conflict resolution + expected_chunks = [ + Chunk( + id=str(i), + content=Content(text=f"updated_content_{i}"), + embedding=Embedding( + dense_embedding=[float(i) * 2, float(i + 1) * 2]), + metadata={ + "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00" + }) for i in range(num_records) + ] + + def metadata_row_to_chunk(row): + return Chunk( + id=row.id, + content=Content(text=row.content), + embedding=Embedding( + dense_embedding=[ + float(x) for x in row.embedding.strip('[]').split(',') + ]), + metadata={ + "source": row.source, + "timestamp": row.timestamp.replace(' ', 'T') + }) + + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.metadata_conflicts_table, + driver_class_name="org.postgresql.Driver", + jdbc_url=self.jdbc_url, + username=self.username, + password=self.password, + query=read_query)) + + chunks = rows | "To Chunks" >> beam.Map(metadata_row_to_chunk) + assert_that(chunks, equal_to(expected_chunks), label='chunks_check') + + def test_conflict_resolution_update(self): + """Test conflict resolution with UPDATE action.""" + self.skip_if_dataflow_runner() + num_records = 20 + + connection_config = ConnectionConfig( + jdbc_url=self.jdbc_url, username=self.username, password=self.password) + + conflict_resolution = ConflictResolution( + on_conflict_fields="id", + action="UPDATE", + update_fields=["embedding", "content"]) + + config = PostgresVectorWriterConfig( + connection_config=connection_config, + table_name=self.default_table_name, + conflict_resolution=conflict_resolution) + + # Generate initial test chunks + test_chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) + self.write_test_pipeline.not_use_test_runner_api = True + # Insert initial test chunks + with self.write_test_pipeline as p: + _ = ( + p + | "Create initial chunks" >> beam.Create(test_chunks) + | "Write initial chunks" >> config.create_write_transform()) + + read_query = f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(content AS VARCHAR(255)), + CAST(embedding AS text), + CAST(metadata AS text) + FROM {self.default_table_name} + ORDER BY id desc + """ + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=self.jdbc_url, + username=self.username, + password=self.password, + query=read_query)) + + chunks = ( + rows + | "To Chunks" >> beam.Map(test_utils.row_to_chunk) + | "Key on Index" >> beam.Map(test_utils.key_on_id) + | "Get First 500" >> beam.transforms.combiners.Top.Of( + num_records, key=lambda x: x[0], reverse=True) + | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) + assert_that( + chunks, equal_to([test_chunks]), label='original_chunks_check') + + updated_chunks = test_utils.ChunkTestUtils.get_expected_values( + 0, num_records, content_prefix="Newcontent", seed_multiplier=2) + self.write_test_pipeline2.not_use_test_runner_api = True + with self.write_test_pipeline2 as p: + _ = ( + p + | "Create updated Chunks" >> beam.Create(updated_chunks) + | "Write updated Chunks" >> config.create_write_transform()) + self.read_test_pipeline2.not_use_test_runner_api = True + with self.read_test_pipeline2 as p: + rows = ( + p + | "Read Updated chunks" >> ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=self.jdbc_url, + username=self.username, + password=self.password, + query=read_query)) + + chunks = ( + rows + | "To Chunks 2" >> beam.Map(test_utils.row_to_chunk) + | "Key on Index 2" >> beam.Map(test_utils.key_on_id) + | "Get First 500 2" >> beam.transforms.combiners.Top.Of( + num_records, key=lambda x: x[0], reverse=True) + | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) + assert_that( + chunks, equal_to([updated_chunks]), label='updated_chunks_check') + + def test_conflict_resolution_default_ignore(self): + """Test conflict resolution with default.""" + self.skip_if_dataflow_runner() + num_records = 20 + + connection_config = ConnectionConfig( + jdbc_url=self.jdbc_url, username=self.username, password=self.password) + + config = PostgresVectorWriterConfig( + connection_config=connection_config, table_name=self.default_table_name) + + # Generate initial test chunks + test_chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) + self.write_test_pipeline.not_use_test_runner_api = True + # Insert initial test chunks + with self.write_test_pipeline as p: + _ = ( + p + | "Create initial chunks" >> beam.Create(test_chunks) + | "Write initial chunks" >> config.create_write_transform()) + + read_query = f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(content AS VARCHAR(255)), + CAST(embedding AS text), + CAST(metadata AS text) + FROM {self.default_table_name} + ORDER BY id desc + """ + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=self.jdbc_url, + username=self.username, + password=self.password, + query=read_query)) + + chunks = ( + rows + | "To Chunks" >> beam.Map(test_utils.row_to_chunk) + | "Key on Index" >> beam.Map(test_utils.key_on_id) + | "Get First 500" >> beam.transforms.combiners.Top.Of( + num_records, key=lambda x: x[0], reverse=True) + | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) + assert_that( + chunks, equal_to([test_chunks]), label='original_chunks_check') + + updated_chunks = test_utils.ChunkTestUtils.get_expected_values( + 0, num_records, content_prefix="Newcontent", seed_multiplier=2) + self.write_test_pipeline2.not_use_test_runner_api = True + with self.write_test_pipeline2 as p: + _ = ( + p + | "Create updated Chunks" >> beam.Create(updated_chunks) + | "Write updated Chunks" >> config.create_write_transform()) + self.read_test_pipeline2.not_use_test_runner_api = True + with self.read_test_pipeline2 as p: + rows = ( + p + | "Read Updated chunks" >> ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=self.jdbc_url, + username=self.username, + password=self.password, + query=read_query)) + + chunks = ( + rows + | "To Chunks 2" >> beam.Map(test_utils.row_to_chunk) + | "Key on Index 2" >> beam.Map(test_utils.key_on_id) + | "Get First 500 2" >> beam.transforms.combiners.Top.Of( + num_records, key=lambda x: x[0], reverse=True) + | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) + assert_that(chunks, equal_to([test_chunks]), label='updated_chunks_check') + + def test_conflict_resolution_default_update_fields(self): + """Test conflict resolution with default update fields (all non-conflict + fields).""" + self.skip_if_dataflow_runner() + num_records = 20 + + connection_config = ConnectionConfig( + jdbc_url=self.jdbc_url, username=self.username, password=self.password) + + # Create a conflict resolution with only the conflict field specified + # No update_fields specified - should default to all non-conflict fields + conflict_resolution = ConflictResolution( + on_conflict_fields="id", action="UPDATE") + + config = PostgresVectorWriterConfig( + connection_config=connection_config, + table_name=self.default_table_name, + conflict_resolution=conflict_resolution) + + # Generate initial test chunks + test_chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) + self.write_test_pipeline.not_use_test_runner_api = True + + # Insert initial test chunks + with self.write_test_pipeline as p: + _ = ( + p + | "Create initial chunks" >> beam.Create(test_chunks) + | "Write initial chunks" >> config.create_write_transform()) + + # Verify initial data was written correctly + read_query = f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(content AS VARCHAR(255)), + CAST(embedding AS text), + CAST(metadata AS text) + FROM {self.default_table_name} + ORDER BY id desc + """ + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=self.jdbc_url, + username=self.username, + password=self.password, + query=read_query)) + + chunks = ( + rows + | "To Chunks" >> beam.Map(test_utils.row_to_chunk) + | "Key on Index" >> beam.Map(test_utils.key_on_id) + | "Get First 500" >> beam.transforms.combiners.Top.Of( + num_records, key=lambda x: x[0], reverse=True) + | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) + assert_that( + chunks, equal_to([test_chunks]), label='original_chunks_check') + + # Create updated chunks with same IDs but different content, embedding, and + # metadata + updated_chunks = [] + for i in range(num_records): + original_chunk = test_chunks[i] + updated_chunk = Chunk( + id=original_chunk.id, + content=Content(text=f"Updated content {i}"), + embedding=Embedding( + dense_embedding=[float(i * 2), float(i * 2 + 1)] + [0.0] * + (test_utils.VECTOR_SIZE - 2)), + metadata={ + "updated": "true", "timestamp": "2024-02-25" + }) + updated_chunks.append(updated_chunk) + + # Write updated chunks - should update all non-conflict fields + self.write_test_pipeline2.not_use_test_runner_api = True + with self.write_test_pipeline2 as p: + _ = ( + p + | "Create updated Chunks" >> beam.Create(updated_chunks) + | "Write updated Chunks" >> config.create_write_transform()) + + # Read and verify that all non-conflict fields were updated + self.read_test_pipeline2.not_use_test_runner_api = True + with self.read_test_pipeline2 as p: + rows = ( + p + | "Read Updated chunks" >> ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=self.jdbc_url, + username=self.username, + password=self.password, + query=read_query)) + + chunks = ( + rows + | "To Chunks 2" >> beam.Map(test_utils.row_to_chunk) + | "Key on Index 2" >> beam.Map(test_utils.key_on_id) + | "Get First 500 2" >> beam.transforms.combiners.Top.Of( + num_records, key=lambda x: x[0], reverse=True) + | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) + + # Verify that all non-conflict fields were updated + assert_that( + chunks, equal_to([updated_chunks]), label='updated_chunks_check') + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py b/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py new file mode 100644 index 000000000000..cd30766a2886 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/test_utils.py @@ -0,0 +1,105 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import hashlib +import json +from typing import List +from typing import NamedTuple + +import apache_beam as beam +from apache_beam.coders import registry +from apache_beam.coders.row_coder import RowCoder +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding + +TestRow = NamedTuple( + 'TestRow', + [('id', str), ('embedding', List[float]), ('content', str), + ('metadata', str)]) +registry.register_coder(TestRow, RowCoder) + +VECTOR_SIZE = 768 + + +def row_to_chunk(row) -> Chunk: + # Parse embedding string back to float list + embedding_list = [float(x) for x in row.embedding.strip('[]').split(',')] + return Chunk( + id=row.id, + content=Content(text=row.content if hasattr(row, 'content') else None), + embedding=Embedding(dense_embedding=embedding_list), + metadata=json.loads(row.metadata) if hasattr(row, 'metadata') else {}) + + +class ChunkTestUtils: + """Helper functions for generating test Chunks.""" + @staticmethod + def from_seed(seed: int, content_prefix: str, seed_multiplier: int) -> Chunk: + """Creates a deterministic Chunk from a seed value.""" + return Chunk( + id=f"id_{seed}", + content=Content(text=f"{content_prefix}{seed}"), + embedding=Embedding( + dense_embedding=[ + float(seed + i * seed_multiplier) / 100 + for i in range(VECTOR_SIZE) + ]), + metadata={"seed": str(seed)}) + + @staticmethod + def get_expected_values( + range_start: int, + range_end: int, + content_prefix: str = "Testval", + seed_multiplier: int = 1) -> List[Chunk]: + """Returns a range of test Chunks.""" + return [ + ChunkTestUtils.from_seed(i, content_prefix, seed_multiplier) + for i in range(range_start, range_end) + ] + + +class HashingFn(beam.CombineFn): + """Hashing function for verification.""" + def create_accumulator(self): + return [] + + def add_input(self, accumulator, input): + accumulator.append(input.content.text if input.content.text else "") + return accumulator + + def merge_accumulators(self, accumulators): + merged = [] + for acc in accumulators: + merged.extend(acc) + return merged + + def extract_output(self, accumulator): + sorted_values = sorted(accumulator) + return hashlib.md5(''.join(sorted_values).encode()).hexdigest() + + +def generate_expected_hash(num_records: int) -> str: + chunks = ChunkTestUtils.get_expected_values(0, num_records) + values = sorted( + chunk.content.text if chunk.content.text else "" for chunk in chunks) + return hashlib.md5(''.join(values).encode()).hexdigest() + + +def key_on_id(chunk): + return (int(chunk.id.split('_')[1]), chunk) diff --git a/sdks/python/setup.py b/sdks/python/setup.py index f3f419cd9849..d7b6e689fcbe 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -411,6 +411,7 @@ def get_portability_package_data(): 'virtualenv-clone>=0.5,<1.0', ], 'test': [ + 'cloud-sql-python-connector[pg8000]>=1.0.0,<2.0.0', 'docstring-parser>=0.15,<1.0', 'freezegun>=0.3.12', 'jinja2>=3.0,<3.2', From fbbebed8b9809e1e4bfa322bb8b1aede70625a19 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 9 Jun 2025 13:20:03 -0400 Subject: [PATCH 02/10] Trigger tests. --- .../trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json index f1ba03a243ee..455144f02a35 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 5 + "modification": 6 } From 20a90a75a56ccf0169be71b21ff89f7d95a8e845 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 10 Jun 2025 09:51:40 -0400 Subject: [PATCH 03/10] Linter fixes. --- sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py | 2 +- sdks/python/apache_beam/ml/gcp/recommendations_ai.py | 2 +- sdks/python/apache_beam/ml/gcp/videointelligenceml.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py b/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py index 417a04c3d2b4..f66bf2e56405 100644 --- a/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py +++ b/sdks/python/apache_beam/io/gcp/datastore/v1new/helper.py @@ -28,6 +28,7 @@ from typing import List from typing import Union +from cachetools.func import ttl_cache from google.api_core import exceptions from google.api_core.gapic_v1 import client_info from google.cloud import environment_vars @@ -36,7 +37,6 @@ from apache_beam.io.gcp.datastore.v1new import types from apache_beam.version import __version__ as beam_version -from cachetools.func import ttl_cache # https://cloud.google.com/datastore/docs/concepts/errors#error_codes _RETRYABLE_DATASTORE_ERRORS = ( diff --git a/sdks/python/apache_beam/ml/gcp/recommendations_ai.py b/sdks/python/apache_beam/ml/gcp/recommendations_ai.py index 1bce097b6046..077fc83bbd07 100644 --- a/sdks/python/apache_beam/ml/gcp/recommendations_ai.py +++ b/sdks/python/apache_beam/ml/gcp/recommendations_ai.py @@ -24,9 +24,9 @@ from typing import Sequence from typing import Tuple +from cachetools.func import ttl_cache from google.api_core.retry import Retry -from cachetools.func import ttl_cache from apache_beam import pvalue from apache_beam.metrics import Metrics from apache_beam.options.pipeline_options import GoogleCloudOptions diff --git a/sdks/python/apache_beam/ml/gcp/videointelligenceml.py b/sdks/python/apache_beam/ml/gcp/videointelligenceml.py index ebd35d2426c0..25fc258b35a1 100644 --- a/sdks/python/apache_beam/ml/gcp/videointelligenceml.py +++ b/sdks/python/apache_beam/ml/gcp/videointelligenceml.py @@ -21,12 +21,13 @@ from typing import Tuple from typing import Union +from cachetools.func import ttl_cache + from apache_beam import typehints from apache_beam.metrics import Metrics from apache_beam.transforms import DoFn from apache_beam.transforms import ParDo from apache_beam.transforms import PTransform -from cachetools.func import ttl_cache try: from google.cloud import videointelligence From 91d6ec23fb5bc60e0558d7841ec5e220a595306c Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 10 Jun 2025 14:31:21 +0000 Subject: [PATCH 04/10] Fix test --- sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py index cc7db95a1d07..3eb408711202 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py @@ -80,8 +80,7 @@ def setUp(self): self.read_test_pipeline = TestPipeline(is_integration_test=True) self._runner = type(self.read_test_pipeline.runner).__name__ - self.default_table_name = "default_embeddings" - f"{self.POSTGRES_TABLE_PREFIX}" \ + self.default_table_name = f"{self.POSTGRES_TABLE_PREFIX}" \ f"{self.table_suffix}" self.jdbc_url = f'jdbc:postgresql://{self.host}:{self.port}/{self.database}' From e64c35e94f432a8a84a662f5112db30d154a8592 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 10 Jun 2025 16:31:59 +0000 Subject: [PATCH 05/10] Add back tests. Update changes.md. Fix unrelated lint. --- CHANGES.md | 1 + sdks/python/apache_beam/ml/gcp/visionml.py | 3 +- .../ml/rag/ingestion/alloydb_it_test.py | 806 +++++++++++++++++- 3 files changed, 808 insertions(+), 2 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 0c126e4087e7..7a98cd53be77 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -90,6 +90,7 @@ * Python: Added JupyterLab 4.x extension compatibility for enhanced notebook integration ([#34495](https://github.com/apache/beam/pull/34495)). * Python: Argument abbreviation is no longer enabled within Beam. If you previously abbreviated arguments (e.g. `--r` for `--runner`), you will now need to specify the whole argument ([#34934](https://github.com/apache/beam/pull/34934)). * Java: Users of ReadFromKafkaViaSDF transform might encounter pipeline graph compatibility issues when updating the pipeline. To mitigate, set the `updateCompatibilityVersion` option to the SDK version used for the original pipeline, example `--updateCompatabilityVersion=2.64.0` +* Python: Updated `AlloyDBVectorWriterConfig` API to align with new `PostgresVectorWriter` transform. Heres a quick guide to update your code: ([#35225](https://github.com/apache/beam/issues/35225)) ## Deprecations diff --git a/sdks/python/apache_beam/ml/gcp/visionml.py b/sdks/python/apache_beam/ml/gcp/visionml.py index dd29dd377388..c4ef30710d58 100644 --- a/sdks/python/apache_beam/ml/gcp/visionml.py +++ b/sdks/python/apache_beam/ml/gcp/visionml.py @@ -25,6 +25,8 @@ from typing import Tuple from typing import Union +from cachetools.func import ttl_cache + from apache_beam import typehints from apache_beam.metrics import Metrics from apache_beam.transforms import DoFn @@ -32,7 +34,6 @@ from apache_beam.transforms import ParDo from apache_beam.transforms import PTransform from apache_beam.transforms import util -from cachetools.func import ttl_cache try: from google.cloud import vision diff --git a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py index 3eb408711202..3c343cbafab5 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py @@ -30,6 +30,8 @@ from apache_beam.ml.rag.ingestion.alloydb import AlloyDBLanguageConnectorConfig from apache_beam.ml.rag.ingestion.alloydb import AlloyDBVectorWriterConfig from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteTransform +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder +from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to @@ -78,10 +80,18 @@ def skip_if_dataflow_runner(self): def setUp(self): self.write_test_pipeline = TestPipeline(is_integration_test=True) self.read_test_pipeline = TestPipeline(is_integration_test=True) + self.write_test_pipeline2 = TestPipeline(is_integration_test=True) + self.read_test_pipeline2 = TestPipeline(is_integration_test=True) self._runner = type(self.read_test_pipeline.runner).__name__ - self.default_table_name = f"{self.POSTGRES_TABLE_PREFIX}" \ + self.default_table_name = f"{self.ALLOYDB_TABLE_PREFIX}" \ f"{self.table_suffix}" + self.default_table_name = f"{self.ALLOYDB_TABLE_PREFIX}" \ + f"{self.table_suffix}" + self.custom_table_name = f"{self.ALLOYDB_TABLE_PREFIX}" \ + f"_custom_{self.table_suffix}" + self.metadata_conflicts_table = f"{self.ALLOYDB_TABLE_PREFIX}" \ + f"_meta_conf_{self.table_suffix}" self.jdbc_url = f'jdbc:postgresql://{self.host}:{self.port}/{self.database}' @@ -96,6 +106,27 @@ def setUp(self): metadata JSONB ) """) + cursor.execute( + f""" + CREATE TABLE {self.custom_table_name} ( + custom_id TEXT PRIMARY KEY, + embedding_vec VECTOR(2), + content_col TEXT, + metadata JSONB + ) + """) + cursor.execute( + f""" + CREATE TABLE {self.metadata_conflicts_table} ( + id TEXT, + source TEXT, + timestamp TIMESTAMP, + content TEXT, + embedding VECTOR(2), + PRIMARY KEY (id), + UNIQUE (source, timestamp) + ) + """) _LOGGER = logging.getLogger(__name__) _LOGGER.info("Created table %s", self.default_table_name) @@ -111,6 +142,93 @@ def tearDownClass(cls): if hasattr(cls, 'conn'): cls.conn.close() + def test_default_schema(self): + """Test basic write with default schema.""" + connection_config = AlloyDBLanguageConnectorConfig( + username=self.username, + password=self.password, + database_name=self.database, + instance_name="projects/apache-beam-testing/locations/us-central1/\ + clusters/testing-psc/instances/testing-psc-1", + ip_type="PSC") + + config = AlloyDBVectorWriterConfig( + connection_config=connection_config, table_name=self.default_table_name) + + # Create test chunks + num_records = 1500 + sample_size = min(500, num_records // 2) + # Generate test chunks + chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) + + # Run pipeline and verify + self.write_test_pipeline.not_use_test_runner_api = True + + with self.write_test_pipeline as p: + _ = (p | beam.Create(chunks) | config.create_write_transform()) + + self.read_test_pipeline.not_use_test_runner_api = True + # Read pipeline to verify + read_query = f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(content AS VARCHAR(255)), + CAST(embedding AS text), + CAST(metadata AS text) + FROM {self.default_table_name} + """ + + # Read and verify pipeline + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=connection_config.to_connection_config().jdbc_url, + username=self.username, + password=self.password, + query=read_query, + classpath=connection_config.additional_jdbc_args()['classpath'])) + + count_result = rows | "Count All" >> beam.combiners.Count.Globally() + assert_that(count_result, equal_to([num_records]), label='count_check') + + chunks = (rows | "To Chunks" >> beam.Map(row_to_chunk)) + chunk_hashes = chunks | "Hash Chunks" >> beam.CombineGlobally( + test_utils.HashingFn()) + assert_that( + chunk_hashes, + equal_to([test_utils.generate_expected_hash(num_records)]), + label='hash_check') + + # Sample validation + first_n = ( + chunks + | "Key on Index" >> beam.Map(test_utils.key_on_id) + | f"Get First {sample_size}" >> beam.transforms.combiners.Top.Of( + sample_size, key=lambda x: x[0], reverse=True) + | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) + expected_first_n = test_utils.ChunkTestUtils.get_expected_values( + 0, sample_size) + assert_that( + first_n, + equal_to([expected_first_n]), + label=f"first_{sample_size}_check") + + last_n = ( + chunks + | "Key on Index 2" >> beam.Map(test_utils.key_on_id) + | f"Get Last {sample_size}" >> beam.transforms.combiners.Top.Of( + sample_size, key=lambda x: x[0]) + | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) + expected_last_n = test_utils.ChunkTestUtils.get_expected_values( + num_records - sample_size, num_records)[::-1] + assert_that( + last_n, + equal_to([expected_last_n]), + label=f"last_{sample_size}_check") + def test_language_connector(self): """Test language connector.""" self.skip_if_dataflow_runner() @@ -198,6 +316,692 @@ def test_language_connector(self): equal_to([expected_last_n]), label=f"last_{sample_size}_check") + def test_custom_specs(self): + """Test custom specifications for ID, embedding, and content.""" + self.skip_if_dataflow_runner() + num_records = 20 + + specs = ( + ColumnSpecsBuilder().add_custom_column_spec( + ColumnSpec.text( + column_name="custom_id", + value_fn=lambda chunk: + f"timestamp_{chunk.metadata.get('timestamp', '')}") + ).add_custom_column_spec( + ColumnSpec.vector( + column_name="embedding_vec", + value_fn=chunk_embedding_fn)).add_custom_column_spec( + ColumnSpec.text( + column_name="content_col", + value_fn=lambda chunk: + f"{len(chunk.content.text)}:{chunk.content.text}")). + with_metadata_spec().build()) + + connection_config = AlloyDBLanguageConnectorConfig( + username=self.username, + password=self.password, + database_name=self.database, + instance_name="projects/apache-beam-testing/locations/us-central1/\ + clusters/testing-psc/instances/testing-psc-1", + ip_type="PSC") + + writer_config = AlloyDBVectorWriterConfig( + connection_config=connection_config, + table_name=self.custom_table_name, + column_specs=specs) + + # Generate test chunks + test_chunks = [ + Chunk( + id=str(i), + content=Content(text=f"content_{i}"), + embedding=Embedding(dense_embedding=[float(i), float(i + 1)]), + metadata={"timestamp": f"2024-02-02T{i:02d}:00:00"}) + for i in range(num_records) + ] + + # Write pipeline + self.write_test_pipeline.not_use_test_runner_api = True + with self.write_test_pipeline as p: + _ = ( + p | beam.Create(test_chunks) + | VectorDatabaseWriteTransform(writer_config)) + + # Read and verify + read_query = f""" + SELECT + CAST(custom_id AS VARCHAR(255)), + CAST(embedding_vec AS text), + CAST(content_col AS VARCHAR(255)), + CAST(metadata AS text) + FROM {self.custom_table_name} + ORDER BY custom_id + """ + + # Convert BeamRow back to Chunk + def custom_row_to_chunk(row): + # Extract timestamp from custom_id + timestamp = row.custom_id.split('timestamp_')[1] + # Extract index from timestamp + i = int(timestamp.split('T')[1][:2]) + + # Parse embedding vector + embedding_list = [ + float(x) for x in row.embedding_vec.strip('[]').split(',') + ] + + # Extract content from length-prefixed format + content = row.content_col.split(':', 1)[1] + + return Chunk( + id=str(i), + content=Content(text=content), + embedding=Embedding(dense_embedding=embedding_list), + metadata=json.loads(row.metadata)) + + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.custom_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=connection_config.to_connection_config().jdbc_url, + username=self.username, + password=self.password, + query=read_query, + classpath=connection_config.additional_jdbc_args()['classpath'])) + + # Verify count + count_result = rows | "Count All" >> beam.combiners.Count.Globally() + assert_that(count_result, equal_to([num_records]), label='count_check') + + chunks = rows | "To Chunks" >> beam.Map(custom_row_to_chunk) + assert_that(chunks, equal_to(test_chunks), label='chunks_check') + + def test_defaults_with_args_specs(self): + """Test custom specifications for ID, embedding, and content.""" + self.skip_if_dataflow_runner() + num_records = 20 + + specs = ( + ColumnSpecsBuilder().with_id_spec( + column_name="custom_id", + python_type=int, + convert_fn=lambda x: int(x), + sql_typecast="::text").with_content_spec( + column_name="content_col", + convert_fn=lambda x: f"{len(x)}:{x}", + ).with_embedding_spec( + column_name="embedding_vec").with_metadata_spec().build()) + + connection_config = AlloyDBLanguageConnectorConfig( + username=self.username, + password=self.password, + database_name=self.database, + instance_name="projects/apache-beam-testing/locations/us-central1/\ + clusters/testing-psc/instances/testing-psc-1", + ip_type="PSC") + + writer_config = AlloyDBVectorWriterConfig( + connection_config=connection_config, + table_name=self.custom_table_name, + column_specs=specs) + + # Generate test chunks + test_chunks = [ + Chunk( + id=str(i), + content=Content(text=f"content_{i}"), + embedding=Embedding(dense_embedding=[float(i), float(i + 1)]), + metadata={"timestamp": f"2024-02-02T{i:02d}:00:00"}) + for i in range(num_records) + ] + + # Write pipeline + self.write_test_pipeline.not_use_test_runner_api = True + with self.write_test_pipeline as p: + _ = ( + p | beam.Create(test_chunks) + | VectorDatabaseWriteTransform(writer_config)) + + # Read and verify + read_query = f""" + SELECT + CAST(custom_id AS VARCHAR(255)), + CAST(embedding_vec AS text), + CAST(content_col AS VARCHAR(255)), + CAST(metadata AS text) + FROM {self.custom_table_name} + ORDER BY custom_id + """ + + # Convert BeamRow back to Chunk + def custom_row_to_chunk(row): + # Parse embedding vector + embedding_list = [ + float(x) for x in row.embedding_vec.strip('[]').split(',') + ] + + # Extract content from length-prefixed format + content = row.content_col.split(':', 1)[1] + + return Chunk( + id=row.custom_id, + content=Content(text=content), + embedding=Embedding(dense_embedding=embedding_list), + metadata=json.loads(row.metadata)) + + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.custom_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=connection_config.to_connection_config().jdbc_url, + username=self.username, + password=self.password, + query=read_query, + classpath=connection_config.additional_jdbc_args()['classpath'])) + + # Verify count + count_result = rows | "Count All" >> beam.combiners.Count.Globally() + assert_that(count_result, equal_to([num_records]), label='count_check') + + chunks = rows | "To Chunks" >> beam.Map(custom_row_to_chunk) + assert_that(chunks, equal_to(test_chunks), label='chunks_check') + + def test_default_id_embedding_specs(self): + """Test with only default id and embedding specs, others set to None.""" + self.skip_if_dataflow_runner() + num_records = 20 + connection_config = AlloyDBLanguageConnectorConfig( + username=self.username, + password=self.password, + database_name=self.database, + instance_name="projects/apache-beam-testing/locations/us-central1/\ + clusters/testing-psc/instances/testing-psc-1", + ip_type="PSC") + specs = ( + ColumnSpecsBuilder().with_id_spec() # Use default id spec + .with_embedding_spec() # Use default embedding spec + .build()) + + writer_config = AlloyDBVectorWriterConfig( + connection_config=connection_config, + table_name=self.default_table_name, + column_specs=specs) + + # Generate test chunks + test_chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) + + # Write pipeline + self.write_test_pipeline.not_use_test_runner_api = True + with self.write_test_pipeline as p: + _ = ( + p | beam.Create(test_chunks) + | VectorDatabaseWriteTransform(writer_config)) + + # Read and verify only id and embedding + read_query = f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(embedding AS text) + FROM {self.default_table_name} + ORDER BY id + """ + + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=connection_config.to_connection_config().jdbc_url, + username=self.username, + password=self.password, + query=read_query, + classpath=connection_config.additional_jdbc_args()['classpath'])) + + chunks = rows | "To Chunks" >> beam.Map(row_to_chunk) + + # Create expected chunks with None values + expected_chunks = test_utils.ChunkTestUtils.get_expected_values( + 0, num_records) + for chunk in expected_chunks: + chunk.content.text = None + chunk.metadata = {} + + assert_that(chunks, equal_to(expected_chunks), label='chunks_check') + + def test_metadata_spec_and_conflicts(self): + """Test metadata specification and conflict resolution.""" + self.skip_if_dataflow_runner() + num_records = 20 + + specs = ( + ColumnSpecsBuilder().with_id_spec().with_embedding_spec(). + with_content_spec().add_metadata_field( + field="source", + column_name="source", + python_type=str, + sql_typecast=None # Plain text field + ).add_metadata_field( + field="timestamp", python_type=str, + sql_typecast="::timestamp").build()) + + # Conflict resolution on source+timestamp + conflict_resolution = ConflictResolution( + on_conflict_fields=["source", "timestamp"], + action="UPDATE", + update_fields=["embedding", "content"]) + connection_config = AlloyDBLanguageConnectorConfig( + username=self.username, + password=self.password, + database_name=self.database, + instance_name="projects/apache-beam-testing/locations/us-central1/\ + clusters/testing-psc/instances/testing-psc-1", + ip_type="PSC") + writer_config = AlloyDBVectorWriterConfig( + connection_config=connection_config, + table_name=self.metadata_conflicts_table, + column_specs=specs, + conflict_resolution=conflict_resolution) + + # Generate initial test chunks + initial_chunks = [ + Chunk( + id=str(i), + content=Content(text=f"content_{i}"), + embedding=Embedding(dense_embedding=[float(i), float(i + 1)]), + metadata={ + "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00" + }) for i in range(num_records) + ] + + # Write initial chunks + self.write_test_pipeline.not_use_test_runner_api = True + with self.write_test_pipeline as p: + _ = ( + p | "Write Initial" >> beam.Create(initial_chunks) + | VectorDatabaseWriteTransform(writer_config)) + + # Generate conflicting chunks (same source+timestamp, different content) + conflicting_chunks = [ + Chunk( + id=f"new_{i}", + content=Content(text=f"updated_content_{i}"), + embedding=Embedding( + dense_embedding=[float(i) * 2, float(i + 1) * 2]), + metadata={ + "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00" + }) for i in range(num_records) + ] + + # Write conflicting chunks + self.write_test_pipeline2.not_use_test_runner_api = True + with self.write_test_pipeline2 as p: + _ = ( + p | "Write Conflicts" >> beam.Create(conflicting_chunks) + | VectorDatabaseWriteTransform(writer_config)) + + # Read and verify + read_query = f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(embedding AS text), + CAST(content AS VARCHAR(255)), + CAST(source AS VARCHAR(255)), + CAST(timestamp AS VARCHAR(255)) + FROM {self.metadata_conflicts_table} + ORDER BY timestamp, id + """ + + # Expected chunks after conflict resolution + expected_chunks = [ + Chunk( + id=str(i), + content=Content(text=f"updated_content_{i}"), + embedding=Embedding( + dense_embedding=[float(i) * 2, float(i + 1) * 2]), + metadata={ + "source": "source_A", "timestamp": f"2024-02-02T{i:02d}:00:00" + }) for i in range(num_records) + ] + + def metadata_row_to_chunk(row): + return Chunk( + id=row.id, + content=Content(text=row.content), + embedding=Embedding( + dense_embedding=[ + float(x) for x in row.embedding.strip('[]').split(',') + ]), + metadata={ + "source": row.source, + "timestamp": row.timestamp.replace(' ', 'T') + }) + + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.metadata_conflicts_table, + driver_class_name="org.postgresql.Driver", + jdbc_url=connection_config.to_connection_config().jdbc_url, + username=self.username, + password=self.password, + query=read_query, + classpath=connection_config.additional_jdbc_args()['classpath'])) + + chunks = rows | "To Chunks" >> beam.Map(metadata_row_to_chunk) + assert_that(chunks, equal_to(expected_chunks), label='chunks_check') + + def test_conflict_resolution_update(self): + """Test conflict resolution with UPDATE action.""" + self.skip_if_dataflow_runner() + num_records = 20 + + connection_config = AlloyDBLanguageConnectorConfig( + username=self.username, + password=self.password, + database_name=self.database, + instance_name="projects/apache-beam-testing/locations/us-central1/\ + clusters/testing-psc/instances/testing-psc-1", + ip_type="PSC") + + conflict_resolution = ConflictResolution( + on_conflict_fields="id", + action="UPDATE", + update_fields=["embedding", "content"]) + + config = AlloyDBVectorWriterConfig( + connection_config=connection_config, + table_name=self.default_table_name, + conflict_resolution=conflict_resolution) + + # Generate initial test chunks + test_chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) + self.write_test_pipeline.not_use_test_runner_api = True + # Insert initial test chunks + with self.write_test_pipeline as p: + _ = ( + p + | "Create initial chunks" >> beam.Create(test_chunks) + | "Write initial chunks" >> config.create_write_transform()) + + read_query = f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(content AS VARCHAR(255)), + CAST(embedding AS text), + CAST(metadata AS text) + FROM {self.default_table_name} + ORDER BY id desc + """ + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=connection_config.to_connection_config().jdbc_url, + username=self.username, + password=self.password, + query=read_query, + classpath=connection_config.additional_jdbc_args()['classpath'])) + + chunks = ( + rows + | "To Chunks" >> beam.Map(row_to_chunk) + | "Key on Index" >> beam.Map(test_utils.key_on_id) + | "Get First 500" >> beam.transforms.combiners.Top.Of( + num_records, key=lambda x: x[0], reverse=True) + | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) + assert_that( + chunks, equal_to([test_chunks]), label='original_chunks_check') + + updated_chunks = test_utils.ChunkTestUtils.get_expected_values( + 0, num_records, content_prefix="Newcontent", seed_multiplier=2) + self.write_test_pipeline2.not_use_test_runner_api = True + with self.write_test_pipeline2 as p: + _ = ( + p + | "Create updated Chunks" >> beam.Create(updated_chunks) + | "Write updated Chunks" >> config.create_write_transform()) + self.read_test_pipeline2.not_use_test_runner_api = True + with self.read_test_pipeline2 as p: + rows = ( + p + | "Read Updated chunks" >> ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=connection_config.to_connection_config().jdbc_url, + username=self.username, + password=self.password, + query=read_query, + classpath=connection_config.additional_jdbc_args()['classpath'])) + + chunks = ( + rows + | "To Chunks 2" >> beam.Map(row_to_chunk) + | "Key on Index 2" >> beam.Map(test_utils.key_on_id) + | "Get First 500 2" >> beam.transforms.combiners.Top.Of( + num_records, key=lambda x: x[0], reverse=True) + | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) + assert_that( + chunks, equal_to([updated_chunks]), label='updated_chunks_check') + + def test_conflict_resolution_default_ignore(self): + """Test conflict resolution with default.""" + self.skip_if_dataflow_runner() + num_records = 20 + + connection_config = AlloyDBLanguageConnectorConfig( + username=self.username, + password=self.password, + database_name=self.database, + instance_name="projects/apache-beam-testing/locations/us-central1/\ + clusters/testing-psc/instances/testing-psc-1", + ip_type="PSC") + + config = AlloyDBVectorWriterConfig( + connection_config=connection_config, table_name=self.default_table_name) + + # Generate initial test chunks + test_chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) + self.write_test_pipeline.not_use_test_runner_api = True + # Insert initial test chunks + with self.write_test_pipeline as p: + _ = ( + p + | "Create initial chunks" >> beam.Create(test_chunks) + | "Write initial chunks" >> config.create_write_transform()) + + read_query = f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(content AS VARCHAR(255)), + CAST(embedding AS text), + CAST(metadata AS text) + FROM {self.default_table_name} + ORDER BY id desc + """ + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=connection_config.to_connection_config().jdbc_url, + username=self.username, + password=self.password, + query=read_query, + classpath=connection_config.additional_jdbc_args()['classpath'])) + + chunks = ( + rows + | "To Chunks" >> beam.Map(row_to_chunk) + | "Key on Index" >> beam.Map(test_utils.key_on_id) + | "Get First 500" >> beam.transforms.combiners.Top.Of( + num_records, key=lambda x: x[0], reverse=True) + | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) + assert_that( + chunks, equal_to([test_chunks]), label='original_chunks_check') + + updated_chunks = test_utils.ChunkTestUtils.get_expected_values( + 0, num_records, content_prefix="Newcontent", seed_multiplier=2) + self.write_test_pipeline2.not_use_test_runner_api = True + with self.write_test_pipeline2 as p: + _ = ( + p + | "Create updated Chunks" >> beam.Create(updated_chunks) + | "Write updated Chunks" >> config.create_write_transform()) + self.read_test_pipeline2.not_use_test_runner_api = True + with self.read_test_pipeline2 as p: + rows = ( + p + | "Read Updated chunks" >> ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=connection_config.to_connection_config().jdbc_url, + username=self.username, + password=self.password, + query=read_query, + classpath=connection_config.additional_jdbc_args()['classpath'])) + + chunks = ( + rows + | "To Chunks 2" >> beam.Map(row_to_chunk) + | "Key on Index 2" >> beam.Map(test_utils.key_on_id) + | "Get First 500 2" >> beam.transforms.combiners.Top.Of( + num_records, key=lambda x: x[0], reverse=True) + | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) + assert_that(chunks, equal_to([test_chunks]), label='updated_chunks_check') + + def test_conflict_resolution_default_update_fields(self): + """Test conflict resolution with default update fields (all non-conflict + fields).""" + self.skip_if_dataflow_runner() + num_records = 20 + + connection_config = AlloyDBLanguageConnectorConfig( + username=self.username, + password=self.password, + database_name=self.database, + instance_name="projects/apache-beam-testing/locations/us-central1/\ + clusters/testing-psc/instances/testing-psc-1", + ip_type="PSC") + + # Create a conflict resolution with only the conflict field specified + # No update_fields specified - should default to all non-conflict fields + conflict_resolution = ConflictResolution( + on_conflict_fields="id", action="UPDATE") + + config = AlloyDBVectorWriterConfig( + connection_config=connection_config, + table_name=self.default_table_name, + conflict_resolution=conflict_resolution) + + # Generate initial test chunks + test_chunks = test_utils.ChunkTestUtils.get_expected_values(0, num_records) + self.write_test_pipeline.not_use_test_runner_api = True + + # Insert initial test chunks + with self.write_test_pipeline as p: + _ = ( + p + | "Create initial chunks" >> beam.Create(test_chunks) + | "Write initial chunks" >> config.create_write_transform()) + + # Verify initial data was written correctly + read_query = f""" + SELECT + CAST(id AS VARCHAR(255)), + CAST(content AS VARCHAR(255)), + CAST(embedding AS text), + CAST(metadata AS text) + FROM {self.default_table_name} + ORDER BY id desc + """ + self.read_test_pipeline.not_use_test_runner_api = True + with self.read_test_pipeline as p: + rows = ( + p + | ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=connection_config.to_connection_config().jdbc_url, + username=self.username, + password=self.password, + query=read_query, + classpath=connection_config.additional_jdbc_args()['classpath'])) + + chunks = ( + rows + | "To Chunks" >> beam.Map(row_to_chunk) + | "Key on Index" >> beam.Map(test_utils.key_on_id) + | "Get First 500" >> beam.transforms.combiners.Top.Of( + num_records, key=lambda x: x[0], reverse=True) + | "Remove Keys 1" >> beam.Map(lambda xs: [x[1] for x in xs])) + assert_that( + chunks, equal_to([test_chunks]), label='original_chunks_check') + + # Create updated chunks with same IDs but different content, embedding, and + # metadata + updated_chunks = [] + for i in range(num_records): + original_chunk = test_chunks[i] + updated_chunk = Chunk( + id=original_chunk.id, + content=Content(text=f"Updated content {i}"), + embedding=Embedding( + dense_embedding=[float(i * 2), float(i * 2 + 1)] + [0.0] * + (test_utils.VECTOR_SIZE - 2)), + metadata={ + "updated": "true", "timestamp": "2024-02-25" + }) + updated_chunks.append(updated_chunk) + + # Write updated chunks - should update all non-conflict fields + self.write_test_pipeline2.not_use_test_runner_api = True + with self.write_test_pipeline2 as p: + _ = ( + p + | "Create updated Chunks" >> beam.Create(updated_chunks) + | "Write updated Chunks" >> config.create_write_transform()) + + # Read and verify that all non-conflict fields were updated + self.read_test_pipeline2.not_use_test_runner_api = True + with self.read_test_pipeline2 as p: + rows = ( + p + | "Read Updated chunks" >> ReadFromJdbc( + table_name=self.default_table_name, + driver_class_name="org.postgresql.Driver", + jdbc_url=connection_config.to_connection_config().jdbc_url, + username=self.username, + password=self.password, + query=read_query, + classpath=connection_config.additional_jdbc_args()['classpath'])) + + chunks = ( + rows + | "To Chunks 2" >> beam.Map(row_to_chunk) + | "Key on Index 2" >> beam.Map(test_utils.key_on_id) + | "Get First 500 2" >> beam.transforms.combiners.Top.Of( + num_records, key=lambda x: x[0], reverse=True) + | "Remove Keys 2" >> beam.Map(lambda xs: [x[1] for x in xs])) + + # Verify that all non-conflict fields were updated + assert_that( + chunks, equal_to([updated_chunks]), label='updated_chunks_check') + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) From 92fef0c3fc33580574aef2d40bb178c2652f8ab4 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 10 Jun 2025 16:34:37 +0000 Subject: [PATCH 06/10] Drop test tables. --- sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py index 3c343cbafab5..343b1d90e985 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py @@ -134,6 +134,8 @@ def tearDown(self): # Drop test table with self.conn.cursor() as cursor: cursor.execute(f"DROP TABLE IF EXISTS {self.default_table_name}") + cursor.execute(f"DROP TABLE IF EXISTS {self.custom_table_name}") + cursor.execute(f"DROP TABLE IF EXISTS {self.metadata_conflicts_table}") _LOGGER = logging.getLogger(__name__) _LOGGER.info("Dropped table %s", self.default_table_name) From e013bed504e36fac340fe4cd05d8f5d1d64cacc7 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 10 Jun 2025 17:23:07 +0000 Subject: [PATCH 07/10] Fix test --- sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py index 343b1d90e985..f49071e74afe 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py @@ -48,7 +48,7 @@ os.environ.get('ALLOYDB_PASSWORD'), "ALLOYDB_PASSWORD environment var is not provided") class AlloydbVectorWriterConfigTest(unittest.TestCase): - POSTGRES_TABLE_PREFIX = 'python_rag_postgres_' + ALLOYDB_TABLE_PREFIX = 'python_rag_postgres_' @classmethod def setUpClass(cls): From 1d634b7374865d02cba6b9c024a8fc3bb4cf12fe Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 10 Jun 2025 18:01:19 +0000 Subject: [PATCH 08/10] Fix tests. --- .../ml/rag/ingestion/alloydb_it_test.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py index f49071e74afe..8985bfcb0a3b 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py @@ -196,7 +196,7 @@ def test_default_schema(self): count_result = rows | "Count All" >> beam.combiners.Count.Globally() assert_that(count_result, equal_to([num_records]), label='count_check') - chunks = (rows | "To Chunks" >> beam.Map(row_to_chunk)) + chunks = (rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk)) chunk_hashes = chunks | "Hash Chunks" >> beam.CombineGlobally( test_utils.HashingFn()) assert_that( @@ -567,7 +567,7 @@ def test_default_id_embedding_specs(self): query=read_query, classpath=connection_config.additional_jdbc_args()['classpath'])) - chunks = rows | "To Chunks" >> beam.Map(row_to_chunk) + chunks = rows | "To Chunks" >> beam.Map(test_utils.row_to_chunk) # Create expected chunks with None values expected_chunks = test_utils.ChunkTestUtils.get_expected_values( @@ -759,7 +759,7 @@ def test_conflict_resolution_update(self): chunks = ( rows - | "To Chunks" >> beam.Map(row_to_chunk) + | "To Chunks" >> beam.Map(test_utils.row_to_chunk) | "Key on Index" >> beam.Map(test_utils.key_on_id) | "Get First 500" >> beam.transforms.combiners.Top.Of( num_records, key=lambda x: x[0], reverse=True) @@ -790,7 +790,7 @@ def test_conflict_resolution_update(self): chunks = ( rows - | "To Chunks 2" >> beam.Map(row_to_chunk) + | "To Chunks 2" >> beam.Map(test_utils.row_to_chunk) | "Key on Index 2" >> beam.Map(test_utils.key_on_id) | "Get First 500 2" >> beam.transforms.combiners.Top.Of( num_records, key=lambda x: x[0], reverse=True) @@ -848,7 +848,7 @@ def test_conflict_resolution_default_ignore(self): chunks = ( rows - | "To Chunks" >> beam.Map(row_to_chunk) + | "To Chunks" >> beam.Map(test_utils.row_to_chunk) | "Key on Index" >> beam.Map(test_utils.key_on_id) | "Get First 500" >> beam.transforms.combiners.Top.Of( num_records, key=lambda x: x[0], reverse=True) @@ -879,7 +879,7 @@ def test_conflict_resolution_default_ignore(self): chunks = ( rows - | "To Chunks 2" >> beam.Map(row_to_chunk) + | "To Chunks 2" >> beam.Map(test_utils.row_to_chunk) | "Key on Index 2" >> beam.Map(test_utils.key_on_id) | "Get First 500 2" >> beam.transforms.combiners.Top.Of( num_records, key=lambda x: x[0], reverse=True) @@ -946,7 +946,7 @@ def test_conflict_resolution_default_update_fields(self): chunks = ( rows - | "To Chunks" >> beam.Map(row_to_chunk) + | "To Chunks" >> beam.Map(test_utils.row_to_chunk) | "Key on Index" >> beam.Map(test_utils.key_on_id) | "Get First 500" >> beam.transforms.combiners.Top.Of( num_records, key=lambda x: x[0], reverse=True) @@ -994,7 +994,7 @@ def test_conflict_resolution_default_update_fields(self): chunks = ( rows - | "To Chunks 2" >> beam.Map(row_to_chunk) + | "To Chunks 2" >> beam.Map(test_utils.row_to_chunk) | "Key on Index 2" >> beam.Map(test_utils.key_on_id) | "Get First 500 2" >> beam.transforms.combiners.Top.Of( num_records, key=lambda x: x[0], reverse=True) From c477b2629e22fca01910e92e860a47b27865127c Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 10 Jun 2025 18:39:50 +0000 Subject: [PATCH 09/10] Fix test. --- .../ml/rag/ingestion/alloydb_it_test.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py index 8985bfcb0a3b..a077fc7b8685 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py @@ -15,16 +15,21 @@ # limitations under the License. # +import json import logging import os import secrets import time import unittest +from typing import List +from typing import NamedTuple import psycopg2 import pytest import apache_beam as beam +from apache_beam.coders import registry +from apache_beam.coders.row_coder import RowCoder from apache_beam.io.jdbc import ReadFromJdbc from apache_beam.ml.rag.ingestion import test_utils from apache_beam.ml.rag.ingestion.alloydb import AlloyDBLanguageConnectorConfig @@ -32,12 +37,37 @@ from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteTransform from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to _LOGGER = logging.getLogger(__name__) +CustomSpecsRow = NamedTuple( + 'CustomSpecsRow', + [ + ('custom_id', str), # For id_spec test + ('embedding_vec', List[float]), # For embedding_spec test + ('content_col', str), # For content_spec test + ('metadata', str) + ]) +registry.register_coder(CustomSpecsRow, RowCoder) + +MetadataConflictRow = NamedTuple( + 'MetadataConflictRow', + [ + ('id', str), + ('source', str), # For metadata_spec and composite key + ('timestamp', str), # For metadata_spec and composite key + ('content', str), + ('embedding', List[float]), + ('metadata', str) + ]) +registry.register_coder(MetadataConflictRow, RowCoder) + @pytest.mark.uses_gcp_java_expansion_service @unittest.skipUnless( From ae9aa55b3db7a22ae763f6527828f26c76afef71 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 10 Jun 2025 18:49:26 +0000 Subject: [PATCH 10/10] Fix tests. --- sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py index a077fc7b8685..ce98de19a1de 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/alloydb_it_test.py @@ -35,8 +35,10 @@ from apache_beam.ml.rag.ingestion.alloydb import AlloyDBLanguageConnectorConfig from apache_beam.ml.rag.ingestion.alloydb import AlloyDBVectorWriterConfig from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteTransform +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder from apache_beam.ml.rag.ingestion.postgres_common import ConflictResolution +from apache_beam.ml.rag.ingestion.postgres_common import chunk_embedding_fn from apache_beam.ml.rag.types import Chunk from apache_beam.ml.rag.types import Content from apache_beam.ml.rag.types import Embedding @@ -114,8 +116,6 @@ def setUp(self): self.read_test_pipeline2 = TestPipeline(is_integration_test=True) self._runner = type(self.read_test_pipeline.runner).__name__ - self.default_table_name = f"{self.ALLOYDB_TABLE_PREFIX}" \ - f"{self.table_suffix}" self.default_table_name = f"{self.ALLOYDB_TABLE_PREFIX}" \ f"{self.table_suffix}" self.custom_table_name = f"{self.ALLOYDB_TABLE_PREFIX}" \ @@ -157,7 +157,6 @@ def setUp(self): UNIQUE (source, timestamp) ) """) - _LOGGER = logging.getLogger(__name__) _LOGGER.info("Created table %s", self.default_table_name) def tearDown(self): @@ -166,7 +165,6 @@ def tearDown(self): cursor.execute(f"DROP TABLE IF EXISTS {self.default_table_name}") cursor.execute(f"DROP TABLE IF EXISTS {self.custom_table_name}") cursor.execute(f"DROP TABLE IF EXISTS {self.metadata_conflicts_table}") - _LOGGER = logging.getLogger(__name__) _LOGGER.info("Dropped table %s", self.default_table_name) @classmethod