Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TCP+SNI support arbitrary large Client Hello #423

Merged
merged 3 commits into from
Feb 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions proxy/tcp/sni_proxy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tcp

import (
"bufio"
"io"
"log"
"net"
Expand Down Expand Up @@ -41,20 +42,40 @@ func (p *SNIProxy) ServeTCP(in net.Conn) error {
p.Conn.Inc(1)
}

// capture client hello
data := make([]byte, 1024)
n, err := in.Read(data)
tlsReader := bufio.NewReader(in)
tlsHeaders, err := tlsReader.Peek(9)
if err != nil {
log.Print("[DEBUG] tcp+sni: TLS handshake failed (failed to peek data)")
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
return err
}
data = data[:n]

host, ok := readServerName(data)
bufferSize, err := clientHelloBufferSize(tlsHeaders)
if err != nil {
log.Printf("[DEBUG] tcp+sni: TLS handshake failed (%s)", err)
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
return err
}

data := make([]byte, bufferSize)
_, err = io.ReadFull(tlsReader, data)
if err != nil {
log.Printf("[DEBUG] tcp+sni: TLS handshake failed (%s)", err)
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
return err
}

// readServerName wants only the handshake message so ignore the first
// 5 bytes which is the TLS record header
host, ok := readServerName(data[5:])
if !ok {
log.Print("[DEBUG] tcp+sni: TLS handshake failed")
log.Print("[DEBUG] tcp+sni: TLS handshake failed (unable to parse client hello)")
if p.ConnFail != nil {
p.ConnFail.Inc(1)
}
Expand Down Expand Up @@ -88,8 +109,8 @@ func (p *SNIProxy) ServeTCP(in net.Conn) error {
}
defer out.Close()

// copy client hello
n, err = out.Write(data)
// write the data already read from the connection
n, err := out.Write(data)
if err != nil {
log.Print("[WARN] tcp+sni: copy client hello failed. ", err)
if p.ConnFail != nil {
Expand Down
90 changes: 41 additions & 49 deletions proxy/tcp/tls_clienthello.go
Original file line number Diff line number Diff line change
@@ -1,73 +1,65 @@
package tcp

// record types
const (
handshakeRecord = 0x16
clientHelloType = 0x01
)

// readServerName returns the server name from a TLS ClientHello message which
// has the server_name extension (SNI). ok is set to true if the ClientHello
// message was parsed successfully. If the server_name extension was not set
// and empty string is returned as serverName.
func readServerName(data []byte) (serverName string, ok bool) {
if m, ok := readClientHello(data); ok {
return m.serverName, true
}
return "", false
}

// readClientHello
func readClientHello(data []byte) (m *clientHelloMsg, ok bool) {
if len(data) < 9 {
// println("buf too short")
return nil, false
}
import "errors"

// Determines the required size of a buffer large enough to hold
// a client hello message including the tls record header and the
// handshake message header.
// The function requires at least the first 9 bytes of the tls conversation
// in "data".
// An error is returned if the data does not follow the
// specification (https://tools.ietf.org/html/rfc5246) or if the client hello
// is fragmented over multiple records.
func clientHelloBufferSize(data []byte) (int, error) {
// TLS record header
// -----------------
// byte 0: rec type (should be 0x16 == Handshake)
// byte 1-2: version (should be 0x3000 < v < 0x3003)
// byte 3-4: rec len
recType := data[0]
if recType != handshakeRecord {
// println("no handshake ")
return nil, false
if len(data) < 9 {
return 0, errors.New("At least 9 bytes required to determine client hello length")
}

recLen := int(data[3])<<8 | int(data[4])
if recLen == 0 || recLen > len(data)-5 {
// println("rec too short")
return nil, false
if data[0] != 0x16 {
return 0, errors.New("Not a TLS handshake")
}

recordLength := int(data[3])<<8 | int(data[4])
if recordLength <= 0 || recordLength > 16384 {
return 0, errors.New("Invalid TLS record length")
}

// Handshake record header
// -----------------------
// byte 5: hs msg type (should be 0x01 == client_hello)
// byte 6-8: hs msg len
hsType := data[5]
if hsType != clientHelloType {
// println("no client_hello")
return nil, false
if data[5] != 0x01 {
return 0, errors.New("Not a client hello")
}

hsLen := int(data[6])<<16 | int(data[7])<<8 | int(data[8])
if hsLen == 0 || hsLen > len(data)-9 {
// println("handshake rec too short")
return nil, false
handshakeLength := int(data[6])<<16 | int(data[7])<<8 | int(data[8])
if handshakeLength <= 0 || handshakeLength > recordLength-4 {
return 0, errors.New("Invalid client hello length (fragmentation not implemented)")
}

// byte 9- : client hello msg
//
// m.unmarshal parses the entire handshake message and
// not just the client hello. Therefore, we need to pass
// data from byte 5 instead of byte 9. (see comment below)
m = new(clientHelloMsg)
if !m.unmarshal(data[5:]) {
// println("client_hello unmarshal failed")
return nil, false
return handshakeLength + 9, nil //9 for the header bytes
}

// readServerName returns the server name from a TLS ClientHello message which
// has the server_name extension (SNI). ok is set to true if the ClientHello
// message was parsed successfully. If the server_name extension was not set
// an empty string is returned as serverName.
// clientHelloHandshakeMsg must contain the full client hello handshake
// message including the 4 byte header.
// See: https://www.ietf.org/rfc/rfc5246.txt
func readServerName(clientHelloHandshakeMsg []byte) (serverName string, ok bool) {
m := new(clientHelloMsg)
if !m.unmarshal(clientHelloHandshakeMsg) {
//println("client_hello unmarshal failed")
return "", false
}
return m, true

return m.serverName, true
}

// The code below is a verbatim copy from go1.7/src/crypto/tls/handshake_messages.go
Expand Down
162 changes: 162 additions & 0 deletions proxy/tcp/tls_clienthello_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package tcp

import (
"encoding/hex"
"testing"
)

func TestClientHelloBufferSize(t *testing.T) {
tests := []struct {
name string
data []byte
size int
fail bool
}{
{
name: "valid data",
// Largest possible client hello message
// |- 16384 -| |----- 16380 ----|
data: []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x01, 0x00, 0x3f, 0xfc},
size: 16384 + 5, // max record length + record header
fail: false,
},
{
name: "not enough data",
data: []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x01, 0x00, 0x3f},
size: 0,
fail: true,
},
{
name: "not a TLS record",
data: []byte{0x15, 0x03, 0x01, 0x01, 0xF4, 0x01, 0x00, 0x01, 0xeb},
size: 0,
fail: true,
},

{
name: "TLS record too large",
// | max + 1 |
data: []byte{0x16, 0x03, 0x01, 0x40, 0x01, 0x01, 0x00, 0x3f, 0xfc},
size: 0,
fail: true,
},

{
name: "TLS record length zero",
// |----------|
data: []byte{0x16, 0x03, 0x01, 0x00, 0x00, 0x01, 0x00, 0x3f, 0xfc},
size: 0,
fail: true,
},

{
name: "Not a client hello",
// |----|
data: []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x02, 0x00, 0x3f, 0xfc},
size: 0,
fail: true,
},

{
name: "Invalid handshake message record length",
// |----- 0 --------|
data: []byte{0x16, 0x03, 0x01, 0x40, 0x00, 0x01, 0x00, 0x00, 0x00},
size: 0,
fail: true,
},

{
name: "Fragmentation (handshake message larger than record)",
// |- 500 ---| |----- 497 ------|
data: []byte{0x16, 0x03, 0x01, 0x01, 0xF4, 0x01, 0x00, 0x01, 0xf1},
size: 0,
fail: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := clientHelloBufferSize(tt.data)

if tt.fail && err == nil {
t.Fatal("expected error, got nil")
} else if !tt.fail && err != nil {
t.Fatalf("expected error to be nil, got %s", err)
}

if want := tt.size; got != want {
t.Fatalf("want size %d, got %d", want, got)
}
})
}
}

func TestReadServerName(t *testing.T) {
tests := []struct {
name string
servername string
ok bool
data string //Hex string, decoded by test
}{
{
// Client hello from:
// openssl s_client -connect google.com:443 -servername google.com
name: "valid client hello with server name",
servername: "google.com",
ok: true,
data: "0100014803032657cacce41598fa82e5b75061050bc31c5affdba106b8e7431852" +
"24af0fa1aa000098cc14cc13cc15c030c02cc028c024c014c00a00a3009f00" +
"6b006a00390038ff8500c400c3008800870081c032c02ec02ac026c00fc005" +
"009d003d003500c00084c02fc02bc027c023c013c00900a2009e0067004000" +
"33003200be00bd00450044c031c02dc029c025c00ec004009c003c002f00ba" +
"0041c011c007c00cc00200050004c012c00800160013c00dc003000a001500" +
"12000900ff010000870000000f000d00000a676f6f676c652e636f6d000b00" +
"0403000102000a003a0038000e000d0019001c000b000c001b00180009000a" +
"001a0016001700080006000700140015000400050012001300010002000300" +
"0f0010001100230000000d00260024060106020603efef0501050205030401" +
"04020403eeeeeded030103020303020102020203",
},
{
// Client hello from:
// openssl s_client -connect google.com:443
name: "valid client hello but no server name extension",
servername: "",
ok: true,
data: "0100013503036dfb09de7b16503dd1bb304dcbe54079913b65abf53de997f73b26c99e" +
"67ba28000098cc14cc13cc15c030c02cc028c024c014c00a00a3009f006b006a00" +
"390038ff8500c400c3008800870081c032c02ec02ac026c00fc005009d003d0035" +
"00c00084c02fc02bc027c023c013c00900a2009e006700400033003200be00bd00" +
"450044c031c02dc029c025c00ec004009c003c002f00ba0041c011c007c00cc002" +
"00050004c012c00800160013c00dc003000a00150012000900ff01000074000b00" +
"0403000102000a003a0038000e000d0019001c000b000c001b00180009000a001a" +
"00160017000800060007001400150004000500120013000100020003000f001000" +
"1100230000000d00260024060106020603efef050105020503040104020403eeee" +
"eded030103020303020102020203",
},
{
name: "invalid client hello",
servername: "",
ok: false,
data: "0100014c5768656e2070656f706c652073617920746f206d653a20776f756c6420796f" +
"75207261746865722062652074686f75676874206f6620617320612066756e6e79" +
"206d616e206f72206120677265617420626f73733f204d7920616e737765722773" +
"20616c77617973207468652073616d652c20746f206d652c207468657927726520" +
"6e6f74206d757475616c6c79206578636c75736976652e2d204461766964204272" +
"656e74",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientHelloMsg, _ := hex.DecodeString(tt.data)
servername, ok := readServerName(clientHelloMsg)
if got, want := servername, tt.servername; got != want {
t.Fatalf("%s: got servername \"%s\" want \"%s\"", tt.name, got, want)
}

if got, want := ok, tt.ok; got != want {
t.Fatalf("%s: got ok %t want %t", tt.name, got, want)
}
})
}
}