From 1211d8b2ddc325538d46254eafd7f5010310b07c Mon Sep 17 00:00:00 2001 From: lizz Date: Thu, 25 Feb 2021 16:25:32 +0800 Subject: [PATCH] Linting (#207) * add pylintrc Signed-off-by: lizz * python3 style super Signed-off-by: lizz * add Signed-off-by: lizz * lint Signed-off-by: lizz * no (object) Signed-off-by: lizz * tiny Signed-off-by: lizz * ha Signed-off-by: lizz * typos Signed-off-by: lizz * typo Signed-off-by: lizz * typo Signed-off-by: lizz * lint Signed-off-by: lizz * lint Signed-off-by: lizz * more lint Signed-off-by: lizz * Fix out_channels unused bug in EDVRNet Signed-off-by: lizz * lint Signed-off-by: lizz --- .pylintrc | 621 ++++++++++++++++++ mmedit/__init__.py | 10 +- mmedit/apis/test.py | 55 +- mmedit/core/distributed_wrapper.py | 2 +- mmedit/core/evaluation/eval_hooks.py | 3 +- mmedit/core/evaluation/metric_utils.py | 4 +- mmedit/core/evaluation/metrics.py | 25 +- mmedit/core/mask.py | 2 +- mmedit/core/optimizer/builder.py | 5 +- mmedit/core/scheduler/lr_updater.py | 13 +- mmedit/datasets/base_dataset.py | 8 +- mmedit/datasets/base_matting_dataset.py | 2 +- mmedit/datasets/base_sr_dataset.py | 2 +- mmedit/datasets/dataset_wrappers.py | 2 +- mmedit/datasets/generation_paired_dataset.py | 2 +- .../datasets/generation_unpaired_dataset.py | 2 +- mmedit/datasets/img_inpainting_dataset.py | 2 +- mmedit/datasets/pipelines/augmentation.py | 29 +- mmedit/datasets/pipelines/compose.py | 2 +- mmedit/datasets/pipelines/crop.py | 16 +- mmedit/datasets/pipelines/formating.py | 22 +- mmedit/datasets/pipelines/loading.py | 8 +- mmedit/datasets/pipelines/matting_aug.py | 22 +- mmedit/datasets/pipelines/normalization.py | 4 +- mmedit/datasets/pipelines/utils.py | 7 +- mmedit/datasets/sr_annotation_dataset.py | 3 +- mmedit/datasets/sr_folder_dataset.py | 2 +- mmedit/datasets/sr_lmdb_dataset.py | 2 +- mmedit/datasets/sr_reds_dataset.py | 2 +- mmedit/datasets/sr_vid4_dataset.py | 2 +- mmedit/datasets/sr_vimeo90k_dataset.py | 2 +- .../decoders/deepfill_decoder.py | 4 +- .../encoder_decoders/decoders/gl_decoder.py | 2 +- .../decoders/indexnet_decoder.py | 12 +- .../decoders/pconv_decoder.py | 2 +- .../decoders/plain_decoder.py | 2 +- .../encoder_decoders/decoders/resnet_dec.py | 9 +- .../encoders/deepfill_encoder.py | 2 +- .../encoder_decoders/encoders/gl_encoder.py | 2 +- .../encoders/indexnet_encoder.py | 38 +- .../encoders/pconv_encoder.py | 4 +- .../encoder_decoders/encoders/resnet_enc.py | 35 +- .../encoder_decoders/encoders/vgg.py | 4 +- .../encoder_decoders/gl_encoder_decoder.py | 2 +- .../necks/contextual_attention_neck.py | 2 +- .../encoder_decoders/necks/gl_dilation.py | 2 +- .../encoder_decoders/pconv_encoder_decoder.py | 2 +- .../simple_encoder_decoder.py | 2 +- .../two_stage_encoder_decoder.py | 2 +- .../generation_backbones/resnet_generator.py | 2 +- .../generation_backbones/unet_generator.py | 4 +- mmedit/models/backbones/sr_backbones/duf.py | 2 +- mmedit/models/backbones/sr_backbones/edsr.py | 4 +- .../models/backbones/sr_backbones/edvr_net.py | 10 +- .../models/backbones/sr_backbones/rrdb_net.py | 6 +- .../backbones/sr_backbones/sr_resnet.py | 2 +- mmedit/models/backbones/sr_backbones/srcnn.py | 2 +- mmedit/models/backbones/sr_backbones/tof.py | 8 +- mmedit/models/base.py | 13 +- mmedit/models/builder.py | 4 +- mmedit/models/common/aspp.py | 7 +- mmedit/models/common/contextual_attention.py | 12 +- mmedit/models/common/conv.py | 2 +- mmedit/models/common/gated_conv_module.py | 2 +- mmedit/models/common/gca_module.py | 14 +- .../models/common/generation_model_utils.py | 12 +- mmedit/models/common/linear_module.py | 6 +- mmedit/models/common/mask_conv_module.py | 8 +- mmedit/models/common/model_utils.py | 2 +- mmedit/models/common/partial_conv.py | 10 +- mmedit/models/common/separable_conv_module.py | 2 +- mmedit/models/common/sr_backbone_utils.py | 2 +- mmedit/models/common/upsample.py | 2 +- .../discriminators/deepfill_disc.py | 2 +- .../components/discriminators/gl_disc.py | 2 +- .../components/discriminators/modified_vgg.py | 2 +- .../discriminators/multi_layer_disc.py | 2 +- .../components/discriminators/patch_disc.py | 2 +- .../components/refiners/deepfill_refiner.py | 2 +- .../components/refiners/plain_refiner.py | 4 +- mmedit/models/inpaintors/deepfillv1.py | 4 +- mmedit/models/inpaintors/gl_inpaintor.py | 6 +- mmedit/models/inpaintors/one_stage.py | 8 +- mmedit/models/inpaintors/pconv_inpaintor.py | 2 +- mmedit/models/inpaintors/two_stage.py | 4 +- mmedit/models/losses/composition_loss.py | 6 +- mmedit/models/losses/gan_loss.py | 10 +- mmedit/models/losses/gradient_loss.py | 2 +- mmedit/models/losses/perceptual_loss.py | 4 +- mmedit/models/losses/pixelwise_loss.py | 12 +- mmedit/models/losses/utils.py | 6 +- mmedit/models/mattors/base_mattor.py | 13 +- mmedit/models/mattors/dim.py | 3 +- mmedit/models/mattors/gca.py | 3 +- mmedit/models/mattors/indexnet.py | 3 +- mmedit/models/restorers/basic_restorer.py | 8 +- mmedit/models/restorers/edvr.py | 4 +- mmedit/models/restorers/srgan.py | 8 +- mmedit/models/synthesizers/cycle_gan.py | 14 +- mmedit/models/synthesizers/pix2pix.py | 10 +- mmedit/version.py | 10 +- tests/test_augmentation.py | 2 +- tests/test_crop.py | 2 +- tests/test_dataset_builder.py | 4 +- tests/test_datasets.py | 46 +- tests/test_eval_hook.py | 2 +- tests/test_loading.py | 6 +- tests/test_normalization.py | 2 +- tests/test_optimizer.py | 2 +- tests/test_visual_hook.py | 2 +- 110 files changed, 994 insertions(+), 379 deletions(-) create mode 100644 .pylintrc diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000000..d7a39be85d --- /dev/null +++ b/.pylintrc @@ -0,0 +1,621 @@ +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-whitelist= + +# Specify a score threshold to be exceeded before program exits with error. +fail-under=10.0 + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS,configs + +# Add files or directories matching the regex patterns to the blacklist. The +# regex matches against base names, not paths. +ignore-patterns= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=print-statement, + parameter-unpacking, + unpacking-in-except, + old-raise-syntax, + backtick, + long-suffix, + old-ne-operator, + old-octal-literal, + import-star-module-level, + non-ascii-bytes-literal, + raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + apply-builtin, + basestring-builtin, + buffer-builtin, + cmp-builtin, + coerce-builtin, + execfile-builtin, + file-builtin, + long-builtin, + raw_input-builtin, + reduce-builtin, + standarderror-builtin, + unicode-builtin, + xrange-builtin, + coerce-method, + delslice-method, + getslice-method, + setslice-method, + no-absolute-import, + old-division, + dict-iter-method, + dict-view-method, + next-method-called, + metaclass-assignment, + indexing-exception, + raising-string, + reload-builtin, + oct-method, + hex-method, + nonzero-method, + cmp-method, + input-builtin, + round-builtin, + intern-builtin, + unichr-builtin, + map-builtin-not-iterating, + zip-builtin-not-iterating, + range-builtin-not-iterating, + filter-builtin-not-iterating, + using-cmp-argument, + eq-without-hash, + div-method, + idiv-method, + rdiv-method, + exception-message-attribute, + invalid-str-codec, + sys-max-int, + bad-python3-import, + deprecated-string-function, + deprecated-str-translate-call, + deprecated-itertools-function, + deprecated-types-field, + next-method-defined, + dict-items-not-iterating, + dict-keys-not-iterating, + dict-values-not-iterating, + deprecated-operator-function, + deprecated-urllib-function, + xreadlines-attribute, + deprecated-sys-function, + exception-escape, + comprehension-escape, + no-member, + invalid-name, + too-many-branches, + wrong-import-order, + too-many-arguments, + missing-function-docstring, + missing-module-docstring, + too-many-locals, + too-few-public-methods, + abstract-method, + broad-except, + too-many-nested-blocks, + too-many-instance-attributes, + missing-class-docstring, + duplicate-code, + not-callable, + protected-access, + dangerous-default-value, + no-name-in-module, + logging-fstring-interpolation, + super-init-not-called, + redefined-builtin, + attribute-defined-outside-init, + arguments-differ, + cyclic-import, + bad-super-call, + too-many-statements + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'error', 'warning', 'refactor', and 'convention' +# which contain the number of messages in each category, as well as 'statement' +# which is the total number of statements analyzed. This score is used by the +# global evaluation report (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +#notes-rgx= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. +#class-attribute-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _, + x, + y, + w, + h, + a, + b + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. +#variable-rgx= + + +[DESIGN] + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled). +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled). +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "BaseException, Exception". +overgeneral-exceptions=BaseException, + Exception diff --git a/mmedit/__init__.py b/mmedit/__init__.py index 7d15f8a324..2fd06c244b 100644 --- a/mmedit/__init__.py +++ b/mmedit/__init__.py @@ -7,15 +7,15 @@ def digit_version(version_str): - digit_version = [] + digit_ver = [] for x in version_str.split('.'): if x.isdigit(): - digit_version.append(int(x)) + digit_ver.append(int(x)) elif x.find('rc') != -1: patch_version = x.split('rc') - digit_version.append(int(patch_version[0]) - 1) - digit_version.append(int(patch_version[1])) - return digit_version + digit_ver.append(int(patch_version[0]) - 1) + digit_ver.append(int(patch_version[1])) + return digit_ver mmcv_min_version = digit_version(MMCV_MIN) diff --git a/mmedit/apis/test.py b/mmedit/apis/test.py index 0fc8abfb40..d6be06ac82 100644 --- a/mmedit/apis/test.py +++ b/mmedit/apis/test.py @@ -162,21 +162,21 @@ def collect_results_cpu(result_part, size, tmpdir=None): # collect all parts if rank != 0: return None - else: - # load results of all parts from tmp dir - part_list = [] - for i in range(world_size): - part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i)) - part_list.append(mmcv.load(part_file)) - # sort the results - ordered_results = [] - for res in zip(*part_list): - ordered_results.extend(list(res)) - # the dataloader may pad some samples - ordered_results = ordered_results[:size] - # remove tmp dir - shutil.rmtree(tmpdir) - return ordered_results + + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = osp.join(tmpdir, 'part_{}.pkl'.format(i)) + part_list.append(mmcv.load(part_file)) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) + return ordered_results def collect_results_gpu(result_part, size): @@ -211,15 +211,16 @@ def collect_results_gpu(result_part, size): # gather all result part dist.all_gather(part_recv_list, part_send) - if rank == 0: - part_list = [] - for recv, shape in zip(part_recv_list, shape_list): - part_list.append( - pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())) - # sort the results - ordered_results = [] - for res in zip(*part_list): - ordered_results.extend(list(res)) - # the dataloader may pad some samples - ordered_results = ordered_results[:size] - return ordered_results + if rank != 0: + return None + + part_list = [] + for recv, shape in zip(part_recv_list, shape_list): + part_list.append(pickle.loads(recv[:shape[0]].cpu().numpy().tobytes())) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + return ordered_results diff --git a/mmedit/core/distributed_wrapper.py b/mmedit/core/distributed_wrapper.py index 1107a339ea..cba5a6368b 100644 --- a/mmedit/core/distributed_wrapper.py +++ b/mmedit/core/distributed_wrapper.py @@ -53,7 +53,7 @@ def __init__(self, broadcast_buffers=False, find_unused_parameters=False, **kwargs): - super(DistributedDataParallelWrapper, self).__init__() + super().__init__() assert len(device_ids) == 1, ( 'Currently, DistributedDataParallelWrapper only supports one' 'single CUDA device for each process.' diff --git a/mmedit/core/evaluation/eval_hooks.py b/mmedit/core/evaluation/eval_hooks.py index 95dcb93f5b..656776b3da 100644 --- a/mmedit/core/evaluation/eval_hooks.py +++ b/mmedit/core/evaluation/eval_hooks.py @@ -80,8 +80,7 @@ def __init__(self, interval=1, gpu_collect=False, **eval_kwargs): - super(DistEvalIterHook, self).__init__(dataloader, interval, - **eval_kwargs) + super().__init__(dataloader, interval, **eval_kwargs) self.gpu_collect = gpu_collect def after_train_iter(self, runner): diff --git a/mmedit/core/evaluation/metric_utils.py b/mmedit/core/evaluation/metric_utils.py index 70909a6e6b..2735a2e6cf 100644 --- a/mmedit/core/evaluation/metric_utils.py +++ b/mmedit/core/evaluation/metric_utils.py @@ -61,7 +61,9 @@ def gauss_filter(sigma, epsilon=1e-2): def gauss_gradient(img, sigma): """Gaussian gradient. - From https://www.mathworks.com/matlabcentral/mlc-downloads/downloads/submissions/8060/versions/2/previews/gaussgradient/gaussgradient.m/index.html # noqa + From https://www.mathworks.com/matlabcentral/mlc-downloads/downloads/ + submissions/8060/versions/2/previews/gaussgradient/gaussgradient.m/ + index.html Args: img (ndarray): Input image. diff --git a/mmedit/core/evaluation/metrics.py b/mmedit/core/evaluation/metrics.py index 897639886f..16256ba52b 100644 --- a/mmedit/core/evaluation/metrics.py +++ b/mmedit/core/evaluation/metrics.py @@ -19,8 +19,8 @@ def sad(alpha, trimap, pred_alpha): assert (pred_alpha[trimap == 255] == 255).all() alpha = alpha.astype(np.float64) / 255 pred_alpha = pred_alpha.astype(np.float64) / 255 - sad = np.abs(pred_alpha - alpha).sum() / 1000 - return sad + sad_result = np.abs(pred_alpha - alpha).sum() / 1000 + return sad_result def mse(alpha, trimap, pred_alpha): @@ -35,10 +35,10 @@ def mse(alpha, trimap, pred_alpha): pred_alpha = pred_alpha.astype(np.float64) / 255 weight_sum = (trimap == 128).sum() if weight_sum != 0: - mse = ((pred_alpha - alpha)**2).sum() / weight_sum + mse_result = ((pred_alpha - alpha)**2).sum() / weight_sum else: - mse = 0 - return mse + mse_result = 0 + return mse_result def gradient_error(alpha, trimap, pred_alpha, sigma=1.4): @@ -100,7 +100,6 @@ def connectivity(alpha, trimap, pred_alpha, step=0.1): alpha = alpha.astype(np.float32) / 255 pred_alpha = pred_alpha.astype(np.float32) / 255 - height, width = alpha.shape thresh_steps = np.arange(0, 1 + step, step) round_down_map = -np.ones_like(alpha) for i in range(1, len(thresh_steps)): @@ -196,10 +195,10 @@ def psnr(img1, img2, crop_border=0, input_order='HWC'): img1 = img1[crop_border:-crop_border, crop_border:-crop_border, None] img2 = img2[crop_border:-crop_border, crop_border:-crop_border, None] - mse = np.mean((img1 - img2)**2) - if mse == 0: + mse_value = np.mean((img1 - img2)**2) + if mse_value == 0: return float('inf') - return 20. * np.log10(255. / np.sqrt(mse)) + return 20. * np.log10(255. / np.sqrt(mse_value)) def _ssim(img1, img2): @@ -280,7 +279,7 @@ def ssim(img1, img2, crop_border=0, input_order='HWC'): return np.array(ssims).mean() -class L1Evaluation(object): +class L1Evaluation: """L1 evaluation metric. Args: @@ -347,8 +346,8 @@ def compute_feature(block): # the products of pairs of adjacent coefficients computed along # horizontal, vertical and diagonal orientations. shifts = [[0, 1], [1, 0], [1, 1], [1, -1]] - for i in range(len(shifts)): - shifted_block = np.roll(block, shifts[i], axis=(0, 1)) + for shift in shifts: + shifted_block = np.roll(block, shift, axis=(0, 1)) alpha, beta_l, beta_r = estimate_aggd_param(block * shifted_block) mean = (beta_r - beta_l) * (gamma(2 / alpha) / gamma(1 / alpha)) feat.extend([alpha, mean, beta_l, beta_r]) @@ -408,7 +407,7 @@ def niqe_core(img, feat = [] for idx_w in range(num_block_w): for idx_h in range(num_block_h): - # process ecah block + # process each block block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale, idx_w * block_size_w // diff --git a/mmedit/core/mask.py b/mmedit/core/mask.py index 9486a07fc4..9c1435f319 100644 --- a/mmedit/core/mask.py +++ b/mmedit/core/mask.py @@ -282,7 +282,7 @@ def random_irregular_mask(img_shape, angle = 2 * math.pi - angle length = length_list[direct_n] brush_w = brush_width_list[direct_n] - # compute end point accoriding to the random angle + # compute end point according to the random angle end_x = (start_x + length * np.sin(angle)).astype(np.int32) end_y = (start_y + length * np.cos(angle)).astype(np.int32) diff --git a/mmedit/core/optimizer/builder.py b/mmedit/core/optimizer/builder.py index 1c99e2bbf4..28d8f65574 100644 --- a/mmedit/core/optimizer/builder.py +++ b/mmedit/core/optimizer/builder.py @@ -46,11 +46,12 @@ def build_optimizers(model, cfgs): for key, cfg in cfgs.items(): if not isinstance(cfg, dict): is_dict_of_dict = False + if is_dict_of_dict: for key, cfg in cfgs.items(): cfg_ = cfg.copy() module = getattr(model, key) optimizers[key] = build_optimizer(module, cfg_) return optimizers - else: - return build_optimizer(model, cfgs) + + return build_optimizer(model, cfgs) diff --git a/mmedit/core/scheduler/lr_updater.py b/mmedit/core/scheduler/lr_updater.py index 43c51fdc38..fb9ad23b30 100644 --- a/mmedit/core/scheduler/lr_updater.py +++ b/mmedit/core/scheduler/lr_updater.py @@ -19,7 +19,7 @@ class LinearLrUpdaterHook(LrUpdaterHook): """ def __init__(self, target_lr=0, start=0, interval=1, **kwargs): - super(LinearLrUpdaterHook, self).__init__(**kwargs) + super().__init__(**kwargs) self.target_lr = target_lr self.start = start self.interval = interval @@ -41,10 +41,11 @@ def get_lr(self, runner, base_lr): progress = runner.iter max_progress = runner.max_iters assert max_progress >= self.start + if max_progress == self.start: return base_lr - else: - # Before 'start', fix lr; After 'start', linearly update lr. - factor = (max(0, progress - self.start) // self.interval) / ( - (max_progress - self.start) // self.interval) - return base_lr + (self.target_lr - base_lr) * factor + + # Before 'start', fix lr; After 'start', linearly update lr. + factor = (max(0, progress - self.start) // self.interval) / ( + (max_progress - self.start) // self.interval) + return base_lr + (self.target_lr - base_lr) * factor diff --git a/mmedit/datasets/base_dataset.py b/mmedit/datasets/base_dataset.py index 130f378ea2..a1cb4b1322 100644 --- a/mmedit/datasets/base_dataset.py +++ b/mmedit/datasets/base_dataset.py @@ -22,7 +22,7 @@ class BaseDataset(Dataset, metaclass=ABCMeta): """ def __init__(self, pipeline, test_mode=False): - super(BaseDataset, self).__init__() + super().__init__() self.test_mode = test_mode self.pipeline = Compose(pipeline) @@ -71,7 +71,7 @@ def __getitem__(self, idx): Args: idx (int): Index for getting each item. """ - if not self.test_mode: - return self.prepare_train_data(idx) - else: + if self.test_mode: return self.prepare_test_data(idx) + + return self.prepare_train_data(idx) diff --git a/mmedit/datasets/base_matting_dataset.py b/mmedit/datasets/base_matting_dataset.py index af156c12f1..35d838aee3 100644 --- a/mmedit/datasets/base_matting_dataset.py +++ b/mmedit/datasets/base_matting_dataset.py @@ -10,7 +10,7 @@ class BaseMattingDataset(BaseDataset): """ def __init__(self, ann_file, pipeline, data_prefix=None, test_mode=False): - super(BaseMattingDataset, self).__init__(pipeline, test_mode) + super().__init__(pipeline, test_mode) self.ann_file = str(ann_file) self.data_prefix = str(data_prefix) self.data_infos = self.load_annotations() diff --git a/mmedit/datasets/base_sr_dataset.py b/mmedit/datasets/base_sr_dataset.py index b372b3772b..07ff48c8dc 100644 --- a/mmedit/datasets/base_sr_dataset.py +++ b/mmedit/datasets/base_sr_dataset.py @@ -16,7 +16,7 @@ class BaseSRDataset(BaseDataset): """ def __init__(self, pipeline, scale, test_mode=False): - super(BaseSRDataset, self).__init__(pipeline, test_mode) + super().__init__(pipeline, test_mode) self.scale = scale @staticmethod diff --git a/mmedit/datasets/dataset_wrappers.py b/mmedit/datasets/dataset_wrappers.py index 5b9a3783c3..8c189c09f4 100644 --- a/mmedit/datasets/dataset_wrappers.py +++ b/mmedit/datasets/dataset_wrappers.py @@ -2,7 +2,7 @@ @DATASETS.register_module() -class RepeatDataset(object): +class RepeatDataset: """A wrapper of repeated dataset. The length of repeated dataset will be `times` larger than the original diff --git a/mmedit/datasets/generation_paired_dataset.py b/mmedit/datasets/generation_paired_dataset.py index 2df184d567..2f3ce262ac 100644 --- a/mmedit/datasets/generation_paired_dataset.py +++ b/mmedit/datasets/generation_paired_dataset.py @@ -21,7 +21,7 @@ class GenerationPairedDataset(BaseGenerationDataset): """ def __init__(self, dataroot, pipeline, test_mode=False): - super(GenerationPairedDataset, self).__init__(pipeline, test_mode) + super().__init__(pipeline, test_mode) phase = 'test' if test_mode else 'train' self.dataroot = osp.join(str(dataroot), phase) self.data_infos = self.load_annotations() diff --git a/mmedit/datasets/generation_unpaired_dataset.py b/mmedit/datasets/generation_unpaired_dataset.py index d7a6aedf09..acf5ce9307 100644 --- a/mmedit/datasets/generation_unpaired_dataset.py +++ b/mmedit/datasets/generation_unpaired_dataset.py @@ -25,7 +25,7 @@ class GenerationUnpairedDataset(BaseGenerationDataset): """ def __init__(self, dataroot, pipeline, test_mode=False): - super(GenerationUnpairedDataset, self).__init__(pipeline, test_mode) + super().__init__(pipeline, test_mode) phase = 'test' if test_mode else 'train' self.dataroot_a = osp.join(str(dataroot), phase + 'A') self.dataroot_b = osp.join(str(dataroot), phase + 'B') diff --git a/mmedit/datasets/img_inpainting_dataset.py b/mmedit/datasets/img_inpainting_dataset.py index cb878096c2..216e6b5c3e 100644 --- a/mmedit/datasets/img_inpainting_dataset.py +++ b/mmedit/datasets/img_inpainting_dataset.py @@ -10,7 +10,7 @@ class ImgInpaintingDataset(BaseDataset): """ def __init__(self, ann_file, pipeline, data_prefix=None, test_mode=False): - super(ImgInpaintingDataset, self).__init__(pipeline, test_mode) + super().__init__(pipeline, test_mode) self.ann_file = str(ann_file) self.data_prefix = str(data_prefix) self.data_infos = self.load_annotations() diff --git a/mmedit/datasets/pipelines/augmentation.py b/mmedit/datasets/pipelines/augmentation.py index 040d808f8e..a36d427b8c 100644 --- a/mmedit/datasets/pipelines/augmentation.py +++ b/mmedit/datasets/pipelines/augmentation.py @@ -11,7 +11,7 @@ @PIPELINES.register_module() -class Resize(object): +class Resize: """Resize data to a specific size for training or resize the images to fit the network input regulation for testing. @@ -147,7 +147,7 @@ def __repr__(self): @PIPELINES.register_module() -class Flip(object): +class Flip: """Flip the input data with a probability. Reverse the order of elements in the given data with a specific direction. @@ -205,7 +205,7 @@ def __repr__(self): @PIPELINES.register_module() -class Pad(object): +class Pad: """Pad the images to align with network downsample factor for testing. See `Reshape` for more explanation. `numpy.pad` is used for the pad @@ -263,13 +263,15 @@ def __repr__(self): @PIPELINES.register_module() -class RandomAffine(object): +class RandomAffine: """Apply random affine to input images. This class is adopted from - https://github.com/pytorch/vision/blob/v0.5.0/torchvision/transforms/transforms.py#L1015 # noqa + https://github.com/pytorch/vision/blob/v0.5.0/torchvision/transforms/ + transforms.py#L1015 It should be noted that in - https://github.com/Yaoyi-Li/GCA-Matting/blob/master/dataloader/data_generator.py#L70 # noqa + https://github.com/Yaoyi-Li/GCA-Matting/blob/master/dataloader/ + data_generator.py#L70 random flip is added. See explanation of `flip_ratio` below. Required keys are the keys in attribute "keys", modified keys are keys in attribute "keys". @@ -475,7 +477,7 @@ def __repr__(self): @PIPELINES.register_module() -class RandomJitter(object): +class RandomJitter: """Randomly jitter the foreground in hsv space. The jitter range of hue is adjustable while the jitter ranges of saturation @@ -545,7 +547,7 @@ def __repr__(self): return self.__class__.__name__ + f'hue_range={self.hue_range}' -class BinarizeImage(object): +class BinarizeImage: """Binarize image. Args: @@ -590,7 +592,7 @@ def __repr__(self): @PIPELINES.register_module() -class RandomMaskDilation(object): +class RandomMaskDilation: """Randomly dilate binary masks. Args: @@ -613,7 +615,6 @@ def _random_dilate(self, img): kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8) dilate_kernel_size = kernel_size img_ = cv2.dilate(img, kernel, iterations=1) - h, w = img_.shape[:2] img_ = (img_ > self.binary_thr).astype(np.float32) @@ -646,7 +647,7 @@ def __repr__(self): @PIPELINES.register_module() -class RandomTransposeHW(object): +class RandomTransposeHW: """Randomly transpose images in H and W dimensions with a probability. (TransposeHW = horizontal flip + anti-clockwise rotatation by 90 degrees) @@ -697,7 +698,7 @@ def __repr__(self): @PIPELINES.register_module() -class GenerateFrameIndiceswithPadding(object): +class GenerateFrameIndiceswithPadding: """Generate frame index with padding for REDS dataset and Vid4 dataset during testing. @@ -787,7 +788,7 @@ def __repr__(self): @PIPELINES.register_module() -class GenerateFrameIndices(object): +class GenerateFrameIndices: """Generate frame index for REDS datasets. It also performs temporal augmention with random interval. @@ -855,7 +856,7 @@ def __repr__(self): @PIPELINES.register_module() -class TemporalReverse(object): +class TemporalReverse: """Reverse frame lists for temporal augmentation. Required keys are the keys in attributes "lq" and "gt", diff --git a/mmedit/datasets/pipelines/compose.py b/mmedit/datasets/pipelines/compose.py index e7cb68e48f..9fcf2d1d73 100644 --- a/mmedit/datasets/pipelines/compose.py +++ b/mmedit/datasets/pipelines/compose.py @@ -6,7 +6,7 @@ @PIPELINES.register_module() -class Compose(object): +class Compose: """Compose a data pipeline with a sequence of transforms. Args: diff --git a/mmedit/datasets/pipelines/crop.py b/mmedit/datasets/pipelines/crop.py index 3cfb2be0ca..c7f0520700 100644 --- a/mmedit/datasets/pipelines/crop.py +++ b/mmedit/datasets/pipelines/crop.py @@ -9,7 +9,7 @@ @PIPELINES.register_module() -class Crop(object): +class Crop: """Crop data to specific size for training. Args: @@ -87,7 +87,7 @@ def __repr__(self): @PIPELINES.register_module() -class FixedCrop(object): +class FixedCrop: """Crop paired data (at a specific position) to specific size for training. Args: @@ -165,7 +165,7 @@ def __repr__(self): @PIPELINES.register_module() -class PairedRandomCrop(object): +class PairedRandomCrop: """Paried random crop. It crops a pair of lq and gt images with corresponding locations. @@ -241,7 +241,7 @@ def __repr__(self): @PIPELINES.register_module() -class CropAroundCenter(object): +class CropAroundCenter: """Randomly crop the images around unknown area in the center 1/4 images. This cropping strategy is adopted in GCA matting. The `unknown area` is the @@ -329,7 +329,7 @@ def __repr__(self): @PIPELINES.register_module() -class CropAroundUnknown(object): +class CropAroundUnknown: """Crop around unknown area with a randomly selected scale. Randomly select the w and h from a list of (w, h). @@ -372,7 +372,7 @@ def __init__(self, if unknown_source not in ['alpha', 'trimap']: raise ValueError('unknown_source must be "alpha" or "trimap", ' f'but got {unknown_source}') - elif unknown_source not in keys: + if unknown_source not in keys: # it could only be trimap, since alpha is checked before raise ValueError( 'if unknown_source is "trimap", it must also be set in keys') @@ -436,7 +436,7 @@ def __repr__(self): @PIPELINES.register_module() -class CropAroundFg(object): +class CropAroundFg: """Crop around the whole foreground in the segmentation mask. Required keys are "seg" and the keys in argument `keys`. @@ -507,7 +507,7 @@ def __call__(self, results): @PIPELINES.register_module() -class ModCrop(object): +class ModCrop: """Mod crop gt images, used during testing. Required keys are "scale" and "gt", diff --git a/mmedit/datasets/pipelines/formating.py b/mmedit/datasets/pipelines/formating.py index 012bdc817e..4c1ce105d1 100644 --- a/mmedit/datasets/pipelines/formating.py +++ b/mmedit/datasets/pipelines/formating.py @@ -17,20 +17,20 @@ def to_tensor(data): """ if isinstance(data, torch.Tensor): return data - elif isinstance(data, np.ndarray): + if isinstance(data, np.ndarray): return torch.from_numpy(data) - elif isinstance(data, Sequence) and not mmcv.is_str(data): + if isinstance(data, Sequence) and not mmcv.is_str(data): return torch.tensor(data) - elif isinstance(data, int): + if isinstance(data, int): return torch.LongTensor([data]) - elif isinstance(data, float): + if isinstance(data, float): return torch.FloatTensor([data]) - else: - raise TypeError(f'type {type(data)} cannot be converted to tensor.') + + raise TypeError(f'type {type(data)} cannot be converted to tensor.') @PIPELINES.register_module() -class ToTensor(object): +class ToTensor: """Convert some values in results dict to `torch.Tensor` type in data loader pipeline. @@ -60,7 +60,7 @@ def __repr__(self): @PIPELINES.register_module() -class ImageToTensor(object): +class ImageToTensor: """Convert image type to `torch.Tensor` type. Args: @@ -138,7 +138,7 @@ def __call__(self, results): @PIPELINES.register_module() -class GetMaskedImage(object): +class GetMaskedImage: """Get masked image. Args: @@ -176,7 +176,7 @@ def __repr__(self): @PIPELINES.register_module() -class FormatTrimap(object): +class FormatTrimap: """Convert trimap (tensor) to one-hot representation. It transforms the trimap label from (0, 128, 255) to (0, 1, 2). If @@ -219,7 +219,7 @@ def __repr__(self): @PIPELINES.register_module() -class Collect(object): +class Collect: """Collect data from the loader relevant to the specific task. This is usually the last stage of the data loader pipeline. Typically keys diff --git a/mmedit/datasets/pipelines/loading.py b/mmedit/datasets/pipelines/loading.py index 504c471ac5..b0a4d0152c 100644 --- a/mmedit/datasets/pipelines/loading.py +++ b/mmedit/datasets/pipelines/loading.py @@ -10,7 +10,7 @@ @PIPELINES.register_module() -class LoadImageFromFile(object): +class LoadImageFromFile: """Load image from file. Args: @@ -132,7 +132,7 @@ def __call__(self, results): @PIPELINES.register_module() -class RandomLoadResizeBg(object): +class RandomLoadResizeBg: """Randomly load a background image and resize it. Required key is "fg", added key is "bg". @@ -178,7 +178,7 @@ def __repr__(self): @PIPELINES.register_module() -class LoadMask(object): +class LoadMask: """Load Mask for multiple types. For different types of mask, users need to provide the corresponding @@ -340,7 +340,7 @@ def __repr__(self): @PIPELINES.register_module() -class GetSpatialDiscountMask(object): +class GetSpatialDiscountMask: """Get spatial discounting mask constant. Spatial discounting mask is first introduced in: diff --git a/mmedit/datasets/pipelines/matting_aug.py b/mmedit/datasets/pipelines/matting_aug.py index d68d99f3f1..eb8f29b79f 100644 --- a/mmedit/datasets/pipelines/matting_aug.py +++ b/mmedit/datasets/pipelines/matting_aug.py @@ -18,7 +18,7 @@ def add_gaussian_noise(img, mu, sigma): @PIPELINES.register_module() -class MergeFgAndBg(object): +class MergeFgAndBg: """Composite foreground image and background image with alpha. Required keys are "alpha", "fg" and "bg", added key is "merged". @@ -43,7 +43,7 @@ def __call__(self, results): @PIPELINES.register_module() -class GenerateTrimap(object): +class GenerateTrimap: """Using random erode/dilate to generate trimap from alpha matte. Required key is "alpha", added key is "trimap". @@ -138,7 +138,7 @@ def __repr__(self): @PIPELINES.register_module() -class GenerateTrimapWithDistTransform(object): +class GenerateTrimapWithDistTransform: """Generate trimap with distance transform function. Args: @@ -190,7 +190,7 @@ def __repr__(self): @PIPELINES.register_module() -class CompositeFg(object): +class CompositeFg: """Composite foreground with a random foreground. This class composites the current training sample with additional data @@ -292,7 +292,7 @@ def __repr__(self): @PIPELINES.register_module() -class GenerateSeg(object): +class GenerateSeg: """Generate segmentation mask from alpha matte. Args: @@ -373,7 +373,7 @@ def __call__(self, results): # generate some holes in segmentation mask num_holes = np.random.randint(*self.num_holes_range) - for i in range(num_holes): + for _ in range(num_holes): hole_size = random.choice(self.hole_sizes) unknown = trimap == 128 start_point = random_choose_unknown(unknown, hole_size) @@ -399,7 +399,7 @@ def __repr__(self): @PIPELINES.register_module() -class PerturbBg(object): +class PerturbBg: """Randomly add gaussian noise or gamma change to background image. Required key is "bg", added key is "noisy_bg". @@ -426,7 +426,7 @@ def __call__(self, results): dict: A dict containing the processed data and information. """ if np.random.rand() >= self.gamma_ratio: - # generate gaussian noise with random guassian N([-7, 7), [2, 6)) + # generate gaussian noise with random gaussian N([-7, 7), [2, 6)) mu = np.random.randint(-7, 7) sigma = np.random.randint(2, 6) results['noisy_bg'] = add_gaussian_noise(results['bg'], mu, sigma) @@ -441,13 +441,13 @@ def __repr__(self): @PIPELINES.register_module() -class GenerateSoftSeg(object): +class GenerateSoftSeg: """Generate soft segmentation mask from input segmentation mask. Required key is "seg", added key is "soft_seg". Args: - fg_thr (float, optional): Threhold of the foreground in the normalized + fg_thr (float, optional): Threshold of the foreground in the normalized input segmentation mask. Defaults to 0.2. border_width (int, optional): Width of border to be padded to the bottom of the mask. Defaults to 25. @@ -514,7 +514,7 @@ def __call__(self, results): dict: A dict containing the processed data and information. """ seg = results['seg'].astype(np.float32) / 255 - height, width = seg.shape[:2] + height, _ = seg.shape[:2] seg[seg > self.fg_thr] = 1 # to align with the original repo, pad the bottom of the mask diff --git a/mmedit/datasets/pipelines/normalization.py b/mmedit/datasets/pipelines/normalization.py index 292d02d8ce..fcb462b1e7 100644 --- a/mmedit/datasets/pipelines/normalization.py +++ b/mmedit/datasets/pipelines/normalization.py @@ -5,7 +5,7 @@ @PIPELINES.register_module() -class Normalize(object): +class Normalize: """Normalize images with the given mean and std value. Required keys are the keys in attribute "keys", added or modified keys are @@ -58,7 +58,7 @@ def __repr__(self): @PIPELINES.register_module() -class RescaleToZeroOne(object): +class RescaleToZeroOne: """Transform the images into a range between 0 and 1. Required keys are the keys in attribute "keys", added or modified keys are diff --git a/mmedit/datasets/pipelines/utils.py b/mmedit/datasets/pipelines/utils.py index 1ea2561f8b..2e1b0be7bb 100644 --- a/mmedit/datasets/pipelines/utils.py +++ b/mmedit/datasets/pipelines/utils.py @@ -34,7 +34,8 @@ def dtype_limits(image, clip_negative=False): """Return intensity limits, i.e. (min, max) tuple, of the image's dtype. This function is adopted from skimage: - https://github.com/scikit-image/scikit-image/blob/7e4840bd9439d1dfb6beaf549998452c99f97fdd/skimage/util/dtype.py#L35 # noqa + https://github.com/scikit-image/scikit-image/blob/ + 7e4840bd9439d1dfb6beaf549998452c99f97fdd/skimage/util/dtype.py#L35 Args: image (ndarray): Input image. @@ -55,7 +56,9 @@ def adjust_gamma(image, gamma=1, gain=1): """Performs Gamma Correction on the input image. This function is adopted from skimage: - https://github.com/scikit-image/scikit-image/blob/7e4840bd9439d1dfb6beaf549998452c99f97fdd/skimage/exposure/exposure.py#L439-L494 # noqa + https://github.com/scikit-image/scikit-image/blob/ + 7e4840bd9439d1dfb6beaf549998452c99f97fdd/skimage/exposure/ + exposure.py#L439-L494 Also known as Power Law Transform. This function transforms the input image pixelwise according to the diff --git a/mmedit/datasets/sr_annotation_dataset.py b/mmedit/datasets/sr_annotation_dataset.py index c80abf07c0..3e614f3d0e 100644 --- a/mmedit/datasets/sr_annotation_dataset.py +++ b/mmedit/datasets/sr_annotation_dataset.py @@ -42,10 +42,9 @@ def __init__(self, ann_file, pipeline, scale, - data_prefix=None, test_mode=False, filename_tmpl='{}'): - super(SRAnnotationDataset, self).__init__(pipeline, scale, test_mode) + super().__init__(pipeline, scale, test_mode) self.lq_folder = str(lq_folder) self.gt_folder = str(gt_folder) self.ann_file = str(ann_file) diff --git a/mmedit/datasets/sr_folder_dataset.py b/mmedit/datasets/sr_folder_dataset.py index d6e4ed4971..5ff4c989fa 100644 --- a/mmedit/datasets/sr_folder_dataset.py +++ b/mmedit/datasets/sr_folder_dataset.py @@ -55,7 +55,7 @@ def __init__(self, scale, test_mode=False, filename_tmpl='{}'): - super(SRFolderDataset, self).__init__(pipeline, scale, test_mode) + super().__init__(pipeline, scale, test_mode) self.lq_folder = str(lq_folder) self.gt_folder = str(gt_folder) self.filename_tmpl = filename_tmpl diff --git a/mmedit/datasets/sr_lmdb_dataset.py b/mmedit/datasets/sr_lmdb_dataset.py index 328d215621..23dd5b9773 100644 --- a/mmedit/datasets/sr_lmdb_dataset.py +++ b/mmedit/datasets/sr_lmdb_dataset.py @@ -60,7 +60,7 @@ class SRLmdbDataset(BaseSRDataset): """ def __init__(self, lq_folder, gt_folder, pipeline, scale, test_mode=False): - super(SRLmdbDataset, self).__init__(pipeline, scale, test_mode) + super().__init__(pipeline, scale, test_mode) self.lq_folder = str(lq_folder) self.gt_folder = str(gt_folder) self.scale = scale diff --git a/mmedit/datasets/sr_reds_dataset.py b/mmedit/datasets/sr_reds_dataset.py index eaa9697d73..add7f821b6 100644 --- a/mmedit/datasets/sr_reds_dataset.py +++ b/mmedit/datasets/sr_reds_dataset.py @@ -42,7 +42,7 @@ def __init__(self, scale, val_partition='official', test_mode=False): - super(SRREDSDataset, self).__init__(pipeline, scale, test_mode) + super().__init__(pipeline, scale, test_mode) assert num_input_frames % 2 == 1, ( f'num_input_frames should be odd numbers, ' f'but received {num_input_frames }.') diff --git a/mmedit/datasets/sr_vid4_dataset.py b/mmedit/datasets/sr_vid4_dataset.py index 0b4758fa47..c96680c827 100644 --- a/mmedit/datasets/sr_vid4_dataset.py +++ b/mmedit/datasets/sr_vid4_dataset.py @@ -46,7 +46,7 @@ def __init__(self, scale, filename_tmpl='{:08d}', test_mode=False): - super(SRVid4Dataset, self).__init__(pipeline, scale, test_mode) + super().__init__(pipeline, scale, test_mode) assert num_input_frames % 2 == 1, ( f'num_input_frames should be odd numbers, ' f'but received {num_input_frames}.') diff --git a/mmedit/datasets/sr_vimeo90k_dataset.py b/mmedit/datasets/sr_vimeo90k_dataset.py index 14a3487296..c856936f5d 100644 --- a/mmedit/datasets/sr_vimeo90k_dataset.py +++ b/mmedit/datasets/sr_vimeo90k_dataset.py @@ -41,7 +41,7 @@ def __init__(self, pipeline, scale, test_mode=False): - super(SRVimeo90KDataset, self).__init__(pipeline, scale, test_mode) + super().__init__(pipeline, scale, test_mode) assert num_input_frames % 2 == 1, ( f'num_input_frames should be odd numbers, ' f'but received {num_input_frames}.') diff --git a/mmedit/models/backbones/encoder_decoders/decoders/deepfill_decoder.py b/mmedit/models/backbones/encoder_decoders/decoders/deepfill_decoder.py index ece212a6be..f8d38005d8 100644 --- a/mmedit/models/backbones/encoder_decoders/decoders/deepfill_decoder.py +++ b/mmedit/models/backbones/encoder_decoders/decoders/deepfill_decoder.py @@ -40,7 +40,7 @@ def __init__(self, out_act_cfg=dict(type='clip', min=-1., max=1.), channel_factor=1., **kwargs): - super(DeepFillDecoder, self).__init__() + super().__init__() self.with_out_activation = out_act_cfg is not None conv_module = self._conv_type[conv_type] @@ -91,7 +91,7 @@ def forward(self, input_dict): x = input_dict for i in range(7): x = getattr(self, f'dec{i + 1}')(x) - if i == 1 or i == 3: + if i in (1, 3): x = F.interpolate(x, scale_factor=2) if self.with_out_activation: diff --git a/mmedit/models/backbones/encoder_decoders/decoders/gl_decoder.py b/mmedit/models/backbones/encoder_decoders/decoders/gl_decoder.py index f27e2e6782..8ecfe750a7 100644 --- a/mmedit/models/backbones/encoder_decoders/decoders/gl_decoder.py +++ b/mmedit/models/backbones/encoder_decoders/decoders/gl_decoder.py @@ -27,7 +27,7 @@ def __init__(self, norm_cfg=None, act_cfg=dict(type='ReLU'), out_act='clip'): - super(GLDecoder, self).__init__() + super().__init__() self.dec1 = ConvModule( in_channels, 256, diff --git a/mmedit/models/backbones/encoder_decoders/decoders/indexnet_decoder.py b/mmedit/models/backbones/encoder_decoders/decoders/indexnet_decoder.py index 6045eac829..feeb5d7e22 100644 --- a/mmedit/models/backbones/encoder_decoders/decoders/indexnet_decoder.py +++ b/mmedit/models/backbones/encoder_decoders/decoders/indexnet_decoder.py @@ -29,7 +29,7 @@ def __init__(self, kernel_size=5, norm_cfg=dict(type='BN'), conv_module=ConvModule): - super(IndexedUpsample, self).__init__() + super().__init__() self.conv = conv_module( in_channels, @@ -78,7 +78,7 @@ def __init__(self, norm_cfg=dict(type='BN'), separable_conv=False): # TODO: remove in_channels argument - super(IndexNetDecoder, self).__init__() + super().__init__() if separable_conv: conv_module = DepthwiseSeparableConvModule @@ -91,11 +91,11 @@ def __init__(self, blocks_out_channels = [96, 64, 32, 24, 16, 32, 32] self.decoder_layers = nn.ModuleList() - for in_channels, out_channels in zip(blocks_in_channels, - blocks_out_channels): + for in_channel, out_channel in zip(blocks_in_channels, + blocks_out_channels): self.decoder_layers.append( - IndexedUpsample(in_channels, out_channels, kernel_size, - norm_cfg, conv_module)) + IndexedUpsample(in_channel, out_channel, kernel_size, norm_cfg, + conv_module)) self.pred = nn.Sequential( conv_module( diff --git a/mmedit/models/backbones/encoder_decoders/decoders/pconv_decoder.py b/mmedit/models/backbones/encoder_decoders/decoders/pconv_decoder.py index 55090d8193..e18cb485ad 100644 --- a/mmedit/models/backbones/encoder_decoders/decoders/pconv_decoder.py +++ b/mmedit/models/backbones/encoder_decoders/decoders/pconv_decoder.py @@ -27,7 +27,7 @@ def __init__(self, interpolation='nearest', conv_cfg=dict(type='PConv', multi_channel=True), norm_cfg=dict(type='BN')): - super(PConvDecoder, self).__init__() + super().__init__() self.num_layers = num_layers self.interpolation = interpolation diff --git a/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py b/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py index 181c1182f8..b93f581639 100644 --- a/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py +++ b/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py @@ -13,7 +13,7 @@ class PlainDecoder(nn.Module): """ def __init__(self, in_channels): - super(PlainDecoder, self).__init__() + super().__init__() self.deconv6_1 = nn.Conv2d(in_channels, 512, kernel_size=1) self.deconv5_1 = nn.Conv2d(512, 512, kernel_size=5, padding=2) diff --git a/mmedit/models/backbones/encoder_decoders/decoders/resnet_dec.py b/mmedit/models/backbones/encoder_decoders/decoders/resnet_dec.py index 1db268d945..8eaa4825c6 100644 --- a/mmedit/models/backbones/encoder_decoders/decoders/resnet_dec.py +++ b/mmedit/models/backbones/encoder_decoders/decoders/resnet_dec.py @@ -112,7 +112,7 @@ def __init__(self, type='LeakyReLU', negative_slope=0.2, inplace=True), with_spectral_norm=False, late_downsample=False): - super(ResNetDec, self).__init__() + super().__init__() if block == 'BasicBlockDec': block = BasicBlockDec else: @@ -335,10 +335,9 @@ def __init__(self, type='LeakyReLU', negative_slope=0.2, inplace=True), with_spectral_norm=False, late_downsample=False): - super(ResGCADecoder, - self).__init__(block, layers, in_channels, kernel_size, conv_cfg, - norm_cfg, act_cfg, with_spectral_norm, - late_downsample) + super().__init__(block, layers, in_channels, kernel_size, conv_cfg, + norm_cfg, act_cfg, with_spectral_norm, + late_downsample) self.gca = GCAModule(128, 128) def forward(self, inputs): diff --git a/mmedit/models/backbones/encoder_decoders/encoders/deepfill_encoder.py b/mmedit/models/backbones/encoder_decoders/encoders/deepfill_encoder.py index 5aa61aebb4..1b740a09d3 100644 --- a/mmedit/models/backbones/encoder_decoders/encoders/deepfill_encoder.py +++ b/mmedit/models/backbones/encoder_decoders/encoders/deepfill_encoder.py @@ -35,7 +35,7 @@ def __init__(self, encoder_type='stage1', channel_factor=1., **kwargs): - super(DeepFillEncoder, self).__init__() + super().__init__() conv_module = self._conv_type[conv_type] channel_list_dict = dict( stage1=[32, 64, 64, 128, 128, 128], diff --git a/mmedit/models/backbones/encoder_decoders/encoders/gl_encoder.py b/mmedit/models/backbones/encoder_decoders/encoders/gl_encoder.py index 57fdbb6a62..46bd602c52 100644 --- a/mmedit/models/backbones/encoder_decoders/encoders/gl_encoder.py +++ b/mmedit/models/backbones/encoder_decoders/encoders/gl_encoder.py @@ -17,7 +17,7 @@ class GLEncoder(nn.Module): """ def __init__(self, norm_cfg=None, act_cfg=dict(type='ReLU')): - super(GLEncoder, self).__init__() + super().__init__() channel_list = [64, 128, 128, 256, 256, 256] kernel_size_list = [5, 3, 3, 3, 3, 3] diff --git a/mmedit/models/backbones/encoder_decoders/encoders/indexnet_encoder.py b/mmedit/models/backbones/encoder_decoders/encoders/indexnet_encoder.py index 3d8caf3c78..98b01cd217 100644 --- a/mmedit/models/backbones/encoder_decoders/encoders/indexnet_encoder.py +++ b/mmedit/models/backbones/encoder_decoders/encoders/indexnet_encoder.py @@ -63,17 +63,17 @@ def build_index_block(in_channels, bias=False, norm_cfg=None, act_cfg=None)) - else: - return ConvModule( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - groups=groups, - bias=False, - norm_cfg=None, - act_cfg=None) + + return ConvModule( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False, + norm_cfg=None, + act_cfg=None) class HolisticIndexBlock(nn.Module): @@ -96,7 +96,7 @@ def __init__(self, norm_cfg=dict(type='BN'), use_context=False, use_nonlinear=False): - super(HolisticIndexBlock, self).__init__() + super().__init__() if use_context: kernel_size, padding = 4, 1 @@ -163,7 +163,7 @@ def __init__(self, use_context=False, use_nonlinear=False, mode='o2o'): - super(DepthwiseIndexBlock, self).__init__() + super().__init__() groups = in_channels if mode == 'o2o' else 1 @@ -173,7 +173,7 @@ def __init__(self, kernel_size, padding = 2, 0 self.index_blocks = nn.ModuleList() - for i in range(4): + for _ in range(4): self.index_blocks.append( build_index_block( in_channels, @@ -220,7 +220,7 @@ def forward(self, x): class InvertedResidual(nn.Module): """Inverted residual layer for indexnet encoder. - It basicaly is a depthwise separable conv module. If `expand_ratio` is not + It basically is a depthwise separable conv module. If `expand_ratio` is not one, then a conv module of kernel_size 1 will be inserted to change the input channels to `in_channels * expand_ratio`. @@ -244,7 +244,7 @@ def __init__(self, expand_ratio, norm_cfg, use_res_connect=False): - super(InvertedResidual, self).__init__() + super().__init__() assert stride in [1, 2], 'stride must 1 or 2' self.use_res_connect = use_res_connect @@ -351,7 +351,7 @@ def __init__(self, freeze_bn=False, use_nonlinear=True, use_context=True): - super(IndexNetEncoder, self).__init__() + super().__init__() if out_stride not in [16, 32]: raise ValueError(f'out_stride must 16 or 32, got {out_stride}') @@ -361,7 +361,7 @@ def __init__(self, # we name the index network in the paper index_block if index_mode == 'holistic': index_block = HolisticIndexBlock - elif index_mode == 'o2o' or index_mode == 'm2o': + elif index_mode in ('o2o', 'm2o'): index_block = partial(DepthwiseIndexBlock, mode=index_mode) else: raise NameError('Unknown index block mode {}'.format(index_mode)) @@ -461,7 +461,7 @@ def _make_layer(self, layer_setting, norm_cfg): ] in_channels = out_channels - for i in range(1, num_blocks): + for _ in range(1, num_blocks): layers.append( InvertedResidual( in_channels, diff --git a/mmedit/models/backbones/encoder_decoders/encoders/pconv_encoder.py b/mmedit/models/backbones/encoder_decoders/encoders/pconv_encoder.py index 160f38292b..81265823fd 100644 --- a/mmedit/models/backbones/encoder_decoders/encoders/pconv_encoder.py +++ b/mmedit/models/backbones/encoder_decoders/encoders/pconv_encoder.py @@ -30,7 +30,7 @@ def __init__(self, conv_cfg=dict(type='PConv', multi_channel=True), norm_cfg=dict(type='BN', requires_grad=True), norm_eval=False): - super(PConvEncoder, self).__init__() + super().__init__() self.num_layers = num_layers self.norm_eval = norm_eval @@ -89,7 +89,7 @@ def __init__(self, act_cfg=dict(type='ReLU'))) def train(self, mode=True): - super(PConvEncoder, self).train(mode) + super().train(mode) if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only diff --git a/mmedit/models/backbones/encoder_decoders/encoders/resnet_enc.py b/mmedit/models/backbones/encoder_decoders/encoders/resnet_enc.py index dc5aa4cee5..3f93ec4623 100644 --- a/mmedit/models/backbones/encoder_decoders/encoders/resnet_enc.py +++ b/mmedit/models/backbones/encoder_decoders/encoders/resnet_enc.py @@ -37,8 +37,8 @@ def __init__(self, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), with_spectral_norm=False): - super(BasicBlock, self).__init__() - assert stride == 1 or stride == 2, ( + super().__init__() + assert stride in (1, 2), ( f'stride other than 1 and 2 is not implemented, got {stride}') assert stride != 2 or interpolation is not None, ( @@ -128,7 +128,7 @@ def __init__(self, act_cfg=dict(type='ReLU'), with_spectral_norm=False, late_downsample=False): - super(ResNetEnc, self).__init__() + super().__init__() if block == 'BasicBlock': block = BasicBlock else: @@ -314,19 +314,18 @@ def __init__(self, with_spectral_norm=False, late_downsample=False, order=('conv', 'act', 'norm')): - super(ResShortcutEnc, - self).__init__(block, layers, in_channels, conv_cfg, norm_cfg, - act_cfg, with_spectral_norm, late_downsample) + super().__init__(block, layers, in_channels, conv_cfg, norm_cfg, + act_cfg, with_spectral_norm, late_downsample) # TODO: rename self.midplanes to self.mid_channels in ResNetEnc self.shortcut_in_channels = [in_channels, self.midplanes, 64, 128, 256] self.shortcut_out_channels = [32, self.midplanes, 64, 128, 256] self.shortcut = nn.ModuleList() - for in_channels, out_channels in zip(self.shortcut_in_channels, - self.shortcut_out_channels): + for in_channel, out_channel in zip(self.shortcut_in_channels, + self.shortcut_out_channels): self.shortcut.append( - self._make_shortcut(in_channels, out_channels, conv_cfg, + self._make_shortcut(in_channel, out_channel, conv_cfg, norm_cfg, act_cfg, order, with_spectral_norm)) @@ -439,12 +438,10 @@ def __init__(self, with_spectral_norm=False, late_downsample=False, order=('conv', 'act', 'norm')): - super(ResGCAEncoder, - self).__init__(block, layers, in_channels, conv_cfg, norm_cfg, - act_cfg, with_spectral_norm, late_downsample, - order) + super().__init__(block, layers, in_channels, conv_cfg, norm_cfg, + act_cfg, with_spectral_norm, late_downsample, order) - assert in_channels == 4 or in_channels == 6, ( + assert in_channels in (4, 6), ( f'in_channels must be 4 or 6, but got {in_channels}') self.trimap_channels = in_channels - 3 @@ -453,12 +450,12 @@ def __init__(self, guidance_out_channels = [16, 32, 128] guidance_head = [] - for in_channels, out_channels in zip(guidance_in_channels, - guidance_out_channels): + for in_channel, out_channel in zip(guidance_in_channels, + guidance_out_channels): guidance_head += [ ConvModule( - in_channels, - out_channels, + in_channel, + out_channel, 3, stride=2, padding=1, @@ -477,7 +474,7 @@ def init_weights(self, pretrained=None): logger = get_root_logger() load_checkpoint(self, pretrained, strict=False, logger=logger) elif pretrained is None: - super(ResGCAEncoder, self).init_weights() + super().init_weights() else: raise TypeError('"pretrained" must be a str or None. ' f'But received {type(pretrained)}.') diff --git a/mmedit/models/backbones/encoder_decoders/encoders/vgg.py b/mmedit/models/backbones/encoder_decoders/encoders/vgg.py index b14f03d131..26fdc65ce3 100644 --- a/mmedit/models/backbones/encoder_decoders/encoders/vgg.py +++ b/mmedit/models/backbones/encoder_decoders/encoders/vgg.py @@ -9,7 +9,7 @@ @COMPONENTS.register_module() class VGG16(nn.Module): - """Customed VGG16 Encoder. + """Customized VGG16 Encoder. A 1x1 conv is added after the original VGG16 conv layers. The indices of max pooling layers are returned for unpooling layers in decoders. @@ -29,7 +29,7 @@ def __init__(self, batch_norm=False, aspp=False, dilations=None): - super(VGG16, self).__init__() + super().__init__() self.batch_norm = batch_norm self.aspp = aspp self.dilations = dilations diff --git a/mmedit/models/backbones/encoder_decoders/gl_encoder_decoder.py b/mmedit/models/backbones/encoder_decoders/gl_encoder_decoder.py index 0039f7eab2..797c1013a3 100644 --- a/mmedit/models/backbones/encoder_decoders/gl_encoder_decoder.py +++ b/mmedit/models/backbones/encoder_decoders/gl_encoder_decoder.py @@ -26,7 +26,7 @@ def __init__(self, encoder=dict(type='GLEncoder'), decoder=dict(type='GLDecoder'), dilation_neck=dict(type='GLDilationNeck')): - super(GLEncoderDecoder, self).__init__() + super().__init__() self.encoder = build_component(encoder) self.decoder = build_component(decoder) self.dilation_neck = build_component(dilation_neck) diff --git a/mmedit/models/backbones/encoder_decoders/necks/contextual_attention_neck.py b/mmedit/models/backbones/encoder_decoders/necks/contextual_attention_neck.py index 535527c32d..e4638e18e5 100644 --- a/mmedit/models/backbones/encoder_decoders/necks/contextual_attention_neck.py +++ b/mmedit/models/backbones/encoder_decoders/necks/contextual_attention_neck.py @@ -33,7 +33,7 @@ def __init__(self, act_cfg=dict(type='ELU'), contextual_attention_args=dict(softmax_scale=10.), **kwargs): - super(ContextualAttentionNeck, self).__init__() + super().__init__() self.contextual_attention = ContextualAttentionModule( **contextual_attention_args) conv_module = self._conv_type[conv_type] diff --git a/mmedit/models/backbones/encoder_decoders/necks/gl_dilation.py b/mmedit/models/backbones/encoder_decoders/necks/gl_dilation.py index d6510df3e2..d1cfc27839 100644 --- a/mmedit/models/backbones/encoder_decoders/necks/gl_dilation.py +++ b/mmedit/models/backbones/encoder_decoders/necks/gl_dilation.py @@ -29,7 +29,7 @@ def __init__(self, norm_cfg=None, act_cfg=dict(type='ReLU'), **kwargs): - super(GLDilationNeck, self).__init__() + super().__init__() conv_module = self._conv_type[conv_type] dilation_convs_ = [] for i in range(4): diff --git a/mmedit/models/backbones/encoder_decoders/pconv_encoder_decoder.py b/mmedit/models/backbones/encoder_decoders/pconv_encoder_decoder.py index 939bfbdb85..57983c400d 100644 --- a/mmedit/models/backbones/encoder_decoders/pconv_encoder_decoder.py +++ b/mmedit/models/backbones/encoder_decoders/pconv_encoder_decoder.py @@ -16,7 +16,7 @@ class PConvEncoderDecoder(nn.Module): """ def __init__(self, encoder, decoder): - super(PConvEncoderDecoder, self).__init__() + super().__init__() self.encoder = build_component(encoder) self.decoder = build_component(decoder) diff --git a/mmedit/models/backbones/encoder_decoders/simple_encoder_decoder.py b/mmedit/models/backbones/encoder_decoders/simple_encoder_decoder.py index a3de3f61bc..08c5a411b2 100644 --- a/mmedit/models/backbones/encoder_decoders/simple_encoder_decoder.py +++ b/mmedit/models/backbones/encoder_decoders/simple_encoder_decoder.py @@ -14,7 +14,7 @@ class SimpleEncoderDecoder(nn.Module): """ def __init__(self, encoder, decoder): - super(SimpleEncoderDecoder, self).__init__() + super().__init__() self.encoder = build_component(encoder) decoder['in_channels'] = self.encoder.out_channels diff --git a/mmedit/models/backbones/encoder_decoders/two_stage_encoder_decoder.py b/mmedit/models/backbones/encoder_decoders/two_stage_encoder_decoder.py index c5c01dee8e..a0d2975847 100644 --- a/mmedit/models/backbones/encoder_decoders/two_stage_encoder_decoder.py +++ b/mmedit/models/backbones/encoder_decoders/two_stage_encoder_decoder.py @@ -36,7 +36,7 @@ def __init__(self, act_cfg=dict(type='ELU'))), stage2=dict(type='DeepFillRefiner'), return_offset=False): - super(DeepFillEncoderDecoder, self).__init__() + super().__init__() self.stage1 = build_backbone(stage1) self.stage2 = build_component(stage2) diff --git a/mmedit/models/backbones/generation_backbones/resnet_generator.py b/mmedit/models/backbones/generation_backbones/resnet_generator.py index 6a4716171f..a946d1bb6f 100644 --- a/mmedit/models/backbones/generation_backbones/resnet_generator.py +++ b/mmedit/models/backbones/generation_backbones/resnet_generator.py @@ -39,7 +39,7 @@ def __init__(self, num_blocks=9, padding_mode='reflect', init_cfg=dict(type='normal', gain=0.02)): - super(ResnetGenerator, self).__init__() + super().__init__() assert num_blocks >= 0, ('Number of residual blocks must be ' f'non-negative, but got {num_blocks}.') assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but" diff --git a/mmedit/models/backbones/generation_backbones/unet_generator.py b/mmedit/models/backbones/generation_backbones/unet_generator.py index 9ba3d25c9b..8d8a528f78 100644 --- a/mmedit/models/backbones/generation_backbones/unet_generator.py +++ b/mmedit/models/backbones/generation_backbones/unet_generator.py @@ -37,7 +37,7 @@ def __init__(self, norm_cfg=dict(type='BN'), use_dropout=False, init_cfg=dict(type='normal', gain=0.02)): - super(UnetGenerator, self).__init__() + super().__init__() # We use norm layers in the unet generator. assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but" f'got {type(norm_cfg)}') @@ -52,7 +52,7 @@ def __init__(self, norm_cfg=norm_cfg, is_innermost=True) # add intermediate layers with base_channels * 8 filters - for i in range(num_down - 5): + for _ in range(num_down - 5): unet_block = UnetSkipConnectionBlock( base_channels * 8, base_channels * 8, diff --git a/mmedit/models/backbones/sr_backbones/duf.py b/mmedit/models/backbones/sr_backbones/duf.py index b66ff00731..3294f3fc05 100644 --- a/mmedit/models/backbones/sr_backbones/duf.py +++ b/mmedit/models/backbones/sr_backbones/duf.py @@ -17,7 +17,7 @@ class DynamicUpsamplingFilter(nn.Module): """ def __init__(self, filter_size=(5, 5)): - super(DynamicUpsamplingFilter, self).__init__() + super().__init__() if not isinstance(filter_size, tuple): raise TypeError('The type of filter_size must be tuple, ' f'but got type{filter_size}') diff --git a/mmedit/models/backbones/sr_backbones/edsr.py b/mmedit/models/backbones/sr_backbones/edsr.py index d201a5439a..ca66ea8996 100644 --- a/mmedit/models/backbones/sr_backbones/edsr.py +++ b/mmedit/models/backbones/sr_backbones/edsr.py @@ -33,7 +33,7 @@ def __init__(self, scale, mid_channels): raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(UpsampleModule, self).__init__(*modules) + super().__init__(*modules) @BACKBONES.register_module() @@ -68,7 +68,7 @@ def __init__(self, res_scale=1, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0)): - super(EDSR, self).__init__() + super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.mid_channels = mid_channels diff --git a/mmedit/models/backbones/sr_backbones/edvr_net.py b/mmedit/models/backbones/sr_backbones/edvr_net.py index 3731b45129..ffe4f0f27f 100644 --- a/mmedit/models/backbones/sr_backbones/edvr_net.py +++ b/mmedit/models/backbones/sr_backbones/edvr_net.py @@ -32,7 +32,7 @@ class ModulatedDCNPack(ModulatedDeformConv2d): """ def __init__(self, *args, **kwargs): - super(ModulatedDCNPack, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.conv_offset = nn.Conv2d( self.in_channels, @@ -73,7 +73,7 @@ def __init__(self, mid_channels=64, deform_groups=8, act_cfg=dict(type='LeakyReLU', negative_slope=0.1)): - super(PCDAlignment, self).__init__() + super().__init__() # Pyramid has three levels: # L3: level 3, 1/4 spatial size @@ -203,7 +203,7 @@ def __init__(self, num_frames=5, center_frame_idx=2, act_cfg=dict(type='LeakyReLU', negative_slope=0.1)): - super(TSAFusion, self).__init__() + super().__init__() self.center_frame_idx = center_frame_idx # temporal attention (before fusion conv) self.temporal_attn1 = nn.Conv2d( @@ -329,7 +329,7 @@ def __init__(self, num_blocks_reconstruction=10, center_frame_idx=2, with_tsa=True): - super(EDVRNet, self).__init__() + super().__init__() self.center_frame_idx = center_frame_idx self.with_tsa = with_tsa act_cfg = dict(type='LeakyReLU', negative_slope=0.1) @@ -374,7 +374,7 @@ def __init__(self, mid_channels, 64, 2, upsample_kernel=3) # we fix the output channels in the last few layers to 64. self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1) - self.conv_last = nn.Conv2d(64, 3, 3, 1, 1) + self.conv_last = nn.Conv2d(64, out_channels, 3, 1, 1) self.img_upsample = nn.Upsample( scale_factor=4, mode='bilinear', align_corners=False) # activation function diff --git a/mmedit/models/backbones/sr_backbones/rrdb_net.py b/mmedit/models/backbones/sr_backbones/rrdb_net.py index 7ea16597ca..4ae4e2ff98 100644 --- a/mmedit/models/backbones/sr_backbones/rrdb_net.py +++ b/mmedit/models/backbones/sr_backbones/rrdb_net.py @@ -19,7 +19,7 @@ class ResidualDenseBlock(nn.Module): """ def __init__(self, mid_channels=64, growth_channels=32): - super(ResidualDenseBlock, self).__init__() + super().__init__() for i in range(5): out_channels = mid_channels if i == 4 else growth_channels self.add_module( @@ -69,7 +69,7 @@ class RRDB(nn.Module): """ def __init__(self, mid_channels, growth_channels=32): - super(RRDB, self).__init__() + super().__init__() self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels) self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels) self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels) @@ -113,7 +113,7 @@ def __init__(self, mid_channels=64, num_blocks=23, growth_channels=32): - super(RRDBNet, self).__init__() + super().__init__() self.conv_first = nn.Conv2d(in_channels, mid_channels, 3, 1, 1) self.body = make_layer( RRDB, diff --git a/mmedit/models/backbones/sr_backbones/sr_resnet.py b/mmedit/models/backbones/sr_backbones/sr_resnet.py index 84bbb6c010..1916f855c6 100644 --- a/mmedit/models/backbones/sr_backbones/sr_resnet.py +++ b/mmedit/models/backbones/sr_backbones/sr_resnet.py @@ -35,7 +35,7 @@ def __init__(self, num_blocks=16, upscale_factor=4): - super(MSRResNet, self).__init__() + super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.mid_channels = mid_channels diff --git a/mmedit/models/backbones/sr_backbones/srcnn.py b/mmedit/models/backbones/sr_backbones/srcnn.py index 324a3e1dcc..9844056f5d 100644 --- a/mmedit/models/backbones/sr_backbones/srcnn.py +++ b/mmedit/models/backbones/sr_backbones/srcnn.py @@ -28,7 +28,7 @@ def __init__(self, channels=(3, 64, 32, 3), kernel_sizes=(9, 1, 5), upscale_factor=4): - super(SRCNN, self).__init__() + super().__init__() assert len(channels) == 4, ('The length of channel tuple should be 4, ' f'but got {len(channels)}') assert len(kernel_sizes) == 3, ( diff --git a/mmedit/models/backbones/sr_backbones/tof.py b/mmedit/models/backbones/sr_backbones/tof.py index 05e48ddac6..0348ae3092 100644 --- a/mmedit/models/backbones/sr_backbones/tof.py +++ b/mmedit/models/backbones/sr_backbones/tof.py @@ -17,7 +17,7 @@ class BasicModule(nn.Module): """ def __init__(self): - super(BasicModule, self).__init__() + super().__init__() self.basic_module = nn.Sequential( ConvModule( @@ -88,8 +88,8 @@ class SPyNet(nn.Module): https://github.com/Coldog2333/pytoflow """ - def __init__(self, load_path=None): - super(SPyNet, self).__init__() + def __init__(self): + super().__init__() self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)]) @@ -159,7 +159,7 @@ class TOFlow(nn.Module): """ def __init__(self, adapt_official_weights=False): - super(TOFlow, self).__init__() + super().__init__() self.adapt_official_weights = adapt_official_weights self.ref_idx = 0 if adapt_official_weights else 3 diff --git a/mmedit/models/base.py b/mmedit/models/base.py index a42fd239fb..a7b48e0f50 100644 --- a/mmedit/models/base.py +++ b/mmedit/models/base.py @@ -20,16 +20,12 @@ class BaseModel(nn.Module, metaclass=ABCMeta): ``train_step``, supporting to train one step when training. """ - def __init__(self): - super(BaseModel, self).__init__() - @abstractmethod def init_weights(self): """Abstract method for initializing weight. All subclass should overwrite it. """ - pass @abstractmethod def forward_train(self, imgs, labels): @@ -37,7 +33,6 @@ def forward_train(self, imgs, labels): All subclass should overwrite it. """ - pass @abstractmethod def forward_test(self, imgs): @@ -45,7 +40,6 @@ def forward_test(self, imgs): All subclass should overwrite it. """ - pass def forward(self, imgs, labels, test_mode, **kwargs): """Forward function for base model. @@ -60,18 +54,17 @@ def forward(self, imgs, labels, test_mode, **kwargs): Tensor: Forward results. """ - if not test_mode: - return self.forward_train(imgs, labels, **kwargs) - else: + if test_mode: return self.forward_test(imgs, **kwargs) + return self.forward_train(imgs, labels, **kwargs) + @abstractmethod def train_step(self, data_batch, optimizer): """Abstract method for one training step. All subclass should overwrite it. """ - pass def val_step(self, data_batch, **kwargs): """Abstract method for one validation step. diff --git a/mmedit/models/builder.py b/mmedit/models/builder.py index d34ed4729e..9b0b320f1b 100644 --- a/mmedit/models/builder.py +++ b/mmedit/models/builder.py @@ -17,8 +17,8 @@ def build(cfg, registry, default_args=None): build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg ] return nn.Sequential(*modules) - else: - return build_from_cfg(cfg, registry, default_args) + + return build_from_cfg(cfg, registry, default_args) def build_backbone(cfg): diff --git a/mmedit/models/common/aspp.py b/mmedit/models/common/aspp.py index 05eab59526..6b95080d03 100644 --- a/mmedit/models/common/aspp.py +++ b/mmedit/models/common/aspp.py @@ -9,7 +9,7 @@ class ASPPPooling(nn.Sequential): def __init__(self, in_channels, out_channels, conv_cfg, norm_cfg, act_cfg): - super(ASPPPooling, self).__init__( + super().__init__( nn.AdaptiveAvgPool2d(1), ConvModule( in_channels, @@ -31,7 +31,8 @@ class ASPP(nn.Module): """ASPP module from DeepLabV3. The code is adopted from - https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/deeplabv3.py # noqa + https://github.com/pytorch/vision/blob/master/torchvision/models/ + segmentation/deeplabv3.py For more information about the module: `"Rethinking Atrous Convolution for Semantic Image Segmentation" @@ -63,7 +64,7 @@ def __init__(self, norm_cfg=dict(type='BN'), act_cfg=dict(type='ReLU'), separable_conv=False): - super(ASPP, self).__init__() + super().__init__() if separable_conv: conv_module = DepthwiseSeparableConvModule diff --git a/mmedit/models/common/contextual_attention.py b/mmedit/models/common/contextual_attention.py index 9129dc2d99..3b28b4a71e 100644 --- a/mmedit/models/common/contextual_attention.py +++ b/mmedit/models/common/contextual_attention.py @@ -32,7 +32,7 @@ class ContextualAttentionModule(nn.Module): Default: 3. softmax_scale (float): The scale factor for softmax function. Default: 10. - return_attenion_score (bool): If True, the attention score will be + return_attention_score (bool): If True, the attention score will be returned. Default: True. """ @@ -47,8 +47,8 @@ def __init__(self, scale=0.5, fuse_kernel_size=3, softmax_scale=10, - return_attenion_score=True): - super(ContextualAttentionModule, self).__init__() + return_attention_score=True): + super().__init__() self.unfold_raw_kernel_size = unfold_raw_kernel_size self.unfold_raw_stride = unfold_raw_stride self.unfold_raw_padding = unfold_raw_padding @@ -60,7 +60,7 @@ def __init__(self, self.fuse_kernel_size = fuse_kernel_size self.with_fuse_correlation = fuse_kernel_size > 1 self.softmax_scale = softmax_scale - self.return_attention_score = return_attenion_score + self.return_attention_score = return_attention_score if self.with_fuse_correlation: assert fuse_kernel_size % 2 == 1 @@ -174,7 +174,7 @@ def patch_copy_deconv(self, attention_score, context_filter): Returns: torch.Tensor: Tensor with shape of (n, c, h, w). """ - n, num_context, h, w = attention_score.size() + n, _, h, w = attention_score.size() attention_score = attention_score.view(1, -1, h, w) output = F.conv_transpose2d( attention_score, @@ -258,7 +258,7 @@ def calculate_unfold_hw(self, return h_unfold, w_unfold def calculate_overlap_factor(self, attention_score): - """Calculte the overlap factor after applying deconv. + """Calculate the overlap factor after applying deconv. Args: attention_score (torch.Tensor): The attention score with shape of diff --git a/mmedit/models/common/conv.py b/mmedit/models/common/conv.py index c827feffc4..e62412e954 100644 --- a/mmedit/models/common/conv.py +++ b/mmedit/models/common/conv.py @@ -1,5 +1,5 @@ from mmcv.cnn import CONV_LAYERS -from torch import nn as nn +from torch import nn CONV_LAYERS.register_module('Deconv', module=nn.ConvTranspose2d) # TODO: octave conv diff --git a/mmedit/models/common/gated_conv_module.py b/mmedit/models/common/gated_conv_module.py index 67907ab8b2..33222245ff 100644 --- a/mmedit/models/common/gated_conv_module.py +++ b/mmedit/models/common/gated_conv_module.py @@ -35,7 +35,7 @@ def __init__(self, feat_act_cfg=dict(type='ELU'), gate_act_cfg=dict(type='Sigmoid'), **kwargs): - super(SimpleGatedConvModule, self).__init__() + super().__init__() # the activation function should specified outside conv module kwargs_ = copy.deepcopy(kwargs) kwargs_['act_cfg'] = None diff --git a/mmedit/models/common/gca_module.py b/mmedit/models/common/gca_module.py index 2c8ee69903..3588061b9d 100644 --- a/mmedit/models/common/gca_module.py +++ b/mmedit/models/common/gca_module.py @@ -18,8 +18,8 @@ class GCAModule(nn.Module): alpha feature patches could be specified by `rate` (see `rate` below). The image feature patches are used to convolve with the image feature itself to calculate the contextual attention. Then the attention feature map is - convolved by alpha feature patches to obtain the attentioned alpha feature. - At last, the attentioned alpah feature is added to the input alpha feature. + convolved by alpha feature patches to obtain the attention alpha feature. + At last, the attention alpha feature is added to the input alpha feature. Args: in_channels (int): Input channels of the guided contextual attention @@ -54,7 +54,7 @@ def __init__(self, interpolation='nearest', penalty=-1e4, eps=1e-4): - super(GCAModule, self).__init__() + super().__init__() self.kernel_size = kernel_size self.stride = stride self.rate = rate @@ -66,7 +66,7 @@ def __init__(self, # reduced the channels of input image feature. self.guidance_conv = nn.Conv2d(in_channels, in_channels // 2, 1) - # convolution after the attentioned alpha feature + # convolution after the attention alpha feature self.out_conv = ConvModule( out_channels, out_channels, @@ -171,7 +171,7 @@ def extract_feature_maps_patches(self, img_feat, alpha_feat, unknown): ``Tensor``: Image feature patches of shape \ (N, img_h*img_w, img_c, img_ks, img_ks). - ``Tensor``: Guided contextual attentioned alpha feature map. \ + ``Tensor``: Guided contextual attention alpha feature map. \ (N, img_h*img_w, alpha_c, alpha_ks, alpha_ks). ``Tensor``: Unknown mask of shape (N, img_h*img_w, 1, 1). @@ -187,7 +187,7 @@ def extract_feature_maps_patches(self, img_feat, alpha_feat, unknown): # extract unknown mask patches with shape: (N, img_h*img_w, 1, 1) unknown_ps = self.extract_patches(unknown, img_ks, self.stride) - unknown_ps = unknown_ps.squeeze(dim=2) # squeezz channel dimention + unknown_ps = unknown_ps.squeeze(dim=2) # squeeze channel dimension unknown_ps = unknown_ps.mean(dim=[2, 3], keepdim=True) return img_ps, alpha_ps, unknown_ps @@ -256,7 +256,7 @@ def propagate_alpha_feature(self, gca_score, alpha_ps): (1, img_h*img_w, alpha_c, alpha_ks, alpha_ks). Returns: - Tensor: Propagted alpha feature map of shape \ + Tensor: Propagated alpha feature map of shape \ (1, alpha_c, alpha_h, alpha_w). """ alpha_ps = alpha_ps[0] # squeeze dim 0 diff --git a/mmedit/models/common/generation_model_utils.py b/mmedit/models/common/generation_model_utils.py index 830a68f0c6..c493f62f9d 100644 --- a/mmedit/models/common/generation_model_utils.py +++ b/mmedit/models/common/generation_model_utils.py @@ -53,7 +53,7 @@ def init_func(m): module.apply(init_func) -class GANImageBuffer(object): +class GANImageBuffer: """This class implements an image buffer that stores previously generated images. @@ -139,7 +139,7 @@ def __init__(self, is_innermost=False, norm_cfg=dict(type='BN'), use_dropout=False): - super(UnetSkipConnectionBlock, self).__init__() + super().__init__() # cannot be both outermost and innermost assert not (is_outermost and is_innermost), ( "'is_outermost' and 'is_innermost' cannot be True" @@ -223,9 +223,9 @@ def forward(self, x): """ if self.is_outermost: return self.model(x) - else: - # add skip connections - return torch.cat([x, self.model(x)], 1) + + # add skip connections + return torch.cat([x, self.model(x)], 1) class ResidualBlockWithDropout(nn.Module): @@ -251,7 +251,7 @@ def __init__(self, padding_mode, norm_cfg=dict(type='BN'), use_dropout=True): - super(ResidualBlockWithDropout, self).__init__() + super().__init__() assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but" f'got {type(norm_cfg)}') assert 'type' in norm_cfg, "'norm_cfg' must have key 'type'" diff --git a/mmedit/models/common/linear_module.py b/mmedit/models/common/linear_module.py index 130e514d23..34b438db03 100644 --- a/mmedit/models/common/linear_module.py +++ b/mmedit/models/common/linear_module.py @@ -5,7 +5,7 @@ class LinearModule(nn.Module): """A linear block that contains linear/norm/activation layers. - For low level visioin, we add spectral norm and padding layer. + For low level vision, we add spectral norm and padding layer. Args: in_features (int): Same as nn.Linear. @@ -27,7 +27,7 @@ def __init__(self, inplace=True, with_spectral_norm=False, order=('linear', 'act')): - super(LinearModule, self).__init__() + super().__init__() assert act_cfg is None or isinstance(act_cfg, dict) self.act_cfg = act_cfg self.inplace = inplace @@ -69,7 +69,7 @@ def init_weights(self): kaiming_init(self.linear, a=a, nonlinearity=nonlinearity) def forward(self, x, activate=True): - """Foward Function. + """Forward Function. Args: x (torch.Tensor): Input tensor with shape of (n, \*, # noqa: W605 diff --git a/mmedit/models/common/mask_conv_module.py b/mmedit/models/common/mask_conv_module.py index c265c1082e..d7556477e3 100644 --- a/mmedit/models/common/mask_conv_module.py +++ b/mmedit/models/common/mask_conv_module.py @@ -26,7 +26,7 @@ class MaskConvModule(ConvModule): padding_mode (str): If the `padding_mode` has not been supported by current `Conv2d` in Pytorch, we will use our own padding layer instead. Currently, we support ['zeros', 'circular'] with official - implementation and ['reflect'] with our own implementaion. + implementation and ['reflect'] with our own implementation. Default: 'zeros'. order (tuple[str]): The order of conv/norm/activation layers. It is a sequence of "conv", "norm" and "act". Examples are @@ -35,7 +35,7 @@ class MaskConvModule(ConvModule): supported_conv_list = ['PConv'] def __init__(self, *args, **kwargs): - super(MaskConvModule, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) assert self.conv_cfg['type'] in self.supported_conv_list self.init_weights() @@ -83,5 +83,5 @@ def forward(self, if return_mask: return x, updated_mask - else: - return x + + return x diff --git a/mmedit/models/common/model_utils.py b/mmedit/models/common/model_utils.py index 794d579ee5..32cd0bbb36 100644 --- a/mmedit/models/common/model_utils.py +++ b/mmedit/models/common/model_utils.py @@ -3,7 +3,7 @@ def set_requires_grad(nets, requires_grad=False): - """Set requies_grad for all the networks. + """Set requires_grad for all the networks. Args: nets (nn.Module | list[nn.Module]): A list of networks or a single diff --git a/mmedit/models/common/partial_conv.py b/mmedit/models/common/partial_conv.py index 2a3dab6feb..21083714b6 100644 --- a/mmedit/models/common/partial_conv.py +++ b/mmedit/models/common/partial_conv.py @@ -13,14 +13,14 @@ class PartialConv2d(nn.Conv2d): [https://arxiv.org/abs/1804.07723] Args: - multi_channel (bool): If True, the mask is multi-channle. Otherwise, + multi_channel (bool): If True, the mask is multi-channel. Otherwise, the mask is single-channel. eps (float): Need to be changed for mixed precision training. For mixed precision training, you need change 1e-8 to 1e-6. """ def __init__(self, *args, multi_channel=False, eps=1e-8, **kwargs): - super(PartialConv2d, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # whether the mask is multi-channel or not self.multi_channel = multi_channel @@ -82,7 +82,7 @@ def forward(self, input, mask=None, return_mask=True): # standard conv2d if mask is not None: input = input * mask - raw_out = super(PartialConv2d, self).forward(input) + raw_out = super().forward(input) if mask is not None: if self.bias is None: @@ -97,5 +97,5 @@ def forward(self, input, mask=None, return_mask=True): if return_mask and mask is not None: return output, updated_mask - else: - return output + + return output diff --git a/mmedit/models/common/separable_conv_module.py b/mmedit/models/common/separable_conv_module.py index ed0ab301b1..cc8badc1d6 100644 --- a/mmedit/models/common/separable_conv_module.py +++ b/mmedit/models/common/separable_conv_module.py @@ -51,7 +51,7 @@ def __init__(self, pw_norm_cfg='default', pw_act_cfg='default', **kwargs): - super(DepthwiseSeparableConvModule, self).__init__() + super().__init__() assert 'groups' not in kwargs, 'groups should not be specified' # if norm/activation config of depthwise/pointwise ConvModule is not diff --git a/mmedit/models/common/sr_backbone_utils.py b/mmedit/models/common/sr_backbone_utils.py index f1f7e8e0fb..a4bd8569e9 100644 --- a/mmedit/models/common/sr_backbone_utils.py +++ b/mmedit/models/common/sr_backbone_utils.py @@ -56,7 +56,7 @@ class ResidualBlockNoBN(nn.Module): """ def __init__(self, mid_channels=64, res_scale=1.0): - super(ResidualBlockNoBN, self).__init__() + super().__init__() self.res_scale = res_scale self.conv1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True) diff --git a/mmedit/models/common/upsample.py b/mmedit/models/common/upsample.py index 0a564f939c..1dd55357d4 100644 --- a/mmedit/models/common/upsample.py +++ b/mmedit/models/common/upsample.py @@ -19,7 +19,7 @@ class PixelShufflePack(nn.Module): def __init__(self, in_channels, out_channels, scale_factor, upsample_kernel): - super(PixelShufflePack, self).__init__() + super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.scale_factor = scale_factor diff --git a/mmedit/models/components/discriminators/deepfill_disc.py b/mmedit/models/components/discriminators/deepfill_disc.py index 28a33d7a3a..f679ab378f 100644 --- a/mmedit/models/components/discriminators/deepfill_disc.py +++ b/mmedit/models/components/discriminators/deepfill_disc.py @@ -25,7 +25,7 @@ class DeepFillv1Discriminators(nn.Module): """ def __init__(self, global_disc_cfg, local_disc_cfg): - super(DeepFillv1Discriminators, self).__init__() + super().__init__() self.global_disc = build_component(global_disc_cfg) self.local_disc = build_component(local_disc_cfg) diff --git a/mmedit/models/components/discriminators/gl_disc.py b/mmedit/models/components/discriminators/gl_disc.py index ad04cf6686..99cc5ff898 100644 --- a/mmedit/models/components/discriminators/gl_disc.py +++ b/mmedit/models/components/discriminators/gl_disc.py @@ -21,7 +21,7 @@ class GLDiscs(nn.Module): """ def __init__(self, global_disc_cfg, local_disc_cfg): - super(GLDiscs, self).__init__() + super().__init__() self.global_disc = MultiLayerDiscriminator(**global_disc_cfg) self.local_disc = MultiLayerDiscriminator(**local_disc_cfg) diff --git a/mmedit/models/components/discriminators/modified_vgg.py b/mmedit/models/components/discriminators/modified_vgg.py index b6e3e53dec..de3977303b 100644 --- a/mmedit/models/components/discriminators/modified_vgg.py +++ b/mmedit/models/components/discriminators/modified_vgg.py @@ -18,7 +18,7 @@ class ModifiedVGG(nn.Module): """ def __init__(self, in_channels, mid_channels): - super(ModifiedVGG, self).__init__() + super().__init__() self.conv0_0 = nn.Conv2d(in_channels, mid_channels, 3, 1, 1, bias=True) self.conv0_1 = nn.Conv2d( diff --git a/mmedit/models/components/discriminators/multi_layer_disc.py b/mmedit/models/components/discriminators/multi_layer_disc.py index cc949a1e58..04b5d4fd37 100644 --- a/mmedit/models/components/discriminators/multi_layer_disc.py +++ b/mmedit/models/components/discriminators/multi_layer_disc.py @@ -56,7 +56,7 @@ def __init__(self, with_out_convs=False, with_spectral_norm=False, **kwargs): - super(MultiLayerDiscriminator, self).__init__() + super().__init__() if fc_in_channels is not None: assert fc_in_channels > 0 diff --git a/mmedit/models/components/discriminators/patch_disc.py b/mmedit/models/components/discriminators/patch_disc.py index 421738a18c..7e14e5c4a6 100644 --- a/mmedit/models/components/discriminators/patch_disc.py +++ b/mmedit/models/components/discriminators/patch_disc.py @@ -31,7 +31,7 @@ def __init__(self, num_conv=3, norm_cfg=dict(type='BN'), init_cfg=dict(type='normal', gain=0.02)): - super(PatchDiscriminator, self).__init__() + super().__init__() assert isinstance(norm_cfg, dict), ("'norm_cfg' should be dict, but" f'got {type(norm_cfg)}') assert 'type' in norm_cfg, "'norm_cfg' must have key 'type'" diff --git a/mmedit/models/components/refiners/deepfill_refiner.py b/mmedit/models/components/refiners/deepfill_refiner.py index 92a83a3dc7..e5ad872524 100644 --- a/mmedit/models/components/refiners/deepfill_refiner.py +++ b/mmedit/models/components/refiners/deepfill_refiner.py @@ -38,7 +38,7 @@ def __init__(self, contextual_attention=dict( type='ContextualAttentionNeck', in_channels=128), decoder=dict(type='DeepFillDecoder', in_channels=256)): - super(DeepFillRefiner, self).__init__() + super().__init__() self.encoder_attention = build_component(encoder_attention) self.encoder_conv = build_component(encoder_conv) self.contextual_attention_neck = build_component(contextual_attention) diff --git a/mmedit/models/components/refiners/plain_refiner.py b/mmedit/models/components/refiners/plain_refiner.py index 1a32a82190..74694718b6 100644 --- a/mmedit/models/components/refiners/plain_refiner.py +++ b/mmedit/models/components/refiners/plain_refiner.py @@ -17,7 +17,9 @@ class PlainRefiner(nn.Module): """ def __init__(self, conv_channels=64, pretrained=None): - super(PlainRefiner, self).__init__() + super().__init__() + + assert pretrained is None, 'pretrained not supported yet' self.refine_conv1 = nn.Conv2d( 4, conv_channels, kernel_size=3, padding=1) diff --git a/mmedit/models/inpaintors/deepfillv1.py b/mmedit/models/inpaintors/deepfillv1.py index ed9733adcf..0a7d108ea6 100644 --- a/mmedit/models/inpaintors/deepfillv1.py +++ b/mmedit/models/inpaintors/deepfillv1.py @@ -24,8 +24,8 @@ def get_module(self, model, module_name): """ if isinstance(model, (DataParallel, DistributedDataParallel)): return getattr(model.module, module_name) - else: - return getattr(model, module_name) + + return getattr(model, module_name) def forward_train_d(self, data_batch, is_real, is_disc): """Forward function in discriminator training step. diff --git a/mmedit/models/inpaintors/gl_inpaintor.py b/mmedit/models/inpaintors/gl_inpaintor.py index 4c83707bc8..f1a77fc639 100644 --- a/mmedit/models/inpaintors/gl_inpaintor.py +++ b/mmedit/models/inpaintors/gl_inpaintor.py @@ -38,7 +38,7 @@ class GLInpaintor(OneStageInpaintor): iter_td=100000 ) - `iter_tc` and `iter_td` correspond to the noation :math:`T_C` and + `iter_tc` and `iter_td` correspond to the notation :math:`T_C` and :math:`T_D` of theoriginal paper. Args: @@ -75,7 +75,7 @@ def __init__(self, train_cfg=None, test_cfg=None, pretrained=None): - super(GLInpaintor, self).__init__( + super().__init__( encdec, disc=disc, loss_gan=loss_gan, @@ -149,7 +149,7 @@ def train_step(self, data_batch, optimizer): 1. get fake res/image 2. optimize discriminator (if in current schedule) - 3. optimzie generator (if in current schedule) + 3. optimize generator (if in current schedule) If ``self.train_cfg.disc_step > 1``, the train step will contain multiple iterations for optimizing discriminator with different input diff --git a/mmedit/models/inpaintors/one_stage.py b/mmedit/models/inpaintors/one_stage.py index 91a382273d..d5de992fee 100644 --- a/mmedit/models/inpaintors/one_stage.py +++ b/mmedit/models/inpaintors/one_stage.py @@ -60,7 +60,7 @@ def __init__(self, train_cfg=None, test_cfg=None, pretrained=None): - super(OneStageInpaintor, self).__init__() + super().__init__() self.with_l1_hole_loss = loss_l1_hole is not None self.with_l1_valid_loss = loss_l1_valid is not None self.with_tv_loss = loss_tv is not None @@ -130,11 +130,11 @@ def forward(self, masked_img, mask, test_mode=True, **kwargs): Returns: dict: Dict contains output results. """ - if not test_mode: - return self.forward_train(masked_img, mask, **kwargs) - else: + if test_mode: return self.forward_test(masked_img, mask, **kwargs) + return self.forward_train(masked_img, mask, **kwargs) + def forward_train(self, *args, **kwargs): """Forward function for training. diff --git a/mmedit/models/inpaintors/pconv_inpaintor.py b/mmedit/models/inpaintors/pconv_inpaintor.py index 25e09ceb9e..eff6f9b8d3 100644 --- a/mmedit/models/inpaintors/pconv_inpaintor.py +++ b/mmedit/models/inpaintors/pconv_inpaintor.py @@ -26,7 +26,7 @@ def forward_test(self, mask (torch.Tensor): Tensor with shape of (n, 1, h, w). save_image (bool, optional): If True, results will be saved as image. Defaults to False. - save_path (str, optional): If given a valid str, the reuslts will + save_path (str, optional): If given a valid str, the results will be saved in this path. Defaults to None. iteration (int, optional): Iteration number. Defaults to None. diff --git a/mmedit/models/inpaintors/two_stage.py b/mmedit/models/inpaintors/two_stage.py index b9daa8bff6..7f48770cc8 100644 --- a/mmedit/models/inpaintors/two_stage.py +++ b/mmedit/models/inpaintors/two_stage.py @@ -39,7 +39,7 @@ def __init__(self, input_with_ones=True, disc_input_with_mask=False, **kwargs): - super(TwoStageInpaintor, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.stage1_loss_type = stage1_loss_type self.stage2_loss_type = stage2_loss_type @@ -62,7 +62,7 @@ def forward_test(self, mask (torch.Tensor): Tensor with shape of (n, 1, h, w). save_image (bool, optional): If True, results will be saved as image. Defaults to False. - save_path (str, optional): If given a valid str, the reuslts will + save_path (str, optional): If given a valid str, the results will be saved in this path. Defaults to None. iteration (int, optional): Iteration number. Defaults to None. diff --git a/mmedit/models/losses/composition_loss.py b/mmedit/models/losses/composition_loss.py index 3ed097c8b0..cc0d3ace00 100644 --- a/mmedit/models/losses/composition_loss.py +++ b/mmedit/models/losses/composition_loss.py @@ -22,7 +22,7 @@ class L1CompositionLoss(nn.Module): """ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): - super(L1CompositionLoss, self).__init__() + super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') @@ -69,7 +69,7 @@ class MSECompositionLoss(nn.Module): """ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): - super(MSECompositionLoss, self).__init__() + super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') @@ -122,7 +122,7 @@ def __init__(self, reduction='mean', sample_wise=False, eps=1e-12): - super(CharbonnierCompLoss, self).__init__() + super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') diff --git a/mmedit/models/losses/gan_loss.py b/mmedit/models/losses/gan_loss.py index 2d973ea354..618d7e820d 100644 --- a/mmedit/models/losses/gan_loss.py +++ b/mmedit/models/losses/gan_loss.py @@ -23,7 +23,7 @@ def __init__(self, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): - super(GANLoss, self).__init__() + super().__init__() self.gan_type = gan_type self.loss_weight = loss_weight self.real_label_val = real_label_val @@ -104,7 +104,7 @@ def gradient_penalty_loss(discriminator, real_data, fake_data, mask=None): discriminator (nn.Module): Network for the discriminator. real_data (Tensor): Real input data. fake_data (Tensor): Fake input data. - mask (Tensor): Masks for inpaitting. Default: None. + mask (Tensor): Masks for inpainting. Default: None. Returns: Tensor: A tensor for gradient penalty. @@ -145,7 +145,7 @@ class GradientPenaltyLoss(nn.Module): """ def __init__(self, loss_weight=1.): - super(GradientPenaltyLoss, self).__init__() + super().__init__() self.loss_weight = loss_weight def forward(self, discriminator, real_data, fake_data, mask=None): @@ -155,7 +155,7 @@ def forward(self, discriminator, real_data, fake_data, mask=None): discriminator (nn.Module): Network for the discriminator. real_data (Tensor): Real input data. fake_data (Tensor): Fake input data. - mask (Tensor): Masks for inpaitting. Default: None. + mask (Tensor): Masks for inpainting. Default: None. Returns: Tensor: Loss. @@ -175,7 +175,7 @@ class DiscShiftLoss(nn.Module): """ def __init__(self, loss_weight=0.1): - super(DiscShiftLoss, self).__init__() + super().__init__() self.loss_weight = loss_weight def forward(self, x): diff --git a/mmedit/models/losses/gradient_loss.py b/mmedit/models/losses/gradient_loss.py index 57cc73340b..a692593d91 100644 --- a/mmedit/models/losses/gradient_loss.py +++ b/mmedit/models/losses/gradient_loss.py @@ -19,7 +19,7 @@ class GradientLoss(nn.Module): """ def __init__(self, loss_weight=1.0, reduction='mean'): - super(GradientLoss, self).__init__() + super().__init__() self.loss_weight = loss_weight self.reduction = reduction if self.reduction not in ['none', 'mean', 'sum']: diff --git a/mmedit/models/losses/perceptual_loss.py b/mmedit/models/losses/perceptual_loss.py index 210e93a90d..9c9cb2508f 100644 --- a/mmedit/models/losses/perceptual_loss.py +++ b/mmedit/models/losses/perceptual_loss.py @@ -32,7 +32,7 @@ def __init__(self, vgg_type='vgg19', use_input_norm=True, pretrained='torchvision://vgg19'): - super(PerceptualVGG, self).__init__() + super().__init__() if pretrained.startswith('torchvision://'): assert vgg_type in pretrained self.layer_name_list = layer_name_list @@ -127,7 +127,7 @@ def __init__(self, norm_img=True, pretrained='torchvision://vgg19', criterion='l1'): - super(PerceptualLoss, self).__init__() + super().__init__() self.norm_img = norm_img self.perceptual_weight = perceptual_weight self.style_weight = style_weight diff --git a/mmedit/models/losses/pixelwise_loss.py b/mmedit/models/losses/pixelwise_loss.py index 3e0dbbec4b..632f343877 100644 --- a/mmedit/models/losses/pixelwise_loss.py +++ b/mmedit/models/losses/pixelwise_loss.py @@ -66,7 +66,7 @@ class L1Loss(nn.Module): """ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): - super(L1Loss, self).__init__() + super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') @@ -108,7 +108,7 @@ class MSELoss(nn.Module): """ def __init__(self, loss_weight=1.0, reduction='mean', sample_wise=False): - super(MSELoss, self).__init__() + super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') @@ -160,7 +160,7 @@ def __init__(self, reduction='mean', sample_wise=False, eps=1e-12): - super(CharbonnierLoss, self).__init__() + super().__init__() if reduction not in ['none', 'mean', 'sum']: raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') @@ -197,7 +197,7 @@ class MaskedTVLoss(L1Loss): """ def __init__(self, loss_weight=1.0): - super(MaskedTVLoss, self).__init__(loss_weight=loss_weight) + super().__init__(loss_weight=loss_weight) def forward(self, pred, mask=None): """Forward function. @@ -210,9 +210,9 @@ def forward(self, pred, mask=None): Returns: [type]: [description] """ - y_diff = super(MaskedTVLoss, self).forward( + y_diff = super().forward( pred[:, :, :-1, :], pred[:, :, 1:, :], weight=mask[:, :, :-1, :]) - x_diff = super(MaskedTVLoss, self).forward( + x_diff = super().forward( pred[:, :, :, :-1], pred[:, :, :, 1:], weight=mask[:, :, :, :-1]) loss = x_diff + y_diff diff --git a/mmedit/models/losses/utils.py b/mmedit/models/losses/utils.py index db7ac7fc54..e300fb9732 100644 --- a/mmedit/models/losses/utils.py +++ b/mmedit/models/losses/utils.py @@ -17,10 +17,10 @@ def reduce_loss(loss, reduction): # none: 0, elementwise_mean:1, sum: 2 if reduction_enum == 0: return loss - elif reduction_enum == 1: + if reduction_enum == 1: return loss.mean() - else: - return loss.sum() + + return loss.sum() def mask_reduce_loss(loss, weight=None, reduction='mean', sample_wise=False): diff --git a/mmedit/models/mattors/base_mattor.py b/mmedit/models/mattors/base_mattor.py index b01f9533c2..c83b2de323 100644 --- a/mmedit/models/mattors/base_mattor.py +++ b/mmedit/models/mattors/base_mattor.py @@ -49,7 +49,7 @@ def __init__(self, train_cfg=None, test_cfg=None, pretrained=None): - super(BaseMattor, self).__init__() + super().__init__() self.train_cfg = train_cfg if train_cfg is not None else ConfigDict() self.test_cfg = test_cfg if test_cfg is not None else ConfigDict() @@ -77,7 +77,8 @@ def __init__(self, # validate if test config is proper if not hasattr(self.test_cfg, 'metrics'): raise KeyError('Missing key "metrics" in test_cfg') - elif mmcv.is_list_of(self.test_cfg.metrics, str): + + if mmcv.is_list_of(self.test_cfg.metrics, str): for metric in self.test_cfg.metrics: if metric not in self.allowed_metrics: raise KeyError(f'metric {metric} is not supported') @@ -207,13 +208,11 @@ def forward_train(self, merged, trimap, alpha, **kwargs): trimap (Tensor): Trimap of the input image. alpha (Tensor): Ground-truth alpha matte. """ - pass @abstractmethod def forward_test(self, merged, trimap, meta, **kwargs): """Defines the computation performed at every test call. """ - pass def train_step(self, data_batch, optimizer): """Defines the computation and network update at every training call. @@ -262,7 +261,7 @@ def forward(self, are set to ``True``. Otherwise return the output of \ ``self.forward_train``. """ - if not test_mode: - return self.forward_train(merged, trimap, meta, alpha, **kwargs) - else: + if test_mode: return self.forward_test(merged, trimap, meta, **kwargs) + + return self.forward_train(merged, trimap, meta, alpha, **kwargs) diff --git a/mmedit/models/mattors/dim.py b/mmedit/models/mattors/dim.py index b5a9480f48..4f95eb4c28 100644 --- a/mmedit/models/mattors/dim.py +++ b/mmedit/models/mattors/dim.py @@ -46,8 +46,7 @@ def __init__(self, loss_alpha=None, loss_comp=None, loss_refine=None): - super(DIM, self).__init__(backbone, refiner, train_cfg, test_cfg, - pretrained) + super().__init__(backbone, refiner, train_cfg, test_cfg, pretrained) if all(v is None for v in (loss_alpha, loss_comp, loss_refine)): raise ValueError('Please specify one loss for DIM.') diff --git a/mmedit/models/mattors/gca.py b/mmedit/models/mattors/gca.py index 0fe6f02f36..46fbf309b2 100644 --- a/mmedit/models/mattors/gca.py +++ b/mmedit/models/mattors/gca.py @@ -30,8 +30,7 @@ def __init__(self, test_cfg=None, pretrained=None, loss_alpha=None): - super(GCA, self).__init__(backbone, None, train_cfg, test_cfg, - pretrained) + super().__init__(backbone, None, train_cfg, test_cfg, pretrained) self.loss_alpha = build_loss(loss_alpha) # support fp16 self.fp16_enabled = False diff --git a/mmedit/models/mattors/indexnet.py b/mmedit/models/mattors/indexnet.py index 3965250557..e886df3759 100644 --- a/mmedit/models/mattors/indexnet.py +++ b/mmedit/models/mattors/indexnet.py @@ -31,8 +31,7 @@ def __init__(self, pretrained=None, loss_alpha=None, loss_comp=None): - super(IndexNet, self).__init__(backbone, None, train_cfg, test_cfg, - pretrained) + super().__init__(backbone, None, train_cfg, test_cfg, pretrained) self.loss_alpha = ( build_loss(loss_alpha) if loss_alpha is not None else None) diff --git a/mmedit/models/restorers/basic_restorer.py b/mmedit/models/restorers/basic_restorer.py index 983e706cb9..9d6052b350 100644 --- a/mmedit/models/restorers/basic_restorer.py +++ b/mmedit/models/restorers/basic_restorer.py @@ -35,7 +35,7 @@ def __init__(self, train_cfg=None, test_cfg=None, pretrained=None): - super(BasicRestorer, self).__init__() + super().__init__() self.train_cfg = train_cfg self.test_cfg = test_cfg @@ -70,11 +70,11 @@ def forward(self, lq, gt=None, test_mode=False, **kwargs): kwargs (dict): Other arguments. """ - if not test_mode: - return self.forward_train(lq, gt) - else: + if test_mode: return self.forward_test(lq, gt, **kwargs) + return self.forward_train(lq, gt) + def forward_train(self, lq, gt): """Training forward function. diff --git a/mmedit/models/restorers/edvr.py b/mmedit/models/restorers/edvr.py index 5bbbb76893..8b582329d5 100644 --- a/mmedit/models/restorers/edvr.py +++ b/mmedit/models/restorers/edvr.py @@ -28,8 +28,8 @@ def __init__(self, train_cfg=None, test_cfg=None, pretrained=None): - super(EDVR, self).__init__(generator, pixel_loss, train_cfg, test_cfg, - pretrained) + super().__init__(generator, pixel_loss, train_cfg, test_cfg, + pretrained) self.with_tsa = generator.get('with_tsa', False) self.step_counter = 0 # count training steps diff --git a/mmedit/models/restorers/srgan.py b/mmedit/models/restorers/srgan.py index fe7ff17ae2..21ffa46088 100644 --- a/mmedit/models/restorers/srgan.py +++ b/mmedit/models/restorers/srgan.py @@ -90,12 +90,12 @@ def forward(self, lq, gt=None, test_mode=False, **kwargs): test_mode (bool): Whether in test mode or not. Default: False. kwargs (dict): Other arguments. """ - if not test_mode: - raise ValueError( - 'SRGAN model does not supprot `forward_train` function.') - else: + if test_mode: return self.forward_test(lq, gt, **kwargs) + raise ValueError( + 'SRGAN model does not supprot `forward_train` function.') + def train_step(self, data_batch, optimizer): """Train step. diff --git a/mmedit/models/synthesizers/cycle_gan.py b/mmedit/models/synthesizers/cycle_gan.py index e02da764d7..af68e4304e 100644 --- a/mmedit/models/synthesizers/cycle_gan.py +++ b/mmedit/models/synthesizers/cycle_gan.py @@ -58,7 +58,7 @@ def __init__(self, train_cfg=None, test_cfg=None, pretrained=None): - super(CycleGAN, self).__init__() + super().__init__() self.train_cfg = train_cfg self.test_cfg = test_cfg @@ -144,8 +144,8 @@ def get_module(self, module): """ if isinstance(module, MMDistributedDataParallel): return module.module - else: - return module + + return module def setup(self, img_a, img_b, meta): """Perform necessary pre-processing steps. @@ -179,7 +179,7 @@ def forward_train(self, img_a, img_b, meta): dict: Dict of forward results for training. """ # necessary setup - real_a, real_b, image_path = self.setup(img_a, img_b, meta) + real_a, real_b, _ = self.setup(img_a, img_b, meta) generators = self.get_module(self.generators) @@ -305,11 +305,11 @@ def forward(self, img_a, img_b, meta, test_mode=False, **kwargs): test_mode (bool): Whether in test mode or not. Default: False. kwargs (dict): Other arguments. """ - if not test_mode: - return self.forward_train(img_a, img_b, meta) - else: + if test_mode: return self.forward_test(img_a, img_b, meta, **kwargs) + return self.forward_train(img_a, img_b, meta) + def backward_discriminators(self, outputs): """Backward function for the discriminators. diff --git a/mmedit/models/synthesizers/pix2pix.py b/mmedit/models/synthesizers/pix2pix.py index 2b3416f79e..daa404a33a 100644 --- a/mmedit/models/synthesizers/pix2pix.py +++ b/mmedit/models/synthesizers/pix2pix.py @@ -49,7 +49,7 @@ def __init__(self, train_cfg=None, test_cfg=None, pretrained=None): - super(Pix2Pix, self).__init__() + super().__init__() self.train_cfg = train_cfg self.test_cfg = test_cfg @@ -124,7 +124,7 @@ def forward_train(self, img_a, img_b, meta): dict: Dict of forward results for training. """ # necessary setup - real_a, real_b, image_path = self.setup(img_a, img_b, meta) + real_a, real_b, _ = self.setup(img_a, img_b, meta) fake_b = self.generator(real_a) results = dict(real_a=real_a, fake_b=fake_b, real_b=real_b) return results @@ -216,11 +216,11 @@ def forward(self, img_a, img_b, meta, test_mode=False, **kwargs): test_mode (bool): Whether in test mode or not. Default: False. kwargs (dict): Other arguments. """ - if not test_mode: - return self.forward_train(img_a, img_b, meta) - else: + if test_mode: return self.forward_test(img_a, img_b, meta, **kwargs) + return self.forward_train(img_a, img_b, meta) + def backward_discriminator(self, outputs): """Backward function for the discriminator. diff --git a/mmedit/version.py b/mmedit/version.py index 26079166a9..7b444688f6 100644 --- a/mmedit/version.py +++ b/mmedit/version.py @@ -4,15 +4,15 @@ def parse_version_info(version_str): - version_info = [] + ver_info = [] for x in version_str.split('.'): if x.isdigit(): - version_info.append(int(x)) + ver_info.append(int(x)) elif x.find('rc') != -1: patch_version = x.split('rc') - version_info.append(int(patch_version[0])) - version_info.append(f'rc{patch_version[1]}') - return tuple(version_info) + ver_info.append(int(patch_version[0])) + ver_info.append(f'rc{patch_version[1]}') + return tuple(ver_info) version_info = parse_version_info(__version__) diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py index 6644743d50..227587105e 100644 --- a/tests/test_augmentation.py +++ b/tests/test_augmentation.py @@ -15,7 +15,7 @@ # yapf: enable -class TestAugmentations(object): +class TestAugmentations: @classmethod def setup_class(cls): diff --git a/tests/test_crop.py b/tests/test_crop.py index b89a4204ca..46a1fbdcf9 100644 --- a/tests/test_crop.py +++ b/tests/test_crop.py @@ -8,7 +8,7 @@ PairedRandomCrop) -class TestAugmentations(object): +class TestAugmentations: @classmethod def setup_class(cls): diff --git a/tests/test_dataset_builder.py b/tests/test_dataset_builder.py index 518692cda0..ffe1aa6958 100644 --- a/tests/test_dataset_builder.py +++ b/tests/test_dataset_builder.py @@ -8,7 +8,7 @@ @DATASETS.register_module() -class ToyDataset(object): +class ToyDataset: def __init__(self, ann_file=None, cnt=0): self.ann_file = ann_file @@ -22,7 +22,7 @@ def __len__(self): @DATASETS.register_module() -class ToyDatasetWithAnnFile(object): +class ToyDatasetWithAnnFile: def __init__(self, ann_file): self.ann_file = ann_file diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 01e9791fb5..5f6aef961e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -14,14 +14,14 @@ SRVimeo90KDataset) -def mock_open(*args, **kargs): +def mock_open(*args, **kwargs): """unittest.mock_open wrapper. unittest.mock_open doesn't support iteration. Wrap it to fix this bug. Reference: https://stackoverflow.com/a/41656192 """ import unittest - f_open = unittest.mock.mock_open(*args, **kargs) + f_open = unittest.mock.mock_open(*args, **kwargs) f_open.return_value.__iter__ = lambda self: iter(self.readline, '') return f_open @@ -31,7 +31,7 @@ def check_keys_contain(result_keys, target_keys): return set(target_keys).issubset(set(result_keys)) -class TestMattingDatasets(object): +class TestMattingDatasets: @classmethod def setup_class(cls): @@ -78,7 +78,7 @@ def test_comp1k_evaluate(self): assert eval_result['MSE'] == 0.005 -class TestSRDatasets(object): +class TestSRDatasets: @classmethod def setup_class(cls): @@ -90,7 +90,7 @@ class ToyDataset(BaseSRDataset): """Toy dataset for testing SRDataset.""" def __init__(self, pipeline, test_mode=False): - super(ToyDataset, self).__init__(pipeline, test_mode) + super().__init__(pipeline, test_mode) def load_annotations(self): pass @@ -312,7 +312,7 @@ def test_sr_lmdb_dataset(self): scale=1) -class TestGenerationDatasets(object): +class TestGenerationDatasets: @classmethod def setup_class(cls): @@ -490,18 +490,18 @@ def test_generation_unpaired_dataset(self): unpair_folder = self.data_prefix / 'unpaired' # input path is Path object - generation_unparied_dataset = GenerationUnpairedDataset( + generation_unpaired_dataset = GenerationUnpairedDataset( dataroot=unpair_folder, pipeline=pipeline, test_mode=True) - data_infos_a = generation_unparied_dataset.data_infos_a - data_infos_b = generation_unparied_dataset.data_infos_b + data_infos_a = generation_unpaired_dataset.data_infos_a + data_infos_b = generation_unpaired_dataset.data_infos_b assert data_infos_a == [ dict(path=str(unpair_folder / 'testA' / '5.jpg')) ] assert data_infos_b == [ dict(path=str(unpair_folder / 'testB' / '6.jpg')) ] - result = generation_unparied_dataset[0] - assert (len(generation_unparied_dataset) == 1) + result = generation_unpaired_dataset[0] + assert (len(generation_unpaired_dataset) == 1) assert check_keys_contain(result.keys(), target_keys) assert check_keys_contain(result['meta'].data.keys(), target_meta_keys) assert (result['meta'].data['img_a_path'] == str(unpair_folder / @@ -510,18 +510,18 @@ def test_generation_unpaired_dataset(self): 'testB' / '6.jpg')) # input path is str - generation_unparied_dataset = GenerationUnpairedDataset( + generation_unpaired_dataset = GenerationUnpairedDataset( dataroot=str(unpair_folder), pipeline=pipeline, test_mode=True) - data_infos_a = generation_unparied_dataset.data_infos_a - data_infos_b = generation_unparied_dataset.data_infos_b + data_infos_a = generation_unpaired_dataset.data_infos_a + data_infos_b = generation_unpaired_dataset.data_infos_b assert data_infos_a == [ dict(path=str(unpair_folder / 'testA' / '5.jpg')) ] assert data_infos_b == [ dict(path=str(unpair_folder / 'testB' / '6.jpg')) ] - result = generation_unparied_dataset[0] - assert (len(generation_unparied_dataset) == 1) + result = generation_unpaired_dataset[0] + assert (len(generation_unpaired_dataset) == 1) assert check_keys_contain(result.keys(), target_keys) assert check_keys_contain(result['meta'].data.keys(), target_meta_keys) assert (result['meta'].data['img_a_path'] == str(unpair_folder / @@ -530,10 +530,10 @@ def test_generation_unpaired_dataset(self): 'testB' / '6.jpg')) # test_mode = False - generation_unparied_dataset = GenerationUnpairedDataset( + generation_unpaired_dataset = GenerationUnpairedDataset( dataroot=str(unpair_folder), pipeline=pipeline, test_mode=False) - data_infos_a = generation_unparied_dataset.data_infos_a - data_infos_b = generation_unparied_dataset.data_infos_b + data_infos_a = generation_unpaired_dataset.data_infos_a + data_infos_b = generation_unpaired_dataset.data_infos_b assert data_infos_a == [ dict(path=str(unpair_folder / 'trainA' / '1.jpg')), dict(path=str(unpair_folder / 'trainA' / '2.jpg')) @@ -542,18 +542,18 @@ def test_generation_unpaired_dataset(self): dict(path=str(unpair_folder / 'trainB' / '3.jpg')), dict(path=str(unpair_folder / 'trainB' / '4.jpg')) ] - assert (len(generation_unparied_dataset) == 2) + assert (len(generation_unpaired_dataset) == 2) img_b_paths = [ str(unpair_folder / 'trainB' / '3.jpg'), str(unpair_folder / 'trainB' / '4.jpg') ] - result = generation_unparied_dataset[0] + result = generation_unpaired_dataset[0] assert check_keys_contain(result.keys(), target_keys) assert check_keys_contain(result['meta'].data.keys(), target_meta_keys) assert (result['meta'].data['img_a_path'] == str(unpair_folder / 'trainA' / '1.jpg')) assert result['meta'].data['img_b_path'] in img_b_paths - result = generation_unparied_dataset[1] + result = generation_unpaired_dataset[1] assert check_keys_contain(result.keys(), target_keys) assert check_keys_contain(result['meta'].data.keys(), target_meta_keys) assert (result['meta'].data['img_a_path'] == str(unpair_folder / @@ -566,7 +566,7 @@ def test_repeat_dataset(): class ToyDataset(Dataset): def __init__(self): - super(ToyDataset, self).__init__() + super().__init__() self.members = [1, 2, 3, 4, 5] def __len__(self): diff --git a/tests/test_eval_hook.py b/tests/test_eval_hook.py index 53a87d57ae..67387b0b61 100644 --- a/tests/test_eval_hook.py +++ b/tests/test_eval_hook.py @@ -25,7 +25,7 @@ def __len__(self): class ExampleModel(nn.Module): def __init__(self): - super(ExampleModel, self).__init__() + super().__init__() self.test_cfg = None self.conv = nn.Conv2d(3, 3, 3) diff --git a/tests/test_loading.py b/tests/test_loading.py index 6578b807dd..d2ad3f358f 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -107,7 +107,7 @@ def test_load_image_from_file_list(): image_loader(results) -class TestMattingLoading(object): +class TestMattingLoading: @staticmethod def check_keys_contain(result_keys, target_keys): @@ -140,7 +140,7 @@ def test_random_load_bg(self): "(bg_dir='tests/data/bg')") -class TestInpaintLoading(object): +class TestInpaintLoading: @classmethod def setup_class(cls): @@ -229,7 +229,7 @@ def test_load_mask(self): results = loader(results) -class TestGenerationLoading(object): +class TestGenerationLoading: @staticmethod def check_keys_contain(result_keys, target_keys): diff --git a/tests/test_normalization.py b/tests/test_normalization.py index aaf7a74423..0e69317993 100644 --- a/tests/test_normalization.py +++ b/tests/test_normalization.py @@ -4,7 +4,7 @@ from mmedit.datasets.pipelines import Normalize, RescaleToZeroOne -class TestAugmentations(object): +class TestAugmentations: @staticmethod def assert_img_equal(img, ref_img, ratio_thr=0.999): diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 034f20e4c0..cb396c8305 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -7,7 +7,7 @@ class ExampleModel(nn.Module): def __init__(self): - super(ExampleModel, self).__init__() + super().__init__() self.model1 = nn.Conv2d(3, 8, kernel_size=3) self.model2 = nn.Conv2d(3, 4, kernel_size=3) diff --git a/tests/test_visual_hook.py b/tests/test_visual_hook.py index 35f84f5a99..8de925f37e 100644 --- a/tests/test_visual_hook.py +++ b/tests/test_visual_hook.py @@ -28,7 +28,7 @@ def __len__(self): class ExampleModel(nn.Module): def __init__(self): - super(ExampleModel, self).__init__() + super().__init__() self.test_cfg = None def train_step(self, data_batch, optimizer):