diff --git a/versioneer.py b/versioneer.py index 64fea1c..2b54540 100644 --- a/versioneer.py +++ b/versioneer.py @@ -1,4 +1,3 @@ - # Version: 0.18 """The Versioneer - like a rocketeer, but for versions. @@ -277,6 +276,7 @@ """ from __future__ import print_function + try: import configparser except ImportError: @@ -308,11 +308,13 @@ def get_root(): setup_py = os.path.join(root, "setup.py") versioneer_py = os.path.join(root, "versioneer.py") if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") + err = ( + "Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND')." + ) raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools @@ -325,8 +327,10 @@ def get_root(): me_dir = os.path.normcase(os.path.splitext(me)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) if me_dir != vsr_dir: - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py)) + print( + "Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(me), versioneer_py) + ) except NameError: pass return root @@ -348,6 +352,7 @@ def get(parser, name): if parser.has_option("versioneer", name): return parser.get("versioneer", name) return None + cfg = VersioneerConfig() cfg.VCS = VCS cfg.style = get(parser, "style") or "" @@ -372,17 +377,18 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None @@ -390,10 +396,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([c] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + p = subprocess.Popen( + [c] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + ) break except EnvironmentError: e = sys.exc_info()[1] @@ -418,7 +427,9 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, return stdout, p.returncode -LONG_VERSION_PY['git'] = ''' +LONG_VERSION_PY[ + "git" +] = ''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -993,7 +1004,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1002,7 +1013,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = set([r for r in refs if re.search(r"\d", r)]) if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1010,19 +1021,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -1037,8 +1055,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1046,10 +1063,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = run_command( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + "%s*" % tag_prefix, + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -1072,17 +1098,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -1091,10 +1116,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -1105,13 +1132,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces @@ -1167,16 +1194,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for i in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } else: rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -1205,11 +1238,13 @@ def versions_from_file(filename): contents = f.read() except EnvironmentError: raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) @@ -1218,8 +1253,7 @@ def versions_from_file(filename): def write_to_version_file(filename, versions): """Write the given version number to the given _version.py file.""" os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) + contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) @@ -1251,8 +1285,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -1366,11 +1399,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -1390,9 +1425,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } class VersioneerBadRootError(Exception): @@ -1415,8 +1454,9 @@ def get_versions(verbose=False): handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" + assert ( + cfg.versionfile_source is not None + ), "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) @@ -1470,9 +1510,13 @@ def get_versions(verbose=False): if verbose: print("unable to compute version") - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } def get_version(): @@ -1521,6 +1565,7 @@ def run(self): print(" date: %s" % vers.get("date")) if vers["error"]: print(" error: %s" % vers["error"]) + cmds["version"] = cmd_version # we override "build_py" in both distutils and setuptools @@ -1553,14 +1598,15 @@ def run(self): # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) + cmds["build_py"] = cmd_build_py if "cx_Freeze" in sys.modules: # cx_freeze enabled? from cx_Freeze.dist import build_exe as _build_exe + # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ @@ -1581,17 +1627,21 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["build_exe"] = cmd_build_exe del cmds["build_py"] - if 'py2exe' in sys.modules: # py2exe enabled? + if "py2exe" in sys.modules: # py2exe enabled? try: from py2exe.distutils_buildexe import py2exe as _py2exe # py3 except ImportError: @@ -1610,13 +1660,17 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["py2exe"] = cmd_py2exe # we override different "sdist" commands for both environments @@ -1643,8 +1697,10 @@ def make_release_tree(self, base_dir, files): # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) + write_to_version_file( + target_versionfile, self._versioneer_generated_versions + ) + cmds["sdist"] = cmd_sdist return cmds @@ -1699,11 +1755,13 @@ def do_setup(): root = get_root() try: cfg = get_config_from_root(root) - except (EnvironmentError, configparser.NoSectionError, - configparser.NoOptionError) as e: + except ( + EnvironmentError, + configparser.NoSectionError, + configparser.NoOptionError, + ) as e: if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) + print("Adding sample versioneer config to setup.cfg", file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) @@ -1712,15 +1770,18 @@ def do_setup(): print(" creating %s" % cfg.versionfile_source) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") if os.path.exists(ipy): try: with open(ipy, "r") as f: @@ -1762,8 +1823,10 @@ def do_setup(): else: print(" 'versioneer.py' already in MANIFEST.in") if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % - cfg.versionfile_source) + print( + " appending versionfile_source ('%s') to MANIFEST.in" + % cfg.versionfile_source + ) with open(manifest_in, "a") as f: f.write("include %s\n" % cfg.versionfile_source) else: diff --git a/xhistogram/core.py b/xhistogram/core.py index 3470506..8815ab4 100644 --- a/xhistogram/core.py +++ b/xhistogram/core.py @@ -7,18 +7,33 @@ import numpy as np from functools import reduce from collections.abc import Iterable -from .duck_array_ops import ( +from numpy import ( digitize, bincount, reshape, ravel_multi_index, concatenate, broadcast_arrays, + broadcast_to, ) # range is a keyword so save the builtin so they can use it. _range = range +try: + import dask.array as dsa + + has_dask = True +except ImportError: + has_dask = False + + +def _any_dask_array(*args): + if not has_dask: + return False + else: + return any(isinstance(a, dsa.core.Array) for a in args) + def _ensure_correctly_formatted_bins(bins, N_expected): # TODO: This could be done better / more robustly @@ -120,8 +135,8 @@ def _dispatch_bincount(bin_indices, weights, N, hist_shapes, block_size=None): return _bincount_loop(bin_indices, weights, N, hist_shapes, block_chunks) -def _histogram_2d_vectorized( - *args, bins=None, weights=None, density=False, right=False, block_size=None +def _bincount_2d_vectorized( + *args, bins=None, weights=None, block_size=None ): """Calculate the histogram independently on each row of a 2D array""" @@ -131,14 +146,17 @@ def _histogram_2d_vectorized( # consistency checks for inputa for a, b in zip(args, bins): assert a.ndim == 2 - assert b.ndim == 1 + #assert b.ndim == 1 assert a.shape == a0.shape if weights is not None: assert weights.shape == a0.shape nrows, ncols = a0.shape - nbins = [len(b) for b in bins] - hist_shapes = [nb + 1 for nb in nbins] + + #bins = [np.expand_dims(b, axis=0) if b.ndim == 1 else b for b in bins] + + # TODO assuming all bins have same form here + b0 = bins[0] # a marginally faster implementation would be to use searchsorted, # like numpy histogram itself does @@ -150,14 +168,38 @@ def _histogram_2d_vectorized( # https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/calibration.py#L592 # but a better approach would be to use something like _search_sorted_inclusive() in # numpy histogram. This is an additional motivation for moving to searchsorted - bins = [np.concatenate((b[:-1], b[-1:] + 1e-8)) for b in bins] + # TODO wouldn't need ifs if we just promoted all bins to 2D + + if b0.ndim == 1: + bins = [np.concatenate((b[:-1], b[-1:] + 1e-8)) for b in bins] + elif b0.ndim == 2: + bins = [np.concatenate((b[:, :-1], b[:, -1:] + 1e-8), axis=1) for b in bins] # the maximum possible value of of digitize is nbins # for right=False: # - 0 corresponds to a < b[0] # - i corresponds to bins[i-1] <= a < b[i] # - nbins corresponds to a a >= b[1] - each_bin_indices = [digitize(a, b) for a, b in zip(args, bins)] + + if b0.ndim == 1: + nbins = [len(b) for b in bins] + hist_shapes = [nb + 1 for nb in nbins] + + each_bin_indices = [digitize(a, b) for a, b in zip(args, bins)] + elif b0.ndim == 2: + nbins = [b.shape[1] for b in bins] + hist_shapes = [nb + 1 for nb in nbins] + + # Apply digitize separately to each row with different bins + each_bin_indices = [] + for a, b in zip(args, bins): + each_bin_indices_single_var = np.stack([digitize(a[row, :], b[row, :]) + for row in np.arange(a.shape[0])]) + each_bin_indices.append(each_bin_indices_single_var) + + # TODO check if this array is correct! + print(each_bin_indices) + # product of the bins gives the joint distribution if N_inputs > 1: bin_indices = ravel_multi_index(each_bin_indices, hist_shapes) @@ -178,6 +220,72 @@ def _histogram_2d_vectorized( return bin_counts +def _bincount(*all_arrays, weights=False, axis=None, density=False): + + all_arrays = list(all_arrays) + + weights_array = all_arrays.pop() + + # TODO a more robust way to pass the bins together with the arrays + n_args = len(all_arrays) // 2 + arrays = all_arrays[:n_args] + bins = all_arrays[n_args:] + + # is this necessary? (it is necessary for the weights to match the data) + all_arrays_broadcast = broadcast_arrays(*arrays) + + a0 = all_arrays_broadcast[0] + + do_full_array = (axis is None) or (set(axis) == set(_range(a0.ndim))) + + if do_full_array: + kept_axes_shape = (1,) * a0.ndim + else: + kept_axes_shape = tuple( + [a0.shape[i] if i not in axis else 1 for i in _range(a0.ndim)] + ) + + def reshape_input(a): + if do_full_array: + d = a.ravel()[None, :] + else: + # reshape the array to 2D + # axis 0: preserved axis after histogram + # axis 1: calculate histogram along this axis + new_pos = tuple(_range(-len(axis), 0)) + c = np.moveaxis(a, axis, new_pos) + split_idx = c.ndim - len(axis) + dims_0 = c.shape[:split_idx] + # assert dims_0 == kept_axes_shape + dims_1 = c.shape[split_idx:] + new_dim_0 = np.prod(dims_0) + new_dim_1 = np.prod(dims_1) + # TODO integer vs float logic here is not robust + d = reshape(c, (new_dim_0, new_dim_1)) + return d + + all_arrays_reshaped = [reshape_input(a) for a in all_arrays_broadcast] + + if any(b.ndim > 1 for b in bins): + bins_reshaped = [reshape_input(b) for b in bins] + else: + bins_reshaped = bins + if weights: + weights_broadcast = broadcast_to(weights_array, a0.shape) + weights_reshaped = reshape_input(weights_broadcast) + else: + weights_reshaped = None + + bin_counts = _bincount_2d_vectorized( + *all_arrays_reshaped, bins=bins_reshaped, weights=weights_reshaped + ) + + final_shape = kept_axes_shape + bin_counts.shape[1:] + bin_counts = reshape(bin_counts, final_shape) + + return bin_counts + + def histogram( *args, bins=None, @@ -280,69 +388,110 @@ def histogram( ax_positive = ndim + ax assert ax_positive < ndim, "axis must be less than ndim" axis_normed.append(ax_positive) - axis = np.atleast_1d(axis_normed) + axis = [int(i) for i in axis_normed] - do_full_array = (axis is None) or (set(axis) == set(_range(a0.ndim))) - if do_full_array: - kept_axes_shape = None - else: - kept_axes_shape = tuple([a0.shape[i] for i in _range(a0.ndim) if i not in axis]) + all_arrays = list(args) + n_inputs = len(all_arrays) - all_args = list(args) + # TODO make feeding weights in less janky if weights is not None: - all_args += [weights] - all_args_broadcast = broadcast_arrays(*all_args) - - def reshape_input(a): - if do_full_array: - d = a.ravel()[None, :] - else: - # reshape the array to 2D - # axis 0: preserved axis after histogram - # axis 1: calculate histogram along this axis - new_pos = tuple(_range(-len(axis), 0)) - c = np.moveaxis(a, axis, new_pos) - split_idx = c.ndim - len(axis) - dims_0 = c.shape[:split_idx] - assert dims_0 == kept_axes_shape - dims_1 = c.shape[split_idx:] - new_dim_0 = np.prod(dims_0) - new_dim_1 = np.prod(dims_1) - d = reshape(c, (new_dim_0, new_dim_1)) - return d + has_weights = True + else: + has_weights = False - all_args_reshaped = [reshape_input(a) for a in all_args_broadcast] + dtype = "i8" if not has_weights else weights.dtype - if weights is not None: - weights_reshaped = all_args_reshaped.pop() - else: - weights_reshaped = None + # here I am assuming all the arrays have the same shape + # probably needs to be generalized + input_indexes = [tuple(_range(a.ndim)) for a in all_arrays] + input_index = input_indexes[0] + assert all([ii == input_index for ii in input_indexes]) # Some sanity checks and format bins and range correctly - bins = _ensure_correctly_formatted_bins(bins, n_inputs) + formatted_bins = _ensure_correctly_formatted_bins(bins, n_inputs) range = _ensure_correctly_formatted_range(range, n_inputs) # histogram_bin_edges trigges computation on dask arrays. It would be possible # to write a version of this that doesn't trigger when `range` is provided, but # for now let's just use np.histogram_bin_edges if is_dask_array: - if not all([isinstance(b, np.ndarray) for b in bins]): + if not all([isinstance(b, np.ndarray) for b in formatted_bins]): raise TypeError( "When using dask arrays, bins must be provided as numpy array(s) of edges" ) + bins = formatted_bins else: - bins = [ - np.histogram_bin_edges(a, b, r, weights_reshaped) - for a, b, r in zip(all_args_reshaped, bins, range) - ] - - bin_counts = _histogram_2d_vectorized( - *all_args_reshaped, - bins=bins, - weights=weights_reshaped, - density=density, - block_size=block_size, - ) + bins = [] + for a, b, r in zip(all_arrays, formatted_bins, range): + if isinstance(b, np.ndarray): + # account for possibility that bins is a >1d numpy array + pass + else: + b = np.histogram_bin_edges(a, b, r) + bins.append(b) + bincount_kwargs = dict(weights=has_weights, axis=axis, density=density) + + # broadcast bins to match reduced data + # TODO assumes all bins arguments have the same shape + b0 = bins[0] + broadcast_bins_shape = [ii for ii in a0.shape if ii not in axis] + list(b0.shape) + bins = [np.broadcast_to(b, broadcast_bins_shape) for b in bins] + + # keep these axes in the inputs + if axis is not None: + drop_axes = tuple([ii for ii in input_index if ii in axis]) + else: + drop_axes = input_index + + if _any_dask_array(weights, *all_arrays): + # We should be able to just apply the bin_count function to every + # block and then sum over all blocks to get the total bin count. + # The main challenge is to figure out the chunk shape that will come + # out of _bincount. We might also need to add dummy dimensions to sum + # over in the _bincount function + import dask.array as dsa + + # Important note from blockwise docs + # > Any index, like i missing from the output index is interpreted as a contraction... + # > In the case of a contraction the passed function should expect an iterable of blocks + # > on any array that holds that index. + # This means that we need to have all the input indexes present in the output index + # However, they will be reduced to singleton (len 1) dimensions + + adjust_chunks = {i: (lambda x: 1) for i in drop_axes} + + new_axes = { + max(input_index) + 1 + i: axis_len + for i, axis_len in enumerate([len(bin) - 1 for bin in bins]) + } + out_index = input_index + tuple(new_axes) + + blockwise_args = [] + for arg in all_arrays: + blockwise_args.append(arg) + blockwise_args.append(input_index) + + # Bins arrays do not contain axes which will get reduced along + # TODO incorrect for >1D bins - how do we know what broadcast axes are on bins here? + bins_input_index = tuple(new_axes.keys()) + for b in bins: + blockwise_args.append(b) + blockwise_args.append(bins_input_index) + + bin_counts = dsa.blockwise( + _bincount, + out_index, + *blockwise_args, + new_axes=new_axes, + adjust_chunks=adjust_chunks, + meta=np.array((), dtype), + **bincount_kwargs, + ) + # sum over the block dims + bin_counts = bin_counts.sum(drop_axes) + else: + # drop the extra axis used for summing over blocks + bin_counts = _bincount(*(all_arrays + bins + [weights]), **bincount_kwargs).squeeze(drop_axes) if density: # Normalise by dividing by bin counts and areas such that all the @@ -360,11 +509,4 @@ def reshape_input(a): else: h = bin_counts - if h.shape[0] == 1: - assert do_full_array - h = h.squeeze() - else: - final_shape = kept_axes_shape + h.shape[1:] - h = reshape(h, final_shape) - return h, bins diff --git a/xhistogram/duck_array_ops.py b/xhistogram/duck_array_ops.py deleted file mode 100644 index b9e632e..0000000 --- a/xhistogram/duck_array_ops.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Compatibility module defining operations on duck numpy-arrays. -Shamelessly copied from xarray.""" - -import numpy as np - -try: - import dask.array as dsa - - has_dask = True -except ImportError: - has_dask = False - - -def _dask_or_eager_func(name, eager_module=np, list_of_args=False, n_array_args=1): - """Create a function that dispatches to dask for dask array inputs.""" - if has_dask: - - def f(*args, **kwargs): - dispatch_args = args[0] if list_of_args else args - if any(isinstance(a, dsa.Array) for a in dispatch_args[:n_array_args]): - module = dsa - else: - module = eager_module - return getattr(module, name)(*args, **kwargs) - - else: - - def f(*args, **kwargs): - return getattr(eager_module, name)(*args, **kwargs) - - return f - - -digitize = _dask_or_eager_func("digitize") -bincount = _dask_or_eager_func("bincount") -reshape = _dask_or_eager_func("reshape") -concatenate = _dask_or_eager_func("concatenate", list_of_args=True) -broadcast_arrays = _dask_or_eager_func("broadcast_arrays") -ravel_multi_index = _dask_or_eager_func("ravel_multi_index") diff --git a/xhistogram/test/fixtures.py b/xhistogram/test/fixtures.py index 330d07b..d6d6ee8 100644 --- a/xhistogram/test/fixtures.py +++ b/xhistogram/test/fixtures.py @@ -1,5 +1,7 @@ import dask import dask.array as dsa +import numpy as np +import xarray as xr def empty_dask_array(shape, dtype=float, chunks=None): @@ -12,3 +14,9 @@ def raise_if_computed(): a = a.rechunk(chunks) return a + +def example_dataarray(shape=(5, 20)): + data = np.random.randn(*shape) + dims = [f"dim_{i}" for i in range(len(shape))] + da = xr.DataArray(data, dims=dims, name="T") + return da diff --git a/xhistogram/test/test_core.py b/xhistogram/test/test_core.py index e0df0ee..eb58d97 100644 --- a/xhistogram/test/test_core.py +++ b/xhistogram/test/test_core.py @@ -182,6 +182,7 @@ def test_histogram_results_3d_density(): np.testing.assert_allclose(integral, 1.0) +# TODO parametrize this over axes so there is only one assert per test @pytest.mark.parametrize("block_size", [None, 5, "auto"]) @pytest.mark.parametrize("use_dask", [False, True]) def test_histogram_shape(use_dask, block_size): diff --git a/xhistogram/test/test_duck_array_ops.py b/xhistogram/test/test_duck_array_ops.py deleted file mode 100644 index db3cbd6..0000000 --- a/xhistogram/test/test_duck_array_ops.py +++ /dev/null @@ -1,80 +0,0 @@ -import numpy as np -import dask.array as dsa -from ..duck_array_ops import ( - digitize, - bincount, - reshape, - ravel_multi_index, - broadcast_arrays, -) -from .fixtures import empty_dask_array -import pytest - - -@pytest.mark.parametrize( - "function, args", - [ - (digitize, [np.random.rand(5, 12), np.linspace(0, 1, 7)]), - (bincount, [np.arange(10)]), - ], -) -def test_eager(function, args): - a = function(*args) - assert isinstance(a, np.ndarray) - - -@pytest.mark.parametrize( - "function, args, kwargs", - [ - (digitize, [empty_dask_array((5, 12)), np.linspace(0, 1, 7)], {}), - (bincount, [empty_dask_array((10,))], {"minlength": 5}), - (reshape, [empty_dask_array((10, 5)), (5, 10)], {}), - (ravel_multi_index, (empty_dask_array((10,)), empty_dask_array((10,))), {}), - ], -) -def test_lazy(function, args, kwargs): - # make sure nothing computes - a = function(*args, **kwargs) - assert isinstance(a, dsa.core.Array) - - -@pytest.mark.parametrize("chunks", [(5, 12), (1, 12), (5, 1)]) -def test_digitize_dask_correct(chunks): - a = np.random.rand(5, 12) - da = dsa.from_array(a, chunks=chunks) - bins = np.linspace(0, 1, 7) - d = digitize(a, bins) - dd = digitize(da, bins) - np.testing.assert_array_equal(d, dd.compute()) - - -def test_ravel_multi_index_correct(): - arr = np.array([[3, 6, 6], [4, 5, 1]]) - expected = np.ravel_multi_index(arr, (7, 6)) - actual = ravel_multi_index(arr, (7, 6)) - np.testing.assert_array_equal(expected, actual) - - expected = np.ravel_multi_index(arr, (7, 6), order="F") - actual = ravel_multi_index(arr, (7, 6), order="F") - np.testing.assert_array_equal(expected, actual) - - -def test_broadcast_arrays_numpy(): - a1 = np.empty((1, 5, 25)) - a2 = np.empty((4, 1, 1)) - - a1b, a2b = broadcast_arrays(a1, a2) - assert a1b.shape == (4, 5, 25) - assert a2b.shape == (4, 5, 25) - - -@pytest.mark.parametrize("d1_chunks", [(5 * (1,), (25,)), ((2, 3), (25,))]) -def test_broadcast_arrays_dask(d1_chunks): - d1 = dsa.empty((5, 25), chunks=d1_chunks) - d2 = dsa.empty((1, 25), chunks=(1, 25)) - - d1b, d2b = broadcast_arrays(d1, d2) - assert d1b.shape == (5, 25) - assert d2b.shape == (5, 25) - assert d1b.chunks == d1_chunks - assert d2b.chunks == d1_chunks diff --git a/xhistogram/test/test_xarray.py b/xhistogram/test/test_xarray.py index 7170537..dace559 100644 --- a/xhistogram/test/test_xarray.py +++ b/xhistogram/test/test_xarray.py @@ -4,6 +4,7 @@ import pandas as pd from itertools import combinations +from .fixtures import example_dataarray from ..xarray import histogram @@ -206,3 +207,55 @@ def test_input_type_check(): np_array = np.arange(100) with pytest.raises(TypeError): histogram(np_array) + + +class TestMultiDimensionalBins: + def test_bin_dataarrays_with_extra_dims(self): + data = xr.DataArray([0], dims=["x"], name="a") + bins = xr.DataArray([[1]], dims=["bad", "a_bin"]) + with pytest.raises(ValueError, match="not present in data"): + histogram(data, dim="x", bins=[bins]) + + def test_bin_dataarrays_without_reduce_dim(self): + data = xr.DataArray([0], dims=["x"], name="a") + bins = xr.DataArray(1) + with pytest.raises(ValueError, match="does not contain"): + histogram(data, dim="x", bins=[bins]) + + # TODO parametrize over ndims? + def test_1d_bins_da(self): + data_a = example_dataarray(shape=(10, 12)) + nbins_a = 8 + bins_a = xr.DataArray(np.linspace(-4, 4, nbins_a + 1), + dims=f'{data_a.name}_bin') + + h = histogram(data_a, bins=[bins_a]) + + assert h.shape == (nbins_a,) + + hist, _ = np.histogram(data_a.values, bins=bins_a) + + np.testing.assert_allclose(hist, h.values) + + def test_2d_bins_da(self): + data_a = example_dataarray(shape=(10, 2)) + nbins_a = 7 + bins_a = xr.DataArray([np.linspace(-4, 3, nbins_a + 1), + np.linspace(-3, 4, nbins_a + 1)], + dims=['dim_1', f'{data_a.name}_bin']) + + h = histogram(data_a, dim='dim_0', bins=[bins_a]) + + print(h) + + assert h.shape == (2, nbins_a) + + def _np_hist(*args, **kwargs): + h, _ = np.histogram(*args, **kwargs) + return h + + + hist = np.stack([_np_hist(data_a.values[0, :], bins=bins_a[0, :]), + _np_hist(data_a.values[1, :], bins=bins_a[1, :])]) + + np.testing.assert_allclose(hist, h.values) diff --git a/xhistogram/xarray.py b/xhistogram/xarray.py index a587abc..1a28e85 100644 --- a/xhistogram/xarray.py +++ b/xhistogram/xarray.py @@ -29,7 +29,7 @@ def histogram( Input data. The number of input arguments determines the dimensonality of the histogram. For example, two arguments prodocue a 2D histogram. All args must be aligned and have the same dimensions. - bins : int, str or numpy array or a list of ints, strs and/or arrays, optional + bins : int, str, numpy array or DataArray, or a list of ints, strs, arrays and/or DataArrays, optional If a list, there should be one entry for each item in ``args``. The bin specifications are as follows: @@ -37,6 +37,10 @@ def histogram( * If str; the method used to automatically calculate the optimal bin width for all arguments in ``args``, as defined by numpy `histogram_bin_edges`. * If numpy array; the bin edges for all arguments in ``args``. + * If xarray DataArray: the bin edges for all arguments in ``args``. + The DataArray can be multidimensional, but must contain the output + bins dimension (i.e. `[var]_bins`) and must not have any dimensions + present in the `dim` argument. * If a list of ints, strs and/or arrays; the bin specification as above for every argument in ``args``. @@ -153,6 +157,9 @@ def histogram( else: weights_data = None + if isinstance(dim, str): + dim = (dim,) + if dim is not None: dims_to_keep = [d for d in all_dims_ordered if d not in dim] axis = [args_transposed[0].get_axis_num(d) for d in dim] @@ -160,33 +167,60 @@ def histogram( dims_to_keep = [] axis = None + # create output dims + new_dims = [a.name + bin_dim_suffix for a in args[:N_args]] + output_dims = dims_to_keep + new_dims + + # Create bin coordinates + bin_coords = [] + for bin, new_dim in zip(bins, new_dims): + if isinstance(bin, xr.DataArray): + # align bins if already DataArrays + # Drop dimensions that will be reduced along before aligning bins + if dim is None: + output_shape = a0 + else: + output_shape = a0.isel(**{d: 0 for d in dim}, drop=True) + aligned_bin_coord, _ = xr.align(bin, output_shape, join="exact") + + # check correct dimensions exist + if new_dim not in aligned_bin_coord.dims: + raise ValueError( + f"bins DataArray does not contain dimension {new_dim}") + if any(d not in output_dims for d in aligned_bin_coord.dims): + raise ValueError("dimensions present in bins not present in data") + + # Need to align so that the var_bins dim is last, similar to the reduce dims on the data + bin_coord = aligned_bin_coord.transpose(new_dim, ...) + else: + bin_coord = xr.DataArray(bin, name=new_dim, dims=(new_dim,), attrs=a0.attrs) + bin_coords.append(bin_coord) + h_data, bins = _histogram( *args_data, weights=weights_data, - bins=bins, + bins=[b.values for b in bin_coords], range=range, axis=axis, density=density, block_size=block_size, ) - # create output dims - new_dims = [a.name + bin_dim_suffix for a in args[:N_args]] - output_dims = dims_to_keep + new_dims + # Adjust bin coords to return positions of bin centres rather than bin edges + def _find_centers(da, dim): + return 0.5 * (da.isel(**{dim: slice(None, -1, None)}) + + da.isel(**{dim: slice(1, None, None)})) - # create new coords - bin_centers = [0.5 * (bin[:-1] + bin[1:]) for bin in bins] - new_coords = { - name: ((name,), bin_center, a.attrs) - for name, bin_center, a in zip(new_dims, bin_centers, args) - } + bin_centers = [ + _find_centers(bin, new_bin_dim) for bin, new_bin_dim in zip(bin_coords, new_dims) + ] # old coords associated with dims old_dim_coords = {name: a0[name] for name in dims_to_keep if name in a_coords} all_coords = {} all_coords.update(old_dim_coords) - all_coords.update(new_coords) + all_coords.update({b.name: b for b in bin_centers}) # add compatible coords if keep_coords: for c in a_coords: