-
Notifications
You must be signed in to change notification settings - Fork 145
/
Copy pathasync_context.py
111 lines (87 loc) · 3.61 KB
/
async_context.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import asyncio
import copy
from .context import Context as _Context
class AsyncContext(_Context):
"""
Async Context for storing segments.
Inherits nearly everything from the main Context class.
Replaces threading.local with a task based local storage class,
Also overrides clear_trace_entities
"""
def __init__(self, *args, loop=None, use_task_factory=True, **kwargs):
super().__init__(*args, **kwargs)
self._loop = loop
if loop is None:
self._loop = asyncio.get_event_loop()
if use_task_factory:
self._loop.set_task_factory(task_factory)
self._local = TaskLocalStorage(loop=loop)
def clear_trace_entities(self):
"""
Clear all trace_entities stored in the task local context.
"""
if self._local is not None:
self._local.clear()
class TaskLocalStorage:
"""
Simple task local storage
"""
def __init__(self, loop=None):
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
def __setattr__(self, name, value):
if name in ('_loop',):
# Set normal attributes
object.__setattr__(self, name, value)
else:
# Set task local attributes
task = asyncio.current_task(loop=self._loop)
if task is None:
return None
if not hasattr(task, 'context'):
task.context = {}
task.context[name] = value
def __getattribute__(self, item):
if item in ('_loop', 'clear'):
# Return references to local objects
return object.__getattribute__(self, item)
task = asyncio.current_task(loop=self._loop)
if task is None:
return None
if hasattr(task, 'context') and item in task.context:
return task.context[item]
raise AttributeError('Task context does not have attribute {0}'.format(item))
def clear(self):
# If were in a task, clear the context dictionary
task = asyncio.current_task(loop=self._loop)
if task is not None and hasattr(task, 'context'):
task.context.clear()
def task_factory(loop, coro):
"""
Task factory function
Fuction closely mirrors the logic inside of
asyncio.BaseEventLoop.create_task. Then if there is a current
task and the current task has a context then share that context
with the new task
"""
task = asyncio.Task(coro, loop=loop)
if task._source_traceback: # flake8: noqa
del task._source_traceback[-1] # flake8: noqa
# Share context with new task if possible
current_task = asyncio.current_task(loop=loop)
if current_task is not None and hasattr(current_task, 'context'):
if current_task.context.get('entities'):
# NOTE: (enowell) Because the `AWSXRayRecorder`'s `Context` decides
# the parent by looking at its `_local.entities`, we must copy the entities
# for concurrent subsegments. Otherwise, the subsegments would be
# modifying the same `entities` list and sugsegments would take other
# subsegments as parents instead of the original `segment`.
#
# See more: https://github.com/aws/aws-xray-sdk-python/blob/0f13101e4dba7b5c735371cb922f727b1d9f46d8/aws_xray_sdk/core/context.py#L90-L101
new_context = copy.copy(current_task.context)
new_context['entities'] = [item for item in current_task.context['entities']]
else:
new_context = current_task.context
setattr(task, 'context', new_context)
return task