forked from hashicorp/memberlist
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlabel.go
181 lines (151 loc) · 4.79 KB
/
label.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package memberlist
import (
"bufio"
"fmt"
"io"
"net"
)
// General approach is to prefix all packets and streams with the same structure:
//
// magic type byte (244): uint8
// length of label name: uint8 (because labels can't be longer than 255 bytes)
// label name: []uint8
// LabelMaxSize is the maximum length of a packet or stream label.
const LabelMaxSize = 255
// AddLabelHeaderToPacket prefixes outgoing packets with the correct header if
// the label is not empty.
func AddLabelHeaderToPacket(buf []byte, label string) ([]byte, error) {
if label == "" {
return buf, nil
}
if len(label) > LabelMaxSize {
return nil, fmt.Errorf("label %q is too long", label)
}
return makeLabelHeader(label, buf), nil
}
// RemoveLabelHeaderFromPacket removes any label header from the provided
// packet and returns it along with the remaining packet contents.
func RemoveLabelHeaderFromPacket(buf []byte) (newBuf []byte, label string, err error) {
if len(buf) == 0 {
return buf, "", nil // can't possibly be labeled
}
// [type:byte] [size:byte] [size bytes]
msgType := messageType(buf[0])
if msgType != hasLabelMsg {
return buf, "", nil
}
if len(buf) < 2 {
return nil, "", fmt.Errorf("cannot decode label; packet has been truncated")
}
size := int(buf[1])
if size < 1 {
return nil, "", fmt.Errorf("label header cannot be empty when present")
}
if len(buf) < 2+size {
return nil, "", fmt.Errorf("cannot decode label; packet has been truncated")
}
label = string(buf[2 : 2+size])
newBuf = buf[2+size:]
return newBuf, label, nil
}
// AddLabelHeaderToStream prefixes outgoing streams with the correct header if
// the label is not empty.
func AddLabelHeaderToStream(conn net.Conn, label string) error {
if label == "" {
return nil
}
if len(label) > LabelMaxSize {
return fmt.Errorf("label %q is too long", label)
}
header := makeLabelHeader(label, nil)
_, err := conn.Write(header)
return err
}
// RemoveLabelHeaderFromStream removes any label header from the beginning of
// the stream if present and returns it along with an updated conn with that
// header removed.
//
// Note that on error it is the caller's responsibility to close the
// connection.
func RemoveLabelHeaderFromStream(conn net.Conn) (net.Conn, string, error) {
br := bufio.NewReader(conn)
// First check for the type byte.
peeked, err := br.Peek(1)
if err != nil {
if err == io.EOF {
// It is safe to return the original net.Conn at this point because
// it never contained any data in the first place so we don't have
// to splice the buffer into the conn because both are empty.
return conn, "", nil
}
return nil, "", err
}
msgType := messageType(peeked[0])
if msgType != hasLabelMsg {
conn, err = newPeekedConnFromBufferedReader(conn, br, 0)
return conn, "", err
}
// We are guaranteed to get a size byte as well.
peeked, err = br.Peek(2)
if err != nil {
if err == io.EOF {
return nil, "", fmt.Errorf("cannot decode label; stream has been truncated")
}
return nil, "", err
}
size := int(peeked[1])
if size < 1 {
return nil, "", fmt.Errorf("label header cannot be empty when present")
}
// NOTE: we don't have to check this against LabelMaxSize because a byte
// already has a max value of 255.
// Once we know the size we can peek the label as well. Note that since we
// are using the default bufio.Reader size of 4096, the entire label header
// fits in the initial buffer fill so this should be free.
peeked, err = br.Peek(2 + size)
if err != nil {
if err == io.EOF {
return nil, "", fmt.Errorf("cannot decode label; stream has been truncated")
}
return nil, "", err
}
label := string(peeked[2 : 2+size])
conn, err = newPeekedConnFromBufferedReader(conn, br, 2+size)
if err != nil {
return nil, "", err
}
return conn, label, nil
}
// newPeekedConnFromBufferedReader will splice the buffer contents after the
// offset into the provided net.Conn and return the result so that the rest of
// the buffer contents are returned first when reading from the returned
// peekedConn before moving on to the unbuffered conn contents.
func newPeekedConnFromBufferedReader(conn net.Conn, br *bufio.Reader, offset int) (*peekedConn, error) {
// Extract any of the readahead buffer.
peeked, err := br.Peek(br.Buffered())
if err != nil {
return nil, err
}
return &peekedConn{
Peeked: peeked[offset:],
Conn: conn,
}, nil
}
func makeLabelHeader(label string, rest []byte) []byte {
newBuf := make([]byte, 2, 2+len(label)+len(rest))
newBuf[0] = byte(hasLabelMsg)
newBuf[1] = byte(len(label))
newBuf = append(newBuf, []byte(label)...)
if len(rest) > 0 {
newBuf = append(newBuf, []byte(rest)...)
}
return newBuf
}
func labelOverhead(label string) int {
if label == "" {
return 0
}
return 2 + len(label)
}