-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix precision + QLoRA state dict tests, DTensor init #1087
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1087
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 69a372d with merge base abe798d (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, just one comment and then good to go.
_set_float32_precision("highest") | ||
assert torch.get_float32_matmul_precision() == "highest" | ||
assert not torch.backends.cudnn.allow_tf32 | ||
assert not torch.backends.cuda.matmul.allow_tf32 | ||
|
||
_set_float32_precision("high") | ||
setattr( # noqa: B010 | ||
torch.backends, "__allow_nonbracketed_mutation_flag", True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason for adding "setattr" a second time and its placement here? Or doing it once at the start is enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry second one should be False. Will fix
stride=sharded_meta_param.stride(), | ||
) | ||
# BC-breaking change to DTensor API in https://github.com/pytorch/pytorch/pull/128112 | ||
if version.parse(torch.__version__) >= version.parse("2.4.0.dev20240606"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanchaol @weifengpy please let me know if this is the best way to handle this, also open to any other ideas either of you have here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also regarding a longer-term plan here, are there any plans to make DTensor
a public API? Or should we consider whether there's an alternative way to do this without directly calling into the DTensor
API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chatted about this. the long-term solution is to use public DTensor api
return DTensor.from_local(
local_tensor,
sharding_spec.mesh,
sharding_spec.placements,
shape=sharding_spec.shape,
stride=sharding_spec.stride,
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since this is state_dict logic, I think we should just directly use DTensor.from_local
? DTensor.__new__
is kinda private API of DTensor.
We are going to make DTensor API be public but probably not DTensor.__new__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more question here: if using with NF4, does it mean we will have to also add logic for aten.view in ao (based on this line)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch. i think so. do you know the args, kwargs after dispatching? if view_as(shape) matches NF4.shape, that’s trivial. otherwise we need a deeper look because the semantics of NF4.view_as(different shape) is not trivial
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, it does not have to be addressed in this PR, if it needs torchao change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah personally I don't know the args or kwargs so need to do a bit more investigation. Will leave this as a fast follow for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just found pytorch from source does not work. torch.__version__
is 2.4.0a0+git....
. My commit hash is before DTensor BC breaking, but It goes into DTensor(spec)
and complained cc @ebsmothers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🫡🫡🫡
You can rebase on main now that #1095 has been merged and the workflows should pass. |
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
After #855 added usage of FSDPTest we cannot set
torch.backends.cudnn
in the same test env as anFSDPTest
. The fix it to settorch.backends.__allow_nonbracketed_mutation_flag
to True inside the test, similar to what was done in #855. Unfortunately we missed this test case because our unit tests that require GPUs do not currently run in CI.Also
test_qlora_state_dict
is failing as it requires a newer version of torchao. Currenttorchao
does not have__version__
defined, so we add the wonderful check"__version__" not in dir(torchao)
. Separately, there is an issue where theDTensor
API changed in a BC-breaking way. I bisected to the appropriate nightly and for now I gate based on that torch version. Once 2.4 releases we should be able to remove this check.Oh also the way we were checking
version("torchao")
in_register_nf4_dispatch_ops.py
doesn't really account for the possibility of havingtorchao-nightly
installed (increasingly prevalent lately..). So I updated that bit of code to account for possible nightly torchao install. Once torchao has a version and we are on 0.3 all of this can go.Follow-ups:
Test plan
Prior to these changes:
After these changes (and on a nightly version of torchao)