@@ -49,30 +49,11 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
49
49
}
50
50
51
51
type msgWriter struct {
52
- mw * msgWriterState
53
- closed bool
54
- }
55
-
56
- func (mw * msgWriter ) Write (p []byte ) (int , error ) {
57
- if mw .closed {
58
- return 0 , errors .New ("cannot use closed writer" )
59
- }
60
- return mw .mw .Write (p )
61
- }
62
-
63
- func (mw * msgWriter ) Close () error {
64
- if mw .closed {
65
- return errors .New ("cannot use closed writer" )
66
- }
67
- mw .closed = true
68
- return mw .mw .Close ()
69
- }
70
-
71
- type msgWriterState struct {
72
52
c * Conn
73
53
74
54
mu * mu
75
55
writeMu * mu
56
+ closed bool
76
57
77
58
ctx context.Context
78
59
opcode opcode
@@ -82,16 +63,16 @@ type msgWriterState struct {
82
63
flateWriter * flate.Writer
83
64
}
84
65
85
- func newMsgWriterState (c * Conn ) * msgWriterState {
86
- mw := & msgWriterState {
66
+ func newMsgWriter (c * Conn ) * msgWriter {
67
+ mw := & msgWriter {
87
68
c : c ,
88
69
mu : newMu (c ),
89
70
writeMu : newMu (c ),
90
71
}
91
72
return mw
92
73
}
93
74
94
- func (mw * msgWriterState ) ensureFlate () {
75
+ func (mw * msgWriter ) ensureFlate () {
95
76
if mw .trimWriter == nil {
96
77
mw .trimWriter = & trimLastFourBytesWriter {
97
78
w : util .WriterFunc (mw .write ),
@@ -104,22 +85,19 @@ func (mw *msgWriterState) ensureFlate() {
104
85
mw .flate = true
105
86
}
106
87
107
- func (mw * msgWriterState ) flateContextTakeover () bool {
88
+ func (mw * msgWriter ) flateContextTakeover () bool {
108
89
if mw .c .client {
109
90
return ! mw .c .copts .clientNoContextTakeover
110
91
}
111
92
return ! mw .c .copts .serverNoContextTakeover
112
93
}
113
94
114
95
func (c * Conn ) writer (ctx context.Context , typ MessageType ) (io.WriteCloser , error ) {
115
- err := c .msgWriterState .reset (ctx , typ )
96
+ err := c .msgWriter .reset (ctx , typ )
116
97
if err != nil {
117
98
return nil , err
118
99
}
119
- return & msgWriter {
120
- mw : c .msgWriterState ,
121
- closed : false ,
122
- }, nil
100
+ return c .msgWriter , nil
123
101
}
124
102
125
103
func (c * Conn ) write (ctx context.Context , typ MessageType , p []byte ) (int , error ) {
@@ -129,8 +107,8 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
129
107
}
130
108
131
109
if ! c .flate () {
132
- defer c .msgWriterState .mu .unlock ()
133
- return c .writeFrame (ctx , true , false , c .msgWriterState .opcode , p )
110
+ defer c .msgWriter .mu .unlock ()
111
+ return c .writeFrame (ctx , true , false , c .msgWriter .opcode , p )
134
112
}
135
113
136
114
n , err := mw .Write (p )
@@ -142,7 +120,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
142
120
return n , err
143
121
}
144
122
145
- func (mw * msgWriterState ) reset (ctx context.Context , typ MessageType ) error {
123
+ func (mw * msgWriter ) reset (ctx context.Context , typ MessageType ) error {
146
124
err := mw .mu .lock (ctx )
147
125
if err != nil {
148
126
return err
@@ -151,21 +129,26 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
151
129
mw .ctx = ctx
152
130
mw .opcode = opcode (typ )
153
131
mw .flate = false
132
+ mw .closed = false
154
133
155
134
mw .trimWriter .reset ()
156
135
157
136
return nil
158
137
}
159
138
160
- func (mw * msgWriterState ) putFlateWriter () {
139
+ func (mw * msgWriter ) putFlateWriter () {
161
140
if mw .flateWriter != nil {
162
141
putFlateWriter (mw .flateWriter )
163
142
mw .flateWriter = nil
164
143
}
165
144
}
166
145
167
146
// Write writes the given bytes to the WebSocket connection.
168
- func (mw * msgWriterState ) Write (p []byte ) (_ int , err error ) {
147
+ func (mw * msgWriter ) Write (p []byte ) (_ int , err error ) {
148
+ if mw .closed {
149
+ return 0 , errors .New ("cannot use closed writer" )
150
+ }
151
+
169
152
err = mw .writeMu .lock (mw .ctx )
170
153
if err != nil {
171
154
return 0 , fmt .Errorf ("failed to write: %w" , err )
@@ -194,7 +177,7 @@ func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
194
177
return mw .write (p )
195
178
}
196
179
197
- func (mw * msgWriterState ) write (p []byte ) (int , error ) {
180
+ func (mw * msgWriter ) write (p []byte ) (int , error ) {
198
181
n , err := mw .c .writeFrame (mw .ctx , false , mw .flate , mw .opcode , p )
199
182
if err != nil {
200
183
return n , fmt .Errorf ("failed to write data frame: %w" , err )
@@ -204,9 +187,14 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
204
187
}
205
188
206
189
// Close flushes the frame to the connection.
207
- func (mw * msgWriterState ) Close () (err error ) {
190
+ func (mw * msgWriter ) Close () (err error ) {
208
191
defer errd .Wrap (& err , "failed to close writer" )
209
192
193
+ if mw .closed {
194
+ return errors .New ("writer already closed" )
195
+ }
196
+ mw .closed = true
197
+
210
198
err = mw .writeMu .lock (mw .ctx )
211
199
if err != nil {
212
200
return err
@@ -232,7 +220,7 @@ func (mw *msgWriterState) Close() (err error) {
232
220
return nil
233
221
}
234
222
235
- func (mw * msgWriterState ) close () {
223
+ func (mw * msgWriter ) close () {
236
224
if mw .c .client {
237
225
mw .c .writeFrameMu .forceLock ()
238
226
putBufioWriter (mw .c .bw )
0 commit comments