Skip to content

Commit

Permalink
Preserve non-decodable %-sequences intact when unquote.
Browse files Browse the repository at this point in the history
  • Loading branch information
serhiy-storchaka committed Oct 13, 2020
1 parent c8ac5e0 commit e3eb8bb
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 73 deletions.
1 change: 1 addition & 0 deletions CHANGES/517.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
No longer loss characters when decode incorrect percent-sequences (like ``%e2%82%f8``). All non-decodable percent-sequences are now preserved intact.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ junit_suite_name = yarl_test_suite


[flake8]
ignore = E301,E302,E704,W503,W504,F811
ignore = E203,E301,E302,E704,W503,W504,F811
max-line-length = 88

[mypy]
Expand Down
23 changes: 19 additions & 4 deletions tests/test_quoting.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,27 @@ def test_unquote_unsafe4(unquoter):
assert unquoter(unsafe="@")("a@b") == "a%40b"


def test_unquote_non_ascii(unquoter):
assert unquoter()("%F8") == "%F8"
@pytest.mark.parametrize(
("input", "expected"),
[
("%e2%82", "%e2%82"),
("%e2%82ac", "%e2%82ac"),
("%e2%82%f8", "%e2%82%f8"),
("%e2%82%2b", "%e2%82+"),
("%e2%82%e2%82%ac", "%e2%82€"),
("%e2%82%e2%82", "%e2%82%e2%82"),
],
)
def test_unquote_non_utf8(unquoter, input, expected):
assert unquoter()(input) == expected


def test_unquote_unsafe_non_utf8(unquoter):
assert unquoter(unsafe="\n")("%e2%82%0a") == "%e2%82%0A"


def test_unquote_non_ascii_non_tailing(unquoter):
assert unquoter()("%F8ab") == "%F8ab"
def test_unquote_plus_non_utf8(unquoter):
assert unquoter(qs=True)("%e2%82%2b") == "%e2%82%2B"


def test_quote_non_ascii(quoter):
Expand Down
62 changes: 31 additions & 31 deletions yarl/_quoting_c.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ from cpython.exc cimport PyErr_NoMemory
from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
from cpython.unicode cimport PyUnicode_DecodeASCII

import codecs
from string import ascii_letters, digits

cdef str GEN_DELIMS = ":/?#[]@"
Expand All @@ -20,6 +21,7 @@ cdef str QS = '+&=;'
DEF BUF_SIZE = 8 * 1024 # 8KiB
cdef char BUFFER[BUF_SIZE]

utf8_decoder = codecs.getincrementaldecoder("utf-8")

cdef inline Py_UCS4 _to_hex(uint8_t v):
if v < 10:
Expand Down Expand Up @@ -295,44 +297,51 @@ cdef class _Unquoter:
cdef str _do_unquote(self, str val):
if len(val) == 0:
return val
cdef str last_pct = ''
cdef bytearray pcts = bytearray()
cdef list ret = []
cdef bytes b
cdef str unquoted
cdef Py_UCS4 ch = 0
cdef int idx = 0
cdef int length = len(val)
cdef int start_pct

decoder = utf8_decoder()

while idx < length:
ch = val[idx]
idx += 1
if pcts:
try:
unquoted = pcts.decode('utf8')
except UnicodeDecodeError:
pass
else:
if ch == '%' and idx <= length - 2:
ch = _restore_ch(val[idx], val[idx + 1])
if ch != <Py_UCS4>-1:
b = bytes((<int>ch,))
idx += 2
try:
unquoted = decoder.decode(b)
except UnicodeDecodeError:
start_pct = idx - 3 - len(decoder.buffer) * 3
ret.append(val[start_pct : idx - 3])
decoder.reset()
try:
unquoted = decoder.decode(b)
except UnicodeDecodeError:
ret.append(val[idx - 3 : idx])
continue
if not unquoted:
continue
if self._qs and unquoted in '+=&;':
ret.append(self._qs_quoter(unquoted))
elif unquoted in self._unsafe:
ret.append(self._quoter(unquoted))
else:
ret.append(unquoted)
del pcts[:]

if ch == '%' and idx <= length - 2:
ch = _restore_ch(val[idx], val[idx + 1])
if ch != <Py_UCS4>-1:
pcts.append(ch)
last_pct = val[idx - 1 : idx + 2]
idx += 2
continue
else:
ch = '%'

