1
2
3
4
5 package http2_test
6
7 import (
8 "bytes"
9 "io"
10 "net/http"
11 "os"
12 "reflect"
13 "slices"
14 "testing"
15 "testing/synctest"
16
17 . "net/http/internal/http2"
18
19 "golang.org/x/net/http2/hpack"
20 )
21
22 type testConnFramer struct {
23 t testing.TB
24 fr *Framer
25 dec *hpack.Decoder
26 }
27
28
29
30 func (tf *testConnFramer) readFrame() Frame {
31 tf.t.Helper()
32 synctest.Wait()
33 fr, err := tf.fr.ReadFrame()
34 if err == io.EOF || err == os.ErrDeadlineExceeded {
35 return nil
36 }
37 if err != nil {
38 tf.t.Fatalf("ReadFrame: %v", err)
39 }
40 return fr
41 }
42
43 type readFramer interface {
44 readFrame() Frame
45 }
46
47
48 func readFrame[T any](t testing.TB, framer readFramer) T {
49 t.Helper()
50 var v T
51 fr := framer.readFrame()
52 if fr == nil {
53 t.Fatalf("got no frame, want frame %T", v)
54 }
55 v, ok := fr.(T)
56 if !ok {
57 t.Fatalf("got frame %T, want %T", fr, v)
58 }
59 return v
60 }
61
62
63
64 func (tf *testConnFramer) wantFrameType(want FrameType) {
65 tf.t.Helper()
66 fr := tf.readFrame()
67 if fr == nil {
68 tf.t.Fatalf("got no frame, want frame %v", want)
69 }
70 if got := fr.Header().Type; got != want {
71 tf.t.Fatalf("got frame %v, want %v", got, want)
72 }
73 }
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97 func (tf *testConnFramer) wantUnorderedFrames(want ...any) {
98 tf.t.Helper()
99 want = slices.Clone(want)
100 seen := 0
101 frame:
102 for seen < len(want) && !tf.t.Failed() {
103 fr := tf.readFrame()
104 if fr == nil {
105 break
106 }
107 for i, f := range want {
108 if f == nil {
109 continue
110 }
111 typ := reflect.TypeOf(f)
112 if typ.Kind() != reflect.Func ||
113 typ.NumIn() != 1 ||
114 typ.NumOut() != 1 ||
115 typ.Out(0) != reflect.TypeFor[bool]() {
116 tf.t.Fatalf("expected func(*SomeFrame) bool, got %T", f)
117 }
118 if typ.In(0) == reflect.TypeOf(fr) {
119 out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)})
120 if out[0].Bool() {
121 want[i] = nil
122 seen++
123 }
124 continue frame
125 }
126 }
127 tf.t.Errorf("got unexpected frame type %T", fr)
128 }
129 if seen < len(want) {
130 for _, f := range want {
131 if f == nil {
132 continue
133 }
134 tf.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0))
135 }
136 tf.t.Fatalf("did not see %v expected frame types", len(want)-seen)
137 }
138 }
139
140 type wantHeader struct {
141 streamID uint32
142 endStream bool
143 header http.Header
144 }
145
146
147
148 func (tf *testConnFramer) wantHeaders(want wantHeader) {
149 tf.t.Helper()
150
151 hf := readFrame[*HeadersFrame](tf.t, tf)
152 if got, want := hf.StreamID, want.streamID; got != want {
153 tf.t.Fatalf("got stream ID %v, want %v", got, want)
154 }
155 if got, want := hf.StreamEnded(), want.endStream; got != want {
156 tf.t.Fatalf("got stream ended %v, want %v", got, want)
157 }
158
159 gotHeader := make(http.Header)
160 tf.dec.SetEmitFunc(func(hf hpack.HeaderField) {
161 gotHeader[hf.Name] = append(gotHeader[hf.Name], hf.Value)
162 })
163 defer tf.dec.SetEmitFunc(nil)
164 if _, err := tf.dec.Write(hf.HeaderBlockFragment()); err != nil {
165 tf.t.Fatalf("decoding HEADERS frame: %v", err)
166 }
167 headersEnded := hf.HeadersEnded()
168 for !headersEnded {
169 cf := readFrame[*ContinuationFrame](tf.t, tf)
170 if cf == nil {
171 tf.t.Fatalf("got end of frames, want CONTINUATION")
172 }
173 if _, err := tf.dec.Write(cf.HeaderBlockFragment()); err != nil {
174 tf.t.Fatalf("decoding CONTINUATION frame: %v", err)
175 }
176 headersEnded = cf.HeadersEnded()
177 }
178 if err := tf.dec.Close(); err != nil {
179 tf.t.Fatalf("hpack decoding error: %v", err)
180 }
181
182 for k, v := range want.header {
183 if !reflect.DeepEqual(v, gotHeader[k]) {
184 tf.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k])
185 }
186 }
187 }
188
189
190
191 func (tf *testConnFramer) decodeHeader(headerBlock []byte) (pairs [][2]string) {
192 tf.dec.SetEmitFunc(func(hf hpack.HeaderField) {
193 if hf.Name == "date" {
194 return
195 }
196 pairs = append(pairs, [2]string{hf.Name, hf.Value})
197 })
198 defer tf.dec.SetEmitFunc(nil)
199 if _, err := tf.dec.Write(headerBlock); err != nil {
200 tf.t.Fatalf("hpack decoding error: %v", err)
201 }
202 if err := tf.dec.Close(); err != nil {
203 tf.t.Fatalf("hpack decoding error: %v", err)
204 }
205 return pairs
206 }
207
208 type wantData struct {
209 streamID uint32
210 endStream bool
211 size int
212 data []byte
213 multiple bool
214 }
215
216
217 func (tf *testConnFramer) wantData(want wantData) {
218 tf.t.Helper()
219 gotSize := 0
220 gotEndStream := false
221 if want.data != nil {
222 want.size = len(want.data)
223 }
224 var gotData []byte
225 for {
226 fr := tf.readFrame()
227 if fr == nil {
228 break
229 }
230 data, ok := fr.(*DataFrame)
231 if !ok {
232 tf.t.Fatalf("got frame %T, want DataFrame", fr)
233 }
234 if want.data != nil {
235 gotData = append(gotData, data.Data()...)
236 }
237 gotSize += len(data.Data())
238 if data.StreamEnded() {
239 gotEndStream = true
240 break
241 }
242 if !want.endStream && gotSize >= want.size {
243 break
244 }
245 if !want.multiple {
246 break
247 }
248 }
249 if gotSize != want.size {
250 tf.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size)
251 }
252 if gotEndStream != want.endStream {
253 tf.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream)
254 }
255 if want.data != nil && !bytes.Equal(gotData, want.data) {
256 tf.t.Fatalf("got data %q, want %q", gotData, want.data)
257 }
258 }
259
260 func (tf *testConnFramer) wantRSTStream(streamID uint32, code ErrCode) {
261 tf.t.Helper()
262 fr := readFrame[*RSTStreamFrame](tf.t, tf)
263 if fr.StreamID != streamID || fr.ErrCode != code {
264 tf.t.Fatalf("got %v, want RST_STREAM StreamID=%v, code=%v", SummarizeFrame(fr), streamID, code)
265 }
266 }
267
268 func (tf *testConnFramer) wantSettings(want map[SettingID]uint32) {
269 fr := readFrame[*SettingsFrame](tf.t, tf)
270 if fr.Header().Flags.Has(FlagSettingsAck) {
271 tf.t.Errorf("got SETTINGS frame with ACK set, want no ACK")
272 }
273 for wantID, wantVal := range want {
274 gotVal, ok := fr.Value(wantID)
275 if !ok {
276 tf.t.Errorf("SETTINGS: %v is not set, want %v", wantID, wantVal)
277 } else if gotVal != wantVal {
278 tf.t.Errorf("SETTINGS: %v is %v, want %v", wantID, gotVal, wantVal)
279 }
280 }
281 if tf.t.Failed() {
282 tf.t.Fatalf("%v", fr)
283 }
284 }
285
286 func (tf *testConnFramer) wantSettingsAck() {
287 tf.t.Helper()
288 fr := readFrame[*SettingsFrame](tf.t, tf)
289 if !fr.Header().Flags.Has(FlagSettingsAck) {
290 tf.t.Fatal("Settings Frame didn't have ACK set")
291 }
292 }
293
294 func (tf *testConnFramer) wantGoAway(maxStreamID uint32, code ErrCode) {
295 tf.t.Helper()
296 fr := readFrame[*GoAwayFrame](tf.t, tf)
297 if fr.LastStreamID != maxStreamID || fr.ErrCode != code {
298 tf.t.Fatalf("got %v, want GOAWAY LastStreamID=%v, code=%v", SummarizeFrame(fr), maxStreamID, code)
299 }
300 }
301
302 func (tf *testConnFramer) wantWindowUpdate(streamID, incr uint32) {
303 tf.t.Helper()
304 wu := readFrame[*WindowUpdateFrame](tf.t, tf)
305 if wu.FrameHeader.StreamID != streamID {
306 tf.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID)
307 }
308 if wu.Increment != incr {
309 tf.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr)
310 }
311 }
312
313 func (tf *testConnFramer) wantClosed() {
314 tf.t.Helper()
315 synctest.Wait()
316 fr, err := tf.fr.ReadFrame()
317 if err == nil {
318 tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr)
319 }
320 if err == os.ErrDeadlineExceeded {
321 tf.t.Fatalf("connection is not closed; want it to be")
322 }
323 }
324
325 func (tf *testConnFramer) wantIdle() {
326 tf.t.Helper()
327 synctest.Wait()
328 fr, err := tf.fr.ReadFrame()
329 if err == nil {
330 tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr)
331 }
332 if err != os.ErrDeadlineExceeded {
333 tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err)
334 }
335 }
336
337 func (tf *testConnFramer) writeSettings(settings ...Setting) {
338 tf.t.Helper()
339 if err := tf.fr.WriteSettings(settings...); err != nil {
340 tf.t.Fatal(err)
341 }
342 }
343
344 func (tf *testConnFramer) writeSettingsAck() {
345 tf.t.Helper()
346 if err := tf.fr.WriteSettingsAck(); err != nil {
347 tf.t.Fatal(err)
348 }
349 }
350
351 func (tf *testConnFramer) writeData(streamID uint32, endStream bool, data []byte) {
352 tf.t.Helper()
353 if err := tf.fr.WriteData(streamID, endStream, data); err != nil {
354 tf.t.Fatal(err)
355 }
356 }
357
358 func (tf *testConnFramer) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
359 tf.t.Helper()
360 if err := tf.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
361 tf.t.Fatal(err)
362 }
363 }
364
365 func (tf *testConnFramer) writeHeaders(p HeadersFrameParam) {
366 tf.t.Helper()
367 if err := tf.fr.WriteHeaders(p); err != nil {
368 tf.t.Fatal(err)
369 }
370 }
371
372
373
374
375
376
377 func (tf *testConnFramer) writeHeadersMode(mode headerType, p HeadersFrameParam) {
378 tf.t.Helper()
379 switch mode {
380 case noHeader:
381 case oneHeader:
382 tf.writeHeaders(p)
383 case splitHeader:
384 if len(p.BlockFragment) < 2 {
385 panic("too small")
386 }
387 contData := p.BlockFragment[1:]
388 contEnd := p.EndHeaders
389 p.BlockFragment = p.BlockFragment[:1]
390 p.EndHeaders = false
391 tf.writeHeaders(p)
392 tf.writeContinuation(p.StreamID, contEnd, contData)
393 default:
394 panic("bogus mode")
395 }
396 }
397
398 func (tf *testConnFramer) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) {
399 tf.t.Helper()
400 if err := tf.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil {
401 tf.t.Fatal(err)
402 }
403 }
404
405 func (tf *testConnFramer) writePriority(id uint32, p PriorityParam) {
406 if err := tf.fr.WritePriority(id, p); err != nil {
407 tf.t.Fatal(err)
408 }
409 }
410
411 func (tf *testConnFramer) writePriorityUpdate(id uint32, p string) {
412 if err := tf.fr.WritePriorityUpdate(id, p); err != nil {
413 tf.t.Fatal(err)
414 }
415 }
416
417 func (tf *testConnFramer) writeRSTStream(streamID uint32, code ErrCode) {
418 tf.t.Helper()
419 if err := tf.fr.WriteRSTStream(streamID, code); err != nil {
420 tf.t.Fatal(err)
421 }
422 }
423
424 func (tf *testConnFramer) writePing(ack bool, data [8]byte) {
425 tf.t.Helper()
426 if err := tf.fr.WritePing(ack, data); err != nil {
427 tf.t.Fatal(err)
428 }
429 }
430
431 func (tf *testConnFramer) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) {
432 tf.t.Helper()
433 if err := tf.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
434 tf.t.Fatal(err)
435 }
436 }
437
438 func (tf *testConnFramer) writeWindowUpdate(streamID, incr uint32) {
439 tf.t.Helper()
440 if err := tf.fr.WriteWindowUpdate(streamID, incr); err != nil {
441 tf.t.Fatal(err)
442 }
443 }
444
View as plain text