Skip to content

Commit

Permalink
Fixed task output stream
Browse files Browse the repository at this point in the history
  • Loading branch information
Lubos Matl committed Sep 20, 2019
1 parent a356291 commit 1641eb3
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 39 deletions.
4 changes: 2 additions & 2 deletions example/apps/test_security/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
base=LoggedTask,
bind=True,
name='sum_task')
def sum_task(self, task_id, a, b):
def sum_task(self, a, b):
return a + b


@celery_app.task(
base=LoggedTask,
bind=True,
name='error_task')
def error_task(self, task_id):
def error_task(self):
raise RuntimeError('error')
18 changes: 9 additions & 9 deletions example/apps/test_security/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,28 +83,28 @@ def test_request_body_should_be_truncated(self):
self.post('/admin/login/', data={'username': 20 * 'a', 'password': 20 * 'b'})
input_logged_request = InputLoggedRequest.objects.get()
assert_equal(len(input_logged_request.request_body), 10)
assert_true(input_logged_request.request_body.endswith('...'))
assert_true(input_logged_request.request_body.endswith(''))

@override_settings(SECURITY_LOG_RESPONSE_BODY_LENGTH=10)
def test_response_body_should_be_truncated(self):
self.post('/admin/login/', data={'username': 20 * 'a', 'password': 20 * 'b'})
input_logged_request = InputLoggedRequest.objects.get()
assert_equal(len(input_logged_request.response_body), 10)
assert_true(input_logged_request.response_body.endswith('...'))
assert_true(input_logged_request.response_body.endswith(''))

@override_settings(SECURITY_LOG_REQUEST_BODY_LENGTH=None)
def test_request_body_truncation_should_be_turned_off(self):
self.post('/admin/login/', data={'username': 2000 * 'a', 'password': 2000 * 'b'})
input_logged_request = InputLoggedRequest.objects.get()
assert_equal(len(input_logged_request.request_body), 4183)
assert_false(input_logged_request.request_body.endswith('...'))
assert_false(input_logged_request.request_body.endswith(''))

@override_settings(SECURITY_LOG_RESPONSE_BODY_LENGTH=None)
def test_response_body_truncation_should_be_turned_off(self):
resp = self.post('/admin/login/', data={'username': 20 * 'a', 'password': 20 * 'b'})
input_logged_request = InputLoggedRequest.objects.get()
assert_equal(input_logged_request.response_body, str(resp.content))
assert_false(input_logged_request.response_body.endswith('...'))
assert_false(input_logged_request.response_body.endswith(''))

@override_settings(SECURITY_LOG_RESPONSE_BODY_CONTENT_TYPES=())
def test_not_allowed_content_type_should_not_be_logged(self):
Expand All @@ -125,25 +125,25 @@ def test_json_request_should_be_truncated_with_another_method(self):
input_logged_request = InputLoggedRequest.objects.get()
assert_equal(
json.loads(input_logged_request.request_body),
json.loads('{"a": "aaaaaaa...", "b": "bbbbbbb..."}')
json.loads('{"a": "aaaaaaaaa…", "b": "bbbbbbbbb…"}')
)
assert_false(input_logged_request.request_body.endswith('...'))
assert_false(input_logged_request.request_body.endswith(''))

@override_settings(SECURITY_LOG_REQUEST_BODY_LENGTH=50, SECURITY_LOG_JSON_STRING_LENGTH=None)
def test_json_request_should_not_be_truncated_with_another_method(self):
self.c.post('/admin/login/', data=json.dumps({'a': 50 * 'a'}),
content_type='application/json')
input_logged_request = InputLoggedRequest.objects.get()
assert_equal(input_logged_request.request_body, '{"a": "' + 40* 'a' + '...')
assert_true(input_logged_request.request_body.endswith('...'))
assert_equal(input_logged_request.request_body, '{"a": "' + 42* 'a' + '')
assert_true(input_logged_request.request_body.endswith(''))

@override_settings(SECURITY_LOG_REQUEST_BODY_LENGTH=100, SECURITY_LOG_JSON_STRING_LENGTH=10)
def test_json_request_should_be_truncated_with_another_method_and_standard_method_too(self):
self.c.post('/admin/login/', data=json.dumps({50 * 'a': 50 * 'a', 50 * 'b': 50 * 'b'}),
content_type='application/json')
input_logged_request = InputLoggedRequest.objects.get()
assert_equal(len(input_logged_request.request_body), 100)
assert_true(input_logged_request.request_body.endswith('...'))
assert_true(input_logged_request.request_body.endswith(''))

