From 155b33bc8448a10cfdaa3c30c0ccd466c1357d6e Mon Sep 17 00:00:00 2001 From: Samuel Merritt Date: Wed, 14 Jun 2023 09:34:36 -0700 Subject: [PATCH] Surface websocket protocol errors to user applications Currently, if a websocket client sends bad data[1] to an application, the application just receives a generic error message with no indication of what went wrong. Also, the client just gets an aborted[2] websocket connection without any clue as to what they did wrong. This commit changes two things: first, the application now gets an error event describing what was wrong with the received data. This gives the application's owner some clue what's going wrong. Second, we now send a Close frame to the client with an appropriate error code, as we SHOULD do[3]. In an ideal world, this will let the client's owner figure out what they're doing wrong and fix it. [1] Invalid according to RFC 6455, for example sending a continuation frame without a preceding start frame or sending a frame with reserved bits (RSV1, RSV2, and RSV3; see https://datatracker.ietf.org/doc/html/rfc6455#section-5.2) set. [2] The underlying TCP connection is closed without first sending a websocket "Close" frame. [3] https://datatracker.ietf.org/doc/html/rfc6455#section-7.1.7 --- WORKSPACE | 10 +- src/workerd/api/web-socket.c++ | 63 +++++++++- src/workerd/api/web-socket.h | 33 +++++ src/workerd/io/io-context.c++ | 2 +- src/workerd/server/server-test.c++ | 187 ++++++++++++++++++++++++++++- src/workerd/server/server.c++ | 3 + 6 files changed, 285 insertions(+), 13 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 13a9c07fe704..d0061a2dab9a 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -24,10 +24,10 @@ bazel_skylib_workspace() http_archive( name = "capnp-cpp", - sha256 = "133d70c0f7482eb36cb6a4c662445bccb6219677aa61e15d22b9ace67bc36aa3", - strip_prefix = "capnproto-capnproto-b38fc11/c++", + sha256 = "b3d756251af1861681c29aa81ee074dac09122a4429f87950ef9235b17122389", + strip_prefix = "capnproto-capnproto-495cebe/c++", type = "tgz", - urls = ["https://github.com/capnproto/capnproto/tarball/b38fc110aed669be98195c957acf0d35fccb8252"], + urls = ["https://github.com/capnproto/capnproto/tarball/495cebe8673a0820e3202de66970059be70d4928"], ) http_archive( @@ -108,8 +108,8 @@ git_repository( # tcmalloc requires Abseil. # -# WARNING: This MUST appear before rules_fuzzing_depnedencies(), below. Otherwise, -# rules_fuzzing_depnedencies() will choose to pull in a different version of Abseil that is too +# WARNING: This MUST appear before rules_fuzzing_dependencies(), below. Otherwise, +# rules_fuzzing_dependencies() will choose to pull in a different version of Abseil that is too # old for tcmalloc. Absurdly, Bazel simply ignores later attempts to define the same repo name, # rather than erroring out. Thus this leads to confusing compiler errors in tcmalloc complaining # that ABSL_ATTRIBUTE_PURE_FUNCTION is not defined. diff --git a/src/workerd/api/web-socket.c++ b/src/workerd/api/web-socket.c++ index 08b0679e971e..2518dd50dcd9 100644 --- a/src/workerd/api/web-socket.c++ +++ b/src/workerd/api/web-socket.c++ @@ -12,6 +12,22 @@ namespace workerd::api { +kj::Maybe WebSocketProtocolError::fromException(const kj::Exception& ex) { + auto maybeContext = ex.getContext(); + while (maybeContext != nullptr) { + auto& context = KJ_ASSERT_NONNULL(maybeContext); + if (magicFileValue == context.file) { + return WebSocketProtocolError(context.line, kj::str(context.description)); + } + maybeContext = context.next; + } + return nullptr; +} + +void WebSocketProtocolError::encodeToException(kj::Exception &ex) && { + ex.wrapContext(magicFileValue.cStr(), code, kj::mv(description)); +} + kj::StringPtr KJ_STRINGIFY(const WebSocket::NativeState& state) { // TODO(someday) We might care more about this `OneOf` than its which, that probably means // returning a kj::String instead. @@ -25,6 +41,20 @@ kj::StringPtr KJ_STRINGIFY(const WebSocket::NativeState& state) { KJ_UNREACHABLE; } +kj::Exception WebSocketErrorHandler::handleWebSocketProtocolError(kj::WebSocket::ProtocolError protocolError) { + KJ_REQUIRE(protocolError.description.size() <= 122); + // We're going to put this as the reason in a WebSocket Close frame, so it has to fit. + // RFC6455 section 5.5 puts a maximum payload length of 125 on control frames, of which Close + // is one type. The payload also needs room for a status code (2 bytes) and maybe a null + // terminator (1 byte), leaving 122 bytes for the description. + kj::Exception ex = KJ_EXCEPTION(FAILED, "worker_do_not_log: WebSocket protocol error"); + api::WebSocketProtocolError wspe{static_cast(protocolError.statusCode), + kj::heapString(protocolError.description)}; + // The status codes are all 4-digit decimal numbers, so they'll easily fit in an int. + kj::mv(wspe).encodeToException(ex); + return ex; +} + IoOwn WebSocket::initNative(IoContext& ioContext, kj::WebSocket& ws) { auto nativeObj = kj::heap(); nativeObj->state.init(Accepted::Hibernatable{.ws = ws}, *nativeObj, ioContext); @@ -309,8 +339,16 @@ kj::Promise> WebSocket::couple(kj::Own other) auto& context = IoContext::current(); - auto upstream = other->pumpTo(*self); - auto downstream = self->pumpTo(*other); + auto upstream = other->pumpTo(*self).catch_([](kj::Exception&& e) { + if (WebSocketProtocolError::fromException(e) == nullptr) { + kj::throwFatalException(kj::mv(e)); + }; + }); + auto downstream = self->pumpTo(*other).catch_([](kj::Exception&& e) { + if (WebSocketProtocolError::fromException(e) == nullptr) { + kj::throwFatalException(kj::mv(e)); + }; + }); if (locality == LOCAL) { // We're terminating the WebSocket in this worker, so the upstream promise (which pumps @@ -349,7 +387,7 @@ void WebSocket::accept(jsg::Lock& js) { "Can't accept() WebSocket after enabling hibernation."); // Technically, this means it's okay to invoke `accept()` once a `new WebSocket()` resolves to // an established connection. This is probably okay? It might spare the worker devs a class of - // errors they do not care care about. + // errors they do not care about. return; } @@ -429,7 +467,7 @@ WebSocket::Accepted::~Accepted() noexcept(false) { void WebSocket::startReadLoop(jsg::Lock& js) { // If the kj::WebSocket happens to be an AbortableWebSocket (see util/abortable.h), then // calling readLoop here could throw synchronously if the canceler has already been tripped. - // Using kj::evalNow() here let's us capture that and handle correctly. + // Using kj::evalNow() here lets us capture that and handle correctly. // // We catch exceptions and return Maybe instead since we want to handle the exceptions // in awaitIo() below, but we don't want the KJ exception converted to JavaScript before we can @@ -464,12 +502,25 @@ void WebSocket::startReadLoop(jsg::Lock& js) { (jsg::Lock& js, kj::Maybe&& maybeError) mutable { auto& native = *farNative; KJ_IF_MAYBE(e, maybeError) { - if (!native.closedIncoming && e->getType() == kj::Exception::Type::DISCONNECTED) { + KJ_IF_MAYBE(wspe, WebSocketProtocolError::fromException(*e)) { + // The client sent us an invalid websocket message. + if (!native.closedOutgoing) { + // Send a close message to the client if we can. + native.closedIncoming = true; + close(js, wspe->getCode(), kj::str(wspe->getDescription())); + } + // Report to the application as an error event. + jsg::Value errorDescription{js.v8Isolate, js.wrapString(wspe->getDescription())}; + error = errorDescription.addRef(js); + dispatchEventImpl(js, jsg::alloc( + kj::str("WebSocket code ", wspe->getCode()), + kj::mv(errorDescription), js.v8Isolate)); + } else if (!native.closedIncoming && e->getType() == kj::Exception::Type::DISCONNECTED) { // Report premature disconnect or cancel as a close event. dispatchEventImpl(js, jsg::alloc( 1006, kj::str("WebSocket disconnected without sending Close frame."), false)); native.closedIncoming = true; - // If there are no further messages to send, so we can discard the underlying connection. + // If there are no further messages to send, we can discard the underlying connection. tryReleaseNative(js); } else { native.closedIncoming = true; diff --git a/src/workerd/api/web-socket.h b/src/workerd/api/web-socket.h index 189d9e04796c..dc1494ab67e4 100644 --- a/src/workerd/api/web-socket.h +++ b/src/workerd/api/web-socket.h @@ -146,6 +146,39 @@ class ErrorEvent: public Event { void visitForGc(jsg::GcVisitor& visitor); }; +class WebSocketProtocolError { +public: + // Exception-like thing for WebSocket protocol errors. Since kj::WebSocket indicates protocol + // errors by throwing an exception and since exceptions are caught by kj::Promise, we can't just + // throw WebSocketProtocolError. Instead, we smuggle the extra in the exception's context. + WebSocketProtocolError(int code, kj::String description) + : code(code), description(kj::mv(description)) {} + + static kj::Maybe fromException(const kj::Exception& ex); + // Generates a WebSocketProtocolError from the exception's context, but only if the context + // actually holds appropriate data, i.e. someone previously used encodeToException on it. + + int getCode() const { return code; } + kj::StringPtr getDescription() const { return description; } + + void encodeToException(kj::Exception& ex) &&; + // Adds a context entry to the exception containing this object's data. This is only useful if + // you're going to retrieve it with fromException later. + +private: + int code; + kj::String description; + static inline constexpr kj::StringPtr magicFileValue = "__WebSocketProtocolError_magicFileValue"_kj; + // Used as a sentinel in exception context frames; if frame.file == magicFileValue, then that + // frame contains data about a websocket protocol error. The exact value is unimportant; it just + // has to not look like a real file path. +}; + +class WebSocketErrorHandler : public kj::WebSocketErrorHandler { + kj::Exception handleWebSocketProtocolError(kj::WebSocket::ProtocolError protocolError) override; +}; +// Handler for WebSocket protocol errors. + // The forward declaration is necessary so we can make some // WebSocket methods accessible to WebSocketPair via friend declaration. class WebSocket; diff --git a/src/workerd/io/io-context.c++ b/src/workerd/io/io-context.c++ index b82f1c21ede4..b0d7b20b1842 100644 --- a/src/workerd/io/io-context.c++ +++ b/src/workerd/io/io-context.c++ @@ -372,7 +372,7 @@ void IoContext::logUncaughtExceptionAsync(UncaughtExceptionSource source, // do still want to syslog if relevant, but we can do that without a lock. if (!jsg::isTunneledException(exception.getDescription()) && !jsg::isDoNotLogException(exception.getDescription()) && - // TODO(soon): Figure out why client disconncects are getting logged here if we don't + // TODO(soon): Figure out why client disconnects are getting logged here if we don't // ignore DISCONNECTED. If we fix that, do we still want to filter these? exception.getType() != kj::Exception::Type::DISCONNECTED) { LOG_EXCEPTION("jsgInternalError", exception); diff --git a/src/workerd/server/server-test.c++ b/src/workerd/server/server-test.c++ index 8ac034f3e265..9ca0c50e14fb 100644 --- a/src/workerd/server/server-test.c++ +++ b/src/workerd/server/server-test.c++ @@ -3,6 +3,7 @@ // https://opensource.org/licenses/Apache-2.0 #include "server.h" +#include #include #include #include @@ -103,11 +104,29 @@ public: if (actual == nullptr) { KJ_FAIL_EXPECT_AT(loc, "message never received"); } else { - std::regex target(matcher.cStr()); + std::regex target(matcher.cStr(), std::regex::extended); KJ_EXPECT(std::regex_match(actual.cStr(), target), actual, matcher, loc); } } + void recvWebSocket(kj::StringPtr expected, kj::SourceLocation loc = {}) { + auto actual = readWebSocketMessage(); + KJ_EXPECT_AT(actual == expected, loc); + } + + void recvWebSocketRegex(kj::StringPtr matcher, kj::SourceLocation loc = {}) { + auto actual = readWebSocketMessage(); + std::regex target(matcher.cStr()); + KJ_EXPECT(std::regex_match(actual.cStr(), target), actual, matcher, loc); + } + + void recvWebSocketClose(int expectedCode) { + auto actual = readWebSocketMessage(); + KJ_EXPECT(actual.size() >= 2); + int gotCode = (static_cast(actual[0]) << 8) + static_cast(actual[1]); + KJ_EXPECT(gotCode == expectedCode); + } + void sendHttpGet(kj::StringPtr path, kj::SourceLocation loc = {}) { send(kj::str( "GET ", path, " HTTP/1.1\n" @@ -129,6 +148,25 @@ public: recvHttp200(expectedResponse, loc); } + void upgradeToWebSocket() { + send(R"( + GET / HTTP/1.1 + Host: somehost + Upgrade: websocket + Sec-WebSocket-Key: AAAAAAAAAAAAAAAAAAAAAA== + Sec-WebSocket-Version: 13 + + )"_blockquote); + + recv(R"( + HTTP/1.1 101 Switching Protocols + Connection: Upgrade + Upgrade: websocket + Sec-WebSocket-Accept: ICX+Yqv66kxgM0FcWaLWlFLwTAI= + + )"_blockquote); + } + bool isEof() { // Return true if the stream is at EOF. @@ -199,6 +237,69 @@ private: buffer.add('\0'); return kj::String(buffer.releaseAsArray()); } + + kj::String readWebSocketMessage(size_t maxMessageSize = 1 << 24) { + // Reads a single, non-fragmented WebSocket message. Returns just the payload. + kj::Vector header(256); + kj::Vector mask(4); + + KJ_IF_MAYBE(p, premature) { + header.add(*p); + premature = kj::Maybe(); + } + + tryRead(header, 2 - header.size(), "reading first two bytes of header"); + bool masked = header[1] & 0x80; + size_t sevenBitPayloadLength = header[1] & 0x7f; + size_t realPayloadLength = sevenBitPayloadLength; + + if (sevenBitPayloadLength == 126) { + tryRead(header, 2, "reading 16-bit payload length"); + realPayloadLength = (static_cast(header[2]) << 8) + static_cast(header[3]); + } else if (sevenBitPayloadLength == 127) { + tryRead(header, 8, "reading 64-bit payload length"); + realPayloadLength = (static_cast(header[2]) << 56) + + (static_cast(header[3]) << 48) + + (static_cast(header[4]) << 40) + + (static_cast(header[5]) << 32) + + (static_cast(header[6]) << 24) + + (static_cast(header[7]) << 16) + + (static_cast(header[8]) << 8) + + (static_cast(header[9])); + + KJ_REQUIRE(realPayloadLength <= maxMessageSize, + kj::str("Payload size too big (", realPayloadLength, " > ", maxMessageSize, ")")); + } + + if (masked) { + tryRead(mask, 4, "reading mask key"); + // Currently we assume the mask is always 0, so its application is a no-op, hence we don't + // bother. + } + kj::Vector payload(realPayloadLength + 1); + + tryRead(payload, realPayloadLength, "reading payload"); + payload.add('\0'); + return kj::String(payload.releaseAsArray()); + } + + template + void tryRead(kj::Vector& buffer, size_t bytesToRead, kj::StringPtr what) { + static_assert(sizeof(T) == 1, "not byte-sized"); + + size_t pos = buffer.size(); + size_t bytesRead = 0; + buffer.resize(buffer.size() + bytesToRead); + while (bytesRead < bytesToRead) { + auto promise = stream->tryRead(buffer.begin() + pos, 1, buffer.size() - pos); + KJ_REQUIRE(promise.poll(ws), kj::str("No data available while ", what)); + // A tryRead() of 1 byte didn't resolve, there must be no data to read. + + size_t n = promise.wait(ws); + KJ_REQUIRE(n > 0, kj::str("Not enough data while ", what)); + bytesRead += n; + } + } }; class TestServer final: private kj::Filesystem, private kj::EntropySource, private kj::Clock { @@ -495,6 +596,90 @@ KJ_TEST("Server: serve basic Service Worker") { Bad Request)"_blockquote); } +KJ_TEST("Server: serve simple WebSocket echo service") { + TestServer test(singleWorker(R"(( + compatibilityDate = "2022-08-17", + modules = [ + ( name = "main.js", + esModule = + `export default { + ` async fetch(request) { + ` const upgradeHeader = request.headers.get('Upgrade'); + ` if (!upgradeHeader || upgradeHeader !== 'websocket') { + ` return new Response('Expected Upgrade: websocket', {status: 400}); + ` } + ` const pair = new WebSocketPair(); + ` const client = pair[0], server = pair[1]; + ` server.accept(); + ` server.addEventListener('message', (event) => { + ` server.send(event.data); // echo it back + ` }); + ` return new Response(null, {status: 101, webSocket: client}); + ` } + `} + ) + ] + ))"_kj)); + + test.start(); + auto conn = test.connect("test-addr"); + conn.upgradeToWebSocket(); + conn.send("\x82\x05hello"); + conn.recvWebSocket("hello"); +} + +KJ_TEST("Server: test WebSocket errors: bad RSV bits") { + TestServer test(singleWorker(R"(( + compatibilityDate = "2022-08-17", + modules = [ + ( name = "main.js", + esModule = + `var errors = []; + `export default { + ` async fetch(request) { + ` const upgradeHeader = request.headers.get('Upgrade'); + ` if (!upgradeHeader || upgradeHeader !== 'websocket') { + ` return new Response("expected Upgrade: 'websocket'", {status: 400}); + ` } + ` const pair = new WebSocketPair(); + ` const client = pair[0], server = pair[1]; + ` server.accept(); + ` server.addEventListener('message', (event) => { + ` if (event.data === "getErrors") { + ` server.send(JSON.stringify(errors)) + ` } else if (event.data === "getErrorCount") { + ` server.send(errors.length.toString()) + ` } else { + ` server.send(event.data); // echo + ` } + ` }); + ` server.addEventListener('error', (event) => { + ` console.log(event.error); + ` errors.push(event); + ` }); + ` return new Response(null, {status: 101, webSocket: client}); + ` } + `} + ) + ] + ))"_kj)); + + test.start(); + + auto wsConn = test.connect("test-addr"); + // wsConn will send some bad WebSocket data. + wsConn.upgradeToWebSocket(); + wsConn.send("\xf1\x08hi there"); // bad frame: all RSV bits set + wsConn.recvWebSocketClose(1002); + + auto errWsConn = test.connect("test-addr"); + errWsConn.upgradeToWebSocket(); + errWsConn.send("\x81\x0dgetErrorCount"); + errWsConn.recvWebSocketRegex("1"); + errWsConn.send("\x81\x09getErrors"); + errWsConn.recvWebSocketRegex(".*RSV bits.*"); +} + KJ_TEST("Server: use service name as Service Worker origin") { TestServer test(singleWorker(R"(( compatibilityDate = "2022-08-17", diff --git a/src/workerd/server/server.c++ b/src/workerd/server/server.c++ index 0cf5ac9df048..78d2511645d3 100644 --- a/src/workerd/server/server.c++ +++ b/src/workerd/server/server.c++ @@ -25,6 +25,7 @@ #include #include #include +#include "src/workerd/api/web-socket.h" #include "workerd-api.h" #include @@ -2350,11 +2351,13 @@ private: : parent(parent), cfBlobJson(kj::mv(cfBlobJson)), listedHttp(parent.owner, parent.timer, parent.headerTable, *this, kj::HttpServerSettings { .errorHandler = *this, + .webSocketErrorHandler = this->webSocketErrorHandler, .webSocketCompressionMode = kj::HttpServerSettings::MANUAL_COMPRESSION }) {} HttpListener& parent; kj::Maybe cfBlobJson; + workerd::api::WebSocketErrorHandler webSocketErrorHandler; ListedHttpServer listedHttp; class ResponseWrapper final: public kj::HttpService::Response {