diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index a3d6975c00a95..3e3330fa4378f 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -2,12 +2,24 @@ High level interface to PyTables for reading and writing pandas data structures to disk """ +from contextlib import suppress import copy from datetime import date, tzinfo import itertools import os import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from textwrap import dedent +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) import warnings import numpy as np @@ -202,12 +214,10 @@ def _tables(): # set the file open policy # return the file open policy; this changes as of pytables 3.1 # depending on the HDF5 version - try: + with suppress(AttributeError): _table_file_open_policy_is_strict = ( tables.file._FILE_OPEN_POLICY == "strict" ) - except AttributeError: - pass return _table_mod @@ -423,10 +433,8 @@ def read_hdf( except (ValueError, TypeError, KeyError): if not isinstance(path_or_buf, HDFStore): # if there is an error, close the store if we opened it. - try: + with suppress(AttributeError): store.close() - except AttributeError: - pass raise @@ -667,12 +675,10 @@ def open(self, mode: str = "a", **kwargs): tables = _tables() if self._mode != mode: - # if we are changing a write mode to read, ok if self._mode in ["a", "w"] and mode in ["r", "r+"]: pass elif mode in ["w"]: - # this would truncate, raise here if self.is_open: raise PossibleDataLossError( @@ -691,41 +697,14 @@ def open(self, mode: str = "a", **kwargs): self._complevel, self._complib, fletcher32=self._fletcher32 ) - try: - self._handle = tables.open_file(self._path, self._mode, **kwargs) - except OSError as err: # pragma: no cover - if "can not be written" in str(err): - print(f"Opening {self._path} in read-only mode") - self._handle = tables.open_file(self._path, "r", **kwargs) - else: - raise - - except ValueError as err: - - # trap PyTables >= 3.1 FILE_OPEN_POLICY exception - # to provide an updated message - if "FILE_OPEN_POLICY" in str(err): - hdf_version = tables.get_hdf5_version() - err = ValueError( - f"PyTables [{tables.__version__}] no longer supports " - "opening multiple files\n" - "even in read-only mode on this HDF5 version " - f"[{hdf_version}]. You can accept this\n" - "and not open the same file multiple times at once,\n" - "upgrade the HDF5 version, or downgrade to PyTables 3.0.0 " - "which allows\n" - "files to be opened multiple times at once\n" - ) - - raise err - - except Exception as err: + if _table_file_open_policy_is_strict and self.is_open: + msg = ( + "Cannot open HDF5 file, which is already opened, " + "even in read-only mode." + ) + raise ValueError(msg) - # trying to read from a non-existent file causes an error which - # is not part of IOError, make it one - if self._mode == "r" and "Unable to open/create file" in str(err): - raise OSError(str(err)) from err - raise + self._handle = tables.open_file(self._path, self._mode, **kwargs) def close(self): """ @@ -763,10 +742,8 @@ def flush(self, fsync: bool = False): if self._handle is not None: self._handle.flush() if fsync: - try: + with suppress(OSError): os.fsync(self._handle.fileno()) - except OSError: - pass def get(self, key: str): """ @@ -814,20 +791,20 @@ def select( Parameters ---------- key : str - Object being retrieved from file. - where : list, default None - List of Term (or convertible) objects, optional. - start : int, default None - Row number to start selection. + Object being retrieved from file. + where : list or None + List of Term (or convertible) objects, optional. + start : int or None + Row number to start selection. stop : int, default None - Row number to stop selection. - columns : list, default None - A list of columns that if not None, will limit the return columns. - iterator : bool, default False - Returns an iterator. - chunksize : int, default None - Number or rows to include in iteration, return an iterator. - auto_close : bool, default False + Row number to stop selection. + columns : list or None + A list of columns that if not None, will limit the return columns. + iterator : bool or False + Returns an iterator. + chunksize : int or None + Number or rows to include in iteration, return an iterator. + auto_close : bool or False Should automatically close the store when finished. Returns @@ -1090,17 +1067,14 @@ def put( Table format. Write as a PyTables Table structure which may perform worse but allow more flexible operations like searching / selecting subsets of the data. - append : bool, default False - This will force Table format, append the input data to the - existing. + append : bool, default False + This will force Table format, append the input data to the existing. data_columns : list, default None - List of columns to create as data columns, or True to - use all columns. See `here + List of columns to create as data columns, or True to use all columns. + See `here `__. encoding : str, default None Provide an encoding for strings. - dropna : bool, default False, do not write an ALL nan row to - The store settable by the option 'io.hdf.dropna_table'. track_times : bool, default True Parameter is propagated to 'create_table' method of 'PyTables'. If set to False it enables to have the same h5 files (same hashes) @@ -1521,11 +1495,12 @@ def copy( Parameters ---------- - propindexes: bool, default True + propindexes : bool, default True Restore indexes in copied file. - keys : list of keys to include in the copy (defaults to all) - overwrite : overwrite (remove and replace) existing nodes in the - new store (default is True) + keys : list, optional + List of keys to include in the copy (defaults to all). + overwrite : bool, default True + Whether to overwrite (remove and replace) existing nodes in the new store. mode, complib, complevel, fletcher32 same as in HDFStore.__init__ Returns @@ -1648,7 +1623,6 @@ def error(t): # infer the pt from the passed value if pt is None: if value is None: - _tables() assert _table_mod is not None # for mypy if getattr(group, "table", None) or isinstance( @@ -1680,10 +1654,8 @@ def error(t): # existing node (and must be a table) if tt is None: - # if we are a writer, determine the tt if value is not None: - if pt == "series_table": index = getattr(value, "index", None) if index is not None: @@ -1735,38 +1707,12 @@ def _write_to_group( errors: str = "strict", track_times: bool = True, ): - group = self.get_node(key) - - # we make this assertion for mypy; the get_node call will already - # have raised if this is incorrect - assert self._handle is not None - - # remove the node if we are not appending - if group is not None and not append: - self._handle.remove_node(group, recursive=True) - group = None - # we don't want to store a table node at all if our object is 0-len # as there are not dtypes if getattr(value, "empty", None) and (format == "table" or append): return - if group is None: - paths = key.split("/") - - # recursively create the groups - path = "/" - for p in paths: - if not len(p): - continue - new_path = path - if not path.endswith("/"): - new_path += "/" - new_path += p - group = self.get_node(new_path) - if group is None: - group = self._handle.create_group(path, p) - path = new_path + group = self._identify_group(key, append) s = self._create_storer(group, format, value, encoding=encoding, errors=errors) if append: @@ -1807,6 +1753,45 @@ def _read_group(self, group: "Node"): s.infer_axes() return s.read() + def _identify_group(self, key: str, append: bool) -> "Node": + """Identify HDF5 group based on key, delete/create group if needed.""" + group = self.get_node(key) + + # we make this assertion for mypy; the get_node call will already + # have raised if this is incorrect + assert self._handle is not None + + # remove the node if we are not appending + if group is not None and not append: + self._handle.remove_node(group, recursive=True) + group = None + + if group is None: + group = self._create_nodes_and_group(key) + + return group + + def _create_nodes_and_group(self, key: str) -> "Node": + """Create nodes from key and return group name.""" + # assertion for mypy + assert self._handle is not None + + paths = key.split("/") + # recursively create the groups + path = "/" + for p in paths: + if not len(p): + continue + new_path = path + if not path.endswith("/"): + new_path += "/" + new_path += p + group = self.get_node(new_path) + if group is None: + group = self._handle.create_group(path, p) + path = new_path + return group + class TableIterator: """ @@ -1875,11 +1860,9 @@ def __init__( self.auto_close = auto_close def __iter__(self): - # iterate current = self.start while current < self.stop: - stop = min(current + self.chunksize, self.stop) value = self.func(None, None, self.coordinates[current:stop]) current = stop @@ -1895,7 +1878,6 @@ def close(self): self.store.close() def get_result(self, coordinates: bool = False): - # return the actual iterator if self.chunksize is not None: if not isinstance(self.s, Table): @@ -2094,7 +2076,6 @@ def maybe_set_size(self, min_itemsize=None): with an integer size """ if _ensure_decoded(self.kind) == "string": - if isinstance(min_itemsize, dict): min_itemsize = min_itemsize.get(self.name) @@ -2152,7 +2133,6 @@ def update_info(self, info): existing_value = idx.get(key) if key in idx and value is not None and existing_value != value: - # frequency/name just warn if key in ["freq", "index_name"]: ws = attribute_conflict_doc % (key, existing_value, value) @@ -2345,10 +2325,8 @@ def _get_atom(cls, values: ArrayLike) -> "Col": atom = cls.get_atom_timedelta64(shape) elif is_complex_dtype(dtype): atom = _tables().ComplexCol(itemsize=itemsize, shape=shape[0]) - elif is_string_dtype(dtype): atom = cls.get_atom_string(shape, itemsize) - else: atom = cls.get_atom_data(shape, kind=dtype.name) @@ -2454,7 +2432,6 @@ def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str): # reverse converts if dtype == "datetime64": - # recreate with tz if indicated converted = _set_tz(converted, tz, coerce=True) @@ -2471,7 +2448,6 @@ def convert(self, values: np.ndarray, nan_rep, encoding: str, errors: str): ) elif meta == "category": - # we have a categorical categories = metadata codes = converted.ravel() @@ -2826,7 +2802,6 @@ def read_array( ret = node[start:stop] if dtype == "datetime64": - # reconstruct a timezone if indicated tz = getattr(attrs, "tz", None) ret = _set_tz(ret, tz, coerce=True) @@ -3011,11 +2986,9 @@ def write_array(self, key: str, value: ArrayLike, items: Optional[Index] = None) atom = None if self._filters is not None: - try: + with suppress(ValueError): # get the atom for this datatype atom = _tables().Atom.from_dtype(value.dtype) - except ValueError: - pass if atom is not None: # We only get here if self._filters is non-None and @@ -3032,7 +3005,6 @@ def write_array(self, key: str, value: ArrayLike, items: Optional[Index] = None) self.write_array_empty(key, value) elif value.dtype.type == np.object_: - # infer the type, warn if we have a non-string type here (for # performance) inferred_type = lib.infer_dtype(value, skipna=False) @@ -3716,7 +3688,6 @@ def validate_data_columns(self, data_columns, min_itemsize, non_index_axes): # if min_itemsize is a dict, add the keys (exclude 'values') if isinstance(min_itemsize, dict): - existing_data_columns = set(data_columns) data_columns = list(data_columns) # ensure we do not modify data_columns.extend( @@ -4152,7 +4123,6 @@ def read_column( # find the axes for a in self.axes: if column == a.name: - if not a.is_data_indexable: raise ValueError( f"column [{column}] can not be extracted individually; " @@ -4278,9 +4248,7 @@ def write_data(self, chunksize: Optional[int], dropna: bool = False): # if dropna==True, then drop ALL nan rows masks = [] if dropna: - for a in self.values_axes: - # figure the mask: only do if we can successfully process this # column, otherwise ignore the mask mask = isna(a.data).all(axis=0) @@ -4860,7 +4828,6 @@ def _unconvert_index( def _maybe_convert_for_string_atom( name: str, block, existing_col, min_itemsize, nan_rep, encoding, errors ): - if not block.is_object: return block.values @@ -4893,7 +4860,6 @@ def _maybe_convert_for_string_atom( # we cannot serialize this data, so report an exception on a column # by column basis for i in range(len(block.shape[0])): - col = block.iget(i) inferred_type = lib.infer_dtype(col, skipna=False) if inferred_type != "string": @@ -5018,7 +4984,7 @@ def _need_convert(kind: str) -> bool: return False -def _maybe_adjust_name(name: str, version) -> str: +def _maybe_adjust_name(name: str, version: Sequence[int]) -> str: """ Prior to 0.10.1, we named values blocks like: values_block_0 an the name values_0, adjust the given name if necessary. @@ -5032,14 +4998,14 @@ def _maybe_adjust_name(name: str, version) -> str: ------- str """ - try: - if version[0] == 0 and version[1] <= 10 and version[2] == 0: - m = re.search(r"values_block_(\d+)", name) - if m: - grp = m.groups()[0] - name = f"values_{grp}" - except IndexError: - pass + if isinstance(version, str) or len(version) < 3: + raise ValueError("Version is incorrect, expected sequence of 3 integers.") + + if version[0] == 0 and version[1] <= 10 and version[2] == 0: + m = re.search(r"values_block_(\d+)", name) + if m: + grp = m.groups()[0] + name = f"values_{grp}" return name @@ -5129,7 +5095,7 @@ def __init__( if is_list_like(where): # see if we have a passed coordinate like - try: + with suppress(ValueError): inferred = lib.infer_dtype(where, skipna=False) if inferred == "integer" or inferred == "boolean": where = np.asarray(where) @@ -5149,9 +5115,6 @@ def __init__( ) self.coordinates = where - except ValueError: - pass - if self.coordinates is None: self.terms = self.generate(where) @@ -5172,15 +5135,16 @@ def generate(self, where): # raise a nice message, suggesting that the user should use # data_columns qkeys = ",".join(q.keys()) - raise ValueError( - f"The passed where expression: {where}\n" - " contains an invalid variable reference\n" - " all of the variable references must be a " - "reference to\n" - " an axis (e.g. 'index' or 'columns'), or a " - "data_column\n" - f" The currently defined references are: {qkeys}\n" - ) from err + msg = dedent( + f"""\ + The passed where expression: {where} + contains an invalid variable reference + all of the variable references must be a reference to + an axis (e.g. 'index' or 'columns'), or a data_column + The currently defined references are: {qkeys} + """ + ) + raise ValueError(msg) from err def select(self): """ diff --git a/pandas/tests/io/pytables/test_store.py b/pandas/tests/io/pytables/test_store.py index c1938db12a0bc..1e1c9e91faa4b 100644 --- a/pandas/tests/io/pytables/test_store.py +++ b/pandas/tests/io/pytables/test_store.py @@ -49,6 +49,7 @@ HDFStore, PossibleDataLossError, Term, + _maybe_adjust_name, read_hdf, ) @@ -4921,3 +4922,10 @@ def test_unsuppored_hdf_file_error(self, datapath): with pytest.raises(ValueError, match=message): pd.read_hdf(data_path) + + +@pytest.mark.parametrize("bad_version", [(1, 2), (1,), [], "12", "123"]) +def test_maybe_adjust_name_bad_version_raises(bad_version): + msg = "Version is incorrect, expected sequence of 3 integers" + with pytest.raises(ValueError, match=msg): + _maybe_adjust_name("values_block_0", version=bad_version)