Skip to content

Commit 11101c9

Browse files
committed
update
1 parent 859e714 commit 11101c9

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2229,6 +2229,7 @@ def _aten_upsample_output_size(
22292229
output_size,
22302230
mode=mode,
22312231
coordinate_transformation_mode=coordinate_transformation_mode,
2232+
nearest_mode="floor",
22322233
)
22332234

22342235

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1683,7 +1683,6 @@ def sample_inputs_upsample_nearest3d(op_info, device, dtype, requires_grad, **kw
16831683
SS = 3
16841684
L = 5
16851685

1686-
align_corners_options = (True, False)
16871686
rank = 3
16881687

16891688
def shape(size, rank, with_batch_channel=True):
@@ -1724,6 +1723,30 @@ def shape(size, rank, with_batch_channel=True):
17241723

17251724

17261725
def sample_inputs_upsample_trilinear3d(op_info, device, dtype, requires_grad, **kwargs):
1726+
del op_info
1727+
del kwargs
1728+
1729+
N, C = 2, 3
1730+
D = 4
1731+
SS = 3
1732+
L = 5
1733+
1734+
align_corners_options = (True, False)
1735+
rank = 3
1736+
1737+
def shape(size, rank, with_batch_channel=True):
1738+
if with_batch_channel:
1739+
return tuple([N, C] + ([size] * rank))
1740+
return tuple([size] * rank)
1741+
1742+
make_arg = functools.partial(
1743+
torch_testing.make_tensor,
1744+
device=device,
1745+
dtype=dtype,
1746+
requires_grad=requires_grad,
1747+
low=-1,
1748+
high=1,
1749+
)
17271750

17281751
for align_corners in align_corners_options:
17291752
yield opinfo_core.SampleInput(

0 commit comments

Comments
 (0)