Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Resolve issue where BigQuery caches fail to load on streams without a primary key, or when a table rename is required #122

Merged
merged 20 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions airbyte/_processors/sql/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,36 @@ def _get_tables_list(
table.replace(schema_prefix, "", 1) if table.startswith(schema_prefix) else table
for table in tables
]

def _swap_temp_table_with_final_table(
self,
stream_name: str,
temp_table_name: str,
final_table_name: str,
) -> None:
"""Swap the temp table with the main one, dropping the old version of the 'final' table.

The BigQuery RENAME implementation requires that the table schema (dataset) is named in the
first part of the ALTER statement, but not in the second part.

For example, BigQuery expects this format:

ALTER TABLE my_schema.my_old_table_name RENAME TO my_new_table_name;
"""
if final_table_name is None:
raise exc.AirbyteLibInternalError(message="Arg 'final_table_name' cannot be None.")
if temp_table_name is None:
raise exc.AirbyteLibInternalError(message="Arg 'temp_table_name' cannot be None.")

_ = stream_name
deletion_name = f"{final_table_name}_deleteme"
commands = "\n".join(
[
f"ALTER TABLE {self._fully_qualified(final_table_name)} "
f"RENAME TO {deletion_name};",
f"ALTER TABLE {self._fully_qualified(temp_table_name)} "
f"RENAME TO {final_table_name};",
f"DROP TABLE {self._fully_qualified(deletion_name)};",
]
)
self._execute_sql(commands)
48 changes: 48 additions & 0 deletions airbyte/_util/google_secrets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Helpers for accessing Google secrets."""

from __future__ import annotations

import json
import os

from google.cloud import secretmanager


def get_gcp_secret(
project_name: str,
secret_name: str,
) -> str:
"""Try to get a GCP secret from the environment, or raise an error.

We assume that the Google service account credentials file contents are stored in the
environment variable GCP_GSM_CREDENTIALS. If this environment variable is not set, we raise an
error. Otherwise, we use the Google Secret Manager API to fetch the secret with the given name.
"""
if "GCP_GSM_CREDENTIALS" not in os.environ:
raise EnvironmentError( # noqa: TRY003, UP024
"GCP_GSM_CREDENTIALS env variable not set, can't fetch secrets. Make sure they are set "
"up as described: "
"https://github.com/airbytehq/airbyte/blob/master/airbyte-ci/connectors/ci_credentials/"
"README.md#get-gsm-access"
)

# load secrets from GSM using the GCP_GSM_CREDENTIALS env variable
secret_client = secretmanager.SecretManagerServiceClient.from_service_account_info(
json.loads(os.environ["GCP_GSM_CREDENTIALS"])
)
return secret_client.access_secret_version(
name=f"projects/{project_name}/secrets/{secret_name}/versions/latest"
).payload.data.decode("UTF-8")


def get_gcp_secret_json(
project_name: str,
secret_name: str,
) -> dict:
"""Get a JSON GCP secret and return as a dict.

We assume that the Google service account credentials file contents are stored in the
environment variable GCP_GSM_CREDENTIALS. If this environment variable is not set, we raise an
error. Otherwise, we use the Google Secret Manager API to fetch the secret with the given name.
"""
return json.loads(get_gcp_secret(secret_name, project_name))
13 changes: 10 additions & 3 deletions airbyte/caches/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from __future__ import annotations

import urllib
from typing import Any

from overrides import overrides
from pydantic import root_validator

from airbyte._processors.sql.bigquery import BigQuerySqlProcessor
from airbyte.caches.base import (
Expand All @@ -41,9 +43,14 @@ class BigQueryCache(CacheBase):

_sql_processor_class: type[BigQuerySqlProcessor] = BigQuerySqlProcessor

def __post_init__(self) -> None:
"""Initialize the BigQuery cache."""
self.schema_name = self.dataset_name
@root_validator(pre=True)
@classmethod
def set_schema_name(cls, values: dict[str, Any]) -> dict[str, Any]:
dataset_name = values.get("dataset_name")
if dataset_name is None:
raise ValueError("dataset_name must be defined") # noqa: TRY003
values["schema_name"] = dataset_name
return values

@overrides
def get_database_name(self) -> str:
Expand Down
20 changes: 7 additions & 13 deletions examples/run_bigquery_faker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,20 @@

from __future__ import annotations

import json
import os
import tempfile
import warnings

from google.cloud import secretmanager

import airbyte as ab
from airbyte._util.google_secrets import get_gcp_secret_json
from airbyte.caches.bigquery import BigQueryCache


warnings.filterwarnings("ignore", message="Cannot create BigQuery Storage client")


# load secrets from GSM using the GCP_GSM_CREDENTIALS env variable
secret_client = secretmanager.SecretManagerServiceClient.from_service_account_info(
json.loads(os.environ["GCP_GSM_CREDENTIALS"])
)

bigquery_destination_secret = json.loads(
secret_client.access_secret_version(
name="projects/dataline-integration-testing/secrets/SECRET_DESTINATION-BIGQUERY_CREDENTIALS__CREDS/versions/latest"
).payload.data.decode("UTF-8")
bigquery_destination_secret = get_gcp_secret_json(
project_name="dataline-integration-testing",
secret_name="SECRET_DESTINATION-BIGQUERY_CREDENTIALS__CREDS",
)


Expand All @@ -55,6 +46,9 @@ def main() -> None:

result = source.read(cache)

# Read a second time to make sure table swaps and incremental are working.
result = source.read(cache)

for name, records in result.streams.items():
print(f"Stream {name}: {len(records)} records")

Expand Down
33 changes: 8 additions & 25 deletions examples/run_integ_test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
"""
from __future__ import annotations

