Skip to content

Commit

Permalink
feat(api): database schemas migration to new API (#10436)
Browse files Browse the repository at this point in the history
* fix(log): log crashes if expired or not authenticated

* fix lint and rison

* add tests

* more tests

* perm fix

* fix test not found

* JS lint

* fix Jest test
  • Loading branch information
dpgaspar authored Jul 29, 2020
1 parent 0aad9c6 commit 671461d
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,13 @@ describe('TableSelector', () => {
});

describe('fetchSchemas', () => {
const FETCH_SCHEMAS_GLOB = 'glob:*/superset/schemas/*/*/';
const FETCH_SCHEMAS_GLOB = 'glob:*/api/v1/database/*/schemas/?q=(force:!*)';
afterEach(fetchMock.resetHistory);
afterAll(fetchMock.reset);

it('should fetch schema options', () => {
const schemaOptions = {
schemas: ['main', 'erf', 'superset'],
result: ['main', 'erf', 'superset'],
};
fetchMock.get(FETCH_SCHEMAS_GLOB, schemaOptions, {
overwriteRoutes: true,
Expand Down
7 changes: 5 additions & 2 deletions superset-frontend/src/components/TableSelector.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,13 @@ export default class TableSelector extends React.PureComponent {
const actualDbId = dbId || this.props.dbId;
if (actualDbId) {
this.setState({ schemaLoading: true });
const endpoint = `/superset/schemas/${actualDbId}/${forceRefresh}/`;
const queryParams = rison.encode({
force: Boolean(forceRefresh),
});
const endpoint = `/api/v1/database/${actualDbId}/schemas/?q=${queryParams}`;
return SupersetClient.get({ endpoint })
.then(({ json }) => {
const schemaOptions = json.schemas.map(s => ({
const schemaOptions = json.result.map(s => ({
value: s,
label: s,
title: s,
Expand Down
79 changes: 74 additions & 5 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@

from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from sqlalchemy.exc import NoSuchTableError, SQLAlchemyError
from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError

from superset import event_logger, security_manager
from superset import event_logger
from superset.databases.decorators import check_datasource_access
from superset.databases.schemas import (
database_schemas_query_schema,
DatabaseSchemaResponseSchema,
SchemasResponseSchema,
SelectStarResponseSchema,
TableMetadataResponseSchema,
)
from superset.extensions import security_manager
from superset.models.core import Database
from superset.typing import FlaskResponse
from superset.utils.core import error_msg_from_exception
Expand Down Expand Up @@ -125,9 +128,16 @@ def get_table_metadata(
class DatabaseRestApi(BaseSupersetModelRestApi):
datamodel = SQLAInterface(Database)

include_route_methods = {"get_list", "table_metadata", "select_star", "schemas"}
include_route_methods = {
"all_schemas",
"get_list",
"table_metadata",
"select_star",
"schemas",
}
class_permission_name = "DatabaseView"
method_permission_name = {
"all_schemas": "list",
"get_list": "list",
"table_metadata": "list",
"select_star": "list",
Expand All @@ -154,25 +164,83 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
"backend",
"function_names",
]
list_select_columns = list_columns + ["extra", "sqlalchemy_uri", "password"]
# Removes the local limit for the page size
max_page_size = -1
validators_columns = {"sqlalchemy_uri": sqlalchemy_uri_validator}

openapi_spec_tag = "Database"
apispec_parameter_schemas = {
"database_schemas_query_schema": database_schemas_query_schema,
"get_schemas_schema": get_schemas_schema,
}
openapi_spec_tag = "Database"
openapi_spec_component_schemas = (
DatabaseSchemaResponseSchema,
TableMetadataResponseSchema,
SelectStarResponseSchema,
SchemasResponseSchema,
)

@expose("/<int:pk>/schemas/")
@protect()
@safe
@rison(database_schemas_query_schema)
@statsd_metrics
def schemas(self, pk: int, **kwargs: Any) -> FlaskResponse:
""" Get all schemas from a database
---
get:
description: Get all schemas from a database
parameters:
- in: path
schema:
type: integer
name: pk
description: The database id
- in: query
name: q
content:
application/json:
schema:
$ref: '#/components/schemas/database_schemas_query_schema'
responses:
200:
description: A List of all schemas from the database
content:
application/json:
schema:
$ref: "#/components/schemas/SchemasResponseSchema"
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
500:
$ref: '#/components/responses/500'
"""
database = self.datamodel.get(pk, self._base_filters)
if not database:
return self.response_404()
try:
schemas = database.get_all_schema_names(
cache=database.schema_cache_enabled,
cache_timeout=database.schema_cache_timeout,
force=kwargs["rison"].get("force", False),
)
schemas = security_manager.get_schemas_accessible_by_user(database, schemas)
return self.response(200, result=schemas)
except OperationalError:
return self.response(
500, message="There was an error connecting to the database"
)

@expose("/<int:pk>/table/<table_name>/<schema_name>/", methods=["GET"])
@protect()
@check_datasource_access
@safe
@event_logger.log_this
@statsd_metrics
def table_metadata(
self, database: Database, table_name: str, schema_name: str
) -> FlaskResponse:
Expand Down Expand Up @@ -229,6 +297,7 @@ def table_metadata(
@check_datasource_access
@safe
@event_logger.log_this
@statsd_metrics
def select_star(
self, database: Database, table_name: str, schema_name: Optional[str] = None
) -> FlaskResponse:
Expand Down Expand Up @@ -286,7 +355,7 @@ def select_star(
@safe
@statsd_metrics
@rison(get_schemas_schema)
def schemas(self, **kwargs: Any) -> FlaskResponse:
def all_schemas(self, **kwargs: Any) -> FlaskResponse:
"""Get all schemas
---
get:
Expand Down
9 changes: 9 additions & 0 deletions superset/databases/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
# under the License.
from marshmallow import fields, Schema

database_schemas_query_schema = {
"type": "object",
"properties": {"force": {"type": "boolean"}},
}


class TableMetadataOptionsResponseSchema(Schema):
deferrable = fields.Bool()
Expand Down Expand Up @@ -79,6 +84,10 @@ class SelectStarResponseSchema(Schema):
result = fields.String(description="SQL select star")


class SchemasResponseSchema(Schema):
result = fields.List(fields.String(description="A database schema name"))


class DatabaseSchemaObjectResponseSchema(Schema):
value = fields.String(description="Schema name")
text = fields.String(description="Schema display name")
Expand Down
3 changes: 3 additions & 0 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,9 @@ def save_or_overwrite_slice( # pylint: disable=too-many-arguments,too-many-loca
def schemas( # pylint: disable=no-self-use
self, db_id: int, force_refresh: str = "false"
) -> FlaskResponse:
logger.warning(
"This API endpoint is deprecated and will be removed in version 1.0.0"
)
db_id = int(db_id)
database = db.session.query(models.Database).get(db_id)
if database:
Expand Down
71 changes: 60 additions & 11 deletions tests/database_api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
class TestDatabaseApi(SupersetTestCase):
def test_get_items(self):
"""
Database API: Test get items
Database API: Test get items
"""
self.login(username="admin")
uri = "api/v1/database/"
Expand Down Expand Up @@ -63,6 +63,9 @@ def test_get_items(self):
self.assertEqual(list(response["result"][0].keys()), expected_columns)

def test_get_items_filter(self):
"""
Database API: Test get items with filter
"""
fake_db = (
db.session.query(Database).filter_by(database_name="fake_db_100").one()
)
Expand Down Expand Up @@ -92,7 +95,7 @@ def test_get_items_filter(self):

def test_get_items_not_allowed(self):
"""
Database API: Test get items not allowed
Database API: Test get items not allowed
"""
self.login(username="gamma")
uri = f"api/v1/database/"
Expand All @@ -103,7 +106,7 @@ def test_get_items_not_allowed(self):

def test_get_table_metadata(self):
"""
Database API: Test get table metadata info
Database API: Test get table metadata info
"""
example_db = get_example_database()
self.login(username="admin")
Expand All @@ -117,7 +120,7 @@ def test_get_table_metadata(self):

def test_get_invalid_database_table_metadata(self):
"""
Database API: Test get invalid database from table metadata
Database API: Test get invalid database from table metadata
"""
database_id = 1000
self.login(username="admin")
Expand All @@ -131,7 +134,7 @@ def test_get_invalid_database_table_metadata(self):

def test_get_invalid_table_table_metadata(self):
"""
Database API: Test get invalid table from table metadata
Database API: Test get invalid table from table metadata
"""
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/wrong_table/null/"
Expand All @@ -141,7 +144,7 @@ def test_get_invalid_table_table_metadata(self):

def test_get_table_metadata_no_db_permission(self):
"""
Database API: Test get table metadata from not permitted db
Database API: Test get table metadata from not permitted db
"""
self.login(username="gamma")
example_db = get_example_database()
Expand All @@ -151,7 +154,7 @@ def test_get_table_metadata_no_db_permission(self):

def test_get_select_star(self):
"""
Database API: Test get select star
Database API: Test get select star
"""
self.login(username="admin")
example_db = get_example_database()
Expand All @@ -163,7 +166,7 @@ def test_get_select_star(self):

def test_get_select_star_not_allowed(self):
"""
Database API: Test get select star not allowed
Database API: Test get select star not allowed
"""
self.login(username="gamma")
example_db = get_example_database()
Expand All @@ -173,7 +176,7 @@ def test_get_select_star_not_allowed(self):

def test_get_select_star_datasource_access(self):
"""
Database API: Test get select star with datasource access
Database API: Test get select star with datasource access
"""
session = db.session
table = SqlaTable(
Expand Down Expand Up @@ -201,7 +204,7 @@ def test_get_select_star_datasource_access(self):

def test_get_select_star_not_found_database(self):
"""
Database API: Test get select star not found database
Database API: Test get select star not found database
"""
self.login(username="admin")
max_id = db.session.query(func.max(Database.id)).scalar()
Expand All @@ -211,7 +214,7 @@ def test_get_select_star_not_found_database(self):

def test_get_select_star_not_found_table(self):
"""
Database API: Test get select star not found database
Database API: Test get select star not found database
"""
self.login(username="admin")
example_db = get_example_database()
Expand All @@ -223,6 +226,9 @@ def test_get_select_star_not_found_table(self):
self.assertEqual(rv.status_code, 404)

def test_schemas(self):
"""
Database API: Test get select star not found database
"""
self.login("admin")
dbs = db.session.query(Database).all()
schemas = []
Expand Down Expand Up @@ -254,8 +260,51 @@ def test_schemas(self):

@mock.patch("superset.security_manager.get_schemas_accessible_by_user")
def test_schemas_no_access(self, mock_get_schemas_accessible_by_user):
"""
Database API: Test all schemas with no access
"""
mock_get_schemas_accessible_by_user.return_value = []
self.login("admin")
rv = self.client.get("api/v1/database/schemas/")
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(0, response["count"])

def test_database_schemas(self):
"""
Database API: Test database schemas
"""
self.login("admin")
database = db.session.query(Database).first()
schemas = database.get_all_schema_names()

rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, response["result"])

rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}"
)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, response["result"])

def test_database_schemas_not_found(self):
"""
Database API: Test database schemas not found
"""
self.logout()
self.login(username="gamma")
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/schemas/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)

def test_database_schemas_invalid_query(self):
"""
Database API: Test database schemas with invalid query
"""
self.login("admin")
database = db.session.query(Database).first()
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}"
)
self.assertEqual(rv.status_code, 400)

0 comments on commit 671461d

Please sign in to comment.