def test_response_with_exception_should_be_logged(self):
assert_equal(InputLoggedRequest.objects.count(), 0)
Expand Down
2 changes: 1 addition & 1 deletion example/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Django==2.0
Django==2.2
django-germanium==2.0.5
django-ipware>=1.0.0
coverage==4.0.2
Expand Down
1 change: 1 addition & 0 deletions example/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
'OPTIONS': {
'context_processors': [
'django.contrib.auth.context_processors.auth',
'django.contrib.messages.context_processors.messages',
]
}
}
Expand Down
2 changes: 0 additions & 2 deletions example/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
proxy_view, hide_request_body_view, log_exempt_view, throttling_exempt_view, extra_throttling_view
)

admin.autodiscover()


urlpatterns = [
path('admin/', admin.site.urls),
Expand Down
2 changes: 2 additions & 0 deletions security/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import json
from json import JSONDecodeError

from datetime import timedelta

from ipware.ip import get_ip
from jsonfield import JSONField

Expand Down
43 changes: 18 additions & 25 deletions security/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,19 @@ class LoggedTask(Task):
abstract = True
logger_level = logging.WARNING

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.output = StringIO()

@property
def stdout(self):
return self.output

@property
def stderr(self):
return self.output
def push_request(self, *args, **kwargs):
task_id = self.request.id
output_stream = self.request.output_stream
super().push_request(*args, **kwargs)
self.request.id = task_id
self.request.output_stream = output_stream

def get_task(self, task_id):
return CeleryTaskLog.objects.get(pk=task_id)

def __call__(self, *args, **kwargs):
# Every set attr is send here
self.request.output_stream = StringIO()
self.on_start(self.request.id, args, kwargs)
super().__call__(*args, **kwargs)

Expand All @@ -62,10 +58,13 @@ def on_start(self, task_id, args, kwargs):
self._call_callback('start', task_id, args, kwargs)

def on_success(self, retval, task_id, args, kwargs):
if retval:
self.request.output_stream.write('Return value is "{}"'.format(retval))

self.get_task(task_id).change_and_save(
state=CeleryTaskLogState.SUCCEEDED,
stop=now(),
output=self.output.getvalue()
output=self.request.output_stream.getvalue()
)
self._call_callback('success', task_id, args, kwargs)

Expand All @@ -75,9 +74,9 @@ def on_failure(self, exc, task_id, args, kwargs, einfo):
state=CeleryTaskLogState.FAILED,
stop=now(),
error_message=einfo,
output = self.output.getvalue()
output=self.request.output_stream.getvalue()
)
except CeleryTask.DoesNotExist:
except CeleryTaskLog.DoesNotExist:
pass
self._call_callback('failure', task_id, args, kwargs)

Expand All @@ -96,12 +95,8 @@ def _create_task(self, options, task_args, task_kwargs):
)
return str(task.pk)

def _get_args(self, task_id, args):
return (task_id,) + tuple(args or ())

def apply_async_on_commit(self, args=None, kwargs=None, **options):
task_id = self._create_task(options, args, kwargs)
args = self._get_args(task_id, args)
self.on_apply(task_id, args, kwargs)
if sys.argv[1:2] == ['test']:
super().apply_async(args=args, kwargs=kwargs, task_id=task_id, **options)
Expand All @@ -113,11 +108,10 @@ def apply_async_on_commit(self, args=None, kwargs=None, **options):

def apply_async(self, args=None, kwargs=None, **options):
task_id = self._create_task(options, args, kwargs)
args = self._get_args(task_id, args)
self.on_apply(task_id, args, kwargs)
return super().apply_async(args=args, kwargs=kwargs, task_id=task_id, **options)

def log_and_retry(self, attempt, exception_message=None, *args, **kwargs):
def log_and_retry(self, attempt, exception_message=None, queue=None, *args, **kwargs):
LOGGER.log(self.logger_level, self.retry_error_message.format(
attempt=attempt, exception_message=exception_message, **kwargs
))
Expand All @@ -126,7 +120,7 @@ def log_and_retry(self, attempt, exception_message=None, *args, **kwargs):
args=args,
kwargs={**kwargs, 'attempt': attempt+1},
countdown=self.repeat_timeouts[attempt - 1] * 60,
queue=self.queue
queue=queue or getattr(self, 'queue', settings.CELERY_DEFAULT_QUEUE)
)


Expand All @@ -144,13 +138,12 @@ def string_to_obj(obj_string):
name='call_django_command'
)
@atomic_with_signals
def call_django_command(self, task_id, command_name, command_args=None):
def call_django_command(self, command_name, command_args=None):
command_args = [] if command_args is None else command_args
call_command(
command_name,
settings=os.environ.get('DJANGO_SETTINGS_MODULE'),
*command_args,
stdout=self.stdout,
stderr=self.stderr
stdout=self.request.output_stream,
stderr=self.request.output_stream,
)

0 comments on commit 1641eb3

Please sign in to comment.