Skip to content

Commit

Permalink
#399: better titles for groups in charts
Browse files Browse the repository at this point in the history
  • Loading branch information
aschonfeld committed Jan 25, 2021
1 parent 1e921c7 commit c30fc07
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 69 deletions.
108 changes: 83 additions & 25 deletions dtale/charts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pandas as pd

from dtale.query import run_query
from dtale.query import build_col_key, run_query
from dtale.utils import (
ChartBuildingError,
classify_type,
Expand Down Expand Up @@ -163,46 +163,91 @@ def group_filter_handler(col_def, group_val, group_classifier):
if len(col_def_segs) > 1:
col, freq = col_def_segs
if group_val == "nan":
return "{col} != {col}".format(col=col)
return "{col} != {col}".format(col=build_col_key(col)), "{}: NaN".format(
col
)
if freq == "WD":
return "{}.dt.dayofweek == {}".format(col, group_val)
return (
"{}.dt.dayofweek == {}".format(build_col_key(col), group_val),
"{}.dt.dayofweek: {}".format(col, group_val),
)
elif freq == "H2":
return "{}.dt.hour == {}".format(col, group_val)
return (
"{}.dt.hour == {}".format(build_col_key(col), group_val),
"{}.dt.hour: {}".format(col, group_val),
)
elif freq == "H":
ts_val = pd.Timestamp(group_val)
return "{col}.dt.date == '{day}' and {col}.dt.hour == {hour}".format(
col=col, day=ts_val.strftime("%Y%m%d"), hour=ts_val.hour
day = ts_val.strftime("%Y%m%d")
hour = ts_val.hour
return (
"{col}.dt.date == '{day}' and {col}.dt.hour == {hour}".format(
col=build_col_key(col), day=day, hour=hour
),
"{col}.dt.date: {day}, {col}.dt.hour: {hour}".format(
col=col, day=day, hour=hour
),
)
elif freq == "D":
ts_val = convert_date_val_to_date(group_val)
return "{col}.dt.date == '{day}'".format(
col=col, day=ts_val.strftime("%Y%m%d")
day = ts_val.strftime("%Y%m%d")
return (
"{col}.dt.date == '{day}'".format(col=build_col_key(col), day=day),
"{}.dt.date: {}".format(col, day),
)
elif freq == "W":
ts_val = convert_date_val_to_date(group_val)
return "{col}.dt.year == {year} and {col}.dt.week == {week}".format(
col=col, year=ts_val.year, week=ts_val.week
return (
"{col}.dt.year == {year} and {col}.dt.week == {week}".format(
col=build_col_key(col), year=ts_val.year, week=ts_val.week
),
"{col}.dt.year: {year}, {col}.dt.week: {week}".format(
col=col, year=ts_val.year, week=ts_val.week
),
)
elif freq == "M":
ts_val = convert_date_val_to_date(group_val)
return "{col}.dt.year == {year} and {col}.dt.month == {month}".format(
col=col, year=ts_val.year, month=ts_val.month
return (
"{col}.dt.year == {year} and {col}.dt.month == {month}".format(
col=build_col_key(col), year=ts_val.year, month=ts_val.month
),
"{col}.dt.year: {year}, {col}.dt.month: {month}".format(
col=col, year=ts_val.year, month=ts_val.month
),
)
elif freq == "Q":
ts_val = convert_date_val_to_date(group_val)
return "{col}.dt.year == {year} and {col}.dt.quarter == {quarter}".format(
col=col, year=ts_val.year, quarter=ts_val.quarter
return (
"{col}.dt.year == {year} and {col}.dt.quarter == {quarter}".format(
col=build_col_key(col), year=ts_val.year, quarter=ts_val.quarter
),
"{col}.dt.year: {year}, {col}.dt.quarter: {quarter}".format(
col=col, year=ts_val.year, quarter=ts_val.quarter
),
)
elif freq == "Y":
ts_val = convert_date_val_to_date(group_val)
return "{col}.dt.year == {year}".format(col=col, year=ts_val.year)
return (
"{col}.dt.year == {year}".format(
col=build_col_key(col), year=ts_val.year
),
"{}.dt.year: {}".format(col, ts_val.year),
)
if group_val == "nan":
return "{col} != {col}".format(col=col_def)
return "{col} != {col}".format(col=build_col_key(col_def)), "{}: NaN".format(
col_def
)
if group_classifier in ["I", "F"]:
return "{col} == {val}".format(col=col_def, val=group_val)
return (
"{col} == {val}".format(col=build_col_key(col_def), val=group_val),
"{}: {}".format(col_def, group_val),
)
if group_classifier == "D":
group_val = convert_date_val_to_date(group_val).strftime("%Y%m%d")
return "{col} == '{val}'".format(col=col_def, val=group_val)
return (
"{col} == '{val}'".format(col=build_col_key(col_def), val=group_val),
"{}: {}".format(col_def, group_val),
)


def build_group_inputs_filter(df, group_inputs):
Expand All @@ -215,11 +260,17 @@ def _group_filter(group_val):

def _full_filter():
for group_val in group_inputs:
group_filter = " and ".join(list(_group_filter(group_val)))
yield group_filter
filter_vals, label_vals = [], []
for fv, lv in _group_filter(group_val):
filter_vals.append(fv)
label_vals.append(lv)
yield " and ".join(filter_vals), ", ".join(label_vals)

filters = list(_full_filter())
return "({})".format(") or (".join(filters))
full_filters, full_labels = [], []
for ff, fl in _full_filter():
full_filters.append(ff)
full_labels.append(fl)
return ("({})".format(") or (".join(full_filters)), ", ".join(full_labels))


def retrieve_chart_data(df, *args, **kwargs):
Expand Down Expand Up @@ -247,7 +298,7 @@ def retrieve_chart_data(df, *args, **kwargs):
all_data = pd.concat(all_data, axis=1)
all_code = ["chart_data = pd.concat(["] + all_code + ["], axis=1)"]
if len(make_list(kwargs.get("group_val"))):
filters = build_group_inputs_filter(all_data, kwargs["group_val"])
filters, labels = build_group_inputs_filter(all_data, kwargs["group_val"])
all_data = run_query(all_data, filters)
all_code.append(
"chart_data = chart_data.query({})".format(triple_quote(filters))
Expand Down Expand Up @@ -592,8 +643,15 @@ def _group_filter():
gc, group_fmts[gc](gv, as_string=True), classifier
)

group_filter = " and ".join(list(_group_filter()))
yield group_filter, data_f.format_lists(grp)
final_group_filter, final_group_label = [], []
for gf, gl in _group_filter():
final_group_filter.append(gf)
final_group_label.append(gl)
group_filter = " and ".join(final_group_filter)
group_label = "({})".format(", ".join(final_group_label))
data = data_f.format_lists(grp)
data["_filter_"] = group_filter
yield group_label, data

if animate_by is not None:
frame_fmt = find_dtype_formatter(
Expand Down
31 changes: 19 additions & 12 deletions dtale/dash_application/charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ def scatter_builder(
axes_builder,
wrapper,
group=None,
group_filter=None,
z=None,
agg=None,
animate_by=None,
Expand Down Expand Up @@ -739,7 +740,7 @@ def build_frame(frame, frame_name):
return wrapper(
graph_wrapper(figure=figure_cfg, modal=modal),
group_filter=dict_merge(
dict(y=y_val), {} if group is None else dict(group=group)
dict(y=y_val), {} if group_filter is None else dict(group=group_filter)
),
)

Expand Down Expand Up @@ -1112,7 +1113,7 @@ def bar_builder(
},
modal=kwargs.get("modal", False),
),
group_filter=dict(group=series_key),
group_filter=dict(group=series.pop("_filter_")),
)
for series_key, series in data["data"].items()
]
Expand Down Expand Up @@ -1322,7 +1323,7 @@ def line_func(s):
},
modal=inputs.get("modal", False),
),
group_filter=dict(group=series_key),
group_filter=dict(group=series.pop("_filter_")),
)
for series_key, series in data["data"].items()
]
Expand Down Expand Up @@ -1521,7 +1522,9 @@ def build_pies():
),
group_filter=dict_merge(
dict(y=y2),
{} if series_key == "all" else dict(group=series_key),
{}
if series_key == "all"
else dict(group=series.pop("_filter_")),
),
)
if len(negative_values):
Expand Down Expand Up @@ -2018,7 +2021,7 @@ def treemap_builder(data_id, export=False, **inputs):
)
)

