Skip to content

Commit

Permalink
Context variables: Allow users to inject a dictionary of context vari…
Browse files Browse the repository at this point in the history
…ables upon starting a dtale process.

These variables can be referenced in query strings by prefixing the name with @.
This can simplify complex queries immensely.
 - Context variables: PR tweaks and added front-end tests
 - Context variables: added test covering case where fetching them returns an error
 - Context variables: accidentally had key/value flipped in lodash object map, fixed it
 - Context variables: updated documentation for filtering
  • Loading branch information
phillipdupuis authored and Andrew Schonfeld committed Feb 21, 2020
1 parent f8d9529 commit 97634bc
Show file tree
Hide file tree
Showing 13 changed files with 360 additions and 27 deletions.
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,25 @@ View all the columns & their data types as well as individual details of each co
|float|![](https://raw.githubusercontent.com/aschonfeld/dtale-media/master/images/Describe_float.png)||

#### Filter
Apply a simple pandas `query` to your data (link to pandas documentation included in popup)
Apply a simple pandas `query` to your data (link to pandas documentation included in popup)

Context Variables are user-defined values passed in via the `context_variables` argument to dtale.show(); they can be referenced in filters by prefixing the variable name with '@'.

For example, here is how you can use context variables in a pandas query:
```python
import pandas as pd

df = pd.DataFrame([
dict(name='Joe', age=7),
dict(name='Bob', age=23),
dict(name='Ann', age=45),
dict(name='Cat', age=88),
])
two_oldest_ages = df['age'].nlargest(2)
df.query('age in @two_oldest_ages')
```
And here is how you would pass that context variable to D-Tale: `dtale.show(df, context_variables=dict(two_oldest_ages=two_oldest_ages))`


|Editing|Result|
|--------|:------:|
Expand Down
7 changes: 5 additions & 2 deletions dtale/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def is_port_in_use(port):


def show(data=None, host=None, port=None, name=None, debug=False, subprocess=True, data_loader=None,
reaper_on=True, open_browser=False, notebook=False, force=False, **kwargs):
reaper_on=True, open_browser=False, notebook=False, force=False, context_vars=None, **kwargs):
"""
Entry point for kicking off D-Tale :class:`flask:flask.Flask` process from python process
Expand Down Expand Up @@ -423,6 +423,9 @@ def show(data=None, host=None, port=None, name=None, debug=False, subprocess=Tru
:param force: if true, this will force the D-Tale instance to run on the specified host/port by killing any
other process running at that location
:type force: bool, optional
:param context_vars: a dictionary of the variables that will be available for use in user-defined expressions,
such as filters
:type context_vars: dict, optional
:Example:
Expand All @@ -440,7 +443,7 @@ def show(data=None, host=None, port=None, name=None, debug=False, subprocess=Tru

initialize_process_props(host, port, force)
url = build_url(ACTIVE_PORT, ACTIVE_HOST)
instance = startup(url, data=data, data_loader=data_loader, name=name)
instance = startup(url, data=data, data_loader=data_loader, name=name, context_vars=context_vars)
is_active = not running_with_flask_debug() and is_up(url)
if is_active:
def _start():
Expand Down
4 changes: 2 additions & 2 deletions dtale/dash_application/charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dtale.utils import (classify_type, dict_merge, divide_chunks,
flatten_lists, get_dtypes, make_list,
make_timeout_request, run_query)
from dtale.views import DATA
from dtale.views import DATA, CONTEXT_VARIABLES
from dtale.views import build_chart as build_chart_data


Expand Down Expand Up @@ -720,7 +720,7 @@ def build_figure_data(data_id, chart_type=None, query=None, x=None, y=None, z=No
rolling_comp=rolling_comp)):
return None

data = run_query(DATA[data_id], query)
data = run_query(DATA[data_id], query, CONTEXT_VARIABLES[data_id])
chart_kwargs = dict(group_col=group, agg=agg, allow_duplicates=chart_type == 'scatter', rolling_win=window,
rolling_comp=rolling_comp)
if chart_type in ZAXIS_CHARTS:
Expand Down
4 changes: 2 additions & 2 deletions dtale/dash_application/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
show_input_handler,
show_yaxis_ranges)
from dtale.utils import dict_merge, make_list, run_query
from dtale.views import DATA
from dtale.views import DATA, CONTEXT_VARIABLES

logger = getLogger(__name__)

Expand Down Expand Up @@ -139,7 +139,7 @@ def query_input(query, pathname, curr_query):
:rtype: tuple of (str, str, str)
"""
try:
run_query(DATA[get_data_id(pathname)], query)
run_query(DATA[get_data_id(pathname)], query, CONTEXT_VARIABLES[get_data_id(pathname)])
return query, {'line-height': 'inherit'}, ''
except BaseException as ex:
return curr_query, {'line-height': 'inherit', 'background-color': 'pink'}, str(ex)
Expand Down
14 changes: 9 additions & 5 deletions dtale/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def sort_df_for_grid(df, params):
return df.sort_index()