import json
import os
import sys
from typing import Any

from google.cloud import secretmanager

import airbyte as ab
from airbyte._util.google_secrets import get_gcp_secret_json


GCP_SECRETS_PROJECT_NAME = "dataline-integration-testing"


def get_secret_name(connector_name: str) -> str:
Expand All @@ -35,31 +34,15 @@ def get_secret_name(connector_name: str) -> str:
return f"SECRET_{connector_name.upper()}_CREDS"


def get_integ_test_config(secret_name: str) -> dict[str, Any]:
if "GCP_GSM_CREDENTIALS" not in os.environ:
raise Exception( # noqa: TRY002, TRY003
f"GCP_GSM_CREDENTIALS env var not set, can't fetch secrets for '{connector_name}'. "
"Make sure they are set up as described: "
"https://github.com/airbytehq/airbyte/blob/master/airbyte-ci/connectors/ci_credentials/"
"README.md#get-gsm-access"
)

secret_client = secretmanager.SecretManagerServiceClient.from_service_account_info(
json.loads(os.environ["GCP_GSM_CREDENTIALS"])
)
return json.loads(
secret_client.access_secret_version(
name=f"projects/dataline-integration-testing/secrets/{secret_name}/versions/latest"
).payload.data.decode("UTF-8")
)


def main(
connector_name: str,
secret_name: str | None,
streams: list[str] | None,
) -> None:
config = get_integ_test_config(secret_name)
config = get_gcp_secret_json(
secret_name=secret_name,
project_name=GCP_SECRETS_PROJECT_NAME,
)
source = ab.get_source(
connector_name,
config=config,
Expand Down
17 changes: 4 additions & 13 deletions examples/run_snowflake_faker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
from __future__ import annotations

import json
import os

from google.cloud import secretmanager

import airbyte as ab
from airbyte._util.google_secrets import get_gcp_secret_json
from airbyte.caches import SnowflakeCache


Expand All @@ -16,14 +12,9 @@
install_if_missing=True,
)

