Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip default headers #486

Merged
merged 11 commits into from
Sep 3, 2015
2 changes: 2 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ CHANGES
Using `force` parameter for the method is deprecated: use `.release()` instead.

* Properly requote URL's path #480

* add `skip_auto_headers` parameter for client API #486
17 changes: 15 additions & 2 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class ClientSession:
_connector = None

def __init__(self, *, connector=None, loop=None, cookies=None,
headers=None, auth=None, request_class=ClientRequest,
headers=None, skip_auto_headers=None,
auth=None, request_class=ClientRequest,
response_class=ClientResponse,
ws_response_class=ClientWebSocketResponse):

Expand Down Expand Up @@ -65,6 +66,11 @@ def __init__(self, *, connector=None, loop=None, cookies=None,
else:
headers = CIMultiDict()
self._default_headers = headers
if skip_auto_headers is not None:
self._skip_auto_headers = frozenset([upstr(i)
for i in skip_auto_headers])
else:
self._skip_auto_headers = frozenset()

self._request_class = request_class
self._response_class = response_class
Expand All @@ -88,6 +94,7 @@ def request(self, method, url, *,
params=None,
data=None,
headers=None,
skip_auto_headers=None,
files=None,
auth=None,
allow_redirects=True,
Expand Down Expand Up @@ -119,9 +126,15 @@ def request(self, method, url, *,
raise ValueError("Can't combine `Authorization` header with "
"`auth` argument")

skip_headers = set(self._skip_auto_headers)
if skip_auto_headers is not None:
for i in skip_auto_headers:
skip_headers.add(upstr(i))

while True:
req = self._request_class(
method, url, params=params, headers=headers, data=data,
method, url, params=params, headers=headers,
skip_auto_headers=skip_headers, data=data,
cookies=self.cookies, files=files, encoding=encoding,
auth=auth, version=version, compress=compress, chunked=chunked,
expect100=expect100,
Expand Down
24 changes: 19 additions & 5 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from . import hdrs, helpers, streams
from .log import client_logger
from .streams import EOF_MARKER, FlowControlStreamReader
from .multidict import CIMultiDictProxy, MultiDictProxy, MultiDict, CIMultiDict
from .multidict import (CIMultiDictProxy, MultiDictProxy, MultiDict,
CIMultiDict)
from .multipart import MultipartWriter
from .protocol import HttpMessage

PY_341 = sys.version_info >= (3, 4, 1)

Expand All @@ -39,6 +41,8 @@ class ClientRequest:
hdrs.ACCEPT_ENCODING: 'gzip, deflate',
}

SERVER_SOFTWARE = HttpMessage.SERVER_SOFTWARE

body = b''
auth = None
response = None
Expand All @@ -53,7 +57,8 @@ class ClientRequest:
# Until writer has finished finalizer will not be called.

def __init__(self, method, url, *,
params=None, headers=None, data=None, cookies=None,
params=None, headers=None, skip_auto_headers=frozenset(),
data=None, cookies=None,
files=None, auth=None, encoding='utf-8',
version=aiohttp.HttpVersion11, compress=None,
chunked=None, expect100=False,
Expand All @@ -77,6 +82,7 @@ def __init__(self, method, url, *,
self.update_host(url)
self.update_path(params)
self.update_headers(headers)
self.update_auto_headers(skip_auto_headers)
self.update_cookies(cookies)
self.update_content_encoding()
self.update_auth(auth)
Expand Down Expand Up @@ -191,14 +197,21 @@ def update_headers(self, headers):
for key, value in headers:
self.headers.add(key, value)

def update_auto_headers(self, skip_auto_headers):
self.skip_auto_headers = skip_auto_headers
used_headers = set(self.headers) | skip_auto_headers

for hdr, val in self.DEFAULT_HEADERS.items():
if hdr not in self.headers:
self.headers[hdr] = val
if hdr not in used_headers:
self.headers.add(hdr, val)

# add host
if hdrs.HOST not in self.headers:
if hdrs.HOST not in used_headers:
self.headers[hdrs.HOST] = self.netloc

if hdrs.USER_AGENT not in used_headers:
self.headers[hdrs.USER_AGENT] = self.SERVER_SOFTWARE

def update_cookies(self, cookies):
"""Update request cookies header."""
if not cookies:
Expand Down Expand Up @@ -445,6 +458,7 @@ def send(self, writer, reader):

# set default content-type
if (self.method in self.POST_METHODS and
hdrs.CONTENT_TYPE not in self.skip_auto_headers and
hdrs.CONTENT_TYPE not in self.headers):
self.headers[hdrs.CONTENT_TYPE] = 'application/octet-stream'

Expand Down
5 changes: 0 additions & 5 deletions aiohttp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,8 +857,3 @@ def __init__(self, transport, method, path,
self.path = path
self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format(
method, path, http_version)

def _add_default_headers(self):
super()._add_default_headers()

self.headers.setdefault(hdrs.USER_AGENT, self.SERVER_SOFTWARE)
36 changes: 32 additions & 4 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ The client session supports context manager protocol for self closing::


.. class:: ClientSession(*, connector=None, loop=None, cookies=None,\
headers=None, auth=None, request_class=ClientRequest,\
headers=None, skip_auto_headers=None, \
auth=None, request_class=ClientRequest,\
response_class=ClientResponse, \
ws_response_class=ClientWebSocketResponse)

Expand All @@ -61,8 +62,23 @@ The client session supports context manager protocol for self closing::

:param dict cookies: Cookies to send with the request (optional)

:param dict headers: HTTP Headers to send with
the request (optional)
:param headers: HTTP Headers to send with
the request (optional).

May be either *iterable of key-value pairs* or
:class:`~collections.abc.Mapping`
(e.g. :class:`dict`,
:class:`~aiohttp.multidict.CIMultiDict`).

:param skip_auto_headers: set of headers for which autogeneration
should be skipped.

*aiohttp* autogenerates headers like ``User-Agent`` or
``Content-Type`` if these headers are not explicitly
passed. Using ``skip_auto_headers`` parameter allows to skip
that generation.

Iterable of :class:`str` or :class:`~aiohttp.multidict.upstr` (optional)

:param aiohttp.helpers.BasicAuth auth: BasicAuth named tuple that represents
HTTP Basic Authorization (optional)
Expand Down Expand Up @@ -106,7 +122,8 @@ The client session supports context manager protocol for self closing::


.. coroutinemethod:: request(method, url, *, params=None, data=None,\
headers=None, auth=None, allow_redirects=True,\
headers=None, skip_auto_headers=None, \
auth=None, allow_redirects=True,\
max_redirects=10, encoding='utf-8',\
version=HttpVersion(major=1, minor=1),\
compress=None, chunked=None, expect100=False,\
Expand All @@ -128,6 +145,17 @@ The client session supports context manager protocol for self closing::
:param dict headers: HTTP Headers to send with
the request (optional)

:param skip_auto_headers: set of headers for which autogeneration
should be skipped.

*aiohttp* autogenerates headers like ``User-Agent`` or
``Content-Type`` if these headers are not explicitly
passed. Using ``skip_auto_headers`` parameter allows to skip
that generation.

Iterable of :class:`str` or :class:`~aiohttp.multidict.upstr`
(optional)

:param aiohttp.helpers.BasicAuth auth: BasicAuth named tuple that
represents HTTP Basic Authorization
(optional)
Expand Down
116 changes: 116 additions & 0 deletions tests/test_client_functional2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import asyncio
import socket
import unittest

import aiohttp
from aiohttp import hdrs, log, web


class TestClientFunctional2(unittest.TestCase):

def setUp(self):
self.handler = None
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
self.client = aiohttp.ClientSession(loop=self.loop)

def tearDown(self):
if self.handler:
self.loop.run_until_complete(self.handler.finish_connections())
self.client.close()
self.loop.stop()
self.loop.run_forever()
self.loop.close()

def find_unused_port(self):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]
s.close()
return port

@asyncio.coroutine
def create_server(self):
app = web.Application(loop=self.loop)

port = self.find_unused_port()
self.handler = app.make_handler(
debug=True, keep_alive_on=False,
access_log=log.access_logger)
srv = yield from self.loop.create_server(
self.handler, '127.0.0.1', port)
url = "http://127.0.0.1:{}".format(port)
self.addCleanup(srv.close)
return app, srv, url

def test_auto_header_user_agent(self):
@asyncio.coroutine
def handler(request):
self.assertIn('aiohttp', request.headers['user-agent'])
return web.Response()

@asyncio.coroutine
def go():
app, srv, url = yield from self.create_server()
app.router.add_route('get', '/', handler)
resp = yield from self.client.get(url+'/')
self.assertEqual(200, resp.status)
yield from resp.release()

self.loop.run_until_complete(go())

def test_skip_auto_headers_user_agent(self):
@asyncio.coroutine
def handler(request):
self.assertNotIn(hdrs.USER_AGENT, request.headers)
return web.Response()

@asyncio.coroutine
def go():
app, srv, url = yield from self.create_server()
app.router.add_route('get', '/', handler)
resp = yield from self.client.get(url+'/',
skip_auto_headers=['user-agent'])
self.assertEqual(200, resp.status)
yield from resp.release()

self.loop.run_until_complete(go())

def test_skip_default_auto_headers_user_agent(self):
@asyncio.coroutine
def handler(request):
self.assertNotIn(hdrs.USER_AGENT, request.headers)
return web.Response()

@asyncio.coroutine
def go():
app, srv, url = yield from self.create_server()
app.router.add_route('get', '/', handler)

client = aiohttp.ClientSession(loop=self.loop,
skip_auto_headers=['user-agent'])
resp = yield from client.get(url+'/')
self.assertEqual(200, resp.status)
yield from resp.release()

client.close()

self.loop.run_until_complete(go())

def test_skip_auto_headers_content_type(self):
@asyncio.coroutine
def handler(request):
self.assertNotIn(hdrs.CONTENT_TYPE, request.headers)
return web.Response()

@asyncio.coroutine
def go():
app, srv, url = yield from self.create_server()
app.router.add_route('get', '/', handler)
resp = yield from self.client.get(
url+'/',
skip_auto_headers=['content-type'])
self.assertEqual(200, resp.status)
yield from resp.release()

self.loop.run_until_complete(go())
20 changes: 20 additions & 0 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import aiohttp
from aiohttp.client_reqrep import ClientRequest, ClientResponse
from aiohttp.multidict import upstr

PY_341 = sys.version_info >= (3, 4, 1)

Expand Down Expand Up @@ -116,6 +117,25 @@ def test_host_header(self):
self.assertEqual(req.headers['HOST'], 'example.com:99')
self.loop.run_until_complete(req.close())

def test_default_headers_useragent(self):
req = ClientRequest('get', 'http://python.org/', loop=self.loop)

self.assertNotIn('SERVER', req.headers)
self.assertIn('USER-AGENT', req.headers)

def test_default_headers_useragent_custom(self):
req = ClientRequest('get', 'http://python.org/', loop=self.loop,
headers={'user-agent': 'my custom agent'})

self.assertIn('USER-Agent', req.headers)
self.assertEqual('my custom agent', req.headers['User-Agent'])

def test_skip_default_useragent_header(self):
req = ClientRequest('get', 'http://python.org/', loop=self.loop,
skip_auto_headers=set([upstr('user-agent')]))

self.assertNotIn('User-Agent', req.headers)

def test_headers(self):
req = ClientRequest('get', 'http://python.org/',
headers={'Content-Type': 'text/plain'},
Expand Down
33 changes: 31 additions & 2 deletions tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ def test_init_headers_MultiDict(self):
("H3", "header3")]))
session.close()

def test_init_headers_list_of_tuples_with_duplicates(self):
session = ClientSession(
headers=[("h1", "header11"),
("h2", "header21"),
("h1", "header12")],
loop=self.loop)
self.assertEqual(
session._default_headers,
CIMultiDict([("H1", "header11"),
("H2", "header21"),
("H1", "header12")]))
session.close()

def test_init_cookies_with_simple_dict(self):
session = ClientSession(
cookies={
Expand Down Expand Up @@ -142,8 +155,24 @@ def test_merge_headers_with_list_of_tuples(self):
]))
session.close()

def _make_one(self):
session = ClientSession(loop=self.loop)
def test_merge_headers_with_list_of_tuples_duplicated_names(self):
session = ClientSession(
headers={
"h1": "header1",
"h2": "header2"
}, loop=self.loop)
headers = session._prepare_headers([("h1", "v1"),
("h1", "v2")])
self.assertIsInstance(headers, CIMultiDict)
self.assertEqual(headers, CIMultiDict([
("H2", "header2"),
("H1", "v1"),
("H1", "v2"),
]))
session.close()

def _make_one(self, **kwargs):
session = ClientSession(loop=self.loop, **kwargs)
params = dict(
headers={"Authorization": "Basic ..."},
max_redirects=2,
Expand Down
Loading