diff --git a/Makefile b/Makefile index 09048eeba76..4662dd51928 100644 --- a/Makefile +++ b/Makefile @@ -51,7 +51,7 @@ cov-dev-full: .develop @echo "Run without extensions" @AIOHTTP_NO_EXTENSIONS=1 py.test --cov=aiohttp tests @echo "Run in debug mode" - @PYTHONASYNCIODEBUG=1 py.test --cov=aiohttp --cov-append tests + @PYTHONASYNCIODEBUG=1 py.test -s -v --cov=aiohttp --cov-append tests @echo "Regular run" @py.test --cov=aiohttp --cov-report=term --cov-report=html --cov-append tests @echo "open file://`pwd`/coverage/index.html" diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index 9b5d862baa9..c0ee9ee514a 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -4,6 +4,7 @@ from . import hdrs # noqa from .client import * # noqa +from .formdata import * # noqa from .helpers import * # noqa from .http_message import HttpVersion, HttpVersion10, HttpVersion11 # noqa from .http_websocket import WSMsgType, WSCloseCode, WSMessage, WebSocketError # noqa @@ -11,6 +12,8 @@ from .multipart import * # noqa from .file_sender import FileSender # noqa from .cookiejar import CookieJar # noqa +from .payload import * # noqa +from .payload_streamer import * # noqa from .resolver import * # noqa # deprecated #1657 @@ -23,9 +26,12 @@ __all__ = (client.__all__ + # noqa + formdata.__all__ + # noqa helpers.__all__ + # noqa - streams.__all__ + # noqa multipart.__all__ + # noqa + payload.__all__ + # noqa + payload_streamer.__all__ + # noqa + streams.__all__ + # noqa ('hdrs', 'FileSender', 'HttpVersion', 'HttpVersion10', 'HttpVersion11', 'WSMsgType', 'MsgType', 'WSCloseCode', diff --git a/aiohttp/abc.py b/aiohttp/abc.py index fb426f14421..bc38f9bf7c8 100644 --- a/aiohttp/abc.py +++ b/aiohttp/abc.py @@ -129,3 +129,20 @@ def update_cookies(self, cookies, response_url=None): @abstractmethod def filter_cookies(self, request_url): """Return the jar's cookies filtered by their attributes.""" + + +class AbstractPayloadWriter(ABC): + + @abstractmethod + def write(self, chunk): + """Write chunk into stream""" + + @asyncio.coroutine # pragma: no branch + @abstractmethod + def write_eof(self, chunk=b''): + """Write last chunk""" + + @asyncio.coroutine # pragma: no branch + @asyncio.coroutine + def drain(self): + """Flush the write buffer.""" diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index cd902c60d39..557f4fc1c83 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -1,8 +1,6 @@ import asyncio import io import json -import mimetypes -import os import sys import traceback import warnings @@ -13,11 +11,11 @@ import aiohttp -from . import hdrs, helpers, http, streams +from . import hdrs, helpers, http, payload +from .formdata import FormData from .helpers import PY_35, HeadersMixin, SimpleCookie, _TimeServiceTimeoutNoop from .http import HttpMessage from .log import client_logger -from .multipart import MultipartWriter from .streams import FlowControlStreamReader try: @@ -219,86 +217,54 @@ def update_auth(self, auth): self.headers[hdrs.AUTHORIZATION] = auth.encode() - def update_body_from_data(self, data, skip_auto_headers): - if not data: + def update_body_from_data(self, body, skip_auto_headers): + if not body: return - if isinstance(data, str): - data = data.encode(self.encoding) - - if isinstance(data, (bytes, bytearray)): - self.body = data - if (hdrs.CONTENT_TYPE not in self.headers and - hdrs.CONTENT_TYPE not in skip_auto_headers): - self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream' - if hdrs.CONTENT_LENGTH not in self.headers and not self.chunked: - self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body)) - - elif isinstance(data, (asyncio.StreamReader, streams.StreamReader, - streams.DataQueue)): - self.body = data + if asyncio.iscoroutine(body): + warnings.warn( + 'coroutine as data object is deprecated, ' + 'use aiohttp.streamer #1664', + DeprecationWarning, stacklevel=2) - elif asyncio.iscoroutine(data): - self.body = data + self.body = body if (hdrs.CONTENT_LENGTH not in self.headers and self.chunked is None): self.chunked = True - elif isinstance(data, io.IOBase): - assert not isinstance(data, io.StringIO), \ - 'attempt to send text data instead of binary' - self.body = data - if not self.chunked and isinstance(data, io.BytesIO): - # Not chunking if content-length can be determined - size = len(data.getbuffer()) - self.headers[hdrs.CONTENT_LENGTH] = str(size) - self.chunked = False - elif (not self.chunked and - isinstance(data, (io.BufferedReader, io.BufferedRandom))): - # Not chunking if content-length can be determined - try: - size = os.fstat(data.fileno()).st_size - data.tell() - self.headers[hdrs.CONTENT_LENGTH] = str(size) - self.chunked = False - except OSError: - # data.fileno() is not supported, e.g. - # io.BufferedReader(io.BytesIO(b'data')) - self.chunked = True - else: - self.chunked = True + return - if hasattr(data, 'mode'): - if data.mode == 'r': - raise ValueError('file {!r} should be open in binary mode' - ''.format(data)) - if (hdrs.CONTENT_TYPE not in self.headers and - hdrs.CONTENT_TYPE not in skip_auto_headers and - hasattr(data, 'name')): - mime = mimetypes.guess_type(data.name)[0] - mime = 'application/octet-stream' if mime is None else mime - self.headers[hdrs.CONTENT_TYPE] = mime - - elif isinstance(data, MultipartWriter): - self.body = data.serialize() - self.headers.update(data.headers) - self.chunked = True + # FormData + if isinstance(body, FormData): + body = body(self.encoding) - else: - if not isinstance(data, helpers.FormData): - data = helpers.FormData(data) + try: + body = payload.PAYLOAD_REGISTRY.get(body) + except payload.LookupError: + body = FormData(body)(self.encoding) - self.body = data(self.encoding) + self.body = body - if (hdrs.CONTENT_TYPE not in self.headers and - hdrs.CONTENT_TYPE not in skip_auto_headers): - self.headers[hdrs.CONTENT_TYPE] = data.content_type + # enable chunked encoding if needed + if not self.chunked: + if hdrs.CONTENT_LENGTH not in self.headers: + size = body.size + if size is None: + self.chunked = True + else: + if hdrs.CONTENT_LENGTH not in self.headers: + self.headers[hdrs.CONTENT_LENGTH] = str(size) - if data.is_multipart: - self.chunked = True - else: - if (hdrs.CONTENT_LENGTH not in self.headers and - not self.chunked): - self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body)) + # set content-type + if (hdrs.CONTENT_TYPE not in self.headers and + hdrs.CONTENT_TYPE not in skip_auto_headers): + self.headers[hdrs.CONTENT_TYPE] = body.content_type + + # copy payload headers + if body.headers: + for (key, value) in body.headers.items(): + if key not in self.headers: + self.headers[key] = value def update_transfer_encoding(self): """Analyze transfer-encoding header.""" @@ -344,7 +310,10 @@ def write_bytes(self, request, conn): yield from self._continue try: - if asyncio.iscoroutine(self.body): + if isinstance(self.body, payload.Payload): + yield from self.body.write(request) + + elif asyncio.iscoroutine(self.body): exc = None value = None stream = self.body @@ -377,29 +346,6 @@ def write_bytes(self, request, conn): raise ValueError( 'Bytes object is expected, got: %s.' % type(result)) - - elif isinstance(self.body, (asyncio.StreamReader, - streams.StreamReader)): - chunk = yield from self.body.read(streams.DEFAULT_LIMIT) - while chunk: - yield from request.write(chunk, drain=True) - chunk = yield from self.body.read(streams.DEFAULT_LIMIT) - - elif isinstance(self.body, streams.DataQueue): - while True: - try: - chunk = yield from self.body.read() - if not chunk: - break - yield from request.write(chunk) - except streams.EofStream: - break - - elif isinstance(self.body, io.IOBase): - chunk = self.body.read(streams.DEFAULT_LIMIT) - while chunk: - request.write(chunk) - chunk = self.body.read(self.chunked) else: if isinstance(self.body, (bytes, bytearray)): self.body = (self.body,) diff --git a/aiohttp/formdata.py b/aiohttp/formdata.py new file mode 100644 index 00000000000..86550952366 --- /dev/null +++ b/aiohttp/formdata.py @@ -0,0 +1,122 @@ +import io +from urllib.parse import urlencode + +from multidict import MultiDict, MultiDictProxy + +from . import hdrs, multipart, payload +from .helpers import guess_filename + +__all__ = ('FormData',) + + +class FormData: + """Helper class for multipart/form-data and + application/x-www-form-urlencoded body generation.""" + + def __init__(self, fields=(), quote_fields=True): + self._writer = multipart.MultipartWriter('form-data') + self._fields = [] + self._is_multipart = False + self._quote_fields = quote_fields + + if isinstance(fields, dict): + fields = list(fields.items()) + elif not isinstance(fields, (list, tuple)): + fields = (fields,) + self.add_fields(*fields) + + def add_field(self, name, value, *, content_type=None, filename=None, + content_transfer_encoding=None): + + if isinstance(value, io.IOBase): + self._is_multipart = True + elif isinstance(value, (bytes, bytearray, memoryview)): + if filename is None and content_transfer_encoding is None: + filename = name + + type_options = MultiDict({'name': name}) + if filename is not None and not isinstance(filename, str): + raise TypeError('filename must be an instance of str. ' + 'Got: %s' % filename) + if filename is None and isinstance(value, io.IOBase): + filename = guess_filename(value, name) + if filename is not None: + type_options['filename'] = filename + self._is_multipart = True + + headers = {} + if content_type is not None: + if not isinstance(content_type, str): + raise TypeError('content_type must be an instance of str. ' + 'Got: %s' % content_type) + headers[hdrs.CONTENT_TYPE] = content_type + self._is_multipart = True + if content_transfer_encoding is not None: + if not isinstance(content_transfer_encoding, str): + raise TypeError('content_transfer_encoding must be an instance' + ' of str. Got: %s' % content_transfer_encoding) + headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding + self._is_multipart = True + + self._fields.append((type_options, headers, value)) + + def add_fields(self, *fields): + to_add = list(fields) + + while to_add: + rec = to_add.pop(0) + + if isinstance(rec, io.IOBase): + k = guess_filename(rec, 'unknown') + self.add_field(k, rec) + + elif isinstance(rec, (MultiDictProxy, MultiDict)): + to_add.extend(rec.items()) + + elif isinstance(rec, (list, tuple)) and len(rec) == 2: + k, fp = rec + self.add_field(k, fp) + + else: + raise TypeError('Only io.IOBase, multidict and (name, file) ' + 'pairs allowed, use .add_field() for passing ' + 'more complex parameters, got {!r}' + .format(rec)) + + def _gen_form_urlencoded(self, encoding): + # form data (x-www-form-urlencoded) + data = [] + for type_options, _, value in self._fields: + data.append((type_options['name'], value)) + + return payload.BytesPayload( + urlencode(data, doseq=True).encode(encoding), + content_type='application/x-www-form-urlencoded') + + def _gen_form_data(self, encoding): + """Encode a list of fields using the multipart/form-data MIME format""" + for dispparams, headers, value in self._fields: + if hdrs.CONTENT_TYPE in headers: + part = payload.get_payload( + value, content_type=headers[hdrs.CONTENT_TYPE], + headers=headers, encoding=encoding) + else: + part = payload.get_payload( + value, headers=headers, encoding=encoding) + if dispparams: + part.set_content_disposition( + 'form-data', quote_fields=self._quote_fields, **dispparams + ) + # FIXME cgi.FieldStorage doesn't likes body parts with + # Content-Length which were sent via chunked transfer encoding + part.headers.pop(hdrs.CONTENT_LENGTH, None) + + self._writer.append_payload(part) + + return self._writer + + def __call__(self, encoding): + if self._is_multipart: + return self._gen_form_data(encoding) + else: + return self._gen_form_urlencoded(encoding) diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 98a749409ab..6da1edb41ba 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -7,7 +7,6 @@ import datetime import functools import heapq -import io import os import re import sys @@ -16,10 +15,9 @@ from functools import total_ordering from pathlib import Path from time import gmtime -from urllib.parse import urlencode +from urllib.parse import quote from async_timeout import timeout -from multidict import MultiDict, MultiDictProxy from . import hdrs @@ -38,7 +36,7 @@ from .backport_cookies import SimpleCookie # noqa -__all__ = ('BasicAuth', 'create_future', 'FormData', 'parse_mimetype', +__all__ = ('BasicAuth', 'create_future', 'parse_mimetype', 'Timeout', 'ensure_future', 'noop') @@ -46,6 +44,13 @@ Timeout = timeout NO_EXTENSIONS = bool(os.environ.get('AIOHTTP_NO_EXTENSIONS')) +CHAR = set(chr(i) for i in range(0, 128)) +CTL = set(chr(i) for i in range(0, 32)) | {chr(127), } +SEPARATORS = {'(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', + '?', '=', '{', '}', ' ', chr(9)} +TOKEN = CHAR ^ CTL ^ SEPARATORS + + if sys.version_info < (3, 5): noop = tuple else: @@ -112,121 +117,6 @@ def create_future(loop): return asyncio.Future(loop=loop) -class FormData: - """Helper class for multipart/form-data and - application/x-www-form-urlencoded body generation.""" - - def __init__(self, fields=(), quote_fields=True): - from . import multipart - self._writer = multipart.MultipartWriter('form-data') - self._fields = [] - self._is_multipart = False - self._quote_fields = quote_fields - - if isinstance(fields, dict): - fields = list(fields.items()) - elif not isinstance(fields, (list, tuple)): - fields = (fields,) - self.add_fields(*fields) - - @property - def is_multipart(self): - return self._is_multipart - - @property - def content_type(self): - if self._is_multipart: - return self._writer.headers[hdrs.CONTENT_TYPE] - else: - return 'application/x-www-form-urlencoded' - - def add_field(self, name, value, *, content_type=None, filename=None, - content_transfer_encoding=None): - - if isinstance(value, io.IOBase): - self._is_multipart = True - elif isinstance(value, (bytes, bytearray, memoryview)): - if filename is None and content_transfer_encoding is None: - filename = name - - type_options = MultiDict({'name': name}) - if filename is not None and not isinstance(filename, str): - raise TypeError('filename must be an instance of str. ' - 'Got: %s' % filename) - if filename is None and isinstance(value, io.IOBase): - filename = guess_filename(value, name) - if filename is not None: - type_options['filename'] = filename - self._is_multipart = True - - headers = {} - if content_type is not None: - if not isinstance(content_type, str): - raise TypeError('content_type must be an instance of str. ' - 'Got: %s' % content_type) - headers[hdrs.CONTENT_TYPE] = content_type - self._is_multipart = True - if content_transfer_encoding is not None: - if not isinstance(content_transfer_encoding, str): - raise TypeError('content_transfer_encoding must be an instance' - ' of str. Got: %s' % content_transfer_encoding) - headers[hdrs.CONTENT_TRANSFER_ENCODING] = content_transfer_encoding - self._is_multipart = True - - self._fields.append((type_options, headers, value)) - - def add_fields(self, *fields): - to_add = list(fields) - - while to_add: - rec = to_add.pop(0) - - if isinstance(rec, io.IOBase): - k = guess_filename(rec, 'unknown') - self.add_field(k, rec) - - elif isinstance(rec, (MultiDictProxy, MultiDict)): - to_add.extend(rec.items()) - - elif isinstance(rec, (list, tuple)) and len(rec) == 2: - k, fp = rec - self.add_field(k, fp) - - else: - raise TypeError('Only io.IOBase, multidict and (name, file) ' - 'pairs allowed, use .add_field() for passing ' - 'more complex parameters, got {!r}' - .format(rec)) - - def _gen_form_urlencoded(self, encoding): - # form data (x-www-form-urlencoded) - data = [] - for type_options, _, value in self._fields: - data.append((type_options['name'], value)) - - data = urlencode(data, doseq=True) - return data.encode(encoding) - - def _gen_form_data(self, *args, **kwargs): - """Encode a list of fields using the multipart/form-data MIME format""" - for dispparams, headers, value in self._fields: - part = self._writer.append(value, headers) - if dispparams: - part.set_content_disposition( - 'form-data', quote_fields=self._quote_fields, **dispparams - ) - # FIXME cgi.FieldStorage doesn't likes body parts with - # Content-Length which were sent via chunked transfer encoding - part.headers.pop(hdrs.CONTENT_LENGTH, None) - yield from self._writer.serialize() - - def __call__(self, encoding): - if self._is_multipart: - return self._gen_form_data(encoding) - else: - return self._gen_form_urlencoded(encoding) - - def parse_mimetype(mimetype): """Parses a MIME type into its components. @@ -271,6 +161,33 @@ def guess_filename(obj, default=None): return default +def content_disposition_header(disptype, quote_fields=True, **params): + """Sets ``Content-Disposition`` header. + + :param str disptype: Disposition type: inline, attachment, form-data. + Should be valid extension token (see RFC 2183) + :param dict params: Disposition params + """ + if not disptype or not (TOKEN > set(disptype)): + raise ValueError('bad content disposition type {!r}' + ''.format(disptype)) + + value = disptype + if params: + lparams = [] + for key, val in params.items(): + if not key or not (TOKEN > set(key)): + raise ValueError('bad content disposition parameter' + ' {!r}={!r}'.format(key, val)) + qval = quote(val, '') if quote_fields else val + lparams.append((key, '"%s"' % qval)) + if key == 'filename': + lparams.append(('filename*', "utf-8''" + qval)) + sparams = '; '.join('='.join(pair) for pair in lparams) + value = '; '.join((value, sparams)) + return value + + class AccessLogger: """Helper object to log access. diff --git a/aiohttp/http_message.py b/aiohttp/http_message.py index 92a34a2596b..14c8b89154a 100644 --- a/aiohttp/http_message.py +++ b/aiohttp/http_message.py @@ -15,6 +15,7 @@ import aiohttp from . import hdrs +from .abc import AbstractPayloadWriter from .helpers import create_future, noop __all__ = ('RESPONSES', 'SERVER_SOFTWARE', @@ -33,7 +34,7 @@ HttpVersion11 = HttpVersion(1, 1) -class PayloadWriter: +class PayloadWriter(AbstractPayloadWriter): def __init__(self, stream, loop): if loop is None: diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 09d8f0d9e70..7952905c681 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -1,38 +1,28 @@ import asyncio import base64 import binascii -import io import json -import mimetypes -import os import re import uuid import warnings import zlib from collections import Mapping, Sequence, deque -from pathlib import Path -from urllib.parse import parse_qsl, quote, unquote, urlencode +from urllib.parse import parse_qsl, unquote, urlencode from multidict import CIMultiDict from .hdrs import (CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING, CONTENT_TYPE) -from .helpers import PY_35, PY_352, parse_mimetype +from .helpers import CHAR, PY_35, PY_352, TOKEN, parse_mimetype from .http import HttpParser +from .payload import (BytesPayload, LookupError, Payload, StringPayload, + get_payload) -__all__ = ('MultipartReader', 'MultipartWriter', - 'BodyPartReader', 'BodyPartWriter', +__all__ = ('MultipartReader', 'MultipartWriter', 'BodyPartReader', 'BadContentDispositionHeader', 'BadContentDispositionParam', 'parse_content_disposition', 'content_disposition_filename') -CHAR = set(chr(i) for i in range(0, 128)) -CTL = set(chr(i) for i in range(0, 32)) | {chr(127), } -SEPARATORS = {'(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', - '?', '=', '{', '}', ' ', chr(9)} -TOKEN = CHAR ^ CTL ^ SEPARATORS - - class BadContentDispositionHeader(RuntimeWarning): pass @@ -668,268 +658,22 @@ def _maybe_release_last_part(self): self._last_part = None -class BodyPartWriter(object): - """Multipart writer for single body part.""" - - def __init__(self, obj, headers=None, *, chunk_size=8192): - if isinstance(obj, MultipartWriter): - if headers is not None: - obj.headers.update(headers) - headers = obj.headers - elif headers is None: - headers = CIMultiDict() - elif not isinstance(headers, CIMultiDict): - headers = CIMultiDict(headers) - - self.obj = obj - self.headers = headers - self._chunk_size = chunk_size - self._fill_headers_with_defaults() - - self._serialize_map = { - bytes: self._serialize_bytes, - str: self._serialize_str, - io.IOBase: self._serialize_io, - MultipartWriter: self._serialize_multipart, - ('application', 'json'): self._serialize_json, - ('application', 'x-www-form-urlencoded'): self._serialize_form - } - self._validate_obj(obj, headers) - - def _validate_obj(self, obj, headers): - mtype, stype, *_ = parse_mimetype(headers.get(CONTENT_TYPE)) - if (mtype, stype) in self._serialize_map: - return - for key in self._serialize_map: - if isinstance(key, tuple): - continue - if isinstance(obj, key): - return - else: - raise TypeError('unexpected body part value type %r' % type(obj)) - - def _fill_headers_with_defaults(self): - if CONTENT_TYPE not in self.headers: - content_type = self._guess_content_type(self.obj) - if content_type is not None: - self.headers[CONTENT_TYPE] = content_type - - if CONTENT_LENGTH not in self.headers: - content_length = self._guess_content_length(self.obj) - if content_length is not None: - self.headers[CONTENT_LENGTH] = str(content_length) - - if CONTENT_DISPOSITION not in self.headers: - filename = self._guess_filename(self.obj) - if filename is not None: - self.set_content_disposition('attachment', filename=filename) - - def _guess_content_length(self, obj): - if isinstance(obj, bytes): - return len(obj) - elif isinstance(obj, str): - *_, params = parse_mimetype(self.headers.get(CONTENT_TYPE)) - charset = params.get('charset', 'us-ascii') - return len(obj.encode(charset)) - elif isinstance(obj, io.StringIO): - *_, params = parse_mimetype(self.headers.get(CONTENT_TYPE)) - charset = params.get('charset', 'us-ascii') - return len(obj.getvalue().encode(charset)) - obj.tell() - elif isinstance(obj, io.BytesIO): - return len(obj.getvalue()) - obj.tell() - elif isinstance(obj, io.IOBase): - try: - return os.fstat(obj.fileno()).st_size - obj.tell() - except (AttributeError, OSError): - return None - else: - return None - - def _guess_content_type(self, obj, default='application/octet-stream'): - if hasattr(obj, 'name'): - name = getattr(obj, 'name') - return mimetypes.guess_type(name)[0] - elif isinstance(obj, (str, io.StringIO)): - return 'text/plain; charset=utf-8' - else: - return default - - def _guess_filename(self, obj): - if isinstance(obj, io.IOBase): - name = getattr(obj, 'name', None) - if name is not None: - return Path(name).name - - def serialize(self): - """Yields byte chunks for body part.""" - - has_encoding = ( - CONTENT_ENCODING in self.headers and - self.headers[CONTENT_ENCODING] != 'identity' or - CONTENT_TRANSFER_ENCODING in self.headers - ) - if has_encoding: - # since we're following streaming approach which doesn't assumes - # any intermediate buffers, we cannot calculate real content length - # with the specified content encoding scheme. So, instead of lying - # about content length and cause reading issues, we have to strip - # this information. - self.headers.pop(CONTENT_LENGTH, None) - - if self.headers: - yield b'\r\n'.join( - b': '.join(map(lambda i: i.encode('latin1'), item)) - for item in self.headers.items() - ) - yield b'\r\n\r\n' - yield from self._maybe_encode_stream(self._serialize_obj()) - yield b'\r\n' - - def _serialize_obj(self): - obj = self.obj - mtype, stype, *_ = parse_mimetype(self.headers.get(CONTENT_TYPE)) - serializer = self._serialize_map.get((mtype, stype)) - if serializer is not None: - return serializer(obj) - - for key in self._serialize_map: - if not isinstance(key, tuple) and isinstance(obj, key): - return self._serialize_map[key](obj) - return self._serialize_default(obj) - - def _serialize_bytes(self, obj): - yield obj - - def _serialize_str(self, obj): - *_, params = parse_mimetype(self.headers.get(CONTENT_TYPE)) - yield obj.encode(params.get('charset', 'us-ascii')) - - def _serialize_io(self, obj): - while True: - chunk = obj.read(self._chunk_size) - if not chunk: - break - if isinstance(chunk, str): - yield from self._serialize_str(chunk) - else: - yield from self._serialize_bytes(chunk) - - def _serialize_multipart(self, obj): - yield from obj.serialize() - - def _serialize_json(self, obj): - *_, params = parse_mimetype(self.headers.get(CONTENT_TYPE)) - yield json.dumps(obj).encode(params.get('charset', 'utf-8')) - - def _serialize_form(self, obj): - if isinstance(obj, Mapping): - obj = list(obj.items()) - return self._serialize_str(urlencode(obj, doseq=True)) - - def _serialize_default(self, obj): - raise TypeError('unknown body part type %r' % type(obj)) - - def _maybe_encode_stream(self, stream): - if CONTENT_ENCODING in self.headers: - stream = self._apply_content_encoding(stream) - if CONTENT_TRANSFER_ENCODING in self.headers: - stream = self._apply_content_transfer_encoding(stream) - yield from stream - - def _apply_content_encoding(self, stream): - encoding = self.headers[CONTENT_ENCODING].lower() - if encoding == 'identity': - yield from stream - elif encoding in ('deflate', 'gzip'): - if encoding == 'gzip': - zlib_mode = 16 + zlib.MAX_WBITS - else: - zlib_mode = -zlib.MAX_WBITS - zcomp = zlib.compressobj(wbits=zlib_mode) - for chunk in stream: - yield zcomp.compress(chunk) - else: - yield zcomp.flush() - else: - raise RuntimeError('unknown content encoding: {}' - ''.format(encoding)) - - def _apply_content_transfer_encoding(self, stream): - encoding = self.headers[CONTENT_TRANSFER_ENCODING].lower() - if encoding == 'base64': - buffer = bytearray() - while True: - if buffer: - div, mod = divmod(len(buffer), 3) - chunk, buffer = buffer[:div * 3], buffer[div * 3:] - if chunk: - yield base64.b64encode(chunk) - chunk = next(stream, None) - if not chunk: - if buffer: - yield base64.b64encode(buffer[:]) - return - buffer.extend(chunk) - elif encoding == 'quoted-printable': - for chunk in stream: - yield binascii.b2a_qp(chunk) - elif encoding == 'binary': - yield from stream - else: - raise RuntimeError('unknown content transfer encoding: {}' - ''.format(encoding)) - - def set_content_disposition(self, disptype, quote_fields=True, **params): - """Sets ``Content-Disposition`` header. - - :param str disptype: Disposition type: inline, attachment, form-data. - Should be valid extension token (see RFC 2183) - :param dict params: Disposition params - """ - if not disptype or not (TOKEN > set(disptype)): - raise ValueError('bad content disposition type {!r}' - ''.format(disptype)) - value = disptype - if params: - lparams = [] - for key, val in params.items(): - if not key or not (TOKEN > set(key)): - raise ValueError('bad content disposition parameter' - ' {!r}={!r}'.format(key, val)) - qval = quote(val, '') if quote_fields else val - lparams.append((key, '"%s"' % qval)) - if key == 'filename': - lparams.append(('filename*', "utf-8''" + qval)) - sparams = '; '.join('='.join(pair) for pair in lparams) - value = '; '.join((value, sparams)) - self.headers[CONTENT_DISPOSITION] = value - - @property - def filename(self): - """Returns filename specified in Content-Disposition header or ``None`` - if missed.""" - _, params = parse_content_disposition( - self.headers.get(CONTENT_DISPOSITION)) - return content_disposition_filename(params) - - -class MultipartWriter(object): +class MultipartWriter(Payload): """Multipart body writer.""" - #: Body part reader class for non multipart/* content types. - part_writer_cls = BodyPartWriter - def __init__(self, subtype='mixed', boundary=None): boundary = boundary if boundary is not None else uuid.uuid4().hex try: - boundary.encode('us-ascii') + self._boundary = boundary.encode('us-ascii') except UnicodeEncodeError: raise ValueError('boundary should contains ASCII only chars') - self.headers = CIMultiDict() - self.headers[CONTENT_TYPE] = 'multipart/{}; boundary="{}"'.format( - subtype, boundary - ) - self.parts = [] + ctype = 'multipart/{}; boundary="{}"'.format(subtype, boundary) + + super().__init__(None, content_type=ctype) + + self._parts = [] + self._headers = CIMultiDict() + self._headers[CONTENT_TYPE] = self.content_type def __enter__(self): return self @@ -938,53 +682,191 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass def __iter__(self): - return iter(self.parts) + return iter(self._parts) def __len__(self): - return len(self.parts) + return len(self._parts) @property def boundary(self): - *_, params = parse_mimetype(self.headers.get(CONTENT_TYPE)) - return params['boundary'].encode('us-ascii') + return self._boundary def append(self, obj, headers=None): - """Adds a new body part to multipart writer.""" - if isinstance(obj, self.part_writer_cls): - if headers: + if headers is None: + headers = CIMultiDict() + + if isinstance(obj, Payload): + if obj.headers is not None: obj.headers.update(headers) - self.parts.append(obj) + else: + obj._headers = headers + self.append_payload(obj) else: - if not headers: - headers = CIMultiDict() - self.parts.append(self.part_writer_cls(obj, headers)) - return self.parts[-1] + try: + self.append_payload(get_payload(obj, headers=headers)) + except LookupError: + raise TypeError + + def append_payload(self, payload): + """Adds a new body part to multipart writer.""" + # content-type + if CONTENT_TYPE not in payload.headers: + payload.headers[CONTENT_TYPE] = payload.content_type + + # compression + encoding = payload.headers.get(CONTENT_ENCODING, '').lower() + if encoding and encoding not in ('deflate', 'gzip', 'identity'): + raise RuntimeError('unknown content encoding: {}'.format(encoding)) + if encoding == 'identity': + encoding = None + + # te encoding + te_encoding = payload.headers.get( + CONTENT_TRANSFER_ENCODING, '').lower() + if te_encoding not in ('', 'base64', 'quoted-printable', 'binary'): + raise RuntimeError('unknown content transfer encoding: {}' + ''.format(te_encoding)) + if te_encoding == 'binary': + te_encoding = None + + # size + size = payload.size + if size is not None and not (encoding or te_encoding): + payload.headers[CONTENT_LENGTH] = str(size) + + # render headers + headers = ''.join( + [k + ': ' + v + '\r\n' for k, v in payload.headers.items()] + ).encode('utf-8') + b'\r\n' + + self._parts.append((payload, headers, encoding, te_encoding)) def append_json(self, obj, headers=None): """Helper to append JSON part.""" - if not headers: + if headers is None: headers = CIMultiDict() - headers[CONTENT_TYPE] = 'application/json' - return self.append(obj, headers) + + *_, params = parse_mimetype(headers.get(CONTENT_TYPE)) + charset = params.get('charset', 'utf-8') + + data = json.dumps(obj).encode(charset) + self.append_payload( + BytesPayload( + data, headers=headers, content_type='application/json')) def append_form(self, obj, headers=None): """Helper to append form urlencoded part.""" - if not headers: - headers = CIMultiDict() - headers[CONTENT_TYPE] = 'application/x-www-form-urlencoded' assert isinstance(obj, (Sequence, Mapping)) - return self.append(obj, headers) - def serialize(self): - """Yields multipart byte chunks.""" - if not self.parts: - yield b'' + if headers is None: + headers = CIMultiDict() + + if isinstance(obj, Mapping): + obj = list(obj.items()) + data = urlencode(obj, doseq=True) + + return self.append_payload( + StringPayload(data, headers=headers, + content_type='application/x-www-form-urlencoded')) + + @property + def size(self): + """Size of the payload.""" + if not self._parts: + return 0 + + total = 0 + for part, headers, encoding, te_encoding in self._parts: + if encoding or te_encoding or part.size is None: + return None + + total += ( + 2 + len(self._boundary) + 2 + # b'--'+self._boundary+b'\r\n' + part.size + len(headers) + + 2 # b'\r\n' + ) + + total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n' + return total + + @asyncio.coroutine + def write(self, writer): + """Write body.""" + if not self._parts: return - for part in self.parts: - yield b'--' + self.boundary + b'\r\n' - yield from part.serialize() - else: - yield b'--' + self.boundary + b'--\r\n' + for part, headers, encoding, te_encoding in self._parts: + yield from writer.write(b'--' + self._boundary + b'\r\n') + yield from writer.write(headers) + + if encoding or te_encoding: + w = MultipartPayloadWriter(writer) + if encoding: + w.enable_compression(encoding) + if te_encoding: + w.enable_encoding(te_encoding) + yield from part.write(w) + yield from w.write_eof() + else: + yield from part.write(writer) + + yield from writer.write(b'\r\n') + + yield from writer.write(b'--' + self._boundary + b'--\r\n') - yield b'' + +class MultipartPayloadWriter: + + def __init__(self, writer): + self._writer = writer + self._encoding = None + self._compress = None + + def enable_encoding(self, encoding): + if encoding == 'base64': + self._encoding = encoding + self._encoding_buffer = bytearray() + elif encoding == 'quoted-printable': + self._encoding = 'quoted-printable' + + def enable_compression(self, encoding='deflate'): + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + self._compress = zlib.compressobj(wbits=zlib_mode) + + @asyncio.coroutine + def write_eof(self): + if self._compress is not None: + chunk = self._compress.flush() + if chunk: + self._compress = None + yield from self.write(chunk) + + if self._encoding == 'base64': + if self._encoding_buffer: + yield from self._writer.write(base64.b64encode( + self._encoding_buffer)) + + @asyncio.coroutine + def write(self, chunk): + if self._compress is not None: + if chunk: + chunk = self._compress.compress(chunk) + if not chunk: + return + + if self._encoding == 'base64': + self._encoding_buffer.extend(chunk) + + if self._encoding_buffer: + buffer = self._encoding_buffer + div, mod = divmod(len(buffer), 3) + enc_chunk, self._encoding_buffer = ( + buffer[:div * 3], buffer[div * 3:]) + if enc_chunk: + enc_chunk = base64.b64encode(enc_chunk) + yield from self._writer.write(enc_chunk) + elif self._encoding == 'quoted-printable': + yield from self._writer.write(binascii.b2a_qp(chunk)) + else: + yield from self._writer.write(chunk) diff --git a/aiohttp/payload.py b/aiohttp/payload.py new file mode 100644 index 00000000000..9924448ef45 --- /dev/null +++ b/aiohttp/payload.py @@ -0,0 +1,267 @@ +import asyncio +import io +import mimetypes +import os +from abc import ABC, abstractmethod + +from multidict import CIMultiDict + +from . import hdrs +from .helpers import (content_disposition_header, guess_filename, + parse_mimetype, sentinel) +from .streams import DEFAULT_LIMIT, DataQueue, EofStream, StreamReader + +__all__ = ('PAYLOAD_REGISTRY', 'get_payload', 'Payload', + 'BytesPayload', 'StringPayload', 'StreamReaderPayload', + 'IOBasePayload', 'BytesIOPayload', 'BufferedReaderPayload', + 'TextIOPayload', 'StringIOPayload') + + +class LookupError(Exception): + pass + + +def get_payload(data, *args, **kwargs): + return PAYLOAD_REGISTRY.get(data, *args, **kwargs) + + +class PayloadRegistry: + """Payload registry. + + note: we need zope.interface for more efficient adapter search + """ + + def __init__(self): + self._registry = [] + + def get(self, data, *args, **kwargs): + if isinstance(data, Payload): + return data + for ctor, type in self._registry: + if isinstance(data, type): + return ctor(data, *args, **kwargs) + + raise LookupError() + + def register(self, ctor, type): + self._registry.append((ctor, type)) + + +class Payload(ABC): + + _size = None + _headers = None + _content_type = 'application/octet-stream' + + def __init__(self, value, *, headers=None, + content_type=sentinel, filename=None, encoding=None): + self._value = value + self._encoding = encoding + self._filename = filename + if headers is not None: + self._headers = CIMultiDict(headers) + if content_type is sentinel and hdrs.CONTENT_TYPE in headers: + content_type = headers[hdrs.CONTENT_TYPE] + + if content_type is sentinel: + content_type = None + + self._content_type = content_type + + @property + def size(self): + """Size of the payload.""" + return self._size + + @property + def filename(self): + """Filename of the payload.""" + return self._filename + + @property + def headers(self): + """Custom item headers""" + return self._headers + + @property + def content_type(self): + """Content type""" + if self._content_type is not None: + return self._content_type + elif self._filename is not None: + mime = mimetypes.guess_type(self._filename)[0] + return 'application/octet-stream' if mime is None else mime + else: + return Payload._content_type + + def set_content_disposition(self, disptype, quote_fields=True, **params): + """Sets ``Content-Disposition`` header. + + :param str disptype: Disposition type: inline, attachment, form-data. + Should be valid extension token (see RFC 2183) + :param dict params: Disposition params + """ + if self._headers is None: + self._headers = CIMultiDict() + + self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header( + disptype, quote_fields=quote_fields, **params) + + @asyncio.coroutine # pragma: no branch + @abstractmethod + def write(self, writer): + """Write payload + + :param AbstractPayloadWriter writer: + """ + + +class BytesPayload(Payload): + + def __init__(self, value, *args, **kwargs): + assert isinstance(value, (bytes, bytearray, memoryview)), \ + "value argument must be byte-ish (%r)" % type(value) + + if 'content_type' not in kwargs: + kwargs['content_type'] = 'application/octet-stream' + + super().__init__(value, *args, **kwargs) + + self._size = len(value) + + @asyncio.coroutine + def write(self, writer): + yield from writer.write(self._value) + + +class StringPayload(BytesPayload): + + def __init__(self, value, *args, + content_type='text/plain; charset=utf-8', **kwargs): + + *_, params = parse_mimetype(content_type) + charset = params.get('charset', 'utf-8') + kwargs['encoding'] = charset + + super().__init__( + value.encode(charset), content_type=content_type, *args, **kwargs) + + +class IOBasePayload(Payload): + + def __init__(self, value, *args, **kwargs): + if 'filename' not in kwargs: + kwargs['filename'] = guess_filename(value) + + super().__init__(value, *args, **kwargs) + + if self._filename is not None: + self.set_content_disposition('attachment', filename=self._filename) + + @asyncio.coroutine + def write(self, writer): + chunk = self._value.read(DEFAULT_LIMIT) + while chunk: + yield from writer.write(chunk) + chunk = self._value.read(DEFAULT_LIMIT) + + self._value.close() + + +class StringIOPayload(IOBasePayload): + + def __init__(self, value, *args, + content_type='text/plain; charset=utf-8', **kwargs): + *_, params = parse_mimetype(content_type) + charset = params.get('charset', 'utf-8') + + super().__init__( + value, + content_type=content_type, + encoding=charset, *args, **kwargs) + + @asyncio.coroutine + def write(self, writer): + chunk = self._value.read(DEFAULT_LIMIT) + while chunk: + yield from writer.write(chunk.encode(self._encoding)) + chunk = self._value.read(DEFAULT_LIMIT) + + self._value.close() + + +class TextIOPayload(IOBasePayload): + + @property + def size(self): + try: + return os.fstat(self._value.fileno()).st_size - self._value.tell() + except OSError: + return None + + @asyncio.coroutine + def write(self, writer): + encoding = self._value.encoding + chunk = self._value.read(DEFAULT_LIMIT) + while chunk: + yield from writer.write(chunk.encode(encoding)) + chunk = self._value.read(DEFAULT_LIMIT) + + self._value.close() + + +class BytesIOPayload(IOBasePayload): + + @property + def size(self): + return len(self._value.getbuffer()) + + +class BufferedReaderPayload(IOBasePayload): + + @property + def size(self): + try: + return os.fstat(self._value.fileno()).st_size - self._value.tell() + except OSError: + # data.fileno() is not supported, e.g. + # io.BufferedReader(io.BytesIO(b'data')) + return None + + +class StreamReaderPayload(Payload): + + @asyncio.coroutine + def write(self, writer): + chunk = yield from self._value.read(DEFAULT_LIMIT) + while chunk: + yield from writer.write(chunk) + chunk = yield from self._value.read(DEFAULT_LIMIT) + + +class DataQueuePayload(Payload): + + @asyncio.coroutine + def write(self, writer): + while True: + try: + chunk = yield from self._value.read() + if not chunk: + break + yield from writer.write(chunk) + except EofStream: + break + + +PAYLOAD_REGISTRY = PayloadRegistry() +PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview)) +PAYLOAD_REGISTRY.register(StringPayload, str) +PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO) +PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase) +PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO) +PAYLOAD_REGISTRY.register( + BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom)) +PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase) +PAYLOAD_REGISTRY.register( + StreamReaderPayload, (asyncio.StreamReader, StreamReader)) +PAYLOAD_REGISTRY.register(DataQueuePayload, DataQueue) diff --git a/aiohttp/payload_streamer.py b/aiohttp/payload_streamer.py new file mode 100644 index 00000000000..53028115146 --- /dev/null +++ b/aiohttp/payload_streamer.py @@ -0,0 +1,70 @@ +""" Payload implemenation for coroutines as data provider. + +As a simple case, you can upload data from file:: + + @aiohttp.streamer + def file_sender(writer, file_name=None): + with open(file_name, 'rb') as f: + chunk = f.read(2**16) + while chunk: + yield from writer.write(chunk) + + chunk = f.read(2**16) + +Then you can use `file_sender` like this: + + async with session.post('http://httpbin.org/post', + data=file_sender(file_name='hude_file')) as resp: + print(await resp.text()) + +..note:: Coroutine must accept `writer` as first argument + +""" + +import asyncio + +from . import payload + +__all__ = ('streamer',) + + +class _stream_wrapper: + + def __init__(self, coro, args, kwargs): + self.coro = coro + self.args = args + self.kwargs = kwargs + + @asyncio.coroutine + def __call__(self, writer): + yield from self.coro(writer, *self.args, **self.kwargs) + + +class streamer: + + def __init__(self, coro): + self.coro = coro + + def __call__(self, *args, **kwargs): + return _stream_wrapper(self.coro, args, kwargs) + + +class StreamWrapperPayload(payload.Payload): + + @asyncio.coroutine + def write(self, writer): + yield from self._value(writer) + + +class StreamPayload(StreamWrapperPayload): + + def __init__(self, value, *args, **kwargs): + super().__init__(value(), *args, **kwargs) + + @asyncio.coroutine + def write(self, writer): + yield from self._value(writer) + + +payload.PAYLOAD_REGISTRY.register(StreamPayload, streamer) +payload.PAYLOAD_REGISTRY.register(StreamWrapperPayload, _stream_wrapper) diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index 405dee1f4d7..eff07f3df37 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -9,7 +9,7 @@ from multidict import CIMultiDict, CIMultiDictProxy -from . import hdrs +from . import hdrs, payload from .helpers import HeadersMixin, SimpleCookie, sentinel from .http import (RESPONSES, SERVER_SOFTWARE, HttpVersion10, HttpVersion11, PayloadWriter) @@ -478,20 +478,53 @@ def __init__(self, *, body=None, status=200, headers[hdrs.CONTENT_TYPE] = content_type super().__init__(status=status, reason=reason, headers=headers) + if text is not None: self.text = text else: - self._body = body + self.body = body @property def body(self): return self._body @body.setter - def body(self, body): - assert body is None or isinstance(body, (bytes, bytearray)), \ - "body argument must be bytes (%r)" % type(body) - self._body = body + def body(self, body, + CONTENT_TYPE=hdrs.CONTENT_TYPE, + CONTENT_LENGTH=hdrs.CONTENT_LENGTH): + if body is None: + self._body = None + self._body_payload = False + elif isinstance(body, (bytes, bytearray)): + self._body = body + self._body_payload = False + else: + try: + self._body = body = payload.PAYLOAD_REGISTRY.get(body) + except payload.LookupError: + raise ValueError('Unsupported body type %r' % type(body)) + + self._body_payload = True + + headers = self._headers + + # enable chunked encoding if needed + if not self._chunked and CONTENT_LENGTH not in headers: + size = body.size + if size is None: + self._chunked = True + elif CONTENT_LENGTH not in headers: + headers[CONTENT_LENGTH] = str(size) + + # set content-type + if CONTENT_TYPE not in headers: + headers[CONTENT_TYPE] = body.content_type + + # copy payload headers + if body.headers: + for (key, value) in body.headers.items(): + if key not in headers: + headers[key] = value @property def text(self): @@ -531,18 +564,20 @@ def content_length(self, value): @asyncio.coroutine def write_eof(self): body = self._body - if (body is not None and - (self._req._method == hdrs.METH_HEAD or - self._status in [204, 304])): - body = b'' - - if body is None: - body = b'' - - yield from super().write_eof(body) + if body is not None: + if (self._req._method == hdrs.METH_HEAD or + self._status in [204, 304]): + yield from super().write_eof() + elif self._body_payload: + yield from body.write(self._payload_writer) + yield from super().write_eof() + else: + yield from super().write_eof(body) + else: + yield from super().write_eof() def _start(self, request): - if not self._chunked: + if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers: if self._body is not None: self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body)) else: diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 12d0a576c10..42767ff4dee 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -1014,9 +1014,32 @@ def handler(request): resp.close() -@pytest.mark.xfail @asyncio.coroutine def test_POST_DATA_with_charset(loop, test_client): + @asyncio.coroutine + def handler(request): + mp = yield from request.multipart() + part = yield from mp.next() + text = yield from part.text() + return web.Response(text=text) + + app = web.Application(loop=loop) + app.router.add_post('/', handler) + client = yield from test_client(app) + + form = aiohttp.FormData() + form.add_field('name', 'текст', content_type='text/plain; charset=koi8-r') + + resp = yield from client.post('/', data=form) + assert 200 == resp.status + content = yield from resp.text() + assert content == 'текст' + resp.close() + + +@pytest.mark.xfail +@asyncio.coroutine +def test_POST_DATA_with_charset_post(loop, test_client): @asyncio.coroutine def handler(request): data = yield from request.post() @@ -1036,14 +1059,13 @@ def handler(request): resp.close() -@pytest.mark.xfail @asyncio.coroutine def test_POST_DATA_with_context_transfer_encoding(loop, test_client): @asyncio.coroutine def handler(request): data = yield from request.post() assert data['name'] == b'text' # should it be str? - return web.Response() + return web.Response(body=data['name']) app = web.Application(loop=loop) app.router.add_post('/', handler) @@ -1059,6 +1081,32 @@ def handler(request): resp.close() +@pytest.mark.xfail +@asyncio.coroutine +def test_POST_DATA_with_content_type_context_transfer_encoding( + loop, test_client): + @asyncio.coroutine + def handler(request): + data = yield from request.post() + assert data['name'] == 'text' # should it be str? + return web.Response(body=data['name']) + + app = web.Application(loop=loop) + app.router.add_post('/', handler) + client = yield from test_client(app) + + form = aiohttp.FormData() + form.add_field('name', 'text', + content_type='text/plain', + content_transfer_encoding='base64') + + resp = yield from client.post('/', data=form) + assert 200 == resp.status + content = yield from resp.text() + assert content == 'text' + resp.close() + + @asyncio.coroutine def test_POST_MultiDict(loop, test_client): @asyncio.coroutine @@ -1233,12 +1281,27 @@ def handler(request): @asyncio.coroutine def test_POST_FILES_SINGLE(loop, test_client, fname): + @asyncio.coroutine + def handler(request): + data = yield from request.text() + with fname.open('r') as f: + content = f.read() + assert content == data + # if system cannot determine 'application/pgp-keys' MIME type + # then use 'application/octet-stream' default + assert request.content_type in ['application/pgp-keys', + 'text/plain', + 'application/octet-stream'] + return web.HTTPOk() + app = web.Application(loop=loop) + app.router.add_post('/', handler) client = yield from test_client(app) with fname.open() as f: - with pytest.raises(ValueError): - yield from client.post('/', data=f) + resp = yield from client.post('/', data=f) + assert 200 == resp.status + resp.close() @asyncio.coroutine @@ -1366,7 +1429,6 @@ def handler(request): resp.close() -@pytest.mark.xfail @asyncio.coroutine def test_POST_STREAM_DATA(loop, test_client, fname): @asyncio.coroutine @@ -1375,7 +1437,75 @@ def handler(request): content = yield from request.read() with fname.open('rb') as f: expected = f.read() - assert request.content_length == str(len(expected)) + assert request.content_length == len(expected) + assert content == expected + + return web.HTTPOk() + + app = web.Application(loop=loop) + app.router.add_post('/', handler) + client = yield from test_client(app) + + with fname.open('rb') as f: + data_size = len(f.read()) + + @aiohttp.streamer + def stream(writer, fname): + with fname.open('rb') as f: + data = f.read(100) + while data: + yield from writer.write(data) + data = f.read(100) + + resp = yield from client.post( + '/', data=stream(fname), headers={'Content-Length': str(data_size)}) + assert 200 == resp.status + resp.close() + + +@asyncio.coroutine +def test_POST_STREAM_DATA_no_params(loop, test_client, fname): + @asyncio.coroutine + def handler(request): + assert request.content_type == 'application/octet-stream' + content = yield from request.read() + with fname.open('rb') as f: + expected = f.read() + assert request.content_length == len(expected) + assert content == expected + + return web.HTTPOk() + + app = web.Application(loop=loop) + app.router.add_post('/', handler) + client = yield from test_client(app) + + with fname.open('rb') as f: + data_size = len(f.read()) + + @aiohttp.streamer + def stream(writer): + with fname.open('rb') as f: + data = f.read(100) + while data: + yield from writer.write(data) + data = f.read(100) + + resp = yield from client.post( + '/', data=stream, headers={'Content-Length': str(data_size)}) + assert 200 == resp.status + resp.close() + + +@asyncio.coroutine +def test_POST_STREAM_DATA_coroutine_deprecated(loop, test_client, fname): + @asyncio.coroutine + def handler(request): + assert request.content_type == 'application/octet-stream' + content = yield from request.read() + with fname.open('rb') as f: + expected = f.read() + assert request.content_length == len(expected) assert content == expected return web.HTTPOk() @@ -1384,7 +1514,7 @@ def handler(request): app.router.add_post('/', handler) client = yield from test_client(app) - with fname.open() as f: + with fname.open('rb') as f: data = f.read() fut = create_future(loop) @@ -1402,7 +1532,6 @@ def stream(): resp.close() -@pytest.mark.xfail @asyncio.coroutine def test_POST_StreamReader(fname, loop, test_client): @asyncio.coroutine @@ -1411,7 +1540,7 @@ def handler(request): content = yield from request.read() with fname.open('rb') as f: expected = f.read() - assert request.content_length == str(len(expected)) + assert request.content_length == len(expected) assert content == expected return web.HTTPOk() @@ -1420,7 +1549,7 @@ def handler(request): app.router.add_post('/', handler) client = yield from test_client(app) - with fname.open() as f: + with fname.open('rb') as f: data = f.read() stream = aiohttp.StreamReader(loop=loop) diff --git a/tests/test_client_request.py b/tests/test_client_request.py index e1787882d78..7dd10fe1554 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -13,7 +13,7 @@ from yarl import URL import aiohttp -from aiohttp import BaseConnector, hdrs, helpers +from aiohttp import BaseConnector, hdrs, helpers, payload from aiohttp.client_reqrep import ClientRequest, ClientResponse from aiohttp.helpers import SimpleCookie @@ -540,7 +540,7 @@ def test_post_data(loop): data={'life': '42'}, loop=loop) resp = req.send(mock.Mock(acquire=acquire)) assert '/' == req.url.path - assert b'life=42' == req.body + assert b'life=42' == req.body._value assert 'application/x-www-form-urlencoded' ==\ req.headers['CONTENT-TYPE'] yield from req.close() @@ -580,7 +580,7 @@ def test_get_with_data(loop): meth, URL('http://python.org/'), data={'life': '42'}, loop=loop) assert '/' == req.url.path - assert b'life=42' == req.body + assert b'life=42' == req.body._value yield from req.close() @@ -592,7 +592,8 @@ def test_bytes_data(loop): data=b'binary data', loop=loop) resp = req.send(mock.Mock(acquire=acquire)) assert '/' == req.url.path - assert b'binary data' == req.body + assert isinstance(req.body, payload.BytesPayload) + assert b'binary data' == req.body._value assert 'application/octet-stream' == req.headers['CONTENT-TYPE'] yield from req.close() resp.close() @@ -813,7 +814,7 @@ def test_data_file(loop, transport): data=io.BufferedReader(io.BytesIO(b'*' * 2)), loop=loop) assert req.chunked - assert isinstance(req.body, io.IOBase) + assert isinstance(req.body, payload.BufferedReaderPayload) assert req.headers['TRANSFER-ENCODING'] == 'chunked' transport, buf = transport diff --git a/tests/test_formdata.py b/tests/test_formdata.py new file mode 100644 index 00000000000..c2e8e667f68 --- /dev/null +++ b/tests/test_formdata.py @@ -0,0 +1,77 @@ +import asyncio +from unittest import mock + +import pytest + +from aiohttp.formdata import FormData + + +@pytest.fixture +def buf(): + return bytearray() + + +@pytest.fixture +def writer(buf): + writer = mock.Mock() + + def write(chunk): + buf.extend(chunk) + return () + + writer.write.side_effect = write + return writer + + +def test_invalid_formdata_params(): + with pytest.raises(TypeError): + FormData('asdasf') + + +def test_invalid_formdata_params2(): + with pytest.raises(TypeError): + FormData('as') # 2-char str is not allowed + + +def test_invalid_formdata_content_type(): + form = FormData() + invalid_vals = [0, 0.1, {}, [], b'foo'] + for invalid_val in invalid_vals: + with pytest.raises(TypeError): + form.add_field('foo', 'bar', content_type=invalid_val) + + +def test_invalid_formdata_filename(): + form = FormData() + invalid_vals = [0, 0.1, {}, [], b'foo'] + for invalid_val in invalid_vals: + with pytest.raises(TypeError): + form.add_field('foo', 'bar', filename=invalid_val) + + +def test_invalid_formdata_content_transfer_encoding(): + form = FormData() + invalid_vals = [0, 0.1, {}, [], b'foo'] + for invalid_val in invalid_vals: + with pytest.raises(TypeError): + form.add_field('foo', + 'bar', + content_transfer_encoding=invalid_val) + + +@asyncio.coroutine +def test_formdata_field_name_is_quoted(buf, writer): + form = FormData() + form.add_field("emails[]", "xxx@x.co", content_type="multipart/form-data") + payload = form("ascii") + yield from payload.write(writer) + assert b'name="emails%5B%5D"' in buf + + +@asyncio.coroutine +def test_formdata_field_name_is_not_quoted(buf, writer): + form = FormData(quote_fields=False) + form.add_field("emails[]", "xxx@x.co", content_type="multipart/form-data") + payload = form("ascii") + yield from payload.write(writer) + assert b'name="emails[]"' in buf diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 529ed4a6ade..a396283bf5c 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -101,58 +101,9 @@ def test_basic_auth_decode_bad_base64(): helpers.BasicAuth.decode('Basic bmtpbTpwd2Q') -def test_invalid_formdata_params(): - with pytest.raises(TypeError): - helpers.FormData('asdasf') - - -def test_invalid_formdata_params2(): - with pytest.raises(TypeError): - helpers.FormData('as') # 2-char str is not allowed - - -def test_invalid_formdata_content_type(): - form = helpers.FormData() - invalid_vals = [0, 0.1, {}, [], b'foo'] - for invalid_val in invalid_vals: - with pytest.raises(TypeError): - form.add_field('foo', 'bar', content_type=invalid_val) - - -def test_invalid_formdata_filename(): - form = helpers.FormData() - invalid_vals = [0, 0.1, {}, [], b'foo'] - for invalid_val in invalid_vals: - with pytest.raises(TypeError): - form.add_field('foo', 'bar', filename=invalid_val) - - -def test_invalid_formdata_content_transfer_encoding(): - form = helpers.FormData() - invalid_vals = [0, 0.1, {}, [], b'foo'] - for invalid_val in invalid_vals: - with pytest.raises(TypeError): - form.add_field('foo', - 'bar', - content_transfer_encoding=invalid_val) - # ------------- access logger ------------------------- -def test_formdata_field_name_is_quoted(): - form = helpers.FormData() - form.add_field("emails[]", "xxx@x.co", content_type="multipart/form-data") - res = b"".join(form("ascii")) - assert b'name="emails%5B%5D"' in res - - -def test_formdata_field_name_is_not_quoted(): - form = helpers.FormData(quote_fields=False) - form.add_field("emails[]", "xxx@x.co", content_type="multipart/form-data") - res = b"".join(form("ascii")) - assert b'name="emails[]"' in res - - def test_access_logger_format(): log_format = '%T {%{SPAM}e} "%{ETag}o" %X {X} %%P %{FOO_TEST}e %{FOO1}e' mock_logger = mock.Mock() @@ -547,3 +498,33 @@ def test_eq(self): def test_le(self): l = helpers.FrozenList([1]) assert l < [2] + + +# -------------------------------- ContentDisposition ------------------- + +def test_content_disposition(): + assert (helpers.content_disposition_header('attachment', foo='bar') == + 'attachment; foo="bar"') + + +def test_content_disposition_bad_type(): + with pytest.raises(ValueError): + helpers.content_disposition_header('foo bar') + with pytest.raises(ValueError): + helpers.content_disposition_header('—Ç–µ—Å—Ç') + with pytest.raises(ValueError): + helpers.content_disposition_header('foo\x00bar') + with pytest.raises(ValueError): + helpers.content_disposition_header('') + + +def test_set_content_disposition_bad_param(): + with pytest.raises(ValueError): + helpers.content_disposition_header('inline', **{'foo bar': 'baz'}) + with pytest.raises(ValueError): + helpers.content_disposition_header('inline', **{'—Ç–µ—Å—Ç': 'baz'}) + with pytest.raises(ValueError): + helpers.content_disposition_header('inline', **{'': 'baz'}) + with pytest.raises(ValueError): + helpers.content_disposition_header('inline', + **{'foo\x00bar': 'baz'}) diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 35f2b757bf5..52f2bd2e4b8 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1,13 +1,14 @@ import asyncio import functools import io -import os import unittest import zlib from unittest import mock +import pytest + import aiohttp.multipart -from aiohttp import helpers +from aiohttp import helpers, payload from aiohttp.hdrs import (CONTENT_DISPOSITION, CONTENT_ENCODING, CONTENT_TRANSFER_ENCODING, CONTENT_TYPE) from aiohttp.helpers import parse_mimetype @@ -17,6 +18,28 @@ from aiohttp.streams import StreamReader +@pytest.fixture +def buf(): + return bytearray() + + +@pytest.fixture +def stream(buf): + writer = mock.Mock() + + def write(chunk): + buf.extend(chunk) + return () + + writer.write.side_effect = write + return writer + + +@pytest.fixture +def writer(): + return aiohttp.multipart.MultipartWriter(boundary=':') + + def run_in_loop(f): @functools.wraps(f) def wrapper(testcase, *args, **kwargs): @@ -723,318 +746,189 @@ def test_reading_skips_prelude(self): self.assertFalse(second.at_eof()) -class BodyPartWriterTestCase(unittest.TestCase): +@asyncio.coroutine +def test_writer(writer): + assert writer.size == 0 + assert writer.boundary == b':' + + +@asyncio.coroutine +def test_writer_serialize_io_chunk(buf, stream, writer): + flo = io.BytesIO(b'foobarbaz') + writer.append(flo) + yield from writer.write(stream) + print(buf) + assert (buf == b'--:\r\nContent-Type: application/octet-stream' + b'\r\nContent-Length: 9\r\n\r\nfoobarbaz\r\n--:--\r\n') + + +@asyncio.coroutine +def test_writer_serialize_json(buf, stream, writer): + writer.append_json({'привет': 'мир'}) + yield from writer.write(stream) + assert (b'{"\\u043f\\u0440\\u0438\\u0432\\u0435\\u0442":' + b' "\\u043c\\u0438\\u0440"}' in buf) + + +@asyncio.coroutine +def test_writer_serialize_form(buf, stream, writer): + data = [('foo', 'bar'), ('foo', 'baz'), ('boo', 'zoo')] + writer.append_form(data) + yield from writer.write(stream) + + assert (b'foo=bar&foo=baz&boo=zoo' in buf) + - def setUp(self): - self.part = aiohttp.multipart.BodyPartWriter(b'') - - def test_guess_content_length(self): - self.part.headers[CONTENT_TYPE] = 'text/plain; charset=utf-8' - self.assertIsNone(self.part._guess_content_length({})) - self.assertIsNone(self.part._guess_content_length(object())) - self.assertEqual(3, - self.part._guess_content_length(io.BytesIO(b'foo'))) - self.assertEqual(3, - self.part._guess_content_length(io.StringIO('foo'))) - self.assertEqual(6, - self.part._guess_content_length(io.StringIO('мяу'))) - self.assertEqual(3, self.part._guess_content_length(b'bar')) - self.assertEqual(12, self.part._guess_content_length('пассед')) - with open(__file__, 'rb') as f: - self.assertEqual(os.fstat(f.fileno()).st_size, - self.part._guess_content_length(f)) - - def test_guess_content_type(self): - default = 'application/octet-stream' - self.assertEqual(default, self.part._guess_content_type(b'foo')) - self.assertEqual('text/plain; charset=utf-8', - self.part._guess_content_type('foo')) - - here = os.path.dirname(__file__) - filename = os.path.join(here, 'aiohttp.png') - - with open(filename, 'rb') as f: - self.assertEqual('image/png', - self.part._guess_content_type(f)) - - def test_guess_filename(self): - class Named: - name = 'foo' - self.assertIsNone(self.part._guess_filename({})) - self.assertIsNone(self.part._guess_filename(object())) - self.assertIsNone(self.part._guess_filename(io.BytesIO(b'foo'))) - self.assertIsNone(self.part._guess_filename(Named())) - with open(__file__, 'rb') as f: - self.assertEqual(os.path.basename(f.name), - self.part._guess_filename(f)) - - def test_autoset_content_disposition(self): - self.part.obj = open(__file__, 'rb') - self.addCleanup(self.part.obj.close) - self.part._fill_headers_with_defaults() - self.assertIn(CONTENT_DISPOSITION, self.part.headers) - fname = os.path.basename(self.part.obj.name) - self.assertEqual( - 'attachment; filename="{0}"; filename*=utf-8\'\'{0}'.format(fname), - self.part.headers[CONTENT_DISPOSITION]) +@asyncio.coroutine +def test_writer_serialize_form_dict(buf, stream, writer): + data = {'hello': 'мир'} + writer.append_form(data) + yield from writer.write(stream) - def test_set_content_disposition(self): - self.part.set_content_disposition('attachment', foo='bar') - self.assertEqual( - 'attachment; foo="bar"', - self.part.headers[CONTENT_DISPOSITION]) + assert (b'hello=%D0%BC%D0%B8%D1%80' in buf) - def test_set_content_disposition_bad_type(self): - with self.assertRaises(ValueError): - self.part.set_content_disposition('foo bar') - with self.assertRaises(ValueError): - self.part.set_content_disposition('тест') - with self.assertRaises(ValueError): - self.part.set_content_disposition('foo\x00bar') - with self.assertRaises(ValueError): - self.part.set_content_disposition('') - def test_set_content_disposition_bad_param(self): - with self.assertRaises(ValueError): - self.part.set_content_disposition('inline', **{'foo bar': 'baz'}) - with self.assertRaises(ValueError): - self.part.set_content_disposition('inline', **{'тест': 'baz'}) - with self.assertRaises(ValueError): - self.part.set_content_disposition('inline', **{'': 'baz'}) - with self.assertRaises(ValueError): - self.part.set_content_disposition('inline', - **{'foo\x00bar': 'baz'}) - - def test_serialize_bytes(self): - self.assertEqual(b'foo', next(self.part._serialize_bytes(b'foo'))) - - def test_serialize_str(self): - self.assertEqual(b'foo', next(self.part._serialize_str('foo'))) - - def test_serialize_str_custom_encoding(self): - self.part.headers[CONTENT_TYPE] = \ - 'text/plain;charset=cp1251' - self.assertEqual('привет'.encode('cp1251'), - next(self.part._serialize_str('привет'))) - - def test_serialize_io(self): - self.assertEqual(b'foo', - next(self.part._serialize_io(io.BytesIO(b'foo')))) - self.assertEqual(b'foo', - next(self.part._serialize_io(io.StringIO('foo')))) - - def test_serialize_io_chunk(self): - flo = io.BytesIO(b'foobarbaz') - self.part._chunk_size = 3 - self.assertEqual([b'foo', b'bar', b'baz'], - list(self.part._serialize_io(flo))) - - def test_serialize_json(self): - self.assertEqual(b'{"\\u043f\\u0440\\u0438\\u0432\\u0435\\u0442":' - b' "\\u043c\\u0438\\u0440"}', - next(self.part._serialize_json({'привет': 'мир'}))) - - def test_serialize_form(self): - data = [('foo', 'bar'), ('foo', 'baz'), ('boo', 'zoo')] - self.assertEqual(b'foo=bar&foo=baz&boo=zoo', - next(self.part._serialize_form(data))) - - def test_serialize_form_dict(self): - data = {'hello': 'мир'} - self.assertEqual(b'hello=%D0%BC%D0%B8%D1%80', - next(self.part._serialize_form(data))) - - def test_serialize_multipart(self): - multipart = aiohttp.multipart.MultipartWriter(boundary=':') - multipart.append('foo-bar-baz') - multipart.append_json({'test': 'passed'}) - multipart.append_form({'test': 'passed'}) - multipart.append_form([('one', 1), ('two', 2)]) - sub_multipart = aiohttp.multipart.MultipartWriter(boundary='::') - sub_multipart.append('nested content') - sub_multipart.headers['X-CUSTOM'] = 'test' - multipart.append(sub_multipart) - self.assertEqual( - [b'--:\r\n', - b'Content-Type: text/plain; charset=utf-8\r\n' - b'Content-Length: 11', - b'\r\n\r\n', - b'foo-bar-baz', - b'\r\n', - - b'--:\r\n', - b'Content-Type: application/json', - b'\r\n\r\n', - b'{"test": "passed"}', - b'\r\n', - - b'--:\r\n', - b'Content-Type: application/x-www-form-urlencoded', - b'\r\n\r\n', - b'test=passed', - b'\r\n', - - b'--:\r\n', - b'Content-Type: application/x-www-form-urlencoded', - b'\r\n\r\n', - b'one=1&two=2', - b'\r\n', - - b'--:\r\n', - b'Content-Type: multipart/mixed; boundary="::"\r\nX-Custom: test', - b'\r\n\r\n', - b'--::\r\n', - b'Content-Type: text/plain; charset=utf-8\r\n' - b'Content-Length: 14', - b'\r\n\r\n', - b'nested content', - b'\r\n', - b'--::--\r\n', - b'', - b'\r\n', - b'--:--\r\n', - b''], - list(self.part._serialize_multipart(multipart)) - ) - - def test_serialize_default(self): - with self.assertRaises(TypeError): - self.part.obj = object() - list(self.part.serialize()) - with self.assertRaises(TypeError): - next(self.part._serialize_default(object())) - - def test_serialize_with_content_encoding_gzip(self): - part = aiohttp.multipart.BodyPartWriter( - 'Time to Relax!', {CONTENT_ENCODING: 'gzip'}) - stream = part.serialize() - self.assertEqual(b'Content-Encoding: gzip\r\n' - b'Content-Type: text/plain; charset=utf-8', - next(stream)) - self.assertEqual(b'\r\n\r\n', next(stream)) - - result = b''.join(stream) - - decompressor = zlib.decompressobj(wbits=16+zlib.MAX_WBITS) - data = decompressor.decompress(result) - self.assertEqual(b'Time to Relax!', data) - self.assertIsNone(next(stream, None)) - - def test_serialize_with_content_encoding_deflate(self): - part = aiohttp.multipart.BodyPartWriter( - 'Time to Relax!', {CONTENT_ENCODING: 'deflate'}) - stream = part.serialize() - self.assertEqual(b'Content-Encoding: deflate\r\n' - b'Content-Type: text/plain; charset=utf-8', - next(stream)) - self.assertEqual(b'\r\n\r\n', next(stream)) - - thing = b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n' - self.assertEqual(thing, b''.join(stream)) - self.assertIsNone(next(stream, None)) - - def test_serialize_with_content_encoding_identity(self): - thing = b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00' - part = aiohttp.multipart.BodyPartWriter( - thing, {CONTENT_ENCODING: 'identity'}) - stream = part.serialize() - self.assertEqual(b'Content-Encoding: identity\r\n' - b'Content-Type: application/octet-stream\r\n' - b'Content-Length: 16', - next(stream)) - self.assertEqual(b'\r\n\r\n', next(stream)) - - self.assertEqual(thing, next(stream)) - self.assertEqual(b'\r\n', next(stream)) - self.assertIsNone(next(stream, None)) - - def test_serialize_with_content_encoding_unknown(self): - part = aiohttp.multipart.BodyPartWriter( - 'Time to Relax!', {CONTENT_ENCODING: 'snappy'}) - with self.assertRaises(RuntimeError): - list(part.serialize()) - - def test_serialize_with_content_transfer_encoding_base64(self): - part = aiohttp.multipart.BodyPartWriter( - 'Time to Relax!', {CONTENT_TRANSFER_ENCODING: 'base64'}) - stream = part.serialize() - self.assertEqual(b'Content-Transfer-Encoding: base64\r\n' - b'Content-Type: text/plain; charset=utf-8', - next(stream)) - self.assertEqual(b'\r\n\r\n', next(stream)) - - self.assertEqual(b'VGltZSB0byBSZWxh', next(stream)) - self.assertEqual(b'eCE=', next(stream)) - self.assertEqual(b'\r\n', next(stream)) - self.assertIsNone(next(stream, None)) - - def test_serialize_io_with_content_transfer_encoding_base64(self): - part = aiohttp.multipart.BodyPartWriter( - io.BytesIO(b'Time to Relax!'), - {CONTENT_TRANSFER_ENCODING: 'base64'}) - part._chunk_size = 6 - stream = part.serialize() - self.assertEqual(b'Content-Transfer-Encoding: base64\r\n' - b'Content-Type: application/octet-stream', - next(stream)) - self.assertEqual(b'\r\n\r\n', next(stream)) - - self.assertEqual(b'VGltZSB0', next(stream)) - self.assertEqual(b'byBSZWxh', next(stream)) - self.assertEqual(b'eCE=', next(stream)) - self.assertEqual(b'\r\n', next(stream)) - self.assertIsNone(next(stream, None)) - - def test_serialize_with_content_transfer_encoding_quote_printable(self): - part = aiohttp.multipart.BodyPartWriter( - 'Привет, мир!', {CONTENT_TRANSFER_ENCODING: 'quoted-printable'}) - stream = part.serialize() - self.assertEqual(b'Content-Transfer-Encoding: quoted-printable\r\n' - b'Content-Type: text/plain; charset=utf-8', - next(stream)) - self.assertEqual(b'\r\n\r\n', next(stream)) - - self.assertEqual(b'=D0=9F=D1=80=D0=B8=D0=B2=D0=B5=D1=82,' - b' =D0=BC=D0=B8=D1=80!', next(stream)) - self.assertEqual(b'\r\n', next(stream)) - self.assertIsNone(next(stream, None)) - - def test_serialize_with_content_transfer_encoding_binary(self): - part = aiohttp.multipart.BodyPartWriter( - 'Привет, мир!'.encode('utf-8'), - {CONTENT_TRANSFER_ENCODING: 'binary'}) - stream = part.serialize() - self.assertEqual(b'Content-Transfer-Encoding: binary\r\n' - b'Content-Type: application/octet-stream', - next(stream)) - self.assertEqual(b'\r\n\r\n', next(stream)) +@asyncio.coroutine +def test_writer_write(buf, stream, writer): + writer.append('foo-bar-baz') + writer.append_json({'test': 'passed'}) + writer.append_form({'test': 'passed'}) + writer.append_form([('one', 1), ('two', 2)]) - self.assertEqual(b'\xd0\x9f\xd1\x80\xd0\xb8\xd0\xb2\xd0\xb5\xd1\x82,' - b' \xd0\xbc\xd0\xb8\xd1\x80!', next(stream)) - self.assertEqual(b'\r\n', next(stream)) - self.assertIsNone(next(stream, None)) + sub_multipart = aiohttp.multipart.MultipartWriter(boundary='::') + sub_multipart.append('nested content') + sub_multipart.headers['X-CUSTOM'] = 'test' + writer.append(sub_multipart) + yield from writer.write(stream) - def test_serialize_with_content_transfer_encoding_unknown(self): - part = aiohttp.multipart.BodyPartWriter( - 'Time to Relax!', {CONTENT_TRANSFER_ENCODING: 'unknown'}) - with self.assertRaises(RuntimeError): - list(part.serialize()) + assert ( + (b'--:\r\n' + b'Content-Type: text/plain; charset=utf-8\r\n' + b'Content-Length: 11\r\n\r\n' + b'foo-bar-baz' + b'\r\n' - def test_filename(self): - self.part.set_content_disposition('related', filename='foo.html') - self.assertEqual('foo.html', self.part.filename) + b'--:\r\n' + b'Content-Type: application/json\r\n' + b'Content-Length: 18\r\n\r\n' + b'{"test": "passed"}' + b'\r\n' + + b'--:\r\n' + b'Content-Type: application/x-www-form-urlencoded\r\n' + b'Content-Length: 11\r\n\r\n' + b'test=passed' + b'\r\n' + + b'--:\r\n' + b'Content-Type: application/x-www-form-urlencoded\r\n' + b'Content-Length: 11\r\n\r\n' + b'one=1&two=2' + b'\r\n' + + b'--:\r\n' + b'Content-Type: multipart/mixed; boundary="::"\r\n' + b'X-Custom: test\r\nContent-Length: 93\r\n\r\n' + b'--::\r\n' + b'Content-Type: text/plain; charset=utf-8\r\n' + b'Content-Length: 14\r\n\r\n' + b'nested content\r\n' + b'--::--\r\n' + b'\r\n' + b'--:--\r\n') == bytes(buf)) - def test_wrap_multipart(self): - writer = aiohttp.multipart.MultipartWriter(boundary=':') - part = aiohttp.multipart.BodyPartWriter(writer) - self.assertEqual(part.headers, writer.headers) - part.headers['X-Custom'] = 'test' - self.assertEqual(part.headers, writer.headers) + +@asyncio.coroutine +def test_writer_serialize_with_content_encoding_gzip(buf, stream, writer): + writer.append('Time to Relax!', {CONTENT_ENCODING: 'gzip'}) + yield from writer.write(stream) + headers, message = bytes(buf).split(b'\r\n\r\n', 1) + + assert (b'--:\r\nContent-Encoding: gzip\r\n' + b'Content-Type: text/plain; charset=utf-8' == headers) + + decompressor = zlib.decompressobj(wbits=16+zlib.MAX_WBITS) + data = decompressor.decompress(message.split(b'\r\n')[0]) + data += decompressor.flush() + assert b'Time to Relax!' == data + + +@asyncio.coroutine +def test_writer_serialize_with_content_encoding_deflate(buf, stream, writer): + writer.append('Time to Relax!', {CONTENT_ENCODING: 'deflate'}) + yield from writer.write(stream) + headers, message = bytes(buf).split(b'\r\n\r\n', 1) + + assert (b'--:\r\nContent-Encoding: deflate\r\n' + b'Content-Type: text/plain; charset=utf-8' == headers) + + thing = b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00\r\n--:--\r\n' + assert thing == message + + +@asyncio.coroutine +def test_writer_serialize_with_content_encoding_identity(buf, stream, writer): + thing = b'\x0b\xc9\xccMU(\xc9W\x08J\xcdI\xacP\x04\x00' + writer.append(thing, {CONTENT_ENCODING: 'identity'}) + yield from writer.write(stream) + headers, message = bytes(buf).split(b'\r\n\r\n', 1) + + assert (b'--:\r\nContent-Encoding: identity\r\n' + b'Content-Type: application/octet-stream\r\n' + b'Content-Length: 16' == headers) + + assert thing == message.split(b'\r\n')[0] + + +def test_writer_serialize_with_content_encoding_unknown(buf, stream, writer): + with pytest.raises(RuntimeError): + writer.append('Time to Relax!', {CONTENT_ENCODING: 'snappy'}) + + +@asyncio.coroutine +def test_writer_with_content_transfer_encoding_base64(buf, stream, writer): + writer.append('Time to Relax!', {CONTENT_TRANSFER_ENCODING: 'base64'}) + yield from writer.write(stream) + headers, message = bytes(buf).split(b'\r\n\r\n', 1) + + assert (b'--:\r\nContent-Transfer-Encoding: base64\r\n' + b'Content-Type: text/plain; charset=utf-8' == + headers) + + assert b'VGltZSB0byBSZWxheCE=' == message.split(b'\r\n')[0] + + +@asyncio.coroutine +def test_writer_content_transfer_encoding_quote_printable(buf, stream, writer): + writer.append('Привет, мир!', + {CONTENT_TRANSFER_ENCODING: 'quoted-printable'}) + yield from writer.write(stream) + headers, message = bytes(buf).split(b'\r\n\r\n', 1) + + assert (b'--:\r\nContent-Transfer-Encoding: quoted-printable\r\n' + b'Content-Type: text/plain; charset=utf-8' == headers) + + assert (b'=D0=9F=D1=80=D0=B8=D0=B2=D0=B5=D1=82,' + b' =D0=BC=D0=B8=D1=80!' == message.split(b'\r\n')[0]) + + +def test_writer_content_transfer_encoding_unknown(buf, stream, writer): + with pytest.raises(RuntimeError): + writer.append('Time to Relax!', {CONTENT_TRANSFER_ENCODING: 'unknown'}) class MultipartWriterTestCase(unittest.TestCase): def setUp(self): + self.buf = bytearray() + self.stream = mock.Mock() + + def write(chunk): + self.buf.extend(chunk) + return () + + self.stream.write.side_effect = write + self.writer = aiohttp.multipart.MultipartWriter(boundary=':') def test_default_subtype(self): @@ -1061,52 +955,50 @@ def test_append(self): self.assertEqual(0, len(self.writer)) self.writer.append('hello, world!') self.assertEqual(1, len(self.writer)) - self.assertIsInstance(self.writer.parts[0], - self.writer.part_writer_cls) + self.assertIsInstance(self.writer._parts[0][0], payload.Payload) def test_append_with_headers(self): self.writer.append('hello, world!', {'x-foo': 'bar'}) self.assertEqual(1, len(self.writer)) - self.assertIn('x-foo', self.writer.parts[0].headers) - self.assertEqual(self.writer.parts[0].headers['x-foo'], 'bar') + self.assertIn('x-foo', self.writer._parts[0][0].headers) + self.assertEqual(self.writer._parts[0][0].headers['x-foo'], 'bar') def test_append_json(self): self.writer.append_json({'foo': 'bar'}) self.assertEqual(1, len(self.writer)) - part = self.writer.parts[0] + part = self.writer._parts[0][0] self.assertEqual(part.headers[CONTENT_TYPE], 'application/json') def test_append_part(self): - part = aiohttp.multipart.BodyPartWriter('test', - {CONTENT_TYPE: 'text/plain'}) + part = payload.get_payload( + 'test', headers={CONTENT_TYPE: 'text/plain'}) self.writer.append(part, {CONTENT_TYPE: 'test/passed'}) self.assertEqual(1, len(self.writer)) - part = self.writer.parts[0] + part = self.writer._parts[0][0] self.assertEqual(part.headers[CONTENT_TYPE], 'test/passed') def test_append_json_overrides_content_type(self): self.writer.append_json({'foo': 'bar'}, {CONTENT_TYPE: 'test/passed'}) self.assertEqual(1, len(self.writer)) - part = self.writer.parts[0] - self.assertEqual(part.headers[CONTENT_TYPE], 'application/json') + part = self.writer._parts[0][0] + self.assertEqual(part.headers[CONTENT_TYPE], 'test/passed') def test_append_form(self): self.writer.append_form({'foo': 'bar'}, {CONTENT_TYPE: 'test/passed'}) self.assertEqual(1, len(self.writer)) - part = self.writer.parts[0] - self.assertEqual(part.headers[CONTENT_TYPE], - 'application/x-www-form-urlencoded') + part = self.writer._parts[0][0] + self.assertEqual(part.headers[CONTENT_TYPE], 'test/passed') def test_append_multipart(self): subwriter = aiohttp.multipart.MultipartWriter(boundary=':') subwriter.append_json({'foo': 'bar'}) self.writer.append(subwriter, {CONTENT_TYPE: 'test/passed'}) self.assertEqual(1, len(self.writer)) - part = self.writer.parts[0] + part = self.writer._parts[0][0] self.assertEqual(part.headers[CONTENT_TYPE], 'test/passed') - def test_serialize(self): - self.assertEqual([b''], list(self.writer.serialize())) + def test_write(self): + self.assertEqual([], list(self.writer.write(self.stream))) def test_with(self): with aiohttp.multipart.MultipartWriter(boundary=':') as writer: diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index ae2fa2f4fba..67b274cb7f1 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -1,4 +1,5 @@ import asyncio +import io import json import pathlib import zlib @@ -8,6 +9,7 @@ from multidict import MultiDict from yarl import URL +import aiohttp from aiohttp import FormData, HttpVersion10, HttpVersion11, multipart, web try: @@ -16,6 +18,16 @@ ssl = False +@pytest.fixture +def here(): + return pathlib.Path(__file__).parent + + +@pytest.fixture +def fname(here): + return here / 'sample.key' + + @asyncio.coroutine def test_simple_get(loop, test_client): @@ -739,6 +751,164 @@ def handler(request): assert 200 == resp.status +@asyncio.coroutine +def test_response_with_streamer(loop, test_client, fname): + + with fname.open('rb') as f: + data = f.read() + + data_size = len(data) + + @aiohttp.streamer + def stream(writer, f_name): + with f_name.open('rb') as f: + data = f.read(100) + while data: + yield from writer.write(data) + data = f.read(100) + + @asyncio.coroutine + def handler(request): + headers = {'Content-Length': str(data_size)} + return web.Response(body=stream(fname), headers=headers) + + app = web.Application(loop=loop) + app.router.add_get('/', handler) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert 200 == resp.status + resp_data = yield from resp.read() + assert resp_data == data + assert resp.headers.get('Content-Length') == str(len(resp_data)) + + +@asyncio.coroutine +def test_response_with_streamer_no_params(loop, test_client, fname): + + with fname.open('rb') as f: + data = f.read() + + data_size = len(data) + + @aiohttp.streamer + def stream(writer): + with fname.open('rb') as f: + data = f.read(100) + while data: + yield from writer.write(data) + data = f.read(100) + + @asyncio.coroutine + def handler(request): + headers = {'Content-Length': str(data_size)} + return web.Response(body=stream, headers=headers) + + app = web.Application(loop=loop) + app.router.add_get('/', handler) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert 200 == resp.status + resp_data = yield from resp.read() + assert resp_data == data + assert resp.headers.get('Content-Length') == str(len(resp_data)) + + +@asyncio.coroutine +def test_response_with_file(loop, test_client, fname): + + with fname.open('rb') as f: + data = f.read() + + @asyncio.coroutine + def handler(request): + return web.Response(body=fname.open('rb')) + + app = web.Application(loop=loop) + app.router.add_get('/', handler) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert 200 == resp.status + resp_data = yield from resp.read() + assert resp_data == data + assert resp.headers.get('Content-Type') in ( + 'application/octet-stream', 'application/pgp-keys') + assert resp.headers.get('Content-Length') == str(len(resp_data)) + assert (resp.headers.get('Content-Disposition') == + 'attachment; filename="sample.key"; filename*=utf-8\'\'sample.key') + + +@asyncio.coroutine +def test_response_with_file_ctype(loop, test_client, fname): + + with fname.open('rb') as f: + data = f.read() + + @asyncio.coroutine + def handler(request): + return web.Response( + body=fname.open('rb'), headers={'content-type': 'text/binary'}) + + app = web.Application(loop=loop) + app.router.add_get('/', handler) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert 200 == resp.status + resp_data = yield from resp.read() + assert resp_data == data + assert resp.headers.get('Content-Type') == 'text/binary' + assert resp.headers.get('Content-Length') == str(len(resp_data)) + assert (resp.headers.get('Content-Disposition') == + 'attachment; filename="sample.key"; filename*=utf-8\'\'sample.key') + + +@asyncio.coroutine +def test_response_with_payload_disp(loop, test_client, fname): + + with fname.open('rb') as f: + data = f.read() + + @asyncio.coroutine + def handler(request): + pl = aiohttp.get_payload(fname.open('rb')) + pl.set_content_disposition('inline', filename='test.txt') + return web.Response( + body=pl, headers={'content-type': 'text/binary'}) + + app = web.Application(loop=loop) + app.router.add_get('/', handler) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert 200 == resp.status + resp_data = yield from resp.read() + assert resp_data == data + assert resp.headers.get('Content-Type') == 'text/binary' + assert resp.headers.get('Content-Length') == str(len(resp_data)) + assert (resp.headers.get('Content-Disposition') == + 'inline; filename="test.txt"; filename*=utf-8\'\'test.txt') + + +@asyncio.coroutine +def test_response_with_payload_stringio(loop, test_client, fname): + + @asyncio.coroutine + def handler(request): + return web.Response(body=io.StringIO('test')) + + app = web.Application(loop=loop) + app.router.add_get('/', handler) + client = yield from test_client(app) + + resp = yield from client.get('/') + assert 200 == resp.status + resp_data = yield from resp.read() + assert resp_data == b'test' + + @asyncio.coroutine def test_response_with_precompressed_body_gzip(loop, test_client): diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 5f14f5cc8ba..222edcdee5e 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -802,7 +802,7 @@ def test_ctor_both_charset_param_and_header(): def test_assign_nonbyteish_body(): resp = Response(body=b'data') - with pytest.raises(AssertionError): + with pytest.raises(ValueError): resp.body = 123 assert b'data' == resp.body assert 4 == resp.content_length