# load secrets from GSM using the GCP_GSM_CREDENTIALS env variable
secret_client = secretmanager.SecretManagerServiceClient.from_service_account_info(
json.loads(os.environ["GCP_GSM_CREDENTIALS"])
)
secret = json.loads(
secret_client.access_secret_version(
name="projects/dataline-integration-testing/secrets/AIRBYTE_LIB_SNOWFLAKE_CREDS/versions/latest"
).payload.data.decode("UTF-8")
secret = get_gcp_secret_json(
project_name="dataline-integration-testing",
secret_name="AIRBYTE_LIB_SNOWFLAKE_CREDS",
)

cache = SnowflakeCache(
Expand Down
92 changes: 83 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Global pytest fixtures."""
from __future__ import annotations

from contextlib import suppress
import json
import logging
import os
Expand All @@ -11,6 +12,10 @@
import time

import ulid
from airbyte._util.google_secrets import get_gcp_secret
from airbyte.caches.base import CacheBase
from airbyte.caches.bigquery import BigQueryCache
from airbyte.caches.duckdb import DuckDBCache
from airbyte.caches.snowflake import SnowflakeCache

import docker
Expand All @@ -22,6 +27,8 @@
from sqlalchemy import create_engine

from airbyte.caches import PostgresCache
from airbyte.caches.util import new_local_cache
from airbyte.sources.base import as_temp_files

logger = logging.getLogger(__name__)

Expand All @@ -32,6 +39,23 @@

LOCAL_TEST_REGISTRY_URL = "./tests/integration_tests/fixtures/registry.json"

AIRBYTE_INTERNAL_GCP_PROJECT = "dataline-integration-testing"


def get_ci_secret(
secret_name,
project_name: str = AIRBYTE_INTERNAL_GCP_PROJECT,
) -> str:
return get_gcp_secret(project_name=project_name, secret_name=secret_name)


def get_ci_secret_json(
secret_name,
project_name: str = AIRBYTE_INTERNAL_GCP_PROJECT,
) -> dict:
return json.loads(get_ci_secret(secret_name=secret_name, project_name=project_name))



def pytest_collection_modifyitems(items: list[Item]) -> None:
"""Override default pytest behavior, sorting our tests in a sensible execution order.
Expand Down Expand Up @@ -174,15 +198,8 @@ def new_pg_cache(pg_dsn):

@pytest.fixture
def new_snowflake_cache():
if "GCP_GSM_CREDENTIALS" not in os.environ:
raise Exception("GCP_GSM_CREDENTIALS env variable not set, can't fetch secrets for Snowflake. Make sure they are set up as described: https://github.com/airbytehq/airbyte/blob/master/airbyte-ci/connectors/ci_credentials/README.md#get-gsm-access")
secret_client = secretmanager.SecretManagerServiceClient.from_service_account_info(
json.loads(os.environ["GCP_GSM_CREDENTIALS"])
)
secret = json.loads(
secret_client.access_secret_version(
name="projects/dataline-integration-testing/secrets/AIRBYTE_LIB_SNOWFLAKE_CREDS/versions/latest"
).payload.data.decode("UTF-8")
secret = get_ci_secret_json(
"AIRBYTE_LIB_SNOWFLAKE_CREDS",
)
config = SnowflakeCache(
account=secret["account"],
Expand All @@ -201,6 +218,30 @@ def new_snowflake_cache():
connection.execute(f"DROP SCHEMA IF EXISTS {config.schema_name}")


@pytest.fixture
@pytest.mark.requires_creds
def new_bigquery_cache():
dest_bigquery_config = get_ci_secret_json(
"SECRET_DESTINATION-BIGQUERY_CREDENTIALS__CREDS"
)

dataset_name = f"test_deleteme_{str(ulid.ULID()).lower()[-6:]}"
credentials_json = dest_bigquery_config["credentials_json"]
with as_temp_files([credentials_json]) as (credentials_path,):
cache = BigQueryCache(
credentials_path=credentials_path,
project_name=dest_bigquery_config["project_id"],
dataset_name=dataset_name
)
yield cache

url = cache.get_sql_alchemy_url()
engine = create_engine(url)
with suppress(Exception):
with engine.begin() as connection:
connection.execute(f"DROP SCHEMA IF EXISTS {cache.schema_name}")


@pytest.fixture(autouse=True)
def source_test_registry(monkeypatch):
"""
Expand Down Expand Up @@ -248,3 +289,36 @@ def source_test_installation():
yield

shutil.rmtree(venv_dir)


@pytest.fixture(scope="function")
def new_duckdb_cache() -> DuckDBCache:
return new_local_cache()


@pytest.fixture(scope="function")
def new_generic_cache(request) -> CacheBase:
"""This is a placeholder fixture that will be overridden by pytest_generate_tests()."""
return request.getfixturevalue(request.param)


def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
"""Override default pytest behavior, parameterizing our tests based on the available cache types.

This is useful for running the same tests with different cache types, to ensure that the tests
can pass across all cache types.
"""
all_cache_type_fixtures: dict[str, str] = {
"BigQuery": "new_bigquery_cache",
"DuckDB": "new_duckdb_cache",
"Postgres": "new_pg_cache",
"Snowflake": "new_snowflake_cache",
}
if "new_generic_cache" in metafunc.fixturenames:
metafunc.parametrize(
"new_generic_cache",
all_cache_type_fixtures.values(),
ids=all_cache_type_fixtures.keys(),
indirect=True,
scope="function",
)
Loading
Loading