Skip to content

Commit

Permalink
Feat: Add BigQuery cache support; Fix: IPython rendering bug when usi…
Browse files Browse the repository at this point in the history
…ng a TTY (#65)

Co-authored-by: bindipankhudi <bindi@airbyte.com>
Co-authored-by: Aaron Steers <aj@airbyte.io>
  • Loading branch information
3 people authored Feb 25, 2024
1 parent 797e657 commit 6cfc3b5
Show file tree
Hide file tree
Showing 9 changed files with 688 additions and 72 deletions.
2 changes: 2 additions & 0 deletions airbyte/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from airbyte import caches, datasets, registry, secrets
from airbyte._factories.connector_factories import get_source
from airbyte.caches.bigquery import BigQueryCache
from airbyte.caches.duckdb import DuckDBCache
from airbyte.caches.factories import get_default_cache, new_local_cache
from airbyte.datasets import CachedDataset
Expand All @@ -29,6 +30,7 @@
"get_source",
"new_local_cache",
# Classes
"BigQueryCache",
"CachedDataset",
"DuckDBCache",
"ReadResult",
Expand Down
19 changes: 6 additions & 13 deletions airbyte/_processors/sql/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class SqlProcessorBase(RecordProcessor):

# Constructor:

@final # We don't want subclasses to have to override the constructor.
def __init__(
self,
cache: CacheBase,
Expand Down Expand Up @@ -435,9 +434,6 @@ def _create_table(
column_definition_str: str,
primary_keys: list[str] | None = None,
) -> None:
if DEBUG_MODE:
assert table_name not in self._get_tables_list(), f"Table {table_name} already exists."

if primary_keys:
pk_str = ", ".join(primary_keys)
column_definition_str += f",\n PRIMARY KEY ({pk_str})"
Expand All @@ -448,11 +444,6 @@ def _create_table(
)
"""
_ = self._execute_sql(cmd)
if DEBUG_MODE:
tables_list = self._get_tables_list()
assert (
table_name in tables_list
), f"Table {table_name} was not created. Found: {tables_list}"

def _normalize_column_name(
self,
Expand Down Expand Up @@ -804,8 +795,8 @@ def _merge_temp_table_to_final_table(
columns = {self._quote_identifier(c) for c in self._get_sql_column_definitions(stream_name)}
pk_columns = {self._quote_identifier(c) for c in self._get_primary_keys(stream_name)}
non_pk_columns = columns - pk_columns
join_clause = "{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns)
set_clause = "{nl} ".join(f"{col} = tmp.{col}" for col in non_pk_columns)
join_clause = f"{nl} AND ".join(f"tmp.{pk_col} = final.{pk_col}" for pk_col in pk_columns)
set_clause = f"{nl} , ".join(f"{col} = tmp.{col}" for col in non_pk_columns)
self._execute_sql(
f"""
MERGE INTO {self._fully_qualified(final_table_name)} final
Expand Down Expand Up @@ -908,12 +899,14 @@ def _emulated_merge_temp_table_to_final_table(
conn.execute(update_stmt)
conn.execute(insert_new_records_stmt)

@final
def _table_exists(
self,
table_name: str,
) -> bool:
"""Return true if the given table exists."""
"""Return true if the given table exists.
Subclasses may override this method to provide a more efficient implementation.
"""
return table_name in self._get_tables_list()

@abc.abstractmethod
Expand Down
204 changes: 204 additions & 0 deletions airbyte/_processors/sql/bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
"""A BigQuery implementation of the cache."""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, final

import sqlalchemy
from google.api_core.exceptions import NotFound
from google.cloud import bigquery
from google.oauth2 import service_account
from overrides import overrides

from airbyte import exceptions as exc
from airbyte._processors.file.jsonl import JsonlWriter
from airbyte._processors.sql.base import SqlProcessorBase
from airbyte.telemetry import CacheTelemetryInfo
from airbyte.types import SQLTypeConverter


if TYPE_CHECKING:
from sqlalchemy.engine.reflection import Inspector

from airbyte._processors.file.base import FileWriterBase
from airbyte.caches.base import CacheBase
from airbyte.caches.bigquery import BigQueryCache


class BigQueryTypeConverter(SQLTypeConverter):
"""A class to convert types for BigQuery."""

@overrides
def to_sql_type(
self,
json_schema_property_def: dict[str, str | dict | list],
) -> sqlalchemy.types.TypeEngine:
"""Convert a value to a SQL type.
We first call the parent class method to get the type. Then if the type is VARCHAR or
BIGINT, we replace it with respective BigQuery types.
"""
sql_type = super().to_sql_type(json_schema_property_def)
# to-do: replace hardcoded return types with some sort of snowflake Variant equivalent
if isinstance(sql_type, sqlalchemy.types.VARCHAR):
return "String"
if isinstance(sql_type, sqlalchemy.types.BIGINT):
return "INT64"

return sql_type.__class__.__name__


class BigQuerySqlProcessor(SqlProcessorBase):
"""A BigQuery implementation of the cache."""

file_writer_class = JsonlWriter
type_converter_class = BigQueryTypeConverter
supports_merge_insert = True

cache: BigQueryCache

def __init__(self, cache: CacheBase, file_writer: FileWriterBase | None = None) -> None:
self._credentials: service_account.Credentials | None = None
self._schema_exists: bool | None = None
super().__init__(cache, file_writer)

@final
@overrides
def _fully_qualified(
self,
table_name: str,
) -> str:
"""Return the fully qualified name of the given table."""
return f"`{self.cache.schema_name}`.`{table_name!s}`"

@final
@overrides
def _quote_identifier(self, identifier: str) -> str:
"""Return the identifier name as is. BigQuery does not require quoting identifiers"""
return f"{identifier}"

@final
@overrides
def _get_telemetry_info(self) -> CacheTelemetryInfo:
return CacheTelemetryInfo("bigquery")

def _write_files_to_new_table(
self,
files: list[Path],
stream_name: str,
batch_id: str,
) -> str:
"""Write a file(s) to a new table.
This is a generic implementation, which can be overridden by subclasses
to improve performance.
"""
temp_table_name = self._create_table_for_loading(stream_name, batch_id)

# Specify the table ID (in the format `project_id.dataset_id.table_id`)
table_id = f"{self.cache.project_name}.{self.cache.dataset_name}.{temp_table_name}"

# Initialize a BigQuery client
client = bigquery.Client(credentials=self._get_credentials())

for file_path in files:
with Path.open(file_path, "rb") as source_file:
load_job = client.load_table_from_file( # Make an API request
file_obj=source_file,
destination=table_id,
job_config=bigquery.LoadJobConfig(
source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON,
schema=[
bigquery.SchemaField(name, field_type=str(type_))
for name, type_ in self._get_sql_column_definitions(
stream_name=stream_name
).items()
],
),
)
_ = load_job.result() # Wait for the job to complete

return temp_table_name

def _ensure_schema_exists(
self,
) -> None:
"""Ensure the target schema exists.
We override the default implementation because BigQuery is very slow at scanning schemas.
This implementation simply calls "CREATE SCHEMA IF NOT EXISTS" and ignores any errors.
"""
if self._schema_exists:
return

sql = f"CREATE SCHEMA IF NOT EXISTS {self.cache.schema_name}"
try:
self._execute_sql(sql)
except Exception as ex:
# Ignore schema exists errors.
if "already exists" not in str(ex):
raise

self._schema_exists = True

def _get_credentials(self) -> service_account.Credentials:
"""Return the GCP credentials."""
if self._credentials is None:
self._credentials = service_account.Credentials.from_service_account_file(
self.cache.credentials_path
)

return self._credentials

def _table_exists(
self,
table_name: str,
) -> bool:
"""Return true if the given table exists.
We override the default implementation because BigQuery is very slow at scanning tables.
"""
client = bigquery.Client(credentials=self._get_credentials())
table_id = f"{self.cache.project_name}.{self.cache.dataset_name}.{table_name}"
try:
client.get_table(table_id)
except NotFound:
return False

except ValueError as ex:
raise exc.AirbyteLibInputError(
message="Invalid project name or dataset name.",
context={
"table_id": table_id,
"table_name": table_name,
"project_name": self.cache.project_name,
"dataset_name": self.cache.dataset_name,
},
) from ex

return True

@final
@overrides
def _get_tables_list(
self,
) -> list[str]:
"""Get the list of available tables in the schema.
For bigquery, {schema_name}.{table_name} is returned, so we need to
strip the schema name in front of the table name, if it exists.
Warning: This method is slow for BigQuery, as it needs to scan all tables in the dataset.
It has been observed to take 30+ seconds in some cases.
"""
with self.get_sql_connection() as conn:
inspector: Inspector = sqlalchemy.inspect(conn)
tables = inspector.get_table_names(schema=self.cache.schema_name)
schema_prefix = f"{self.cache.schema_name}."
return [
table.replace(schema_prefix, "", 1) if table.startswith(schema_prefix) else table
for table in tables
]
7 changes: 5 additions & 2 deletions airbyte/caches/_catalog_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from __future__ import annotations

import json
from datetime import datetime
from typing import TYPE_CHECKING, Callable

from pytz import utc
from sqlalchemy import Column, DateTime, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from sqlalchemy.sql import func

from airbyte_protocol.models import (
AirbyteStateMessage,
Expand Down Expand Up @@ -50,7 +51,9 @@ class StreamState(Base): # type: ignore[valid-type,misc]
stream_name = Column(String)
table_name = Column(String, primary_key=True)
state_json = Column(String)
last_updated = Column(DateTime(timezone=True), onupdate=func.now(), default=func.now())
last_updated = Column(
DateTime(timezone=True), onupdate=datetime.now(utc), default=datetime.now(utc)
)


class CatalogManager:
Expand Down
38 changes: 38 additions & 0 deletions airbyte/caches/bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
"""A BigQuery implementation of the cache."""

from __future__ import annotations

import urllib

from overrides import overrides

from airbyte._processors.sql.bigquery import BigQuerySqlProcessor
from airbyte.caches.base import (
CacheBase,
)


class BigQueryCache(CacheBase):
"""The BigQuery cache implementation."""

project_name: str
dataset_name: str = "airbyte_raw"
credentials_path: str

_sql_processor_class: type[BigQuerySqlProcessor] = BigQuerySqlProcessor

def __post_init__(self) -> None:
"""Initialize the BigQuery cache."""
self.schema_name = self.dataset_name

@overrides
def get_database_name(self) -> str:
"""Return the name of the database. For BigQuery, this is the project name."""
return self.project_name

@overrides
def get_sql_alchemy_url(self) -> str:
"""Return the SQLAlchemy URL to use."""
credentials_path_encoded = urllib.parse.quote(self.credentials_path)
return f"bigquery://{self.project_name!s}?credentials_path={credentials_path_encoded}"
7 changes: 5 additions & 2 deletions airbyte/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import datetime
import importlib
import math
import sys
import time
Expand All @@ -25,10 +26,12 @@

ipy_display: ModuleType | None
try:
IS_NOTEBOOK = True
from IPython import display as ipy_display # type: ignore # noqa: PGH003
# Default to IS_NOTEBOOK=False if a TTY is detected.
IS_NOTEBOOK = not sys.stdout.isatty()
ipy_display = importlib.import_module("IPython.display")

except ImportError:
# If IPython is not installed, then we're definitely not in a notebook.
ipy_display = None
IS_NOTEBOOK = False

Expand Down
Loading

0 comments on commit 6cfc3b5

Please sign in to comment.