def filter_df_for_grid(df, params):
def filter_df_for_grid(df, params, context_vars):
"""
Filter dataframe based on 'filters' property in parameter dictionary. Filter
configuration is of the following shape:
Expand Down Expand Up @@ -505,6 +505,8 @@ def filter_df_for_grid(df, params):
:type df: :class:`pandas:pandas.DataFrame`
:param params: arguments from :attr:`flask:flask.request`
:type params: dict
:param context_vars: a dictionary of the variables that will be available for use in user-defined expressions
:type context_vars: dict
:return: filtering dataframe
:rtype: :class:`pandas:pandas.DataFrame`
"""
Expand Down Expand Up @@ -539,7 +541,7 @@ def filter_df_for_grid(df, params):
df = df[stringified_col.astype(str) == filter_val[1:]]
else:
df = df[stringified_col.astype(str).str.lower().str.contains(filter_val.lower(), na=False)]
df = run_query(df, params.get('query'))
df = run_query(df, params.get('query'), context_vars)
return df


Expand Down Expand Up @@ -759,7 +761,7 @@ def make_timeout_request(target, args=None, kwargs=None, timeout=60):
return results


def run_query(df, query):
def run_query(df, query, context_vars):
"""
Utility function for running :func:`pandas:pandas.DataFrame.query` . This function contains extra logic to
handle when column names contain special characters. Looks like pandas will be handling this in a future
Expand All @@ -771,6 +773,8 @@ def run_query(df, query):
:type df: :class:`pandas:pandas.DataFrame`
:param query: query string
:type query: string
:param context_vars: dictionary of user-defined variables which can be referenced by name in query strings
:type context_vars: dict
:return: filtered dataframe
"""
if (query or '') == '':
Expand All @@ -792,10 +796,10 @@ def run_query(df, query):
inv_replacements = {replacements[k]: k for k in replacements.keys()}
df = df.rename(columns=replacements) # Rename the columns

df = df.query(final_query) # Carry out query
df = df.query(final_query, local_dict=context_vars) # Carry out query
df = df.rename(columns=inv_replacements)
else:
df = df.query(query)
df = df.query(query, local_dict=context_vars)

if not len(df):
raise Exception('query "{}" found no data, please alter'.format(query))
Expand Down
78 changes: 67 additions & 11 deletions dtale/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import webbrowser
from builtins import map, range, str, zip
from logging import getLogger
from collections import defaultdict

from flask import json, redirect, render_template, request

Expand Down Expand Up @@ -32,6 +33,7 @@
SETTINGS = {}
METADATA = {}
IDX_COL = str('dtale_index')
CONTEXT_VARIABLES = defaultdict(dict)


def head_data_id():
Expand Down Expand Up @@ -361,7 +363,7 @@ def format_data(data):
return data, index


def startup(url, data=None, data_loader=None, name=None, data_id=None):
def startup(url, data=None, data_loader=None, name=None, data_id=None, context_vars=None):
"""
Loads and stores data globally
- If data has indexes then it will lock save those columns as locked on the front-end
Expand All @@ -373,8 +375,11 @@ def startup(url, data=None, data_loader=None, name=None, data_id=None):
:param name: string label to apply to your session
:param data_id: integer id assigned to a piece of data viewable in D-Tale, if this is populated then it will
override the data at that id
:param context_vars: a dictionary of the variables that will be available for use in user-defined expressions,
such as filters
:type context_vars: dict, optional
"""
global DATA, DTYPES, SETTINGS, METADATA
global DATA, DTYPES, SETTINGS, METADATA, CONTEXT_VARIABLES

