From d9c2bed4c7714f14ff5b53b1877a26c7f89a1da5 Mon Sep 17 00:00:00 2001 From: Loic Pottier Date: Fri, 3 Nov 2023 15:32:13 -0700 Subject: [PATCH] Added documentation and new AMSMsgHeader class + moved from memcpy to ResourceManager::copy Signed-off-by: Loic Pottier --- src/AMSlib/wf/basedb.hpp | 326 ++++++++++++++++++++++++++------------- 1 file changed, 215 insertions(+), 111 deletions(-) diff --git a/src/AMSlib/wf/basedb.hpp b/src/AMSlib/wf/basedb.hpp index b8ea8ac8..4fe10878 100644 --- a/src/AMSlib/wf/basedb.hpp +++ b/src/AMSlib/wf/basedb.hpp @@ -641,9 +641,103 @@ class RedisDB : public BaseDB #ifdef __ENABLE_RMQ__ -/** @brief Structure that represents a received RabbitMQ message */ -typedef std::tuple - inbound_msg; +/** + * @brief AMS represents the header as follows: + * The header is 12 bytes long: + * - 1 byte is the size of the header (here 12). Limit max: 255 + * - 1 byte is the precision (4 for float, 8 for double). Limit max: 255 + * - 2 bytes are the MPI rank (0 if AMS is not running with MPI). Limit max: 65535 + * - 4 bytes are the number of elements in the message. Limit max: 2^32 - 1 + * - 2 bytes are the input dimension. Limit max: 65535 + * - 2 bytes are the output dimension. Limit max: 65535 + * + * |__Header__|__Datatype__|___Rank___|__#elems__|___InDim___|___OutDim___|...real data...| + * ^ ^ ^ ^ ^ ^ ^ ^ + * | Byte 1 | Byte 2 | Byte 3-4 | Byte 4-8 | Byte 8-10 | Byte 10-12 | Byte 12-X | + * + * where X = datatype * num_element * (InDim + OutDim). Total message size is 12+X. + * + * The data starts at byte 12, ends at byte X. + * The data is structured as pairs of input/outputs. Let K be the total number of + * elements, then we have K pairs of inputs/outputs (either float or double): + * + * |__Header_(12B)__|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__| + */ +template +struct AMSMsgHeader { + /** @brief Heaader size (bytes) */ + uint8_t hsize; + /** @brief Data type size (bytes) */ + uint8_t dtype; + /** @brief MPI rank */ + uint16_t mpi_rank; + /** @brief Number of elements */ + uint32_t num_elem; + /** @brief Inputs dimension */ + uint16_t in_dim; + /** @brief Outputs dimension */ + uint16_t out_dim; + + /** + * @brief Constructor for AMSMsgHeader + * @param[in] mpi_rank MPI rank + * @param[in] num_elem Number of elements (input/outputs) + * @param[in] in_dim Inputs dimension + * @param[in] out_dim Outputs dimension + */ + AMSMsgHeader(size_t mpi_rank, size_t num_elem, size_t in_dim, size_t out_dim) + : hsize(static_cast(AMSMsgHeader::size())), + dtype(static_cast(sizeof(TypeValue))), + mpi_rank(static_cast(mpi_rank)), + num_elem(static_cast(num_elem)), + in_dim(static_cast(in_dim)), + out_dim(static_cast(out_dim)) + { + } + + /** + * @brief Return the size of a header in the AMS protocol. + * @return The size of a message header in AMS (in byte) + */ + static size_t size() + { + return sizeof(hsize) + sizeof(dtype) + sizeof(mpi_rank) + sizeof(num_elem) + + sizeof(in_dim) + sizeof(out_dim); + } + + /** + * @brief Fill an empty buffer with a valid header. + * @param[in] data_blob The buffer to fill + * @return The number of bytes in the header or 0 if error + */ + size_t encode(uint8_t* data_blob) + { + if (!data_blob) return 0; + + size_t current_offset = 0; + // Header size (should be 1 bytes) + data_blob[current_offset] = hsize; + current_offset += sizeof(hsize); + // Data type (should be 1 bytes) + data_blob[current_offset] = dtype; + current_offset += sizeof(dtype); + // MPI rank (should be 2 bytes) + std::memcpy(data_blob + current_offset, &(mpi_rank), sizeof(mpi_rank)); + current_offset += sizeof(mpi_rank); + // Num elem (should be 4 bytes) + std::memcpy(data_blob + current_offset, &(num_elem), sizeof(num_elem)); + current_offset += sizeof(num_elem); + // Input dim (should be 2 bytes) + std::memcpy(data_blob + current_offset, &(in_dim), sizeof(in_dim)); + current_offset += sizeof(in_dim); + // Output dim (should be 2 bytes) + std::memcpy(data_blob + current_offset, &(out_dim), sizeof(out_dim)); + current_offset += sizeof(out_dim); + + return current_offset; + } +}; + /** * @brief Class representing a message for the AMSLib @@ -658,10 +752,8 @@ class AMSMessage int _rank; /** @brief The data represented as a binary blob */ uint8_t* _data; - /** @brief The size of the header in bytes */ - size_t _header_size; /** @brief The total size of the binary blob in bytes */ - size_t _data_size; + size_t _total_size; /** @brief The number of input/output pairs */ size_t _num_elements; /** @brief The dimensions of inputs */ @@ -671,10 +763,10 @@ class AMSMessage public: /** - * @brief Constructor - * @param[in] num_elements Number of elements - * @param[in] inputs Inputs - * @param[in] outputs Outputs + * @brief Constructor + * @param[in] num_elements Number of elements + * @param[in] inputs Inputs + * @param[in] outputs Outputs */ AMSMessage(int id, size_t num_elements, @@ -685,35 +777,21 @@ class AMSMessage _input_dim(inputs.size()), _output_dim(outputs.size()), _data(nullptr), - _header_size(0), - _data_size(0) + _total_size(0) { #ifdef __ENABLE_MPI__ MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); #endif - _header_size = - 2 * sizeof(uint8_t) + sizeof(uint32_t) + 3 * sizeof(uint16_t); - _data_size = _header_size + (_num_elements * (_input_dim + _output_dim)) * - sizeof(TypeValue); - _data = ams::ResourceManager::allocate(_data_size, + _total_size = AMSMsgHeader::size() + getDataSize(); + _data = ams::ResourceManager::allocate(_total_size, AMSResourceType::HOST); - size_t current_offset = encode_header( - _data, _header_size, _num_elements, _input_dim, _output_dim); - // Creating the body part of the messages - for (size_t i = 0; i < _num_elements; i++) { - for (size_t j = 0; j < _input_dim; j++) { - std::memcpy(_data + current_offset, &(inputs[j][i]), sizeof(TypeValue)); - // We move on to the next TypeValue - current_offset += sizeof(TypeValue); - } - for (size_t j = 0; j < _output_dim; j++) { - std::memcpy(_data + current_offset, - &(outputs[j][i]), - sizeof(TypeValue)); - current_offset += sizeof(TypeValue); - } - } + AMSMsgHeader header(_rank, + _num_elements, + _input_dim, + _output_dim); + size_t current_offset = header.encode(_data); + current_offset = encode_data(_data, current_offset, inputs, outputs); } AMSMessage(const AMSMessage&) = delete; @@ -728,8 +806,7 @@ class AMSMessage _num_elements = other._num_elements; _input_dim = other._input_dim; _output_dim = other._output_dim; - _header_size = other._header_size; - _data_size = other._data_size; + _total_size = other._total_size; _data = other._data; other._data = nullptr; } @@ -737,78 +814,42 @@ class AMSMessage } /** - * @brief Return the header size in bytes - * @return The number of bytes in the header or -1 if error + * @brief Fill a buffer with a data section starting at a given position. + * @param[in] data_blob The buffer to fill + * @param[in] offset Position where to start writing in the buffer + * @param[in] inputs Inputs + * @param[in] outputs Outputs + * @return The number of bytes in the message or 0 if error */ - static size_t hsize() + size_t encode_data(uint8_t* data_blob, + size_t offset, + const std::vector& inputs, + const std::vector& outputs) { - return sizeof(uint8_t) + sizeof(uint8_t) + sizeof(uint16_t) + - sizeof(uint32_t) + sizeof(uint16_t) + sizeof(uint16_t); + if (!data_blob) return 0; + // Creating the body part of the messages + // TODO: slow method (one copy per element!), improve by reducing number of copies + for (size_t i = 0; i < _num_elements; i++) { + for (size_t j = 0; j < _input_dim; j++) { + ams::ResourceManager::copy(&(inputs[j][i]), reinterpret_cast(_data + offset), sizeof(TypeValue)); + offset += sizeof(TypeValue); + } + for (size_t j = 0; j < _output_dim; j++) { + ams::ResourceManager::copy(&(outputs[j][i]), reinterpret_cast(_data + offset), sizeof(TypeValue)); + offset += sizeof(TypeValue); + } + } + + return offset; } /** - * @brief Fill an empty buffer with a valid header. We encode the header as follow: - * The header is 12 bytes long and structured as follows: - * - 1 byte is the size of the header (here 12). Limit max: 255 - * - 1 byte is the precision (4 for float, 8 for double). Limit max: 255 - * - 2 bytes are the MPI rank (0 if AMS is not running with MPI). Limit max: 65535 - * - 4 bytes are the number of elements in the message. Limit max: 2^32 - 1 - * - 2 bytes are the input dimension. Limit max: 65535 - * - 2 bytes are the output dimension. Limit max: 65535 - * - * |__Header__|__Datatype__|___Rank___|__#elems__|___InDim___|___OutDim___|...real data...| - * ^ ^ ^ ^ ^ ^ ^ ^ - * | Byte 1 | Byte 2 | Byte 3-4 | Byte 4-8 | Byte 8-10 | Byte 10-12 | Byte 12-X | - * - * where X = datatype * num_element * (InDim + OutDim). Total message size is 12+X. - * - * The data starts at byte 12, ends at byte X. - * The data is structured as pairs of input/outputs. Let K be the total number of - * elements, then we have K pairs of inputs/outputs (either float or double): - * - * |__Header_(12B)__|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__| - * - * @param[in] data_blob The buffer to fill - * @param[in] header_size Size of the header in bytes - * @param[in] num_elements Total number of elements - * @param[in] input_dim The dimensions of inputs - * @param[in] output_dim The dimensions of outputs - * @return The number of bytes in the header or -1 if error + * @brief Return the size of the data portion for that message + * @return Size in bytes of the data portion */ - size_t encode_header(uint8_t* data_blob, - const size_t header_size, - const size_t num_elements, - const size_t input_dim, - const size_t output_dim) + size_t getDataSize() { - if (!data_blob) return 0; - - uint16_t mpirank = static_cast(_rank); - uint32_t num_elem = static_cast(num_elements); - uint16_t num_in = static_cast(input_dim); - uint16_t num_out = static_cast(output_dim); - - size_t current_offset = 0; - // Header size (should be 1 bytes) - data_blob[current_offset] = static_cast(header_size); - current_offset += sizeof(uint8_t); - // Data type (should be 1 bytes) - data_blob[current_offset] = static_cast(sizeof(TypeValue)); - current_offset += sizeof(uint8_t); - // MPI rank (should be 2 bytes) - std::memcpy(data_blob + current_offset, &(mpirank), sizeof(uint16_t)); - current_offset += sizeof(uint16_t); - // Num elem (should be 4 bytes) - std::memcpy(data_blob + current_offset, &(num_elem), sizeof(uint32_t)); - current_offset += sizeof(uint32_t); - // Input dim (should be 2 bytes) - std::memcpy(data_blob + current_offset, &(num_in), sizeof(uint16_t)); - current_offset += sizeof(uint16_t); - // Output dim (should be 2 bytes) - std::memcpy(data_blob + current_offset, &(num_out), sizeof(uint16_t)); - current_offset += sizeof(uint16_t); - - return current_offset; + return (_num_elements * (_input_dim + _output_dim)) * sizeof(TypeValue); } /** @@ -827,13 +868,7 @@ class AMSMessage * @brief Return the size in bytes of the underlying binary blob * @return Byte size of data pointer */ - size_t size() const { return _data_size; } - - /** - * @brief Return the size in bytes of the header - * @return Byte size of message header - */ - size_t header_size() const { return _header_size; } + size_t size() const { return _total_size; } ~AMSMessage() { @@ -842,6 +877,18 @@ class AMSMessage } }; // class AMSMessage +/** @brief Structure that represents a received RabbitMQ message. + * - The first field is the message content (body) + * - The second field is the RMQ exchange from which the message + * has been received + * - The third field is the routing key + * - The fourth is the delivery tag (ID of the message) + * - The fifth field is a boolean that indicates if that message + * has been redelivered by RMQ. + */ +typedef std::tuple + inbound_msg; + /** * @brief Specific handler for RabbitMQ connections based on libevent. */ @@ -908,6 +955,9 @@ class RMQConsumerHandler : public AMQP::LibEventHandler #else int ret = SSL_use_certificate_chain_file(ssl, _cacert.c_str()); #endif + // TODO: with openssl 3.0 + // SSL_set_options(ssl, SSL_OP_IGNORE_UNEXPECTED_EOF); + if (ret != 1) { std::string error("openssl: error loading ca-chain (" + _cacert + ") + from ["); @@ -1103,7 +1153,7 @@ class RMQConsumer std::shared_ptr _loop; /** @brief The handler which contains various callbacks for the sender */ std::shared_ptr> _handler; - /** @brief Queue that contains all the messages received on receiver queue */ + /** @brief Queue that contains all the messages received on receiver queue (messages can be popped in) */ std::vector _messages; public: @@ -1537,16 +1587,18 @@ class RMQPublisher * @brief Wait that the connection is ready (blocking call) * @return True if the publisher is ready to publish */ - void wait_ready(int ms = 300, int timeout_sec = 10) + void wait_ready(int ms = 500, int timeout_sec = 30) { // We wait for the connection to be ready int total_time = 0; while (!ready_publish()) { std::this_thread::sleep_for(std::chrono::milliseconds(ms)); - DBG(RMQPublisher, "Waiting for connection to be ready...") + DBG(RMQPublisher, + "[rank=%d] Waiting for connection to be ready...", + _rank) total_time += ms; if (total_time > timeout_sec * 1000) { - DBG(RMQPublisher, "Connection timeout") + DBG(RMQPublisher, "[rank=%d] Connection timeout", _rank) break; // TODO: if connection is not working -> revert to classic file DB. } @@ -1579,6 +1631,58 @@ class RMQPublisher /** * @brief Class that manages a RabbitMQ broker and handles connection, event * loop and set up various handlers. + * @details This class manages a specific type of database backend in AMSLib. + * Instead of writing inputs/outputs directly to files (CSV or HDF5), we + * send these elements (a collection of inputs and their corresponding outputs) + * to a service called RabbitMQ which is listening on a given IP and port. + * + * This class requires a RabbitMQ server to be running somewhere, + * the credentials of that server should be formatted as a JSON file as follows: + * + * { + * "rabbitmq-name": "testamsrabbitmq", + * "rabbitmq-password": "XXX", + * "rabbitmq-user": "pottier1", + * "rabbitmq-vhost": "ams", + * "service-port": 31495, + * "service-host": "url.czapps.llnl.gov", + * "rabbitmq-cert": "tls-cert.crt", + * "rabbitmq-inbound-queue": "test4", + * "rabbitmq-outbound-queue": "test3" + * } + * + * The TLS certificate must be generated by the user and the absolute paths are preferred. + * A TLS certificate can be generated with the following command: + * + * openssl s_client \ + * -connect $REMOTE_HOST:$REMOTE_PORT -showcerts < /dev/null \ + * 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' > tls.crt + * + * RabbitMQDB creates two RabbitMQ connections per MPI rank, one for publishing data to RMQ and one for consuming data. + * Each connection has its own I/O loop (based on Libevent) running in a dedicated thread because I/O loop are blocking. + * Therefore, we have two threads per MPI rank. + * + * 1. Publishing data: When the store() method is being called, it triggers a series of calls: + * + * RabbitMQDB::store() -> RMQPublisher::publish() -> RMQPublisherHandler::publish() + * + * Here, RMQPublisherHandler::publish() has access to internal RabbitMQ channels and can publish the message + * on the outbound queue (rabbitmq-outbound-queue in the JSON configuration). + * Note that storing data like that is much faster than with writing files as a call to RabbitMQDB::store() + * is virtually free, the actual data sending part is taking place in a thread and does not slow down + * the main simulation (MPI). + * + * 2. Consuming data: The inbound queue (rabbitmq-inbound-queue in the JSON configuration) is the queue for incoming data. The + * RMQConsumer is listening on that queue for messages. In the AMSLib approach, that queue is used to communicate + * updates to rank regarding the ML surrrogate model. RMQConsumer will automatically populate a std::vector with all + * messages received since the execution of AMS started. + * + * Glabal note: Most calls dealing with RabbitMQ (to establish a RMQ connection, opening a channel, publish data etc) + * are asynchronous callbacks (similar to asyncio in Python or future in C++). + * So, the simulation can have already started and the RMQ connection might not be valid which is why most part + * of the code that deals with RMQ are wrapped into callbacks that will get run only in case of success. + * For example, we create a channel only if the underlying connection has been succesfuly initiated + * (see RMQPublisherHandler::onReady()). */ template class RabbitMQDB final : public BaseDB