diff --git a/main/compilation/expression.cpp b/main/compilation/expression.cpp index cabe9db3..2c87417d 100644 --- a/main/compilation/expression.cpp +++ b/main/compilation/expression.cpp @@ -1,4 +1,5 @@ #include "expression.h" +#include "../utils/string_utils.h" #include Expression::Expression(const Type type) : type(type) { @@ -28,18 +29,18 @@ bool Expression::is_numbery() const { return this->type == number || this->type == integer || this->type == boolean; } -int Expression::print_to_buffer(char *buffer) const { +int Expression::print_to_buffer(char *buffer, size_t buffer_len) const { switch (this->type) { case boolean: - return sprintf(buffer, "%s", this->evaluate_boolean() ? "true" : "false"); + return csprintf(buffer, buffer_len, "%s", this->evaluate_boolean() ? "true" : "false"); case integer: - return sprintf(buffer, "%lld", this->evaluate_integer()); + return csprintf(buffer, buffer_len, "%lld", this->evaluate_integer()); case number: - return sprintf(buffer, "%f", this->evaluate_number()); + return csprintf(buffer, buffer_len, "%f", this->evaluate_number()); case string: - return sprintf(buffer, "\"%s\"", this->evaluate_string().c_str()); + return csprintf(buffer, buffer_len, "\"%s\"", this->evaluate_string().c_str()); case identifier: - return sprintf(buffer, "%s", this->evaluate_identifier().c_str()); + return csprintf(buffer, buffer_len, "%s", this->evaluate_identifier().c_str()); default: throw std::runtime_error("expression has an invalid datatype"); } diff --git a/main/compilation/expression.h b/main/compilation/expression.h index 4f5a03fe..02b06009 100644 --- a/main/compilation/expression.h +++ b/main/compilation/expression.h @@ -10,7 +10,7 @@ class Expression; using Expression_ptr = std::shared_ptr; using ConstExpression_ptr = std::shared_ptr; -int write_arguments_to_buffer(const std::vector arguments, char *buffer); +int write_arguments_to_buffer(const std::vector arguments, char *buffer, size_t buffer_len); class Expression { protected: @@ -25,5 +25,5 @@ class Expression { virtual std::string evaluate_string() const; virtual std::string evaluate_identifier() const; bool is_numbery() const; - int print_to_buffer(char *buffer) const; + int print_to_buffer(char *buffer, size_t buffer_len) const; }; diff --git a/main/compilation/expressions.cpp b/main/compilation/expressions.cpp index 98ea2e2a..7dc106fb 100644 --- a/main/compilation/expressions.cpp +++ b/main/compilation/expressions.cpp @@ -1,15 +1,16 @@ #include "expressions.h" #include "../modules/module.h" +#include "../utils/string_utils.h" #include "math.h" #include -int write_arguments_to_buffer(const std::vector arguments, char *buffer) { +int write_arguments_to_buffer(const std::vector arguments, char *buffer, size_t buffer_len) { int pos = 0; for (auto const &argument : arguments) { if (argument != arguments[0]) { - pos += std::sprintf(&buffer[pos], ", "); + pos += csprintf(&buffer[pos], buffer_len - pos, ", "); } - pos += argument->print_to_buffer(&buffer[pos]); + pos += argument->print_to_buffer(&buffer[pos], buffer_len - pos); } return pos; } diff --git a/main/compilation/variable.cpp b/main/compilation/variable.cpp index 2e18d0de..7d23d64c 100644 --- a/main/compilation/variable.cpp +++ b/main/compilation/variable.cpp @@ -1,4 +1,5 @@ #include "variable.h" +#include "../utils/string_utils.h" #include "expression.h" #include @@ -21,18 +22,18 @@ void Variable::assign(const ConstExpression_ptr expression) { } } -int Variable::print_to_buffer(char *const buffer) const { +int Variable::print_to_buffer(char *const buffer, size_t buffer_len) const { switch (this->type) { case boolean: - return sprintf(buffer, "%s", this->boolean_value ? "true" : "false"); + return csprintf(buffer, buffer_len, "%s", this->boolean_value ? "true" : "false"); case integer: - return sprintf(buffer, "%lld", this->integer_value); + return csprintf(buffer, buffer_len, "%lld", this->integer_value); case number: - return sprintf(buffer, "%f", this->number_value); + return csprintf(buffer, buffer_len, "%f", this->number_value); case string: - return sprintf(buffer, "\"%s\"", this->string_value.c_str()); + return csprintf(buffer, buffer_len, "\"%s\"", this->string_value.c_str()); case identifier: - return sprintf(buffer, "%s", this->identifier_value.c_str()); + return csprintf(buffer, buffer_len, "%s", this->identifier_value.c_str()); default: throw std::runtime_error("variable has an invalid datatype"); } diff --git a/main/compilation/variable.h b/main/compilation/variable.h index 46148e1e..1c533c05 100644 --- a/main/compilation/variable.h +++ b/main/compilation/variable.h @@ -23,7 +23,7 @@ class Variable { Variable(const Type type); void assign(const ConstExpression_ptr expression); - int print_to_buffer(char *const buffer) const; + int print_to_buffer(char *const buffer, size_t buffer_len) const; }; class BooleanVariable : public Variable { diff --git a/main/main.cpp b/main/main.cpp index 6fab2ad0..51e9d49e 100644 --- a/main/main.cpp +++ b/main/main.cpp @@ -189,7 +189,7 @@ void process_tree(owl_tree *const tree, bool from_expander) { } else if (!statement.expression.empty) { const ConstExpression_ptr expression = compile_expression(statement.expression); static char buffer[256]; - expression->print_to_buffer(buffer); + expression->print_to_buffer(buffer, sizeof(buffer)); echo(buffer); } else if (!statement.constructor.empty) { const struct parsed_constructor constructor = parsed_constructor_get(statement.constructor); diff --git a/main/modules/can.cpp b/main/modules/can.cpp index ebea6bde..05542903 100644 --- a/main/modules/can.cpp +++ b/main/modules/can.cpp @@ -1,4 +1,5 @@ #include "can.h" +#include "../utils/string_utils.h" #include "../utils/uart.h" #include "driver/twai.h" #include @@ -97,10 +98,10 @@ bool Can::receive() { if (this->output_on) { static char buffer[256]; - int pos = std::sprintf(buffer, "%s %03lx", this->name.c_str(), message.identifier); + int pos = csprintf(buffer, sizeof(buffer), "%s %03lx", this->name.c_str(), message.identifier); if (!(message.flags & TWAI_MSG_FLAG_RTR)) { for (int i = 0; i < message.data_length_code; ++i) { - pos += std::sprintf(&buffer[pos], ",%02x", message.data[i]); + pos += csprintf(&buffer[pos], sizeof(buffer) - pos, ",%02x", message.data[i]); } } echo(buffer); diff --git a/main/modules/core.cpp b/main/modules/core.cpp index be30afa8..858d31c2 100644 --- a/main/modules/core.cpp +++ b/main/modules/core.cpp @@ -48,9 +48,9 @@ void Core::call(const std::string method_name, const std::vectorprint_to_buffer(&buffer[pos]); + pos += argument->print_to_buffer(&buffer[pos], sizeof(buffer) - pos); } echo(buffer); } else if (method_name == "output") { @@ -173,22 +173,22 @@ std::string Core::get_output() const { int pos = 0; for (auto const &element : this->output_list) { if (pos > 0) { - pos += sprintf(&output_buffer[pos], " "); + pos += csprintf(&output_buffer[pos], sizeof(output_buffer) - pos, " "); } const Variable_ptr variable = element.module ? element.module->get_property(element.property_name) : Global::get_variable(element.property_name); switch (variable->type) { case boolean: - pos += sprintf(&output_buffer[pos], "%s", variable->boolean_value ? "true" : "false"); + pos += csprintf(&output_buffer[pos], sizeof(output_buffer) - pos, "%s", variable->boolean_value ? "true" : "false"); break; case integer: - pos += sprintf(&output_buffer[pos], "%lld", variable->integer_value); + pos += csprintf(&output_buffer[pos], sizeof(output_buffer) - pos, "%lld", variable->integer_value); break; case number: - pos += sprintf(&output_buffer[pos], "%.*f", element.precision, variable->number_value); + pos += csprintf(&output_buffer[pos], sizeof(output_buffer) - pos, "%.*f", element.precision, variable->number_value); break; case string: - pos += sprintf(&output_buffer[pos], "\"%s\"", variable->string_value.c_str()); + pos += csprintf(&output_buffer[pos], sizeof(output_buffer) - pos, "\"%s\"", variable->string_value.c_str()); break; default: throw std::runtime_error("invalid type"); diff --git a/main/modules/expander.cpp b/main/modules/expander.cpp index db2368d7..4f785ea5 100644 --- a/main/modules/expander.cpp +++ b/main/modules/expander.cpp @@ -2,6 +2,7 @@ #include "storage.h" #include "utils/serial-replicator.h" +#include "utils/string_utils.h" #include "utils/timing.h" #include "utils/uart.h" #include @@ -36,7 +37,7 @@ Expander::Expander(const std::string name, break; } if (serial->available()) { - len = serial->read_line(buffer); + len = serial->read_line(buffer, sizeof(buffer)); strip(buffer, len); echo("%s: %s", name.c_str(), buffer); } @@ -46,7 +47,7 @@ Expander::Expander(const std::string name, void Expander::step() { static char buffer[1024]; while (this->serial->has_buffered_lines()) { - int len = this->serial->read_line(buffer); + int len = this->serial->read_line(buffer, sizeof(buffer)); check(buffer, len); this->last_message_millis = millis(); if (buffer[0] == '!' && buffer[1] == '!') { @@ -94,9 +95,9 @@ void Expander::call(const std::string method_name, const std::vectorserial->write_checked_line(buffer, pos); } } diff --git a/main/modules/input.cpp b/main/modules/input.cpp index 93031ffa..608c4ba5 100644 --- a/main/modules/input.cpp +++ b/main/modules/input.cpp @@ -1,4 +1,5 @@ #include "input.h" +#include "../utils/string_utils.h" #include "../utils/uart.h" #include #include @@ -38,7 +39,7 @@ void Input::call(const std::string method_name, const std::vectorget_level()); + csprintf(buffer, sizeof(buffer), "%d", this->get_level()); return buffer; } diff --git a/main/modules/module.cpp b/main/modules/module.cpp index bf0c68f5..44d10051 100644 --- a/main/modules/module.cpp +++ b/main/modules/module.cpp @@ -1,5 +1,6 @@ #include "module.h" #include "../global.h" +#include "../utils/string_utils.h" #include "../utils/uart.h" #include "analog.h" #include "bluetooth.h" @@ -358,11 +359,11 @@ void Module::step() { } if (this->broadcast && !this->properties.empty()) { static char buffer[1024]; - int pos = sprintf(buffer, "!!"); + int pos = csprintf(buffer, sizeof(buffer), "!!"); for (auto const &[property_name, property] : this->properties) { - pos += sprintf(&buffer[pos], "%s.%s=", this->name.c_str(), property_name.c_str()); - pos += property->print_to_buffer(&buffer[pos]); - pos += sprintf(&buffer[pos], ";"); + pos += csprintf(&buffer[pos], sizeof(buffer) - pos, "%s.%s=", this->name.c_str(), property_name.c_str()); + pos += property->print_to_buffer(&buffer[pos], sizeof(buffer) - pos); + pos += csprintf(&buffer[pos], sizeof(buffer) - pos, ";"); } echo(buffer); } diff --git a/main/modules/proxy.cpp b/main/modules/proxy.cpp index baff1900..1de7c43a 100644 --- a/main/modules/proxy.cpp +++ b/main/modules/proxy.cpp @@ -1,4 +1,5 @@ #include "proxy.h" +#include "../utils/string_utils.h" #include "driver/uart.h" #include @@ -9,19 +10,19 @@ Proxy::Proxy(const std::string name, const std::vector arguments) : Module(proxy, name), expander(expander) { static char buffer[256]; - int pos = std::sprintf(buffer, "%s = %s(", name.c_str(), module_type.c_str()); - pos += write_arguments_to_buffer(arguments, &buffer[pos]); - pos += std::sprintf(&buffer[pos], "); "); - pos += std::sprintf(&buffer[pos], "%s.broadcast()", name.c_str()); + int pos = csprintf(buffer, sizeof(buffer), "%s = %s(", name.c_str(), module_type.c_str()); + pos += write_arguments_to_buffer(arguments, &buffer[pos], sizeof(buffer) - pos); + pos += csprintf(&buffer[pos], sizeof(buffer) - pos, "); "); + pos += csprintf(&buffer[pos], sizeof(buffer) - pos, "%s.broadcast()", name.c_str()); expander->serial->write_checked_line(buffer, pos); } void Proxy::call(const std::string method_name, const std::vector arguments) { static char buffer[256]; - int pos = std::sprintf(buffer, "%s.%s(", this->name.c_str(), method_name.c_str()); - pos += write_arguments_to_buffer(arguments, &buffer[pos]); - pos += std::sprintf(&buffer[pos], ")"); + int pos = csprintf(buffer, sizeof(buffer), "%s.%s(", this->name.c_str(), method_name.c_str()); + pos += write_arguments_to_buffer(arguments, &buffer[pos], sizeof(buffer) - pos); + pos += csprintf(&buffer[pos], sizeof(buffer) - pos, ")"); this->expander->serial->write_checked_line(buffer, pos); } @@ -31,8 +32,8 @@ void Proxy::write_property(const std::string property_name, const ConstExpressio } if (!from_expander) { static char buffer[256]; - int pos = std::sprintf(buffer, "%s.%s = ", this->name.c_str(), property_name.c_str()); - pos += expression->print_to_buffer(&buffer[pos]); + int pos = csprintf(buffer, sizeof(buffer), "%s.%s = ", this->name.c_str(), property_name.c_str()); + pos += expression->print_to_buffer(&buffer[pos], sizeof(buffer) - pos); this->expander->serial->write_checked_line(buffer, pos); } Module::get_property(property_name)->assign(expression); diff --git a/main/modules/serial.cpp b/main/modules/serial.cpp index 11e7573f..7d1ff072 100644 --- a/main/modules/serial.cpp +++ b/main/modules/serial.cpp @@ -1,4 +1,5 @@ #include "serial.h" +#include "utils/string_utils.h" #include "utils/uart.h" #include #include @@ -55,7 +56,7 @@ void Serial::write_checked_line(const char *message, const int length) const { int start = 0; for (unsigned int i = 0; i < length + 1; ++i) { if (i >= length || message[i] == '\n') { - sprintf(checksum_buffer, "@%02x\n", checksum); + csprintf(checksum_buffer, sizeof(checksum_buffer), "@%02x\n", checksum); uart_write_bytes(this->uart_num, &message[start], i - start); uart_write_bytes(this->uart_num, checksum_buffer, 4); start = i + 1; @@ -89,8 +90,20 @@ int Serial::read(uint32_t timeout) const { return length > 0 ? data : -1; } -int Serial::read_line(char *buffer) const { +int Serial::read_line(char *buffer, size_t buffer_len) const { int pos = uart_pattern_pop_pos(this->uart_num); + if (pos >= buffer_len) { + if (this->available() < pos) { + uart_flush_input(this->uart_num); + while (uart_pattern_pop_pos(this->uart_num) > 0) + ; + throw std::runtime_error("buffer too small, but cannot discard line. flushed serial."); + } + + for (int i = 0; i < pos; i++) + this->read(); + throw std::runtime_error("buffer too small. discarded line."); + } return pos >= 0 ? uart_read_bytes(this->uart_num, (uint8_t *)buffer, pos + 1, 0) : 0; } @@ -109,7 +122,7 @@ std::string Serial::get_output() const { int byte; int pos = 0; while ((byte = this->read()) >= 0) { - pos += std::sprintf(&buffer[pos], pos == 0 ? "%02x" : " %02x", byte); + pos += csprintf(&buffer[pos], sizeof(buffer) - pos, pos == 0 ? "%02x" : " %02x", byte); } return buffer; } diff --git a/main/modules/serial.h b/main/modules/serial.h index fca40658..39b5e489 100644 --- a/main/modules/serial.h +++ b/main/modules/serial.h @@ -24,7 +24,7 @@ class Serial : public Module { int available() const; bool has_buffered_lines() const; int read(const uint32_t timeout = 0) const; - int read_line(char *buffer) const; + int read_line(char *buffer, size_t buffer_len) const; size_t write(const uint8_t byte) const; void write_checked_line(const char *message, const int length) const; void flush() const; diff --git a/main/utils/string_utils.cpp b/main/utils/string_utils.cpp index 63cbbacf..4ef552b5 100644 --- a/main/utils/string_utils.cpp +++ b/main/utils/string_utils.cpp @@ -1,4 +1,6 @@ #include "string_utils.h" +#include +#include #include std::string cut_first_word(std::string &msg, const char delimiter) { @@ -10,4 +12,19 @@ std::string cut_first_word(std::string &msg, const char delimiter) { bool starts_with(const std::string haystack, const std::string needle) { return haystack.substr(0, needle.length()) == needle; +} + +int csprintf(char *buffer, size_t buffer_len, const char *format, ...) { + va_list args; + + va_start(args, format); + const int num_chars = std::vsnprintf(buffer, buffer_len, format, args); + va_end(args); + + if (num_chars < 0) + throw std::runtime_error("encoding error"); + if (num_chars > buffer_len - 1) + throw std::runtime_error("buffer too small"); + + return num_chars; } \ No newline at end of file diff --git a/main/utils/string_utils.h b/main/utils/string_utils.h index d26801d5..16e71db9 100644 --- a/main/utils/string_utils.h +++ b/main/utils/string_utils.h @@ -4,4 +4,6 @@ std::string cut_first_word(std::string &msg, char delimiter = ' '); -bool starts_with(const std::string haystack, const std::string needle); \ No newline at end of file +bool starts_with(const std::string haystack, const std::string needle); + +int csprintf(char *buffer, size_t buffer_len, const char *format, ...); \ No newline at end of file diff --git a/main/utils/uart.cpp b/main/utils/uart.cpp index 74dd0511..c471b8c7 100644 --- a/main/utils/uart.cpp +++ b/main/utils/uart.cpp @@ -6,11 +6,11 @@ void echo(const char *format, ...) { static char buffer[1024]; - int pos = 0; va_list args; va_start(args, format); - pos += std::vsnprintf(&buffer[pos], sizeof buffer - pos - 1, format, args); + const int num_chars = std::vsnprintf(buffer, sizeof(buffer) - 1, format, args); + int pos = std::min(num_chars, static_cast(sizeof(buffer) - 2)); va_end(args); pos += std::sprintf(&buffer[pos], "\n"); @@ -30,10 +30,11 @@ void echo(const char *format, ...) { } int strip(char *buffer, int len) { - while (buffer[len - 1] == ' ' || - buffer[len - 1] == '\t' || - buffer[len - 1] == '\r' || - buffer[len - 1] == '\n') { + while (len > 0 && + (buffer[len - 1] == ' ' || + buffer[len - 1] == '\t' || + buffer[len - 1] == '\r' || + buffer[len - 1] == '\n')) { len--; } buffer[len] = 0;