1
+ from contextlib import contextmanager
1
2
import unittest
2
3
from unittest .mock import Mock , patch , MagicMock
3
4
4
5
import requests
5
6
6
7
import databricks .sql .cloudfetch .downloader as downloader
8
+ from databricks .sql .common .http import DatabricksHttpClient
7
9
from databricks .sql .exc import Error
8
10
from databricks .sql .types import SSLOptions
9
11
@@ -12,6 +14,7 @@ def create_response(**kwargs) -> requests.Response:
12
14
result = requests .Response ()
13
15
for k , v in kwargs .items ():
14
16
setattr (result , k , v )
17
+ result .close = Mock ()
15
18
return result
16
19
17
20
@@ -52,91 +55,94 @@ def test_run_link_past_expiry_buffer(self, mock_time):
52
55
53
56
mock_time .assert_called_once ()
54
57
55
- @patch ("requests.Session" , return_value = MagicMock (get = MagicMock (return_value = None )))
56
58
@patch ("time.time" , return_value = 1000 )
57
- def test_run_get_response_not_ok (self , mock_time , mock_session ):
58
- mock_session .return_value .get .return_value = create_response (status_code = 404 )
59
-
59
+ def test_run_get_response_not_ok (self , mock_time ):
60
+ http_client = DatabricksHttpClient .get_instance ()
60
61
settings = Mock (link_expiry_buffer_secs = 0 , download_timeout = 0 )
61
62
settings .download_timeout = 0
62
63
settings .use_proxy = False
63
64
result_link = Mock (expiryTime = 1001 )
64
65
65
- d = downloader .ResultSetDownloadHandler (
66
- settings , result_link , ssl_options = SSLOptions ()
67
- )
68
- with self .assertRaises (requests .exceptions .HTTPError ) as context :
69
- d .run ()
70
- self .assertTrue ("404" in str (context .exception ))
66
+ with patch .object (
67
+ http_client ,
68
+ "execute" ,
69
+ return_value = create_response (status_code = 404 , _content = b"1234" ),
70
+ ):
71
+ d = downloader .ResultSetDownloadHandler (
72
+ settings , result_link , ssl_options = SSLOptions ()
73
+ )
74
+ with self .assertRaises (requests .exceptions .HTTPError ) as context :
75
+ d .run ()
76
+ self .assertTrue ("404" in str (context .exception ))
71
77
72
- @patch ("requests.Session" , return_value = MagicMock (get = MagicMock (return_value = None )))
73
78
@patch ("time.time" , return_value = 1000 )
74
- def test_run_uncompressed_successful (self , mock_time , mock_session ):
79
+ def test_run_uncompressed_successful (self , mock_time ):
80
+ http_client = DatabricksHttpClient .get_instance ()
75
81
file_bytes = b"1234567890" * 10
76
- mock_session .return_value .get .return_value = create_response (
77
- status_code = 200 , _content = file_bytes
78
- )
79
-
80
82
settings = Mock (link_expiry_buffer_secs = 0 , download_timeout = 0 , use_proxy = False )
81
83
settings .is_lz4_compressed = False
82
84
result_link = Mock (bytesNum = 100 , expiryTime = 1001 )
83
85
84
- d = downloader .ResultSetDownloadHandler (
85
- settings , result_link , ssl_options = SSLOptions ()
86
- )
87
- file = d .run ()
86
+ with patch .object (
87
+ http_client ,
88
+ "execute" ,
89
+ return_value = create_response (status_code = 200 , _content = file_bytes ),
90
+ ):
91
+ d = downloader .ResultSetDownloadHandler (
92
+ settings , result_link , ssl_options = SSLOptions ()
93
+ )
94
+ file = d .run ()
88
95
89
- assert file .file_bytes == b"1234567890" * 10
96
+ assert file .file_bytes == b"1234567890" * 10
90
97
91
- @patch (
92
- "requests.Session" ,
93
- return_value = MagicMock (get = MagicMock (return_value = MagicMock (ok = True ))),
94
- )
95
98
@patch ("time.time" , return_value = 1000 )
96
- def test_run_compressed_successful (self , mock_time , mock_session ):
99
+ def test_run_compressed_successful (self , mock_time ):
100
+ http_client = DatabricksHttpClient .get_instance ()
97
101
file_bytes = b"1234567890" * 10
98
102
compressed_bytes = b'\x04 "M\x18 h@d\x00 \x00 \x00 \x00 \x00 \x00 \x00 #\x14 \x00 \x00 \x00 \xaf 1234567890\n \x00 BP67890\x00 \x00 \x00 \x00 '
99
- mock_session .return_value .get .return_value = create_response (
100
- status_code = 200 , _content = compressed_bytes
101
- )
102
103
103
104
settings = Mock (link_expiry_buffer_secs = 0 , download_timeout = 0 , use_proxy = False )
104
105
settings .is_lz4_compressed = True
105
106
result_link = Mock (bytesNum = 100 , expiryTime = 1001 )
107
+ with patch .object (
108
+ http_client ,
109
+ "execute" ,
110
+ return_value = create_response (status_code = 200 , _content = compressed_bytes ),
111
+ ):
112
+ d = downloader .ResultSetDownloadHandler (
113
+ settings , result_link , ssl_options = SSLOptions ()
114
+ )
115
+ file = d .run ()
116
+
117
+ assert file .file_bytes == b"1234567890" * 10
106
118
107
- d = downloader .ResultSetDownloadHandler (
108
- settings , result_link , ssl_options = SSLOptions ()
109
- )
110
- file = d .run ()
111
-
112
- assert file .file_bytes == b"1234567890" * 10
113
-
114
- @patch ("requests.Session.get" , side_effect = ConnectionError ("foo" ))
115
119
@patch ("time.time" , return_value = 1000 )
116
- def test_download_connection_error (self , mock_time , mock_session ):
120
+ def test_download_connection_error (self , mock_time ):
121
+
122
+ http_client = DatabricksHttpClient .get_instance ()
117
123
settings = Mock (
118
124
link_expiry_buffer_secs = 0 , use_proxy = False , is_lz4_compressed = True
119
125
)
120
126
result_link = Mock (bytesNum = 100 , expiryTime = 1001 )
121
- mock_session .return_value .get .return_value .content = b'\x04 "M\x18 h@d\x00 \x00 \x00 \x00 \x00 \x00 \x00 #\x14 \x00 \x00 \x00 \xaf 1234567890\n \x00 BP67890\x00 \x00 \x00 \x00 '
122
127
123
- d = downloader .ResultSetDownloadHandler (
124
- settings , result_link , ssl_options = SSLOptions ()
125
- )
126
- with self .assertRaises (ConnectionError ):
127
- d .run ()
128
+ with patch .object (http_client , "execute" , side_effect = ConnectionError ("foo" )):
129
+ d = downloader .ResultSetDownloadHandler (
130
+ settings , result_link , ssl_options = SSLOptions ()
131
+ )
132
+ with self .assertRaises (ConnectionError ):
133
+ d .run ()
128
134
129
- @patch ("requests.Session.get" , side_effect = TimeoutError ("foo" ))
130
135
@patch ("time.time" , return_value = 1000 )
131
- def test_download_timeout (self , mock_time , mock_session ):
136
+ def test_download_timeout (self , mock_time ):
137
+ http_client = DatabricksHttpClient .get_instance ()
132
138
settings = Mock (
133
139
link_expiry_buffer_secs = 0 , use_proxy = False , is_lz4_compressed = True
134
140
)
135
141
result_link = Mock (bytesNum = 100 , expiryTime = 1001 )
136
- mock_session .return_value .get .return_value .content = b'\x04 "M\x18 h@d\x00 \x00 \x00 \x00 \x00 \x00 \x00 #\x14 \x00 \x00 \x00 \xaf 1234567890\n \x00 BP67890\x00 \x00 \x00 \x00 '
137
142
138
- d = downloader .ResultSetDownloadHandler (
139
- settings , result_link , ssl_options = SSLOptions ()
140
- )
141
- with self .assertRaises (TimeoutError ):
142
- d .run ()
143
+ with patch .object (http_client , "execute" , side_effect = TimeoutError ("foo" )):
144
+ d = downloader .ResultSetDownloadHandler (
145
+ settings , result_link , ssl_options = SSLOptions ()
146
+ )
147
+ with self .assertRaises (TimeoutError ):
148
+ d .run ()
0 commit comments