Skip to content

Sync fork #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packaging/windows/internal/cuda_install.bat
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ if not exist "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" (
)

if not exist "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" (
curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.1-windows10-x64-v7.6.4.38.zip --output "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip"
curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.2-windows10-x64-v7.6.5.32.zip --output "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip"
if errorlevel 1 exit /b 1
set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip"
)
Expand Down
12 changes: 10 additions & 2 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,16 @@ def main(args):
print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
model.to(device)

if not (args.test_only or args.post_training_quantize):
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(args.backend)
torch.quantization.prepare_qat(model, inplace=True)

if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay)
Expand All @@ -65,8 +69,6 @@ def main(args):
step_size=args.lr_step_size,
gamma=args.lr_gamma)

model.to(device)

criterion = nn.CrossEntropyLoss()
model_without_ddp = model
if args.distributed:
Expand Down Expand Up @@ -224,6 +226,12 @@ def parse_args():
It also serializes the transforms",
action="store_true",
)
parser.add_argument(
"--sync-bn",
dest="sync_bn",
help="Use sync batch norm",
action="store_true",
)
parser.add_argument(
"--test-only",
dest="test_only",
Expand Down
57 changes: 57 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import unittest
import random
import colorsys
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple


Expand Down Expand Up @@ -56,6 +57,45 @@ def test_crop(self):
cropped_img_script = script_crop(img_tensor, top, left, height, width)
self.assertTrue(torch.equal(img_cropped, cropped_img_script))

def test_hsv2rgb(self):
shape = (3, 100, 150)
for _ in range(20):
img = torch.rand(*shape, dtype=torch.float)
ft_img = F_t._hsv2rgb(img).permute(1, 2, 0).flatten(0, 1)

h, s, v, = img.unbind(0)
h = h.flatten().numpy()
s = s.flatten().numpy()
v = v.flatten().numpy()

rgb = []
for h1, s1, v1 in zip(h, s, v):
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))

colorsys_img = torch.tensor(rgb, dtype=torch.float32)
max_diff = (ft_img - colorsys_img).abs().max()
self.assertLess(max_diff, 1e-5)

def test_rgb2hsv(self):
shape = (3, 150, 100)
for _ in range(20):
img = torch.rand(*shape, dtype=torch.float)
ft_hsv_img = F_t._rgb2hsv(img).permute(1, 2, 0).flatten(0, 1)

r, g, b, = img.unbind(0)
r = r.flatten().numpy()
g = g.flatten().numpy()
b = b.flatten().numpy()

hsv = []
for r1, g1, b1 in zip(r, g, b):
hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))

colorsys_img = torch.tensor(hsv, dtype=torch.float32)

max_diff = (colorsys_img - ft_hsv_img).abs().max()
self.assertLess(max_diff, 1e-5)

def test_adjustments(self):
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
Expand Down Expand Up @@ -97,6 +137,23 @@ def test_adjustments(self):
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
self.assertTrue(torch.equal(img, img_clone))

# test for class interface
f = transforms.ColorJitter(brightness=factor.item())
scripted_fn = torch.jit.script(f)
scripted_fn(img)

f = transforms.ColorJitter(contrast=factor.item())
scripted_fn = torch.jit.script(f)
scripted_fn(img)

f = transforms.ColorJitter(saturation=factor.item())
scripted_fn = torch.jit.script(f)
scripted_fn(img)

f = transforms.ColorJitter(brightness=1)
scripted_fn = torch.jit.script(f)
scripted_fn(img)

def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
Expand Down
35 changes: 9 additions & 26 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,39 +446,22 @@ def test_heatmaps_to_keypoints(self):
assert torch.all(out2[1].eq(out_trace2[1]))

def test_keypoint_rcnn(self):
class KeyPointRCNN(torch.nn.Module):
def __init__(self):
super(KeyPointRCNN, self).__init__()
self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(
pretrained=True, min_size=200, max_size=300)

def forward(self, images):
output = self.model(images)
# TODO: The keypoints_scores require the use of Argmax that is updated in ONNX.
# For now we are testing all the output of KeypointRCNN except keypoints_scores.
# Enable When Argmax is updated in ONNX Runtime.
return output[0]['boxes'], output[0]['labels'], output[0]['scores'], output[0]['keypoints']