if pcts:
ret.append(last_pct) # %F8ab
last_pct = ''
if decoder.buffer:
start_pct = idx - 1 - len(decoder.buffer) * 3
ret.append(val[start_pct : idx - 1])
decoder.reset()

if ch == '+':
if not self._qs or ch in self._unsafe:
Expand All @@ -350,16 +359,7 @@ cdef class _Unquoter:

ret.append(ch)

if pcts:
try:
unquoted = pcts.decode('utf8')
except UnicodeDecodeError:
ret.append(last_pct) # %F8
else:
if self._qs and unquoted in '+=&;':
ret.append(self._qs_quoter(unquoted))
elif unquoted in self._unsafe:
ret.append(self._quoter(unquoted))
else:
ret.append(unquoted)
if decoder.buffer:
ret.append(val[length - len(decoder.buffer) * 3 : length])

return ''.join(ret)
66 changes: 29 additions & 37 deletions yarl/_quoting_py.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import codecs
import re
from string import ascii_letters, ascii_lowercase, digits
from typing import Optional, cast
Expand All @@ -16,6 +17,8 @@
_IS_HEX = re.compile(b"[A-Z0-9][A-Z0-9]")
_IS_HEX_STR = re.compile("[A-Fa-f0-9][A-Fa-f0-9]")

utf8_decoder = codecs.getincrementaldecoder("utf-8")


class _Quoter:
def __init__(
Expand Down Expand Up @@ -127,19 +130,30 @@ def __call__(self, val: Optional[str]) -> Optional[str]:
raise TypeError("Argument should be str")
if not val:
return ""
last_pct = ""
pcts = bytearray()
decoder = utf8_decoder()
ret = []
idx = 0
while idx < len(val):
ch = val[idx]
idx += 1
if pcts:
try:
unquoted = pcts.decode("utf8")
except UnicodeDecodeError:
pass
else:
if ch == "%" and idx <= len(val) - 2:
pct = val[idx : idx + 2] # noqa: E203
if _IS_HEX_STR.fullmatch(pct):
b = bytes([int(pct, base=16)])
idx += 2
try:
unquoted = decoder.decode(b)
except UnicodeDecodeError:
start_pct = idx - 3 - len(decoder.buffer) * 3 # type: ignore
ret.append(val[start_pct : idx - 3])
decoder.reset()
try:
unquoted = decoder.decode(b)
except UnicodeDecodeError:
ret.append(val[idx - 3 : idx])
continue
if not unquoted:
continue
if self._qs and unquoted in "+=&;":
to_add = self._qs_quoter(unquoted)
if to_add is None: # pragma: no cover
Expand All @@ -152,19 +166,12 @@ def __call__(self, val: Optional[str]) -> Optional[str]:
ret.append(to_add)
else:
ret.append(unquoted)
del pcts[:]

if ch == "%" and idx <= len(val) - 2:
pct = val[idx : idx + 2] # noqa: E203
if _IS_HEX_STR.fullmatch(pct):
pcts.append(int(pct, base=16))
last_pct = "%" + pct
idx += 2
continue

if pcts:
ret.append(last_pct) # %F8ab
last_pct = ""
if decoder.buffer: # type: ignore
start_pct = idx - 1 - len(decoder.buffer) * 3 # type: ignore
ret.append(val[start_pct : idx - 1])
decoder.reset()

if ch == "+":
if not self._qs or ch in self._unsafe:
Expand All @@ -182,24 +189,9 @@ def __call__(self, val: Optional[str]) -> Optional[str]:

ret.append(ch)

if pcts:
try:
unquoted = pcts.decode("utf8")
except UnicodeDecodeError:
ret.append(last_pct) # %F8
else:
if self._qs and unquoted in "+=&;":
to_add = self._qs_quoter(unquoted)
if to_add is None: # pragma: no cover
raise RuntimeError("Cannot quote None")
ret.append(to_add)
elif unquoted in self._unsafe:
to_add = self._qs_quoter(unquoted)
if to_add is None: # pragma: no cover
raise RuntimeError("Cannot quote None")
ret.append(to_add)
else:
ret.append(unquoted)
if decoder.buffer: # type: ignore
ret.append(val[-len(decoder.buffer) * 3 :]) # type: ignore

ret2 = "".join(ret)
if ret2 == val:
return val
Expand Down

0 comments on commit e3eb8bb

Please sign in to comment.