if data_loader is not None:
data = data_loader()
Expand Down Expand Up @@ -414,6 +419,7 @@ def startup(url, data=None, data_loader=None, name=None, data_id=None):
SETTINGS[data_id] = dict(locked=curr_locked)
DATA[data_id] = data
DTYPES[data_id] = build_dtypes_state(data, DTYPES.get(data_id, []))
CONTEXT_VARIABLES[data_id] = build_context_variables(data_id, context_vars)
return DtaleData(data_id, url)
else:
raise Exception('data loaded is None!')
Expand All @@ -426,13 +432,14 @@ def cleanup():
:param port: integer string for a D-Tale process's port
:type port: str
"""
global DATA, DTYPES, SETTINGS, METADATA
global DATA, DTYPES, SETTINGS, METADATA, CONTEXT_VARIABLES

# use pop() because in some pytests port is not available
DATA = {}
SETTINGS = {}
DTYPES = {}
METADATA = {}
CONTEXT_VARIABLES = defaultdict(dict)


def base_render_template(template, data_id, **kwargs):
Expand Down Expand Up @@ -804,7 +811,7 @@ def test_filter(data_id):
:return: JSON {success: True/False}
"""
try:
run_query(DATA[data_id], get_str_arg(request, 'query'))
run_query(DATA[data_id], get_str_arg(request, 'query'), CONTEXT_VARIABLES[data_id])
return jsonify(dict(success=True))
except BaseException as e:
return jsonify(dict(error=str(e), traceback=str(traceback.format_exc())))
Expand Down Expand Up @@ -931,7 +938,7 @@ def get_data(data_id):
}
"""
try:
global SETTINGS, DATA, DTYPES
global SETTINGS, DATA, DTYPES, CONTEXT_VARIABLES
data = DATA[data_id]

# this will check for when someone instantiates D-Tale programatically and directly alters the internal
Expand All @@ -957,7 +964,7 @@ def get_data(data_id):
curr_settings = dict_merge(curr_settings, dict(sort=params['sort']))
else:
curr_settings = {k: v for k, v in curr_settings.items() if k != 'sort'}
data = filter_df_for_grid(data, params)
data = filter_df_for_grid(data, params, CONTEXT_VARIABLES[data_id])
if params.get('query') is not None:
curr_settings = dict_merge(curr_settings, dict(query=params['query']))
else:
Expand Down Expand Up @@ -1000,7 +1007,7 @@ def get_histogram(data_id):
col = get_str_arg(request, 'col', 'values')
bins = get_int_arg(request, 'bins', 20)
try:
data = run_query(DATA[data_id], get_str_arg(request, 'query'))
data = run_query(DATA[data_id], get_str_arg(request, 'query'), CONTEXT_VARIABLES[data_id])
selected_col = find_selected_column(data, col)
data = data[~pd.isnull(data[selected_col])][[selected_col]]
hist = np.histogram(data, bins=bins)
Expand Down Expand Up @@ -1028,7 +1035,7 @@ def get_correlations(data_id):
} or {error: 'Exception message', traceback: 'Exception stacktrace'}
"""
try:
data = run_query(DATA[data_id], get_str_arg(request, 'query'))
data = run_query(DATA[data_id], get_str_arg(request, 'query'), CONTEXT_VARIABLES[data_id])
valid_corr_cols = []
valid_date_cols = []
rolling = False
Expand Down Expand Up @@ -1090,7 +1097,7 @@ def get_chart_data(data_id):
} or {error: 'Exception message', traceback: 'Exception stacktrace'}
"""
try:
data = run_query(DATA[data_id], get_str_arg(request, 'query'))
data = run_query(DATA[data_id], get_str_arg(request, 'query'), CONTEXT_VARIABLES[data_id])
x = get_str_arg(request, 'x')
y = get_json_arg(request, 'y')
group_col = get_json_arg(request, 'group')
Expand Down Expand Up @@ -1121,7 +1128,7 @@ def get_correlations_ts(data_id):
} or {error: 'Exception message', traceback: 'Exception stacktrace'}
"""
try:
data = run_query(DATA[data_id], get_str_arg(request, 'query'))
data = run_query(DATA[data_id], get_str_arg(request, 'query'), CONTEXT_VARIABLES[data_id])
cols = get_str_arg(request, 'cols')
cols = json.loads(cols)
date_col = get_str_arg(request, 'dateCol')
Expand Down Expand Up @@ -1175,7 +1182,7 @@ def get_scatter(data_id):
date_col = get_str_arg(request, 'dateCol')
rolling = get_bool_arg(request, 'rolling')
try:
data = run_query(DATA[data_id], get_str_arg(request, 'query'))
data = run_query(DATA[data_id], get_str_arg(request, 'query'), CONTEXT_VARIABLES[data_id])
idx_col = str('index')
y_cols = [cols[1], idx_col]
if rolling:
Expand Down Expand Up @@ -1213,3 +1220,52 @@ def get_scatter(data_id):
return jsonify(data)
except BaseException as e:
return jsonify(dict(error=str(e), traceback=str(traceback.format_exc())))