images, test_images = self.get_test_images()
# TODO:
# Enable test for dummy_image (no detection) once issue is
# _onnx_heatmaps_to_keypoints_loop for empty heatmaps is fixed
# dummy_images = [torch.ones(3, 100, 100) * 0.3]
model = KeyPointRCNN()
dummy_images = [torch.ones(3, 100, 100) * 0.3]
model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
model.eval()
model(images)
self.run_model(model, [(images,), (test_images,)],
self.run_model(model, [(images,), (test_images,), (dummy_images,)],
input_names=["images_tensors"],
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
dynamic_axes={"images_tensors": [0, 1, 2, 3]},
tolerate_small_mismatch=True)

self.run_model(model, [(dummy_images,), (test_images,)],
input_names=["images_tensors"],
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
dynamic_axes={"images_tensors": [0, 1, 2, 3]},
tolerate_small_mismatch=True)
# TODO: enable this test once dynamic model export is fixed
# Test exported model for an image with no detections on other images
# self.run_model(model, [(dummy_images,), (test_images,)],
# input_names=["images_tensors"],
# output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
# dynamic_axes={"images_tensors": [0, 1, 2, 3]},
# tolerate_small_mismatch=True)


if __name__ == '__main__':
Expand Down
4 changes: 4 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ def test_nms(self):
keep_ref = self.reference_nms(boxes, scores, iou)
keep = ops.nms(boxes, scores, iou)
self.assertTrue(torch.allclose(keep, keep_ref), err_msg.format(iou))
self.assertRaises(RuntimeError, ops.nms, torch.rand(4), torch.rand(3), 0.5)
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 5), torch.rand(3), 0.5)
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(3, 2), 0.5)
self.assertRaises(RuntimeError, ops.nms, torch.rand(3, 4), torch.rand(4), 0.5)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_nms_cuda(self):
Expand Down
121 changes: 120 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,22 @@
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')


def cycle_over(objs):
objs = list(objs)
for idx, obj in enumerate(objs):
yield obj, objs[:idx] + objs[idx + 1:]


def int_dtypes():
yield from iter(
(torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,)
)


def float_dtypes():
yield from iter((torch.float32, torch.float, torch.float64, torch.double))


class Tester(unittest.TestCase):

def test_crop(self):
Expand Down Expand Up @@ -299,13 +315,22 @@ def test_pad(self):
width = random.randint(10, 32) * 2
img = torch.ones(3, height, width)
padding = random.randint(1, 20)
fill = random.randint(1, 50)
result = transforms.Compose([
transforms.ToPILImage(),
transforms.Pad(padding),
transforms.Pad(padding, fill=fill),
transforms.ToTensor(),
])(img)
self.assertEqual(result.size(1), height + 2 * padding)
self.assertEqual(result.size(2), width + 2 * padding)
# check that all elements in the padded region correspond
# to the pad value
fill_v = fill / 255
eps = 1e-5
self.assertTrue((result[:, :padding, :] - fill_v).abs().max() < eps)
self.assertTrue((result[:, :, :padding] - fill_v).abs().max() < eps)
self.assertRaises(ValueError, transforms.Pad(padding, fill=(1, 2)),
transforms.ToPILImage()(img))

def test_pad_with_tuple_of_pad_values(self):
height = random.randint(10, 32) * 2
Expand Down Expand Up @@ -501,6 +526,100 @@ def test_to_tensor(self):
output = trans(img)
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))

def test_convert_image_dtype_float_to_float(self):
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in output_dtypes:
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
output_image = transform(input_image)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0

self.assertAlmostEqual(actual_min, desired_min)
self.assertAlmostEqual(actual_max, desired_max)

def test_convert_image_dtype_float_to_int(self):
for input_dtype in float_dtypes():
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
for output_dtype in int_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)

if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
input_dtype == torch.float64 and output_dtype == torch.int64
):
with self.assertRaises(RuntimeError):
transform(input_image)
else:
output_image = transform(input_image)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, torch.iinfo(output_dtype).max

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max)

def test_convert_image_dtype_int_to_float(self):
for input_dtype in int_dtypes():
input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype)
for output_dtype in float_dtypes():
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
output_image = transform(input_image)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0.0, 1.0

self.assertAlmostEqual(actual_min, desired_min)
self.assertGreaterEqual(actual_min, desired_min)
self.assertAlmostEqual(actual_max, desired_max)
self.assertLessEqual(actual_max, desired_max)

def test_convert_image_dtype_int_to_int(self):
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype)
for output_dtype in output_dtypes:
output_max = torch.iinfo(output_dtype).max

with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
output_image = transform(input_image)

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, output_max

# see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details
if input_max >= output_max:
error_term = 0
else:
error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1)

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max + error_term)

def test_convert_image_dtype_int_to_int_consistency(self):
for input_dtype, output_dtypes in cycle_over(int_dtypes()):
input_max = torch.iinfo(input_dtype).max
input_image = torch.tensor((0, input_max), dtype=input_dtype)
for output_dtype in output_dtypes:
output_max = torch.iinfo(output_dtype).max
if output_max <= input_max:
continue

with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
transform = transforms.ConvertImageDtype(output_dtype)
inverse_transfrom = transforms.ConvertImageDtype(input_dtype)
output_image = inverse_transfrom(transform(input_image))

actual_min, actual_max = output_image.tolist()
desired_min, desired_max = 0, input_max

self.assertEqual(actual_min, desired_min)
self.assertEqual(actual_max, desired_max)

@unittest.skipIf(accimage is None, 'accimage not available')
def test_accimage_to_tensor(self):
trans = transforms.ToTensor()
Expand Down
Loading