diff --git a/src/WebSockets.jl b/src/WebSockets.jl index 9019f2052..2f29b1df1 100644 --- a/src/WebSockets.jl +++ b/src/WebSockets.jl @@ -10,7 +10,6 @@ using HTTP: header import ..@debug, ..DEBUG_LEVEL, ..@require, ..precondition_error - const WS_FINAL = 0x80 const WS_CONTINUATION = 0x00 const WS_TEXT = 0x01 @@ -18,7 +17,6 @@ const WS_BINARY = 0x02 const WS_CLOSE = 0x08 const WS_PING = 0x09 const WS_PONG = 0x0A - const WS_MASK = 0x80 @@ -37,29 +35,33 @@ struct WebSocketHeader end +@enum ReadyState CONNECTED=0x1 CLOSING=0x2 CLOSED=0x3 + + mutable struct WebSocket{T <: IO} <: IO io::T frame_type::UInt8 server::Bool rxpayload::Vector{UInt8} txpayload::Vector{UInt8} - txclosed::Bool - rxclosed::Bool + state::ReadyState end + function WebSocket(io::T; server=false, binary=false) where T <: IO WebSocket{T}(io, binary ? WS_BINARY : WS_TEXT, server, - UInt8[], UInt8[], false, false) + UInt8[], UInt8[], CONNECTED) end # Handshake + is_websocket_upgrade(r::HTTP.Message) = (r isa HTTP.Request && r.method == "GET" || r.status == 101) && HTTP.hasheader(r, "Connection", "upgrade") && - HTTP.hasheader(r, "Upgrade", "webscoket") + HTTP.hasheader(r, "Upgrade", "websocket") function check_upgrade(http) @@ -168,15 +170,6 @@ function Base.write(ws::WebSocket, x1, x2, xs...) end -function IOExtras.closewrite(ws::WebSocket) - @require !ws.txclosed - opcode = WS_FINAL | WS_CLOSE - @debug 1 "WebSocket ⬅️ $(WebSocketHeader(opcode, 0x00))" - write(ws.io, opcode, 0x00) - ws.txclosed = true -end - - wslength(l) = l < 0x7E ? (UInt8(l), UInt8[]) : l <= 0xFFFF ? (0x7E, reinterpret(UInt8, [hton(UInt16(l))])) : (0x7F, reinterpret(UInt8, [hton(UInt64(l))])) @@ -184,10 +177,11 @@ wslength(l) = l < 0x7E ? (UInt8(l), UInt8[]) : wswrite(ws::WebSocket, x) = wswrite(ws, WS_FINAL | ws.frame_type, x) + wswrite(ws::WebSocket, opcode::UInt8, x) = wswrite(ws, opcode, Vector{UInt8}(x)) -function wswrite(ws::WebSocket, opcode::UInt8, bytes::Vector{UInt8}) +function wswrite(ws::WebSocket, opcode::UInt8, bytes::Vector{UInt8}) n = length(bytes) len, extended_len = wslength(n) if ws.server @@ -218,23 +212,27 @@ end function Base.close(ws::WebSocket) - if !ws.txclosed - closewrite(ws) - end - while !ws.rxclosed + @require ws.state == CONNECTED + opcode = WS_FINAL | WS_CLOSE + @debug 1 "WebSocket ⬅️ $(WebSocketHeader(opcode, 0x00))" + write(ws.io, opcode, 0x00) + ws.state = CLOSING + while !eof(ws) && ws.state == CLOSING readframe(ws) end end -Base.isopen(ws::WebSocket) = !ws.rxclosed +Base.isopen(ws::WebSocket) = (ws.state == CONNECTED) && isopen(ws.io) # Receiving Frames + Base.eof(ws::WebSocket) = eof(ws.io) + Base.readavailable(ws::WebSocket) = collect(readframe(ws)) @@ -265,7 +263,6 @@ function readframe(ws::WebSocket) end if h.opcode == WS_CLOSE - ws.rxclosed = true if h.length >= 2 status = UInt16(ws.rxpayload[1]) << 8 | ws.rxpayload[2] if status != 1000 @@ -273,6 +270,11 @@ function readframe(ws::WebSocket) throw(WebSocketError(status, message)) end end + if ws.state == CONNECTED + close(ws) + end + ws.state = CLOSED + close(ws.io) return UInt8[] elseif h.opcode == WS_PING write(ws.io, [WS_PONG, 0x00]) @@ -294,6 +296,7 @@ function WebSocketHeader(bytes...) return readheader(io) end + function Base.show(io::IO, h::WebSocketHeader) print(io, "WebSocketHeader(", h.opcode == WS_CONTINUATION ? "CONTINUATION" : diff --git a/test/WebSockets.jl b/test/WebSockets.jl index 303e9310e..9cfc47ef6 100644 --- a/test/WebSockets.jl +++ b/test/WebSockets.jl @@ -5,24 +5,39 @@ using HTTP.IOExtras @testset "WebSockets" begin for s in ["ws", "wss"] + info("Testing $(s)...") + HTTP.WebSockets.open("$s://echo.websocket.org") do ws + write(ws, HTTP.bytes("Foo")) + @test !eof(ws) + @test String(readavailable(ws)) == "Foo" - HTTP.WebSockets.open("$s://echo.websocket.org") do io - write(io, HTTP.bytes("Foo")) - @test !eof(io) - @test String(readavailable(io)) == "Foo" - - write(io, HTTP.bytes("Hello")) - write(io, " There") - write(io, " World", "!") - closewrite(io) + close(ws) + end +end - buf = IOBuffer() - write(buf, io) - @test String(take!(buf)) == "Hello There World!" - close(io) +p = 8000 +@async HTTP.listen(ip"127.0.0.1",p) do http + if HTTP.WebSockets.is_websocket_upgrade(http.message) + HTTP.WebSockets.upgrade(http) do ws + data = "" + while !eof(ws); + data = String(readavailable(ws)) + write(ws,data) + end + end end +end + +sleep(2) + +info("Testing local server...") +HTTP.WebSockets.open("ws://127.0.0.1:$(p)") do ws + write(ws, HTTP.bytes("Foo")) + @test !eof(ws) + @test String(readavailable(ws)) == "Foo" + close(ws) end end # testset