def build_context_variables(data_id, new_context_vars=None):
"""
Build and return the dictionary of context variables associated with a process.
If the names of any new variables are not formatted properly, an exception will be raised.
New variables will overwrite the values of existing variables if they share the same name.
:param data_id: integer string identifier for a D-Tale process's data
:type data_id: str
:param new_context_vars: dictionary of name, value pairs for new context variables
:type new_context_vars: dict, optional
:returns: dict of the context variables for this process
:rtype: dict
"""
global CONTEXT_VARIABLES

if new_context_vars:
for name, value in new_context_vars.items():
if not isinstance(name, str):
raise SyntaxError('{}, context variables must be a valid string'.format(name))
elif not name.replace('_', '').isalnum():
raise SyntaxError('{}, context variables can only contain letters, digits, or underscores'.format(name))
elif name.startswith('_'):
raise SyntaxError('{}, context variables can not start with an underscore'.format(name))

return dict_merge(CONTEXT_VARIABLES[data_id], new_context_vars)


@dtale.route('/context-variables/<data_id>')
def get_context_variables(data_id):
"""
:class:`flask:flask.Flask` route which returns a view-only version of the context variables to the front end.
:param data_id: integer string identifier for a D-Tale process's data
:type data_id: str
:return: JSON
"""
global CONTEXT_VARIABLES

def value_as_str(value):
"""Convert values into a string representation that can be shown to the user in the front-end."""
return str(value)[:1000]

try:
return jsonify(context_variables={k: value_as_str(v) for k, v in CONTEXT_VARIABLES[data_id].items()},
success=True)
except BaseException as e:
return jsonify(error=str(e), traceback=str(traceback.format_exc()))
39 changes: 39 additions & 0 deletions static/__tests__/dtale/DataViewer-filter-test.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,43 @@ describe("DataViewer tests", () => {
}, 400);
}, 400);
});

test("DataViewer: filtering, context variables error", done => {
const { DataViewer } = require("../../dtale/DataViewer");
const Filter = require("../../dtale/Filter").default;
const ContextVariables = require("../../dtale/ContextVariables").default;

const store = reduxUtils.createDtaleStore();
buildInnerHTML({ settings: "", dataId: "error" }, store);
const result = mount(
<Provider store={store}>
<DataViewer />
</Provider>,
{
attachTo: document.getElementById("content"),
}
);

setTimeout(() => {
result.update();

//open filter
clickMainMenuButton(result, "Filter");
result.update();
setTimeout(() => {
result.update();
t.equal(
result
.find(Filter)
.find(ContextVariables)
.find(RemovableError)
.find("div.dtale-alert")
.text(),
"Error loading context variables",
"should display error"
);
done();
}, 400);
}, 400);
});
});
Loading

0 comments on commit 97634bc

Please sign in to comment.