Skip to content

Commit

Permalink
fixing dtype and device in test
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles committed Oct 22, 2024
1 parent 708c698 commit 2b08079
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,8 +1291,8 @@ def forward(self, x):
y = self.mha(x, x, x)[0]
return self.lin(y)

mod = MHAModel()
input = torch.randn(1,1,4096)
mod = MHAModel().to(device).to(dtype)
input = torch.randn(1,1,4096).to(device).to(dtype)
out=mod(*input)

torchao.autoquant(mod, set_inductor_config=False)
Expand Down

0 comments on commit 2b08079

Please sign in to comment.