55import string
66import sys
77import time
8- from adb .usb_exceptions import TcpTimeoutException
8+ from mock import mock
9+
10+ from adb .common import TcpHandle , UsbHandle
11+ from adb .usb_exceptions import TcpTimeoutException
912
1013PRINTABLE_DATA = set (string .printable ) - set (string .whitespace )
1114
@@ -16,33 +19,23 @@ def _Dotify(data):
1619 return '' .join (char if char in PRINTABLE_DATA else '.' for char in data )
1720
1821
19- class StubUsb (object ):
20- """UsbHandle stub."""
21-
22- def __init__ (self ):
22+ class StubHandleBase (object ):
23+ def __init__ (self , timeout_ms , is_tcp = False ):
2324 self .written_data = []
2425 self .read_data = []
25- self .timeout_ms = 0
26+ self .is_tcp = is_tcp
27+ self .timeout_ms = timeout_ms
2628
27- def BulkWrite (self , data , unused_timeout_ms = None ):
28- expected_data = self .written_data .pop (0 )
29- if isinstance (data , bytearray ):
30- data = bytes (data )
31- if not isinstance (data , bytes ):
32- data = data .encode ('utf8' )
33- if expected_data != data :
34- raise ValueError ('Expected %s (%s) got %s (%s)' % (
35- binascii .hexlify (expected_data ), _Dotify (expected_data ),
36- binascii .hexlify (data ), _Dotify (data )))
29+ def _signal_handler (self , signum , frame ):
30+ raise TcpTimeoutException ('End of time' )
3731
38- def BulkRead (self , length ,
39- timeout_ms = None ): # pylint: disable=unused-argument
40- data = self .read_data .pop (0 )
41- if length < len (data ):
42- raise ValueError (
43- 'Overflow packet length. Read %d bytes, got %d bytes: %s' ,
44- length , len (data ))
45- return bytearray (data )
32+ def _return_seconds (self , time_ms ):
33+ return (float (time_ms )/ 1000 ) if time_ms else 0
34+
35+ def _alarm_sounder (self , timeout_ms ):
36+ signal .signal (signal .SIGALRM , self ._signal_handler )
37+ signal .setitimer (signal .ITIMER_REAL ,
38+ self ._return_seconds (timeout_ms ))
4639
4740 def ExpectWrite (self , data ):
4841 if not isinstance (data , bytes ):
@@ -54,22 +47,6 @@ def ExpectRead(self, data):
5447 data = data .encode ('utf8' )
5548 self .read_data .append (data )
5649
57- def Timeout (self , timeout_ms ):
58- return timeout_ms if timeout_ms is not None else self .timeout_ms
59-
60- class StubTcp (StubUsb ):
61-
62- def _signal_handler (self , signum , frame ):
63- raise TcpTimeoutException ('End of time' )
64-
65- def _return_seconds (self , time_ms ):
66- return (float (time_ms )/ 1000 ) if time_ms else 0
67-
68- def _alarm_sounder (self , timeout_ms ):
69- signal .signal (signal .SIGALRM , self ._signal_handler )
70- signal .setitimer (signal .ITIMER_REAL ,
71- self ._return_seconds (timeout_ms ))
72-
7350 def BulkWrite (self , data , timeout_ms = None ):
7451 expected_data = self .written_data .pop (0 )
7552 if isinstance (data , bytearray ):
@@ -80,8 +57,8 @@ def BulkWrite(self, data, timeout_ms=None):
8057 raise ValueError ('Expected %s (%s) got %s (%s)' % (
8158 binascii .hexlify (expected_data ), _Dotify (expected_data ),
8259 binascii .hexlify (data ), _Dotify (data )))
83- if b'i_need_a_timeout' in data :
84- self ._alarm_sounder (timeout_ms )
60+ if self . is_tcp and b'i_need_a_timeout' in data :
61+ self ._alarm_sounder (timeout_ms )
8562 time .sleep (2 * self ._return_seconds (timeout_ms ))
8663
8764 def BulkRead (self , length ,
@@ -91,8 +68,56 @@ def BulkRead(self, length,
9168 raise ValueError (
9269 'Overflow packet length. Read %d bytes, got %d bytes: %s' ,
9370 length , len (data ))
94- if b'i_need_a_timeout' in data :
95- self ._alarm_sounder (timeout_ms )
71+ if self . is_tcp and b'i_need_a_timeout' in data :
72+ self ._alarm_sounder (timeout_ms )
9673 time .sleep (2 * self ._return_seconds (timeout_ms ))
97- return bytearray (data )
74+ return bytearray (data )
75+
76+ def Timeout (self , timeout_ms ):
77+ return timeout_ms if timeout_ms is not None else self .timeout_ms
78+
79+
80+ class StubUsb (UsbHandle ):
81+ """UsbHandle stub."""
82+ def __init__ (self , device , setting , usb_info = None , timeout_ms = None ):
83+ super (StubUsb , self ).__init__ (device , setting , usb_info , timeout_ms )
84+ self .stub_base = StubHandleBase (0 )
85+
86+ def ExpectWrite (self , data ):
87+ return self .stub_base .ExpectWrite (data )
88+
89+ def ExpectRead (self , data ):
90+ return self .stub_base .ExpectRead (data )
91+
92+ def BulkWrite (self , data , unused_timeout_ms = None ):
93+ return self .stub_base .BulkWrite (data , unused_timeout_ms )
94+
95+ def BulkRead (self , length , timeout_ms = None ):
96+ return self .stub_base .BulkRead (length , timeout_ms )
97+
98+ def Timeout (self , timeout_ms ):
99+ return self .stub_base .Timeout (timeout_ms )
100+
101+
102+ class StubTcp (TcpHandle ):
103+ def __init__ (self , serial , timeout_ms = None ):
104+ """TcpHandle stub."""
105+ self ._connect = mock .MagicMock (return_value = None )
106+
107+ super (StubTcp , self ).__init__ (serial , timeout_ms )
108+ self .stub_base = StubHandleBase (0 , is_tcp = True )
98109
110+ def ExpectWrite (self , data ):
111+ return self .stub_base .ExpectWrite (data )
112+
113+ def ExpectRead (self , data ):
114+ return self .stub_base .ExpectRead (data )
115+
116+ def BulkWrite (self , data , unused_timeout_ms = None ):
117+ return self .stub_base .BulkWrite (data , unused_timeout_ms )
118+
119+ def BulkRead (self , length , timeout_ms = None ):
120+ return self .stub_base .BulkRead (length , timeout_ms )
121+
122+ def Timeout (self , timeout_ms ):
123+ return self .stub_base .Timeout (timeout_ms )
0 commit comments