Skip to content

Commit

Permalink
Unify exception replacing and rewrite some TimeoutError cases
Browse files Browse the repository at this point in the history
  • Loading branch information
mxr committed Sep 17, 2023
1 parent 973eed1 commit 52794e6
Show file tree
Hide file tree
Showing 7 changed files with 891 additions and 624 deletions.
1 change: 1 addition & 0 deletions pyupgrade/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class State(NamedTuple):

RECORD_FROM_IMPORTS = frozenset((
'__future__',
'asyncio',
'functools',
'mmap',
'select',
Expand Down
189 changes: 189 additions & 0 deletions pyupgrade/_plugins/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
from __future__ import annotations

import ast
import functools
from typing import Iterable
from typing import NamedTuple

from tokenize_rt import Offset
from tokenize_rt import Token

from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
from pyupgrade._data import Version
from pyupgrade._token_helpers import arg_str
from pyupgrade._token_helpers import find_op
from pyupgrade._token_helpers import parse_call_args
from pyupgrade._token_helpers import replace_name


class ExceptionRewriteTarget(NamedTuple):
module: str | None
cls: str
min_version: Version


def _fix_except(
i: int,
tokens: list[Token],
*,
state: State,
rewritten_name: str,
exc_rewrite_targets: Iterable[ExceptionRewriteTarget],
) -> None:
# find all the arg strs in the tuple
except_index = i
while tokens[except_index].src != 'except':
except_index -= 1
start = find_op(tokens, except_index, '(')
func_args, end = parse_call_args(tokens, start)

# save the exceptions and remove the block
arg_strs = [arg_str(tokens, *arg) for arg in func_args]
del tokens[start:end]

# rewrite the block without dupes
args = []
for arg in arg_strs:
left, part, right = arg.partition('.')
if any(
state.settings.min_version >= ert.min_version and
left == ert.module and part == '.' and right == ert.cls
for ert in exc_rewrite_targets if ert.module
):
args.append(rewritten_name)
elif (
part == right == '' and
any(
left == ert.cls for ert in exc_rewrite_targets
if not ert.module and
state.settings.min_version >= ert.min_version
)
):
args.append(rewritten_name)
elif any(
state.settings.min_version >= ert.min_version and
left == ert.cls and part == right == '' and
ert.cls in state.from_imports[ert.module]
for ert in exc_rewrite_targets if ert.module
):
args.append(rewritten_name)
else:
args.append(arg)

unique_args = tuple(dict.fromkeys(args))

if len(unique_args) > 1:
joined = '({})'.format(', '.join(unique_args))
elif tokens[start - 1].name != 'UNIMPORTANT_WS':
joined = f' {unique_args[0]}'
else:
joined = unique_args[0]

new = Token('CODE', joined)
tokens.insert(start, new)


def is_alias(
node: ast.AST,
state: State,
exc_rewrite_targets: Iterable[ExceptionRewriteTarget],
) -> tuple[Offset, str] | None:
if (
isinstance(node, ast.Name) and
any(
node.id == ert.cls for ert in exc_rewrite_targets if
not ert.module and
state.settings.min_version >= ert.min_version
)
):
return ast_to_offset(node), node.id
elif (
isinstance(node, ast.Name) and
any(
state.settings.min_version >= ert.min_version and
node.id == ert.cls and
node.id in state.from_imports[ert.module]
for ert in exc_rewrite_targets if ert.module
)
):
return ast_to_offset(node), node.id
elif (
isinstance(node, ast.Attribute) and
isinstance(node.value, ast.Name) and
any(
state.settings.min_version >= ert.min_version and
node.attr == ert.cls and
node.value.id == ert.module
for ert in exc_rewrite_targets if ert.module
)
):
return ast_to_offset(node), node.attr
else:
return None


def _alias_cbs(
node: ast.AST,
state: State,
rewritten_name: str,
exc_rewrite_targets: Iterable[ExceptionRewriteTarget],
) -> Iterable[tuple[Offset, TokenFunc]]:
offset_name = is_alias(node, state, exc_rewrite_targets)
if offset_name is not None:
offset, name = offset_name
func = functools.partial(replace_name, name=name, new=rewritten_name)
yield offset, func


def visit_Raise(
state: State,
node: ast.Raise,
rewritten_name: str,
exc_rewrite_targets: Iterable[ExceptionRewriteTarget],
) -> Iterable[tuple[Offset, TokenFunc]]:
if node.exc is not None:
yield from _alias_cbs(
node.exc,
state,
rewritten_name,
exc_rewrite_targets,
)
if isinstance(node.exc, ast.Call):
yield from _alias_cbs(
node.exc.func,
state,
rewritten_name,
exc_rewrite_targets,
)


def visit_Try(
state: State,
node: ast.Try,
rewritten_name: str,
exc_rewrite_targets: Iterable[ExceptionRewriteTarget],
) -> Iterable[tuple[Offset, TokenFunc]]:
for handler in node.handlers:
if (
isinstance(handler.type, ast.Tuple) and
any(
is_alias(elt, state, exc_rewrite_targets)
for elt in handler.type.elts
)
):
func = functools.partial(
_fix_except,
state=state,
rewritten_name=rewritten_name,
exc_rewrite_targets=exc_rewrite_targets,
)
yield ast_to_offset(handler.type), func
elif handler.type is not None:
yield from _alias_cbs(
handler.type,
state,
rewritten_name,
exc_rewrite_targets,
)
124 changes: 15 additions & 109 deletions pyupgrade/_plugins/oserror_aliases.py
Original file line number Diff line number Diff line change
@@ -1,104 +1,23 @@
from __future__ import annotations

import ast
import functools
from typing import Iterable

from tokenize_rt import Offset
from tokenize_rt import Token

from pyupgrade._ast_helpers import ast_to_offset
from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
from pyupgrade._token_helpers import arg_str
from pyupgrade._token_helpers import find_op
from pyupgrade._token_helpers import parse_call_args
from pyupgrade._token_helpers import replace_name
from pyupgrade._plugins import exceptions

ERROR_NAMES = frozenset(('EnvironmentError', 'IOError', 'WindowsError'))
ERROR_MODULES = frozenset(('mmap', 'select', 'socket'))


def _fix_oserror_except(
i: int,
tokens: list[Token],
*,
from_imports: dict[str, set[str]],
) -> None:
# find all the arg strs in the tuple
except_index = i
while tokens[except_index].src != 'except':
except_index -= 1
start = find_op(tokens, except_index, '(')
func_args, end = parse_call_args(tokens, start)

# save the exceptions and remove the block
arg_strs = [arg_str(tokens, *arg) for arg in func_args]
del tokens[start:end]

# rewrite the block without dupes
args = []
for arg in arg_strs:
left, part, right = arg.partition('.')
if left in ERROR_MODULES and part == '.' and right == 'error':
args.append('OSError')
elif left in ERROR_NAMES and part == right == '':
args.append('OSError')
elif (
left == 'error' and
part == right == '' and
any('error' in from_imports[mod] for mod in ERROR_MODULES)
):
args.append('OSError')
else:
args.append(arg)

unique_args = tuple(dict.fromkeys(args))

if len(unique_args) > 1:
joined = '({})'.format(', '.join(unique_args))
elif tokens[start - 1].name != 'UNIMPORTANT_WS':
joined = f' {unique_args[0]}'
else:
joined = unique_args[0]

new = Token('CODE', joined)
tokens.insert(start, new)


def _is_oserror_alias(
node: ast.AST,
from_imports: dict[str, set[str]],
) -> tuple[Offset, str] | None:
if isinstance(node, ast.Name) and node.id in ERROR_NAMES:
return ast_to_offset(node), node.id
elif (
isinstance(node, ast.Name) and
node.id == 'error' and
any(node.id in from_imports[mod] for mod in ERROR_MODULES)
):
return ast_to_offset(node), node.id
elif (
isinstance(node, ast.Attribute) and
isinstance(node.value, ast.Name) and
node.value.id in ERROR_MODULES and
node.attr == 'error'
):
return ast_to_offset(node), node.attr
else:
return None


def _oserror_alias_cbs(
node: ast.AST,
from_imports: dict[str, set[str]],
) -> Iterable[tuple[Offset, TokenFunc]]:
offset_name = _is_oserror_alias(node, from_imports)
if offset_name is not None:
offset, name = offset_name
func = functools.partial(replace_name, name=name, new='OSError')
yield offset, func
EXC_REWRITE_TARGETS = (
exceptions.ExceptionRewriteTarget('mmap', 'error', (3,)),
exceptions.ExceptionRewriteTarget('select', 'error', (3,)),
exceptions.ExceptionRewriteTarget('socket', 'error', (3,)),
exceptions.ExceptionRewriteTarget(None, 'IOError', (3,)),
exceptions.ExceptionRewriteTarget(None, 'EnvironmentError', (3,)),
exceptions.ExceptionRewriteTarget(None, 'WindowsError', (3,)),
)


@register(ast.Raise)
Expand All @@ -107,10 +26,9 @@ def visit_Raise(
node: ast.Raise,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
if node.exc is not None:
yield from _oserror_alias_cbs(node.exc, state.from_imports)
if isinstance(node.exc, ast.Call):
yield from _oserror_alias_cbs(node.exc.func, state.from_imports)
return exceptions.visit_Raise(
state, node, 'OSError', EXC_REWRITE_TARGETS,
)


@register(ast.Try)
Expand All @@ -119,18 +37,6 @@ def visit_Try(
node: ast.Try,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
for handler in node.handlers:
if (
isinstance(handler.type, ast.Tuple) and
any(
_is_oserror_alias(elt, state.from_imports)
for elt in handler.type.elts
)
):
func = functools.partial(
_fix_oserror_except,
from_imports=state.from_imports,
)
yield ast_to_offset(handler.type), func
elif handler.type is not None:
yield from _oserror_alias_cbs(handler.type, state.from_imports)
return exceptions.visit_Try(
state, node, 'OSError', EXC_REWRITE_TARGETS,
)
38 changes: 38 additions & 0 deletions pyupgrade/_plugins/timeouterror_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

import ast
from typing import Iterable

from tokenize_rt import Offset

from pyupgrade._data import register
from pyupgrade._data import State
from pyupgrade._data import TokenFunc
from pyupgrade._plugins import exceptions

EXC_REWRITE_TARGETS = (
exceptions.ExceptionRewriteTarget('socket', 'timeout', (3, 10)),
exceptions.ExceptionRewriteTarget('asyncio', 'TimeoutError', (3, 11)),
)


@register(ast.Raise)
def visit_Raise(
state: State,
node: ast.Raise,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
return exceptions.visit_Raise(
state, node, 'TimeoutError', EXC_REWRITE_TARGETS,
)


@register(ast.Try)
def visit_Try(
state: State,
node: ast.Try,
parent: ast.AST,
) -> Iterable[tuple[Offset, TokenFunc]]:
return exceptions.visit_Try(
state, node, 'TimeoutError', EXC_REWRITE_TARGETS,
)
Loading

0 comments on commit 52794e6

Please sign in to comment.