1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import asyncio
1516import types
1617from unittest import mock
1718
@@ -45,6 +46,35 @@ def __exit__(self, *args):
4546 return self
4647
4748
49+ class MockAsyncCursor :
50+ def __init__ (self , * args , ** kwargs ):
51+ pass
52+
53+ # pylint: disable=unused-argument, no-self-use
54+ async def execute (self , query , params = None , throw_exception = False ):
55+ if throw_exception :
56+ raise Exception ("Test Exception" )
57+
58+ # pylint: disable=unused-argument, no-self-use
59+ async def executemany (self , query , params = None , throw_exception = False ):
60+ if throw_exception :
61+ raise Exception ("Test Exception" )
62+
63+ # pylint: disable=unused-argument, no-self-use
64+ async def callproc (self , query , params = None , throw_exception = False ):
65+ if throw_exception :
66+ raise Exception ("Test Exception" )
67+
68+ async def __aenter__ (self , * args , ** kwargs ):
69+ return self
70+
71+ async def __aexit__ (self , * args , ** kwargs ):
72+ pass
73+
74+ def close (self ):
75+ pass
76+
77+
4878class MockConnection :
4979 commit = mock .MagicMock (spec = types .MethodType )
5080 commit .__name__ = "commit"
@@ -64,22 +94,68 @@ def get_dsn_parameters(self): # pylint: disable=no-self-use
6494 return {"dbname" : "test" }
6595
6696
97+ class MockAsyncConnection :
98+ commit = mock .MagicMock (spec = types .MethodType )
99+ commit .__name__ = "commit"
100+
101+ rollback = mock .MagicMock (spec = types .MethodType )
102+ rollback .__name__ = "rollback"
103+
104+ def __init__ (self , * args , ** kwargs ):
105+ self .cursor_factory = kwargs .pop ("cursor_factory" , None )
106+
107+ @staticmethod
108+ async def connect (* args , ** kwargs ):
109+ return MockAsyncConnection (** kwargs )
110+
111+ def cursor (self ):
112+ if self .cursor_factory :
113+ cur = self .cursor_factory (self )
114+ return cur
115+ return MockAsyncCursor ()
116+
117+ def get_dsn_parameters (self ): # pylint: disable=no-self-use
118+ return {"dbname" : "test" }
119+
120+ async def __aenter__ (self ):
121+ return self
122+
123+ async def __aexit__ (self , * args ):
124+ return mock .MagicMock (spec = types .MethodType )
125+
126+
67127class TestPostgresqlIntegration (TestBase ):
68128 def setUp (self ):
69129 super ().setUp ()
70130 self .cursor_mock = mock .patch (
71131 "opentelemetry.instrumentation.psycopg.pg_cursor" , MockCursor
72132 )
133+ self .cursor_async_mock = mock .patch (
134+ "opentelemetry.instrumentation.psycopg.pg_async_cursor" ,
135+ MockAsyncCursor ,
136+ )
73137 self .connection_mock = mock .patch ("psycopg.connect" , MockConnection )
138+ self .connection_sync_mock = mock .patch (
139+ "psycopg.Connection.connect" , MockConnection
140+ )
141+ self .connection_async_mock = mock .patch (
142+ "psycopg.AsyncConnection.connect" , MockAsyncConnection .connect
143+ )
74144
75145 self .cursor_mock .start ()
146+ self .cursor_async_mock .start ()
76147 self .connection_mock .start ()
148+ self .connection_sync_mock .start ()
149+ self .connection_async_mock .start ()
77150
78151 def tearDown (self ):
79152 super ().tearDown ()
80153 self .memory_exporter .clear ()
81154 self .cursor_mock .stop ()
155+ self .cursor_async_mock .stop ()
82156 self .connection_mock .stop ()
157+ self .connection_sync_mock .stop ()
158+ self .connection_async_mock .stop ()
83159 with self .disable_logging ():
84160 PsycopgInstrumentor ().uninstrument ()
85161
@@ -114,6 +190,91 @@ def test_instrumentor(self):
114190 spans_list = self .memory_exporter .get_finished_spans ()
115191 self .assertEqual (len (spans_list ), 1 )
116192
193+ # pylint: disable=unused-argument
194+ def test_instrumentor_with_connection_class (self ):
195+ PsycopgInstrumentor ().instrument ()
196+
197+ cnx = psycopg .Connection .connect (database = "test" )
198+
199+ cursor = cnx .cursor ()
200+
201+ query = "SELECT * FROM test"
202+ cursor .execute (query )
203+
204+ spans_list = self .memory_exporter .get_finished_spans ()
205+ self .assertEqual (len (spans_list ), 1 )
206+ span = spans_list [0 ]
207+
208+ # Check version and name in span's instrumentation info
209+ self .assertEqualSpanInstrumentationInfo (
210+ span , opentelemetry .instrumentation .psycopg
211+ )
212+
213+ # check that no spans are generated after uninstrument
214+ PsycopgInstrumentor ().uninstrument ()
215+
216+ cnx = psycopg .Connection .connect (database = "test" )
217+ cursor = cnx .cursor ()
218+ query = "SELECT * FROM test"
219+ cursor .execute (query )
220+
221+ spans_list = self .memory_exporter .get_finished_spans ()
222+ self .assertEqual (len (spans_list ), 1 )
223+
224+ async def test_wrap_async_connection_class_with_cursor (self ):
225+ PsycopgInstrumentor ().instrument ()
226+
227+ async def test_async_connection ():
228+ acnx = await psycopg .AsyncConnection .connect (database = "test" )
229+ async with acnx as cnx :
230+ async with cnx .cursor () as cursor :
231+ await cursor .execute ("SELECT * FROM test" )
232+
233+ asyncio .run (test_async_connection ())
234+ spans_list = self .memory_exporter .get_finished_spans ()
235+ self .assertEqual (len (spans_list ), 1 )
236+ span = spans_list [0 ]
237+
238+ # Check version and name in span's instrumentation info
239+ self .assertEqualSpanInstrumentationInfo (
240+ span , opentelemetry .instrumentation .psycopg
241+ )
242+
243+ # check that no spans are generated after uninstrument
244+ PsycopgInstrumentor ().uninstrument ()
245+
246+ asyncio .run (test_async_connection ())
247+
248+ spans_list = self .memory_exporter .get_finished_spans ()
249+ self .assertEqual (len (spans_list ), 1 )
250+
251+ # pylint: disable=unused-argument
252+ async def test_instrumentor_with_async_connection_class (self ):
253+ PsycopgInstrumentor ().instrument ()
254+
255+ async def test_async_connection ():
256+ acnx = await psycopg .AsyncConnection .connect (database = "test" )
257+ async with acnx as cnx :
258+ await cnx .execute ("SELECT * FROM test" )
259+
260+ asyncio .run (test_async_connection ())
261+
262+ spans_list = self .memory_exporter .get_finished_spans ()
263+ self .assertEqual (len (spans_list ), 1 )
264+ span = spans_list [0 ]
265+
266+ # Check version and name in span's instrumentation info
267+ self .assertEqualSpanInstrumentationInfo (
268+ span , opentelemetry .instrumentation .psycopg
269+ )
270+
271+ # check that no spans are generated after uninstrument
272+ PsycopgInstrumentor ().uninstrument ()
273+ asyncio .run (test_async_connection ())
274+
275+ spans_list = self .memory_exporter .get_finished_spans ()
276+ self .assertEqual (len (spans_list ), 1 )
277+
117278 def test_span_name (self ):
118279 PsycopgInstrumentor ().instrument ()
119280
@@ -140,6 +301,33 @@ def test_span_name(self):
140301 self .assertEqual (spans_list [4 ].name , "query" )
141302 self .assertEqual (spans_list [5 ].name , "query" )
142303
304+ async def test_span_name_async (self ):
305+ PsycopgInstrumentor ().instrument ()
306+
307+ cnx = psycopg .AsyncConnection .connect (database = "test" )
308+ async with cnx .cursor () as cursor :
309+ await cursor .execute ("Test query" , ("param1Value" , False ))
310+ await cursor .execute (
311+ """multi
312+ line
313+ query"""
314+ )
315+ await cursor .execute ("tab\t separated query" )
316+ await cursor .execute ("/* leading comment */ query" )
317+ await cursor .execute (
318+ "/* leading comment */ query /* trailing comment */"
319+ )
320+ await cursor .execute ("query /* trailing comment */" )
321+
322+ spans_list = self .memory_exporter .get_finished_spans ()
323+ self .assertEqual (len (spans_list ), 6 )
324+ self .assertEqual (spans_list [0 ].name , "Test" )
325+ self .assertEqual (spans_list [1 ].name , "multi" )
326+ self .assertEqual (spans_list [2 ].name , "tab" )
327+ self .assertEqual (spans_list [3 ].name , "query" )
328+ self .assertEqual (spans_list [4 ].name , "query" )
329+ self .assertEqual (spans_list [5 ].name , "query" )
330+
143331 # pylint: disable=unused-argument
144332 def test_not_recording (self ):
145333 mock_tracer = mock .Mock ()
@@ -160,6 +348,26 @@ def test_not_recording(self):
160348
161349 PsycopgInstrumentor ().uninstrument ()
162350
351+ # pylint: disable=unused-argument
352+ async def test_not_recording_async (self ):
353+ mock_tracer = mock .Mock ()
354+ mock_span = mock .Mock ()
355+ mock_span .is_recording .return_value = False
356+ mock_tracer .start_span .return_value = mock_span
357+ PsycopgInstrumentor ().instrument ()
358+ with mock .patch ("opentelemetry.trace.get_tracer" ) as tracer :
359+ tracer .return_value = mock_tracer
360+ cnx = psycopg .AsyncConnection .connect (database = "test" )
361+ async with cnx .cursor () as cursor :
362+ query = "SELECT * FROM test"
363+ cursor .execute (query )
364+ self .assertFalse (mock_span .is_recording ())
365+ self .assertTrue (mock_span .is_recording .called )
366+ self .assertFalse (mock_span .set_attribute .called )
367+ self .assertFalse (mock_span .set_status .called )
368+
369+ PsycopgInstrumentor ().uninstrument ()
370+
163371 # pylint: disable=unused-argument
164372 def test_custom_tracer_provider (self ):
165373 resource = resources .Resource .create ({})
0 commit comments