diff --git a/packages/vaex-core/vaex/dataframe.py b/packages/vaex-core/vaex/dataframe.py index 1a05a8aaa7..7fe851c2f4 100644 --- a/packages/vaex-core/vaex/dataframe.py +++ b/packages/vaex-core/vaex/dataframe.py @@ -332,6 +332,16 @@ def is_datetime(self, expression): def is_string(self, expression): return vaex.array_types.is_string_type(self.data_type(expression)) + def is_image(self, expression): + try: + import PIL + except ModuleNotFoundError: + raise RuntimeError("Please install pillow for image support") + if self.data_type(expression) != object: + return False + value = self.dropna(column_names=[expression]).head(1)[expression].values[0] + return hasattr(value, '_repr_png_') + def is_category(self, column): """Returns true if column is a category.""" column = _ensure_string_from_expression(column) @@ -3988,7 +3998,7 @@ def table_part(k1, k2, parts): if columns_sliced is not None and j >= columns_sliced: column_index += 1 # skip over the slice/ellipsis value = values[name][i] - value = _format_value(value) + value = _format_value(value, value_format=format) values_list[column_index+1][1].append(value) # parts += [""] # return values_list @@ -4011,7 +4021,10 @@ def table_part(k1, k2, parts): values_list = dict(values_list) # print(values_list) import tabulate - table_text = str(tabulate.tabulate(values_list, headers="keys", tablefmt=format)) + tablefmt = format + if tablefmt == "html": + tablefmt = "unsafehtml" + table_text = str(tabulate.tabulate(values_list, headers="keys", tablefmt=tablefmt)) # Tabulate 0.8.7+ escapes html :() table_text = table_text.replace('<i style='opacity: 0.6'>', "") table_text = table_text.replace('</i>', "") @@ -4052,7 +4065,7 @@ def table_part(k1, k2, parts): parts += ["{:,}".format(i + k1)] for name in column_names: value = data_parts[name][i] - value = _format_value(value) + value = _format_value(value, value_format=format) parts += ["%r" % value] parts += [""] return parts @@ -4084,7 +4097,7 @@ def _output_css(self): def _repr_mimebundle_(self, include=None, exclude=None, **kwargs): # TODO: optimize, since we use the same data in both versions # TODO: include latex version - return {'text/html':self._head_and_tail_table(format='html'), 'text/plain': self._head_and_tail_table(format='plain')} + return {'html': self._head_and_tail_table(format='html'), 'text/plain': self._head_and_tail_table(format='plain')} def _repr_html_(self): """Representation for Jupyter.""" diff --git a/packages/vaex-core/vaex/expression.py b/packages/vaex-core/vaex/expression.py index a970d89236..173951f194 100644 --- a/packages/vaex-core/vaex/expression.py +++ b/packages/vaex-core/vaex/expression.py @@ -24,7 +24,6 @@ import vaex.serialize from . import expresso - try: from StringIO import StringIO except ImportError: @@ -35,7 +34,6 @@ except AttributeError: collectionsAbc = collections - # TODO: repeated from dataframe.py default_shape = 128 PRINT_MAX_COUNT = 10 @@ -43,11 +41,9 @@ expression_namespace = {} expression_namespace['nan'] = np.nan - expression_namespace = {} expression_namespace['nan'] = np.nan - _binary_ops = [ dict(code="+", name='add', op=operator.add), dict(code="in", name='contains', op=operator.contains), @@ -100,7 +96,8 @@ def f(a, b): # print(op, a, b) if isinstance(b, str) and self.dtype.is_datetime: b = np.datetime64(b) - if self.df.is_category(self.expression) and self.df._future_behaviour and not isinstance(b, Expression): + if self.df.is_category(self.expression) and self.df._future_behaviour and not isinstance(b, + Expression): labels = self.df.category_labels(self.expression) if b not in labels: raise ValueError(f'Value {b} not present in {labels}') @@ -137,6 +134,7 @@ def f(a, b): b = f'scalar_datetime("{b}")' expression = '({0} {1} {2})'.format(a.expression, op['code'], b) return Expression(self.ds, expression=expression) + attrs['__%s__' % op['name']] = f if op['name'] in reversable: def f(a, b): @@ -153,6 +151,7 @@ def f(a, b): b = b.expression expression = '({2} {1} {0})'.format(a.expression, op['code'], b) return Expression(self.ds, expression=expression) + attrs['__r%s__' % op['name']] = f wrap(op) @@ -162,7 +161,9 @@ def f(a): self = a expression = '{0}({1})'.format(op['code'], a.expression) return Expression(self.ds, expression=expression) + attrs['__%s__' % op['name']] = f + wrap(op) return type(future_class_name, future_class_parents, attrs) @@ -172,6 +173,7 @@ class DateTime(object): Usually accessed using e.g. `df.birthday.dt.dayofweek` """ + def __init__(self, expression): self.expression = expression @@ -181,6 +183,17 @@ class TimeDelta(object): Usually accessed using e.g. `df.delay.td.days` """ + + def __init__(self, expression): + self.expression = expression + + +class Image(object): + """Image operations + + Operations for images based on PIL/Pillow + """ + def __init__(self, expression): self.expression = expression @@ -190,12 +203,14 @@ class StringOperations(object): Usually accessed using e.g. `df.name.str.lower()` """ + def __init__(self, expression): self.expression = expression class StringOperationsPandas(object): """String operations using Pandas Series (much slower)""" + def __init__(self, expression): self.expression = expression @@ -366,6 +381,7 @@ def _assert_struct_dtype(self): class Expression(with_metaclass(Meta)): """Expression class""" + def __init__(self, ds, expression, ast=None, _selection=False): self.ds = ds assert not isinstance(ds, Expression) @@ -471,7 +487,6 @@ def __bool__(self): return expresso.node_to_string(self.ast.left) != expresso.node_to_string(self.ast.comparators[0]) return True - @property def df(self): # lets gradually move to using .df @@ -513,12 +528,14 @@ def to_dask_array(self, chunks="auto"): dtype = self.dtype chunks = da.core.normalize_chunks(chunks, shape=self.shape, dtype=dtype.numpy) name = 'vaex-expression-%s' % str(uuid.uuid1()) + def getitem(df, item): assert len(item) == 1 item = item[0] start, stop, step = item.start, item.stop, item.step assert step in [None, 1] return self.evaluate(start, stop, parallel=False) + dsk = da.core.getem(name, chunks, getitem=getitem, shape=self.shape, dtype=dtype.numpy) dsk[name] = self return da.Array(dsk, name, chunks, dtype=dtype.numpy) @@ -591,7 +608,7 @@ def __getitem__(self, slicer): indices, fields = slicer else: raise NotImplementedError - + if indices != slice(None): expr = self.df[indices][self.expression] else: @@ -618,6 +635,11 @@ def __abs__(self): """Returns the absolute value of the expression""" return self.abs() + @property + def vision(self): + """Gives access to image operations via :py:class:`Image`""" + return Image(self) + @property def dt(self): """Gives access to datetime operations via :py:class:`DateTime`""" @@ -663,9 +685,11 @@ def expand(self, stop=[]): """ stop = _ensure_strings_from_expressions(stop) + def translate(id): if id in self.ds.virtual_columns and id not in stop: return self.ds.virtual_columns[id] + expr = expresso.translate(self.ast, translate) return Expression(self.ds, expr) @@ -684,6 +708,7 @@ def variables(self, ourself=False, expand_virtual=True, include_virtual=True): {'x', 'y'} """ variables = set() + def record(varname): # always do this for selection if self._selection and self.df.has_selection(varname): @@ -694,7 +719,8 @@ def record(varname): if (include_virtual and (varname != self.expression)) or (varname == self.expression and ourself): variables.add(varname) if expand_virtual: - variables.update(self.df[self.df.virtual_columns[varname]].variables(ourself=include_virtual, include_virtual=include_virtual)) + variables.update(self.df[self.df.virtual_columns[varname]].variables(ourself=include_virtual, + include_virtual=include_virtual)) # we usually don't want to record ourself elif varname != self.expression or ourself: variables.add(varname) @@ -705,11 +731,13 @@ def record(varname): variables -= {'df'} for varname in self._ast_slices: if varname in self.df.virtual_columns and varname not in variables: - if (include_virtual and (f"df['{varname}']" != self.expression)) or (f"df['{varname}']" == self.expression and ourself): + if (include_virtual and (f"df['{varname}']" != self.expression)) or ( + f"df['{varname}']" == self.expression and ourself): variables.add(varname) if expand_virtual: if varname in self.df.virtual_columns: - variables |= self.df[self.df.virtual_columns[varname]].variables(ourself=include_virtual, include_virtual=include_virtual) + variables |= self.df[self.df.virtual_columns[varname]].variables(ourself=include_virtual, + include_virtual=include_virtual) elif f"df['{varname}']" != self.expression or ourself: variables.add(varname) @@ -741,6 +769,7 @@ def walk(node): if isinstance(obj, FunctionSerializablePickle): obj = obj.f return [node_repr, fname, obj, deps] + return walk(expresso._graph(expression)) def _graphviz(self, dot=None): @@ -748,6 +777,7 @@ def _graphviz(self, dot=None): from graphviz import Graph, Digraph node = self._graph() dot = dot or Digraph(comment=self.expression) + def walk(node): if isinstance(node, six.string_types): dot.node(node, node) @@ -760,6 +790,7 @@ def walk(node): dep_id, dep = walk(dep) dot.edge(node_id, dep_id) return node_id, node + walk(node) return dot @@ -786,30 +817,43 @@ def tolist(self, i1=None, i2=None): def __repr__(self): return self._repr_plain_() - def _repr_plain_(self): + def _repr_mimebundle_(self, include=None, exclude=None, **kwargs): + # TODO: optimize, since we use the same data in both versions + # TODO: include latex version + return {"html": self._repr_html_(), 'text/plain': self._repr_plain_()} + + + def _repr_values(self, value_format='plain'): from .formatting import _format_value - def format(values): + + def format_values(values): for i in range(len(values)): value = values[i] - yield _format_value(value) + yield _format_value(value, value_format=value_format) + colalign = ("right",) * 2 try: N = len(self.ds) if N <= PRINT_MAX_COUNT: - values = format(self.evaluate(0, N)) - values = tabulate.tabulate([[i, k] for i, k in enumerate(values)], tablefmt='plain', colalign=colalign) + values = format_values(self.evaluate(0, N)) + values = [[i, k] for i, k in enumerate(values)] + values = tabulate.tabulate(values, tablefmt='plain', colalign=colalign) else: - values_head = format(self.evaluate(0, PRINT_MAX_COUNT//2)) - values_tail = format(self.evaluate(N - PRINT_MAX_COUNT//2, N)) - values_head = list(zip(range(PRINT_MAX_COUNT//2), values_head)) +\ - list(zip(range(N - PRINT_MAX_COUNT//2, N), values_tail)) - values = tabulate.tabulate([k for k in values_head], tablefmt='plain', colalign=colalign) + values_head = format_values(self.evaluate(0, PRINT_MAX_COUNT // 2)) + values_tail = format_values(self.evaluate(N - PRINT_MAX_COUNT // 2, N)) + values = list(zip(range(PRINT_MAX_COUNT // 2), values_head)) + \ + list(zip(range(N - PRINT_MAX_COUNT // 2, N), values_tail)) + values = tabulate.tabulate([k for k in values], tablefmt='plain', colalign=colalign) values = values.split('\n') width = max(map(len, values)) separator = '\n' + '...'.center(width, ' ') + '\n' - values = "\n".join(values[:PRINT_MAX_COUNT//2]) + separator + "\n".join(values[PRINT_MAX_COUNT//2:]) + '\n' + values = "\n".join(values[:PRINT_MAX_COUNT // 2]) + separator + "\n".join( + values[PRINT_MAX_COUNT // 2:]) + '\n' except Exception as e: values = 'Error evaluating: %r' % e + return values + + def _repr_info(self): expression = self.expression if len(expression) > 60: expression = expression[:57] + '...' @@ -823,11 +867,26 @@ def format(values): state = "expression" line = 'Length: {:,} dtype: {} ({})\n'.format(len(self.ds), dtype, state) info += line - info += '-' * (len(line)-1) + '\n' - info += values + info += '-' * (len(line) - 1) + '\n' + return info + + def _repr_html_(self): + info = self._repr_info() + if self.is_image(): + # TODO set up as plain like other expression + info = info.replace("dtype: object", "dtype: image") + info += self.ds[[self.expression]]._repr_html_() + else: + info += self._repr_values(value_format='html') + return f"
{info}
" + + def _repr_plain_(self): + info = self._repr_info() + info += self._repr_values() return info - def count(self, binby=[], limits=None, shape=default_shape, selection=False, delay=False, edges=False, progress=None): + def count(self, binby=[], limits=None, shape=default_shape, selection=False, delay=False, edges=False, + progress=None): '''Shortcut for ds.count(expression, ...), see `Dataset.count`''' kwargs = dict(locals()) del kwargs['self'] @@ -985,6 +1044,7 @@ def value_counts(self, dropna=False, dropnan=False, dropmissing=False, ascending counter_type = counter_type_from_dtype(data_type_item, transient) counters = [None] * self.ds.executor.thread_pool.nthreads + def map(thread_index, i1, i2, selection_masks, blocks): ar = blocks[0] if counters[thread_index] is None: @@ -1001,10 +1061,13 @@ def map(thread_index, i1, i2, selection_masks, blocks): else: counters[thread_index].update(ar) return 0 + def reduce(a, b): - return a+b + return a + b + progressbar = vaex.utils.progressbars(progress, title="value counts") - self.ds.map_reduce(map, reduce, [self.expression], delay=False, progress=progressbar, name='value_counts', info=True, to_numpy=False) + self.ds.map_reduce(map, reduce, [self.expression], delay=False, progress=progressbar, name='value_counts', + info=True, to_numpy=False) counters = [k for k in counters if k is not None] counter = counters[0] for other in counters[1:]: @@ -1072,7 +1135,8 @@ def reduce(a, b): return Series(counts, index=keys) @docsubst - def unique(self, dropna=False, dropnan=False, dropmissing=False, selection=None, axis=None, array_type='list', progress=None, delay=False): + def unique(self, dropna=False, dropnan=False, dropmissing=False, selection=None, axis=None, array_type='list', + progress=None, delay=False): """Returns all unique values. :param dropmissing: do not count missing values @@ -1082,9 +1146,11 @@ def unique(self, dropna=False, dropnan=False, dropmissing=False, selection=None, :param progress: {progress} :param bool array_type: {array_type} """ - return self.ds.unique(self, dropna=dropna, dropnan=dropnan, dropmissing=dropmissing, selection=selection, array_type=array_type, axis=axis, progress=progress, delay=delay) + return self.ds.unique(self, dropna=dropna, dropnan=dropnan, dropmissing=dropmissing, selection=selection, + array_type=array_type, axis=axis, progress=progress, delay=delay) - def nunique(self, dropna=False, dropnan=False, dropmissing=False, selection=None, axis=None, progress=None, delay=False): + def nunique(self, dropna=False, dropnan=False, dropmissing=False, selection=None, axis=None, progress=None, + delay=False): """Counts number of unique values, i.e. `len(df.x.unique()) == df.x.nunique()`. :param dropmissing: do not count missing values @@ -1092,16 +1158,20 @@ def nunique(self, dropna=False, dropnan=False, dropmissing=False, selection=None :param dropna: short for any of the above, (see :func:`Expression.isna`) :param bool axis: Axis over which to determine the unique elements (None will flatten arrays or lists) """ + def key_function(): fp = vaex.cache.fingerprint(self.fingerprint(), dropna, dropnan, dropmissing, selection, axis) return f'nunique-{fp}' + @vaex.cache._memoize(key_function=key_function, delay=delay) def f(): - value = self.unique(dropna=dropna, dropnan=dropnan, dropmissing=dropmissing, selection=selection, axis=axis, array_type=None, progress=progress, delay=delay) + value = self.unique(dropna=dropna, dropnan=dropnan, dropmissing=dropmissing, selection=selection, axis=axis, + array_type=None, progress=progress, delay=delay) if delay: return value.then(len) else: return len(value) + return f() def countna(self): @@ -1162,7 +1232,8 @@ def jit_pythran(self, verbose=False): expression = self.expression if expression in self.ds.virtual_columns: expression = self.ds.virtual_columns[self.expression] - all_vars = self.ds.get_column_names(virtual=True, strings=True, hidden=True) + list(self.ds.variables.keys()) + all_vars = self.ds.get_column_names(virtual=True, strings=True, hidden=True) + list( + self.ds.variables.keys()) vaex.expresso.validate_expression(expression, all_vars, funcs, names) names = list(set(names)) types = ", ".join(str(self.ds.data_type(name)) + "[]" for name in names) @@ -1179,7 +1250,9 @@ def f({0}): m.update(code.encode('utf-8')) module_name = "pythranized_" + m.hexdigest() # print(m.hexdigest()) - module_path = pythran.compile_pythrancode(module_name, code, extra_compile_args=["-DBOOST_SIMD", "-march=native"] + [] if verbose else ["-w"]) + module_path = pythran.compile_pythrancode(module_name, code, extra_compile_args=["-DBOOST_SIMD", + "-march=native"] + [] if verbose else [ + "-w"]) module = imp.load_dynamic(module_name, module_path) function_name = "f_" + m.hexdigest() @@ -1187,7 +1260,7 @@ def f({0}): return Expression(self.ds, "{0}({1})".format(function.name, argstring)) finally: - logger.setLevel(log_level) + logger.setLevel(log_level) def _rename(self, old, new, inplace=False): expression = self if inplace else self.copy() @@ -1220,7 +1293,8 @@ def isin(self, values, use_hashmap=True): """ if use_hashmap: # easiest way to create a set is using the vaex dataframe - values = np.array(values, dtype=self.dtype.numpy) # ensure that values are the same dtype as the expression (otherwise the set downcasts at the C++ level during execution) + values = np.array(values, + dtype=self.dtype.numpy) # ensure that values are the same dtype as the expression (otherwise the set downcasts at the C++ level during execution) df_values = vaex.from_arrays(x=values) ordered_set = df_values._set(df_values.x) var = self.df.add_variable('var_isin_ordered_set', ordered_set, unique=True) @@ -1359,6 +1433,7 @@ def try_nan(x): return np.isnan(x) except: return False + mapper_nan_key_mask = np.array([try_nan(k) for k in mapper_keys]) mapper_has_nan = mapper_nan_key_mask.sum() > 0 if mapper_nan_key_mask.sum() > 1: @@ -1391,7 +1466,8 @@ def try_nan(x): if allow_missing: if default_value is not None: value0 = list(mapper.values())[0] - assert np.issubdtype(type(default_value), np.array(value0).dtype), "default value has to be of similar type" + assert np.issubdtype(type(default_value), + np.array(value0).dtype), "default value has to be of similar type" else: if only_has_nan: pass # we're good, the hash mapper deals with nan @@ -1428,7 +1504,8 @@ def try_nan(x): key_set_name = df.add_variable('map_key_set', ordered_set, unique=True) choices_name = df.add_variable('map_choices', choices, unique=True) if allow_missing: - expr = '_map({}, {}, {}, use_missing={!r}, axis={!r})'.format(self, key_set_name, choices_name, use_masked_array, axis) + expr = '_map({}, {}, {}, use_missing={!r}, axis={!r})'.format(self, key_set_name, choices_name, + use_masked_array, axis) else: expr = '_map({}, {}, {}, axis={!r})'.format(self, key_set_name, choices_name, axis) return Expression(df, expr) @@ -1440,6 +1517,9 @@ def is_masked(self): def is_string(self): return self.df.is_string(self.expression) + def is_image(self): + return self.df.is_image(self.expression) + class FunctionSerializable(object): pass @@ -1511,6 +1591,7 @@ def __init__(self, expression, arguments, argument_dtypes, return_dtype, verbose else: def placeholder(*args, **kwargs): raise Exception('You chose not to compile this function (locally), but did invoke it') + self.f = placeholder def state_get(self): @@ -1594,18 +1675,20 @@ def compile(self): @fuse() def f({0}): return {1} -'''.format(argstring, self.expression)#, ";".join(conversions)) +'''.format(argstring, self.expression) # , ";".join(conversions)) if self.verbose: print("generated code") print(code) - scope = dict()#cupy=cupy) + scope = dict() # cupy=cupy) exec(code, scope) func = scope['f'] + def wrapper(*args): args = [vaex.array_types.to_numpy(k) for k in args] args = [vaex.utils.to_native_array(arg) if isinstance(arg, np.ndarray) else arg for arg in args] args = [cupy.asarray(arg) if isinstance(arg, np.ndarray) else arg for arg in args] return cupy.asnumpy(func(*args)) + return wrapper @@ -1620,6 +1703,7 @@ def __call__(self, *args, **kwargs): def _apply(self, *args, **kwargs): length = len(args[0]) result = [] + def fix_type(v): # TODO: only when column is str type? if isinstance(v, np.str_): @@ -1628,6 +1712,7 @@ def fix_type(v): return v.decode('utf8') else: return v + args = [vaex.array_types.tolist(k) for k in args] for i in range(length): scalar_result = self.f(*[fix_type(k[i]) for k in args], **{key: value[i] for key, value in kwargs.items()}) @@ -1642,15 +1727,17 @@ def __init__(self, dataset, name, f): self.dataset = dataset self.name = name - if not vaex.serialize.can_serialize(f): # if not serializable, assume we can use pickle + if not vaex.serialize.can_serialize(f): # if not serializable, assume we can use pickle f = FunctionSerializablePickle(f) self.f = f def __call__(self, *args, **kwargs): - arg_string = ", ".join([str(k) for k in args] + ['{}={:r}'.format(name, value) for name, value in kwargs.items()]) + arg_string = ", ".join( + [str(k) for k in args] + ['{}={:r}'.format(name, value) for name, value in kwargs.items()]) expression = "{}({})".format(self.name, arg_string) return Expression(self.dataset, expression) + class FunctionBuiltin(object): def __init__(self, dataset, name, **kwargs): @@ -1660,6 +1747,7 @@ def __init__(self, dataset, name, **kwargs): def __call__(self, *args, **kwargs): kwargs = dict(kwargs, **self.kwargs) - arg_string = ", ".join([str(k) for k in args] + ['{}={:r}'.format(name, value) for name, value in kwargs.items()]) + arg_string = ", ".join( + [str(k) for k in args] + ['{}={:r}'.format(name, value) for name, value in kwargs.items()]) expression = "{}({})".format(self.name, arg_string) return Expression(self.dataset, expression) diff --git a/packages/vaex-core/vaex/expresso.py b/packages/vaex-core/vaex/expresso.py index d0d719552a..ee5cef5cb7 100644 --- a/packages/vaex-core/vaex/expresso.py +++ b/packages/vaex-core/vaex/expresso.py @@ -127,6 +127,9 @@ def validate_expression(expr, variable_set, function_set=[], names=None): validate_expression(expr.value, variable_set, function_set, names) elif isinstance(expr, ast_Constant): pass # like True and False + elif isinstance(expr, _ast.Tuple): + for el in expr.elts: + validate_expression(el, variable_set, function_set, names) elif isinstance(expr, _ast.List): for el in expr.elts: validate_expression(el, variable_set, function_set, names) @@ -381,6 +384,9 @@ def visit_Str(self, node): def visit_List(self, node): return "[{}]".format(", ".join([self.visit(k) for k in node.elts])) + def visit_Tuple(self, node): + return "({})".format(" ".join([self.visit(k) + "," for k in node.elts])) + def pow(self, left, right): return "({left} ** {right})".format(left=left, right=right) diff --git a/packages/vaex-core/vaex/formatting.py b/packages/vaex-core/vaex/formatting.py index 4e070783ae..e6c1637311 100644 --- a/packages/vaex-core/vaex/formatting.py +++ b/packages/vaex-core/vaex/formatting.py @@ -1,3 +1,5 @@ +from base64 import b64encode + import numpy as np import numbers import six @@ -6,14 +8,30 @@ from vaex import datatype, struct MAX_LENGTH = 50 +IMAGE_WIDTH = 100 +IMAGE_HEIGHT = 100 + def _trim_string(value): if len(value) > MAX_LENGTH: - value = repr(value[:MAX_LENGTH-3])[:-1] + '...' + value = repr(value[:MAX_LENGTH - 3])[:-1] + '...' return value -def _format_value(value): + +def _format_value(value, value_format='plain'): + if value_format == "html": + if hasattr(value, '_repr_png_'): + data = value._repr_png_() + base64_data = b64encode(data) + data_encoded = base64_data.decode('ascii') + url_data = f"data:image/png;base64,{data_encoded}" + plain = f'' + return plain + elif hasattr(value, 'shape') and len(value.shape) > 1: + return _trim_string(str(value).replace('\n', '
')) + + # print("value = ", value, type(value), isinstance(value, numbers.Number)) if isinstance(value, pa.lib.Scalar): if datatype.DataType(value.type).is_struct: @@ -44,16 +62,17 @@ def _format_value(value): tmp = datetime.timedelta(seconds=value / np.timedelta64(1, 's')) ms = tmp.microseconds s = np.mod(tmp.seconds, 60) - m = np.mod(tmp.seconds//60, 60) + m = np.mod(tmp.seconds // 60, 60) h = tmp.seconds // 3600 d = tmp.days if ms: - value = str('%i days %+02i:%02i:%02i.%i' % (d,h,m,s,ms)) + value = str('%i days %+02i:%02i:%02i.%i' % (d, h, m, s, ms)) else: - value = str('%i days %+02i:%02i:%02i' % (d,h,m,s)) + value = str('%i days %+02i:%02i:%02i' % (d, h, m, s)) return value elif isinstance(value, numbers.Number): value = str(value) + else: value = repr(value) value = _trim_string(value) diff --git a/packages/vaex-core/vaex/registry.py b/packages/vaex-core/vaex/registry.py index ef625ac0db..e364c63748 100644 --- a/packages/vaex-core/vaex/registry.py +++ b/packages/vaex-core/vaex/registry.py @@ -11,7 +11,8 @@ 'str_pandas': vaex.expression.StringOperationsPandas, 'dt': vaex.expression.DateTime, 'td': vaex.expression.TimeDelta, - 'struct': vaex.expression.StructOperations + 'struct': vaex.expression.StructOperations, + 'vision': vaex.expression.Image } diff --git a/packages/vaex-core/vaex/vision.py b/packages/vaex-core/vaex/vision.py new file mode 100644 index 0000000000..ba76d12c39 --- /dev/null +++ b/packages/vaex-core/vaex/vision.py @@ -0,0 +1,183 @@ +__author__ = 'yonatanalexander' + +import glob +import os +import pathlib +import functools +import numpy as np +import warnings +import io +import vaex +import vaex.utils + +try: + import PIL + import base64 +except: + PIL = vaex.utils.optional_import("PIL.Image", modules="pillow") + + +def get_paths(path, suffix=None): + if isinstance(path, list): + return functools.reduce(lambda a, b: get_paths(a, suffix=suffix) + get_paths(b, suffix=suffix), path) + if os.path.isfile(path): + files = [path] + elif os.path.isdir(path): + files = [] + if suffix is not None: + files = [str(path) for path in pathlib.Path(path).rglob(f"*{suffix}")] + else: + for suffix in ['jpg', 'png', 'jpeg', 'ppm', 'thumbnail']: + files.extend([str(path) for path in pathlib.Path(path).rglob(f"*{suffix}")]) + elif isinstance(path, str) and len(glob.glob(path)) > 0: + return glob.glob(path) + else: + raise ValueError( + f"path: {path} do not point to an image, a directory of images, or a nested directory of images, or a glob path of files") + # TODO validate the files without opening it + return files + + +def _safe_apply(f, image_array): + try: + return f(image_array) + except Exception as e: + return None + + +def _infer(item): + if hasattr(item, 'as_py'): + item = item.as_py() + if isinstance(item, np.ndarray): + decode = numpy_2_pil + elif isinstance(item, int): + item = np.ndarray(item) + decode = numpy_2_pil + elif isinstance(item, bytes): + decode = bytes_2_pil + elif isinstance(item, str): + if os.path.isfile(item): + decode = PIL.Image.open + else: + decode = str_2_pil + else: + raise RuntimeError(f"Can't handle item {item}") + return _safe_apply(decode, item) + + +@vaex.register_function(scope='vision') +def infer(images): + images = [_infer(image) for image in images] + return np.array(images, dtype="O") + + +@vaex.register_function(scope='vision') +def open(path, suffix=None): + files = get_paths(path=path, suffix=suffix) + df = vaex.from_arrays(path=files) + df['image'] = df['path'].vision.infer() + return df + + +@vaex.register_function(scope='vision') +def filename(images): + images = [image.filename if hasattr(image, 'filename') else None for image in images] + return np.array(images, dtype="O") + + +@vaex.register_function(scope='vision') +def resize(images, size, resample=3, **kwargs): + images = [image.resize(size, resample=resample, **kwargs) for image in images] + return np.array(images, dtype="O") + + +@vaex.register_function(scope='vision') +def to_numpy(images): + images = [pil_2_numpy(image) for image in images] + return np.array(images, dtype="O") + + +@vaex.register_function(scope='vision') +def to_bytes(arrays, format='png'): + images = [pil_2_bytes(image_array, format=format) for image_array in arrays] + return np.array(images, dtype="O") + + +@vaex.register_function(scope='vision') +def to_str(arrays, format='png', encoding=None): + images = [pil_2_str(image_array, format=format, encoding=encoding) for image_array in arrays] + return np.array(images, dtype="O") + + +@vaex.register_function(scope='vision') +def from_numpy(arrays): + images = [_safe_apply(numpy_2_pil, image_array) for image_array in arrays] + return np.array(images, dtype="O") + + +@vaex.register_function(scope='vision') +def from_bytes(arrays): + images = [_safe_apply(bytes_2_pil, image_array) for image_array in arrays] + return np.array(images, dtype="O") + + +@vaex.register_function(scope='vision') +def from_str(arrays): + images = [_safe_apply(str_2_pil, image_array) for image_array in arrays] + return np.array(images, dtype="O") + + +@vaex.register_function(scope='vision') +def from_path(arrays): + images = [_safe_apply(PIL.Image.open, image_array) for image_array in vaex.array_types.tolist(arrays)] + return np.array(images, dtype="O") + + +def rgba_2_pil(rgba): + # TODO remove? + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + im = PIL.Image.fromarray(rgba[::-1], "RGBA") # , "RGBA", 0, -1) + return im + + +def numpy_2_pil(array): + return PIL.Image.fromarray(np.uint8(array)) + + +def pil_2_numpy(im): + if im is not None: + return np.array(im).astype(object) + return None + + +def pil_2_bytes(im, format="png"): + f = io.BytesIO() + im.save(f, format) + return base64.b64encode(f.getvalue()) + + +def bytes_2_pil(b): + return PIL.Image.open(io.BytesIO(base64.b64decode(b))) + + +def pil_2_str(im, format="png", encoding=None): + args = [encoding] if encoding else [] + return pil_2_bytes(im, format=format).decode(*args) + + +def str_2_pil(im, encoding=None): + args = [encoding] if encoding else [] + return bytes_2_pil(im.encode(*args)) + + +def rgba_to_url(rgba): + bit8 = rgba.dtype == np.uint8 + if not bit8: + rgba = (rgba * 255.).astype(np.uint8) + im = rgba_2_pil(rgba) + data = pil_2_bytes(im) + data = base64.b64encode(data) + data = data.decode("ascii") + imgurl = "data:image/png;base64," + data + "" + return imgurl diff --git a/tests/data/images/cats/cat.4865.jpg b/tests/data/images/cats/cat.4865.jpg new file mode 100755 index 0000000000..4818086f22 Binary files /dev/null and b/tests/data/images/cats/cat.4865.jpg differ diff --git a/tests/data/images/cats/cat.9021.jpg b/tests/data/images/cats/cat.9021.jpg new file mode 100755 index 0000000000..aeddb381a9 Binary files /dev/null and b/tests/data/images/cats/cat.9021.jpg differ diff --git a/tests/data/images/dogs/dog.2423.jpg b/tests/data/images/dogs/dog.2423.jpg new file mode 100755 index 0000000000..c468a0ec69 Binary files /dev/null and b/tests/data/images/dogs/dog.2423.jpg differ diff --git a/tests/data/images/dogs/dog.8091.jpg b/tests/data/images/dogs/dog.8091.jpg new file mode 100755 index 0000000000..5234c8b7c4 Binary files /dev/null and b/tests/data/images/dogs/dog.8091.jpg differ diff --git a/tests/ml/vision_test.py b/tests/ml/vision_test.py new file mode 100644 index 0000000000..723b56e3a8 --- /dev/null +++ b/tests/ml/vision_test.py @@ -0,0 +1,39 @@ +import vaex.vision +import PIL + +basedir = 'tests/data/images' + + +def test_vision_conversions(): + df = vaex.vision.open(basedir) + df['image_bytes'] = df['image'].vision.to_bytes() + df['image_str'] = df['image'].vision.to_str() + df['image_array'] = df['image'].vision.resize((10, 10)).vision.to_numpy() + + assert isinstance(df['image_bytes'].vision.from_bytes().values[0], PIL.Image.Image) + assert isinstance(df['image_str'].vision.from_str().values[0], PIL.Image.Image) + assert isinstance(df['image_array'].vision.from_numpy().values[0], PIL.Image.Image) + + assert isinstance(df['image_bytes'].vision.infer().values[0], PIL.Image.Image) + assert isinstance(df['image_str'].vision.infer().values[0], PIL.Image.Image) + assert isinstance(df['image_array'].vision.infer().values[0], PIL.Image.Image) + assert isinstance(df['path'].vision.infer().values[0], PIL.Image.Image) + + +def test_vision_open(): + df = vaex.vision.open(basedir) + assert df.shape == (4, 2) + assert vaex.vision.open(basedir + '/dogs').shape == (2, 2) + assert vaex.vision.open(basedir + '/dogs/dog*').shape == (2, 2) + assert vaex.vision.open(basedir + '/dogs/dog.2423.jpg').shape == (1, 2) + assert vaex.vision.open([basedir + '/dogs/dog.2423.jpg', basedir + '/cats/cat.4865.jpg']).shape == (2, 2) + assert 'path' in df + assert 'image' in df + + +def test_vision(): + df = vaex.vision.open(basedir) + assert df.shape == (4, 2) + assert isinstance(df.image.tolist()[0], PIL.Image.Image) + assert df.image.vision.to_numpy().shape == (4, 261, 350, 3) + assert df.image.vision.resize((8, 4)).vision.to_numpy().shape == (4, 4, 8, 3)