Skip to content

Commit

Permalink
Add binary safety check to server
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Nov 30, 2024
1 parent 36696f3 commit 6d89f8f
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 6 deletions.
32 changes: 26 additions & 6 deletions llamafile/server/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llamafile/server/server.h"
#include "llamafile/server/time.h"
#include "llamafile/server/tokenbucket.h"
#include "llamafile/server/utils.h"
#include "llamafile/server/worker.h"
#include "llamafile/string.h"
#include "llamafile/threadlocal.h"
Expand Down Expand Up @@ -478,7 +479,7 @@ Client::send_response_chunk(const std::string_view content)

// perform send system call
ssize_t sent;
if ((sent = writev(fd_, iov, 3)) != bytes) {
if ((sent = safe_writev(fd_, iov, 3)) != bytes) {
if (sent == -1 && errno != EAGAIN && errno != ECONNRESET)
SLOG("writev failed %m");
close_connection_ = true;
Expand All @@ -504,15 +505,34 @@ Client::send_response_finish()
return send("0\r\n\r\n");
}

// writes raw data to socket
// writes any old data to socket
//
// unlike send() this won't fail if binary content is detected.
bool
Client::send_binary(const void* p, size_t n)
{
ssize_t sent;
if ((sent = write(fd_, p, n)) != n) {
if (sent == -1 && errno != EAGAIN && errno != ECONNRESET)
SLOG("write failed %m");
close_connection_ = true;
return false;
}
return true;
}

// writes non-binary data to socket
//
// consider using the higher level methods like send_error(),
// send_response(), send_response_start(), etc.
bool
Client::send(const std::string_view s)
{
iovec iov[1];
ssize_t sent;
if ((sent = write(fd_, s.data(), s.size())) != s.size()) {
iov[0].iov_base = (void*)s.data();
iov[0].iov_len = s.size();
if ((sent = safe_writev(fd_, iov, 1)) != s.size()) {
if (sent == -1 && errno != EAGAIN && errno != ECONNRESET)
SLOG("write failed %m");
close_connection_ = true;
Expand All @@ -521,7 +541,7 @@ Client::send(const std::string_view s)
return true;
}

// writes two pieces of raw data to socket in single system call
// writes two pieces of non-binary data to socket in single system call
//
// consider using the higher level methods like send_error(),
// send_response(), send_response_start(), etc.
Expand All @@ -534,7 +554,7 @@ Client::send2(const std::string_view s1, const std::string_view s2)
iov[0].iov_len = s1.size();
iov[1].iov_base = (void*)s2.data();
iov[1].iov_len = s2.size();
if ((sent = writev(fd_, iov, 2)) != s1.size() + s2.size()) {
if ((sent = safe_writev(fd_, iov, 2)) != s1.size() + s2.size()) {
if (sent == -1 && errno != EAGAIN && errno != ECONNRESET)
SLOG("writev failed %m");
close_connection_ = true;
Expand Down Expand Up @@ -755,7 +775,7 @@ Client::dispatcher()
close_connection_ = true;
return false;
}
if (!send(std::string_view(buf, chunk))) {
if (!send_binary(buf, chunk)) {
close_connection_ = true;
return false;
}
Expand Down
1 change: 1 addition & 0 deletions llamafile/server/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ struct Client
bool read_content() __wur;
bool send_continue() __wur;
bool send(const std::string_view) __wur;
bool send_binary(const void*, size_t) __wur;
void defer_cleanup(void (*)(void*), void*);
bool send_error(int, const char* = nullptr);
char* append_http_response_message(char*, int, const char* = nullptr);
Expand Down
4 changes: 4 additions & 0 deletions llamafile/server/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <__fwd/string_view.h>
#include <__fwd/vector.h>
#include <optional>
#include <sys/uio.h>

struct llama_model;

Expand All @@ -28,6 +29,9 @@ namespace server {

class Atom;

ssize_t
safe_writev(int, const iovec*, int);

bool
atob(std::string_view, bool);

Expand Down
46 changes: 46 additions & 0 deletions llamafile/server/writev.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "llamafile/server/log.h"
#include "utils.h"
#include <cerrno>
#include <string_view>

namespace lf {
namespace server {

ssize_t
safe_writev(int fd, const iovec* iov, int iovcnt)
{
for (int i = 0; i < iovcnt; ++i) {
bool has_binary = false;
size_t n = iov[i].iov_len;
unsigned char* p = (unsigned char*)iov[i].iov_base;
for (size_t j = 0; j < n; ++j) {
has_binary |= p[j] < 7;
}
if (has_binary) {
SLOG("safe_writev() detected binary server is compromised");
errno = EINVAL;
return -1;
}
}
return writev(fd, iov, iovcnt);
}

} // namespace server
} // namespace lf

0 comments on commit 6d89f8f

Please sign in to comment.