Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Pickle Can't get attribute 'ObjectRef' #599

Closed
warpoons opened this issue Mar 12, 2024 · 11 comments · Fixed by #610
Closed

[Bug]: Pickle Can't get attribute 'ObjectRef' #599

warpoons opened this issue Mar 12, 2024 · 11 comments · Fixed by #610

Comments

@warpoons
Copy link

Issue Type

Usability

Modules Involved

SPU compiler

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

spu 0.9.0.dev20240311

OS Platform and Distribution

Ubuntu 18.04.6 LTS by WSL

Python Version

3.10

Compiler Version

GCC 11.3.0

Current Behavior?

A bug happened!

Standalone code to reproduce the issue

I just test the tutorials from "https://www.secretflow.org.cn/en/docs/spu/0.7.0b0/tutorials/quick_start" to privately run the compare protocol using SPU. All things are right until I move JAX to SPU using the code "x = ppd.device("P1")(make_rand)()". The bug happened as:
Traceback (most recent call last):
  File "/mnt/c/Users/78299/Desktop/my_spu/demo.py", line 33, in <module>
    x = ppd.device("P1")(make_rand)()
  File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/spu/utils/distributed.py", line 503, in __call__
    self.device.node_client.run(server_fn, *args, **kwargs),
  File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/spu/utils/distributed.py", line 253, in run
    return self._call(self._stub.Run, fn, *args, **kwargs)
  File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/spu/utils/distributed.py", line 244, in _call
    result = pickle.loads(rsp_data)
AttributeError: Can't get attribute 'ObjectRef' on <module '__main__' from '/mnt/c/Users/78299/Desktop/my_spu/demo.py'>

Relevant log output

Traceback (most recent call last):
  File "/mnt/c/Users/78299/Desktop/my_spu/demo.py", line 33, in <module>
    x = ppd.device("P1")(make_rand)()
  File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/spu/utils/distributed.py", line 503, in __call__
    self.device.node_client.run(server_fn, *args, **kwargs),
  File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/spu/utils/distributed.py", line 253, in run
    return self._call(self._stub.Run, fn, *args, **kwargs)
  File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/spu/utils/distributed.py", line 244, in _call
    result = pickle.loads(rsp_data)
AttributeError: Can't get attribute 'ObjectRef' on <module '__main__' from '/mnt/c/Users/78299/Desktop/my_spu/demo.py'>
@warpoons
Copy link
Author

Currently, my Terminal shows:
(spu_py310) warpoons@LabPcPZ:/mnt/c/Users/78299/Desktop/my_spu$ python -m spu.utils.distributed up
[2024-03-12 09:58:14,182] [ForkServerProcess-3] Starting grpc server at 127.0.0.1:61329
[2024-03-12 09:58:14,187] [ForkServerProcess-2] Starting grpc server at 127.0.0.1:61328
[2024-03-12 09:58:14,194] [ForkServerProcess-1] Starting grpc server at 127.0.0.1:61327
--- Logging error ---
--- Logging error ---
--- Logging error ---
Traceback (most recent call last):
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 440, in format
return self._format(record)
Traceback (most recent call last):
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 436, in _format
return self._fmt % values
KeyError: 'processNameCorrected'
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 440, in format
return self._format(record)

During handling of the above exception, another exception occurred:

File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 436, in _format
return self._fmt % values
Traceback (most recent call last):
KeyError: 'processNameCorrected'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 1100, in emit
msg = self.format(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 943, in format
return fmt.format(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 681, in format
s = self.formatMessage(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 650, in formatMessage
return self._style.format(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 1100, in emit
msg = self.format(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 943, in format
return fmt.format(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 442, in format
raise ValueError('Formatting field not found in record: %s' % e)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 681, in format
s = self.formatMessage(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 650, in formatMessage
return self._style.format(record)
ValueError: Formatting field not found in record: 'processNameCorrected'
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 442, in format
raise ValueError('Formatting field not found in record: %s' % e)
Call stack:
ValueError: Formatting field not found in record: 'processNameCorrected'
Call stack:
Traceback (most recent call last):
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 440, in format
return self._format(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 436, in _format
return self._fmt % values
KeyError: 'processNameCorrected'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 1100, in emit
msg = self.format(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 943, in format
return fmt.format(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 681, in format
s = self.formatMessage(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 650, in formatMessage
return self._style.format(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 442, in format
raise ValueError('Formatting field not found in record: %s' % e)
ValueError: Formatting field not found in record: 'processNameCorrected'
Call stack:
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 973, in _bootstrap
self._bootstrap_inner()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 973, in _bootstrap
self._bootstrap_inner()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/thread.py", line 83, in _worker
work_item.run()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/thread.py", line 83, in _worker
work_item.run()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 793, in _stream_response_in_pool
_send_message_callback_to_blocking_iterator_adapter(
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 793, in _stream_response_in_pool
_send_message_callback_to_blocking_iterator_adapter(
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 813, in _send_message_callback_to_blocking_iterator_adapter
response, proceed = _take_response_from_response_iterator(
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 813, in _send_message_callback_to_blocking_iterator_adapter
response, proceed = _take_response_from_response_iterator(
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 599, in _take_response_from_response_iterator
return next(response_iterator), True
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 599, in _take_response_from_response_iterator
return next(response_iterator), True
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/spu/utils/distributed.py", line 325, in Run
logger.info(f"Run : {fn.name} at {self.node_id}")
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/spu/utils/distributed.py", line 325, in Run
logger.info(f"Run : {fn.name} at {self.node_id}")
Message: 'Run : builtin_spu_init at node:2'
Arguments: ()
Message: 'Run : builtin_spu_init at node:0'
Arguments: ()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 973, in _bootstrap
self._bootstrap_inner()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/thread.py", line 83, in _worker
work_item.run()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 793, in _stream_response_in_pool
_send_message_callback_to_blocking_iterator_adapter(
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 813, in _send_message_callback_to_blocking_iterator_adapter
response, proceed = _take_response_from_response_iterator(
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 599, in _take_response_from_response_iterator
return next(response_iterator), True
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/spu/utils/distributed.py", line 325, in Run
logger.info(f"Run : {fn.name} at {self.node_id}")
Message: 'Run : builtin_spu_init at node:1'
Arguments: ()
I0312 09:58:23.638907 5129 external/com_github_brpc_brpc/src/brpc/server.cpp:1181] Server[yacl::link::transport::internal::ReceiverServiceImpl] is serving on port=61438.
W0312 09:58:23.638925 5129 external/com_github_brpc_brpc/src/brpc/server.cpp:1187] Builtin services are disabled according to ServerOptions.has_builtin_services
I0312 09:58:23.639010 5130 external/com_github_brpc_brpc/src/brpc/server.cpp:1181] Server[yacl::link::transport::internal::ReceiverServiceImpl] is serving on port=61439.
W0312 09:58:23.639024 5130 external/com_github_brpc_brpc/src/brpc/server.cpp:1187] Builtin services are disabled according to ServerOptions.has_builtin_services
I0312 09:58:23.639154 5128 external/com_github_brpc_brpc/src/brpc/server.cpp:1181] Server[yacl::link::transport::internal::ReceiverServiceImpl] is serving on port=61437.
W0312 09:58:23.639168 5128 external/com_github_brpc_brpc/src/brpc/server.cpp:1187] Builtin services are disabled according to ServerOptions.has_builtin_services
[2024-03-12 09:58:23,640] [ForkServerProcess-3] spu-runtime (SPU) initialized
[2024-03-12 09:58:23,640] [ForkServerProcess-1] spu-runtime (SPU) initialized
[2024-03-12 09:58:23,640] [ForkServerProcess-2] spu-runtime (SPU) initialized
--- Logging error ---
Traceback (most recent call last):
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 440, in format
return self._format(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 436, in _format
return self._fmt % values
KeyError: 'processNameCorrected'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 1100, in emit
msg = self.format(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 943, in format
return fmt.format(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 681, in format
s = self.formatMessage(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 650, in formatMessage
return self._style.format(record)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/logging/init.py", line 442, in format
raise ValueError('Formatting field not found in record: %s' % e)
ValueError: Formatting field not found in record: 'processNameCorrected'
Call stack:
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 973, in _bootstrap
self._bootstrap_inner()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/thread.py", line 83, in _worker
work_item.run()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 793, in _stream_response_in_pool
_send_message_callback_to_blocking_iterator_adapter(
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 813, in _send_message_callback_to_blocking_iterator_adapter
response, proceed = _take_response_from_response_iterator(
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 599, in _take_response_from_response_iterator
return next(response_iterator), True
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/spu/utils/distributed.py", line 325, in Run
logger.info(f"Run : {fn.name} at {self.node_id}")
Message: 'Run : make_rand at node:0'
Arguments: ()
^CProcess ForkServerProcess-3:
Traceback (most recent call last):
Process ForkServerProcess-2:
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/runpy.py", line 196, in _run_module_as_main
Process ForkServerProcess-1:
return _run_code(code, main_globals, None,
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/spu/utils/distributed.py", line 1367, in
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/multiprocess/process.py", line 314, in _bootstrap
self.run()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/multiprocess/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/spu/utils/distributed.py", line 213, in serve
server.wait_for_termination()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/multiprocess/process.py", line 314, in _bootstrap
self.run()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 1350, in wait_for_termination
return _common.wait(
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/multiprocess/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/multiprocess/process.py", line 314, in _bootstrap
self.run()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_common.py", line 156, in wait
_wait_once(wait_fn, MAXIMUM_WAIT_TIMEOUT, spin_cb)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/spu/utils/distributed.py", line 213, in serve
server.wait_for_termination()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/multiprocess/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_common.py", line 116, in _wait_once
wait_fn(timeout=timeout)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 1350, in wait_for_termination
return _common.wait(
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/spu/utils/distributed.py", line 213, in serve
server.wait_for_termination()
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 607, in wait
signaled = self._cond.wait(timeout)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_server.py", line 1350, in wait_for_termination
return _common.wait(
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_common.py", line 156, in wait
_wait_once(wait_fn, MAXIMUM_WAIT_TIMEOUT, spin_cb)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 324, in wait
gotit = waiter.acquire(True, timeout)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_common.py", line 156, in wait
_wait_once(wait_fn, MAXIMUM_WAIT_TIMEOUT, spin_cb)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_common.py", line 116, in _wait_once
wait_fn(timeout=timeout)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/grpc/_common.py", line 116, in _wait_once
wait_fn(timeout=timeout)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 607, in wait
signaled = self._cond.wait(timeout)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 607, in wait
signaled = self._cond.wait(timeout)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 324, in wait
gotit = waiter.acquire(True, timeout)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/threading.py", line 324, in wait
gotit = waiter.acquire(True, timeout)
KeyboardInterrupt
worker.join()
KeyboardInterrupt
KeyboardInterrupt
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/multiprocess/process.py", line 149, in join
res = self._popen.wait(timeout)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/multiprocess/popen_fork.py", line 43, in wait
return self.poll(os.WNOHANG if timeout == 0.0 else 0)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/multiprocess/popen_forkserver.py", line 65, in poll
if not wait([self.sentinel], timeout):
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/site-packages/multiprocess/connection.py", line 934, in wait
ready = selector.select(timeout)
File "/home/warpoons/anaconda3/envs/spu_py310/lib/python3.10/selectors.py", line 416, in select
fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt

Thanks for your reading and looking forward to your reply!

@tpppppub
Copy link
Member

tpppppub commented Mar 12, 2024

We have reproduced this issue, please try the 0.8.0b0 version for your current use.

@tpppppub tpppppub changed the title [Bug]: [Bug]: Pickle Can't get attribute 'ObjectRef' Mar 12, 2024
@warpoons
Copy link
Author

We have reproduced this issue, please try the 0.8.0b0 version for your current use.

Thanks for your reply! I have another question:
My current work is to optimize the CNN network architecture tailored to the needs of PPML. So the first thing I have to do is to get a checkpoint of optimized network by pruning or quantization or linearization. Then I use SPU to evaluate the efficiency of such a checkpoint under a 2PC inference seeting. May I ask if the underlying SPU has two stages such as the offline pre-generation and online inference stage just similar to Delphi or CrypTFlow2. More importantly, does SPU support evaluating latency and communication costs at the granularity of network layer-wise or even module-wise? Thanks!

@tpppppub
Copy link
Member

Currently, the offline and online stages are not separated in SPU. As for the communication and latency costs, you can use two configs enable_pphlo_profile and enable_hal_profile to obtain them (you can refer to this #590 ). But the result is at the granularity of the entire program. If you run a single layer independently, you can get the costs of that layer.

@warpoons
Copy link
Author

Currently, the offline and online stages are not separated in SPU. As for the communication and latency costs, you can use two configs enable_pphlo_profile and enable_hal_profile to obtain them (you can refer to this #590 ). But the result is at the granularity of the entire program. If you run a single layer independently, you can get the costs of that layer.

Thanks for your patient response. Your suggestions have been of great help to me.

@warpoons
Copy link
Author

Currently, the offline and online stages are not separated in SPU. As for the communication and latency costs, you can use two configs enable_pphlo_profile and enable_hal_profile to obtain them (you can refer to this #590 ). But the result is at the granularity of the entire program. If you run a single layer independently, you can get the costs of that layer.

I have an additional question about evaluating the model checkpoint. Sorry for taking your time. May I ask if there are any kinds of tutorials or guidelines regarding how to transform a highly-customized model architecture to be evaluable using SPU? Considering the (partial) convolutions or activations could be replaced by some cheaper variants, is the key to this question how to define the model using JAX instead of SPU itself? (since SPU is highly compatible with JAX). Appreciate your great efforts. Thanks!

@tpppppub
Copy link
Member

I think Puma's paper and its implementation in SPU may be a good tutorial.

@warpoons
Copy link
Author

I think Puma's paper and its implementation in SPU may be a good tutorial.

I will read this great article in detail, thank you!

@anakinxc anakinxc reopened this Mar 14, 2024
@anakinxc anakinxc mentioned this issue Mar 15, 2024
anakinxc added a commit that referenced this issue Mar 15, 2024
# Pull Request

## What problem does this PR solve?

Issue Number: Fixed #599 

## Possible side effects?

- Performance:

- Backward compatibility:
@warpoons
Copy link
Author

warpoons commented Apr 8, 2024

Currently, the offline and online stages are not separated in SPU. As for the communication and latency costs, you can use two configs enable_pphlo_profile and enable_hal_profile to obtain them (you can refer to this #590 ). But the result is at the granularity of the entire program. If you run a single layer independently, you can get the costs of that layer.

Hi @tpppppub. Thanks for your idea on evaluating individual layers to obtain its private inference cost. I have tried the tutorial on "https://github.com/secretflow/spu/tree/main/examples/python/ml/flax_resnet" and it successfully works, indicating that evaluating ResNet50 models privately needs 2.6GB communication and 20+s latency. Now I would like to study the communication and latency costs of each linear/nonlinear layer in DNN and make further optimizations, but I'm not very sure how to evaluate individual each DNN layer since the model is integrally loaded into the program. May I ask if you can provide some specific tutorials, examples, or guidance? Thank you sooooooooo much for your time!

@tpppppub
Copy link
Member

tpppppub commented Apr 8, 2024

SPU just runs Jax programs in an MPC manner. Model partitioning is done in the plaintext world and is out of the scope of SPU. In short, take a look at the model definition in the huggingface transformers library and use Flax to write your own model partition version.

@warpoons
Copy link
Author

warpoons commented Apr 8, 2024

SPU just runs Jax programs in an MPC manner. Model partitioning is done in the plaintext world and is out of the scope of SPU. In short, take a look at the model definition in the huggingface transformers library and use Flax to write your own model partition version.

Thanks for your comments. I'll try this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants