1
2
3
4
5
6
7
8 package http2_test
9
10 import (
11 "bytes"
12 "context"
13 "crypto/tls"
14 "fmt"
15 "internal/gate"
16 "io"
17 "net"
18 "net/http"
19 . "net/http/internal/http2"
20 "reflect"
21 "sync/atomic"
22 "testing"
23 "testing/synctest"
24 "time"
25 _ "unsafe"
26
27 "golang.org/x/net/http2/hpack"
28 )
29
30
31 func TestTestClientConn(t *testing.T) { synctestTest(t, testTestClientConn) }
32 func testTestClientConn(t testing.TB) {
33
34 tc := newTestClientConn(t)
35
36
37
38
39
40 tc.greet()
41
42
43
44 body := tc.newRequestBody()
45 body.writeBytes(10)
46 body.closeWithError(io.EOF)
47
48
49
50 req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
51 rt := tc.roundTrip(req)
52
53
54
55 tc.wantHeaders(wantHeader{
56 streamID: rt.streamID(),
57 endStream: false,
58 header: http.Header{
59 ":authority": []string{"dummy.tld"},
60 ":method": []string{"PUT"},
61 ":path": []string{"/"},
62 },
63 })
64
65 tc.wantData(wantData{
66 streamID: rt.streamID(),
67 endStream: true,
68 size: 10,
69 multiple: true,
70 })
71
72
73 tc.writeHeaders(HeadersFrameParam{
74 StreamID: rt.streamID(),
75 EndHeaders: true,
76 EndStream: true,
77 BlockFragment: tc.makeHeaderBlockFragment(
78 ":status", "200",
79 ),
80 })
81
82
83
84
85 rt.wantStatus(200)
86 rt.wantBody(nil)
87 }
88
89
90
91
92
93
94
95
96
97
98 type testClientConn struct {
99 t testing.TB
100
101 tr *Transport
102 fr *Framer
103 cc *ClientConn
104 testConnFramer
105
106 encbuf bytes.Buffer
107 enc *hpack.Encoder
108
109 roundtrips []*testRoundTrip
110
111 netconn *synctestNetConn
112 }
113
114 func newTestClientConnFromClientConn(t testing.TB, tr *Transport, cc *ClientConn) *testClientConn {
115 tc := &testClientConn{
116 t: t,
117 tr: tr,
118 cc: cc,
119 }
120
121
122
123
124 cli, srv := synctestNetPipe()
125 cc.TestSetNetConn(cli)
126
127 srv.SetReadDeadline(time.Now())
128 tc.netconn = srv
129 tc.enc = hpack.NewEncoder(&tc.encbuf)
130 tc.fr = NewFramer(srv, srv)
131 tc.testConnFramer = testConnFramer{
132 t: t,
133 fr: tc.fr,
134 dec: hpack.NewDecoder(InitialHeaderTableSize, nil),
135 }
136 tc.fr.SetMaxReadFrameSize(10 << 20)
137 t.Cleanup(func() {
138 tc.closeWrite()
139 })
140
141 return tc
142 }
143
144 func (tc *testClientConn) readClientPreface() {
145 tc.t.Helper()
146
147 buf := make([]byte, len(ClientPreface))
148 if _, err := io.ReadFull(tc.netconn, buf); err != nil {
149 tc.t.Fatalf("reading preface: %v", err)
150 }
151 if !bytes.Equal(buf, []byte(ClientPreface)) {
152 tc.t.Fatalf("client preface: %q, want %q", buf, ClientPreface)
153 }
154 }
155
156 func newTestClientConn(t testing.TB, opts ...any) *testClientConn {
157 t.Helper()
158
159 tt := newTestTransport(t, opts...)
160 const singleUse = false
161 tr := transportFromH1Transport(tt.tr1).(*Transport)
162 _, err := tr.TestNewClientConn(nil, singleUse, nil)
163 if err != nil {
164 t.Fatalf("newClientConn: %v", err)
165 }
166
167 return tt.getConn()
168 }
169
170
171 func (tc *testClientConn) hasFrame() bool {
172 synctest.Wait()
173 return len(tc.netconn.Peek()) > 0
174 }
175
176
177 func (tc *testClientConn) isClosed() bool {
178 synctest.Wait()
179 return tc.netconn.IsClosedByPeer()
180 }
181
182
183
184 func (tc *testClientConn) closeWrite() {
185 tc.netconn.Close()
186 }
187
188
189
190 func (tc *testClientConn) closeWriteWithError(err error) {
191 tc.netconn.loc.setReadError(io.EOF)
192 tc.netconn.loc.setWriteError(err)
193 }
194
195
196 type testRequestBody struct {
197 tc *testClientConn
198 gate gate.Gate
199
200
201 buf bytes.Buffer
202 bytes int
203
204 err error
205 }
206
207 func (tc *testClientConn) newRequestBody() *testRequestBody {
208 b := &testRequestBody{
209 tc: tc,
210 gate: gate.New(false),
211 }
212 return b
213 }
214
215 func (b *testRequestBody) unlock() {
216 b.gate.Unlock(b.buf.Len() > 0 || b.bytes > 0 || b.err != nil)
217 }
218
219
220 func (b *testRequestBody) Read(p []byte) (n int, _ error) {
221 if err := b.gate.WaitAndLock(context.Background()); err != nil {
222 return 0, err
223 }
224 defer b.unlock()
225 switch {
226 case b.buf.Len() > 0:
227 return b.buf.Read(p)
228 case b.bytes > 0:
229 if len(p) > b.bytes {
230 p = p[:b.bytes]
231 }
232 b.bytes -= len(p)
233 for i := range p {
234 p[i] = 'A'
235 }
236 return len(p), nil
237 default:
238 return 0, b.err
239 }
240 }
241
242
243 func (b *testRequestBody) Close() error {
244 return nil
245 }
246
247
248 func (b *testRequestBody) writeBytes(n int) {
249 defer synctest.Wait()
250 b.gate.Lock()
251 defer b.unlock()
252 b.bytes += n
253 b.checkWrite()
254 synctest.Wait()
255 }
256
257
258 func (b *testRequestBody) Write(p []byte) (int, error) {
259 defer synctest.Wait()
260 b.gate.Lock()
261 defer b.unlock()
262 n, err := b.buf.Write(p)
263 b.checkWrite()
264 return n, err
265 }
266
267 func (b *testRequestBody) checkWrite() {
268 if b.bytes > 0 && b.buf.Len() > 0 {
269 b.tc.t.Fatalf("can't interleave Write and writeBytes on request body")
270 }
271 if b.err != nil {
272 b.tc.t.Fatalf("can't write to request body after closeWithError")
273 }
274 }
275
276
277 func (b *testRequestBody) closeWithError(err error) {
278 defer synctest.Wait()
279 b.gate.Lock()
280 defer b.unlock()
281 b.err = err
282 }
283
284
285
286
287
288 func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip {
289 ctx, cancel := context.WithCancel(req.Context())
290 req = req.WithContext(ctx)
291 rt := &testRoundTrip{
292 t: tc.t,
293 donec: make(chan struct{}),
294 cancel: cancel,
295 }
296 tc.roundtrips = append(tc.roundtrips, rt)
297 go func() {
298
299
300
301
302
303
304
305 defer close(rt.donec)
306 cresp := &http.Response{}
307 creq := &ClientRequest{
308 Context: req.Context(),
309 Method: req.Method,
310 URL: req.URL,
311 Header: Header(req.Header),
312 Trailer: Header(req.Trailer),
313 Body: req.Body,
314 Host: req.Host,
315 GetBody: req.GetBody,
316 ContentLength: req.ContentLength,
317 Cancel: req.Cancel,
318 Close: req.Close,
319 ResTrailer: (*Header)(&cresp.Trailer),
320 }
321 resp, err := tc.cc.TestRoundTrip(creq, func(id uint32) {
322 rt.id.Store(id)
323 })
324 rt.respErr = err
325 if resp != nil {
326 cresp.Status = resp.Status + " " + http.StatusText(resp.StatusCode)
327 cresp.StatusCode = resp.StatusCode
328 cresp.Proto = "HTTP/2.0"
329 cresp.ProtoMajor = 2
330 cresp.ProtoMinor = 0
331 cresp.ContentLength = resp.ContentLength
332 cresp.Uncompressed = resp.Uncompressed
333 cresp.Header = http.Header(resp.Header)
334 cresp.Trailer = http.Header(resp.Trailer)
335 cresp.Body = resp.Body
336 cresp.TLS = resp.TLS
337 cresp.Request = req
338 rt.resp = cresp
339 }
340 }()
341 synctest.Wait()
342
343 tc.t.Cleanup(func() {
344 if !rt.done() {
345 return
346 }
347 res, _ := rt.result()
348 if res != nil {
349 res.Body.Close()
350 }
351 })
352
353 return rt
354 }
355
356 func (tc *testClientConn) greet(settings ...Setting) {
357 tc.wantFrameType(FrameSettings)
358 tc.wantFrameType(FrameWindowUpdate)
359 tc.writeSettings(settings...)
360 tc.writeSettingsAck()
361 tc.wantFrameType(FrameSettings)
362 }
363
364
365
366
367
368 func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte {
369 if len(s)%2 != 0 {
370 tc.t.Fatalf("uneven list of header name/value pairs")
371 }
372 tc.encbuf.Reset()
373 for i := 0; i < len(s); i += 2 {
374 tc.enc.WriteField(hpack.HeaderField{Name: s[i], Value: s[i+1]})
375 }
376 return tc.encbuf.Bytes()
377 }
378
379
380
381 func (tc *testClientConn) inflowWindow(streamID uint32) int32 {
382 w, err := tc.cc.TestInflowWindow(streamID)
383 if err != nil {
384 tc.t.Error(err)
385 }
386 return w
387 }
388
389
390 type testRoundTrip struct {
391 t testing.TB
392 resp *http.Response
393 respErr error
394 donec chan struct{}
395 id atomic.Uint32
396 cancel context.CancelFunc
397 }
398
399
400 func (rt *testRoundTrip) streamID() uint32 {
401 id := rt.id.Load()
402 if id == 0 {
403 panic("stream ID unknown")
404 }
405 return id
406 }
407
408
409 func (rt *testRoundTrip) done() bool {
410 synctest.Wait()
411 select {
412 case <-rt.donec:
413 return true
414 default:
415 return false
416 }
417 }
418
419
420 func (rt *testRoundTrip) result() (*http.Response, error) {
421 t := rt.t
422 t.Helper()
423 synctest.Wait()
424 select {
425 case <-rt.donec:
426 default:
427 t.Fatalf("RoundTrip is not done; want it to be")
428 }
429 return rt.resp, rt.respErr
430 }
431
432
433
434 func (rt *testRoundTrip) response() *http.Response {
435 t := rt.t
436 t.Helper()
437 resp, err := rt.result()
438 if err != nil {
439 t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr)
440 }
441 if resp == nil {
442 t.Fatalf("RoundTrip returned nil *Response and nil error")
443 }
444 return resp
445 }
446
447
448 func (rt *testRoundTrip) err() error {
449 t := rt.t
450 t.Helper()
451 _, err := rt.result()
452 return err
453 }
454
455
456 func (rt *testRoundTrip) wantStatus(want int) {
457 t := rt.t
458 t.Helper()
459 if got := rt.response().StatusCode; got != want {
460 t.Fatalf("got response status %v, want %v", got, want)
461 }
462 }
463
464
465 func (rt *testRoundTrip) readBody() ([]byte, error) {
466 t := rt.t
467 t.Helper()
468 return io.ReadAll(rt.response().Body)
469 }
470
471
472
473 func (rt *testRoundTrip) wantBody(want []byte) {
474 t := rt.t
475 t.Helper()
476 got, err := rt.readBody()
477 if err != nil {
478 t.Fatalf("unexpected error reading response body: %v", err)
479 }
480 if !bytes.Equal(got, want) {
481 t.Fatalf("unexpected response body:\ngot: %q\nwant: %q", got, want)
482 }
483 }
484
485
486 func (rt *testRoundTrip) wantHeaders(want http.Header) {
487 t := rt.t
488 t.Helper()
489 res := rt.response()
490 if diff := diffHeaders(res.Header, want); diff != "" {
491 t.Fatalf("unexpected response headers:\n%v", diff)
492 }
493 }
494
495
496 func (rt *testRoundTrip) wantTrailers(want http.Header) {
497 t := rt.t
498 t.Helper()
499 res := rt.response()
500 if diff := diffHeaders(res.Trailer, want); diff != "" {
501 t.Fatalf("unexpected response trailers:\n%v", diff)
502 }
503 }
504
505 func diffHeaders(got, want http.Header) string {
506
507 if len(got) == 0 && len(want) == 0 {
508 return ""
509 }
510
511
512 if reflect.DeepEqual(got, want) {
513 return ""
514 }
515 return fmt.Sprintf("got: %v\nwant: %v", got, want)
516 }
517
518
519
520
521 type testTransport struct {
522 t testing.TB
523 tr1 *http.Transport
524
525 ccs []*testClientConn
526 }
527
528 func newTestTransport(t testing.TB, opts ...any) *testTransport {
529 t.Helper()
530 tt := &testTransport{
531 t: t,
532 }
533
534 tr1 := &http.Transport{
535 DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
536
537
538
539
540
541
542
543 cli, srv := synctestNetPipe()
544 go func() {
545 tlsSrv := tls.Server(srv, testServerTLSConfig)
546 if err := tlsSrv.Handshake(); err != nil {
547 t.Errorf("unexpected TLS server handshake error: %v", err)
548 }
549 }()
550 return cli, nil
551 },
552 Protocols: protocols("h2"),
553 TLSClientConfig: testClientTLSConfig,
554 }
555 for _, o := range opts {
556 switch o := o.(type) {
557 case nil:
558 case func(*http.Transport):
559 o(tr1)
560 case func(*http.HTTP2Config):
561 if tr1.HTTP2 == nil {
562 tr1.HTTP2 = &http.HTTP2Config{}
563 }
564 o(tr1.HTTP2)
565 default:
566 t.Fatalf("unknown newTestTransport option type %T", o)
567 }
568 }
569 tt.tr1 = tr1
570
571 tr2 := transportFromH1Transport(tr1).(*Transport)
572 tr2.TestSetNewClientConnHook(func(cc *ClientConn) {
573 tc := newTestClientConnFromClientConn(t, tr2, cc)
574 tt.ccs = append(tt.ccs, tc)
575 })
576
577 t.Cleanup(func() {
578 synctest.Wait()
579 if len(tt.ccs) > 0 {
580 t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs))
581 }
582 })
583
584 return tt
585 }
586
587 func (tt *testTransport) hasConn() bool {
588 return len(tt.ccs) > 0
589 }
590
591 func (tt *testTransport) getConn() *testClientConn {
592 tt.t.Helper()
593 synctest.Wait()
594 if len(tt.ccs) == 0 {
595 tt.t.Fatalf("no new ClientConns created; wanted one")
596 }
597 tc := tt.ccs[0]
598 tt.ccs = tt.ccs[1:]
599 tc.readClientPreface()
600 synctest.Wait()
601 return tc
602 }
603
604 func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip {
605 ctx, cancel := context.WithCancel(req.Context())
606 req = req.WithContext(ctx)
607 rt := &testRoundTrip{
608 t: tt.t,
609 donec: make(chan struct{}),
610 cancel: cancel,
611 }
612 go func() {
613 defer close(rt.donec)
614 rt.resp, rt.respErr = tt.tr1.RoundTrip(req)
615 }()
616 synctest.Wait()
617
618 tt.t.Cleanup(func() {
619 if !rt.done() {
620 return
621 }
622 res, _ := rt.result()
623 if res != nil {
624 res.Body.Close()
625 }
626 })
627
628 return rt
629 }
630
View as plain text