forked from facebook/wdt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
WdtBase.h
270 lines (212 loc) · 7.83 KB
/
WdtBase.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
/**
* Copyright (c) 2014-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*/
#pragma once
#include "ErrorCodes.h"
#include "WdtOptions.h"
#include "Reporting.h"
#include "Throttler.h"
#include "Protocol.h"
#include <memory>
#include <string>
#include <vector>
#include <folly/RWSpinLock.h>
#include <unordered_map>
namespace facebook {
namespace wdt {
/// filename-filesize pair. Negative filesize denotes the entire file.
typedef std::pair<std::string, int64_t> FileInfo;
/**
* Basic Uri class to parse and get information from wdt url
* This class can be used in two ways :
* 1. Construct the class with a url and get fields like
* hostname, and different get parameters
* 2. Construct an empty object and set the fields, and
* generate a url
*
* Example of a url :
* wdt://localhost?dir=/tmp/wdt&ports=22356,22357
*/
class WdtUri {
public:
/// Empty Uri object
WdtUri() = default;
/// Construct the uri object using a string url
explicit WdtUri(const std::string& url);
/// Get the host name of the url
std::string getHostName() const;
/// Get the query param by key
std::string getQueryParam(const std::string& key) const;
/// Get all the query params
const std::unordered_map<std::string, std::string>& getQueryParams() const;
/// Sets hostname to generate a url
void setHostName(const std::string& hostName);
/// Sets a query param in the query params map
void setQueryParam(const std::string& key, const std::string& value);
/// Generate url by serializing the members of this struct
std::string generateUrl() const;
/// Assignment operator to convert string to wdt uri object
WdtUri& operator=(const std::string& url);
/// Clears the field of the uri
void clear();
/// Get the error code if any during parsing
ErrorCode getErrorCode() const;
private:
/**
* Returns whether the url could be processed successfully. Populates
* the values on a best effort basis.
*/
ErrorCode process(const std::string& url);
/**
* Map of get parameters of the url. Key and value
* of the map are the name and value of get parameter respectively
*/
std::unordered_map<std::string, std::string> queryParams_;
/// Prefix of the wdt url
const std::string WDT_URL_PREFIX{"wdt://"};
/// Hostname where the receiever is running
std::string hostName_{""};
/// Error code that reflects that status of parsing url
ErrorCode errorCode_{OK};
};
/**
* Basic request for creating wdt objects
* This request can be used for creating receivers and the
* counter part sender or vice versa
*/
struct WdtTransferRequest {
/**
* Transfer Id for the transfer. It has to be same
* on both sender and receiver
*/
std::string transferId;
/// Protocol version on sender and receiver
int64_t protocolVersion{Protocol::protocol_version};
/// Ports on which receiver is listening / sender is sending to
std::vector<int32_t> ports;
/// Address on which receiver binded the ports / sender is sending data to
std::string hostName;
/// Directory to write the data to / read the data from
std::string directory;
/// Only required for the sender
std::vector<FileInfo> fileInfo;
/// Any error assosciated with this transfer request upon processing
ErrorCode errorCode{OK};
/// Constructor with list of ports
explicit WdtTransferRequest(const std::vector<int32_t>& ports);
/**
* Constructor with start port and num ports. Fills the vector with
* ports from [startPort, startPort + numPorts)
*/
WdtTransferRequest(int startPort, int numPorts, const std::string& directory);
/// Constructor to construct the request object from a url string
explicit WdtTransferRequest(const std::string& uriString);
/// Serialize this structure into a url string containing all fields
std::string generateUrl(bool genFull = false) const;
/// Get stringified port list
std::string getSerializedPortsList() const;
/// Operator for finding if two request objects are equal
bool operator==(const WdtTransferRequest& that) const;
/// Names of the get parameters for different fields
const static std::string TRANSFER_ID_PARAM;
const static std::string PROTOCOL_VERSION_PARAM;
const static std::string DIRECTORY_PARAM;
const static std::string PORTS_PARAM;
};
/**
* Shared code/functionality between Receiver and Sender
* TODO: a lot more code from sender/receiver should move here
*/
class WdtBase {
public:
/// Interface for external abort checks (pull mode)
class IAbortChecker {
public:
virtual bool shouldAbort() const = 0;
virtual ~IAbortChecker() {
}
};
/// Constructor
WdtBase();
/**
* Does the setup before start, returns the transfer request
* that corresponds to the information relating to the sender
* The transfer request has error code set should there be an error
*/
virtual WdtTransferRequest init() = 0;
/// Destructor
virtual ~WdtBase();
/// Transfer can be marked to abort and threads will eventually
/// get aborted after this method has been called based on
/// whether they are doing read/write on the socket and the timeout for the
/// socket. Push mode for abort.
void abort(const ErrorCode abortCode);
/// clears abort flag
void clearAbort();
/**
* sets an extra external call back to check for abort
* can be for instance extending IAbortChecker with
* bool checkAbort() {return atomicBool->load();}
* see wdtCmdLine.cpp for an example.
*/
void setAbortChecker(const std::shared_ptr<IAbortChecker>& checker);
/// threads can call this method to find out
/// whether transfer has been marked from abort
ErrorCode getCurAbortCode();
/// Wdt objects can report progress. Setter for progress reporter
/// defined in Reporting.h
void setProgressReporter(std::unique_ptr<ProgressReporter>& progressReporter);
/// Set throttler externally. Should be set before any transfer calls
void setThrottler(std::shared_ptr<Throttler> throttler);
/// Sets the transferId for this transfer
void setTransferId(const std::string& transferId);
/// Sets the protocol version for the transfer
void setProtocolVersion(int64_t protocolVersion);
/// Get the transfer id of the object
std::string getTransferId();
/// Finishes the wdt object and returns a report
virtual std::unique_ptr<TransferReport> finish() = 0;
/// Method to transfer the data. Doesn't block and
/// returns with the status
virtual ErrorCode transferAsync() = 0;
/// Basic setup for throttler using options
void configureThrottler();
/// Utility to generate a random transer id
static std::string generateTransferId();
protected:
/// Global throttler across all threads
std::shared_ptr<Throttler> throttler_;
/// Holds the instance of the progress reporter default or customized
std::unique_ptr<ProgressReporter> progressReporter_;
/// Unique id for the transfer
std::string transferId_;
/// protocol version to use, this is determined by negotiating protocol
/// version with the other side
int protocolVersion_{Protocol::protocol_version};
/// abort checker class passed to socket functions
class AbortChecker : public IAbortChecker {
public:
explicit AbortChecker(WdtBase* wdtBase) : wdtBase_(wdtBase) {
}
bool shouldAbort() const {
return wdtBase_->getCurAbortCode() != OK;
}
private:
WdtBase* wdtBase_;
};
/// abort checker passed to socket functions
AbortChecker abortCheckerCallback_;
private:
folly::RWSpinLock abortCodeLock_;
/// Internal and default abort code
ErrorCode abortCode_{OK};
/// Additional external source of check for abort requested
std::shared_ptr<IAbortChecker> abortChecker_{nullptr};
};
}
} // namespace facebook::wdt