From 937c80d278dc7aced1d046ae8528292e1eb16293 Mon Sep 17 00:00:00 2001 From: Bryn Keller Date: Fri, 11 Jan 2019 16:38:17 -0800 Subject: [PATCH] Service-based support for summarization --- pyrivet/rivet.py | 93 +++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 89 insertions(+), 4 deletions(-) diff --git a/pyrivet/rivet.py b/pyrivet/rivet.py index ff2b4cb..d7255ad 100644 --- a/pyrivet/rivet.py +++ b/pyrivet/rivet.py @@ -9,12 +9,17 @@ import shutil import numpy as np import scipy.spatial.distance as distance +import json from typing import List, Tuple """An interface for rivet_console, using the command line and subprocesses.""" rivet_executable = 'rivet_console' +rivet_client = 'rivet-client' + +"""If set, use the rivet-client to run jobs on a remote server""" +server_url = os.getenv("RIVET_SERVER") class PointCloud: @@ -215,15 +220,19 @@ def barcodes(bytes, slices): def _rivet_name(base, homology, x, y): - output_name = base + (".H%d_x%d_y%d.rivet" % (homology, x, y)) + output_name = base.strip() + (".H%d_x%d_y%d.rivet" % (homology, x, y)) return output_name -def compute_file(input_name, output_name=None, homology=0, x=0, y=0): +def compute_file(input_name, output_name=None, homology=0, x=0, y=0, threads=1): if not output_name: output_name = _rivet_name(input_name, homology, x, y) - cmd = "%s %s %s -H %d -x %d -y %d -f msgpack" % \ - (rivet_executable, input_name, output_name, homology, x, y) + if server_url: + cmd = "%s %s %s -H %d -x %d -y %d --redis %s --threads %d" % \ + (rivet_client, input_name, output_name, homology, x, y, server_url, threads) + else: + cmd = "%s %s %s -H %d -x %d -y %d -f msgpack --num_threads %s" % \ + (rivet_executable, input_name, output_name, homology, x, y, threads) subprocess.check_output(shlex.split(cmd)) return output_name @@ -254,6 +263,82 @@ def bounds_file(name): return parse_bounds(subprocess.check_output(shlex.split(cmd)).split(b'\n')) +class Summary: + def __init__(self, invariants, structure, slices, barcodes, bounds): + self.invariants = invariants + self.structure = structure + self.slices = slices + self.barcodes = barcodes, + self.bounds = bounds + + +def summarize(saveable, homology=0, x=0, y=0, slices=None, bounds=False, structure=False, return_invariants=False, threads=1): + """Computes module invariants for the given dataset, and queries them according the the parameters.""" + if not server_url: + raise NotImplementedError + + with TempDir() as dir: + input_path = os.path.join(dir, 'input.txt') + output_path = os.path.join(dir, 'output.rivet') + slices_path = os.path.join(dir, 'slices.txt') + bounds_path = output_path + '.bounds.json' + structure_path = output_path + '.structure.json' + barcodes_path = output_path + '.barcodes.json' + with open(input_path, 'wt') as input: + saveable.save(input) + cmd = "%s %s %s -H %d -x %d -y %d --redis %s --threads %d" % \ + (rivet_client, input_path, output_path, homology, x, y, server_url, threads) + if slices: + with open(os.path.join(dir, 'slices.txt'), 'wt') as slice_temp: + for angle, offset in slices: + slice_temp.write("%s %s\n" % (angle, offset)) + + cmd += (" --slices %s" % slices_path) + if structure: + cmd += " --structure" + if bounds: + cmd += " --bounds" + if not return_invariants: + cmd += " --no-invariants" + output = subprocess.check_output(shlex.split(cmd)).split(b'\n') + print(output) + if return_invariants: + invariants = open(output_path, 'rb').read() + else: + invariants = None + if bounds: + bounds = json.load(open(bounds_path)) + bounds = Bounds((bounds['x_low'], bounds['y_low']), (bounds['x_high'], bounds['y_high'])) + else: + bounds = None + if slices: + barcodes_json = json.load(open(barcodes_path)) + barcodes = [] + for bc in barcodes_json: + angle = bc['angle'] + offset = bc['offset'] + bars_array = np.array(bc['bars']['data']).reshape(bc['bars']['dim']) + bars = [barcode.Bar(row[0], row[1], row[2]) for row in bars_array] + barcodes.append(((angle, offset), barcode.Barcode(bars))) + else: + barcodes = [] + if structure: + structure_json = json.load(open(structure_path)) + x_grades = [fractions.Fraction(*num) for num in structure_json['x_grades']] + y_grades = [fractions.Fraction(*num) for num in structure_json['y_grades']] + graded_rank = None # TODO! Rust API doesn't have this right now. + xi_0 = [(val['x'], val['y'], val['betti_0']) for val in structure_json['points'] + if val['betti_0'] > 0] + xi_1 = [(val['x'], val['y'], val['betti_1']) for val in structure_json['points'] + if val['betti_1'] > 0] + xi_2 = [(val['x'], val['y'], val['betti_2']) for val in structure_json['points'] + if val['betti_2'] > 0] + structure = MultiBetti(Dimensions(x_grades, y_grades), graded_rank, xi_0, xi_1, xi_2) + else: + structure = None + return Summary(invariants, structure, slices, barcodes, bounds) + + class TempDir(os.PathLike): def __enter__(self): self.dirname = os.path.join(tempfile.gettempdir(),