From b7e33eeb6a7eec084b78ce0ff9159e743ac2751e Mon Sep 17 00:00:00 2001 From: Achal Shah Date: Thu, 2 Jun 2022 13:33:16 -0700 Subject: [PATCH] tests for pg and mysql Signed-off-by: Achal Shah --- sdk/python/feast/infra/registry_stores/sql.py | 18 +- .../registration/test_sql_registry.py | 303 +++++++++++++++++- setup.py | 1 + 3 files changed, 310 insertions(+), 12 deletions(-) diff --git a/sdk/python/feast/infra/registry_stores/sql.py b/sdk/python/feast/infra/registry_stores/sql.py index 3ba3ed0196..9d19c44439 100644 --- a/sdk/python/feast/infra/registry_stores/sql.py +++ b/sdk/python/feast/infra/registry_stores/sql.py @@ -53,7 +53,7 @@ entities = Table( "entities", metadata, - Column("entity_id", String, primary_key=True), + Column("entity_id", String(50), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("entity_proto", LargeBinary, nullable=False), ) @@ -61,7 +61,7 @@ data_sources = Table( "data_sources", metadata, - Column("data_source_name", String, primary_key=True), + Column("data_source_name", String(50), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("data_source_proto", LargeBinary, nullable=False), ) @@ -69,7 +69,7 @@ feature_views = Table( "feature_views", metadata, - Column("feature_view_name", String, primary_key=True), + Column("feature_view_name", String(50), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("materialized_intervals", LargeBinary, nullable=True), Column("feature_view_proto", LargeBinary, nullable=False), @@ -78,7 +78,7 @@ request_feature_views = Table( "request_feature_views", metadata, - Column("feature_view_name", String, primary_key=True), + Column("feature_view_name", String(50), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("feature_view_proto", LargeBinary, nullable=False), ) @@ -86,7 +86,7 @@ on_demand_feature_views = Table( "on_demand_feature_views", metadata, - Column("feature_view_name", String, primary_key=True), + Column("feature_view_name", String(50), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("feature_view_proto", LargeBinary, nullable=False), ) @@ -94,7 +94,7 @@ feature_user_metadata = Table( "feature_metadata", metadata, - Column("feature_name", String, primary_key=True), + Column("feature_name", String(50), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("feature_metadata_binary", LargeBinary, nullable=False), ) @@ -102,7 +102,7 @@ feature_services = Table( "feature_services", metadata, - Column("feature_service_name", String, primary_key=True), + Column("feature_service_name", String(50), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("feature_service_proto", LargeBinary, nullable=False), ) @@ -110,7 +110,7 @@ saved_datasets = Table( "saved_datasets", metadata, - Column("saved_dataset_name", String, primary_key=True), + Column("saved_dataset_name", String(50), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("saved_dataset_proto", LargeBinary, nullable=False), ) @@ -118,7 +118,7 @@ validation_references = Table( "validation_references", metadata, - Column("validation_reference_name", String, primary_key=True), + Column("validation_reference_name", String(50), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("validation_reference_proto", LargeBinary, nullable=False), ) diff --git a/sdk/python/tests/integration/registration/test_sql_registry.py b/sdk/python/tests/integration/registration/test_sql_registry.py index 34b01f8149..f966b45758 100644 --- a/sdk/python/tests/integration/registration/test_sql_registry.py +++ b/sdk/python/tests/integration/registration/test_sql_registry.py @@ -16,10 +16,11 @@ import pandas as pd import pytest +from pytest_lazyfixture import lazy_fixture from testcontainers.core.container import DockerContainer from testcontainers.core.waiting_utils import wait_for_logs -from feast import FileSource +from feast import Feature, FileSource, RequestSource from feast.data_format import ParquetFormat from feast.entity import Entity from feast.feature_view import FeatureView @@ -27,7 +28,8 @@ from feast.infra.registry_stores.sql import SqlRegistry from feast.on_demand_feature_view import on_demand_feature_view from feast.repo_config import RegistryConfig -from feast.types import Array, Bytes, Float32, Int64, String +from feast.types import Array, Bytes, Float32, Int32, Int64, String +from feast.value_type import ValueType POSTGRES_USER = "test" POSTGRES_PASSWORD = "test" @@ -38,7 +40,7 @@ @pytest.fixture(scope="session") -def sql_registry(): +def pg_registry(): container = ( DockerContainer("postgres:latest") .with_exposed_ports(5432) @@ -66,6 +68,39 @@ def sql_registry(): container.stop() +@pytest.fixture(scope="session") +def mysql_registry(): + container = ( + DockerContainer("mysql:latest") + .with_exposed_ports(3306) + .with_env("MYSQL_RANDOM_ROOT_PASSWORD", "true") + .with_env("MYSQL_USER", POSTGRES_USER) + .with_env("MYSQL_PASSWORD", POSTGRES_PASSWORD) + .with_env("MYSQL_DATABASE", POSTGRES_DB) + ) + + container.start() + + log_string_to_wait_for = "/usr/sbin/mysqld: ready for connections. Version: '8.0.29' socket: '/var/run/mysqld/mysqld.sock' port: 3306" + waited = wait_for_logs( + container=container, predicate=log_string_to_wait_for, timeout=30, interval=10, + ) + logger.info("Waited for %s seconds until mysql container was up", waited) + container_port = container.get_exposed_port(3306) + + registry_config = RegistryConfig( + registry_type="sql", + path=f"mysql+mysqldb://{POSTGRES_USER}:{POSTGRES_PASSWORD}@127.0.0.1:{container_port}/{POSTGRES_DB}", + ) + + yield SqlRegistry(registry_config, None) + + container.stop() + + +@pytest.mark.parametrize( + "sql_registry", [lazy_fixture("mysql_registry"), lazy_fixture("pg_registry")], +) def test_apply_entity_success(sql_registry): entity = Entity( name="driver_car_id", description="Car driver id", tags={"team": "matchmaking"}, @@ -103,6 +138,9 @@ def test_apply_entity_success(sql_registry): @pytest.mark.integration +@pytest.mark.parametrize( + "sql_registry", [lazy_fixture("mysql_registry"), lazy_fixture("pg_registry")], +) def test_apply_entity_integration(sql_registry): entity = Entity( name="driver_car_id", description="Car driver id", tags={"team": "matchmaking"}, @@ -135,6 +173,9 @@ def test_apply_entity_integration(sql_registry): sql_registry.teardown() +@pytest.mark.parametrize( + "sql_registry", [lazy_fixture("mysql_registry"), lazy_fixture("pg_registry")], +) def test_apply_feature_view_success(sql_registry): # Create Feature Views batch_source = FileSource( @@ -203,6 +244,9 @@ def test_apply_feature_view_success(sql_registry): sql_registry.teardown() +@pytest.mark.parametrize( + "sql_registry", [lazy_fixture("mysql_registry"), lazy_fixture("pg_registry")], +) def test_apply_on_demand_feature_view_success(sql_registry): # Create Feature Views driver_stats = FileSource( @@ -268,3 +312,256 @@ def location_features_from_push(inputs: pd.DataFrame) -> pd.DataFrame: assert len(feature_views) == 0 sql_registry.teardown() + + +# TODO(kevjumba): remove this in feast 0.23 when deprecating +@pytest.mark.parametrize( + "sql_registry", [lazy_fixture("mysql_registry"), lazy_fixture("pg_registry")], +) +@pytest.mark.parametrize( + "request_source_schema", + [[Field(name="my_input_1", dtype=Int32)], {"my_input_1": ValueType.INT32}], +) +def test_modify_feature_views_success(sql_registry, request_source_schema): + # Create Feature Views + batch_source = FileSource( + file_format=ParquetFormat(), + path="file://feast/*", + timestamp_field="ts_col", + created_timestamp_column="timestamp", + ) + + request_source = RequestSource(name="request_source", schema=request_source_schema,) + + entity = Entity(name="fs1_my_entity_1", join_keys=["test"]) + + fv1 = FeatureView( + name="my_feature_view_1", + schema=[Field(name="fs1_my_feature_1", dtype=Int64)], + entities=[entity], + tags={"team": "matchmaking"}, + batch_source=batch_source, + ttl=timedelta(minutes=5), + ) + + @on_demand_feature_view( + features=[ + Feature(name="odfv1_my_feature_1", dtype=ValueType.STRING), + Feature(name="odfv1_my_feature_2", dtype=ValueType.INT32), + ], + sources=[request_source], + ) + def odfv1(feature_df: pd.DataFrame) -> pd.DataFrame: + data = pd.DataFrame() + data["odfv1_my_feature_1"] = feature_df["my_input_1"].astype("category") + data["odfv1_my_feature_2"] = feature_df["my_input_1"].astype("int32") + return data + + project = "project" + + # Register Feature Views + sql_registry.apply_feature_view(odfv1, project) + sql_registry.apply_feature_view(fv1, project) + + # Modify odfv by changing a single feature dtype + @on_demand_feature_view( + features=[ + Feature(name="odfv1_my_feature_1", dtype=ValueType.FLOAT), + Feature(name="odfv1_my_feature_2", dtype=ValueType.INT32), + ], + sources=[request_source], + ) + def odfv1(feature_df: pd.DataFrame) -> pd.DataFrame: + data = pd.DataFrame() + data["odfv1_my_feature_1"] = feature_df["my_input_1"].astype("float") + data["odfv1_my_feature_2"] = feature_df["my_input_1"].astype("int32") + return data + + # Apply the modified odfv + sql_registry.apply_feature_view(odfv1, project) + + # Check odfv + on_demand_feature_views = sql_registry.list_on_demand_feature_views(project) + + assert ( + len(on_demand_feature_views) == 1 + and on_demand_feature_views[0].name == "odfv1" + and on_demand_feature_views[0].features[0].name == "odfv1_my_feature_1" + and on_demand_feature_views[0].features[0].dtype == Float32 + and on_demand_feature_views[0].features[1].name == "odfv1_my_feature_2" + and on_demand_feature_views[0].features[1].dtype == Int32 + ) + request_schema = on_demand_feature_views[0].get_request_data_schema() + assert ( + list(request_schema.keys())[0] == "my_input_1" + and list(request_schema.values())[0] == ValueType.INT32 + ) + + feature_view = sql_registry.get_on_demand_feature_view("odfv1", project) + assert ( + feature_view.name == "odfv1" + and feature_view.features[0].name == "odfv1_my_feature_1" + and feature_view.features[0].dtype == Float32 + and feature_view.features[1].name == "odfv1_my_feature_2" + and feature_view.features[1].dtype == Int32 + ) + request_schema = feature_view.get_request_data_schema() + assert ( + list(request_schema.keys())[0] == "my_input_1" + and list(request_schema.values())[0] == ValueType.INT32 + ) + + # Make sure fv1 is untouched + feature_views = sql_registry.list_feature_views(project) + + # List Feature Views + assert ( + len(feature_views) == 1 + and feature_views[0].name == "my_feature_view_1" + and feature_views[0].features[0].name == "fs1_my_feature_1" + and feature_views[0].features[0].dtype == Int64 + and feature_views[0].entities[0] == "fs1_my_entity_1" + ) + + feature_view = sql_registry.get_feature_view("my_feature_view_1", project) + assert ( + feature_view.name == "my_feature_view_1" + and feature_view.features[0].name == "fs1_my_feature_1" + and feature_view.features[0].dtype == Int64 + and feature_view.entities[0] == "fs1_my_entity_1" + ) + + sql_registry.teardown() + + +@pytest.mark.integration +@pytest.mark.parametrize( + "sql_registry", [lazy_fixture("mysql_registry"), lazy_fixture("pg_registry")], +) +def test_apply_feature_view_integration(sql_registry): + # Create Feature Views + batch_source = FileSource( + file_format=ParquetFormat(), + path="file://feast/*", + timestamp_field="ts_col", + created_timestamp_column="timestamp", + ) + + entity = Entity(name="fs1_my_entity_1", join_keys=["test"]) + + fv1 = FeatureView( + name="my_feature_view_1", + schema=[ + Field(name="fs1_my_feature_1", dtype=Int64), + Field(name="fs1_my_feature_2", dtype=String), + Field(name="fs1_my_feature_3", dtype=Array(String)), + Field(name="fs1_my_feature_4", dtype=Array(Bytes)), + ], + entities=[entity], + tags={"team": "matchmaking"}, + batch_source=batch_source, + ttl=timedelta(minutes=5), + ) + + project = "project" + + # Register Feature View + sql_registry.apply_feature_view(fv1, project) + + feature_views = sql_registry.list_feature_views(project) + + # List Feature Views + assert ( + len(feature_views) == 1 + and feature_views[0].name == "my_feature_view_1" + and feature_views[0].features[0].name == "fs1_my_feature_1" + and feature_views[0].features[0].dtype == Int64 + and feature_views[0].features[1].name == "fs1_my_feature_2" + and feature_views[0].features[1].dtype == String + and feature_views[0].features[2].name == "fs1_my_feature_3" + and feature_views[0].features[2].dtype == Array(String) + and feature_views[0].features[3].name == "fs1_my_feature_4" + and feature_views[0].features[3].dtype == Array(Bytes) + and feature_views[0].entities[0] == "fs1_my_entity_1" + ) + + feature_view = sql_registry.get_feature_view("my_feature_view_1", project) + assert ( + feature_view.name == "my_feature_view_1" + and feature_view.features[0].name == "fs1_my_feature_1" + and feature_view.features[0].dtype == Int64 + and feature_view.features[1].name == "fs1_my_feature_2" + and feature_view.features[1].dtype == String + and feature_view.features[2].name == "fs1_my_feature_3" + and feature_view.features[2].dtype == Array(String) + and feature_view.features[3].name == "fs1_my_feature_4" + and feature_view.features[3].dtype == Array(Bytes) + and feature_view.entities[0] == "fs1_my_entity_1" + ) + + sql_registry.delete_feature_view("my_feature_view_1", project) + feature_views = sql_registry.list_feature_views(project) + assert len(feature_views) == 0 + + sql_registry.teardown() + + +@pytest.mark.integration +@pytest.mark.parametrize( + "sql_registry", [lazy_fixture("mysql_registry"), lazy_fixture("pg_registry")], +) +def test_apply_data_source(sql_registry): + # Create Feature Views + batch_source = FileSource( + name="test_source", + file_format=ParquetFormat(), + path="file://feast/*", + timestamp_field="ts_col", + created_timestamp_column="timestamp", + ) + + entity = Entity(name="fs1_my_entity_1", join_keys=["test"]) + + fv1 = FeatureView( + name="my_feature_view_1", + schema=[ + Field(name="fs1_my_feature_1", dtype=Int64), + Field(name="fs1_my_feature_2", dtype=String), + Field(name="fs1_my_feature_3", dtype=Array(String)), + Field(name="fs1_my_feature_4", dtype=Array(Bytes)), + ], + entities=[entity], + tags={"team": "matchmaking"}, + batch_source=batch_source, + ttl=timedelta(minutes=5), + ) + + project = "project" + + # Register data source and feature view + sql_registry.apply_data_source(batch_source, project, commit=False) + sql_registry.apply_feature_view(fv1, project, commit=True) + + registry_feature_views = sql_registry.list_feature_views(project) + registry_data_sources = sql_registry.list_data_sources(project) + assert len(registry_feature_views) == 1 + assert len(registry_data_sources) == 1 + registry_feature_view = registry_feature_views[0] + assert registry_feature_view.batch_source == batch_source + registry_data_source = registry_data_sources[0] + assert registry_data_source == batch_source + + # Check that change to batch source propagates + batch_source.timestamp_field = "new_ts_col" + sql_registry.apply_data_source(batch_source, project, commit=False) + sql_registry.apply_feature_view(fv1, project, commit=True) + registry_feature_views = sql_registry.list_feature_views(project) + registry_data_sources = sql_registry.list_data_sources(project) + assert len(registry_feature_views) == 1 + assert len(registry_data_sources) == 1 + registry_feature_view = registry_feature_views[0] + assert registry_feature_view.batch_source == batch_source + registry_batch_source = sql_registry.list_data_sources(project)[0] + assert registry_batch_source == batch_source + + sql_registry.teardown() diff --git a/setup.py b/setup.py index 16409de120..a9499924eb 100644 --- a/setup.py +++ b/setup.py @@ -132,6 +132,7 @@ "moto", "mypy==0.931", "mypy-protobuf==3.1", + "mysqlclient", "avro==1.10.0", "gcsfs>=0.4.0,<=2022.01.0", "urllib3>=1.25.4,<2",