diff --git a/utils/ws.go b/utils/ws.go index e7e243c..ab0d51e 100644 --- a/utils/ws.go +++ b/utils/ws.go @@ -3,7 +3,9 @@ package utils import ( "bufio" "context" + "crypto/sha1" "crypto/tls" + "encoding/base64" "errors" "fmt" "io" @@ -145,32 +147,28 @@ func ServerWebsocketUpgrade(w http.ResponseWriter, r *http.Request) (*WebsocketC var rw *bufio.ReadWriter var err error isRaw := IsV2rayHttpUpdate(r) - if isRaw { // v2ray-http-upgrade - w.Header().Set("Connection", "upgrade") - w.Header().Set("Upgrade", "websocket") - w.WriteHeader(http.StatusSwitchingProtocols) - if flusher, isFlusher := w.(interface{ FlushError() error }); isFlusher { - err = flusher.FlushError() - if err != nil { - return nil, fmt.Errorf("flush response: %w", err) - } - } - hijacker, canHijack := w.(http.Hijacker) - if !canHijack { - return nil, errors.New("invalid connection, maybe HTTP/2") - } - conn, rw, err = hijacker.Hijack() - if err != nil { - return nil, fmt.Errorf("hijack failed: %w", err) - } - } else { - conn, rw, _, err = ws.UpgradeHTTP(r, w) + w.Header().Set("Connection", "upgrade") + w.Header().Set("Upgrade", "websocket") + if !isRaw { + w.Header().Set("Sec-Websocket-Accept", getSecAccept(r.Header.Get("Sec-WebSocket-Key"))) + } + w.WriteHeader(http.StatusSwitchingProtocols) + if flusher, isFlusher := w.(interface{ FlushError() error }); isFlusher { + err = flusher.FlushError() if err != nil { - return nil, fmt.Errorf("ws upgrade failed: %w", err) + return nil, fmt.Errorf("flush response: %w", err) } } + hijacker, canHijack := w.(http.Hijacker) + if !canHijack { + return nil, errors.New("invalid connection, maybe HTTP/2") + } + conn, rw, err = hijacker.Hijack() + if err != nil { + return nil, fmt.Errorf("hijack failed: %w", err) + } - // gobwas/ws will flush rw.Writer, so we only need warp rw.Reader + // rw.Writer was flushed, so we only need warp rw.Reader conn = peek.WarpConnWithBioReader(conn, rw.Reader) return NewWebsocketConn(conn, ws.StateServerSide, isRaw), nil @@ -233,3 +231,13 @@ func ClientWebsocketDial(uri url.URL, cHeaders http.Header, netDial proxy.NetDia return NewWebsocketConn(conn, ws.StateClientSide, false), headers, nil } + +func getSecAccept(secKey string) string { + const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + const nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize) + p := make([]byte, nonceSize+len(magic)) + copy(p[:nonceSize], secKey) + copy(p[nonceSize:], magic) + sum := sha1.Sum(p) + return base64.StdEncoding.EncodeToString(sum[:]) +}