def _build_treemap_data(values, labels, name):
def _build_treemap_data(values, labels, name, group_filter):
x, y, width, height = 0.0, 0.0, 100.0, 100.0
normed = squarify.normalize_sizes(values, width, height)
rects = squarify.squarify(normed, x, y, width, height)
Expand Down Expand Up @@ -2065,10 +2068,9 @@ def _build_treemap_data(values, labels, name):
annotations=annotations,
hovermode="closest",
)
group_filter = None
if name != "all":
layout["title"] = name
group_filter = dict(group=name)
group_filter = dict(group=group_filter)
figure_cfg = dict(data=[trace], layout=layout)
base_fig = graph_wrapper(
style={"margin-right": "auto", "margin-left": "auto"},
Expand All @@ -2080,7 +2082,12 @@ def _build_treemap_data(values, labels, name):
return chart_builder(base_fig, group_filter=group_filter)

chart = [
_build_treemap_data(series[treemap_value], series["x"], series_key)
_build_treemap_data(
series[treemap_value],
series["x"],
series_key,
series.pop("_filter_", None),
)
for series_key, series in data["data"].items()
]
code.append(
Expand Down Expand Up @@ -2136,9 +2143,8 @@ def map_builder(data_id, export=False, **inputs):
agg_title = AGGS[props.agg]
title = "{} ({})".format(title, agg_title)
if props.group_val is not None:
title = "{} {}".format(
title, build_group_inputs_filter(raw_data, props.group_val)
)
_, group_label = build_group_inputs_filter(raw_data, props.group_val)
title = "{} ({})".format(title, group_label)
layout = build_layout(
dict(title=title, autosize=True, margin={"l": 0, "r": 0, "b": 0})
)
Expand Down Expand Up @@ -2815,9 +2821,10 @@ def build_chart(data_id=None, data=None, **inputs):
axes_builder,
chart_builder,
group=subgroup,
group_filter=subgroup_cfg.pop("_filter_"),
**kwargs
)
for subgroup in data["data"]
for subgroup, subgroup_cfg in data["data"].items()
]
)
else:
Expand Down
12 changes: 6 additions & 6 deletions dtale/dash_application/drilldown_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
build_selections,
AGGS,
)
from dtale.query import build_query
from dtale.utils import (
classify_type,
dict_merge,
Expand All @@ -29,7 +28,7 @@
make_list,
get_dtypes,
)
from dtale.query import run_query
from dtale.query import build_query, run_query
from dtale.charts.utils import (
MAX_GROUPS,
ZAXIS_CHARTS,
Expand All @@ -56,7 +55,8 @@ def build_histogram(data_id, col, query, point_filter):
query,
global_state.get_context_variables(data_id),
)
data = run_query(data, build_group_inputs_filter(data, [point_filter]))
query, _ = build_group_inputs_filter(data, [point_filter])
data = run_query(data, query)
s = data[~pd.isnull(data[col])][col]
hist_data, hist_labels = np.histogram(s, bins=10)
hist_labels = list(map(lambda x: json_float(x, precision=3), hist_labels[1:]))
Expand Down Expand Up @@ -250,7 +250,7 @@ def load_drilldown_content(
)
return hist_chart, dict(display="none")
else:
xy_query = build_group_inputs_filter(
xy_query, _ = build_group_inputs_filter(
global_state.get_data(data_id),
[point_filter],
)
Expand Down Expand Up @@ -285,7 +285,7 @@ def load_drilldown_content(
)
return hist_chart, dict(display="none")
else:
map_query = build_group_inputs_filter(
map_query, _ = build_group_inputs_filter(
global_state.get_data(data_id),
[point_filter],
)
Expand Down Expand Up @@ -323,7 +323,7 @@ def load_drilldown_content(
)
return hist_chart, dict(display="none")
else:
x_query = build_group_inputs_filter(
x_query, _ = build_group_inputs_filter(
global_state.get_data(data_id),
[point_filter],
)
Expand Down
2 changes: 1 addition & 1 deletion dtale/duplicate_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def remove(self, df):
duplicates = pd.concat(duplicates)
group_filter = None
if self.cfg.get("filter"):
group_filter = build_group_inputs_filter(
group_filter, _ = build_group_inputs_filter(
df, [{col: val for col, val in zip(group, self.cfg["filter"])}]
)
duplicates = run_query(duplicates, group_filter)
Expand Down
8 changes: 7 additions & 1 deletion dtale/query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pandas as pd
from pkg_resources import parse_version

import dtale.global_state as global_state


Expand Down Expand Up @@ -49,7 +52,10 @@ def _load_pct(df):
if (query or "") == "":
return _load_pct(df)

df = df.query(query, local_dict=context_vars or {})
is_pandas25 = parse_version(pd.__version__) >= parse_version("0.25.0")
df = df.query(
query if is_pandas25 else query.replace("`", ""), local_dict=context_vars or {}
)

if not len(df) and not ignore_empty:
raise Exception('query "{}" found no data, please alter'.format(query))
Expand Down
8 changes: 4 additions & 4 deletions tests/dtale/charts/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def test_convert_date_val_to_date():
@pytest.mark.unit
def test_group_filter_handler():
assert (
chart_utils.group_filter_handler("date", "2020-01-01", "D")
== "date == '20200101'"
chart_utils.group_filter_handler("date", "2020-01-01", "D")[0]
== "`date` == '20200101'"
)
assert (
chart_utils.group_filter_handler("date", 1577854800000, "D")
== "date == '20200101'"
chart_utils.group_filter_handler("date", 1577854800000, "D")[0]
== "`date` == '20200101'"
)
1 change: 0 additions & 1 deletion tests/dtale/dash/test_custom_geojson.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def test_update_geojson():
params["state"][0]["value"] = "africa_110m.json"
response = c.post("/dtale/charts/_dash-update-component", json=params)
resp_data = response.get_json()["response"]
print(resp_data["output-geojson-upload"]["children"])
assert (
resp_data["output-geojson-upload"]["children"]
== "africa_110m uploaded!"
Expand Down
Loading

0 comments on commit c30fc07

Please sign in to comment.