Skip to content

Commit

Permalink
Add np.atleast_1,2,3d() for finfields. Other minor changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
lschoe authored Apr 11, 2024
1 parent d52a9ba commit 5a7f5b6
Show file tree
Hide file tree
Showing 14 changed files with 228 additions and 226 deletions.
362 changes: 174 additions & 188 deletions demos/KaplanMeierSurvivalExplained.ipynb

Large diffs are not rendered by default.

27 changes: 14 additions & 13 deletions demos/kmsurvival.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,16 @@
from mpyc.runtime import mpc


def fit_plot(T1, T2, E1, E2, title, unit_of_time, label1, label2):
kmf1 = KaplanMeierFitter()
kmf2 = KaplanMeierFitter()
ax = kmf1.fit(T1, E1, label=label1, alpha=0.05).plot(show_censors=True)
ax = kmf2.fit(T2, E2, label=label2, alpha=0.05).plot(ax=ax, show_censors=True)
def plot_fits(kmf1, kmf2, title, unit_of_time):
ax = kmf1.plot(show_censors=True)
ax = kmf2.plot(ax=ax, show_censors=True)
ax.set_title(title)
if unit_of_time:
plt.xlabel(f'timeline ({unit_of_time})')
lifelines.plotting.add_at_risk_counts(kmf1, kmf2, ax=ax, labels=None)
plt.tight_layout()
figname = ax.figure.canvas.manager.get_window_title()
ax.figure.canvas.manager.set_window_title(f'Party {mpc.pid} - {figname}')
return kmf1, kmf2


def events_to_table(maxT, T, E):
Expand Down Expand Up @@ -261,12 +258,15 @@ async def main():
' for own events in the clear')

if args.print_tables or args.plot_curves:
plt.figure(1)
title = f'Party {mpc.pid}: {name} Survival - individual events'
kmf1, kmf2 = fit_plot(T1, T2, E1, E2, title, unit_of_time, label1, label2)
kmf1 = KaplanMeierFitter(alpha=0.05, label=label1).fit(T1, E1)
kmf2 = KaplanMeierFitter(alpha=0.05, label=label2).fit(T2, E2)
if args.print_tables:
print(kmf1.event_table)
print(kmf2.event_table)
if args.plot_curves:
plt.figure(1)
title = f'Party {mpc.pid}: {name} Survival - individual events'
plot_fits(kmf1, kmf2, title, unit_of_time)

# expand to timeline 1..maxT and add all input data homomorphically per group
d1, n1 = events_to_table(maxT, T1, E1)
Expand All @@ -286,14 +286,15 @@ async def main():
' for aggregated events in the clear')

if args.print_tables or args.plot_curves:
plt.figure(2)
title = f'Party {mpc.pid}: {name} Survival - aggregated by {stride} {unit_of_time}'
kmf1, kmf2 = fit_plot([t * stride for t in T1], [t * stride for t in T2], E1, E2,
title, unit_of_time, label1, label2)
kmf1 = KaplanMeierFitter(alpha=0.05, label=label1).fit([t * stride for t in T1], E1)
kmf2 = KaplanMeierFitter(alpha=0.05, label=label2).fit([t * stride for t in T2], E2)
if args.print_tables:
print(kmf1.event_table)
print(kmf2.event_table)
if args.plot_curves:
plt.figure(2)
title = f'Party {mpc.pid}: {name} Survival - aggregated by {stride} {unit_of_time}'
plot_fits(kmf1, kmf2, title, unit_of_time)
plt.show()

logging.info('Optimized secure logrank test on all individual events.')
Expand Down
2 changes: 1 addition & 1 deletion docs/cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ reserving lower case for the single-letter options of MPyC programs:

.. argparse::
:module: mpyc
:func: get_arg_parser
:func: _get_arg_parser
:prog: mpycprog.py

MPyC configuration : @before
Expand Down
12 changes: 6 additions & 6 deletions docs/mpyc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ MPyC package
.. automodule:: mpyc
:members:

mpyc.gmpy
---------

.. automodule:: mpyc.gmpy
:members:

mpyc.numpy
----------

.. automodule:: mpyc.numpy
:members:

mpyc.gmpy
---------

.. automodule:: mpyc.gmpy
:members:

mpyc.gfpx
---------

