Skip to content

Commit

Permalink
refactor: Fix python linting issues
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
  • Loading branch information
narendasan committed Apr 21, 2021
1 parent 2af2c11 commit 59868d2
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion cpp/ptq/training/vgg16/export_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test(model, dataloader, crit):

with torch.no_grad():
for data, labels in dataloader:
data, labels = data.cuda(), labels.cuda(async=True)
data, labels = data.cuda(), labels.cuda(non_blocking=True)
out = model(data)
loss += crit(out, labels)
preds = torch.max(out, 1)[1]
Expand Down
2 changes: 1 addition & 1 deletion cpp/ptq/training/vgg16/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def train(model, dataloader, crit, opt, epoch):
model.train()
running_loss = 0.0
for batch, (data, labels) in enumerate(dataloader):
data, labels = data.cuda(), labels.cuda(async=True)
data, labels = data.cuda(), labels.cuda(non_blocking=True)
opt.zero_grad()
out = model(data)
loss = crit(out, labels)
Expand Down
2 changes: 1 addition & 1 deletion py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
if "max_batch_size" in compile_spec:
assert type(compile_spec["max_batch_size"]) is int
info.max_batch_size = compile_spec["max_batch_size"]

if "truncate_long_and_double" in compile_spec:
assert type(compile_spec["truncate_long_and_double"]) is bool
info.truncate_long_and_double = compile_spec["truncate_long_and_double"]
Expand Down

0 comments on commit 59868d2

Please sign in to comment.