Skip to content

Commit

Permalink
Service-based support for summarization
Browse files Browse the repository at this point in the history
  • Loading branch information
xoltar committed Jan 12, 2019
1 parent 81843e6 commit 937c80d
Showing 1 changed file with 89 additions and 4 deletions.
93 changes: 89 additions & 4 deletions pyrivet/rivet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@
import shutil
import numpy as np
import scipy.spatial.distance as distance
import json
from typing import List, Tuple

"""An interface for rivet_console, using the command line
and subprocesses."""

rivet_executable = 'rivet_console'
rivet_client = 'rivet-client'

"""If set, use the rivet-client to run jobs on a remote server"""
server_url = os.getenv("RIVET_SERVER")


class PointCloud:
Expand Down Expand Up @@ -215,15 +220,19 @@ def barcodes(bytes, slices):


def _rivet_name(base, homology, x, y):
output_name = base + (".H%d_x%d_y%d.rivet" % (homology, x, y))
output_name = base.strip() + (".H%d_x%d_y%d.rivet" % (homology, x, y))
return output_name


def compute_file(input_name, output_name=None, homology=0, x=0, y=0):
def compute_file(input_name, output_name=None, homology=0, x=0, y=0, threads=1):
if not output_name:
output_name = _rivet_name(input_name, homology, x, y)
cmd = "%s %s %s -H %d -x %d -y %d -f msgpack" % \
(rivet_executable, input_name, output_name, homology, x, y)
if server_url:
cmd = "%s %s %s -H %d -x %d -y %d --redis %s --threads %d" % \
(rivet_client, input_name, output_name, homology, x, y, server_url, threads)
else:
cmd = "%s %s %s -H %d -x %d -y %d -f msgpack --num_threads %s" % \
(rivet_executable, input_name, output_name, homology, x, y, threads)
subprocess.check_output(shlex.split(cmd))
return output_name

Expand Down Expand Up @@ -254,6 +263,82 @@ def bounds_file(name):
return parse_bounds(subprocess.check_output(shlex.split(cmd)).split(b'\n'))


class Summary:
def __init__(self, invariants, structure, slices, barcodes, bounds):
self.invariants = invariants
self.structure = structure
self.slices = slices
self.barcodes = barcodes,
self.bounds = bounds


def summarize(saveable, homology=0, x=0, y=0, slices=None, bounds=False, structure=False, return_invariants=False, threads=1):
"""Computes module invariants for the given dataset, and queries them according the the parameters."""
if not server_url:
raise NotImplementedError

with TempDir() as dir:
input_path = os.path.join(dir, 'input.txt')
output_path = os.path.join(dir, 'output.rivet')
slices_path = os.path.join(dir, 'slices.txt')
bounds_path = output_path + '.bounds.json'
structure_path = output_path + '.structure.json'
barcodes_path = output_path + '.barcodes.json'
with open(input_path, 'wt') as input:
saveable.save(input)
cmd = "%s %s %s -H %d -x %d -y %d --redis %s --threads %d" % \
(rivet_client, input_path, output_path, homology, x, y, server_url, threads)
if slices:
with open(os.path.join(dir, 'slices.txt'), 'wt') as slice_temp:
for angle, offset in slices:
slice_temp.write("%s %s\n" % (angle, offset))

cmd += (" --slices %s" % slices_path)
if structure:
cmd += " --structure"
if bounds:
cmd += " --bounds"
if not return_invariants:
cmd += " --no-invariants"
output = subprocess.check_output(shlex.split(cmd)).split(b'\n')
print(output)
if return_invariants:
invariants = open(output_path, 'rb').read()
else:
invariants = None
if bounds:
bounds = json.load(open(bounds_path))
bounds = Bounds((bounds['x_low'], bounds['y_low']), (bounds['x_high'], bounds['y_high']))
else:
bounds = None
if slices:
barcodes_json = json.load(open(barcodes_path))
barcodes = []
for bc in barcodes_json:
angle = bc['angle']
offset = bc['offset']
bars_array = np.array(bc['bars']['data']).reshape(bc['bars']['dim'])
bars = [barcode.Bar(row[0], row[1], row[2]) for row in bars_array]
barcodes.append(((angle, offset), barcode.Barcode(bars)))
else:
barcodes = []
if structure:
structure_json = json.load(open(structure_path))
x_grades = [fractions.Fraction(*num) for num in structure_json['x_grades']]
y_grades = [fractions.Fraction(*num) for num in structure_json['y_grades']]
graded_rank = None # TODO! Rust API doesn't have this right now.
xi_0 = [(val['x'], val['y'], val['betti_0']) for val in structure_json['points']
if val['betti_0'] > 0]
xi_1 = [(val['x'], val['y'], val['betti_1']) for val in structure_json['points']
if val['betti_1'] > 0]
xi_2 = [(val['x'], val['y'], val['betti_2']) for val in structure_json['points']
if val['betti_2'] > 0]
structure = MultiBetti(Dimensions(x_grades, y_grades), graded_rank, xi_0, xi_1, xi_2)
else:
structure = None
return Summary(invariants, structure, slices, barcodes, bounds)


class TempDir(os.PathLike):
def __enter__(self):
self.dirname = os.path.join(tempfile.gettempdir(),
Expand Down

0 comments on commit 937c80d

Please sign in to comment.