Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix precision + QLoRA state dict tests, DTensor init #1087

Merged
merged 8 commits into from
Jun 18, 2024

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Jun 13, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

After #855 added usage of FSDPTest we cannot set torch.backends.cudnn in the same test env as an FSDPTest. The fix it to set torch.backends.__allow_nonbracketed_mutation_flag to True inside the test, similar to what was done in #855. Unfortunately we missed this test case because our unit tests that require GPUs do not currently run in CI.

Also test_qlora_state_dict is failing as it requires a newer version of torchao. Current torchao does not have __version__ defined, so we add the wonderful check "__version__" not in dir(torchao). Separately, there is an issue where the DTensor API changed in a BC-breaking way. I bisected to the appropriate nightly and for now I gate based on that torch version. Once 2.4 releases we should be able to remove this check.

Oh also the way we were checking version("torchao") in _register_nf4_dispatch_ops.py doesn't really account for the possibility of having torchao-nightly installed (increasingly prevalent lately..). So I updated that bit of code to account for possible nightly torchao install. Once torchao has a version and we are on 0.3 all of this can go.

Follow-ups:

  • Test GPU unit tests in our CI (there are quite a few of them and rn we have no signal on these, many of which are for more bleeding-edge features)
  • Define proper version-based gating across our various dependencies and apply it consistently across the repo

Test plan

Prior to these changes:

pytest tests/
...
FAILED tests/torchtune/utils/test_distributed.py::TestFullyShardState::test_qlora_state_dict - RuntimeError: Process 0 exited with error code 10 and exception:
FAILED tests/torchtune/utils/test_precision.py::TestPrecisionUtils::test_set_float32_precision - RuntimeError: not allowed to set torch.backends.cudnn flags after disable_global_flags; please use flags() context manager instead
============== 2 failed, 242 passed, 34 skipped, 2 warnings in 55.96s ======================

After these changes (and on a nightly version of torchao)

pytest tests/
...
===== 244 passed, 34 skipped, 2 warnings in 63.99s (0:01:03) ========

Copy link

pytorch-bot bot commented Jun 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1087

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 69a372d with merge base abe798d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 13, 2024
Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, just one comment and then good to go.

tests/torchtune/utils/test_precision.py Show resolved Hide resolved
_set_float32_precision("highest")
assert torch.get_float32_matmul_precision() == "highest"
assert not torch.backends.cudnn.allow_tf32
assert not torch.backends.cuda.matmul.allow_tf32

_set_float32_precision("high")
setattr( # noqa: B010
torch.backends, "__allow_nonbracketed_mutation_flag", True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason for adding "setattr" a second time and its placement here? Or doing it once at the start is enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry second one should be False. Will fix

@ebsmothers ebsmothers changed the title Fix precision test Fix precision test and DTensor init Jun 13, 2024
@ebsmothers ebsmothers changed the title Fix precision test and DTensor init Fix precision + QLoRA state dict tests, DTensor init Jun 13, 2024
stride=sharded_meta_param.stride(),
)
# BC-breaking change to DTensor API in https://github.com/pytorch/pytorch/pull/128112
if version.parse(torch.__version__) >= version.parse("2.4.0.dev20240606"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol @weifengpy please let me know if this is the best way to handle this, also open to any other ideas either of you have here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also regarding a longer-term plan here, are there any plans to make DTensor a public API? Or should we consider whether there's an alternative way to do this without directly calling into the DTensor API?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chatted about this. the long-term solution is to use public DTensor api

return DTensor.from_local(
    local_tensor,
    sharding_spec.mesh,
    sharding_spec.placements,
    shape=sharding_spec.shape,
    stride=sharding_spec.stride,
)

Copy link

@wanchaol wanchaol Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is state_dict logic, I think we should just directly use DTensor.from_local? DTensor.__new__ is kinda private API of DTensor.

We are going to make DTensor API be public but probably not DTensor.__new__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more question here: if using with NF4, does it mean we will have to also add logic for aten.view in ao (based on this line)?

Copy link
Contributor

@weifengpy weifengpy Jun 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. i think so. do you know the args, kwargs after dispatching? if view_as(shape) matches NF4.shape, that’s trivial. otherwise we need a deeper look because the semantics of NF4.view_as(different shape) is not trivial

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, it does not have to be addressed in this PR, if it needs torchao change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah personally I don't know the args or kwargs so need to do a bit more investigation. Will leave this as a fast follow for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just found pytorch from source does not work. torch.__version__ is 2.4.0a0+git..... My commit hash is before DTensor BC breaking, but It goes into DTensor(spec) and complained cc @ebsmothers

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🫡🫡🫡

@joecummings
Copy link
Contributor

You can rebase on main now that #1095 has been merged and the workflows should pass.

@ebsmothers ebsmothers merged commit 66d1a9c into pytorch:main Jun 18, 2024
29 checks passed
maximegmd pushed a commit to maximegmd/torchtune that referenced this pull request Jul 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants