Skip to content

Commit

Permalink
[BUG] Fixing an error triggered from the conv_channel_last_pass whi…
Browse files Browse the repository at this point in the history
…le compiling the model `sam` (#444)

Closes #325 

The error in the linked issue was caused by [this code
segment](https://github.com/CentML/hidet/blob/bfbb4db6d7792ed3de3be4e9702e597b8fbbe373/python/hidet/graph/transforms/conv_channel_last.py#L46-L75)
in `graph/transforms/conv_channel_last.py`.

By the logic flow of this code segment, if the operator `node` has two
inputs, the first one with rank 4 and the second rank 3(an example case
in the model: an `AddOp` where the first input has shape `[1, 256, 64,
64]` and the second `[256, 1, 1]`) , then by the time the code reaches
the line 75, the variable `new_perm`would have value `[1, 2, 0]`, and
this value will be recorded as the permutation scheme used to get the
new output, which is incorrect as the appropriate value should be `[0,
2, 3, 1]` here.
  • Loading branch information
BolinSNLHM authored and vadiklyutiy committed Dec 20, 2024
1 parent 4f142c4 commit ba45522
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions python/hidet/graph/transforms/conv_channel_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def reforward(self, tensor_map: Dict[Tensor, Tuple[Tensor, Optional[List[int]]]]
node = self.op
new_inputs: List[Tensor] = []
update_attributes: Dict[str, Any] = {}

rank_to_perm = {4: [0, 2, 3, 1], 3: [1, 2, 0], 2: [1, 0], 1: [0]}

for x in node.inputs:
if x in tensor_map:
current_x, current_perm = tensor_map[x]
Expand All @@ -55,24 +58,18 @@ def reforward(self, tensor_map: Dict[Tensor, Tuple[Tensor, Optional[List[int]]]]
else:
# Input is not channel last, convert it to channel last
x_rank = len(x.shape)
if x_rank == 4:
new_perm = [0, 2, 3, 1]
elif x_rank == 3:
new_perm = [1, 2, 0]
elif x_rank == 2:
new_perm = [1, 0]
elif x_rank == 1:
new_perm = [0]
else:
new_perm = rank_to_perm.get(x_rank, None)
if new_perm is None:
raise ValueError('Channel Last Pass met input tensor of scoped operator with shape > 4.')

new_x = transpose(current_x, new_perm)
tensor_map[x] = (new_x, new_perm)
new_inputs.append(new_x)
if 'axis' in node.attrs and isinstance(node.attrs['axis'], int):
update_attributes['axis'] = new_perm.index(node.attrs['axis'])
outputs = node.reforward(new_inputs, update_attributes)
for idx, y in enumerate(node.outputs):
tensor_map[y] = (outputs[idx], new_perm)
tensor_map[y] = (outputs[idx], rank_to_perm[len(outputs[idx].shape)])

@staticmethod
def initialize_scoped_ops():
Expand Down

0 comments on commit ba45522

Please sign in to comment.