-
Notifications
You must be signed in to change notification settings - Fork 22.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RRef.local_value() for TorchScript #35433
Conversation
Differential Revision: D7923050 Differential Version: 100883364
💊 CircleCI build failures summary and remediationsAs of commit 9838f3b (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no CircleCI failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 7 times. |
Differential Revision: D7923050 Differential Version: 100896349
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
If possible, shall we add a test that creates an RRef locally in a TorchScript function and then directly call local_value? Sth like:
@torch.jit.script
def remote_call_rref_local_value(dst_worker_name, rref):
fut = rpc.rpc_async(self_worker_name, create_rref, args=(5))
rref = fut.wait()
self.assertEqual(5, rref.local_value())
|
||
@dist_init | ||
def test_rref_local_value(self): | ||
if self.rank != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason for only testing on rank 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is no need to run it on every node.
It's beneficial because, when error happens, it's clearer on which side the error happens.
if self.rank != 0: | ||
return | ||
|
||
dst_worker_name = worker_name((self.rank + 1) % self.world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is different from how test_rref_is_owner
gets the worker name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
surely I can update that as well. Only intended to move it for clustering relevant tests togetherr,
# type: (RRef[Tensor]) -> Tensor | ||
return rref_local_value(rref) | ||
|
||
with self.assertRaisesRegex(RuntimeError, ""): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we check if the error message matches?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updating.
# type: (str, RRef[Tensor]) -> Tensor | ||
args = (rref,) | ||
kwargs = {} | ||
fut = rpc.rpc_async(dst_worker_name, rref_local_value, args, kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we cannot use rpc_sync
yet because it is not added to TorchScript, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct.
@@ -82,7 +131,7 @@ def script_run_forward_rref_my_script_module(rref): | |||
return rref.to_here().forward() | |||
|
|||
|
|||
class LocalRRefTest(RpcAgentTestFixture): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change relevant? And why do we need this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's a by-the-way change.
because I shouldn't add RpcAgentTestFixture here, it's already added to the final inheritance layer, JitRpcTest.
ret = fut.wait() | ||
return ret | ||
|
||
remote_call_rref_local_value(dst_worker_name, rref) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we assert the returned value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, that's better.
Creating local RRef in TorchScript is not suppoorted yet.
I added a test closer to your intention. |
Differential Revision: D7923050 Differential Version: 100957849
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Stamp to unblock. Let's wait for tests to pass.
Lint failure is irrelevant.
{
path: 'tools/clang_format_new.py',
start_line: 83,
end_line: 83,
start_column: 9,
end_column: 9,
annotation_level: 'failure',
message: '[E999] SyntaxError: invalid syntax'
}
Differential Revision: D7923050 Differential Version: 100961347
This pull request has been merged in 9b4bbaa. |
Stack:
:black_circle: #35433 Add RRef.local_value() for TorchScript 💛
Make RRef TorchScript API the same as RRef Python API.
Differential Revision: D7923050