Skip to content

Commit

Permalink
Remove two if statements in fp8 padding (pytorch#935)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#935

Reviewed By: vkuzo

Differential Revision: D63051205
  • Loading branch information
y-sq authored and facebook-github-bot committed Sep 24, 2024
1 parent 728d629 commit abec9d6
Showing 1 changed file with 1 addition and 7 deletions.
8 changes: 1 addition & 7 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ def _get_min_alignment(size: int, alignment_value: int) -> int:
16
```
"""
if size % alignment_value == 0:
return size
return (1 + (size // alignment_value)) * alignment_value
return (1 + ((size-1) // alignment_value)) * alignment_value


def pad_tensor_for_matmul(
Expand Down Expand Up @@ -234,10 +232,6 @@ def pad_tensor_for_matmul(
dim1_aligned = _get_min_alignment(dim1, 16) if 0 in dims else dim1
dim2_aligned = _get_min_alignment(dim2, 16) if 1 in dims else dim2

# Check if padding is needed for either dimension
if dim1 == dim1_aligned and dim2 == dim2_aligned:
return tensor

# Calculate padding values for both dimensions
pad_dim1 = dim1_aligned - dim1
pad_dim2 = dim2_aligned - dim2
Expand Down

0 comments on commit abec9d6

Please sign in to comment.