1
2
3
4
5 package http2_test
6
7 import (
8 "bufio"
9 "bytes"
10 "compress/gzip"
11 "context"
12 crand "crypto/rand"
13 "crypto/tls"
14 "encoding/hex"
15 "errors"
16 "flag"
17 "fmt"
18 "io"
19 "log"
20 "math/rand"
21 "net"
22 "net/http"
23 "net/http/httptest"
24 "net/http/httptrace"
25 "net/textproto"
26 "net/url"
27 "os"
28 "reflect"
29 "sort"
30 "strconv"
31 "strings"
32 "sync"
33 "sync/atomic"
34 "testing"
35 "testing/synctest"
36 "time"
37
38 . "net/http/internal/http2"
39 "net/http/internal/httpcommon"
40
41 "golang.org/x/net/http2/hpack"
42 )
43
44 var (
45 extNet = flag.Bool("extnet", false, "do external network tests")
46 transportHost = flag.String("transporthost", "go.dev", "hostname to use for TestTransport")
47 )
48
49 var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
50
51 var canceledCtx context.Context
52
53 func init() {
54 ctx, cancel := context.WithCancel(context.Background())
55 cancel()
56 canceledCtx = ctx
57 }
58
59
60 func newTransport(t testing.TB, opts ...any) *http.Transport {
61 tr1 := &http.Transport{
62 TLSClientConfig: tlsConfigInsecure,
63 Protocols: protocols("h2"),
64 HTTP2: &http.HTTP2Config{},
65 }
66 for _, o := range opts {
67 switch o := o.(type) {
68 case func(*http.Transport):
69 o(tr1)
70 case func(*http.HTTP2Config):
71 o(tr1.HTTP2)
72 default:
73 t.Fatalf("unknown newTransport option %T", o)
74 }
75 }
76 t.Cleanup(tr1.CloseIdleConnections)
77 return tr1
78 }
79
80 func TestTransportExternal(t *testing.T) {
81 if !*extNet {
82 t.Skip("skipping external network test")
83 }
84 req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
85 rt := newTransport(t)
86 res, err := rt.RoundTrip(req)
87 if err != nil {
88 t.Fatalf("%v", err)
89 }
90 res.Write(os.Stdout)
91 }
92
93 func TestIdleConnTimeout(t *testing.T) {
94 for _, test := range []struct {
95 name string
96 idleConnTimeout time.Duration
97 wait time.Duration
98 baseTransport *http.Transport
99 wantNewConn bool
100 }{{
101 name: "NoExpiry",
102 idleConnTimeout: 2 * time.Second,
103 wait: 1 * time.Second,
104 baseTransport: nil,
105 wantNewConn: false,
106 }, {
107 name: "H2TransportTimeoutExpires",
108 idleConnTimeout: 1 * time.Second,
109 wait: 2 * time.Second,
110 baseTransport: nil,
111 wantNewConn: true,
112 }, {
113 name: "H1TransportTimeoutExpires",
114 idleConnTimeout: 0 * time.Second,
115 wait: 1 * time.Second,
116 baseTransport: newTransport(t, func(tr1 *http.Transport) {
117 tr1.IdleConnTimeout = 2 * time.Second
118 }),
119 wantNewConn: false,
120 }} {
121 synctestSubtest(t, test.name, func(t testing.TB) {
122 tt := newTestTransport(t, func(tr *http.Transport) {
123 tr.IdleConnTimeout = test.idleConnTimeout
124 })
125 var tc *testClientConn
126 for i := 0; i < 3; i++ {
127 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
128 rt := tt.roundTrip(req)
129
130
131
132
133 wantConn := i == 0 || test.wantNewConn
134 if has := tt.hasConn(); has != wantConn {
135 t.Fatalf("request %v: hasConn=%v, want %v", i, has, wantConn)
136 }
137 if wantConn {
138 tc = tt.getConn()
139
140
141 tc.wantFrameType(FrameSettings)
142 tc.wantFrameType(FrameWindowUpdate)
143 tc.writeSettings()
144 }
145 if tt.hasConn() {
146 t.Fatalf("request %v: Transport has more than one conn", i)
147 }
148
149
150 hf := readFrame[*HeadersFrame](t, tc)
151 tc.writeHeaders(HeadersFrameParam{
152 StreamID: hf.StreamID,
153 EndHeaders: true,
154 EndStream: true,
155 BlockFragment: tc.makeHeaderBlockFragment(
156 ":status", "200",
157 ),
158 })
159 rt.wantStatus(200)
160
161
162 if wantConn {
163 tc.wantFrameType(FrameSettings)
164 }
165
166 time.Sleep(test.wait)
167 if got, want := tc.isClosed(), test.wantNewConn; got != want {
168 t.Fatalf("after waiting %v, conn closed=%v; want %v", test.wait, got, want)
169 }
170 }
171 })
172 }
173 }
174
175 func TestTransportH2c(t *testing.T) {
176 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
177 fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil)
178 }, func(s *http.Server) {
179 s.Protocols = protocols("h2c")
180 })
181 req, err := http.NewRequest("GET", ts.URL+"/foobar", nil)
182 if err != nil {
183 t.Fatal(err)
184 }
185 var gotConnCnt int32
186 trace := &httptrace.ClientTrace{
187 GotConn: func(connInfo httptrace.GotConnInfo) {
188 if !connInfo.Reused {
189 atomic.AddInt32(&gotConnCnt, 1)
190 }
191 },
192 }
193 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
194 tr := newTransport(t)
195 tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
196 return net.Dial(network, addr)
197 }
198 tr.Protocols = protocols("h2c")
199 res, err := tr.RoundTrip(req)
200 if err != nil {
201 t.Fatal(err)
202 }
203 if res.ProtoMajor != 2 {
204 t.Fatal("proto not h2c")
205 }
206 body, err := io.ReadAll(res.Body)
207 if err != nil {
208 t.Fatal(err)
209 }
210 if got, want := string(body), "Hello, /foobar, http: true"; got != want {
211 t.Fatalf("response got %v, want %v", got, want)
212 }
213 if got, want := gotConnCnt, int32(1); got != want {
214 t.Errorf("Too many got connections: %d", gotConnCnt)
215 }
216 }
217
218 func TestTransport(t *testing.T) {
219 const body = "sup"
220 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
221 io.WriteString(w, body)
222 })
223
224 tr := ts.Client().Transport.(*http.Transport)
225 defer tr.CloseIdleConnections()
226
227 u, err := url.Parse(ts.URL)
228 if err != nil {
229 t.Fatal(err)
230 }
231 for i, m := range []string{"GET", ""} {
232 req := &http.Request{
233 Method: m,
234 URL: u,
235 Header: http.Header{},
236 }
237 res, err := tr.RoundTrip(req)
238 if err != nil {
239 t.Fatalf("%d: %s", i, err)
240 }
241
242 t.Logf("%d: Got res: %+v", i, res)
243 if g, w := res.StatusCode, 200; g != w {
244 t.Errorf("%d: StatusCode = %v; want %v", i, g, w)
245 }
246 if g, w := res.Status, "200 OK"; g != w {
247 t.Errorf("%d: Status = %q; want %q", i, g, w)
248 }
249 wantHeader := http.Header{
250 "Content-Length": []string{"3"},
251 "Content-Type": []string{"text/plain; charset=utf-8"},
252 "Date": []string{"XXX"},
253 }
254
255 if d := res.Header["Date"]; len(d) == 1 {
256 d[0] = "XXX"
257 }
258 if !reflect.DeepEqual(res.Header, wantHeader) {
259 t.Errorf("%d: res Header = %v; want %v", i, res.Header, wantHeader)
260 }
261 if res.Request != req {
262 t.Errorf("%d: Response.Request = %p; want %p", i, res.Request, req)
263 }
264 if res.TLS == nil {
265 t.Errorf("%d: Response.TLS = nil; want non-nil", i)
266 }
267 slurp, err := io.ReadAll(res.Body)
268 if err != nil {
269 t.Errorf("%d: Body read: %v", i, err)
270 } else if string(slurp) != body {
271 t.Errorf("%d: Body = %q; want %q", i, slurp, body)
272 }
273 res.Body.Close()
274 }
275 }
276
277 func TestTransportFailureErrorForHTTP1Response(t *testing.T) {
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299 t.Skip("test is racy")
300
301 const expectedHTTP1PayloadHint = "frame header looked like an HTTP/1.1 header"
302
303 ts := httptest.NewServer(http.NewServeMux())
304 t.Cleanup(ts.Close)
305
306 for _, tc := range []struct {
307 name string
308 maxFrameSize uint32
309 expectedErrorIs error
310 }{
311 {
312 name: "with default max frame size",
313 maxFrameSize: 0,
314 },
315 {
316 name: "with enough frame size to start reading",
317 maxFrameSize: InvalidHTTP1LookingFrameHeader().Length + 1,
318 },
319 } {
320 t.Run(tc.name, func(t *testing.T) {
321 tr := newTransport(t)
322 tr.HTTP2.MaxReadFrameSize = int(tc.maxFrameSize)
323 tr.Protocols = protocols("h2c")
324
325 req, err := http.NewRequest("GET", ts.URL, nil)
326 if err != nil {
327 t.Fatal(err)
328 }
329
330 _, err = tr.RoundTrip(req)
331 if err == nil || !strings.Contains(err.Error(), expectedHTTP1PayloadHint) {
332 t.Errorf("expected error to contain %q, got %v", expectedHTTP1PayloadHint, err)
333 }
334 })
335 }
336 }
337
338 func testTransportReusesConns(t *testing.T, wantSame bool, modReq func(*http.Request)) {
339 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
340 io.WriteString(w, r.RemoteAddr)
341 }, func(ts *httptest.Server) {
342 ts.Config.ConnState = func(c net.Conn, st http.ConnState) {
343 t.Logf("conn %v is now state %v", c.RemoteAddr(), st)
344 }
345 })
346 tr := newTransport(t)
347 get := func() string {
348 req, err := http.NewRequest("GET", ts.URL, nil)
349 if err != nil {
350 t.Fatal(err)
351 }
352 modReq(req)
353 res, err := tr.RoundTrip(req)
354 if err != nil {
355 t.Fatal(err)
356 }
357 defer res.Body.Close()
358 slurp, err := io.ReadAll(res.Body)
359 if err != nil {
360 t.Fatalf("Body read: %v", err)
361 }
362 addr := strings.TrimSpace(string(slurp))
363 if addr == "" {
364 t.Fatalf("didn't get an addr in response")
365 }
366 return addr
367 }
368 first := get()
369 second := get()
370 if got := first == second; got != wantSame {
371 t.Errorf("first and second responses on same connection: %v; want %v", got, wantSame)
372 }
373 }
374
375 func TestTransportReusesConns(t *testing.T) {
376 for _, test := range []struct {
377 name string
378 modReq func(*http.Request)
379 wantSame bool
380 }{{
381 name: "ReuseConn",
382 modReq: func(*http.Request) {},
383 wantSame: true,
384 }, {
385 name: "RequestClose",
386 modReq: func(r *http.Request) { r.Close = true },
387 wantSame: false,
388 }, {
389 name: "ConnClose",
390 modReq: func(r *http.Request) { r.Header.Set("Connection", "close") },
391 wantSame: false,
392 }} {
393 t.Run(test.name, func(t *testing.T) {
394 testTransportReusesConns(t, test.wantSame, test.modReq)
395 })
396 }
397 }
398
399 func TestTransportGetGotConnHooks_HTTP2Transport(t *testing.T) {
400 testTransportGetGotConnHooks(t, false)
401 }
402 func TestTransportGetGotConnHooks_Client(t *testing.T) { testTransportGetGotConnHooks(t, true) }
403
404 func testTransportGetGotConnHooks(t *testing.T, useClient bool) {
405 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
406 io.WriteString(w, r.RemoteAddr)
407 })
408
409 tr := newTransport(t)
410 client := ts.Client()
411
412 var (
413 getConns int32
414 gotConns int32
415 )
416 for i := 0; i < 2; i++ {
417 trace := &httptrace.ClientTrace{
418 GetConn: func(hostport string) {
419 atomic.AddInt32(&getConns, 1)
420 },
421 GotConn: func(connInfo httptrace.GotConnInfo) {
422 got := atomic.AddInt32(&gotConns, 1)
423 wantReused, wantWasIdle := false, false
424 if got > 1 {
425 wantReused, wantWasIdle = true, true
426 }
427 if connInfo.Reused != wantReused || connInfo.WasIdle != wantWasIdle {
428 t.Errorf("GotConn %v: Reused=%v (want %v), WasIdle=%v (want %v)", i, connInfo.Reused, wantReused, connInfo.WasIdle, wantWasIdle)
429 }
430 },
431 }
432 req, err := http.NewRequest("GET", ts.URL, nil)
433 if err != nil {
434 t.Fatal(err)
435 }
436 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
437
438 var res *http.Response
439 if useClient {
440 res, err = client.Do(req)
441 } else {
442 res, err = tr.RoundTrip(req)
443 }
444 if err != nil {
445 t.Fatal(err)
446 }
447 res.Body.Close()
448 if get := atomic.LoadInt32(&getConns); get != int32(i+1) {
449 t.Errorf("after request %v, %v calls to GetConns: want %v", i, get, i+1)
450 }
451 if got := atomic.LoadInt32(&gotConns); got != int32(i+1) {
452 t.Errorf("after request %v, %v calls to GotConns: want %v", i, got, i+1)
453 }
454 }
455 }
456
457 func TestTransportAbortClosesPipes(t *testing.T) {
458 shutdown := make(chan struct{})
459 ts := newTestServer(t,
460 func(w http.ResponseWriter, r *http.Request) {
461 w.(http.Flusher).Flush()
462 <-shutdown
463 },
464 )
465 defer close(shutdown)
466
467 errCh := make(chan error)
468 go func() {
469 defer close(errCh)
470 tr := newTransport(t)
471 req, err := http.NewRequest("GET", ts.URL, nil)
472 if err != nil {
473 errCh <- err
474 return
475 }
476 res, err := tr.RoundTrip(req)
477 if err != nil {
478 errCh <- err
479 return
480 }
481 defer res.Body.Close()
482 ts.CloseClientConnections()
483 _, err = io.ReadAll(res.Body)
484 if err == nil {
485 errCh <- errors.New("expected error from res.Body.Read")
486 return
487 }
488 }()
489
490 select {
491 case err := <-errCh:
492 if err != nil {
493 t.Fatal(err)
494 }
495
496 case <-time.After(3 * time.Second):
497 t.Fatal("timeout")
498 }
499 }
500
501
502
503 func TestTransportPath(t *testing.T) {
504 gotc := make(chan *url.URL, 1)
505 ts := newTestServer(t,
506 func(w http.ResponseWriter, r *http.Request) {
507 gotc <- r.URL
508 },
509 )
510
511 tr := newTransport(t)
512 const (
513 path = "/testpath"
514 query = "q=1"
515 )
516 surl := ts.URL + path + "?" + query
517 req, err := http.NewRequest("POST", surl, nil)
518 if err != nil {
519 t.Fatal(err)
520 }
521 c := &http.Client{Transport: tr}
522 res, err := c.Do(req)
523 if err != nil {
524 t.Fatal(err)
525 }
526 defer res.Body.Close()
527 got := <-gotc
528 if got.Path != path {
529 t.Errorf("Read Path = %q; want %q", got.Path, path)
530 }
531 if got.RawQuery != query {
532 t.Errorf("Read RawQuery = %q; want %q", got.RawQuery, query)
533 }
534 }
535
536 func randString(n int) string {
537 rnd := rand.New(rand.NewSource(int64(n)))
538 b := make([]byte, n)
539 for i := range b {
540 b[i] = byte(rnd.Intn(256))
541 }
542 return string(b)
543 }
544
545 func TestTransportBody(t *testing.T) {
546 bodyTests := []struct {
547 body string
548 noContentLen bool
549 }{
550 {body: "some message"},
551 {body: "some message", noContentLen: true},
552 {body: strings.Repeat("a", 1<<20), noContentLen: true},
553 {body: strings.Repeat("a", 1<<20)},
554 {body: randString(16<<10 - 1)},
555 {body: randString(16 << 10)},
556 {body: randString(16<<10 + 1)},
557 {body: randString(512<<10 - 1)},
558 {body: randString(512 << 10)},
559 {body: randString(512<<10 + 1)},
560 {body: randString(1<<20 - 1)},
561 {body: randString(1 << 20)},
562 {body: randString(1<<20 + 2)},
563 }
564
565 type reqInfo struct {
566 req *http.Request
567 slurp []byte
568 err error
569 }
570 gotc := make(chan reqInfo, 1)
571 ts := newTestServer(t,
572 func(w http.ResponseWriter, r *http.Request) {
573 slurp, err := io.ReadAll(r.Body)
574 if err != nil {
575 gotc <- reqInfo{err: err}
576 } else {
577 gotc <- reqInfo{req: r, slurp: slurp}
578 }
579 },
580 )
581
582 for i, tt := range bodyTests {
583 tr := newTransport(t)
584
585 var body io.Reader = strings.NewReader(tt.body)
586 if tt.noContentLen {
587 body = struct{ io.Reader }{body}
588 }
589 req, err := http.NewRequest("POST", ts.URL, body)
590 if err != nil {
591 t.Fatalf("#%d: %v", i, err)
592 }
593 c := &http.Client{Transport: tr}
594 res, err := c.Do(req)
595 if err != nil {
596 t.Fatalf("#%d: %v", i, err)
597 }
598 defer res.Body.Close()
599 ri := <-gotc
600 if ri.err != nil {
601 t.Errorf("#%d: read error: %v", i, ri.err)
602 continue
603 }
604 if got := string(ri.slurp); got != tt.body {
605 t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body))
606 }
607 wantLen := int64(len(tt.body))
608 if tt.noContentLen && tt.body != "" {
609 wantLen = -1
610 }
611 if ri.req.ContentLength != wantLen {
612 t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen)
613 }
614 }
615 }
616
617 func shortString(v string) string {
618 const maxLen = 100
619 if len(v) <= maxLen {
620 return v
621 }
622 return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:])
623 }
624
625 type capitalizeReader struct {
626 r io.Reader
627 }
628
629 func (cr capitalizeReader) Read(p []byte) (n int, err error) {
630 n, err = cr.r.Read(p)
631 for i, b := range p[:n] {
632 if b >= 'a' && b <= 'z' {
633 p[i] = b - ('a' - 'A')
634 }
635 }
636 return
637 }
638
639 type flushWriter struct {
640 w io.Writer
641 }
642
643 func (fw flushWriter) Write(p []byte) (n int, err error) {
644 n, err = fw.w.Write(p)
645 if f, ok := fw.w.(http.Flusher); ok {
646 f.Flush()
647 }
648 return
649 }
650
651 func newLocalListener(t *testing.T) net.Listener {
652 ln, err := net.Listen("tcp4", "127.0.0.1:0")
653 if err == nil {
654 return ln
655 }
656 ln, err = net.Listen("tcp6", "[::1]:0")
657 if err != nil {
658 t.Fatal(err)
659 }
660 return ln
661 }
662
663 func TestTransportReqBodyAfterResponse_200(t *testing.T) {
664 synctestTest(t, func(t testing.TB) {
665 testTransportReqBodyAfterResponse(t, 200)
666 })
667 }
668 func TestTransportReqBodyAfterResponse_403(t *testing.T) {
669 synctestTest(t, func(t testing.TB) {
670 testTransportReqBodyAfterResponse(t, 403)
671 })
672 }
673
674 func testTransportReqBodyAfterResponse(t testing.TB, status int) {
675 const bodySize = 1 << 10
676
677 tc := newTestClientConn(t)
678 tc.greet()
679
680 body := tc.newRequestBody()
681 body.writeBytes(bodySize / 2)
682 req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
683 rt := tc.roundTrip(req)
684
685 tc.wantHeaders(wantHeader{
686 streamID: rt.streamID(),
687 endStream: false,
688 header: http.Header{
689 ":authority": []string{"dummy.tld"},
690 ":method": []string{"PUT"},
691 ":path": []string{"/"},
692 },
693 })
694
695
696 tc.writeWindowUpdate(0, bodySize)
697 tc.writeWindowUpdate(rt.streamID(), bodySize)
698
699 tc.wantData(wantData{
700 streamID: rt.streamID(),
701 endStream: false,
702 size: bodySize / 2,
703 })
704
705 tc.writeHeaders(HeadersFrameParam{
706 StreamID: rt.streamID(),
707 EndHeaders: true,
708 EndStream: true,
709 BlockFragment: tc.makeHeaderBlockFragment(
710 ":status", strconv.Itoa(status),
711 ),
712 })
713
714 res := rt.response()
715 if res.StatusCode != status {
716 t.Fatalf("status code = %v; want %v", res.StatusCode, status)
717 }
718
719 body.writeBytes(bodySize / 2)
720 body.closeWithError(io.EOF)
721
722 if status == 200 {
723
724 tc.wantData(wantData{
725 streamID: rt.streamID(),
726 endStream: true,
727 size: bodySize / 2,
728 multiple: true,
729 })
730 } else {
731
732 tc.wantFrameType(FrameRSTStream)
733 }
734
735 rt.wantBody(nil)
736 }
737
738
739 func TestTransportFullDuplex(t *testing.T) {
740 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
741 w.WriteHeader(200)
742 w.(http.Flusher).Flush()
743 io.Copy(flushWriter{w}, capitalizeReader{r.Body})
744 fmt.Fprintf(w, "bye.\n")
745 })
746
747 tr := newTransport(t)
748 c := &http.Client{Transport: tr}
749
750 pr, pw := io.Pipe()
751 req, err := http.NewRequest("PUT", ts.URL, io.NopCloser(pr))
752 if err != nil {
753 t.Fatal(err)
754 }
755 req.ContentLength = -1
756 res, err := c.Do(req)
757 if err != nil {
758 t.Fatal(err)
759 }
760 defer res.Body.Close()
761 if res.StatusCode != 200 {
762 t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200)
763 }
764 bs := bufio.NewScanner(res.Body)
765 want := func(v string) {
766 if !bs.Scan() {
767 t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err())
768 }
769 }
770 write := func(v string) {
771 _, err := io.WriteString(pw, v)
772 if err != nil {
773 t.Fatalf("pipe write: %v", err)
774 }
775 }
776 write("foo\n")
777 want("FOO")
778 write("bar\n")
779 want("BAR")
780 pw.Close()
781 want("bye.")
782 if err := bs.Err(); err != nil {
783 t.Fatal(err)
784 }
785 }
786
787 func TestTransportConnectRequest(t *testing.T) {
788 gotc := make(chan *http.Request, 1)
789 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
790 gotc <- r
791 })
792
793 u, err := url.Parse(ts.URL)
794 if err != nil {
795 t.Fatal(err)
796 }
797
798 tr := newTransport(t)
799 c := &http.Client{Transport: tr}
800
801 tests := []struct {
802 req *http.Request
803 want string
804 }{
805 {
806 req: &http.Request{
807 Method: "CONNECT",
808 Header: http.Header{},
809 URL: u,
810 },
811 want: u.Host,
812 },
813 {
814 req: &http.Request{
815 Method: "CONNECT",
816 Header: http.Header{},
817 URL: u,
818 Host: "example.com:123",
819 },
820 want: "example.com:123",
821 },
822 }
823
824 for i, tt := range tests {
825 res, err := c.Do(tt.req)
826 if err != nil {
827 t.Errorf("%d. RoundTrip = %v", i, err)
828 continue
829 }
830 res.Body.Close()
831 req := <-gotc
832 if req.Method != "CONNECT" {
833 t.Errorf("method = %q; want CONNECT", req.Method)
834 }
835 if req.Host != tt.want {
836 t.Errorf("Host = %q; want %q", req.Host, tt.want)
837 }
838 if req.URL.Host != tt.want {
839 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
840 }
841 }
842 }
843
844 type headerType int
845
846 const (
847 noHeader headerType = iota
848 oneHeader
849 splitHeader
850 )
851
852 const (
853 f0 = noHeader
854 f1 = oneHeader
855 f2 = splitHeader
856 d0 = false
857 d1 = true
858 )
859
860
861
862
863
864
865 func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) }
866 func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) }
867 func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) }
868 func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) }
869 func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) }
870 func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) }
871 func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) }
872 func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) }
873 func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) }
874 func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) }
875 func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) }
876 func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) }
877 func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) }
878 func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) }
879 func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) }
880 func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) }
881 func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) }
882 func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) }
883 func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) }
884 func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) }
885 func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) }
886 func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) }
887 func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) }
888 func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) }
889 func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) }
890 func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) }
891 func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) }
892 func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) }
893 func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) }
894 func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) }
895 func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) }
896 func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) }
897 func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) }
898 func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) }
899 func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) }
900 func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) }
901
902 func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) {
903 synctestTest(t, func(t testing.TB) {
904 testTransportResPatternBubble(t, expect100Continue, resHeader, withData, trailers)
905 })
906 }
907 func testTransportResPatternBubble(t testing.TB, expect100Continue, resHeader headerType, withData bool, trailers headerType) {
908 const reqBody = "some request body"
909 const resBody = "some response body"
910
911 if resHeader == noHeader {
912
913
914 panic("invalid combination")
915 }
916
917 tc := newTestClientConn(t)
918 tc.greet()
919
920 req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
921 if expect100Continue != noHeader {
922 req.Header.Set("Expect", "100-continue")
923 }
924 rt := tc.roundTrip(req)
925
926 tc.wantFrameType(FrameHeaders)
927
928
929 tc.writeHeadersMode(expect100Continue, HeadersFrameParam{
930 StreamID: rt.streamID(),
931 EndHeaders: true,
932 EndStream: false,
933 BlockFragment: tc.makeHeaderBlockFragment(
934 ":status", "100",
935 ),
936 })
937
938
939 tc.wantData(wantData{
940 streamID: rt.streamID(),
941 endStream: true,
942 size: len(reqBody),
943 })
944
945 hdr := []string{
946 ":status", "200",
947 "x-foo", "blah",
948 "x-bar", "more",
949 }
950 if trailers != noHeader {
951 hdr = append(hdr, "trailer", "some-trailer")
952 }
953 tc.writeHeadersMode(resHeader, HeadersFrameParam{
954 StreamID: rt.streamID(),
955 EndHeaders: true,
956 EndStream: withData == false && trailers == noHeader,
957 BlockFragment: tc.makeHeaderBlockFragment(hdr...),
958 })
959 if withData {
960 endStream := trailers == noHeader
961 tc.writeData(rt.streamID(), endStream, []byte(resBody))
962 }
963 tc.writeHeadersMode(trailers, HeadersFrameParam{
964 StreamID: rt.streamID(),
965 EndHeaders: true,
966 EndStream: true,
967 BlockFragment: tc.makeHeaderBlockFragment(
968 "some-trailer", "some-value",
969 ),
970 })
971
972 rt.wantStatus(200)
973 if !withData {
974 rt.wantBody(nil)
975 } else {
976 rt.wantBody([]byte(resBody))
977 }
978 if trailers == noHeader {
979 rt.wantTrailers(nil)
980 } else {
981 rt.wantTrailers(http.Header{
982 "Some-Trailer": {"some-value"},
983 })
984 }
985 }
986
987
988 func TestTransportUnknown1xx(t *testing.T) { synctestTest(t, testTransportUnknown1xx) }
989 func testTransportUnknown1xx(t testing.TB) {
990 var buf bytes.Buffer
991 SetTestHookGot1xx(t, func(code int, header textproto.MIMEHeader) error {
992 fmt.Fprintf(&buf, "code=%d header=%v\n", code, header)
993 return nil
994 })
995
996 tc := newTestClientConn(t)
997 tc.greet()
998
999 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1000 rt := tc.roundTrip(req)
1001
1002 for i := 110; i <= 114; i++ {
1003 tc.writeHeaders(HeadersFrameParam{
1004 StreamID: rt.streamID(),
1005 EndHeaders: true,
1006 EndStream: false,
1007 BlockFragment: tc.makeHeaderBlockFragment(
1008 ":status", fmt.Sprint(i),
1009 "foo-bar", fmt.Sprint(i),
1010 ),
1011 })
1012 }
1013 tc.writeHeaders(HeadersFrameParam{
1014 StreamID: rt.streamID(),
1015 EndHeaders: true,
1016 EndStream: true,
1017 BlockFragment: tc.makeHeaderBlockFragment(
1018 ":status", "204",
1019 ),
1020 })
1021
1022 res := rt.response()
1023 if res.StatusCode != 204 {
1024 t.Fatalf("status code = %v; want 204", res.StatusCode)
1025 }
1026 want := `code=110 header=map[Foo-Bar:[110]]
1027 code=111 header=map[Foo-Bar:[111]]
1028 code=112 header=map[Foo-Bar:[112]]
1029 code=113 header=map[Foo-Bar:[113]]
1030 code=114 header=map[Foo-Bar:[114]]
1031 `
1032 if got := buf.String(); got != want {
1033 t.Errorf("Got trace:\n%s\nWant:\n%s", got, want)
1034 }
1035 }
1036
1037 func TestTransportReceiveUndeclaredTrailer(t *testing.T) {
1038 synctestTest(t, testTransportReceiveUndeclaredTrailer)
1039 }
1040 func testTransportReceiveUndeclaredTrailer(t testing.TB) {
1041 tc := newTestClientConn(t)
1042 tc.greet()
1043
1044 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1045 rt := tc.roundTrip(req)
1046
1047 tc.writeHeaders(HeadersFrameParam{
1048 StreamID: rt.streamID(),
1049 EndHeaders: true,
1050 EndStream: false,
1051 BlockFragment: tc.makeHeaderBlockFragment(
1052 ":status", "200",
1053 ),
1054 })
1055 tc.writeHeaders(HeadersFrameParam{
1056 StreamID: rt.streamID(),
1057 EndHeaders: true,
1058 EndStream: true,
1059 BlockFragment: tc.makeHeaderBlockFragment(
1060 "some-trailer", "I'm an undeclared Trailer!",
1061 ),
1062 })
1063
1064 rt.wantStatus(200)
1065 rt.wantBody(nil)
1066 rt.wantTrailers(http.Header{
1067 "Some-Trailer": []string{"I'm an undeclared Trailer!"},
1068 })
1069 }
1070
1071 func TestTransportInvalidTrailer_Pseudo1(t *testing.T) {
1072 testTransportInvalidTrailer_Pseudo(t, oneHeader)
1073 }
1074 func TestTransportInvalidTrailer_Pseudo2(t *testing.T) {
1075 testTransportInvalidTrailer_Pseudo(t, splitHeader)
1076 }
1077 func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) {
1078 testInvalidTrailer(t, trailers, PseudoHeaderError(":colon"),
1079 ":colon", "foo",
1080 "foo", "bar",
1081 )
1082 }
1083
1084 func TestTransportInvalidTrailer_Capital1(t *testing.T) {
1085 testTransportInvalidTrailer_Capital(t, oneHeader)
1086 }
1087 func TestTransportInvalidTrailer_Capital2(t *testing.T) {
1088 testTransportInvalidTrailer_Capital(t, splitHeader)
1089 }
1090 func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) {
1091 testInvalidTrailer(t, trailers, HeaderFieldNameError("Capital"),
1092 "foo", "bar",
1093 "Capital", "bad",
1094 )
1095 }
1096 func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) {
1097 testInvalidTrailer(t, oneHeader, HeaderFieldNameError(""),
1098 "", "bad",
1099 )
1100 }
1101 func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) {
1102 testInvalidTrailer(t, oneHeader, HeaderFieldValueError("x"),
1103 "x", "has\nnewline",
1104 )
1105 }
1106
1107 func testInvalidTrailer(t *testing.T, mode headerType, wantErr error, trailers ...string) {
1108 synctestTest(t, func(t testing.TB) {
1109 testInvalidTrailerBubble(t, mode, wantErr, trailers...)
1110 })
1111 }
1112 func testInvalidTrailerBubble(t testing.TB, mode headerType, wantErr error, trailers ...string) {
1113 tc := newTestClientConn(t)
1114 tc.greet()
1115
1116 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1117 rt := tc.roundTrip(req)
1118
1119 tc.writeHeaders(HeadersFrameParam{
1120 StreamID: rt.streamID(),
1121 EndHeaders: true,
1122 EndStream: false,
1123 BlockFragment: tc.makeHeaderBlockFragment(
1124 ":status", "200",
1125 "trailer", "declared",
1126 ),
1127 })
1128 tc.writeHeadersMode(mode, HeadersFrameParam{
1129 StreamID: rt.streamID(),
1130 EndHeaders: true,
1131 EndStream: true,
1132 BlockFragment: tc.makeHeaderBlockFragment(trailers...),
1133 })
1134
1135 rt.wantStatus(200)
1136 body, err := rt.readBody()
1137 se, ok := err.(StreamError)
1138 if !ok || se.Cause != wantErr {
1139 t.Fatalf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", body, err, wantErr, wantErr)
1140 }
1141 if len(body) > 0 {
1142 t.Fatalf("body = %q; want nothing", body)
1143 }
1144 }
1145
1146
1147
1148
1149
1150 func headerListSize(h http.Header) (size uint32) {
1151 for k, vv := range h {
1152 for _, v := range vv {
1153 hf := hpack.HeaderField{Name: k, Value: v}
1154 size += hf.Size()
1155 }
1156 }
1157 return size
1158 }
1159
1160
1161
1162
1163
1164
1165
1166
1167 func padHeaders(t testing.TB, h http.Header, limit uint64, filler string) {
1168 if limit > 0xffffffff {
1169 t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit)
1170 }
1171 hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
1172 minPadding := uint64(hf.Size())
1173 size := uint64(headerListSize(h))
1174
1175 minlimit := size + minPadding
1176 if limit < minlimit {
1177 t.Fatalf("padHeaders: limit %v < %v", limit, minlimit)
1178 }
1179
1180
1181
1182 nameFmt := "Pad-Headers-%06d"
1183 hf = hpack.HeaderField{Name: fmt.Sprintf(nameFmt, 1), Value: filler}
1184 fieldSize := uint64(hf.Size())
1185
1186
1187
1188 limit = limit - minPadding
1189 for i := 0; size+fieldSize < limit; i++ {
1190 name := fmt.Sprintf(nameFmt, i)
1191 h.Add(name, filler)
1192 size += fieldSize
1193 }
1194
1195
1196 remain := limit - size
1197 lastValue := strings.Repeat("*", int(remain))
1198 h.Add("Pad-Headers", lastValue)
1199 }
1200
1201 func TestPadHeaders(t *testing.T) {
1202 check := func(h http.Header, limit uint32, fillerLen int) {
1203 if h == nil {
1204 h = make(http.Header)
1205 }
1206 filler := strings.Repeat("f", fillerLen)
1207 padHeaders(t, h, uint64(limit), filler)
1208 gotSize := headerListSize(h)
1209 if gotSize != limit {
1210 t.Errorf("Got size = %v; want %v", gotSize, limit)
1211 }
1212 }
1213
1214 hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
1215 minLimit := hf.Size()
1216 for limit := minLimit; limit <= 128; limit++ {
1217 for fillerLen := 0; uint32(fillerLen) <= limit; fillerLen++ {
1218 check(nil, limit, fillerLen)
1219 }
1220 }
1221
1222
1223
1224
1225
1226
1227 tests := []struct {
1228 fillerLen int
1229 limit uint32
1230 }{
1231 {
1232 fillerLen: 64,
1233 limit: 1024,
1234 },
1235 {
1236 fillerLen: 1024,
1237 limit: 1286,
1238 },
1239 {
1240 fillerLen: 256,
1241 limit: 2048,
1242 },
1243 {
1244 fillerLen: 1024,
1245 limit: 10 * 1024,
1246 },
1247 {
1248 fillerLen: 1023,
1249 limit: 11 * 1024,
1250 },
1251 }
1252 h := make(http.Header)
1253 for _, tc := range tests {
1254 check(nil, tc.limit, tc.fillerLen)
1255 check(h, tc.limit, tc.fillerLen)
1256 }
1257 }
1258
1259 func TestTransportChecksRequestHeaderListSize(t *testing.T) {
1260 synctestTest(t, testTransportChecksRequestHeaderListSize)
1261 }
1262 func testTransportChecksRequestHeaderListSize(t testing.TB) {
1263 const peerSize = 16 << 10
1264
1265 tc := newTestClientConn(t)
1266 tc.greet(Setting{SettingMaxHeaderListSize, peerSize})
1267
1268 checkRoundTrip := func(req *http.Request, wantErr error, desc string) {
1269 t.Helper()
1270 rt := tc.roundTrip(req)
1271 if wantErr != nil {
1272 if err := rt.err(); !errors.Is(err, wantErr) {
1273 t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr)
1274 }
1275 return
1276 }
1277
1278 tc.wantFrameType(FrameHeaders)
1279 tc.writeHeaders(HeadersFrameParam{
1280 StreamID: rt.streamID(),
1281 EndHeaders: true,
1282 EndStream: true,
1283 BlockFragment: tc.makeHeaderBlockFragment(
1284 ":status", "200",
1285 ),
1286 })
1287
1288 rt.wantStatus(http.StatusOK)
1289 }
1290 headerListSizeForRequest := func(req *http.Request) (size uint64) {
1291 _, err := httpcommon.EncodeHeaders(context.Background(), httpcommon.EncodeHeadersParam{
1292 Request: httpcommon.Request{
1293 Header: req.Header,
1294 Trailer: req.Trailer,
1295 URL: req.URL,
1296 Host: req.Host,
1297 Method: req.Method,
1298 ActualContentLength: req.ContentLength,
1299 },
1300 AddGzipHeader: true,
1301 PeerMaxHeaderListSize: 0xffffffffffffffff,
1302 }, func(name, value string) {
1303 hf := hpack.HeaderField{Name: name, Value: value}
1304 size += uint64(hf.Size())
1305 })
1306 if err != nil {
1307 t.Fatal(err)
1308 }
1309 return size
1310 }
1311
1312
1313
1314 newRequest := func() *http.Request {
1315
1316 const bodytext = "hello"
1317 body := strings.NewReader(bodytext)
1318 req, err := http.NewRequest("POST", "https://example.tld/", body)
1319 if err != nil {
1320 t.Fatalf("newRequest: NewRequest: %v", err)
1321 }
1322 req.ContentLength = int64(len(bodytext))
1323 req.Header = http.Header{"User-Agent": nil}
1324 return req
1325 }
1326
1327
1328 req := newRequest()
1329 req.Trailer = make(http.Header)
1330 filler := strings.Repeat("*", 1024)
1331 padHeaders(t, req.Trailer, peerSize, filler)
1332
1333
1334 defaultBytes := headerListSizeForRequest(req)
1335 padHeaders(t, req.Header, peerSize-defaultBytes, filler)
1336 checkRoundTrip(req, nil, "Headers & Trailers under limit")
1337
1338
1339 req = newRequest()
1340 padHeaders(t, req.Header, peerSize, filler)
1341 checkRoundTrip(req, ErrRequestHeaderListSize, "Headers over limit")
1342
1343
1344 req = newRequest()
1345 req.Trailer = make(http.Header)
1346 padHeaders(t, req.Trailer, peerSize+1, filler)
1347 checkRoundTrip(req, ErrRequestHeaderListSize, "Trailers over limit")
1348
1349
1350 req = newRequest()
1351 filler = strings.Repeat("*", int(peerSize))
1352 req.Header.Set("Big", filler)
1353 checkRoundTrip(req, ErrRequestHeaderListSize, "Single large header")
1354
1355
1356 req = newRequest()
1357 req.Trailer = make(http.Header)
1358 req.Trailer.Set("Big", filler)
1359 checkRoundTrip(req, ErrRequestHeaderListSize, "Single large trailer")
1360 }
1361
1362 func TestTransportChecksResponseHeaderListSize(t *testing.T) {
1363 synctestTest(t, testTransportChecksResponseHeaderListSize)
1364 }
1365 func testTransportChecksResponseHeaderListSize(t testing.TB) {
1366 tc := newTestClientConn(t)
1367 tc.greet()
1368
1369 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1370 rt := tc.roundTrip(req)
1371
1372 tc.wantFrameType(FrameHeaders)
1373
1374 hdr := []string{":status", "200"}
1375 large := strings.Repeat("a", 1<<10)
1376 for i := 0; i < 5042; i++ {
1377 hdr = append(hdr, large, large)
1378 }
1379 hbf := tc.makeHeaderBlockFragment(hdr...)
1380
1381
1382
1383 if size, want := len(hbf), 6329; size != want {
1384 t.Fatalf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want)
1385 }
1386 tc.writeHeaders(HeadersFrameParam{
1387 StreamID: rt.streamID(),
1388 EndHeaders: true,
1389 EndStream: true,
1390 BlockFragment: hbf,
1391 })
1392
1393 res, err := rt.result()
1394 if e, ok := err.(StreamError); ok {
1395 err = e.Cause
1396 }
1397 if err != ErrResponseHeaderListSize {
1398 size := int64(0)
1399 if res != nil {
1400 res.Body.Close()
1401 for k, vv := range res.Header {
1402 for _, v := range vv {
1403 size += int64(len(k)) + int64(len(v)) + 32
1404 }
1405 }
1406 }
1407 t.Fatalf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size)
1408 }
1409 }
1410
1411 func TestTransportCookieHeaderSplit(t *testing.T) { synctestTest(t, testTransportCookieHeaderSplit) }
1412 func testTransportCookieHeaderSplit(t testing.TB) {
1413 tc := newTestClientConn(t)
1414 tc.greet()
1415
1416 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1417 req.Header.Add("Cookie", "a=b;c=d; e=f;")
1418 req.Header.Add("Cookie", "e=f;g=h; ")
1419 req.Header.Add("Cookie", "i=j")
1420 rt := tc.roundTrip(req)
1421
1422 tc.wantHeaders(wantHeader{
1423 streamID: rt.streamID(),
1424 endStream: true,
1425 header: http.Header{
1426 "cookie": []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"},
1427 },
1428 })
1429 tc.writeHeaders(HeadersFrameParam{
1430 StreamID: rt.streamID(),
1431 EndHeaders: true,
1432 EndStream: true,
1433 BlockFragment: tc.makeHeaderBlockFragment(
1434 ":status", "204",
1435 ),
1436 })
1437
1438 if err := rt.err(); err != nil {
1439 t.Fatalf("RoundTrip = %v, want success", err)
1440 }
1441 }
1442
1443
1444
1445
1446 func TestTransportBodyReadErrorType(t *testing.T) {
1447 doPanic := make(chan bool, 1)
1448 ts := newTestServer(t,
1449 func(w http.ResponseWriter, r *http.Request) {
1450 w.(http.Flusher).Flush()
1451 <-doPanic
1452 panic("boom")
1453 },
1454 optQuiet,
1455 )
1456
1457 tr := newTransport(t)
1458 c := &http.Client{Transport: tr}
1459
1460 res, err := c.Get(ts.URL)
1461 if err != nil {
1462 t.Fatal(err)
1463 }
1464 defer res.Body.Close()
1465 doPanic <- true
1466 buf := make([]byte, 100)
1467 n, err := res.Body.Read(buf)
1468 got, ok := err.(StreamError)
1469 want := StreamError{StreamID: 0x1, Code: 0x2}
1470 if !ok || got.StreamID != want.StreamID || got.Code != want.Code {
1471 t.Errorf("Read = %v, %#v; want error %#v", n, err, want)
1472 }
1473 }
1474
1475
1476
1477
1478 func TestTransportDoubleCloseOnWriteError(t *testing.T) {
1479 var (
1480 mu sync.Mutex
1481 conn net.Conn
1482 )
1483
1484 ts := newTestServer(t,
1485 func(w http.ResponseWriter, r *http.Request) {
1486 mu.Lock()
1487 defer mu.Unlock()
1488 if conn != nil {
1489 conn.Close()
1490 }
1491 },
1492 )
1493
1494 tr := newTransport(t)
1495 tr.DialTLS = func(network, addr string) (net.Conn, error) {
1496 tc, err := tls.Dial(network, addr, tlsConfigInsecure)
1497 if err != nil {
1498 return nil, err
1499 }
1500 mu.Lock()
1501 defer mu.Unlock()
1502 conn = tc
1503 return tc, nil
1504 }
1505 c := &http.Client{Transport: tr}
1506 c.Get(ts.URL)
1507 }
1508
1509
1510
1511
1512 func TestTransportDisableKeepAlives(t *testing.T) {
1513 ts := newTestServer(t,
1514 func(w http.ResponseWriter, r *http.Request) {
1515 io.WriteString(w, "hi")
1516 },
1517 )
1518
1519 connClosed := make(chan struct{})
1520 tr := newTransport(t)
1521 tr.Dial = func(network, addr string) (net.Conn, error) {
1522 tc, err := net.Dial(network, addr)
1523 if err != nil {
1524 return nil, err
1525 }
1526 return ¬eCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil
1527 }
1528 tr.DisableKeepAlives = true
1529 c := &http.Client{Transport: tr}
1530 res, err := c.Get(ts.URL)
1531 if err != nil {
1532 t.Fatal(err)
1533 }
1534 if _, err := io.ReadAll(res.Body); err != nil {
1535 t.Fatal(err)
1536 }
1537 defer res.Body.Close()
1538
1539 select {
1540 case <-connClosed:
1541 case <-time.After(1 * time.Second):
1542 t.Errorf("timeout")
1543 }
1544
1545 }
1546
1547
1548
1549 func TestTransportDisableKeepAlives_Concurrency(t *testing.T) {
1550 const D = 25 * time.Millisecond
1551 ts := newTestServer(t,
1552 func(w http.ResponseWriter, r *http.Request) {
1553 time.Sleep(D)
1554 io.WriteString(w, "hi")
1555 },
1556 )
1557
1558 var dials int32
1559 var conns sync.WaitGroup
1560 tr := newTransport(t)
1561 tr.Dial = func(network, addr string) (net.Conn, error) {
1562 tc, err := net.Dial(network, addr)
1563 if err != nil {
1564 return nil, err
1565 }
1566 atomic.AddInt32(&dials, 1)
1567 conns.Add(1)
1568 return ¬eCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil
1569 }
1570 tr.DisableKeepAlives = true
1571 c := &http.Client{Transport: tr}
1572 var reqs sync.WaitGroup
1573 const N = 20
1574 for i := 0; i < N; i++ {
1575 reqs.Add(1)
1576 if i == N-1 {
1577
1578
1579
1580
1581
1582
1583 time.Sleep(D * 2)
1584 }
1585 go func() {
1586 defer reqs.Done()
1587 res, err := c.Get(ts.URL)
1588 if err != nil {
1589 t.Error(err)
1590 return
1591 }
1592 if _, err := io.ReadAll(res.Body); err != nil {
1593 t.Error(err)
1594 return
1595 }
1596 res.Body.Close()
1597 }()
1598 }
1599 reqs.Wait()
1600 conns.Wait()
1601 t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N)
1602 }
1603
1604 type noteCloseConn struct {
1605 net.Conn
1606 onceClose sync.Once
1607 closefn func()
1608 }
1609
1610 func (c *noteCloseConn) Close() error {
1611 c.onceClose.Do(c.closefn)
1612 return c.Conn.Close()
1613 }
1614
1615 func isTimeout(err error) bool {
1616 switch err := err.(type) {
1617 case nil:
1618 return false
1619 case *url.Error:
1620 return isTimeout(err.Err)
1621 case net.Error:
1622 return err.Timeout()
1623 }
1624 return false
1625 }
1626
1627
1628 func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) {
1629 synctestTest(t, func(t testing.TB) {
1630 testTransportResponseHeaderTimeout(t, false)
1631 })
1632 }
1633 func TestTransportResponseHeaderTimeout_Body(t *testing.T) {
1634 synctestTest(t, func(t testing.TB) {
1635 testTransportResponseHeaderTimeout(t, true)
1636 })
1637 }
1638
1639 func testTransportResponseHeaderTimeout(t testing.TB, body bool) {
1640 const bodySize = 4 << 20
1641 tc := newTestClientConn(t, func(t1 *http.Transport) {
1642 t1.ResponseHeaderTimeout = 5 * time.Millisecond
1643 })
1644 tc.greet()
1645
1646 var req *http.Request
1647 var reqBody *testRequestBody
1648 if body {
1649 reqBody = tc.newRequestBody()
1650 reqBody.writeBytes(bodySize)
1651 reqBody.closeWithError(io.EOF)
1652 req, _ = http.NewRequest("POST", "https://dummy.tld/", reqBody)
1653 req.Header.Set("Content-Type", "text/foo")
1654 } else {
1655 req, _ = http.NewRequest("GET", "https://dummy.tld/", nil)
1656 }
1657
1658 rt := tc.roundTrip(req)
1659
1660 tc.wantFrameType(FrameHeaders)
1661
1662 tc.writeWindowUpdate(0, bodySize)
1663 tc.writeWindowUpdate(rt.streamID(), bodySize)
1664
1665 if body {
1666 tc.wantData(wantData{
1667 endStream: true,
1668 size: bodySize,
1669 multiple: true,
1670 })
1671 }
1672
1673 time.Sleep(4 * time.Millisecond)
1674 if rt.done() {
1675 t.Fatalf("RoundTrip is done after 4ms; want still waiting")
1676 }
1677 time.Sleep(1 * time.Millisecond)
1678
1679 if err := rt.err(); !isTimeout(err) {
1680 t.Fatalf("RoundTrip error: %v; want timeout error", err)
1681 }
1682 }
1683
1684
1685 func TestTransportWindowUpdateBeyondLimit(t *testing.T) {
1686 synctestTest(t, testTransportWindowUpdateBeyondLimit)
1687 }
1688 func testTransportWindowUpdateBeyondLimit(t testing.TB) {
1689 const windowIncrease uint32 = (1 << 31) - 1
1690 tc := newTestClientConn(t)
1691 tc.greet()
1692
1693 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1694 rt := tc.roundTrip(req)
1695 tc.wantHeaders(wantHeader{
1696 streamID: rt.streamID(),
1697 endStream: true,
1698 })
1699
1700 tc.writeWindowUpdate(rt.streamID(), windowIncrease)
1701 tc.wantRSTStream(rt.streamID(), ErrCodeFlowControl)
1702
1703 tc.writeWindowUpdate(0, windowIncrease)
1704 tc.wantClosed()
1705 }
1706
1707 func TestTransportDisableCompression(t *testing.T) {
1708 const body = "sup"
1709 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
1710 want := http.Header{
1711 "User-Agent": []string{"Go-http-client/2.0"},
1712 }
1713 if !reflect.DeepEqual(r.Header, want) {
1714 t.Errorf("request headers = %v; want %v", r.Header, want)
1715 }
1716 })
1717
1718 tr := newTransport(t)
1719 tr.DisableCompression = true
1720
1721 req, err := http.NewRequest("GET", ts.URL, nil)
1722 if err != nil {
1723 t.Fatal(err)
1724 }
1725 res, err := tr.RoundTrip(req)
1726 if err != nil {
1727 t.Fatal(err)
1728 }
1729 defer res.Body.Close()
1730 }
1731
1732
1733 func TestTransportRejectsConnHeaders(t *testing.T) {
1734 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
1735 var got []string
1736 for k := range r.Header {
1737 got = append(got, k)
1738 }
1739 sort.Strings(got)
1740 w.Header().Set("Got-Header", strings.Join(got, ","))
1741 })
1742
1743 tr := newTransport(t)
1744
1745 tests := []struct {
1746 key string
1747 value []string
1748 want string
1749 }{
1750 {
1751 key: "Upgrade",
1752 value: []string{"anything"},
1753 want: "ERROR: http2: invalid Upgrade request header: [\"anything\"]",
1754 },
1755 {
1756 key: "Connection",
1757 value: []string{"foo"},
1758 want: "ERROR: http2: invalid Connection request header: [\"foo\"]",
1759 },
1760 {
1761 key: "Connection",
1762 value: []string{"close"},
1763 want: "Accept-Encoding,User-Agent",
1764 },
1765 {
1766 key: "Connection",
1767 value: []string{"CLoSe"},
1768 want: "Accept-Encoding,User-Agent",
1769 },
1770 {
1771 key: "Connection",
1772 value: []string{"close", "something-else"},
1773 want: "ERROR: http2: invalid Connection request header: [\"close\" \"something-else\"]",
1774 },
1775 {
1776 key: "Connection",
1777 value: []string{"keep-alive"},
1778 want: "Accept-Encoding,User-Agent",
1779 },
1780 {
1781 key: "Connection",
1782 value: []string{"Keep-ALIVE"},
1783 want: "Accept-Encoding,User-Agent",
1784 },
1785 {
1786 key: "Proxy-Connection",
1787 value: []string{"keep-alive"},
1788 want: "Accept-Encoding,User-Agent",
1789 },
1790 {
1791 key: "Transfer-Encoding",
1792 value: []string{""},
1793 want: "Accept-Encoding,User-Agent",
1794 },
1795 {
1796 key: "Transfer-Encoding",
1797 value: []string{"foo"},
1798 want: "ERROR: http2: invalid Transfer-Encoding request header: [\"foo\"]",
1799 },
1800 {
1801 key: "Transfer-Encoding",
1802 value: []string{"chunked"},
1803 want: "Accept-Encoding,User-Agent",
1804 },
1805 {
1806 key: "Transfer-Encoding",
1807 value: []string{"chunKed"},
1808 want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunKed\"]",
1809 },
1810 {
1811 key: "Transfer-Encoding",
1812 value: []string{"chunked", "other"},
1813 want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunked\" \"other\"]",
1814 },
1815 {
1816 key: "Content-Length",
1817 value: []string{"123"},
1818 want: "Accept-Encoding,User-Agent",
1819 },
1820 {
1821 key: "Keep-Alive",
1822 value: []string{"doop"},
1823 want: "Accept-Encoding,User-Agent",
1824 },
1825 }
1826
1827 for _, tt := range tests {
1828 req, _ := http.NewRequest("GET", ts.URL, nil)
1829 req.Header[tt.key] = tt.value
1830 res, err := tr.RoundTrip(req)
1831 var got string
1832 if err != nil {
1833 got = fmt.Sprintf("ERROR: %v", err)
1834 } else {
1835 got = res.Header.Get("Got-Header")
1836 res.Body.Close()
1837 }
1838 if got != tt.want {
1839 t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want)
1840 }
1841 }
1842 }
1843
1844
1845
1846 func TestTransportRejectsContentLengthWithSign(t *testing.T) {
1847 tests := []struct {
1848 name string
1849 cl []string
1850 wantCL string
1851 }{
1852 {
1853 name: "proper content-length",
1854 cl: []string{"3"},
1855 wantCL: "3",
1856 },
1857 {
1858 name: "ignore cl with plus sign",
1859 cl: []string{"+3"},
1860 wantCL: "",
1861 },
1862 {
1863 name: "ignore cl with minus sign",
1864 cl: []string{"-3"},
1865 wantCL: "",
1866 },
1867 {
1868 name: "max int64, for safe uint64->int64 conversion",
1869 cl: []string{"9223372036854775807"},
1870 wantCL: "9223372036854775807",
1871 },
1872 {
1873 name: "overflows int64, so ignored",
1874 cl: []string{"9223372036854775808"},
1875 wantCL: "",
1876 },
1877 }
1878
1879 for _, tt := range tests {
1880 tt := tt
1881 t.Run(tt.name, func(t *testing.T) {
1882 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
1883 w.Header().Set("Content-Length", tt.cl[0])
1884 })
1885 tr := newTransport(t)
1886
1887 req, _ := http.NewRequest("HEAD", ts.URL, nil)
1888 res, err := tr.RoundTrip(req)
1889
1890 var got string
1891 if err != nil {
1892 got = fmt.Sprintf("ERROR: %v", err)
1893 } else {
1894 got = res.Header.Get("Content-Length")
1895 res.Body.Close()
1896 }
1897
1898 if got != tt.wantCL {
1899 t.Fatalf("Got: %q\nWant: %q", got, tt.wantCL)
1900 }
1901 })
1902 }
1903 }
1904
1905
1906
1907 func TestTransportFailsOnInvalidHeadersAndTrailers(t *testing.T) {
1908 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
1909 var got []string
1910 for k := range r.Header {
1911 got = append(got, k)
1912 }
1913 sort.Strings(got)
1914 w.Header().Set("Got-Header", strings.Join(got, ","))
1915 })
1916
1917 tests := [...]struct {
1918 h http.Header
1919 t http.Header
1920 wantErr string
1921 }{
1922 0: {
1923 h: http.Header{"with space": {"foo"}},
1924 wantErr: `net/http: invalid header field name "with space"`,
1925 },
1926 1: {
1927 h: http.Header{"name": {"Брэд"}},
1928 wantErr: "",
1929 },
1930 2: {
1931 h: http.Header{"имя": {"Brad"}},
1932 wantErr: `net/http: invalid header field name "имя"`,
1933 },
1934 3: {
1935 h: http.Header{"foo": {"foo\x01bar"}},
1936 wantErr: `net/http: invalid header field value for "foo"`,
1937 },
1938 4: {
1939 t: http.Header{"foo": {"foo\x01bar"}},
1940 wantErr: `net/http: invalid trailer field value for "foo"`,
1941 },
1942 5: {
1943 t: http.Header{"x-\r\nda": {"foo\x01bar"}},
1944 wantErr: `net/http: invalid trailer field name "x-\r\nda"`,
1945 },
1946 }
1947
1948 tr := newTransport(t)
1949
1950 for i, tt := range tests {
1951 req, _ := http.NewRequest("GET", ts.URL, nil)
1952 req.Header = tt.h
1953 if req.Header == nil {
1954 req.Header = http.Header{}
1955 }
1956 req.Trailer = tt.t
1957 res, err := tr.RoundTrip(req)
1958 var bad bool
1959 if tt.wantErr == "" {
1960 if err != nil {
1961 bad = true
1962 t.Errorf("case %d: error = %v; want no error", i, err)
1963 }
1964 } else {
1965 if !strings.Contains(fmt.Sprint(err), tt.wantErr) {
1966 bad = true
1967 t.Errorf("case %d: error = %v; want error %q", i, err, tt.wantErr)
1968 }
1969 }
1970 if err == nil {
1971 if bad {
1972 t.Logf("case %d: server got headers %q", i, res.Header.Get("Got-Header"))
1973 }
1974 res.Body.Close()
1975 }
1976 }
1977 }
1978
1979
1980
1981
1982 func TestTransportReadHeadResponse(t *testing.T) { synctestTest(t, testTransportReadHeadResponse) }
1983 func testTransportReadHeadResponse(t testing.TB) {
1984 tc := newTestClientConn(t)
1985 tc.greet()
1986
1987 req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
1988 rt := tc.roundTrip(req)
1989
1990 tc.wantFrameType(FrameHeaders)
1991 tc.writeHeaders(HeadersFrameParam{
1992 StreamID: rt.streamID(),
1993 EndHeaders: true,
1994 EndStream: false,
1995 BlockFragment: tc.makeHeaderBlockFragment(
1996 ":status", "200",
1997 "content-length", "123",
1998 ),
1999 })
2000 tc.writeData(rt.streamID(), true, nil)
2001
2002 res := rt.response()
2003 if res.ContentLength != 123 {
2004 t.Fatalf("Content-Length = %d; want 123", res.ContentLength)
2005 }
2006 rt.wantBody(nil)
2007 }
2008
2009 func TestTransportReadHeadResponseWithBody(t *testing.T) {
2010 synctestTest(t, testTransportReadHeadResponseWithBody)
2011 }
2012 func testTransportReadHeadResponseWithBody(t testing.TB) {
2013
2014
2015 log.SetOutput(io.Discard)
2016 defer log.SetOutput(os.Stderr)
2017
2018 response := "redirecting to /elsewhere"
2019 tc := newTestClientConn(t)
2020 tc.greet()
2021
2022 req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
2023 rt := tc.roundTrip(req)
2024
2025 tc.wantFrameType(FrameHeaders)
2026 tc.writeHeaders(HeadersFrameParam{
2027 StreamID: rt.streamID(),
2028 EndHeaders: true,
2029 EndStream: false,
2030 BlockFragment: tc.makeHeaderBlockFragment(
2031 ":status", "200",
2032 "content-length", strconv.Itoa(len(response)),
2033 ),
2034 })
2035 tc.writeData(rt.streamID(), true, []byte(response))
2036
2037 res := rt.response()
2038 if res.ContentLength != int64(len(response)) {
2039 t.Fatalf("Content-Length = %d; want %d", res.ContentLength, len(response))
2040 }
2041 rt.wantBody(nil)
2042 }
2043
2044 type neverEnding byte
2045
2046 func (b neverEnding) Read(p []byte) (int, error) {
2047 for i := range p {
2048 p[i] = byte(b)
2049 }
2050 return len(p), nil
2051 }
2052
2053
2054
2055 func TestTransportStreamEndsWhileBodyIsBeingWritten(t *testing.T) {
2056 synctestTest(t, testTransportStreamEndsWhileBodyIsBeingWritten)
2057 }
2058 func testTransportStreamEndsWhileBodyIsBeingWritten(t testing.TB) {
2059 body := "this is the client request body"
2060 const windowSize = 10
2061
2062 tc := newTestClientConn(t)
2063 tc.greet(Setting{SettingInitialWindowSize, windowSize})
2064
2065
2066 req, _ := http.NewRequest("PUT", "https://dummy.tld/", strings.NewReader(body))
2067 rt := tc.roundTrip(req)
2068 tc.wantFrameType(FrameHeaders)
2069 tc.wantData(wantData{
2070 streamID: rt.streamID(),
2071 endStream: false,
2072 size: windowSize,
2073 })
2074
2075
2076 tc.writeHeaders(HeadersFrameParam{
2077 StreamID: rt.streamID(),
2078 EndHeaders: true,
2079 EndStream: true,
2080 BlockFragment: tc.makeHeaderBlockFragment(
2081 ":status", "413",
2082 ),
2083 })
2084 rt.wantStatus(413)
2085 }
2086
2087 func TestTransportFlowControl(t *testing.T) { synctestTest(t, testTransportFlowControl) }
2088 func testTransportFlowControl(t testing.TB) {
2089 const maxBuffer = 64 << 10
2090 tc := newTestClientConn(t, func(tr *http.Transport) {
2091 tr.HTTP2 = &http.HTTP2Config{
2092 MaxReceiveBufferPerConnection: maxBuffer,
2093 MaxReceiveBufferPerStream: maxBuffer,
2094 MaxReadFrameSize: 16 << 20,
2095 }
2096 })
2097 tc.greet()
2098
2099 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2100 rt := tc.roundTrip(req)
2101 tc.wantFrameType(FrameHeaders)
2102
2103 tc.writeHeaders(HeadersFrameParam{
2104 StreamID: rt.streamID(),
2105 EndHeaders: true,
2106 EndStream: false,
2107 BlockFragment: tc.makeHeaderBlockFragment(
2108 ":status", "200",
2109 ),
2110 })
2111 rt.wantStatus(200)
2112
2113
2114
2115
2116 tc.writeData(rt.streamID(), false, make([]byte, maxBuffer))
2117 tc.wantIdle()
2118
2119
2120
2121 resp := rt.response()
2122 if _, err := io.ReadFull(resp.Body, make([]byte, maxBuffer)); err != nil {
2123 t.Fatalf("io.Body.Read: %v", err)
2124 }
2125 var connTokens, streamTokens uint32
2126 for {
2127 f := tc.readFrame()
2128 if f == nil {
2129 break
2130 }
2131 wu, ok := f.(*WindowUpdateFrame)
2132 if !ok {
2133 t.Fatalf("received unexpected frame %T (want WINDOW_UPDATE)", f)
2134 }
2135 switch wu.StreamID {
2136 case 0:
2137 connTokens += wu.Increment
2138 case wu.StreamID:
2139 streamTokens += wu.Increment
2140 default:
2141 t.Fatalf("received unexpected WINDOW_UPDATE for stream %v", wu.StreamID)
2142 }
2143 }
2144 if got, want := connTokens, uint32(maxBuffer); got != want {
2145 t.Errorf("transport provided %v bytes of connection WINDOW_UPDATE, want %v", got, want)
2146 }
2147 if got, want := streamTokens, uint32(maxBuffer); got != want {
2148 t.Errorf("transport provided %v bytes of stream WINDOW_UPDATE, want %v", got, want)
2149 }
2150 }
2151
2152
2153
2154
2155
2156
2157 func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) {
2158 synctestTest(t, func(t testing.TB) {
2159 testTransportUsesGoAwayDebugError(t, false)
2160 })
2161 }
2162
2163 func TestTransportUsesGoAwayDebugError_Body(t *testing.T) {
2164 synctestTest(t, func(t testing.TB) {
2165 testTransportUsesGoAwayDebugError(t, true)
2166 })
2167 }
2168
2169 func testTransportUsesGoAwayDebugError(t testing.TB, failMidBody bool) {
2170 tc := newTestClientConn(t)
2171 tc.greet()
2172
2173 const goAwayErrCode = ErrCodeHTTP11Required
2174 const goAwayDebugData = "some debug data"
2175
2176 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2177 rt := tc.roundTrip(req)
2178
2179 tc.wantFrameType(FrameHeaders)
2180
2181 if failMidBody {
2182 tc.writeHeaders(HeadersFrameParam{
2183 StreamID: rt.streamID(),
2184 EndHeaders: true,
2185 EndStream: false,
2186 BlockFragment: tc.makeHeaderBlockFragment(
2187 ":status", "200",
2188 "content-length", "123",
2189 ),
2190 })
2191 }
2192
2193
2194
2195 tc.writeGoAway(5, ErrCodeNo, []byte(goAwayDebugData))
2196 tc.writeGoAway(5, goAwayErrCode, nil)
2197 tc.closeWrite()
2198
2199 res, err := rt.result()
2200 whence := "RoundTrip"
2201 if failMidBody {
2202 whence = "Body.Read"
2203 if err != nil {
2204 t.Fatalf("RoundTrip error = %v, want success", err)
2205 }
2206 _, err = res.Body.Read(make([]byte, 1))
2207 }
2208
2209 want := GoAwayError{
2210 LastStreamID: 5,
2211 ErrCode: goAwayErrCode,
2212 DebugData: goAwayDebugData,
2213 }
2214 if !reflect.DeepEqual(err, want) {
2215 t.Errorf("%v error = %T: %#v, want %T (%#v)", whence, err, err, want, want)
2216 }
2217 }
2218
2219 func testTransportReturnsUnusedFlowControl(t testing.TB, oneDataFrame bool) {
2220 tc := newTestClientConn(t)
2221 tc.greet()
2222
2223 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2224 rt := tc.roundTrip(req)
2225
2226 tc.wantFrameType(FrameHeaders)
2227 tc.writeHeaders(HeadersFrameParam{
2228 StreamID: rt.streamID(),
2229 EndHeaders: true,
2230 EndStream: false,
2231 BlockFragment: tc.makeHeaderBlockFragment(
2232 ":status", "200",
2233 "content-length", "5000",
2234 ),
2235 })
2236 initialInflow := tc.inflowWindow(0)
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248 const streamNotEnded = false
2249 if oneDataFrame {
2250 tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 5000))
2251 } else {
2252 tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 1))
2253 }
2254
2255 res := rt.response()
2256 if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
2257 t.Fatalf("body read = %v, %v; want 1, nil", n, err)
2258 }
2259 res.Body.Close()
2260 synctest.Wait()
2261
2262 sentAdditionalData := false
2263 tc.wantUnorderedFrames(
2264 func(f *RSTStreamFrame) bool {
2265 if f.ErrCode != ErrCodeCancel {
2266 t.Fatalf("Expected a RSTStreamFrame with code cancel; got %v", SummarizeFrame(f))
2267 }
2268 if !oneDataFrame {
2269
2270 tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 4999))
2271 sentAdditionalData = true
2272 }
2273 return true
2274 },
2275 func(f *WindowUpdateFrame) bool {
2276 if !oneDataFrame && !sentAdditionalData {
2277 t.Fatalf("Got WindowUpdateFrame, don't expect one yet")
2278 }
2279 if f.Increment != 5000 {
2280 t.Fatalf("Expected WindowUpdateFrames for 5000 bytes; got %v", SummarizeFrame(f))
2281 }
2282 return true
2283 },
2284 )
2285
2286 if got, want := tc.inflowWindow(0), initialInflow; got != want {
2287 t.Fatalf("connection flow tokens = %v, want %v", got, want)
2288 }
2289 }
2290
2291
2292 func TestTransportReturnsUnusedFlowControlSingleWrite(t *testing.T) {
2293 synctestTest(t, func(t testing.TB) {
2294 testTransportReturnsUnusedFlowControl(t, true)
2295 })
2296 }
2297
2298
2299 func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) {
2300 synctestTest(t, func(t testing.TB) {
2301 testTransportReturnsUnusedFlowControl(t, false)
2302 })
2303 }
2304
2305
2306
2307 func TestTransportAdjustsFlowControl(t *testing.T) { synctestTest(t, testTransportAdjustsFlowControl) }
2308 func testTransportAdjustsFlowControl(t testing.TB) {
2309 const bodySize = 1 << 20
2310
2311 tc := newTestClientConn(t)
2312 tc.wantFrameType(FrameSettings)
2313 tc.wantFrameType(FrameWindowUpdate)
2314
2315
2316 body := tc.newRequestBody()
2317 body.writeBytes(bodySize)
2318 body.closeWithError(io.EOF)
2319
2320 req, _ := http.NewRequest("POST", "https://dummy.tld/", body)
2321 rt := tc.roundTrip(req)
2322
2323 tc.wantFrameType(FrameHeaders)
2324
2325 gotBytes := int64(0)
2326 for {
2327 f := readFrame[*DataFrame](t, tc)
2328 gotBytes += int64(len(f.Data()))
2329
2330
2331 if gotBytes >= InitialWindowSize/2 {
2332 break
2333 }
2334 }
2335
2336 tc.writeSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize})
2337 tc.writeWindowUpdate(0, bodySize)
2338 tc.writeSettingsAck()
2339
2340 tc.wantUnorderedFrames(
2341 func(f *SettingsFrame) bool { return true },
2342 func(f *DataFrame) bool {
2343 gotBytes += int64(len(f.Data()))
2344 return f.StreamEnded()
2345 },
2346 )
2347
2348 if gotBytes != bodySize {
2349 t.Fatalf("server received %v bytes of body, want %v", gotBytes, bodySize)
2350 }
2351
2352 tc.writeHeaders(HeadersFrameParam{
2353 StreamID: rt.streamID(),
2354 EndHeaders: true,
2355 EndStream: true,
2356 BlockFragment: tc.makeHeaderBlockFragment(
2357 ":status", "200",
2358 ),
2359 })
2360 rt.wantStatus(200)
2361 }
2362
2363
2364 func TestTransportReturnsDataPaddingFlowControl(t *testing.T) {
2365 synctestTest(t, testTransportReturnsDataPaddingFlowControl)
2366 }
2367 func testTransportReturnsDataPaddingFlowControl(t testing.TB) {
2368 tc := newTestClientConn(t)
2369 tc.greet()
2370
2371 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2372 rt := tc.roundTrip(req)
2373
2374 tc.wantFrameType(FrameHeaders)
2375 tc.writeHeaders(HeadersFrameParam{
2376 StreamID: rt.streamID(),
2377 EndHeaders: true,
2378 EndStream: false,
2379 BlockFragment: tc.makeHeaderBlockFragment(
2380 ":status", "200",
2381 "content-length", "5000",
2382 ),
2383 })
2384
2385 initialConnWindow := tc.inflowWindow(0)
2386 initialStreamWindow := tc.inflowWindow(rt.streamID())
2387
2388 pad := make([]byte, 5)
2389 tc.writeDataPadded(rt.streamID(), false, make([]byte, 5000), pad)
2390
2391
2392 synctest.Wait()
2393 if got, want := tc.inflowWindow(0), initialConnWindow-5000; got != want {
2394 t.Errorf("conn inflow window = %v, want %v", got, want)
2395 }
2396 if got, want := tc.inflowWindow(rt.streamID()), initialStreamWindow-5000; got != want {
2397 t.Errorf("stream inflow window = %v, want %v", got, want)
2398 }
2399 }
2400
2401
2402
2403 func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) {
2404 synctestTest(t, testTransportReturnsErrorOnBadResponseHeaders)
2405 }
2406 func testTransportReturnsErrorOnBadResponseHeaders(t testing.TB) {
2407 tc := newTestClientConn(t)
2408 tc.greet()
2409
2410 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2411 rt := tc.roundTrip(req)
2412
2413 tc.wantFrameType(FrameHeaders)
2414 tc.writeHeaders(HeadersFrameParam{
2415 StreamID: rt.streamID(),
2416 EndHeaders: true,
2417 EndStream: false,
2418 BlockFragment: tc.makeHeaderBlockFragment(
2419 ":status", "200",
2420 " content-type", "bogus",
2421 ),
2422 })
2423
2424 err := rt.err()
2425 want := StreamError{1, ErrCodeProtocol, HeaderFieldNameError(" content-type")}
2426 if !reflect.DeepEqual(err, want) {
2427 t.Fatalf("RoundTrip error = %#v; want %#v", err, want)
2428 }
2429
2430 fr := readFrame[*RSTStreamFrame](t, tc)
2431 if fr.StreamID != 1 || fr.ErrCode != ErrCodeProtocol {
2432 t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", SummarizeFrame(fr))
2433 }
2434 }
2435
2436
2437
2438 type byteAndEOFReader byte
2439
2440 func (b byteAndEOFReader) Read(p []byte) (n int, err error) {
2441 if len(p) == 0 {
2442 panic("unexpected useless call")
2443 }
2444 p[0] = byte(b)
2445 return 1, io.EOF
2446 }
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457 func TestTransportBodyDoubleEndStream(t *testing.T) {
2458 synctestTest(t, testTransportBodyDoubleEndStream)
2459 }
2460 func testTransportBodyDoubleEndStream(t testing.TB) {
2461 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2462
2463 })
2464
2465 tr := newTransport(t)
2466
2467 for i := 0; i < 2; i++ {
2468 req, _ := http.NewRequest("POST", ts.URL, byteAndEOFReader('a'))
2469 req.ContentLength = 1
2470 res, err := tr.RoundTrip(req)
2471 if err != nil {
2472 t.Fatalf("failure on req %d: %v", i+1, err)
2473 }
2474 defer res.Body.Close()
2475 }
2476 }
2477
2478
2479 func TestTransportRequestPathPseudo(t *testing.T) {
2480 type result struct {
2481 path string
2482 err string
2483 }
2484 tests := []struct {
2485 req *http.Request
2486 want result
2487 }{
2488 0: {
2489 req: &http.Request{
2490 Method: "GET",
2491 URL: &url.URL{
2492 Host: "foo.com",
2493 Path: "/foo",
2494 },
2495 },
2496 want: result{path: "/foo"},
2497 },
2498
2499
2500
2501 1: {
2502 req: &http.Request{
2503 Method: "GET",
2504 URL: &url.URL{
2505 Host: "foo.com",
2506 Path: "//foo",
2507 },
2508 },
2509 want: result{path: "//foo"},
2510 },
2511
2512
2513 2: {
2514 req: &http.Request{
2515 Method: "GET",
2516 URL: &url.URL{
2517 Scheme: "https",
2518 Opaque: "//foo.com/path",
2519 Host: "foo.com",
2520 Path: "/ignored",
2521 },
2522 },
2523 want: result{path: "/path"},
2524 },
2525
2526
2527 3: {
2528 req: &http.Request{
2529 Method: "GET",
2530 Host: "bar.com",
2531 URL: &url.URL{
2532 Scheme: "https",
2533 Opaque: "//bar.com/path",
2534 Host: "foo.com",
2535 Path: "/ignored",
2536 },
2537 },
2538 want: result{path: "/path"},
2539 },
2540
2541
2542 4: {
2543 req: &http.Request{
2544 Method: "GET",
2545 URL: &url.URL{
2546 Opaque: "/path",
2547 Host: "foo.com",
2548 Path: "/ignored",
2549 },
2550 },
2551 want: result{path: "/path"},
2552 },
2553
2554
2555 5: {
2556 req: &http.Request{
2557 Method: "GET",
2558 URL: &url.URL{
2559 Scheme: "https",
2560 Opaque: "//unknown_host/path",
2561 Host: "foo.com",
2562 Path: "/ignored",
2563 },
2564 },
2565 want: result{err: `invalid request :path "https://unknown_host/path" from URL.Opaque = "//unknown_host/path"`},
2566 },
2567
2568
2569 6: {
2570 req: &http.Request{
2571 Method: "CONNECT",
2572 URL: &url.URL{
2573 Host: "foo.com",
2574 },
2575 },
2576 want: result{},
2577 },
2578 }
2579 for i, tt := range tests {
2580 hbuf := &bytes.Buffer{}
2581 henc := hpack.NewEncoder(hbuf)
2582 _, err := httpcommon.EncodeHeaders(context.Background(), httpcommon.EncodeHeadersParam{
2583 Request: httpcommon.Request{
2584 Header: tt.req.Header,
2585 Trailer: tt.req.Trailer,
2586 URL: tt.req.URL,
2587 Host: tt.req.Host,
2588 Method: tt.req.Method,
2589 ActualContentLength: tt.req.ContentLength,
2590 },
2591 AddGzipHeader: false,
2592 PeerMaxHeaderListSize: 0xffffffffffffffff,
2593 }, func(name, value string) {
2594 henc.WriteField(hpack.HeaderField{Name: name, Value: value})
2595 })
2596 hdrs := hbuf.Bytes()
2597 var got result
2598 hpackDec := hpack.NewDecoder(InitialHeaderTableSize, func(f hpack.HeaderField) {
2599 if f.Name == ":path" {
2600 got.path = f.Value
2601 }
2602 })
2603 if err != nil {
2604 got.err = err.Error()
2605 } else if len(hdrs) > 0 {
2606 if _, err := hpackDec.Write(hdrs); err != nil {
2607 t.Errorf("%d. bogus hpack: %v", i, err)
2608 continue
2609 }
2610 }
2611 if got != tt.want {
2612 t.Errorf("%d. got %+v; want %+v", i, got, tt.want)
2613 }
2614
2615 }
2616
2617 }
2618
2619
2620
2621 func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) {
2622 synctestTest(t, testRoundTripDoesntConsumeRequestBodyEarly)
2623 }
2624 func testRoundTripDoesntConsumeRequestBodyEarly(t testing.TB) {
2625 tc := newTestClientConn(t)
2626 tc.greet()
2627 tc.closeWrite()
2628 synctest.Wait()
2629
2630 const body = "foo"
2631 req, _ := http.NewRequest("POST", "http://foo.com/", io.NopCloser(strings.NewReader(body)))
2632 rt := tc.roundTrip(req)
2633 if err := rt.err(); err != ErrClientConnNotEstablished {
2634 t.Fatalf("RoundTrip = %v; want errClientConnNotEstablished", err)
2635 }
2636
2637 slurp, err := io.ReadAll(req.Body)
2638 if err != nil {
2639 t.Errorf("ReadAll = %v", err)
2640 }
2641 if string(slurp) != body {
2642 t.Errorf("Body = %q; want %q", slurp, body)
2643 }
2644 }
2645
2646
2647
2648
2649
2650 func TestTransportCancelDataResponseRace(t *testing.T) {
2651 cancel := make(chan struct{})
2652 clientGotResponse := make(chan bool, 1)
2653
2654 const msg = "Hello."
2655 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2656 if strings.Contains(r.URL.Path, "/hello") {
2657 time.Sleep(50 * time.Millisecond)
2658 io.WriteString(w, msg)
2659 return
2660 }
2661 for i := 0; i < 50; i++ {
2662 io.WriteString(w, "Some data.")
2663 w.(http.Flusher).Flush()
2664 if i == 2 {
2665 <-clientGotResponse
2666 close(cancel)
2667 }
2668 time.Sleep(10 * time.Millisecond)
2669 }
2670 })
2671
2672 tr := newTransport(t)
2673
2674 c := &http.Client{Transport: tr}
2675 req, _ := http.NewRequest("GET", ts.URL, nil)
2676 req.Cancel = cancel
2677 res, err := c.Do(req)
2678 clientGotResponse <- true
2679 if err != nil {
2680 t.Fatal(err)
2681 }
2682 if _, err = io.Copy(io.Discard, res.Body); err == nil {
2683 t.Fatal("unexpected success")
2684 }
2685
2686 res, err = c.Get(ts.URL + "/hello")
2687 if err != nil {
2688 t.Fatal(err)
2689 }
2690 slurp, err := io.ReadAll(res.Body)
2691 if err != nil {
2692 t.Fatal(err)
2693 }
2694 if string(slurp) != msg {
2695 t.Errorf("Got = %q; want %q", slurp, msg)
2696 }
2697 }
2698
2699
2700
2701 func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
2702 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2703 w.WriteHeader(200)
2704 io.WriteString(w, "body")
2705 })
2706
2707 tr := newTransport(t)
2708
2709 req, _ := http.NewRequest("GET", ts.URL, nil)
2710 resp, err := tr.RoundTrip(req)
2711 if err != nil {
2712 t.Fatal(err)
2713 }
2714 if _, err = io.Copy(io.Discard, resp.Body); err != nil {
2715 t.Fatalf("error reading response body: %v", err)
2716 }
2717 if err := resp.Body.Close(); err != nil {
2718 t.Fatalf("error closing response body: %v", err)
2719 }
2720
2721
2722 req.Header = http.Header{}
2723 }
2724
2725 func TestTransportCloseAfterLostPing(t *testing.T) { synctestTest(t, testTransportCloseAfterLostPing) }
2726 func testTransportCloseAfterLostPing(t testing.TB) {
2727 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
2728 h2.PingTimeout = 1 * time.Second
2729 h2.SendPingTimeout = 1 * time.Second
2730 })
2731 tc.greet()
2732
2733 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2734 rt := tc.roundTrip(req)
2735 tc.wantFrameType(FrameHeaders)
2736
2737 time.Sleep(1 * time.Second)
2738 tc.wantFrameType(FramePing)
2739
2740 time.Sleep(1 * time.Second)
2741 err := rt.err()
2742 if err == nil || !strings.Contains(err.Error(), "client connection lost") {
2743 t.Fatalf("expected to get error about \"connection lost\", got %v", err)
2744 }
2745 }
2746
2747 func TestTransportPingWriteBlocks(t *testing.T) {
2748 ts := newTestServer(t,
2749 func(w http.ResponseWriter, r *http.Request) {},
2750 )
2751 tr := newTransport(t)
2752 tr.Dial = func(network, addr string) (net.Conn, error) {
2753 s, c := net.Pipe()
2754 go func() {
2755 srv := tls.Server(s, tlsConfigInsecure)
2756 srv.Handshake()
2757
2758
2759
2760
2761 var buf [1024]byte
2762 s.Read(buf[:])
2763 }()
2764 return c, nil
2765 }
2766 tr.HTTP2.PingTimeout = 1 * time.Millisecond
2767 tr.HTTP2.SendPingTimeout = 1 * time.Millisecond
2768 c := &http.Client{Transport: tr}
2769 _, err := c.Get(ts.URL)
2770 if err == nil {
2771 t.Fatalf("Get = nil, want error")
2772 }
2773 }
2774
2775 func TestTransportPingWhenReadingMultiplePings(t *testing.T) {
2776 synctestTest(t, testTransportPingWhenReadingMultiplePings)
2777 }
2778 func testTransportPingWhenReadingMultiplePings(t testing.TB) {
2779 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
2780 h2.SendPingTimeout = 1000 * time.Millisecond
2781 })
2782 tc.greet()
2783
2784 ctx, cancel := context.WithCancel(context.Background())
2785 req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
2786 rt := tc.roundTrip(req)
2787
2788 tc.wantFrameType(FrameHeaders)
2789 tc.writeHeaders(HeadersFrameParam{
2790 StreamID: rt.streamID(),
2791 EndHeaders: true,
2792 EndStream: false,
2793 BlockFragment: tc.makeHeaderBlockFragment(
2794 ":status", "200",
2795 ),
2796 })
2797
2798 for i := 0; i < 5; i++ {
2799
2800 time.Sleep(999 * time.Millisecond)
2801 if f := tc.readFrame(); f != nil {
2802 t.Fatalf("unexpected frame: %v", f)
2803 }
2804
2805
2806 time.Sleep(1 * time.Millisecond)
2807 f := readFrame[*PingFrame](t, tc)
2808 tc.writePing(true, f.Data)
2809 }
2810
2811
2812 cancel()
2813 synctest.Wait()
2814
2815 tc.wantFrameType(FrameRSTStream)
2816 _, err := rt.readBody()
2817 if err == nil {
2818 t.Fatalf("Response.Body.Read() = %v, want error", err)
2819 }
2820 }
2821
2822 func TestTransportPingWhenReadingPingDisabled(t *testing.T) {
2823 synctestTest(t, testTransportPingWhenReadingPingDisabled)
2824 }
2825 func testTransportPingWhenReadingPingDisabled(t testing.TB) {
2826 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
2827 h2.SendPingTimeout = 0
2828 })
2829 tc.greet()
2830
2831 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2832 rt := tc.roundTrip(req)
2833
2834 tc.wantFrameType(FrameHeaders)
2835 tc.writeHeaders(HeadersFrameParam{
2836 StreamID: rt.streamID(),
2837 EndHeaders: true,
2838 EndStream: false,
2839 BlockFragment: tc.makeHeaderBlockFragment(
2840 ":status", "200",
2841 ),
2842 })
2843
2844
2845 time.Sleep(1 * time.Minute)
2846 if f := tc.readFrame(); f != nil {
2847 t.Fatalf("unexpected frame: %v", f)
2848 }
2849 }
2850
2851 func TestTransportRetryAfterGOAWAYNoRetry(t *testing.T) {
2852 synctestTest(t, testTransportRetryAfterGOAWAYNoRetry)
2853 }
2854 func testTransportRetryAfterGOAWAYNoRetry(t testing.TB) {
2855 tt := newTestTransport(t)
2856
2857 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2858 rt := tt.roundTrip(req)
2859
2860
2861
2862
2863
2864 tc := tt.getConn()
2865 tc.wantFrameType(FrameSettings)
2866 tc.wantFrameType(FrameWindowUpdate)
2867 tc.wantHeaders(wantHeader{
2868 streamID: 1,
2869 endStream: true,
2870 })
2871 tc.writeSettings()
2872 tc.writeGoAway(0 , ErrCodeInternal, nil)
2873 if rt.err() == nil {
2874 t.Fatalf("after GOAWAY, RoundTrip is not done, want error")
2875 }
2876 }
2877
2878 func TestTransportRetryAfterGOAWAYRetry(t *testing.T) {
2879 synctestTest(t, testTransportRetryAfterGOAWAYRetry)
2880 }
2881 func testTransportRetryAfterGOAWAYRetry(t testing.TB) {
2882 tt := newTestTransport(t)
2883
2884 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2885 rt := tt.roundTrip(req)
2886
2887
2888
2889
2890
2891 tc := tt.getConn()
2892 tc.wantFrameType(FrameSettings)
2893 tc.wantFrameType(FrameWindowUpdate)
2894 tc.wantHeaders(wantHeader{
2895 streamID: 1,
2896 endStream: true,
2897 })
2898 tc.writeSettings()
2899 tc.writeGoAway(0 , ErrCodeNo, nil)
2900 if rt.done() {
2901 t.Fatalf("after GOAWAY, RoundTrip is done; want it to be retrying")
2902 }
2903
2904
2905 tc = tt.getConn()
2906 tc.wantFrameType(FrameSettings)
2907 tc.wantFrameType(FrameWindowUpdate)
2908 tc.wantHeaders(wantHeader{
2909 streamID: 1,
2910 endStream: true,
2911 })
2912 tc.writeSettings()
2913 tc.writeHeaders(HeadersFrameParam{
2914 StreamID: 1,
2915 EndHeaders: true,
2916 EndStream: true,
2917 BlockFragment: tc.makeHeaderBlockFragment(
2918 ":status", "200",
2919 ),
2920 })
2921
2922 rt.wantStatus(200)
2923 }
2924
2925 func TestTransportRetryAfterGOAWAYSecondRequest(t *testing.T) {
2926 synctestTest(t, testTransportRetryAfterGOAWAYSecondRequest)
2927 }
2928 func testTransportRetryAfterGOAWAYSecondRequest(t testing.TB) {
2929 tt := newTestTransport(t)
2930
2931
2932 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2933 rt1 := tt.roundTrip(req)
2934 tc := tt.getConn()
2935 tc.wantFrameType(FrameSettings)
2936 tc.wantFrameType(FrameWindowUpdate)
2937 tc.wantHeaders(wantHeader{
2938 streamID: 1,
2939 endStream: true,
2940 })
2941 tc.writeSettings()
2942 tc.wantFrameType(FrameSettings)
2943 tc.writeHeaders(HeadersFrameParam{
2944 StreamID: 1,
2945 EndHeaders: true,
2946 EndStream: true,
2947 BlockFragment: tc.makeHeaderBlockFragment(
2948 ":status", "200",
2949 ),
2950 })
2951 rt1.wantStatus(200)
2952
2953
2954
2955
2956
2957 req, _ = http.NewRequest("GET", "https://dummy.tld/", nil)
2958 rt2 := tt.roundTrip(req)
2959
2960
2961 tc.wantHeaders(wantHeader{
2962 streamID: 3,
2963 endStream: true,
2964 })
2965 tc.writeSettings()
2966 tc.writeGoAway(1 , ErrCodeProtocol, nil)
2967 if rt2.done() {
2968 t.Fatalf("after GOAWAY, RoundTrip is done; want it to be retrying")
2969 }
2970
2971
2972 tc = tt.getConn()
2973 tc.wantFrameType(FrameSettings)
2974 tc.wantFrameType(FrameWindowUpdate)
2975 tc.wantHeaders(wantHeader{
2976 streamID: 1,
2977 endStream: true,
2978 })
2979 tc.writeSettings()
2980 tc.writeHeaders(HeadersFrameParam{
2981 StreamID: 1,
2982 EndHeaders: true,
2983 EndStream: true,
2984 BlockFragment: tc.makeHeaderBlockFragment(
2985 ":status", "200",
2986 ),
2987 })
2988 rt2.wantStatus(200)
2989 }
2990
2991 func TestTransportRetryAfterRefusedStream(t *testing.T) {
2992 synctestTest(t, testTransportRetryAfterRefusedStream)
2993 }
2994 func testTransportRetryAfterRefusedStream(t testing.TB) {
2995 tt := newTestTransport(t)
2996
2997 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2998 rt := tt.roundTrip(req)
2999
3000
3001 tc := tt.getConn()
3002 tc.wantFrameType(FrameSettings)
3003 tc.wantFrameType(FrameWindowUpdate)
3004 tc.wantHeaders(wantHeader{
3005 streamID: 1,
3006 endStream: true,
3007 })
3008 tc.writeSettings()
3009 tc.wantFrameType(FrameSettings)
3010 tc.writeRSTStream(1, ErrCodeRefusedStream)
3011 if rt.done() {
3012 t.Fatalf("after RST_STREAM, RoundTrip is done; want it to be retrying")
3013 }
3014
3015
3016 tc.wantHeaders(wantHeader{
3017 streamID: 3,
3018 endStream: true,
3019 })
3020 tc.writeSettings()
3021 tc.writeHeaders(HeadersFrameParam{
3022 StreamID: 3,
3023 EndHeaders: true,
3024 EndStream: true,
3025 BlockFragment: tc.makeHeaderBlockFragment(
3026 ":status", "204",
3027 ),
3028 })
3029
3030 rt.wantStatus(204)
3031 }
3032
3033 func TestTransportRetryHasLimit(t *testing.T) { synctestTest(t, testTransportRetryHasLimit) }
3034 func testTransportRetryHasLimit(t testing.TB) {
3035 tt := newTestTransport(t)
3036
3037 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3038 rt := tt.roundTrip(req)
3039
3040 tc := tt.getConn()
3041 tc.netconn.SetReadDeadline(time.Time{})
3042 tc.wantFrameType(FrameSettings)
3043 tc.wantFrameType(FrameWindowUpdate)
3044
3045 count := 0
3046 start := time.Now()
3047 for streamID := uint32(1); !rt.done(); streamID += 2 {
3048 count++
3049 tc.wantHeaders(wantHeader{
3050 streamID: streamID,
3051 endStream: true,
3052 })
3053 if streamID == 1 {
3054 tc.writeSettings()
3055 tc.wantFrameType(FrameSettings)
3056 }
3057 tc.writeRSTStream(streamID, ErrCodeRefusedStream)
3058
3059 if totalDelay := time.Since(start); totalDelay > 5*time.Minute {
3060 t.Fatalf("RoundTrip still retrying after %v, should have given up", totalDelay)
3061 }
3062 synctest.Wait()
3063 }
3064 if got, want := count, 5; got < count {
3065 t.Errorf("RoundTrip made %v attempts, want at least %v", got, want)
3066 }
3067 if rt.err() == nil {
3068 t.Errorf("RoundTrip succeeded, want error")
3069 }
3070 }
3071
3072 func TestTransportResponseDataBeforeHeaders(t *testing.T) {
3073 synctestTest(t, testTransportResponseDataBeforeHeaders)
3074 }
3075 func testTransportResponseDataBeforeHeaders(t testing.TB) {
3076
3077 log.SetOutput(io.Discard)
3078 t.Cleanup(func() { log.SetOutput(os.Stderr) })
3079
3080 tc := newTestClientConn(t)
3081 tc.greet()
3082
3083
3084 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3085 rt1 := tc.roundTrip(req)
3086 tc.wantFrameType(FrameHeaders)
3087 tc.writeHeaders(HeadersFrameParam{
3088 StreamID: rt1.streamID(),
3089 EndHeaders: true,
3090 EndStream: true,
3091 BlockFragment: tc.makeHeaderBlockFragment(
3092 ":status", "200",
3093 ),
3094 })
3095 rt1.wantStatus(200)
3096
3097
3098 rt2 := tc.roundTrip(req)
3099 tc.wantFrameType(FrameHeaders)
3100 tc.writeData(rt2.streamID(), true, []byte("payload"))
3101 if err, ok := rt2.err().(StreamError); !ok || err.Code != ErrCodeProtocol {
3102 t.Fatalf("expected stream PROTOCOL_ERROR, got: %v", err)
3103 }
3104 }
3105
3106 func TestTransportMaxFrameReadSize(t *testing.T) {
3107 for _, test := range []struct {
3108 maxReadFrameSize uint32
3109 want uint32
3110 }{{
3111 maxReadFrameSize: 64000,
3112 want: 64000,
3113 }, {
3114 maxReadFrameSize: 1024,
3115
3116
3117
3118
3119
3120
3121
3122
3123 want: DefaultMaxReadFrameSize,
3124 }} {
3125 synctestSubtest(t, fmt.Sprint(test.maxReadFrameSize), func(t testing.TB) {
3126 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
3127 h2.MaxReadFrameSize = int(test.maxReadFrameSize)
3128 })
3129
3130 fr := readFrame[*SettingsFrame](t, tc)
3131 got, ok := fr.Value(SettingMaxFrameSize)
3132 if !ok {
3133 t.Errorf("Transport.MaxReadFrameSize = %v; server got no setting, want %v", test.maxReadFrameSize, test.want)
3134 } else if got != test.want {
3135 t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want)
3136 }
3137 })
3138 }
3139 }
3140
3141 func TestTransportRequestsLowServerLimit(t *testing.T) {
3142 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3143 }, func(h2 *http.HTTP2Config) {
3144 h2.MaxConcurrentStreams = 1
3145 })
3146
3147 var (
3148 connCountMu sync.Mutex
3149 connCount int
3150 )
3151 tr := newTransport(t)
3152 tr.DialTLS = func(network, addr string) (net.Conn, error) {
3153 connCountMu.Lock()
3154 defer connCountMu.Unlock()
3155 connCount++
3156 return tls.Dial(network, addr, tlsConfigInsecure)
3157 }
3158
3159 const reqCount = 3
3160 for i := 0; i < reqCount; i++ {
3161 req, err := http.NewRequest("GET", ts.URL, nil)
3162 if err != nil {
3163 t.Fatal(err)
3164 }
3165 res, err := tr.RoundTrip(req)
3166 if err != nil {
3167 t.Fatal(err)
3168 }
3169 if got, want := res.StatusCode, 200; got != want {
3170 t.Errorf("StatusCode = %v; want %v", got, want)
3171 }
3172 if res != nil && res.Body != nil {
3173 res.Body.Close()
3174 }
3175 }
3176
3177 if connCount != 1 {
3178 t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount)
3179 }
3180 }
3181
3182
3183 func TestTransportRequestsStallAtServerLimit(t *testing.T) {
3184 synctest.Test(t, testTransportRequestsStallAtServerLimit)
3185 }
3186 func testTransportRequestsStallAtServerLimit(t *testing.T) {
3187 const maxConcurrent = 2
3188
3189 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
3190 h2.StrictMaxConcurrentRequests = true
3191 })
3192 tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
3193
3194 cancelClientRequest := make(chan struct{})
3195
3196
3197
3198 var rts []*testRoundTrip
3199 for k := 0; k < maxConcurrent+2; k++ {
3200 req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil)
3201 if k == maxConcurrent {
3202 req.Cancel = cancelClientRequest
3203 }
3204 rt := tc.roundTrip(req)
3205 rts = append(rts, rt)
3206
3207 if k < maxConcurrent {
3208
3209 tc.wantHeaders(wantHeader{
3210 streamID: rt.streamID(),
3211 endStream: true,
3212 header: http.Header{
3213 ":authority": []string{"dummy.tld"},
3214 ":method": []string{"GET"},
3215 ":path": []string{fmt.Sprintf("/%d", k)},
3216 },
3217 })
3218 } else {
3219
3220
3221 if fr := tc.readFrame(); fr != nil {
3222 t.Fatalf("after making new request while at stream limit, got unexpected frame: %v", fr)
3223 }
3224 }
3225
3226 if rt.done() {
3227 t.Fatalf("rt %v done", k)
3228 }
3229 }
3230
3231
3232
3233 close(cancelClientRequest)
3234 synctest.Wait()
3235 if err := rts[maxConcurrent].err(); err == nil {
3236 t.Fatalf("RoundTrip(%d) should have failed due to cancel, did not", maxConcurrent)
3237 }
3238
3239
3240 for i, rt := range rts {
3241 if i != maxConcurrent && rt.done() {
3242 t.Fatalf("RoundTrip(%d) is done, but should not be", i)
3243 }
3244 }
3245
3246
3247 tc.writeHeaders(HeadersFrameParam{
3248 StreamID: rts[0].streamID(),
3249 EndHeaders: true,
3250 EndStream: true,
3251 BlockFragment: tc.makeHeaderBlockFragment(
3252 ":status", "200",
3253 ),
3254 })
3255 synctest.Wait()
3256 tc.wantHeaders(wantHeader{
3257 streamID: rts[maxConcurrent+1].streamID(),
3258 endStream: true,
3259 header: http.Header{
3260 ":authority": []string{"dummy.tld"},
3261 ":method": []string{"GET"},
3262 ":path": []string{fmt.Sprintf("/%d", maxConcurrent+1)},
3263 },
3264 })
3265 rts[0].wantStatus(200)
3266 }
3267
3268 func TestTransportMaxDecoderHeaderTableSize(t *testing.T) {
3269 synctestTest(t, testTransportMaxDecoderHeaderTableSize)
3270 }
3271 func testTransportMaxDecoderHeaderTableSize(t testing.TB) {
3272 var reqSize, resSize uint32 = 8192, 16384
3273 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
3274 h2.MaxDecoderHeaderTableSize = int(reqSize)
3275 })
3276
3277 fr := readFrame[*SettingsFrame](t, tc)
3278 if v, ok := fr.Value(SettingHeaderTableSize); !ok {
3279 t.Fatalf("missing SETTINGS_HEADER_TABLE_SIZE setting")
3280 } else if v != reqSize {
3281 t.Fatalf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", v, reqSize)
3282 }
3283
3284 tc.writeSettings(Setting{SettingHeaderTableSize, resSize})
3285 synctest.Wait()
3286 if got, want := tc.cc.TestPeerMaxHeaderTableSize(), resSize; got != want {
3287 t.Fatalf("peerHeaderTableSize = %d, want %d", got, want)
3288 }
3289 }
3290
3291 func TestTransportMaxEncoderHeaderTableSize(t *testing.T) {
3292 synctestTest(t, testTransportMaxEncoderHeaderTableSize)
3293 }
3294 func testTransportMaxEncoderHeaderTableSize(t testing.TB) {
3295 var peerAdvertisedMaxHeaderTableSize uint32 = 16384
3296 const wantMaxEncoderHeaderTableSize = 8192
3297 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
3298 h2.MaxEncoderHeaderTableSize = wantMaxEncoderHeaderTableSize
3299 })
3300 tc.greet(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize})
3301
3302 if got, want := tc.cc.TestHPACKEncoder().MaxDynamicTableSize(), uint32(wantMaxEncoderHeaderTableSize); got != want {
3303 t.Fatalf("henc.MaxDynamicTableSize() = %d, want %d", got, want)
3304 }
3305 }
3306
3307
3308
3309 func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) {
3310 synctestTest(t, testTransportAllocationsAfterResponseBodyClose)
3311 }
3312 func testTransportAllocationsAfterResponseBodyClose(t testing.TB) {
3313 tc := newTestClientConn(t)
3314 tc.greet()
3315
3316
3317 req, _ := http.NewRequest("PUT", "https://dummy.tld/", nil)
3318 rt := tc.roundTrip(req)
3319 tc.wantFrameType(FrameHeaders)
3320
3321
3322 tc.writeHeaders(HeadersFrameParam{
3323 StreamID: rt.streamID(),
3324 EndHeaders: true,
3325 EndStream: false,
3326 BlockFragment: tc.makeHeaderBlockFragment(
3327 ":status", "200",
3328 ),
3329 })
3330 tc.writeData(rt.streamID(), false, make([]byte, 64))
3331 tc.wantIdle()
3332
3333
3334 respBody := rt.response().Body
3335 var buf [1]byte
3336 if _, err := respBody.Read(buf[:]); err != nil {
3337 t.Error(err)
3338 }
3339 if err := respBody.Close(); err != nil {
3340 t.Error(err)
3341 }
3342 tc.wantFrameType(FrameRSTStream)
3343
3344
3345 tc.writeData(rt.streamID(), false, make([]byte, 64))
3346
3347 if _, err := respBody.Read(buf[:]); err == nil {
3348 t.Error("read from closed body unexpectedly succeeded")
3349 }
3350 }
3351
3352
3353
3354 func TestTransportNoBodyMeansNoDATA(t *testing.T) { synctestTest(t, testTransportNoBodyMeansNoDATA) }
3355 func testTransportNoBodyMeansNoDATA(t testing.TB) {
3356 tc := newTestClientConn(t)
3357 tc.greet()
3358
3359 req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
3360 rt := tc.roundTrip(req)
3361
3362 tc.wantHeaders(wantHeader{
3363 streamID: rt.streamID(),
3364 endStream: true,
3365 header: http.Header{
3366 ":authority": []string{"dummy.tld"},
3367 ":method": []string{"GET"},
3368 ":path": []string{"/"},
3369 },
3370 })
3371 if fr := tc.readFrame(); fr != nil {
3372 t.Fatalf("unexpected frame after headers: %v", fr)
3373 }
3374 }
3375
3376 func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) {
3377 DisableGoroutineTracking(b)
3378 b.ReportAllocs()
3379 ts := newTestServer(b,
3380 func(w http.ResponseWriter, r *http.Request) {
3381 for i := 0; i < nResHeader; i++ {
3382 name := fmt.Sprint("A-", i)
3383 w.Header().Set(name, "*")
3384 }
3385 },
3386 optQuiet,
3387 )
3388
3389 tr := newTransport(b)
3390
3391 req, err := http.NewRequest("GET", ts.URL, nil)
3392 if err != nil {
3393 b.Fatal(err)
3394 }
3395
3396 for i := 0; i < nReqHeaders; i++ {
3397 name := fmt.Sprint("A-", i)
3398 req.Header.Set(name, "*")
3399 }
3400
3401 b.ResetTimer()
3402
3403 for i := 0; i < b.N; i++ {
3404 res, err := tr.RoundTrip(req)
3405 if err != nil {
3406 if res != nil {
3407 res.Body.Close()
3408 }
3409 b.Fatalf("RoundTrip err = %v; want nil", err)
3410 }
3411 res.Body.Close()
3412 if res.StatusCode != http.StatusOK {
3413 b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
3414 }
3415 }
3416 }
3417
3418 type infiniteReader struct{}
3419
3420 func (r infiniteReader) Read(b []byte) (int, error) {
3421 return len(b), nil
3422 }
3423
3424
3425
3426 func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) {
3427 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3428 w.WriteHeader(http.StatusOK)
3429 })
3430
3431 tr := newTransport(t)
3432
3433
3434 req, _ := http.NewRequest("PUT", ts.URL, infiniteReader{})
3435 res, err := tr.RoundTrip(req)
3436 if err != nil {
3437 t.Fatal(err)
3438 }
3439 if res.StatusCode != http.StatusOK {
3440 t.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
3441 }
3442 }
3443
3444
3445
3446 func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) {
3447 synctestTest(t, testTransportHandlesInvalidStatuslessResponse)
3448 }
3449 func testTransportHandlesInvalidStatuslessResponse(t testing.TB) {
3450 tc := newTestClientConn(t)
3451 tc.greet()
3452
3453 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3454 rt := tc.roundTrip(req)
3455
3456 tc.wantFrameType(FrameHeaders)
3457 tc.writeHeaders(HeadersFrameParam{
3458 StreamID: rt.streamID(),
3459 EndHeaders: true,
3460 EndStream: false,
3461 BlockFragment: tc.makeHeaderBlockFragment(
3462 "content-type", "text/html",
3463 ),
3464 })
3465 tc.writeData(rt.streamID(), true, []byte("payload"))
3466 }
3467
3468 func BenchmarkClientRequestHeaders(b *testing.B) {
3469 b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) })
3470 b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10, 0) })
3471 b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100, 0) })
3472 b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000, 0) })
3473 }
3474
3475 func BenchmarkClientResponseHeaders(b *testing.B) {
3476 b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) })
3477 b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 10) })
3478 b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 100) })
3479 b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 1000) })
3480 }
3481
3482 func BenchmarkDownloadFrameSize(b *testing.B) {
3483 b.Run(" 16k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 16*1024) })
3484 b.Run(" 64k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 64*1024) })
3485 b.Run("128k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 128*1024) })
3486 b.Run("256k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 256*1024) })
3487 b.Run("512k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 512*1024) })
3488 }
3489 func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) {
3490 DisableGoroutineTracking(b)
3491 const transferSize = 1024 * 1024 * 1024
3492 b.ReportAllocs()
3493 ts := newTestServer(b,
3494 func(w http.ResponseWriter, r *http.Request) {
3495
3496 w.Header().Set("Content-Length", strconv.Itoa(transferSize))
3497 w.Header().Set("Content-Transfer-Encoding", "binary")
3498 var data [1024 * 1024]byte
3499 for i := 0; i < transferSize/(1024*1024); i++ {
3500 w.Write(data[:])
3501 }
3502 }, optQuiet,
3503 )
3504
3505 tr := newTransport(b)
3506 tr.HTTP2.MaxReadFrameSize = int(frameSize)
3507
3508 req, err := http.NewRequest("GET", ts.URL, nil)
3509 if err != nil {
3510 b.Fatal(err)
3511 }
3512
3513 b.N = 3
3514 b.SetBytes(transferSize)
3515 b.ResetTimer()
3516
3517 for i := 0; i < b.N; i++ {
3518 res, err := tr.RoundTrip(req)
3519 if err != nil {
3520 if res != nil {
3521 res.Body.Close()
3522 }
3523 b.Fatalf("RoundTrip err = %v; want nil", err)
3524 }
3525 data, _ := io.ReadAll(res.Body)
3526 if len(data) != transferSize {
3527 b.Fatalf("Response length invalid")
3528 }
3529 res.Body.Close()
3530 if res.StatusCode != http.StatusOK {
3531 b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
3532 }
3533 }
3534 }
3535
3536 func BenchmarkClientGzip(b *testing.B) {
3537 DisableGoroutineTracking(b)
3538 b.ReportAllocs()
3539
3540 const responseSize = 1024 * 1024
3541
3542 var buf bytes.Buffer
3543 gz := gzip.NewWriter(&buf)
3544 if _, err := io.CopyN(gz, crand.Reader, responseSize); err != nil {
3545 b.Fatal(err)
3546 }
3547 gz.Close()
3548
3549 data := buf.Bytes()
3550 ts := newTestServer(b,
3551 func(w http.ResponseWriter, r *http.Request) {
3552 w.Header().Set("Content-Encoding", "gzip")
3553 w.Write(data)
3554 },
3555 optQuiet,
3556 )
3557
3558 tr := newTransport(b)
3559
3560 req, err := http.NewRequest("GET", ts.URL, nil)
3561 if err != nil {
3562 b.Fatal(err)
3563 }
3564
3565 b.ResetTimer()
3566
3567 for i := 0; i < b.N; i++ {
3568 res, err := tr.RoundTrip(req)
3569 if err != nil {
3570 b.Fatalf("RoundTrip err = %v; want nil", err)
3571 }
3572 if res.StatusCode != http.StatusOK {
3573 b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
3574 }
3575 n, err := io.Copy(io.Discard, res.Body)
3576 res.Body.Close()
3577 if err != nil {
3578 b.Fatalf("RoundTrip err = %v; want nil", err)
3579 }
3580 if n != responseSize {
3581 b.Fatalf("RoundTrip expected %d bytes, got %d", responseSize, n)
3582 }
3583 }
3584 }
3585
3586
3587
3588
3589 func TestClientConnCloseAtHeaders(t *testing.T) { synctestTest(t, testClientConnCloseAtHeaders) }
3590 func testClientConnCloseAtHeaders(t testing.TB) {
3591 tc := newTestClientConn(t)
3592 tc.greet()
3593
3594 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3595 rt := tc.roundTrip(req)
3596 tc.wantFrameType(FrameHeaders)
3597
3598 tc.cc.Close()
3599 synctest.Wait()
3600 if err := rt.err(); err != ErrClientConnForceClosed {
3601 t.Fatalf("RoundTrip error = %v, want errClientConnForceClosed", err)
3602 }
3603 }
3604
3605
3606
3607 func TestClientConnCloseAtBody(t *testing.T) { synctestTest(t, testClientConnCloseAtBody) }
3608 func testClientConnCloseAtBody(t testing.TB) {
3609 tc := newTestClientConn(t)
3610 tc.greet()
3611
3612 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3613 rt := tc.roundTrip(req)
3614 tc.wantFrameType(FrameHeaders)
3615
3616 tc.writeHeaders(HeadersFrameParam{
3617 StreamID: rt.streamID(),
3618 EndHeaders: true,
3619 EndStream: false,
3620 BlockFragment: tc.makeHeaderBlockFragment(
3621 ":status", "200",
3622 ),
3623 })
3624 tc.writeData(rt.streamID(), false, make([]byte, 64))
3625 resp := rt.response()
3626 tc.cc.Close()
3627 synctest.Wait()
3628
3629 if _, err := io.Copy(io.Discard, resp.Body); err == nil {
3630 t.Error("expected a Copy error, got nil")
3631 }
3632 }
3633
3634
3635
3636 func TestClientConnShutdown(t *testing.T) { synctestTest(t, testClientConnShutdown) }
3637 func testClientConnShutdown(t testing.TB) {
3638 tc := newTestClientConn(t)
3639 tc.greet()
3640
3641 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3642 rt := tc.roundTrip(req)
3643 tc.wantFrameType(FrameHeaders)
3644
3645 go tc.cc.Shutdown(context.Background())
3646 synctest.Wait()
3647
3648 tc.wantFrameType(FrameGoAway)
3649 tc.wantIdle()
3650 body := []byte("body")
3651 tc.writeHeaders(HeadersFrameParam{
3652 StreamID: rt.streamID(),
3653 EndHeaders: true,
3654 EndStream: false,
3655 BlockFragment: tc.makeHeaderBlockFragment(
3656 ":status", "200",
3657 ),
3658 })
3659 tc.writeData(rt.streamID(), true, body)
3660
3661 rt.wantStatus(200)
3662 rt.wantBody(body)
3663
3664
3665 tc.wantClosed()
3666 }
3667
3668
3669
3670
3671 func TestClientConnShutdownCancel(t *testing.T) { synctestTest(t, testClientConnShutdownCancel) }
3672 func testClientConnShutdownCancel(t testing.TB) {
3673 tc := newTestClientConn(t)
3674 tc.greet()
3675
3676 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3677 rt := tc.roundTrip(req)
3678 tc.wantFrameType(FrameHeaders)
3679
3680 ctx, cancel := context.WithCancel(t.Context())
3681 var shutdownErr error
3682 go func() {
3683 shutdownErr = tc.cc.Shutdown(ctx)
3684 }()
3685 synctest.Wait()
3686
3687 tc.wantFrameType(FrameGoAway)
3688 tc.wantIdle()
3689
3690 cancel()
3691 synctest.Wait()
3692
3693 if shutdownErr != context.Canceled {
3694 t.Fatalf("ClientConn.Shutdown(ctx) did not return context.Canceled after cancelling context")
3695 }
3696
3697
3698
3699
3700
3701
3702
3703
3704
3705
3706 if rt.done() {
3707 t.Fatal("RoundTrip unexpectedly returned during shutdown")
3708 }
3709 }
3710
3711 type errReader struct {
3712 body []byte
3713 err error
3714 }
3715
3716 func (r *errReader) Read(p []byte) (int, error) {
3717 if len(r.body) > 0 {
3718 n := copy(p, r.body)
3719 r.body = r.body[n:]
3720 return n, nil
3721 }
3722 return 0, r.err
3723 }
3724
3725 func testTransportBodyReadError(t *testing.T, body []byte) {
3726 synctestTest(t, func(t testing.TB) {
3727 testTransportBodyReadErrorBubble(t, body)
3728 })
3729 }
3730 func testTransportBodyReadErrorBubble(t testing.TB, body []byte) {
3731 tc := newTestClientConn(t)
3732 tc.greet()
3733
3734 bodyReadError := errors.New("body read error")
3735 b := tc.newRequestBody()
3736 b.Write(body)
3737 b.closeWithError(bodyReadError)
3738 req, _ := http.NewRequest("PUT", "https://dummy.tld/", b)
3739 rt := tc.roundTrip(req)
3740
3741 tc.wantFrameType(FrameHeaders)
3742 var receivedBody []byte
3743 readFrames:
3744 for {
3745 switch f := tc.readFrame().(type) {
3746 case *DataFrame:
3747 receivedBody = append(receivedBody, f.Data()...)
3748 case *RSTStreamFrame:
3749 break readFrames
3750 default:
3751 t.Fatalf("unexpected frame: %v", f)
3752 case nil:
3753 t.Fatalf("transport is idle, want RST_STREAM")
3754 }
3755 }
3756 if !bytes.Equal(receivedBody, body) {
3757 t.Fatalf("body: %q; expected %q", receivedBody, body)
3758 }
3759
3760 if err := rt.err(); err != bodyReadError {
3761 t.Fatalf("err = %v; want %v", err, bodyReadError)
3762 }
3763 }
3764
3765 func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) }
3766 func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyReadError(t, []byte("123")) }
3767
3768
3769
3770
3771 func TestTransportBodyEagerEndStream(t *testing.T) { synctestTest(t, testTransportBodyEagerEndStream) }
3772 func testTransportBodyEagerEndStream(t testing.TB) {
3773 const reqBody = "some request body"
3774 const resBody = "some response body"
3775
3776 tc := newTestClientConn(t)
3777 tc.greet()
3778
3779 body := strings.NewReader(reqBody)
3780 req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
3781 tc.roundTrip(req)
3782
3783 tc.wantFrameType(FrameHeaders)
3784 f := readFrame[*DataFrame](t, tc)
3785 if !f.StreamEnded() {
3786 t.Fatalf("data frame without END_STREAM %v", f)
3787 }
3788 }
3789
3790 type chunkReader struct {
3791 chunks [][]byte
3792 }
3793
3794 func (r *chunkReader) Read(p []byte) (int, error) {
3795 if len(r.chunks) > 0 {
3796 n := copy(p, r.chunks[0])
3797 r.chunks = r.chunks[1:]
3798 return n, nil
3799 }
3800 panic("shouldn't read this many times")
3801 }
3802
3803
3804
3805
3806
3807
3808
3809
3810
3811 func TestTransportBodyLargerThanSpecifiedContentLength_len3(t *testing.T) {
3812 body := &chunkReader{[][]byte{
3813 []byte("123"),
3814 []byte("456"),
3815 }}
3816 synctestTest(t, func(t testing.TB) {
3817 testTransportBodyLargerThanSpecifiedContentLength(t, body, 3)
3818 })
3819 }
3820
3821 func TestTransportBodyLargerThanSpecifiedContentLength_len2(t *testing.T) {
3822 body := &chunkReader{[][]byte{
3823 []byte("123"),
3824 }}
3825 synctestTest(t, func(t testing.TB) {
3826 testTransportBodyLargerThanSpecifiedContentLength(t, body, 2)
3827 })
3828 }
3829
3830 func testTransportBodyLargerThanSpecifiedContentLength(t testing.TB, body *chunkReader, contentLen int64) {
3831 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3832 r.Body.Read(make([]byte, 6))
3833 })
3834
3835 tr := newTransport(t)
3836
3837 req, _ := http.NewRequest("POST", ts.URL, body)
3838 req.ContentLength = contentLen
3839 _, err := tr.RoundTrip(req)
3840 if err != ErrReqBodyTooLong {
3841 t.Fatalf("expected %v, got %v", ErrReqBodyTooLong, err)
3842 }
3843 }
3844
3845
3846 func TestTransportNewClientConnCloseOnWriteError(t *testing.T) {
3847 synctestTest(t, testTransportNewClientConnCloseOnWriteError)
3848 }
3849 func testTransportNewClientConnCloseOnWriteError(t testing.TB) {
3850
3851
3852
3853
3854
3855
3856
3857 t.Skip("TODO: test fails because write errors don't cause the conn to close")
3858
3859 tc := newTestClientConn(t)
3860
3861 synctest.Wait()
3862 writeErr := errors.New("write error")
3863 tc.netconn.loc.setWriteError(writeErr)
3864
3865 tc.writeSettings()
3866 tc.wantIdle()
3867
3868
3869 tc.wantFrameType(FrameSettings)
3870 tc.wantFrameType(FrameWindowUpdate)
3871 tc.wantIdle()
3872
3873 synctest.Wait()
3874 if !tc.netconn.IsClosedByPeer() {
3875 t.Error("expected closed conn")
3876 }
3877 }
3878
3879 func TestTransportRoundtripCloseOnWriteError(t *testing.T) {
3880 synctestTest(t, testTransportRoundtripCloseOnWriteError)
3881 }
3882 func testTransportRoundtripCloseOnWriteError(t testing.TB) {
3883 tc := newTestClientConn(t)
3884 tc.greet()
3885
3886 body := tc.newRequestBody()
3887 body.writeBytes(1)
3888 req, _ := http.NewRequest("GET", "https://dummy.tld/", body)
3889 rt := tc.roundTrip(req)
3890
3891 writeErr := errors.New("write error")
3892 tc.closeWriteWithError(writeErr)
3893
3894 body.writeBytes(1)
3895 if err := rt.err(); err != writeErr {
3896 t.Fatalf("RoundTrip error %v, want %v", err, writeErr)
3897 }
3898
3899 rt2 := tc.roundTrip(req)
3900 if err := rt2.err(); err != ErrClientConnUnusable {
3901 t.Fatalf("RoundTrip error %v, want errClientConnUnusable", err)
3902 }
3903 }
3904
3905
3906
3907
3908 func TestTransportBodyRewindRace(t *testing.T) {
3909 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3910 w.Header().Set("Connection", "close")
3911 w.WriteHeader(http.StatusOK)
3912 return
3913 })
3914
3915 tr := newTransport(t)
3916 tr.MaxConnsPerHost = 1
3917 client := &http.Client{
3918 Transport: tr,
3919 }
3920
3921 const clients = 50
3922
3923 var wg sync.WaitGroup
3924 wg.Add(clients)
3925 for i := 0; i < clients; i++ {
3926 req, err := http.NewRequest("POST", ts.URL, bytes.NewBufferString("abcdef"))
3927 if err != nil {
3928 t.Fatalf("unexpected new request error: %v", err)
3929 }
3930
3931 go func() {
3932 defer wg.Done()
3933 res, err := client.Do(req)
3934 if err == nil {
3935 res.Body.Close()
3936 }
3937 }()
3938 }
3939
3940 wg.Wait()
3941 }
3942
3943 type errorReader struct{ err error }
3944
3945 func (r errorReader) Read(p []byte) (int, error) { return 0, r.err }
3946
3947
3948
3949 func TestTransportServerResetStreamAtHeaders(t *testing.T) {
3950 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3951 w.WriteHeader(http.StatusUnauthorized)
3952 return
3953 })
3954
3955 tr := newTransport(t)
3956 tr.MaxConnsPerHost = 1
3957 tr.ExpectContinueTimeout = 10 * time.Second
3958
3959 client := &http.Client{
3960 Transport: tr,
3961 }
3962
3963 req, err := http.NewRequest("POST", ts.URL, errorReader{io.EOF})
3964 if err != nil {
3965 t.Fatalf("unexpected new request error: %v", err)
3966 }
3967 req.ContentLength = 0
3968 req.Header.Set("Expect", "100-continue")
3969 res, err := client.Do(req)
3970 if err != nil {
3971 t.Fatal(err)
3972 }
3973 res.Body.Close()
3974 }
3975
3976 type trackingReader struct {
3977 rdr io.Reader
3978 wasRead uint32
3979 }
3980
3981 func (tr *trackingReader) Read(p []byte) (int, error) {
3982 atomic.StoreUint32(&tr.wasRead, 1)
3983 return tr.rdr.Read(p)
3984 }
3985
3986 func (tr *trackingReader) WasRead() bool {
3987 return atomic.LoadUint32(&tr.wasRead) != 0
3988 }
3989
3990 func TestTransportExpectContinue(t *testing.T) {
3991 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3992 switch r.URL.Path {
3993 case "/reject":
3994 w.WriteHeader(403)
3995 default:
3996 io.Copy(io.Discard, r.Body)
3997 }
3998 })
3999
4000 tr := newTransport(t)
4001 tr.MaxConnsPerHost = 1
4002 tr.ExpectContinueTimeout = 10 * time.Second
4003
4004 client := &http.Client{
4005 Transport: tr,
4006 }
4007
4008 testCases := []struct {
4009 Name string
4010 Path string
4011 Body *trackingReader
4012 ExpectedCode int
4013 ShouldRead bool
4014 }{
4015 {
4016 Name: "read-all",
4017 Path: "/",
4018 Body: &trackingReader{rdr: strings.NewReader("hello")},
4019 ExpectedCode: 200,
4020 ShouldRead: true,
4021 },
4022 {
4023 Name: "reject",
4024 Path: "/reject",
4025 Body: &trackingReader{rdr: strings.NewReader("hello")},
4026 ExpectedCode: 403,
4027 ShouldRead: false,
4028 },
4029 }
4030
4031 for _, tc := range testCases {
4032 t.Run(tc.Name, func(t *testing.T) {
4033 startTime := time.Now()
4034
4035 req, err := http.NewRequest("POST", ts.URL+tc.Path, tc.Body)
4036 if err != nil {
4037 t.Fatal(err)
4038 }
4039 req.Header.Set("Expect", "100-continue")
4040 res, err := client.Do(req)
4041 if err != nil {
4042 t.Fatal(err)
4043 }
4044 res.Body.Close()
4045
4046 if delta := time.Since(startTime); delta >= tr.ExpectContinueTimeout {
4047 t.Error("Request didn't finish before expect continue timeout")
4048 }
4049 if res.StatusCode != tc.ExpectedCode {
4050 t.Errorf("Unexpected status code, got %d, expected %d", res.StatusCode, tc.ExpectedCode)
4051 }
4052 if tc.Body.WasRead() != tc.ShouldRead {
4053 t.Errorf("Unexpected read status, got %v, expected %v", tc.Body.WasRead(), tc.ShouldRead)
4054 }
4055 })
4056 }
4057 }
4058
4059 type closeChecker struct {
4060 io.ReadCloser
4061 closed chan struct{}
4062 }
4063
4064 func newCloseChecker(r io.ReadCloser) *closeChecker {
4065 return &closeChecker{r, make(chan struct{})}
4066 }
4067
4068 func newStaticCloseChecker(body string) *closeChecker {
4069 return newCloseChecker(io.NopCloser(strings.NewReader("body")))
4070 }
4071
4072 func (rc *closeChecker) Read(b []byte) (n int, err error) {
4073 select {
4074 default:
4075 case <-rc.closed:
4076
4077
4078
4079 return 0, errors.New("read after Body.Close")
4080 }
4081 return rc.ReadCloser.Read(b)
4082 }
4083
4084 func (rc *closeChecker) Close() error {
4085 close(rc.closed)
4086 return rc.ReadCloser.Close()
4087 }
4088
4089 func (rc *closeChecker) isClosed() error {
4090
4091
4092
4093 timeout := time.Duration(10 * time.Second)
4094 select {
4095 case <-rc.closed:
4096 case <-time.After(timeout):
4097 return fmt.Errorf("body not closed after %v", timeout)
4098 }
4099 return nil
4100 }
4101
4102
4103 type blockingWriteConn struct {
4104 net.Conn
4105 writeOnce sync.Once
4106 writec chan struct{}
4107 unblockc chan struct{}
4108 count, limit int
4109 }
4110
4111 func newBlockingWriteConn(conn net.Conn, limit int) *blockingWriteConn {
4112 return &blockingWriteConn{
4113 Conn: conn,
4114 limit: limit,
4115 writec: make(chan struct{}),
4116 unblockc: make(chan struct{}),
4117 }
4118 }
4119
4120
4121 func (c *blockingWriteConn) wait() {
4122 <-c.writec
4123 }
4124
4125
4126 func (c *blockingWriteConn) unblock() {
4127 close(c.unblockc)
4128 }
4129
4130 func (c *blockingWriteConn) Write(b []byte) (n int, err error) {
4131 if c.count+len(b) > c.limit {
4132 c.writeOnce.Do(func() {
4133 close(c.writec)
4134 })
4135 <-c.unblockc
4136 }
4137 n, err = c.Conn.Write(b)
4138 c.count += n
4139 return n, err
4140 }
4141
4142
4143
4144 func TestTransportFrameBufferReuse(t *testing.T) {
4145 filler := hex.EncodeToString([]byte(randString(2048)))
4146
4147 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4148 if got, want := r.Header.Get("Big"), filler; got != want {
4149 t.Errorf(`r.Header.Get("Big") = %q, want %q`, got, want)
4150 }
4151 b, err := io.ReadAll(r.Body)
4152 if err != nil {
4153 t.Errorf("error reading request body: %v", err)
4154 }
4155 if got, want := string(b), filler; got != want {
4156 t.Errorf("request body = %q, want %q", got, want)
4157 }
4158 if got, want := r.Trailer.Get("Big"), filler; got != want {
4159 t.Errorf(`r.Trailer.Get("Big") = %q, want %q`, got, want)
4160 }
4161 })
4162
4163 tr := newTransport(t)
4164
4165 var wg sync.WaitGroup
4166 defer wg.Wait()
4167 for i := 0; i < 10; i++ {
4168 wg.Add(1)
4169 go func() {
4170 defer wg.Done()
4171 req, err := http.NewRequest("POST", ts.URL, strings.NewReader(filler))
4172 if err != nil {
4173 t.Error(err)
4174 return
4175 }
4176 req.Header.Set("Big", filler)
4177 req.Trailer = make(http.Header)
4178 req.Trailer.Set("Big", filler)
4179 res, err := tr.RoundTrip(req)
4180 if err != nil {
4181 t.Error(err)
4182 return
4183 }
4184 if got, want := res.StatusCode, 200; got != want {
4185 t.Errorf("StatusCode = %v; want %v", got, want)
4186 }
4187 if res != nil && res.Body != nil {
4188 res.Body.Close()
4189 }
4190 }()
4191 }
4192
4193 }
4194
4195
4196
4197
4198
4199 func TestTransportBlockingRequestWrite(t *testing.T) {
4200 filler := hex.EncodeToString([]byte(randString(2048)))
4201 for _, test := range []struct {
4202 name string
4203 req func(url string) (*http.Request, error)
4204 }{{
4205 name: "headers",
4206 req: func(url string) (*http.Request, error) {
4207 req, err := http.NewRequest("POST", url, nil)
4208 if err != nil {
4209 return nil, err
4210 }
4211 req.Header.Set("Big", filler)
4212 return req, err
4213 },
4214 }, {
4215 name: "body",
4216 req: func(url string) (*http.Request, error) {
4217 req, err := http.NewRequest("POST", url, strings.NewReader(filler))
4218 if err != nil {
4219 return nil, err
4220 }
4221 return req, err
4222 },
4223 }, {
4224 name: "trailer",
4225 req: func(url string) (*http.Request, error) {
4226 req, err := http.NewRequest("POST", url, strings.NewReader("body"))
4227 if err != nil {
4228 return nil, err
4229 }
4230 req.Trailer = make(http.Header)
4231 req.Trailer.Set("Big", filler)
4232 return req, err
4233 },
4234 }} {
4235 test := test
4236 t.Run(test.name, func(t *testing.T) {
4237 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4238 if v := r.Header.Get("Big"); v != "" && v != filler {
4239 t.Errorf("request header mismatch")
4240 }
4241 if v, _ := io.ReadAll(r.Body); len(v) != 0 && string(v) != "body" && string(v) != filler {
4242 t.Errorf("request body mismatch\ngot: %q\nwant: %q", string(v), filler)
4243 }
4244 if v := r.Trailer.Get("Big"); v != "" && v != filler {
4245 t.Errorf("request trailer mismatch\ngot: %q\nwant: %q", string(v), filler)
4246 }
4247 }, func(h2 *http.HTTP2Config) {
4248 h2.MaxConcurrentStreams = 1
4249 }, func(s *http.Server) {
4250 s.Protocols = protocols("h2c")
4251 })
4252
4253
4254 connc := make(chan *blockingWriteConn, 1)
4255 connCount := 0
4256 tr := newTransport(t)
4257 tr.Protocols = protocols("h2c")
4258 tr.Dial = func(network, addr string) (net.Conn, error) {
4259 connCount++
4260 c, err := net.Dial(network, addr)
4261 wc := newBlockingWriteConn(c, 1024)
4262 select {
4263 case connc <- wc:
4264 default:
4265 }
4266 return wc, err
4267 }
4268 t.Log(ts.URL)
4269
4270
4271 {
4272 req, err := http.NewRequest("POST", ts.URL, nil)
4273 if err != nil {
4274 t.Fatal(err)
4275 }
4276 res, err := tr.RoundTrip(req)
4277 if err != nil {
4278 t.Fatal(err)
4279 }
4280 if got, want := res.StatusCode, 200; got != want {
4281 t.Errorf("StatusCode = %v; want %v", got, want)
4282 }
4283 if res != nil && res.Body != nil {
4284 res.Body.Close()
4285 }
4286 }
4287
4288
4289 reqc := make(chan struct{})
4290 go func() {
4291 defer close(reqc)
4292 req, err := test.req(ts.URL)
4293 if err != nil {
4294 t.Error(err)
4295 return
4296 }
4297 res, _ := tr.RoundTrip(req)
4298 if res != nil && res.Body != nil {
4299 res.Body.Close()
4300 }
4301 }()
4302 conn := <-connc
4303 conn.wait()
4304
4305
4306
4307 {
4308 req, err := http.NewRequest("POST", ts.URL, nil)
4309 if err != nil {
4310 t.Fatal(err)
4311 }
4312 res, err := tr.RoundTrip(req)
4313 if err != nil {
4314 t.Fatal(err)
4315 }
4316 if got, want := res.StatusCode, 200; got != want {
4317 t.Errorf("StatusCode = %v; want %v", got, want)
4318 }
4319 if res != nil && res.Body != nil {
4320 res.Body.Close()
4321 }
4322 }
4323
4324
4325 select {
4326 case <-reqc:
4327 t.Errorf("request 2 unexpectedly completed")
4328 default:
4329 }
4330
4331 conn.unblock()
4332 <-reqc
4333
4334 if connCount != 2 {
4335 t.Errorf("created %v connections, want 1", connCount)
4336 }
4337 })
4338 }
4339 }
4340
4341 func TestTransportCloseRequestBody(t *testing.T) {
4342 var statusCode int
4343 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4344 w.WriteHeader(statusCode)
4345 })
4346
4347 tr := newTransport(t)
4348 ctx := context.Background()
4349 cc, err := tr.NewClientConn(ctx, "https", ts.Listener.Addr().String())
4350 if err != nil {
4351 t.Fatal(err)
4352 }
4353 defer cc.Close()
4354
4355 for _, status := range []int{200, 401} {
4356 t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) {
4357 statusCode = status
4358 pr, pw := io.Pipe()
4359 body := newCloseChecker(pr)
4360 req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
4361 if err != nil {
4362 t.Fatal(err)
4363 }
4364 res, err := cc.RoundTrip(req)
4365 if err != nil {
4366 t.Fatal(err)
4367 }
4368 res.Body.Close()
4369 pw.Close()
4370 if err := body.isClosed(); err != nil {
4371 t.Fatal(err)
4372 }
4373 })
4374 }
4375 }
4376
4377 func TestTransportNoRetryOnStreamProtocolError(t *testing.T) {
4378 synctestTest(t, testTransportNoRetryOnStreamProtocolError)
4379 }
4380 func testTransportNoRetryOnStreamProtocolError(t testing.TB) {
4381
4382
4383
4384
4385
4386 tt := newTestTransport(t)
4387
4388
4389
4390
4391
4392
4393 req1, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4394 rt1 := tt.roundTrip(req1)
4395 tc1 := tt.getConn()
4396 tc1.wantFrameType(FrameSettings)
4397 tc1.wantFrameType(FrameWindowUpdate)
4398 tc1.wantHeaders(wantHeader{
4399 streamID: 1,
4400 endStream: true,
4401 })
4402 tc1.writeSettings()
4403 tc1.wantFrameType(FrameSettings)
4404
4405
4406 req2, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4407 rt2 := tt.roundTrip(req2)
4408 tc1.wantHeaders(wantHeader{
4409 streamID: 3,
4410 endStream: true,
4411 })
4412
4413
4414 tc1.writeRSTStream(3, ErrCodeProtocol)
4415 if rt1.done() {
4416 t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #1 is done; want still in progress")
4417 }
4418 if !rt2.done() {
4419 t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #2 is in progress; want done")
4420 }
4421
4422 if tt.hasConn() {
4423 t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #2 is unexpectedly retried")
4424 }
4425
4426
4427 tc1.writeHeaders(HeadersFrameParam{
4428 StreamID: 1,
4429 EndHeaders: true,
4430 EndStream: true,
4431 BlockFragment: tc1.makeHeaderBlockFragment(
4432 ":status", "200",
4433 ),
4434 })
4435 rt1.wantStatus(200)
4436 }
4437
4438 func TestClientConnReservations(t *testing.T) { synctestTest(t, testClientConnReservations) }
4439 func testClientConnReservations(t testing.TB) {
4440 tc := newTestClientConn(t)
4441 tc.greet(
4442 Setting{ID: SettingMaxConcurrentStreams, Val: InitialMaxConcurrentStreams},
4443 )
4444
4445 doRoundTrip := func() {
4446 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4447 rt := tc.roundTrip(req)
4448 tc.wantFrameType(FrameHeaders)
4449 tc.writeHeaders(HeadersFrameParam{
4450 StreamID: rt.streamID(),
4451 EndHeaders: true,
4452 EndStream: true,
4453 BlockFragment: tc.makeHeaderBlockFragment(
4454 ":status", "200",
4455 ),
4456 })
4457 rt.wantStatus(200)
4458 }
4459
4460 n := 0
4461 for n <= InitialMaxConcurrentStreams && tc.cc.ReserveNewRequest() {
4462 n++
4463 }
4464 if n != InitialMaxConcurrentStreams {
4465 t.Errorf("did %v reservations; want %v", n, InitialMaxConcurrentStreams)
4466 }
4467 doRoundTrip()
4468 n2 := 0
4469 for n2 <= 5 && tc.cc.ReserveNewRequest() {
4470 n2++
4471 }
4472 if n2 != 1 {
4473 t.Fatalf("after one RoundTrip, did %v reservations; want 1", n2)
4474 }
4475
4476
4477 for i := 0; i < n; i++ {
4478 doRoundTrip()
4479 }
4480
4481 n2 = 0
4482 for n2 <= InitialMaxConcurrentStreams && tc.cc.ReserveNewRequest() {
4483 n2++
4484 }
4485 if n2 != n {
4486 t.Errorf("after reset, reservations = %v; want %v", n2, n)
4487 }
4488 }
4489
4490 func TestTransportTimeoutServerHangs(t *testing.T) { synctestTest(t, testTransportTimeoutServerHangs) }
4491 func testTransportTimeoutServerHangs(t testing.TB) {
4492 tc := newTestClientConn(t)
4493 tc.greet()
4494
4495 ctx, cancel := context.WithCancel(context.Background())
4496 req, _ := http.NewRequestWithContext(ctx, "PUT", "https://dummy.tld/", nil)
4497 rt := tc.roundTrip(req)
4498
4499 tc.wantFrameType(FrameHeaders)
4500 time.Sleep(5 * time.Second)
4501 if f := tc.readFrame(); f != nil {
4502 t.Fatalf("unexpected frame: %v", f)
4503 }
4504 if rt.done() {
4505 t.Fatalf("after 5 seconds with no response, RoundTrip unexpectedly returned")
4506 }
4507
4508 cancel()
4509 synctest.Wait()
4510 if rt.err() != context.Canceled {
4511 t.Fatalf("RoundTrip error: %v; want context.Canceled", rt.err())
4512 }
4513 }
4514
4515 func TestTransportContentLengthWithoutBody(t *testing.T) {
4516 for _, test := range []struct {
4517 name string
4518 contentLength string
4519 wantBody string
4520 wantErr error
4521 wantContentLength int64
4522 }{
4523 {
4524 name: "non-zero content length",
4525 contentLength: "42",
4526 wantErr: io.ErrUnexpectedEOF,
4527 wantContentLength: 42,
4528 },
4529 {
4530 name: "zero content length",
4531 contentLength: "0",
4532 wantErr: nil,
4533 wantContentLength: 0,
4534 },
4535 } {
4536 synctestSubtest(t, test.name, func(t testing.TB) {
4537 contentLength := ""
4538 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4539 w.Header().Set("Content-Length", contentLength)
4540 })
4541 tr := newTransport(t)
4542
4543 contentLength = test.contentLength
4544
4545 req, _ := http.NewRequest("GET", ts.URL, nil)
4546 res, err := tr.RoundTrip(req)
4547 if err != nil {
4548 t.Fatal(err)
4549 }
4550 defer res.Body.Close()
4551 body, err := io.ReadAll(res.Body)
4552
4553 if err != test.wantErr {
4554 t.Errorf("Expected error %v, got: %v", test.wantErr, err)
4555 }
4556 if len(body) > 0 {
4557 t.Errorf("Expected empty body, got: %v", body)
4558 }
4559 if res.ContentLength != test.wantContentLength {
4560 t.Errorf("Expected content length %d, got: %d", test.wantContentLength, res.ContentLength)
4561 }
4562 })
4563 }
4564 }
4565
4566 func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) {
4567 synctestTest(t, testTransportCloseResponseBodyWhileRequestBodyHangs)
4568 }
4569 func testTransportCloseResponseBodyWhileRequestBodyHangs(t testing.TB) {
4570 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4571 w.WriteHeader(200)
4572 w.(http.Flusher).Flush()
4573 io.Copy(io.Discard, r.Body)
4574 })
4575
4576 tr := newTransport(t)
4577
4578 pr, pw := net.Pipe()
4579 req, err := http.NewRequest("GET", ts.URL, pr)
4580 if err != nil {
4581 t.Fatal(err)
4582 }
4583 res, err := tr.RoundTrip(req)
4584 if err != nil {
4585 t.Fatal(err)
4586 }
4587
4588 res.Body.Close()
4589 pw.Close()
4590 }
4591
4592 func TestTransport300ResponseBody(t *testing.T) { synctestTest(t, testTransport300ResponseBody) }
4593 func testTransport300ResponseBody(t testing.TB) {
4594 reqc := make(chan struct{})
4595 body := []byte("response body")
4596 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4597 w.WriteHeader(300)
4598 w.(http.Flusher).Flush()
4599 <-reqc
4600 w.Write(body)
4601 })
4602
4603 tr := newTransport(t)
4604
4605 pr, pw := net.Pipe()
4606 req, err := http.NewRequest("GET", ts.URL, pr)
4607 if err != nil {
4608 t.Fatal(err)
4609 }
4610 res, err := tr.RoundTrip(req)
4611 if err != nil {
4612 t.Fatal(err)
4613 }
4614 close(reqc)
4615 got, err := io.ReadAll(res.Body)
4616 if err != nil {
4617 t.Fatalf("error reading response body: %v", err)
4618 }
4619 if !bytes.Equal(got, body) {
4620 t.Errorf("got response body %q, want %q", string(got), string(body))
4621 }
4622 res.Body.Close()
4623 pw.Close()
4624 }
4625
4626 func TestTransportWriteByteTimeout(t *testing.T) {
4627 ts := newTestServer(t, nil, func(s *http.Server) {
4628 s.Protocols = protocols("h2c")
4629 })
4630 tr := newTransport(t)
4631 tr.Protocols = protocols("h2c")
4632 tr.Dial = func(network, addr string) (net.Conn, error) {
4633 _, c := net.Pipe()
4634 return c, nil
4635 }
4636 tr.HTTP2.WriteByteTimeout = 1 * time.Millisecond
4637 defer tr.CloseIdleConnections()
4638 c := &http.Client{Transport: tr}
4639
4640 _, err := c.Get(ts.URL)
4641 if !errors.Is(err, os.ErrDeadlineExceeded) {
4642 t.Fatalf("Get on unresponsive connection: got %q; want ErrDeadlineExceeded", err)
4643 }
4644 }
4645
4646 type slowWriteConn struct {
4647 net.Conn
4648 hasWriteDeadline bool
4649 }
4650
4651 func (c *slowWriteConn) SetWriteDeadline(t time.Time) error {
4652 c.hasWriteDeadline = !t.IsZero()
4653 return nil
4654 }
4655
4656 func (c *slowWriteConn) Write(b []byte) (n int, err error) {
4657 if c.hasWriteDeadline && len(b) > 1 {
4658 n, err = c.Conn.Write(b[:1])
4659 if err != nil {
4660 return n, err
4661 }
4662 return n, fmt.Errorf("slow write: %w", os.ErrDeadlineExceeded)
4663 }
4664 return c.Conn.Write(b)
4665 }
4666
4667 func TestTransportSlowWrites(t *testing.T) { synctestTest(t, testTransportSlowWrites) }
4668 func testTransportSlowWrites(t testing.TB) {
4669 ts := newTestServer(t, nil, func(s *http.Server) {
4670 s.Protocols = protocols("h2c")
4671 })
4672 tr := newTransport(t)
4673 tr.Protocols = protocols("h2c")
4674 tr.Dial = func(network, addr string) (net.Conn, error) {
4675 c, err := net.Dial(network, addr)
4676 return &slowWriteConn{Conn: c}, err
4677 }
4678 tr.HTTP2.WriteByteTimeout = 1 * time.Millisecond
4679 c := &http.Client{Transport: tr}
4680
4681 const bodySize = 1 << 20
4682 resp, err := c.Post(ts.URL, "text/foo", io.LimitReader(neverEnding('A'), bodySize))
4683 if err != nil {
4684 t.Fatal(err)
4685 }
4686 resp.Body.Close()
4687 }
4688
4689 func TestTransportClosesConnAfterGoAwayNoStreams(t *testing.T) {
4690 synctestTest(t, func(t testing.TB) {
4691 testTransportClosesConnAfterGoAway(t, 0)
4692 })
4693 }
4694 func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) {
4695 synctestTest(t, func(t testing.TB) {
4696 testTransportClosesConnAfterGoAway(t, 1)
4697 })
4698 }
4699
4700
4701
4702
4703
4704
4705
4706 func testTransportClosesConnAfterGoAway(t testing.TB, lastStream uint32) {
4707 tc := newTestClientConn(t)
4708 tc.greet()
4709
4710 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4711 rt := tc.roundTrip(req)
4712
4713 tc.wantFrameType(FrameHeaders)
4714 tc.writeGoAway(lastStream, ErrCodeNo, nil)
4715
4716 if lastStream > 0 {
4717
4718 tc.writeHeaders(HeadersFrameParam{
4719 StreamID: rt.streamID(),
4720 EndHeaders: true,
4721 EndStream: true,
4722 BlockFragment: tc.makeHeaderBlockFragment(
4723 ":status", "200",
4724 ),
4725 })
4726 }
4727
4728 tc.closeWrite()
4729 err := rt.err()
4730 if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr {
4731 t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr)
4732 }
4733 if !tc.isClosed() {
4734 t.Errorf("ClientConn did not close its net.Conn, expected it to")
4735 }
4736 }
4737
4738 type slowCloser struct {
4739 closing chan struct{}
4740 closed chan struct{}
4741 }
4742
4743 func (r *slowCloser) Read([]byte) (int, error) {
4744 return 0, io.EOF
4745 }
4746
4747 func (r *slowCloser) Close() error {
4748 close(r.closing)
4749 <-r.closed
4750 return nil
4751 }
4752
4753 func TestTransportSlowClose(t *testing.T) {
4754 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4755 })
4756
4757 client := ts.Client()
4758 body := &slowCloser{
4759 closing: make(chan struct{}),
4760 closed: make(chan struct{}),
4761 }
4762
4763 reqc := make(chan struct{})
4764 go func() {
4765 defer close(reqc)
4766 res, err := client.Post(ts.URL, "text/plain", body)
4767 if err != nil {
4768 t.Error(err)
4769 }
4770 res.Body.Close()
4771 }()
4772 defer func() {
4773 close(body.closed)
4774 <-reqc
4775 }()
4776
4777 <-body.closing
4778
4779 res, err := client.Get(ts.URL)
4780 if err != nil {
4781 t.Fatal(err)
4782 }
4783 res.Body.Close()
4784 }
4785
4786 func TestTransportDialTLSContext(t *testing.T) {
4787 blockCh := make(chan struct{})
4788 serverTLSConfigFunc := func(ts *httptest.Server) {
4789 ts.Config.TLSConfig = &tls.Config{
4790
4791
4792 ClientAuth: tls.RequestClientCert,
4793 }
4794 }
4795 ts := newTestServer(t,
4796 func(w http.ResponseWriter, r *http.Request) {},
4797 serverTLSConfigFunc,
4798 )
4799 tr := newTransport(t)
4800 tr.TLSClientConfig = &tls.Config{
4801 GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
4802
4803
4804 close(blockCh)
4805 <-cri.Context().Done()
4806 return nil, cri.Context().Err()
4807 },
4808 InsecureSkipVerify: true,
4809 }
4810 req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
4811 if err != nil {
4812 t.Fatal(err)
4813 }
4814 ctx, cancel := context.WithCancel(context.Background())
4815 defer cancel()
4816 req = req.WithContext(ctx)
4817 errCh := make(chan error)
4818 go func() {
4819 defer close(errCh)
4820 res, err := tr.RoundTrip(req)
4821 if err != nil {
4822 errCh <- err
4823 return
4824 }
4825 res.Body.Close()
4826 }()
4827
4828 <-blockCh
4829
4830 cancel()
4831
4832 err = <-errCh
4833 if err == nil {
4834 t.Fatal("cancelling context during client certificate fetch did not error as expected")
4835 return
4836 }
4837 if !errors.Is(err, context.Canceled) {
4838 t.Fatalf("unexpected error returned after cancellation: %v", err)
4839 }
4840 }
4841
4842
4843
4844
4845
4846 func TestDialRaceResumesDial(t *testing.T) {
4847 t.Skip("https://go.dev/issue/77908: test fails when using an http.Transport")
4848 blockCh := make(chan struct{})
4849 serverTLSConfigFunc := func(ts *httptest.Server) {
4850 ts.Config.TLSConfig = &tls.Config{
4851
4852
4853 ClientAuth: tls.RequestClientCert,
4854 }
4855 }
4856 ts := newTestServer(t,
4857 func(w http.ResponseWriter, r *http.Request) {},
4858 serverTLSConfigFunc,
4859 )
4860 tr := newTransport(t)
4861 tr.TLSClientConfig = &tls.Config{
4862 GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
4863 select {
4864 case <-blockCh:
4865
4866 return &tls.Certificate{}, nil
4867 default:
4868 }
4869 close(blockCh)
4870 <-cri.Context().Done()
4871 return nil, cri.Context().Err()
4872 },
4873 InsecureSkipVerify: true,
4874 }
4875 req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
4876 if err != nil {
4877 t.Fatal(err)
4878 }
4879
4880 ctx1, cancel1 := context.WithCancel(context.Background())
4881 defer cancel1()
4882 req1 := req.WithContext(ctx1)
4883 ctx2, cancel2 := context.WithCancel(context.Background())
4884 defer cancel2()
4885 req2 := req.WithContext(ctx2)
4886 errCh := make(chan error)
4887 go func() {
4888 res, err := tr.RoundTrip(req1)
4889 if err != nil {
4890 errCh <- err
4891 return
4892 }
4893 res.Body.Close()
4894 }()
4895 successCh := make(chan struct{})
4896 go func() {
4897
4898
4899 <-blockCh
4900 res, err := tr.RoundTrip(req2)
4901 if err != nil {
4902 errCh <- err
4903 return
4904 }
4905 res.Body.Close()
4906
4907
4908 close(successCh)
4909 }()
4910
4911 <-blockCh
4912
4913 cancel1()
4914
4915 err = <-errCh
4916 if err == nil {
4917 t.Fatal("cancelling context during client certificate fetch did not error as expected")
4918 return
4919 }
4920 if !errors.Is(err, context.Canceled) {
4921 t.Fatalf("unexpected error returned after cancellation: %v", err)
4922 }
4923 select {
4924 case err := <-errCh:
4925 t.Fatalf("unexpected second error: %v", err)
4926 case <-successCh:
4927 }
4928 }
4929
4930 func TestTransportDataAfter1xxHeader(t *testing.T) { synctestTest(t, testTransportDataAfter1xxHeader) }
4931 func testTransportDataAfter1xxHeader(t testing.TB) {
4932
4933 log.SetOutput(io.Discard)
4934 defer log.SetOutput(os.Stderr)
4935
4936
4937 tc := newTestClientConn(t)
4938 tc.greet()
4939
4940 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4941 rt := tc.roundTrip(req)
4942
4943 tc.wantFrameType(FrameHeaders)
4944 tc.writeHeaders(HeadersFrameParam{
4945 StreamID: rt.streamID(),
4946 EndHeaders: true,
4947 EndStream: false,
4948 BlockFragment: tc.makeHeaderBlockFragment(
4949 ":status", "100",
4950 ),
4951 })
4952 tc.writeData(rt.streamID(), true, []byte{0})
4953 err := rt.err()
4954 if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol {
4955 t.Errorf("RoundTrip error: %v; want ErrCodeProtocol", err)
4956 }
4957 tc.wantFrameType(FrameRSTStream)
4958 }
4959
4960 func TestIssue66763Race(t *testing.T) {
4961 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {},
4962 func(s *http.Server) {
4963 s.Protocols = protocols("h2c")
4964 })
4965 tr := newTransport(t)
4966 tr.IdleConnTimeout = 1 * time.Nanosecond
4967 tr.Protocols = protocols("h2c")
4968
4969 donec := make(chan struct{})
4970 go func() {
4971
4972
4973
4974 conn, err := tr.NewClientConn(t.Context(), "http", ts.URL)
4975 close(donec)
4976 if err == nil {
4977 conn.Close()
4978 }
4979 }()
4980
4981
4982
4983 <-donec
4984 }
4985
4986
4987
4988 func TestIssue67671(t *testing.T) {
4989 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {},
4990 func(s *http.Server) {
4991 s.Protocols = protocols("h2c")
4992 })
4993 tr := newTransport(t)
4994 tr.Protocols = protocols("h2c")
4995 req, _ := http.NewRequest("GET", ts.URL, nil)
4996 req.Close = true
4997 for i := 0; i < 2; i++ {
4998 res, err := tr.RoundTrip(req)
4999 if err != nil {
5000 t.Fatal(err)
5001 }
5002 res.Body.Close()
5003 }
5004 }
5005
5006 func TestTransport1xxLimits(t *testing.T) {
5007 for _, test := range []struct {
5008 name string
5009 opt any
5010 ctxfn func(context.Context) context.Context
5011 hcount int
5012 limited bool
5013 }{{
5014 name: "default",
5015 hcount: 10,
5016 limited: false,
5017 }, {
5018 name: "MaxResponseHeaderBytes",
5019 opt: func(tr *http.Transport) {
5020 tr.MaxResponseHeaderBytes = 10000
5021 },
5022 hcount: 10,
5023 limited: true,
5024 }, {
5025 name: "limit by client trace",
5026 ctxfn: func(ctx context.Context) context.Context {
5027 count := 0
5028 return httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
5029 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
5030 count++
5031 if count >= 10 {
5032 return errors.New("too many 1xx")
5033 }
5034 return nil
5035 },
5036 })
5037 },
5038 hcount: 10,
5039 limited: true,
5040 }, {
5041 name: "limit disabled by client trace",
5042 opt: func(tr *http.Transport) {
5043 tr.MaxResponseHeaderBytes = 10000
5044 },
5045 ctxfn: func(ctx context.Context) context.Context {
5046 return httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
5047 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
5048 return nil
5049 },
5050 })
5051 },
5052 hcount: 20,
5053 limited: false,
5054 }} {
5055 synctestSubtest(t, test.name, func(t testing.TB) {
5056 tc := newTestClientConn(t, test.opt)
5057 tc.greet()
5058
5059 ctx := context.Background()
5060 if test.ctxfn != nil {
5061 ctx = test.ctxfn(ctx)
5062 }
5063 req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
5064 rt := tc.roundTrip(req)
5065 tc.wantFrameType(FrameHeaders)
5066
5067 for i := 0; i < test.hcount; i++ {
5068 if fr, err := tc.fr.ReadFrame(); err != os.ErrDeadlineExceeded {
5069 t.Fatalf("after writing %v 1xx headers: read %v, %v; want idle", i, fr, err)
5070 }
5071 tc.writeHeaders(HeadersFrameParam{
5072 StreamID: rt.streamID(),
5073 EndHeaders: true,
5074 EndStream: false,
5075 BlockFragment: tc.makeHeaderBlockFragment(
5076 ":status", "103",
5077 "x-field", strings.Repeat("a", 1000),
5078 ),
5079 })
5080 }
5081 if test.limited {
5082 tc.wantFrameType(FrameRSTStream)
5083 } else {
5084 tc.wantIdle()
5085 }
5086 })
5087 }
5088 }
5089
5090
5091
5092 func TestTransportSendPingWithReset(t *testing.T) { synctestTest(t, testTransportSendPingWithReset) }
5093 func testTransportSendPingWithReset(t testing.TB) {
5094 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
5095 h2.StrictMaxConcurrentRequests = true
5096 })
5097
5098 const maxConcurrent = 3
5099 tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
5100
5101
5102 var rts []*testRoundTrip
5103 for i := range maxConcurrent + 1 {
5104 req := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5105 rt := tc.roundTrip(req)
5106 if i >= maxConcurrent {
5107 tc.wantIdle()
5108 continue
5109 }
5110 tc.wantFrameType(FrameHeaders)
5111 rts = append(rts, rt)
5112 }
5113
5114
5115 rts[0].cancel()
5116 tc.wantRSTStream(rts[0].streamID(), ErrCodeCancel)
5117 pf := readFrame[*PingFrame](t, tc)
5118 tc.wantIdle()
5119
5120
5121 rts[1].cancel()
5122 tc.wantRSTStream(rts[1].streamID(), ErrCodeCancel)
5123 tc.wantIdle()
5124
5125
5126
5127 tc.writePing(true, pf.Data)
5128 tc.wantFrameType(FrameHeaders)
5129 tc.wantIdle()
5130 }
5131
5132
5133
5134
5135
5136 func TestTransportNoPingAfterResetWithFrames(t *testing.T) {
5137 synctestTest(t, testTransportNoPingAfterResetWithFrames)
5138 }
5139 func testTransportNoPingAfterResetWithFrames(t testing.TB) {
5140 tc := newTestClientConn(t, func(h2 *http.HTTP2Config) {
5141 h2.StrictMaxConcurrentRequests = true
5142 })
5143
5144 const maxConcurrent = 1
5145 tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
5146
5147
5148
5149 req1 := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5150 rt1 := tc.roundTrip(req1)
5151 tc.wantFrameType(FrameHeaders)
5152 tc.writeHeaders(HeadersFrameParam{
5153 StreamID: rt1.streamID(),
5154 EndHeaders: true,
5155 BlockFragment: tc.makeHeaderBlockFragment(
5156 ":status", "200",
5157 ),
5158 })
5159 rt1.wantStatus(200)
5160
5161
5162
5163 req2 := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5164 rt2 := tc.roundTrip(req2)
5165 tc.wantIdle()
5166
5167
5168
5169 rt1.cancel()
5170 tc.wantRSTStream(rt1.streamID(), ErrCodeCancel)
5171 tc.wantFrameType(FrameHeaders)
5172
5173
5174
5175
5176 rt2.cancel()
5177 tc.wantRSTStream(rt2.streamID(), ErrCodeCancel)
5178 tc.wantFrameType(FramePing)
5179 }
5180
5181
5182
5183 func TestTransportSendNoMoreThanOnePingWithReset(t *testing.T) {
5184 synctestTest(t, testTransportSendNoMoreThanOnePingWithReset)
5185 }
5186 func testTransportSendNoMoreThanOnePingWithReset(t testing.TB) {
5187 tc := newTestClientConn(t)
5188 tc.greet()
5189
5190 makeAndResetRequest := func() {
5191 t.Helper()
5192 ctx, cancel := context.WithCancel(context.Background())
5193 req := Must(http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil))
5194 rt := tc.roundTrip(req)
5195 tc.wantFrameType(FrameHeaders)
5196 cancel()
5197 tc.wantRSTStream(rt.streamID(), ErrCodeCancel)
5198 }
5199
5200
5201
5202 makeAndResetRequest()
5203 pf1 := readFrame[*PingFrame](t, tc)
5204 tc.wantIdle()
5205
5206
5207
5208
5209
5210 makeAndResetRequest()
5211 tc.wantIdle()
5212
5213
5214
5215 tc.writeHeaders(HeadersFrameParam{
5216 StreamID: 1,
5217 EndHeaders: true,
5218 EndStream: true,
5219 BlockFragment: tc.makeHeaderBlockFragment(
5220 ":status", "200",
5221 ),
5222 })
5223 tc.wantIdle()
5224
5225
5226
5227
5228 makeAndResetRequest()
5229 tc.wantIdle()
5230
5231
5232 tc.writePing(true, pf1.Data)
5233 tc.wantIdle()
5234
5235
5236
5237
5238 makeAndResetRequest()
5239 tc.wantIdle()
5240
5241
5242 tc.writeHeaders(HeadersFrameParam{
5243 StreamID: 3,
5244 EndHeaders: true,
5245 EndStream: true,
5246 BlockFragment: tc.makeHeaderBlockFragment(
5247 ":status", "200",
5248 ),
5249 })
5250 tc.wantIdle()
5251
5252
5253
5254 makeAndResetRequest()
5255 tc.wantFrameType(FramePing)
5256 }
5257
5258 func TestTransportConnBecomesUnresponsive(t *testing.T) {
5259 synctestTest(t, testTransportConnBecomesUnresponsive)
5260 }
5261 func testTransportConnBecomesUnresponsive(t testing.TB) {
5262
5263
5264
5265 tt := newTestTransport(t)
5266
5267 const maxConcurrent = 3
5268
5269 t.Logf("first request opens a new connection and succeeds")
5270 req1 := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5271 rt1 := tt.roundTrip(req1)
5272 tc1 := tt.getConn()
5273 tc1.wantFrameType(FrameSettings)
5274 tc1.wantFrameType(FrameWindowUpdate)
5275 hf1 := readFrame[*HeadersFrame](t, tc1)
5276 tc1.writeSettings(Setting{SettingMaxConcurrentStreams, maxConcurrent})
5277 tc1.wantFrameType(FrameSettings)
5278 tc1.writeHeaders(HeadersFrameParam{
5279 StreamID: hf1.StreamID,
5280 EndHeaders: true,
5281 EndStream: true,
5282 BlockFragment: tc1.makeHeaderBlockFragment(
5283 ":status", "200",
5284 ),
5285 })
5286 rt1.wantStatus(200)
5287 rt1.response().Body.Close()
5288
5289
5290
5291
5292 for i := 0; i < maxConcurrent; i++ {
5293 t.Logf("request %v receives no response and is canceled", i)
5294 ctx, cancel := context.WithCancel(context.Background())
5295 req := Must(http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil))
5296 tt.roundTrip(req)
5297 if tt.hasConn() {
5298 t.Fatalf("new connection created; expect existing conn to be reused")
5299 }
5300 tc1.wantFrameType(FrameHeaders)
5301 cancel()
5302 tc1.wantFrameType(FrameRSTStream)
5303 if i == 0 {
5304 tc1.wantFrameType(FramePing)
5305 }
5306 tc1.wantIdle()
5307 }
5308
5309
5310
5311 req2 := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5312 rt2 := tt.roundTrip(req2)
5313 tc2 := tt.getConn()
5314 tc2.wantFrameType(FrameSettings)
5315 tc2.wantFrameType(FrameWindowUpdate)
5316 hf := readFrame[*HeadersFrame](t, tc2)
5317 tc2.writeSettings(Setting{SettingMaxConcurrentStreams, maxConcurrent})
5318 tc2.wantFrameType(FrameSettings)
5319 tc2.writeHeaders(HeadersFrameParam{
5320 StreamID: hf.StreamID,
5321 EndHeaders: true,
5322 EndStream: true,
5323 BlockFragment: tc2.makeHeaderBlockFragment(
5324 ":status", "200",
5325 ),
5326 })
5327 rt2.wantStatus(200)
5328 rt2.response().Body.Close()
5329 }
5330
5331
5332
5333
5334
5335
5336 func newTestTransportWithUnusedConn(t testing.TB, opts ...any) *testTransport {
5337 tt := newTestTransport(t, opts...)
5338
5339 waitc := make(chan struct{})
5340 dialContext := tt.tr1.DialContext
5341 tt.tr1.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
5342 <-waitc
5343 return dialContext(ctx, network, address)
5344 }
5345
5346 req := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5347 rt := tt.roundTrip(req)
5348 rt.cancel()
5349 if rt.err() == nil {
5350 t.Fatalf("RoundTrip still running after request is canceled")
5351 }
5352
5353 close(waitc)
5354 synctest.Wait()
5355 return tt
5356 }
5357
5358
5359 func TestTransportUnusedConnOK(t *testing.T) { synctestTest(t, testTransportUnusedConnOK) }
5360 func testTransportUnusedConnOK(t testing.TB) {
5361 tt := newTestTransportWithUnusedConn(t)
5362
5363 req := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5364 tc := tt.getConn()
5365 tc.wantFrameType(FrameSettings)
5366 tc.wantFrameType(FrameWindowUpdate)
5367
5368
5369
5370 rt := tt.roundTrip(req)
5371 tc.wantHeaders(wantHeader{
5372 streamID: 1,
5373 endStream: true,
5374 header: http.Header{
5375 ":authority": []string{"dummy.tld"},
5376 ":method": []string{"GET"},
5377 ":path": []string{"/"},
5378 },
5379 })
5380
5381 tc.writeSettings()
5382 tc.writeSettingsAck()
5383 tc.wantFrameType(FrameSettings)
5384
5385 tc.writeHeaders(HeadersFrameParam{
5386 StreamID: 1,
5387 EndHeaders: true,
5388 EndStream: true,
5389 BlockFragment: tc.makeHeaderBlockFragment(
5390 ":status", "200",
5391 ),
5392 })
5393 rt.wantStatus(200)
5394 rt.wantBody(nil)
5395 }
5396
5397
5398 func TestTransportUnusedConnImmediateFailureUsed(t *testing.T) {
5399 synctestTest(t, testTransportUnusedConnImmediateFailureUsed)
5400 }
5401 func testTransportUnusedConnImmediateFailureUsed(t testing.TB) {
5402 tt := newTestTransportWithUnusedConn(t)
5403
5404
5405 tc1 := tt.getConn()
5406 tc1.closeWrite()
5407
5408
5409
5410
5411 req := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5412 rt := tt.roundTrip(req)
5413 if err := rt.err(); err == nil || errors.Is(err, ErrNoCachedConn) {
5414 t.Fatalf("RoundTrip with broken conn: got %v, want an error other than ErrNoCachedConn", err)
5415 }
5416
5417
5418
5419
5420 _ = tt.roundTrip(req)
5421 tc2 := tt.getConn()
5422 tc2.wantFrameType(FrameSettings)
5423 tc2.wantFrameType(FrameWindowUpdate)
5424 tc2.wantFrameType(FrameHeaders)
5425 }
5426
5427
5428 func TestTransportUnusedConnIdleTimoutBeforeUse(t *testing.T) {
5429 synctestTest(t, testTransportUnusedConnIdleTimoutBeforeUse)
5430 }
5431 func testTransportUnusedConnIdleTimoutBeforeUse(t testing.TB) {
5432 tt := newTestTransportWithUnusedConn(t, func(t1 *http.Transport) {
5433 t1.IdleConnTimeout = 1 * time.Second
5434 })
5435
5436 _ = tt.getConn()
5437
5438
5439 time.Sleep(2 * time.Second)
5440 synctest.Wait()
5441
5442
5443
5444
5445
5446 req := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5447 _ = tt.roundTrip(req)
5448 tc2 := tt.getConn()
5449 tc2.wantFrameType(FrameSettings)
5450 tc2.wantFrameType(FrameWindowUpdate)
5451 tc2.wantFrameType(FrameHeaders)
5452 }
5453
5454
5455
5456 func TestTransportTLSNextProtoConnImmediateFailureUnused(t *testing.T) {
5457 synctestTest(t, testTransportTLSNextProtoConnImmediateFailureUnused)
5458 }
5459 func testTransportTLSNextProtoConnImmediateFailureUnused(t testing.TB) {
5460 tt := newTestTransportWithUnusedConn(t, func(t1 *http.Transport) {
5461 t1.IdleConnTimeout = 1 * time.Second
5462 })
5463
5464
5465 tc1 := tt.getConn()
5466 tc1.closeWrite()
5467
5468
5469
5470 time.Sleep(10 * time.Second)
5471
5472
5473
5474
5475 req := Must(http.NewRequest("GET", "https://dummy.tld/", nil))
5476 _ = tt.roundTrip(req)
5477 tc2 := tt.getConn()
5478 tc2.wantFrameType(FrameSettings)
5479 tc2.wantFrameType(FrameWindowUpdate)
5480 tc2.wantFrameType(FrameHeaders)
5481 }
5482
5483 func TestExtendedConnectClientWithServerSupport(t *testing.T) {
5484 t.Skip("https://go.dev/issue/53208 -- net/http needs to support the :protocol header")
5485 SetDisableExtendedConnectProtocol(t, false)
5486 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
5487 if r.Header.Get(":protocol") != "extended-connect" {
5488 t.Fatalf("unexpected :protocol header received")
5489 }
5490 t.Log(io.Copy(w, r.Body))
5491 })
5492 tr := newTransport(t)
5493 pr, pw := io.Pipe()
5494 pwDone := make(chan struct{})
5495 req, _ := http.NewRequest("CONNECT", ts.URL, pr)
5496 req.Header.Set(":protocol", "extended-connect")
5497 req.Header.Set("X-A", "A")
5498 req.Header.Set("X-B", "B")
5499 req.Header.Set("X-C", "C")
5500 go func() {
5501 pw.Write([]byte("hello, extended connect"))
5502 pw.Close()
5503 close(pwDone)
5504 }()
5505
5506 res, err := tr.RoundTrip(req)
5507 if err != nil {
5508 t.Fatal(err)
5509 }
5510 body, err := io.ReadAll(res.Body)
5511 if err != nil {
5512 t.Fatal(err)
5513 }
5514 if !bytes.Equal(body, []byte("hello, extended connect")) {
5515 t.Fatal("unexpected body received")
5516 }
5517 }
5518
5519 func TestExtendedConnectClientWithoutServerSupport(t *testing.T) {
5520 t.Skip("https://go.dev/issue/53208 -- net/http needs to support the :protocol header")
5521 SetDisableExtendedConnectProtocol(t, true)
5522 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
5523 io.Copy(w, r.Body)
5524 })
5525 tr := newTransport(t)
5526 pr, pw := io.Pipe()
5527 pwDone := make(chan struct{})
5528 req, _ := http.NewRequest("CONNECT", ts.URL, pr)
5529 req.Header.Set(":protocol", "extended-connect")
5530 req.Header.Set("X-A", "A")
5531 req.Header.Set("X-B", "B")
5532 req.Header.Set("X-C", "C")
5533 go func() {
5534 pw.Write([]byte("hello, extended connect"))
5535 pw.Close()
5536 close(pwDone)
5537 }()
5538
5539 _, err := tr.RoundTrip(req)
5540 if !errors.Is(err, ErrExtendedConnectNotSupported) {
5541 t.Fatalf("expected error errExtendedConnectNotSupported, got: %v", err)
5542 }
5543 }
5544
5545
5546
5547 func TestExtendedConnectReadFrameError(t *testing.T) {
5548 synctestTest(t, testExtendedConnectReadFrameError)
5549 }
5550 func testExtendedConnectReadFrameError(t testing.TB) {
5551 t.Skip("https://go.dev/issue/53208 -- net/http needs to support the :protocol header")
5552 tc := newTestClientConn(t)
5553 tc.wantFrameType(FrameSettings)
5554 tc.wantFrameType(FrameWindowUpdate)
5555
5556 req, _ := http.NewRequest("CONNECT", "https://dummy.tld/", nil)
5557 req.Header.Set(":protocol", "extended-connect")
5558 rt := tc.roundTrip(req)
5559 tc.wantIdle()
5560
5561 tc.closeWrite()
5562 if !rt.done() {
5563 t.Fatalf("after connection closed: RoundTrip still running; want done")
5564 }
5565 if rt.err() == nil {
5566 t.Fatalf("after connection closed: RoundTrip succeeded; want error")
5567 }
5568 }
5569
View as plain text