diff --git a/arraymanagement/nodes/__init__.py b/arraymanagement/nodes/__init__.py index 3d9db51..9fcf117 100644 --- a/arraymanagement/nodes/__init__.py +++ b/arraymanagement/nodes/__init__.py @@ -154,3 +154,73 @@ def descendants(self, ignore_groups=False): descendants = [x for x in descendants if not x.is_group] return descendants +def store_select(pandas_store, key, where=None, **kwargs): + + if "0.12" not in pd.__version__ and isinstance(where, list): + where = [parse_back_compat(x) for x in where] + # we used to accidentally pass series into start/stop + if 'start' in kwargs: + kwargs['start'] = int(kwargs['start']) + if 'stop' in kwargs: + kwargs['stop'] = int(kwargs['stop']) + return pandas_store.select(key, where=where, **kwargs) + +"""From pandas +""" +def parse_back_compat(w, op=None, value=None): + import warnings + from pandas.computation.pytables import Expr + from pandas.compat import string_types + from datetime import datetime, timedelta + import numpy as np + + """ allow backward compatibility for passed arguments """ + + if isinstance(w, dict): + w, op, value = w.get('field'), w.get('op'), w.get('value') + if not isinstance(w, string_types): + raise TypeError( + "where must be passed as a string if op/value are passed") + warnings.warn("passing a dict to Expr is deprecated, " + "pass the where as a single string", + DeprecationWarning) + if isinstance(w, tuple): + if len(w) == 2: + w, value = w + op = '==' + elif len(w) == 3: + w, op, value = w + warnings.warn("passing a tuple into Expr is deprecated, " + "pass the where as a single string", + DeprecationWarning) + + if op is not None: + if not isinstance(w, string_types): + raise TypeError( + "where must be passed as a string if op/value are passed") + + if isinstance(op, Expr): + raise TypeError("invalid op passed, must be a string") + w = "{0}{1}".format(w, op) + if value is not None: + if isinstance(value, Expr): + raise TypeError("invalid value passed, must be a string") + + # stringify with quotes these values + def convert(v): + if isinstance(v, (basestring, datetime,np.datetime64,timedelta,np.timedelta64)) or hasattr(v, 'timetuple'): + return "'{0}'".format(str(v)) + return v + + if isinstance(value, (list,tuple)): + value = [ convert(v) for v in value ] + else: + value = convert(value) + + w = "{0}{1}".format(w, value) + + warnings.warn("passing multiple values to Expr is deprecated, " + "pass the where as a single string", + DeprecationWarning) + + return w diff --git a/arraymanagement/nodes/hdfnodes.py b/arraymanagement/nodes/hdfnodes.py index ebd134a..539d102 100644 --- a/arraymanagement/nodes/hdfnodes.py +++ b/arraymanagement/nodes/hdfnodes.py @@ -16,7 +16,7 @@ from ..exceptions import ArrayManagementException from ..pathutils import dirsplit -from . import Node +from . import Node, store_select import logging import math logger = logging.getLogger(__name__) @@ -60,8 +60,8 @@ def put(self, key, value, format='fixed', append=False, min_itemsize={}): class HDFDataSetMixin(object): def select(self, *args, **kwargs): - return self.store.select(self.localpath, *args, **kwargs) - + return store_select(self.store, self.localpath, *args, **kwargs) + def append(self, *args, **kwargs): return self.store.append(self.localpath, *args, **kwargs) @@ -307,7 +307,7 @@ class PyTables(Node): def __init__(self, context, localpath="/"): super(PyTables, self).__init__(context) self.localpath = localpath - self.handle = tables.File(self.absolute_file_path) + self.handle = tables.File(self.absolute_file_path, mode="a") if self.localpath == "/": children = [x._v_pathname for x in self.handle.listNodes(self.localpath)] if children == ['/__data__']: diff --git a/arraymanagement/nodes/sqlcaching.py b/arraymanagement/nodes/sqlcaching.py index afe15f0..539ba90 100644 --- a/arraymanagement/nodes/sqlcaching.py +++ b/arraymanagement/nodes/sqlcaching.py @@ -13,7 +13,7 @@ write_pandas, override_hdf_types, ) -from arraymanagement.nodes.hdfnodes import Node +from arraymanagement.nodes.hdfnodes import Node, store_select from arraymanagement.nodes.sql import query_info from sqlalchemy.sql.expression import bindparam, tuple_ @@ -83,7 +83,7 @@ def init_from_file(self): def query_min_itemsize(self): try: - min_itemsize = self.store.select('min_itemsize') + min_itemsize = store_select(self.store, 'min_itemsize') except KeyError: return None return min_itemsize.to_dict() @@ -151,7 +151,7 @@ def cache_info(self, query_params): param_dict = self.parameter_dict(query_params) query = param_dict.items() try: - result = self.store.select('cache_spec', where=query) + result = store_select(self.store, 'cache_spec', where=query) except KeyError: return None if result is None: @@ -272,10 +272,13 @@ def _single_select(self, **kwargs): self.cache_data(query_params) cache_info = self.cache_info(query_params) start_row, end_row = cache_info + #convert these series to ints + start_row = start_row[0] + end_row = end_row[0] if not where: where = None - result = self.store.select(self.localpath, where=where, - start=start_row, stop=end_row) + result = store_select(self.store, self.localpath, + where=where, start=start_row, stop=end_row) return result def repr_data(self): repr_data = super(DumbParameterizedQueryTable, self).repr_data() @@ -310,8 +313,8 @@ def select(self, **kwargs): start_row, end_row = cache_info if not where: where = None - result = self.store.select(self.localpath, where=where, - start=start_row, stop=end_row) + result = store_select(self.store, self.localpath, + where=where, start=start_row, stop=end_row) return result def filter_sql(self, **kwargs): @@ -337,7 +340,8 @@ def cache_info(self, query_params): data = self.parameter_dict(query_params) hashval = gethashval(data) try: - result = self.store.select('cache_spec', where=[('hashval', hashval)]) + result = store_select(self.store, 'cache_spec', + where=[('hashval', hashval)]) except KeyError: return None if result is None: @@ -368,11 +372,10 @@ def select(self, query_filter, where=None): if cache_info is None: self.cache_data(query_filter) cache_info = self.cache_info(query_filter) + start_row, end_row = cache_info - if not where: - where = None - result = self.store.select(self.localpath, where=where, - start=start_row, stop=end_row) + result = store_select(self.store, self.localpath, where=where, + start=start_row, stop=end_row) return result def cache_query(self, query_filter): @@ -401,10 +404,13 @@ def store_cache_spec(self, query_filter, start_row, end_row): write_pandas(self.store, 'cache_spec', data, {}, 1.1, replace=False) + def cache_info(self, query_filter): hashval = self.gethashval(query_filter) try: - result = self.store.select('cache_spec', where=[('hashval', hashval)]) + #rewriting where statement for 0.13 pandas style + result = store_select(self.store, 'cache_spec', + where=[('hashval', hashval)]) except KeyError: return None if result is None: @@ -443,14 +449,13 @@ def init_from_file(self): setattr(self, name, column(name)) def select(self, query_filter, where=None, **kwargs): - ignore_cache = kwargs.get('IgnoreCache',None) if ignore_cache: query = self.compiled_query(query_filter,kwargs) return query - - if 'date' not in kwargs.keys(): + dateKeys = [k for k in kwargs.keys() if 'date' in k] + if not dateKeys: #no dates in query fs = FlexibleSqlCaching(self) @@ -461,11 +466,9 @@ def select(self, query_filter, where=None, **kwargs): return result else: - dateKeys = [k for k in kwargs.keys() if 'date' in k] dateKeys = sorted(dateKeys) start_date, end_date = kwargs[dateKeys[0]], kwargs[dateKeys[1]] - result = self.cache_info(query_filter,start_date, end_date) if result is None: @@ -497,11 +500,11 @@ def store_cache_spec(self, query_filter, start_row, end_row, start_date, end_dat def cache_info(self, query_filter, start_date, end_date): hashval = self.gethashval(query_filter) try: - # print self.store['/cache_spec'] + # result = store_select(self.store, 'cache_spec', + # where=[('hashval', hashval), + # ('start_date',start_date)]) - result = self.store.select('cache_spec', where=[('hashval', hashval), - ('start_date',start_date)]) start_date = pd.Timestamp(start_date) end_date = pd.Timestamp(end_date) @@ -559,7 +562,6 @@ def cache_data(self, query_params, start_date, end_date): break; all_query = and_(query_params,column(col_date) >=start_date, column(col_date) <= end_date) - q = self.cache_query(all_query) log.debug(str(q)) @@ -579,7 +581,6 @@ def cache_data(self, query_params, start_date, end_date): db_string_types=db_string_types, db_datetime_types=db_datetime_types ) - self.min_itemsize = min_itemsize self.finalize_min_itemsize() overrides = self.col_types @@ -589,6 +590,7 @@ def cache_data(self, query_params, start_date, end_date): starting_row = self.table.nrows except AttributeError: starting_row = 0 + write_pandas_hdf_from_cursor(self.store, self.localpath, cur, columns, self.min_itemsize, dtype_overrides=overrides, @@ -599,19 +601,17 @@ def cache_data(self, query_params, start_date, end_date): ending_row = self.table.nrows except AttributeError: ending_row = 0 - self.store_cache_spec(query_params, starting_row, ending_row, start_date, end_date) def munge_tables(self, hashval, start_date, end_date): store = self.store - store.select('cache_spec', where=[('hashval', hashval)]) + # store.select('cache_spec', where=[('hashval', hashval)]) store['/cache_spec'][['start_date','end_date']].sort(['start_date']) - - df_min = store.select('cache_spec', where=[('start_date', '<=', start_date)]).reset_index() - df_max = store.select('cache_spec', where=[('end_date', '<=', end_date)]).reset_index() + df_min = store_select(store, 'cache_spec', where=[('start_date', '<=', start_date)]).reset_index() + df_max = store_select(store, 'cache_spec', where=[('end_date', '<=', end_date)]).reset_index() df_total = df_min.append(df_max) df_total.drop_duplicates('_end_row',inplace=True) @@ -623,8 +623,7 @@ def munge_tables(self, hashval, start_date, end_date): for s in ss_vals: start_row = s[0] end_row = s[1] - - temp = store.select(self.localpath, + temp = store_select(store, self.localpath, start=start_row, stop=end_row) temp.head() diff --git a/example/datalib/config.py b/example/datalib/config.py index 20715e0..efab997 100644 --- a/example/datalib/config.py +++ b/example/datalib/config.py @@ -2,6 +2,8 @@ from arraymanagement.nodes.csvnodes import PandasCSVNode from arraymanagement.nodes.hdfnodes import PandasHDFNode, PyTables from arraymanagement.nodes.sql import SimpleQueryTable +from arraymanagement.nodes.sqlcaching import YamlSqlDateCaching + global_config = dict( is_dataset = False, @@ -14,6 +16,7 @@ ('*.hdf5' , PandasHDFNode), ('*.h5' , PandasHDFNode), ('*.sql' , SimpleQueryTable), + ("*.yaml", YamlSqlDateCaching), ]) ) diff --git a/example/sqlviews/example_no_dates_not_entities.yaml b/example/sqlviews/example_no_dates_not_entities.yaml index 449f34c..f452099 100644 --- a/example/sqlviews/example_no_dates_not_entities.yaml +++ b/example/sqlviews/example_no_dates_not_entities.yaml @@ -1,7 +1,6 @@ SQL: # Query for EOD data for list of entities eod_stock: - type: 'conditional' conditionals: query: > diff --git a/tests/node_test.py b/tests/node_test.py index 87b1975..06e65f2 100644 --- a/tests/node_test.py +++ b/tests/node_test.py @@ -9,54 +9,54 @@ from arraymanagement.client import ArrayClient -# def setup_module(): -# basepath = join(dirname(dirname(__file__)), 'example') -# client = ArrayClient(basepath) -# client.clear_disk_cache() -# -# def teardown_module(): -# basepath = join(dirname(dirname(__file__)), 'example') -# client = ArrayClient(basepath) -# client.clear_disk_cache() -# -# def test_csv_node(): -# basepath = join(dirname(dirname(__file__)), 'example') -# client = ArrayClient(basepath) -# node = client.get_node('/csvs/sample') -# data = node.get() -# #better check later -# assert data.shape == (73,2) -# -# def test_hdf_node(): -# basepath = join(dirname(dirname(__file__)), 'example') -# client = ArrayClient(basepath) -# node = client.get_node('/pandashdf5/data') -# assert 'sample' in node.keys() -# node = node.get_node('sample') -# data = node.select() -# assert data.shape == (73,2) -# -# def test_custom_node(): -# basepath = join(dirname(dirname(__file__)), 'example') -# client = ArrayClient(basepath) -# node = client.get_node('/custom/sample2') -# data1 = node.select() -# node = client.get_node('/custom/sample') -# data2 = node.get() -# assert data2.iloc[2]['values'] == 2 -# assert data1.iloc[2]['values'] == 4 -# -# -# def test_csv_node(): -# basepath = join(dirname(dirname(__file__)), 'example') -# client = ArrayClient(basepath) -# node = client.get_node('/customcsvs/sample') -# data1 = node.get() -# node = client.get_node('/customcsvs/sample2') -# data2 = node.select() -# node = client.get_node('/customcsvs/sample_pipe') -# data3 = node.select() -# #better check later +def setup_module(): + basepath = join(dirname(dirname(__file__)), 'example') + client = ArrayClient(basepath) + client.clear_disk_cache() + +def teardown_module(): + basepath = join(dirname(dirname(__file__)), 'example') + client = ArrayClient(basepath) + client.clear_disk_cache() + +def test_csv_node(): + basepath = join(dirname(dirname(__file__)), 'example') + client = ArrayClient(basepath) + node = client.get_node('/csvs/sample') + data = node.get() + #better check later + assert data.shape == (73,2) + +def test_hdf_node(): + basepath = join(dirname(dirname(__file__)), 'example') + client = ArrayClient(basepath) + node = client.get_node('/pandashdf5/data') + assert 'sample' in node.keys() + node = node.get_node('sample') + data = node.select() + assert data.shape == (73,2) + +def test_custom_node(): + basepath = join(dirname(dirname(__file__)), 'example') + client = ArrayClient(basepath) + node = client.get_node('/custom/sample2') + data1 = node.select() + node = client.get_node('/custom/sample') + data2 = node.get() + assert data2.iloc[2]['values'] == 2 + assert data1.iloc[2]['values'] == 4 + + +def test_csv_node(): + basepath = join(dirname(dirname(__file__)), 'example') + client = ArrayClient(basepath) + node = client.get_node('/customcsvs/sample') + data1 = node.get() + node = client.get_node('/customcsvs/sample2') + data2 = node.select() + node = client.get_node('/customcsvs/sample_pipe') + data3 = node.select() + #better check later def test_sql_yaml_cache(): basepath = join(dirname(dirname(__file__)), 'example')