Skip to content

Commit

Permalink
Add dryRun parameter for query API
Browse files Browse the repository at this point in the history
And split ibis.py by data source
  • Loading branch information
grieve54706 committed Jun 4, 2024
1 parent 4e24567 commit 64e6327
Show file tree
Hide file tree
Showing 12 changed files with 241 additions and 87 deletions.
24 changes: 24 additions & 0 deletions ibis-server/app/logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from functools import wraps

from app.config import get_config

Expand All @@ -7,3 +8,26 @@

def get_logger(name):
return logging.getLogger(name)


def log_dto(f):
logger = get_logger('app.routers.ibis')

@wraps(f)
def wrapper(*args, **kwargs):
logger.debug(f'DTO: {kwargs["dto"]}')
return f(*args, **kwargs)

return wrapper


def log_rewritten(f):
logger = get_logger('app.mdl.rewriter')

@wraps(f)
def wrapper(*args, **kwargs):
rs = f(*args, **kwargs)
logger.debug(f'Rewritten SQL: {rs}')
return rs

return wrapper
2 changes: 1 addition & 1 deletion ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import orjson

from app.config import get_config
from app.util import log_rewritten
from app.logger import log_rewritten

wren_engine_endpoint = get_config().wren_engine_endpoint

Expand Down
30 changes: 30 additions & 0 deletions ibis-server/app/model/coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from json import loads

from app.mdl.rewriter import rewrite
from app.model.data_source import DataSource, ConnectionInfo


class Coordinator:
def __init__(self, data_source: DataSource, connection_info: ConnectionInfo, manifest_str: str):
self.data_source = data_source
self.connection = self.data_source.get_connection(connection_info)
self.manifest_str = manifest_str

def query(self, sql) -> dict:
rewritten_sql = rewrite(self.manifest_str, sql)
return self._to_json(self.connection.sql(rewritten_sql, dialect='trino').to_pandas())

def dry_run(self, sql):
try:
rewritten_sql = rewrite(self.manifest_str, sql)
self.connection.sql(rewritten_sql, dialect='trino')
return {"status": "success"}
except Exception as e:
return {"status": "failure", "message": str(e)}

@staticmethod
def _to_json(df):
json_obj = loads(df.to_json(orient='split'))
del json_obj['index']
json_obj['dtypes'] = df.dtypes.apply(lambda x: x.name).to_dict()
return json_obj
33 changes: 0 additions & 33 deletions ibis-server/app/routers/ibis.py

This file was deleted.

13 changes: 13 additions & 0 deletions ibis-server/app/routers/ibis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from fastapi import APIRouter

import app.routers.ibis.bigquery as bigquery
import app.routers.ibis.postgres as postgres
import app.routers.ibis.snowflake as snowflake

prefix = "/v2/ibis"

router = APIRouter(prefix=prefix)

router.include_router(bigquery.router)
router.include_router(postgres.router)
router.include_router(snowflake.router)
21 changes: 21 additions & 0 deletions ibis-server/app/routers/ibis/bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Annotated

from fastapi import APIRouter, Query

from app.logger import log_dto
from app.model.data_source import DataSource
from app.model.dto import BigQueryDTO
from app.model.coordinator import Coordinator

router = APIRouter(prefix='/bigquery', tags=['bigquery'])

data_source = DataSource.bigquery


@router.post("/query")
@log_dto
def query(dto: BigQueryDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> dict:
coord = Coordinator(data_source, dto.connection_info, dto.manifest_str)
if dry_run:
return coord.dry_run(dto.sql)
return coord.query(dto.sql)
21 changes: 21 additions & 0 deletions ibis-server/app/routers/ibis/postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Annotated

from fastapi import APIRouter, Query

from app.logger import log_dto
from app.model.data_source import DataSource
from app.model.dto import PostgresDTO
from app.model.coordinator import Coordinator

router = APIRouter(prefix='/postgres', tags=['postgres'])

data_source = DataSource.postgres


@router.post("/query")
@log_dto
def query(dto: PostgresDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> dict:
coord = Coordinator(data_source, dto.connection_info, dto.manifest_str)
if dry_run:
return coord.dry_run(dto.sql)
return coord.query(dto.sql)
21 changes: 21 additions & 0 deletions ibis-server/app/routers/ibis/snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Annotated

from fastapi import APIRouter, Query

from app.logger import log_dto
from app.model.data_source import DataSource
from app.model.dto import SnowflakeDTO
from app.model.coordinator import Coordinator

router = APIRouter(prefix='/snowflake', tags=['snowflake'])

data_source = DataSource.snowflake


@router.post("/query")
@log_dto
def query(dto: SnowflakeDTO, dry_run: Annotated[bool, Query(alias="dryRun")] = False) -> dict:
coord = Coordinator(data_source, dto.connection_info, dto.manifest_str)
if dry_run:
return coord.dry_run(dto.sql)
return coord.query(dto.sql)
36 changes: 0 additions & 36 deletions ibis-server/app/util.py

This file was deleted.

43 changes: 37 additions & 6 deletions ibis-server/tests/routers/ibis/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_connection_info():
"credentials": os.getenv("TEST_BIG_QUERY_CREDENTIALS_BASE64_JSON")
}

def test_bigquery(self, manifest_str: str):
def test_query(self, manifest_str: str):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
Expand All @@ -67,13 +67,13 @@ def test_bigquery(self, manifest_str: str):
assert result['data'][0][0] is not None
assert result['dtypes'] is not None

def test_no_manifest(self):
def test_query_without_manifest(self):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
json={
"connectionInfo": connection_info,
"sql": "SELECT * FROM Orders LIMIT 1"
"sql": 'SELECT * FROM "Orders" LIMIT 1'
}
)
assert response.status_code == 422
Expand All @@ -83,7 +83,7 @@ def test_no_manifest(self):
assert result['detail'][0]['loc'] == ['body', 'manifestStr']
assert result['detail'][0]['msg'] == 'Field required'