Expand Down
10 changes: 5 additions & 5 deletions mpyc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
8. [sectypes](https://lschoe.github.io/mpyc/mpyc.sectypes.html): SecInt/Fld/Fxp/Flt types for secure (secret-shared) integer/finite-field/fixed-/floating-point values
9. [runtime](https://lschoe.github.io/mpyc/mpyc.runtime.html): core MPC protocols (many hidden by Python's operator overloading)
10. [mpctools](https://lschoe.github.io/mpyc/mpyc.mpctools.html): reduce and accumulate with log round complexity
11. [seclists](https://lschoe.github.io/mpyc/mpyc.seclists.html): secure lists with oblivious access and updates
12. [secgroups](https://lschoe.github.io/mpyc/mpyc.secgroups.html): SecGrp types for secure (secret-shared) finite group elements
13. [random](https://lschoe.github.io/mpyc/mpyc.random.html): securely mimicking Python’s [random](https://docs.python.org/3/library/random.html) module
11. [random](https://lschoe.github.io/mpyc/mpyc.random.html): securely mimicking Python’s [random](https://docs.python.org/3/library/random.html) module
12. [seclists](https://lschoe.github.io/mpyc/mpyc.seclists.html): secure lists with oblivious access and updates
13. [secgroups](https://lschoe.github.io/mpyc/mpyc.secgroups.html): SecGrp types for secure (secret-shared) finite group elements
14. [statistics](https://lschoe.github.io/mpyc/mpyc.statistics.html): securely mimicking Python’s [statistics](https://docs.python.org/3/library/statistics.html) module

The modules are listed in topological order w.r.t. internal dependencies:

- Modules 1-5 are basic modules which can also be used outside an MPC context
- Modules 6-9 form the core of MPyC
- Modules 10-12 form the extended core of MPyC
- Modules 13-14 are small libraries on top of the (extended) core
- Modules 10-13 form the extended core of MPyC
- Module 14 is a small library on top of the (extended) core
4 changes: 2 additions & 2 deletions mpyc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import asyncio


def get_arg_parser():
def _get_arg_parser():
"""Return parser for command line arguments passed to the MPyC runtime."""
parser = argparse.ArgumentParser(add_help=False)

Expand Down Expand Up @@ -108,7 +108,7 @@ def get_arg_parser():


if os.getenv('READTHEDOCS') != 'True':
options = get_arg_parser().parse_known_args()[0]
options = _get_arg_parser().parse_known_args()[0]
if options.VERSION or options.HELP:
options.no_log = True

Expand Down
3 changes: 1 addition & 2 deletions mpyc/asyncoro.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,7 @@ def __await__(self):

def _nested_list(rt, n, dims):
if dims:
n0 = dims[0]
dims = dims[1:]
n0, *dims = dims
s = [_nested_list(rt, n0, dims) for _ in range(n)]
else:
s = [rt() for _ in range(n)]
Expand Down
8 changes: 5 additions & 3 deletions mpyc/finfields.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,11 @@ def __array_function__(self, func, types, args, kwargs):
a = list(map(cls, a))
elif isinstance(a, bool) or a is np.True_ or a is np.False_:
pass
elif not isinstance(a, tuple): # shape
elif func.__name__.startswith('atleast_'):
a = tuple(map(lambda _: cls(_, check=False), a))
elif isinstance(a, tuple):
pass # e.g., for func like shape
else:
a = cls.field(a)
return a

Expand Down Expand Up @@ -1342,8 +1346,6 @@ def trace(self, *args, **kwargs):

return type(self)(a, check=True)

# TODO: add atleast1d(a), atleast2d(a), atleast3d(a)

def transpose(self, *axes):
a = self.value.transpose(*axes)
return type(self)(a, check=False)
Expand Down
2 changes: 1 addition & 1 deletion mpyc/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4224,7 +4224,7 @@ def generate_configs(m, addresses):

def setup():
"""Setup a runtime."""
parser = mpyc.get_arg_parser()
parser = mpyc._get_arg_parser()
argv = sys.argv # keep raw args
options, args = parser.parse_known_args()
if options.VERSION:
Expand Down
5 changes: 5 additions & 0 deletions mpyc/secgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,11 @@ def SecGrpFunc(*args, **kwargs):
globals()[name] = SecGrpFunc


SecSymmetricGroup: type
SecQuadraticResidues: type
SecSchnorrGroup: type
SecEllipticCurve: type
SecClassGroup: type
_toSecGrpFunc(fg.SymmetricGroup) # make SecSymmetricGroup as secure SymmetricGroup version
_toSecGrpFunc(fg.QuadraticResidues) # make SecQuadraticResidues as secure QuadraticResidues version
_toSecGrpFunc(fg.SchnorrGroup) # make SecSchnorrGroup as secure SchnorrGroup version
Expand Down
3 changes: 2 additions & 1 deletion mpyc/sectypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,8 @@ def __bool__(self):
def __array_function__(self, func, types, args, kwargs):
# minimal redirect for now
if f'{func.__name__}' == 'vstack':
# NB: Numpy 2.0 inserts keyword arguments for np.vstack when converting deprecated np.row_stack call.
# NB: Numpy 2.0 inserts keyword arguments for np.vstack
# when converting deprecated np.row_stack call.
kwargs = {}
return eval(f'runtime.np_{func.__name__}')(*args, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions mpyc/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,11 +602,11 @@ def linear_regression(x, y):
"""
n = len(x)
if len(y) != n:
raise statistics.StatisticsError('covariance requires that both inputs '
raise statistics.StatisticsError('linear regression requires that both inputs '
'have same number of data points')

if n < 2:
raise statistics.StatisticsError('covariance requires at least two data points')
raise statistics.StatisticsError('linear regression requires at least two data points')

sectype = type(x[0]) # all elts of x assumed of same type
if not issubclass(sectype, SecureObject):
Expand Down
9 changes: 8 additions & 1 deletion tests/test_finfields.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,12 @@ def test_array_function(self):
np.assertEqual(np.tensordot(F_a, F_b), np.tensordot(a, b))
np.assertEqual(np.inner(F_a, F_b), np.inner(a, b))
np.assertEqual(np.outer(F_a, F_b), np.outer(a, b))
np.assertEqual(np.atleast_1d(F(1)), np.atleast_1d(1))
np.assertEqual(np.atleast_1d(F(1), F_a)[1], np.atleast_1d(1, a)[1])
np.assertEqual(np.atleast_2d(F(1)), np.atleast_2d(1))
np.assertEqual(np.atleast_2d(F(1), F_a)[1], np.atleast_2d(1, a)[1])
np.assertEqual(np.atleast_3d(F(1)), np.atleast_3d(1))
np.assertEqual(np.atleast_3d(F(1), F_a)[1], np.atleast_3d(1, a)[1])
np.assertEqual(np.concatenate((F_a, F_b, F_a.T)), np.concatenate((a, b, a.T)))
np.assertEqual(np.stack([F_a, F_b]), np.stack([a, b]))
np.assertEqual(np.stack([a, F_b], axis=1), np.stack([a, b], axis=1))
Expand All @@ -478,7 +484,8 @@ def test_array_function(self):
np.assertEqual(np.hstack([F_a, F_b]), np.hstack([a, b]))
np.assertEqual(np.dstack([F_a, F_b]), np.dstack([a, b]))
np.assertEqual(np.column_stack([F_a, F_b]), np.column_stack([a, b]))
np.assertEqual(np.row_stack([F_a, F_b]), np.row_stack([a, b]))
if np.lib.NumpyVersion(np.__version__) < '2.0.0b1':
np.assertEqual(np.row_stack([F_a, F_b]), np.row_stack([a, b]))
np.assertEqual(np.split(F_a, 2)[0], np.split(a, 2)[0])
np.assertEqual(np.array_split(F_a, 2)[1], np.array_split(a, 2)[1])
np.assertEqual(np.dsplit(F_a.reshape(1, 4, 4), 2)[1], np.dsplit(a.reshape(1, 4, 4), 2)[1])
Expand Down
3 changes: 2 additions & 1 deletion tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def test_secint_array(self):
np.assertEqual(mpc.run(mpc.output(np.dstack((c[0], c[0])))), np.dstack((a[0], a[0])))
np.assertEqual(mpc.run(mpc.output(np.dstack((c[0, 0],)))), np.dstack((a[0, 0],)))
np.assertEqual(mpc.run(mpc.output(np.column_stack((c, c, c)))), np.column_stack((a, a, a)))
np.assertEqual(mpc.run(mpc.output(np.row_stack((c, c, c)))), np.row_stack((a, a, a)))
if np.lib.NumpyVersion(np.__version__) < '2.0.0b1':
np.assertEqual(mpc.run(mpc.output(np.row_stack((c, c, c)))), np.row_stack((a, a, a)))
np.assertEqual(mpc.run(mpc.output(np.split(c, 2, 1)[0])), np.split(a, 2, 1)[0])
np.assertEqual(mpc.run(mpc.output(np.dsplit(d, 1)[0])), np.dsplit(b, 1)[0])
np.assertEqual(mpc.run(mpc.output(np.hsplit(c, 2)[0])), np.hsplit(a, 2)[0])
Expand Down

0 comments on commit 5a7f5b6

Please sign in to comment.