14
14
15
15
use futures:: io:: Error ;
16
16
use log:: * ;
17
- use sfw_provider_requests:: requests:: { ProviderRequest , PullRequest , RegisterRequest } ;
17
+ use sfw_provider_requests:: auth_token:: AuthToken ;
18
+ use sfw_provider_requests:: requests:: {
19
+ async_io:: TokioAsyncRequestWriter , ProviderRequest , PullRequest , RegisterRequest ,
20
+ } ;
18
21
use sfw_provider_requests:: responses:: {
19
- ProviderResponse , ProviderResponseError , PullResponse , RegisterResponse ,
22
+ async_io :: TokioAsyncResponseReader , ProviderResponse , ProviderResponseError ,
20
23
} ;
21
- use sfw_provider_requests:: AuthToken ;
22
24
use sphinx:: route:: DestinationAddressBytes ;
23
- use std:: net:: { Shutdown , SocketAddr } ;
24
- use std:: time:: Duration ;
25
+ use std:: net:: SocketAddr ;
25
26
use tokio:: prelude:: * ;
26
27
27
28
#[ derive( Debug ) ]
@@ -50,6 +51,12 @@ impl From<ProviderResponseError> for ProviderClientError {
50
51
ProviderResponseError :: MarshalError => InvalidRequestError ,
51
52
ProviderResponseError :: UnmarshalError => InvalidResponseError ,
52
53
ProviderResponseError :: UnmarshalErrorInvalidLength => InvalidResponseLengthError ,
54
+ ProviderResponseError :: UnmarshalErrorInvalidKind => InvalidResponseLengthError ,
55
+
56
+ ProviderResponseError :: TooLongResponseError => InvalidResponseError ,
57
+ ProviderResponseError :: TooShortResponseError => InvalidResponseError ,
58
+ ProviderResponseError :: IOError ( _) => NetworkError ,
59
+ ProviderResponseError :: RemoteConnectionClosed => NetworkError ,
53
60
}
54
61
}
55
62
}
@@ -58,72 +65,120 @@ pub struct ProviderClient {
58
65
provider_network_address : SocketAddr ,
59
66
our_address : DestinationAddressBytes ,
60
67
auth_token : Option < AuthToken > ,
68
+ connection : Option < tokio:: net:: TcpStream > ,
69
+ max_response_size : usize ,
61
70
}
62
71
63
72
impl ProviderClient {
64
73
pub fn new (
65
74
provider_network_address : SocketAddr ,
66
75
our_address : DestinationAddressBytes ,
67
76
auth_token : Option < AuthToken > ,
77
+ max_response_size : usize ,
68
78
) -> Self {
69
79
ProviderClient {
70
80
provider_network_address,
71
81
our_address,
72
82
auth_token,
83
+ max_response_size,
84
+ // establish connection when it's necessary (mainly to not break current code
85
+ // as then 'new' would need to be called within async context)
86
+ connection : None ,
87
+ }
88
+ }
89
+
90
+ async fn check_connection ( & mut self ) -> bool {
91
+ if self . connection . is_some ( ) {
92
+ true
93
+ } else {
94
+ // TODO: possibly also introduce timeouts here?
95
+ // However, at this point it's slightly less important as we are in full control
96
+ // of providers.
97
+ self . connection = tokio:: net:: TcpStream :: connect ( self . provider_network_address )
98
+ . await
99
+ . ok ( ) ;
100
+ self . connection . is_some ( )
73
101
}
74
102
}
75
103
76
104
pub fn update_token ( & mut self , auth_token : AuthToken ) {
77
105
self . auth_token = Some ( auth_token)
78
106
}
79
107
80
- pub async fn send_request ( & self , bytes : Vec < u8 > ) -> Result < Vec < u8 > , ProviderClientError > {
81
- let mut socket = tokio:: net:: TcpStream :: connect ( self . provider_network_address ) . await ?;
82
-
83
- socket. set_keepalive ( Some ( Duration :: from_secs ( 2 ) ) ) ?;
84
- socket. write_all ( & bytes[ ..] ) . await ?;
85
- if let Err ( e) = socket. shutdown ( Shutdown :: Write ) {
86
- warn ! ( "failed to close write part of the socket; err = {:?}" , e)
108
+ pub async fn send_request (
109
+ & mut self ,
110
+ request : ProviderRequest ,
111
+ ) -> Result < ProviderResponse , ProviderClientError > {
112
+ if !self . check_connection ( ) . await {
113
+ return Err ( ProviderClientError :: NetworkError ) ;
87
114
}
88
115
89
- let mut response = Vec :: new ( ) ;
90
- socket. read_to_end ( & mut response) . await ?;
91
- if let Err ( e) = socket. shutdown ( Shutdown :: Read ) {
92
- debug ! ( "failed to close read part of the socket; err = {:?}. It was probably already closed by the provider" , e)
116
+ let socket = self . connection . as_mut ( ) . unwrap ( ) ;
117
+ let ( mut socket_reader, mut socket_writer) = socket. split ( ) ;
118
+
119
+ // TODO: benchmark and determine if below should be done:
120
+ // let mut socket_writer = tokio::io::BufWriter::new(socket_writer);
121
+ // let mut socket_reader = tokio::io::BufReader::new(socket_reader);
122
+
123
+ let mut request_writer = TokioAsyncRequestWriter :: new ( & mut socket_writer) ;
124
+ let mut response_reader =
125
+ TokioAsyncResponseReader :: new ( & mut socket_reader, self . max_response_size ) ;
126
+
127
+ if let Err ( e) = request_writer. try_write_request ( request) . await {
128
+ debug ! ( "Failed to write the request - {:?}" , e) ;
129
+ return Err ( e. into ( ) ) ;
93
130
}
94
131
95
- Ok ( response )
132
+ Ok ( response_reader . try_read_response ( ) . await ? )
96
133
}
97
134
98
- pub async fn retrieve_messages ( & self ) -> Result < Vec < Vec < u8 > > , ProviderClientError > {
135
+ pub async fn retrieve_messages ( & mut self ) -> Result < Vec < Vec < u8 > > , ProviderClientError > {
99
136
let auth_token = match self . auth_token . as_ref ( ) {
100
137
Some ( token) => token. clone ( ) ,
101
138
None => {
102
139
return Err ( ProviderClientError :: EmptyAuthTokenError ) ;
103
140
}
104
141
} ;
105
142
106
- let pull_request = PullRequest :: new ( self . our_address . clone ( ) , auth_token) ;
107
- let bytes = pull_request. to_bytes ( ) ;
108
-
109
- let response = self . send_request ( bytes) . await ?;
110
-
111
- let parsed_response = PullResponse :: from_bytes ( & response) ?;
112
- Ok ( parsed_response. messages )
143
+ let pull_request =
144
+ ProviderRequest :: Pull ( PullRequest :: new ( self . our_address . clone ( ) , auth_token) ) ;
145
+ match self . send_request ( pull_request) . await ? {
146
+ ProviderResponse :: Pull ( res) => Ok ( res. extract_messages ( ) ) ,
147
+ ProviderResponse :: Failure ( res) => {
148
+ error ! (
149
+ "We failed to get our request processed - {:?}" ,
150
+ res. get_message( )
151
+ ) ;
152
+ Err ( ProviderClientError :: InvalidResponseError )
153
+ }
154
+ _ => {
155
+ error ! ( "Received response of unexpected type!" ) ;
156
+ Err ( ProviderClientError :: InvalidResponseError )
157
+ }
158
+ }
113
159
}
114
160
115
- pub async fn register ( & self ) -> Result < AuthToken , ProviderClientError > {
161
+ pub async fn register ( & mut self ) -> Result < AuthToken , ProviderClientError > {
116
162
if self . auth_token . is_some ( ) {
117
163
return Err ( ProviderClientError :: ClientAlreadyRegisteredError ) ;
118
164
}
119
165
120
- let register_request = RegisterRequest :: new ( self . our_address . clone ( ) ) ;
121
- let bytes = register_request. to_bytes ( ) ;
122
-
123
- let response = self . send_request ( bytes) . await ?;
124
- let parsed_response = RegisterResponse :: from_bytes ( & response) ?;
125
-
126
- Ok ( parsed_response. auth_token )
166
+ let register_request =
167
+ ProviderRequest :: Register ( RegisterRequest :: new ( self . our_address . clone ( ) ) ) ;
168
+ match self . send_request ( register_request) . await ? {
169
+ ProviderResponse :: Register ( res) => Ok ( res. get_token ( ) ) ,
170
+ ProviderResponse :: Failure ( res) => {
171
+ error ! (
172
+ "We failed to get our request processed - {:?}" ,
173
+ res. get_message( )
174
+ ) ;
175
+ Err ( ProviderClientError :: InvalidResponseError )
176
+ }
177
+ _ => {
178
+ error ! ( "Received response of unexpected type!" ) ;
179
+ Err ( ProviderClientError :: InvalidResponseError )
180
+ }
181
+ }
127
182
}
128
183
129
184
pub fn is_registered ( & self ) -> bool {
0 commit comments