Skip to content

Commit

Permalink
Make g.user attribute access safe for public users (#14287)
Browse files Browse the repository at this point in the history
  • Loading branch information
robdiciuccio authored Apr 26, 2021
1 parent 7ff35df commit 6875a1a
Show file tree
Hide file tree
Showing 12 changed files with 39 additions and 25 deletions.
2 changes: 1 addition & 1 deletion superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,7 @@ def favorite_status(self, **kwargs: Any) -> Response:
charts = ChartDAO.find_by_ids(requested_ids)
if not charts:
return self.response_404()
favorited_chart_ids = ChartDAO.favorited_ids(charts, g.user.id)
favorited_chart_ids = ChartDAO.favorited_ids(charts, g.user.get_id())
res = [
{"id": request_id, "value": request_id in favorited_chart_ids}
for request_id in requested_ids
Expand Down
2 changes: 1 addition & 1 deletion superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _try_json_readsha( # pylint: disable=unused-argument
# from flask import g, request
# def GET_FEATURE_FLAGS_FUNC(feature_flags_dict: Dict[str, bool]) -> Dict[str, bool]:
# if hasattr(g, "user") and g.user.is_active:
# feature_flags_dict['some_feature'] = g.user and g.user.id == 5
# feature_flags_dict['some_feature'] = g.user and g.user.get_id() == 5
# return feature_flags_dict
GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] = None

Expand Down
4 changes: 3 additions & 1 deletion superset/dashboards/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,9 @@ def favorite_status(self, **kwargs: Any) -> Response:
dashboards = DashboardDAO.find_by_ids(requested_ids)
if not dashboards:
return self.response_404()
favorited_dashboard_ids = DashboardDAO.favorited_ids(dashboards, g.user.id)
favorited_dashboard_ids = DashboardDAO.favorited_ids(
dashboards, g.user.get_id()
)
res = [
{"id": request_id, "value": request_id in favorited_dashboard_ids}
for request_id in requested_ids
Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ def estimate_query_cost(
if not cls.get_allow_cost_estimate(extra):
raise Exception("Database does not support cost estimation")

user_name = g.user.username if g.user else None
user_name = g.user.username if g.user and hasattr(g.user, "username") else None
parsed_query = sql_parse.ParsedQuery(sql)
statements = parsed_query.get_statements()

Expand Down
6 changes: 3 additions & 3 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]:

if hasattr(g, "user") and g.user:
if add_to_cache_keys:
self.cache_key_wrapper(g.user.id)
return g.user.id
self.cache_key_wrapper(g.user.get_id())
return g.user.get_id()
return None

def current_username(self, add_to_cache_keys: bool = True) -> Optional[str]:
Expand All @@ -154,7 +154,7 @@ def current_username(self, add_to_cache_keys: bool = True) -> Optional[str]:
:returns: The username
"""

if g.user:
if g.user and hasattr(g.user, "username"):
if add_to_cache_keys:
self.cache_key_wrapper(g.user.username)
return g.user.username
Expand Down
4 changes: 2 additions & 2 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def user_view_menu_names(self, permission_name: str) -> Set[str]:
view_menu_names = (
base_query.join(assoc_user_role)
.join(self.user_model)
.filter(self.user_model.id == g.user.id)
.filter(self.user_model.id == g.user.get_id())
.filter(self.permission_model.name == permission_name)
).all()
return {s.name for s in view_menu_names}
Expand Down Expand Up @@ -1044,7 +1044,7 @@ def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]:

user_roles = (
self.get_session.query(assoc_user_role.c.role_id)
.filter(assoc_user_role.c.user_id == g.user.id)
.filter(assoc_user_role.c.user_id == g.user.get_id())
.subquery()
)
regular_filter_roles = (
Expand Down
2 changes: 1 addition & 1 deletion superset/sql_validators/presto_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def validate(
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
VALIDATE) SELECT 1 FROM default.mytable.
"""
user_name = g.user.username if g.user else None
user_name = g.user.username if g.user and hasattr(g.user, "username") else None
parsed_query = ParsedQuery(sql)
statements = parsed_query.get_statements()

Expand Down
5 changes: 4 additions & 1 deletion superset/views/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@ def apply(self, query: Query, value: Any) -> Query:
if security_manager.current_user is None:
return query
users_favorite_query = db.session.query(FavStar.obj_id).filter(
and_(FavStar.user_id == g.user.id, FavStar.class_name == self.class_name)
and_(
FavStar.user_id == g.user.get_id(),
FavStar.class_name == self.class_name,
)
)
if value:
return query.filter(and_(self.model.id.in_(users_favorite_query)))
Expand Down
4 changes: 2 additions & 2 deletions superset/views/base_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def pre_load(self, data: Dict[Any, Any]) -> None:
@staticmethod
def set_owners(instance: Model, owners: List[int]) -> None:
owner_objs = list()
if g.user.id not in owners:
owners.append(g.user.id)
if g.user.get_id() not in owners:
owners.append(g.user.get_id())
for owner_id in owners:
user = current_app.appbuilder.get_session.query(
current_app.appbuilder.sm.user_model
Expand Down
27 changes: 18 additions & 9 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,8 +1067,10 @@ def get_datasource_label(ds_name: utils.DatasourceName) -> str:
views = [vn for vn in views if substr_parsed in get_datasource_label(vn)]

if not schema_parsed and database.default_schemas:
user_schema = g.user.email.split("@")[0]
valid_schemas = set(database.default_schemas + [user_schema])
user_schemas = (
[g.user.email.split("@")[0]] if hasattr(g.user, "email") else []
)
valid_schemas = set(database.default_schemas + user_schemas)

tables = [tn for tn in tables if tn.schema in valid_schemas]
views = [vn for vn in views if vn.schema in valid_schemas]
Expand Down Expand Up @@ -1261,7 +1263,9 @@ def testconn( # pylint: disable=too-many-return-statements,no-self-use
database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)

username = g.user.username if g.user is not None else None
username = (
g.user.username if g.user and hasattr(g.user, "username") else None
)
engine = database.get_sqla_engine(user_name=username)

with closing(engine.raw_connection()) as conn:
Expand Down Expand Up @@ -1515,7 +1519,7 @@ def user_slices( # pylint: disable=no-self-use
) -> FlaskResponse:
"""List of slices a user owns, created, modified or faved"""
if not user_id:
user_id = g.user.id
user_id = g.user.get_id()

owner_ids_query = (
db.session.query(Slice.id)
Expand Down Expand Up @@ -1567,7 +1571,7 @@ def created_slices( # pylint: disable=no-self-use
) -> FlaskResponse:
"""List of slices created by this user"""
if not user_id:
user_id = g.user.id
user_id = g.user.get_id()
qry = (
db.session.query(Slice)
.filter(or_(Slice.created_by_fk == user_id, Slice.changed_by_fk == user_id))
Expand Down Expand Up @@ -1595,7 +1599,7 @@ def fave_slices( # pylint: disable=no-self-use
) -> FlaskResponse:
"""Favorite slices for a user"""
if not user_id:
user_id = g.user.id
user_id = g.user.get_id()
qry = (
db.session.query(Slice, FavStar.dttm)
.join(
Expand Down Expand Up @@ -1779,8 +1783,9 @@ def publish( # pylint: disable=no-self-use

edit_perm = is_owner(dash, g.user) or admin_role in get_user_roles()
if not edit_perm:
username = g.user.username if hasattr(g.user, "username") else "user"
return json_error_response(
f'ERROR: "{g.user.username}" cannot alter '
f'ERROR: "{username}" cannot alter '
f'dashboard "{dash.dashboard_title}"',
status=403,
)
Expand Down Expand Up @@ -2304,7 +2309,9 @@ def _sql_json_async( # pylint: disable=too-many-arguments
rendered_query,
return_results=False,
store_results=not query.select_as_cta,
user_name=g.user.username if g.user else None,
user_name=g.user.username
if g.user and hasattr(g.user, "username")
else None,
start_time=now_as_float(),
expand_data=expand_data,
log_params=log_params,
Expand Down Expand Up @@ -2376,7 +2383,9 @@ def _sql_json_sync(
rendered_query,
return_results=True,
store_results=store_results,
user_name=g.user.username if g.user else None,
user_name=g.user.username
if g.user and hasattr(g.user, "username")
else None,
expand_data=expand_data,
log_params=log_params,
)
Expand Down
4 changes: 2 additions & 2 deletions superset/views/database/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def form_post(self, form: CsvToDatabaseForm) -> Response:
sqla_table = SqlaTable(table_name=csv_table.table)
sqla_table.database = expore_database
sqla_table.database_id = database.id
sqla_table.user_id = g.user.id
sqla_table.user_id = g.user.get_id()
sqla_table.schema = csv_table.schema
sqla_table.fetch_metadata()
db.session.add(sqla_table)
Expand Down Expand Up @@ -360,7 +360,7 @@ def form_post(self, form: ExcelToDatabaseForm) -> Response:
sqla_table = SqlaTable(table_name=excel_table.table)
sqla_table.database = expore_database
sqla_table.database_id = database.id
sqla_table.user_id = g.user.id
sqla_table.user_id = g.user.get_id()
sqla_table.schema = excel_table.schema
sqla_table.fetch_metadata()
db.session.add(sqla_table)
Expand Down
2 changes: 1 addition & 1 deletion superset/views/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,4 +297,4 @@ class SqlLab(BaseSupersetView):
@has_access
def my_queries(self) -> FlaskResponse: # pylint: disable=no-self-use
"""Assigns a list of found users to the given role."""
return redirect("/savedqueryview/list/?_flt_0_user={}".format(g.user.id))
return redirect("/savedqueryview/list/?_flt_0_user={}".format(g.user.get_id()))

0 comments on commit 6875a1a

Please sign in to comment.