Skip to content

Commit

Permalink
Feat (QuantTensor): initial support for interpolate and pixel_shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Apr 20, 2023
1 parent 67df105 commit 80a6680
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions src/brevitas/quant_tensor/torch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,38 @@ def adaptive_max_pool2d_handler(*args, **kwargs):
@implements(F.adaptive_max_pool3d)
def adaptive_max_pool3d_handler(*args, **kwargs):
return quant_invariant_handler(F.adaptive_max_pool3d, *args, **kwargs)


@implements(F.interpolate)
def interpolate_handler(
inp,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None,
recompute_scale_factor=None,
**kwargs): # support newer kwargs added in recent pytorch versions
if mode == 'nearest' or mode == 'nearest_exact':
return quant_invariant_handler(
F.interpolate,
inp,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
**kwargs)
else:
return F.interpolate(
inp.value,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
**kwargs)


@implements(F.pixel_shuffle)
def pixel_shuffle_handler(*args, **kwargs):
return quant_invariant_handler(F.pixel_shuffle_handler, *args, **kwargs)

0 comments on commit 80a6680

Please sign in to comment.