Skip to content

Commit

Permalink
[CodeStyle][B017] catch more specific exceptions in unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Apr 5, 2023
1 parent ea8aa43 commit 8b3cfab
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 66 deletions.
41 changes: 1 addition & 40 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ select = [
"B002",
"B003",
"B004",
# "B005",
# "B006",
# "B007",
# "B008",
"B009",
"B010",
"B011",
Expand All @@ -55,65 +51,30 @@ select = [
"B014",
"B015",
"B016",
# "B017",
"B017",
"B018",
"B019",
"B020",
"B021",
"B022",
# "B023",
# "B024",
"B025",
# "B026",
# "B027",
# "B028",
"B029",
# "B030",
"B032",
# "B904",

# Pylint
"PLC0414",
# "PLC1901",
"PLC3002",
"PLE0100",
"PLE0101",
# "PLE0116",
# "PLE0117",
# "PLE0118",
"PLE0604",
"PLE0605",
"PLE1142",
"PLE1205",
"PLE1206",
"PLE1307",
# "PLE1310",
# "PLE1507",
"PLE2502",
# "PLE2510",
# "PLE2512",
# "PLE2513",
# "PLE2514",
# "PLE2515",
# "PLR0133",
"PLR0206",
"PLR0402",
# "PLR0911",
# "PLR0912",
# "PLR0913",
# "PLR0915",
# "PLR1701",
# "PLR1711",
# "PLR1722",
# "PLR2004",
# "PLR5501",
# "PLW0120",
# "PLW0129",
# "PLW0602",
# "PLW0603",
# "PLW0711",
# "PLW1508",
# "PLW2901",
]
unfixable = [
"NPY001"
Expand Down
18 changes: 7 additions & 11 deletions python/paddle/distributed/fleet/base/graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __str__(self):

return (
'{'
+ 'rank={};'.format(self.kind)
+ f'rank={self.kind};'
+ ','.join([node.name for node in self.nodes])
+ '}'
)
Expand Down Expand Up @@ -97,7 +97,7 @@ def compile(self, dot_path):
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
logging.warning("write block debug graph to {}".format(image_path))
logging.warning(f"write block debug graph to {image_path}")
return image_path

def show(self, dot_path):
Expand Down Expand Up @@ -125,13 +125,11 @@ def _rank_repr(self):
def __str__(self):
reprs = [
'digraph G {',
'title = {}'.format(crepr(self.title)),
f'title = {crepr(self.title)}',
]

for attr in self.attrs:
reprs.append(
"{key}={value};".format(key=attr, value=crepr(self.attrs[attr]))
)
reprs.append(f"{attr}={crepr(self.attrs[attr])};")

reprs.append(self._rank_repr())

Expand Down Expand Up @@ -161,8 +159,7 @@ def __str__(self):
label=self.label,
extra=','
+ ','.join(
"%s=%s" % (key, crepr(value))
for key, value in self.attrs.items()
f"{key}={crepr(value)}" for key, value in self.attrs.items()
)
if self.attrs
else "",
Expand Down Expand Up @@ -191,8 +188,7 @@ def __str__(self):
if not self.attrs
else "["
+ ','.join(
"{}={}".format(attr[0], crepr(attr[1]))
for attr in self.attrs.items()
f"{attr[0]}={crepr(attr[1])}" for attr in self.attrs.items()
)
+ "]",
)
Expand Down Expand Up @@ -292,5 +288,5 @@ def add_edge(self, source, target, **kwargs):
source,
target,
color="#00000" if not highlight else "orange",
**kwargs
**kwargs,
)
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def pass_config(self):
]

def test_resnet(self):
with self.assertRaises(Exception):
with self.assertRaises(Exception): # noqa: B017
self.check_main(resnet_model, batch_size=32)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def set_feeds(self):
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
with self.assertRaises(Exception):
with self.assertRaises(Exception): # noqa: B017
self.check_output_with_option(use_gpu)


Expand All @@ -99,7 +99,7 @@ def set_feeds(self):
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
with self.assertRaises(Exception):
with self.assertRaises(Exception): # noqa: B017
self.check_output_with_option(use_gpu)


Expand Down
14 changes: 12 additions & 2 deletions test/dygraph_to_static/test_convert_call_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

import paddle
from paddle.jit import to_static
from paddle.jit.dy2static.convert_call_func import translator_logger


def dyfunc_generator():
for i in range(100):
yield paddle.fluid.dygraph.to_variable([i] * 10)
yield paddle.to_tensor([i] * 10)


def main_func():
Expand All @@ -31,8 +32,17 @@ def main_func():

class TestConvertGenerator(unittest.TestCase):
def test_raise_error(self):
with self.assertRaises(Exception):
translator_logger.verbosity_level = 1
with self.assertLogs(
translator_logger.logger_name, level='WARNING'
) as cm:
to_static(main_func)()
self.assertRegex(
cm.output[0],
"Your function:`dyfunc_generator` doesn't support "
"to transform to static function because it is a "
"generator function",
)


if __name__ == '__main__':
Expand Down
33 changes: 23 additions & 10 deletions test/legacy_test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,10 @@ def test_exception(self):

trans_batch = transforms.Compose([transforms.Resize(-1)])

with self.assertRaises(Exception):
with self.assertRaises((cv2.error, AssertionError, ValueError)):
self.do_transform(trans)

with self.assertRaises(Exception):
with self.assertRaises((cv2.error, AssertionError, ValueError)):
self.do_transform(trans_batch)

with self.assertRaises(ValueError):
Expand Down Expand Up @@ -411,22 +411,35 @@ def test_exception(self):
with self.assertRaises(NotImplementedError):
transform = transforms.BrightnessTransform('0.1', keys='a')

with self.assertRaises(Exception):
with self.assertRaisesRegex(
AssertionError, "scale should be a tuple or list"
):
transform = transforms.RandomErasing(scale=0.5)

with self.assertRaises(Exception):
with self.assertRaisesRegex(
AssertionError, "ratio should be a tuple or list"
):
transform = transforms.RandomErasing(ratio=0.8)

with self.assertRaises(Exception):
with self.assertRaisesRegex(
AssertionError,
r"scale should be of kind \(min, max\) and in range \[0, 1\]",
):
transform = transforms.RandomErasing(scale=(10, 0.4))

with self.assertRaises(Exception):
with self.assertRaisesRegex(
AssertionError, r"ratio should be of kind \(min, max\)"
):
transform = transforms.RandomErasing(ratio=(3.3, 0.3))

with self.assertRaises(Exception):
with self.assertRaisesRegex(
AssertionError, r"The probability should be in range \[0, 1\]"
):
transform = transforms.RandomErasing(prob=1.5)

with self.assertRaises(Exception):
with self.assertRaisesRegex(
ValueError, r"value must be 'random' when type is str"
):
transform = transforms.RandomErasing(value="0")

def test_info(self):
Expand Down Expand Up @@ -571,10 +584,10 @@ def test_exception(self):

trans_batch = transforms.Compose([transforms.Resize(-1)])

with self.assertRaises(Exception):
with self.assertRaises((cv2.error, AssertionError, ValueError)):
self.do_transform(trans)

with self.assertRaises(Exception):
with self.assertRaises((cv2.error, AssertionError, ValueError)):
self.do_transform(trans_batch)

with self.assertRaises(ValueError):
Expand Down

0 comments on commit 8b3cfab

Please sign in to comment.