def test_no_sql(self, manifest_str: str):
def test_query_without_sql(self, manifest_str: str):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
Expand All @@ -99,12 +99,12 @@ def test_no_sql(self, manifest_str: str):
assert result['detail'][0]['loc'] == ['body', 'sql']
assert result['detail'][0]['msg'] == 'Field required'

def test_no_connection_info(self, manifest_str: str):
def test_query_without_connection_info(self, manifest_str: str):
response = client.post(
url="/v2/ibis/bigquery/query",
json={
"manifestStr": manifest_str,
"sql": "SELECT * FROM Orders LIMIT 1"
"sql": 'SELECT * FROM "Orders" LIMIT 1'
}
)
assert response.status_code == 422
Expand All @@ -113,3 +113,34 @@ def test_no_connection_info(self, manifest_str: str):
assert result['detail'][0]['type'] == 'missing'
assert result['detail'][0]['loc'] == ['body', 'connectionInfo']
assert result['detail'][0]['msg'] == 'Field required'

def test_query_with_dry_run(self, manifest_str: str):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
params={"dryRun": True},
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 1'
}
)
assert response.status_code == 200
result = response.json()
assert result['status'] == 'success'

def test_query_with_dry_run_and_invalid_sql(self, manifest_str: str):
connection_info = self.get_connection_info()
response = client.post(
url="/v2/ibis/bigquery/query",
params={"dryRun": True},
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM X'
}
)
assert response.status_code == 200
result = response.json()
assert result['status'] == 'failure'
assert result['message'] is not None
41 changes: 36 additions & 5 deletions ibis-server/tests/routers/ibis/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def to_connection_url(pg: PostgresContainer):
info = TestPostgres.to_connection_info(pg)
return f"postgres://{info['user']}:{info['password']}@{info['host']}:{info['port']}/{info['database']}"

def test_postgres(self, postgres: PostgresContainer, manifest_str: str):
def test_query(self, postgres: PostgresContainer, manifest_str: str):
connection_info = self.to_connection_info(postgres)
response = client.post(
url="/v2/ibis/postgres/query",
Expand All @@ -98,7 +98,7 @@ def test_postgres(self, postgres: PostgresContainer, manifest_str: str):
assert result['data'][0][0] == 1
assert result['dtypes'] is not None

def test_postgres_with_connection_url(self, postgres: PostgresContainer, manifest_str: str):
def test_query_with_connection_url(self, postgres: PostgresContainer, manifest_str: str):
connection_url = self.to_connection_url(postgres)
response = client.post(
url="/v2/ibis/postgres/query",
Expand All @@ -117,7 +117,7 @@ def test_postgres_with_connection_url(self, postgres: PostgresContainer, manifes
assert result['data'][0][0] == 1
assert result['dtypes'] is not None

def test_no_manifest(self, postgres: PostgresContainer):
def test_query_without_manifest(self, postgres: PostgresContainer):
connection_info = self.to_connection_info(postgres)
response = client.post(
url="/v2/ibis/postgres/query",
Expand All @@ -133,7 +133,7 @@ def test_no_manifest(self, postgres: PostgresContainer):
assert result['detail'][0]['loc'] == ['body', 'manifestStr']
assert result['detail'][0]['msg'] == 'Field required'

def test_no_sql(self, postgres: PostgresContainer, manifest_str: str):
def test_query_without_sql(self, postgres: PostgresContainer, manifest_str: str):
connection_info = self.to_connection_info(postgres)
response = client.post(
url="/v2/ibis/postgres/query",
Expand All @@ -149,7 +149,7 @@ def test_no_sql(self, postgres: PostgresContainer, manifest_str: str):
assert result['detail'][0]['loc'] == ['body', 'sql']
assert result['detail'][0]['msg'] == 'Field required'

def test_no_connection_info(self, manifest_str: str):
def test_query_without_connection_info(self, manifest_str: str):
response = client.post(
url="/v2/ibis/postgres/query",
json={
Expand All @@ -163,3 +163,34 @@ def test_no_connection_info(self, manifest_str: str):
assert result['detail'][0]['type'] == 'missing'
assert result['detail'][0]['loc'] == ['body', 'connectionInfo']
assert result['detail'][0]['msg'] == 'Field required'

def test_query_with_dry_run(self, postgres: PostgresContainer, manifest_str: str):
connection_info = self.to_connection_info(postgres)
response = client.post(
url="/v2/ibis/postgres/query",
params={"dryRun": True},
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "Orders" LIMIT 1'
}
)
assert response.status_code == 200
result = response.json()
assert result['status'] == 'success'

def test_query_with_dry_run_and_invalid_sql(self, postgres: PostgresContainer, manifest_str: str):
connection_info = self.to_connection_info(postgres)
response = client.post(
url="/v2/ibis/postgres/query",
params={"dryRun": True},
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": 'SELECT * FROM X'
}
)
assert response.status_code == 200
result = response.json()
assert result['status'] == 'failure'
assert result['message'] is not None
Loading

0 comments on commit 64e6327

Please sign in to comment.