Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweak the Printer code in runtime for smaller code #8023

Merged
merged 7 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cmake/HalideTestHelpers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ function(add_halide_test TARGET)
add_test(NAME ${TARGET}
COMMAND ${args_COMMAND} ${args_ARGS}
WORKING_DIRECTORY "${args_WORKING_DIRECTORY}")
set_halide_compiler_warnings(${TARGET})
if (NOT Halide_TARGET MATCHES "wasm")
set_halide_compiler_warnings(${TARGET})
endif ()

# We can't add Halide::TerminateHandler here, because it requires Halide::Error
# and friends to be present in the final linkage, but some callers of add_halide_test()
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/d3d12compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ static constexpr uint64_t trace_buf_size = 4096;
WEAK char trace_buf[trace_buf_size] = {};
WEAK int trace_indent = 0;

struct trace : public BasicPrinter<trace_buf_size> {
struct trace : public PrinterBase {
ScopedMutexLock lock;

explicit trace(void *user_context = nullptr)
: BasicPrinter<trace_buf_size>(user_context, trace_buf),
: PrinterBase(user_context, trace_buf, trace_buf_size),
lock(&trace_lock) {
for (int i = 0; i < trace_indent; i++) {
*this << " ";
Expand Down
20 changes: 9 additions & 11 deletions src/runtime/posix_error_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,16 @@ extern "C" {
extern void abort();

WEAK void halide_default_error(void *user_context, const char *msg) {
char buf[4096];
char *dst = halide_string_to_string(buf, buf + 4094, "Error: ");
dst = halide_string_to_string(dst, dst + 4094, msg);
// We still have one character free. Add a newline if there
// isn't one already.
if (dst[-1] != '\n') {
dst[0] = '\n';
dst[1] = 0;
dst += 1;
// Can't use StackBasicPrinter here because it limits size to 256
constexpr int buf_size = 4096;
char buf[buf_size];
PrinterBase dst(user_context, buf, buf_size);
dst << "Error: " << msg;
const char *d = dst.str();
if (d && *d && d[strlen(d) - 1] != '\n') {
dst << "\n";
}
(void)halide_msan_annotate_memory_is_initialized(user_context, buf, dst - buf + 1);
halide_print(user_context, buf);
halide_print(user_context, dst.str());
abort();
}
}
Expand Down
226 changes: 111 additions & 115 deletions src/runtime/printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,179 +41,174 @@ constexpr uint64_t default_printer_buffer_length = 1024;
// Then remember the print only happens when the debug object leaves
// scope, which may print at a confusing time.

namespace {
template<PrinterType printer_type, uint64_t buffer_length = default_printer_buffer_length>
class Printer {
char *buf, *dst, *end;
void *user_context;
bool own_mem;
class PrinterBase {
protected:
char *dst;
char *const end;
char *const start;
void *const user_context;

NEVER_INLINE void allocation_error() const {
halide_error(user_context, "Printer buffer allocation failed.\n");
}

public:
explicit Printer(void *ctx, char *mem = nullptr)
: user_context(ctx), own_mem(mem == nullptr) {
if (mem != nullptr) {
buf = mem;
} else {
buf = (char *)malloc(buffer_length);
// This class will stream text into the range [start, start + size - 1].
// It does *not* assume any ownership of the memory; it assumes
// the memory will remain valid for its lifespan, and doesn't
// attempt to free any allocations. It also doesn't do any sanity
// checking of the pointers, so if you pass in a null or bogus value,
// it will attempt to use it.
NEVER_INLINE PrinterBase(void *user_context_, char *start_, uint64_t size_)
: dst(start_),
// (If start is null, set end = start to ensure no writes are done)
end(start_ ? start_ + size_ - 1 : start_),
start(start_),
user_context(user_context_) {
if (end > start) {
// null-terminate the final byte to ensure string isn't $ENDLESS
*end = 0;
}
}

NEVER_INLINE const char *str() {
(void)halide_msan_annotate_memory_is_initialized(user_context, start, dst - start + 1);
return start;
}

uint64_t size() const {
halide_debug_assert(user_context, dst >= start);
return (uint64_t)(dst - start);
}

uint64_t capacity() const {
halide_debug_assert(user_context, end >= start);
return (uint64_t)(end - start);
}

dst = buf;
NEVER_INLINE void clear() {
dst = start;
if (dst) {
end = buf + (buffer_length - 1);
*end = 0;
} else {
// Pointers equal ensures no writes to buffer via formatting code
end = dst;
dst[0] = 0;
}
}

#if HALIDE_RUNTIME_PRINTER_LOG_THREADID
uint64_t tid;
pthread_threadid_np(0, &tid);
*this << "(TID:" << tid << ")";
#endif
NEVER_INLINE void erase(int n) {
if (dst) {
dst -= n;
if (dst < start) {
dst = start;
}
dst[0] = 0;
}
}

// Not movable, not copyable
Printer(const Printer &copy) = delete;
Printer &operator=(const Printer &) = delete;
Printer(Printer &&) = delete;
Printer &operator=(Printer &&) = delete;
struct Float16Bits {
uint16_t bits;
};

Printer &operator<<(const char *arg) {
// These are NEVER_INLINE because Clang will aggressively inline
// all of them, but the code size of calling out-of-line here is slightly
// smaller, and we ~always prefer smaller code size when using Printer
// in the runtime (it's a modest but nonzero difference).
NEVER_INLINE PrinterBase &operator<<(const char *arg) {
dst = halide_string_to_string(dst, end, arg);
return *this;
}

Printer &operator<<(int64_t arg) {
NEVER_INLINE PrinterBase &operator<<(int64_t arg) {
dst = halide_int64_to_string(dst, end, arg, 1);
return *this;
}

Printer &operator<<(int32_t arg) {
NEVER_INLINE PrinterBase &operator<<(int32_t arg) {
dst = halide_int64_to_string(dst, end, arg, 1);
return *this;
}

Printer &operator<<(uint64_t arg) {
NEVER_INLINE PrinterBase &operator<<(uint64_t arg) {
dst = halide_uint64_to_string(dst, end, arg, 1);
return *this;
}

Printer &operator<<(uint32_t arg) {
NEVER_INLINE PrinterBase &operator<<(uint32_t arg) {
dst = halide_uint64_to_string(dst, end, arg, 1);
return *this;
}

Printer &operator<<(double arg) {
NEVER_INLINE PrinterBase &operator<<(double arg) {
dst = halide_double_to_string(dst, end, arg, 1);
return *this;
}

Printer &operator<<(float arg) {
NEVER_INLINE PrinterBase &operator<<(float arg) {
dst = halide_double_to_string(dst, end, arg, 0);
return *this;
}

Printer &operator<<(const void *arg) {
dst = halide_pointer_to_string(dst, end, arg);
NEVER_INLINE PrinterBase &operator<<(Float16Bits arg) {
double value = halide_float16_bits_to_double(arg.bits);
dst = halide_double_to_string(dst, end, value, 1);
return *this;
}

Printer &write_float16_from_bits(const uint16_t arg) {
double value = halide_float16_bits_to_double(arg);
dst = halide_double_to_string(dst, end, value, 1);
NEVER_INLINE PrinterBase &operator<<(const void *arg) {
dst = halide_pointer_to_string(dst, end, arg);
return *this;
}

Printer &operator<<(const halide_type_t &t) {
NEVER_INLINE PrinterBase &operator<<(const halide_type_t &t) {
dst = halide_type_to_string(dst, end, &t);
return *this;
}

Printer &operator<<(const halide_buffer_t &buf) {
NEVER_INLINE PrinterBase &operator<<(const halide_buffer_t &buf) {
dst = halide_buffer_to_string(dst, end, &buf);
return *this;
}

template<typename T>
void append(const T &value) {
*this << value;
}

template<typename First, typename Second, typename... Rest>
void append(const First &first, const Second &second, const Rest &...rest) {
append<First>(first);
append<Second, Rest...>(second, rest...);
}

// Use it like a stringstream.
const char *str() {
if (buf) {
if (printer_type == StringStreamPrinterType) {
msan_annotate_is_initialized();
}
return buf;
} else {
return allocation_error();
}
}

// Clear it. Useful for reusing a stringstream.
void clear() {
dst = buf;
if (dst) {
dst[0] = 0;
}
template<typename... Args>
void append(const Args &...args) {
((*this << args), ...);
}

// Returns the number of characters in the buffer
uint64_t size() const {
return (uint64_t)(dst - buf);
}
// Not movable, not copyable
PrinterBase(const PrinterBase &copy) = delete;
PrinterBase &operator=(const PrinterBase &) = delete;
PrinterBase(PrinterBase &&) = delete;
PrinterBase &operator=(PrinterBase &&) = delete;
};

uint64_t capacity() const {
return buffer_length;
}
namespace {

// Delete the last N characters
void erase(int n) {
if (dst) {
dst -= n;
if (dst < buf) {
dst = buf;
}
dst[0] = 0;
template<PrinterType printer_type, uint64_t buffer_length = default_printer_buffer_length>
class HeapPrinter : public PrinterBase {
public:
NEVER_INLINE explicit HeapPrinter(void *user_context)
: PrinterBase(user_context, (char *)malloc(buffer_length), buffer_length) {
if (!start) {
allocation_error();
}
}

const char *allocation_error() {
return "Printer buffer allocation failed.\n";
}

void msan_annotate_is_initialized() {
(void)halide_msan_annotate_memory_is_initialized(user_context, buf, dst - buf + 1);
#if HALIDE_RUNTIME_PRINTER_LOG_THREADID
uint64_t tid;
pthread_threadid_np(0, &tid);
*this << "(TID:" << tid << ")";
#endif
}

~Printer() {
if (!buf) {
halide_error(user_context, allocation_error());
NEVER_INLINE ~HeapPrinter() {
if (printer_type == ErrorPrinterType) {
halide_error(user_context, str());
} else if (printer_type == BasicPrinterType) {
halide_print(user_context, str());
} else {
msan_annotate_is_initialized();
if (printer_type == ErrorPrinterType) {
halide_error(user_context, buf);
} else if (printer_type == BasicPrinterType) {
halide_print(user_context, buf);
} else {
// It's a stringstream. Do nothing.
}
// It's a stringstream. Do nothing.
}

if (own_mem) {
free(buf);
}
free(start);
}
};

// A class that supports << with all the same types as Printer, but
// does nothing and should compile to a no-op.
class SinkPrinter {
Expand All @@ -227,13 +222,13 @@ ALWAYS_INLINE SinkPrinter operator<<(const SinkPrinter &s, T) {
}

template<uint64_t buffer_length = default_printer_buffer_length>
using BasicPrinter = Printer<BasicPrinterType, buffer_length>;
using BasicPrinter = HeapPrinter<BasicPrinterType, buffer_length>;

template<uint64_t buffer_length = default_printer_buffer_length>
using ErrorPrinter = Printer<ErrorPrinterType, buffer_length>;
using ErrorPrinter = HeapPrinter<ErrorPrinterType, buffer_length>;

template<uint64_t buffer_length = default_printer_buffer_length>
using StringStreamPrinter = Printer<StringStreamPrinterType, buffer_length>;
using StringStreamPrinter = HeapPrinter<StringStreamPrinterType, buffer_length>;

using print = BasicPrinter<>;
using error = ErrorPrinter<>;
Expand All @@ -244,17 +239,16 @@ using debug = BasicPrinter<>;
#else
using debug = SinkPrinter;
#endif
} // namespace

// A Printer that automatically reserves stack space for the printer buffer, rather than malloc.
// Note that this requires an explicit buffer_length, and it (generally) should be <= 256.
template<PrinterType printer_type, uint64_t buffer_length>
class StackPrinter : public Printer<printer_type, buffer_length> {
class StackPrinter : public PrinterBase {
char scratch[buffer_length];

public:
explicit StackPrinter(void *ctx)
: Printer<printer_type, buffer_length>(ctx, scratch) {
explicit StackPrinter(void *user_context)
: PrinterBase(user_context, scratch, buffer_length) {
static_assert(buffer_length <= 256, "StackPrinter is meant only for small buffer sizes; you are probably making a mistake.");
}
};
Expand All @@ -268,6 +262,8 @@ using StackErrorPrinter = StackPrinter<ErrorPrinterType, buffer_length>;
template<uint64_t buffer_length = default_printer_buffer_length>
using StackStringStreamPrinter = StackPrinter<StringStreamPrinterType, buffer_length>;

} // namespace

} // namespace Internal
} // namespace Runtime
} // namespace Halide
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/runtime_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ typedef ptrdiff_t ssize_t;

#define WEAK __attribute__((weak))

#define NEVER_INLINE __attribute__((noinline))

// Note that ALWAYS_INLINE should *always* also be `inline`.
#define ALWAYS_INLINE inline __attribute__((always_inline))

Expand Down
Loading