Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 crand "crypto/rand"
16 "crypto/tls"
17 "crypto/x509"
18 "encoding/json"
19 "errors"
20 "fmt"
21 "internal/testenv"
22 "io"
23 "log"
24 "math/rand"
25 "mime/multipart"
26 "net"
27 . "net/http"
28 "net/http/httptest"
29 "net/http/httptrace"
30 "net/http/httputil"
31 "net/http/internal"
32 "net/http/internal/testcert"
33 "net/url"
34 "os"
35 "path/filepath"
36 "reflect"
37 "regexp"
38 "runtime"
39 "slices"
40 "strconv"
41 "strings"
42 "sync"
43 "sync/atomic"
44 "syscall"
45 "testing"
46 "testing/synctest"
47 "time"
48 )
49
50 type dummyAddr string
51 type oneConnListener struct {
52 conn net.Conn
53 }
54
55 func (l *oneConnListener) Accept() (c net.Conn, err error) {
56 c = l.conn
57 if c == nil {
58 err = io.EOF
59 return
60 }
61 err = nil
62 l.conn = nil
63 return
64 }
65
66 func (l *oneConnListener) Close() error {
67 return nil
68 }
69
70 func (l *oneConnListener) Addr() net.Addr {
71 return dummyAddr("test-address")
72 }
73
74 func (a dummyAddr) Network() string {
75 return string(a)
76 }
77
78 func (a dummyAddr) String() string {
79 return string(a)
80 }
81
82 type noopConn struct{}
83
84 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
85 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
86 func (noopConn) SetDeadline(t time.Time) error { return nil }
87 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
88 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
89
90 type rwTestConn struct {
91 io.Reader
92 io.Writer
93 noopConn
94
95 closeFunc func() error
96 closec chan bool
97 }
98
99 func (c *rwTestConn) Close() error {
100 if c.closeFunc != nil {
101 return c.closeFunc()
102 }
103 select {
104 case c.closec <- true:
105 default:
106 }
107 return nil
108 }
109
110 type testConn struct {
111 readMu sync.Mutex
112 readBuf bytes.Buffer
113 writeBuf bytes.Buffer
114 closec chan bool
115 noopConn
116 }
117
118 func newTestConn() *testConn {
119 return &testConn{closec: make(chan bool, 1)}
120 }
121
122 func (c *testConn) Read(b []byte) (int, error) {
123 c.readMu.Lock()
124 defer c.readMu.Unlock()
125 return c.readBuf.Read(b)
126 }
127
128 func (c *testConn) Write(b []byte) (int, error) {
129 return c.writeBuf.Write(b)
130 }
131
132 func (c *testConn) Close() error {
133 select {
134 case c.closec <- true:
135 default:
136 }
137 return nil
138 }
139
140
141
142 func reqBytes(req string) []byte {
143 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
144 }
145
146 type handlerTest struct {
147 logbuf bytes.Buffer
148 handler Handler
149 }
150
151 func newHandlerTest(h Handler) handlerTest {
152 return handlerTest{handler: h}
153 }
154
155 func (ht *handlerTest) rawResponse(req string) string {
156 reqb := reqBytes(req)
157 var output strings.Builder
158 conn := &rwTestConn{
159 Reader: bytes.NewReader(reqb),
160 Writer: &output,
161 closec: make(chan bool, 1),
162 }
163 ln := &oneConnListener{conn: conn}
164 srv := &Server{
165 ErrorLog: log.New(&ht.logbuf, "", 0),
166 Handler: ht.handler,
167 }
168 go srv.Serve(ln)
169 <-conn.closec
170 return output.String()
171 }
172
173 func TestConsumingBodyOnNextConn(t *testing.T) {
174 t.Parallel()
175 defer afterTest(t)
176 conn := new(testConn)
177 for i := 0; i < 2; i++ {
178 conn.readBuf.Write([]byte(
179 "POST / HTTP/1.1\r\n" +
180 "Host: test\r\n" +
181 "Content-Length: 11\r\n" +
182 "\r\n" +
183 "foo=1&bar=1"))
184 }
185
186 reqNum := 0
187 ch := make(chan *Request)
188 servech := make(chan error)
189 listener := &oneConnListener{conn}
190 handler := func(res ResponseWriter, req *Request) {
191 reqNum++
192 ch <- req
193 }
194
195 go func() {
196 servech <- Serve(listener, HandlerFunc(handler))
197 }()
198
199 var req *Request
200 req = <-ch
201 if req == nil {
202 t.Fatal("Got nil first request.")
203 }
204 if req.Method != "POST" {
205 t.Errorf("For request #1's method, got %q; expected %q",
206 req.Method, "POST")
207 }
208
209 req = <-ch
210 if req == nil {
211 t.Fatal("Got nil first request.")
212 }
213 if req.Method != "POST" {
214 t.Errorf("For request #2's method, got %q; expected %q",
215 req.Method, "POST")
216 }
217
218 if serveerr := <-servech; serveerr != io.EOF {
219 t.Errorf("Serve returned %q; expected EOF", serveerr)
220 }
221 }
222
223 type stringHandler string
224
225 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
226 w.Header().Set("Result", string(s))
227 }
228
229 var handlers = []struct {
230 pattern string
231 msg string
232 }{
233 {"/", "Default"},
234 {"/someDir/", "someDir"},
235 {"/#/", "hash"},
236 {"someHost.com/someDir/", "someHost.com/someDir"},
237 }
238
239 var vtests = []struct {
240 url string
241 expected string
242 }{
243 {"http://localhost/someDir/apage", "someDir"},
244 {"http://localhost/%23/apage", "hash"},
245 {"http://localhost/otherDir/apage", "Default"},
246 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
247 {"http://otherHost.com/someDir/apage", "someDir"},
248 {"http://otherHost.com/aDir/apage", "Default"},
249
250 {"http://localhost/someDir", "/someDir/"},
251 {"http://localhost/%23", "/%23/"},
252 {"http://someHost.com/someDir", "/someDir/"},
253 }
254
255 func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
256 func testHostHandlers(t *testing.T, mode testMode) {
257 mux := NewServeMux()
258 for _, h := range handlers {
259 mux.Handle(h.pattern, stringHandler(h.msg))
260 }
261 ts := newClientServerTest(t, mode, mux).ts
262
263 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
264 if err != nil {
265 t.Fatal(err)
266 }
267 defer conn.Close()
268 cc := httputil.NewClientConn(conn, nil)
269 for _, vt := range vtests {
270 var r *Response
271 var req Request
272 if req.URL, err = url.Parse(vt.url); err != nil {
273 t.Errorf("cannot parse url: %v", err)
274 continue
275 }
276 if err := cc.Write(&req); err != nil {
277 t.Errorf("writing request: %v", err)
278 continue
279 }
280 r, err := cc.Read(&req)
281 if err != nil {
282 t.Errorf("reading response: %v", err)
283 continue
284 }
285 switch r.StatusCode {
286 case StatusOK:
287 s := r.Header.Get("Result")
288 if s != vt.expected {
289 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
290 }
291 case StatusTemporaryRedirect:
292 s := r.Header.Get("Location")
293 if s != vt.expected {
294 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
295 }
296 default:
297 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
298 }
299 }
300 }
301
302 var serveMuxRegister = []struct {
303 pattern string
304 h Handler
305 }{
306 {"/dir/", serve(200)},
307 {"/search", serve(201)},
308 {"codesearch.google.com/search", serve(202)},
309 {"codesearch.google.com/", serve(203)},
310 {"example.com/", HandlerFunc(checkQueryStringHandler)},
311 }
312
313
314 func serve(code int) HandlerFunc {
315 return func(w ResponseWriter, r *Request) {
316 w.WriteHeader(code)
317 }
318 }
319
320
321
322
323 func checkQueryStringHandler(w ResponseWriter, r *Request) {
324 u := *r.URL
325 u.Scheme = "http"
326 u.Host = r.Host
327 u.RawQuery = ""
328 if "http://"+r.URL.RawQuery == u.String() {
329 w.WriteHeader(200)
330 } else {
331 w.WriteHeader(500)
332 }
333 }
334
335 var serveMuxTests = []struct {
336 method string
337 host string
338 path string
339 code int
340 pattern string
341 }{
342 {"GET", "google.com", "/", 404, ""},
343 {"GET", "google.com", "/dir", 307, "/dir/"},
344 {"GET", "google.com", "/dir/", 200, "/dir/"},
345 {"GET", "google.com", "/dir/file", 200, "/dir/"},
346 {"GET", "google.com", "/search", 201, "/search"},
347 {"GET", "google.com", "/search/", 404, ""},
348 {"GET", "google.com", "/search/foo", 404, ""},
349 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
350 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
351 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
352 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
353 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
354 {"GET", "images.google.com", "/search", 201, "/search"},
355 {"GET", "images.google.com", "/search/", 404, ""},
356 {"GET", "images.google.com", "/search/foo", 404, ""},
357 {"GET", "google.com", "/../search", 307, "/search"},
358 {"GET", "google.com", "/dir/..", 307, ""},
359 {"GET", "google.com", "/dir/..", 307, ""},
360 {"GET", "google.com", "/dir/./file", 307, "/dir/"},
361
362
363
364 {"CONNECT", "google.com", "/dir", 307, "/dir/"},
365 {"CONNECT", "google.com", "/../search", 404, ""},
366 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
367 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
368 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
369 }
370
371 func TestServeMuxHandler(t *testing.T) {
372 setParallel(t)
373 mux := NewServeMux()
374 for _, e := range serveMuxRegister {
375 mux.Handle(e.pattern, e.h)
376 }
377
378 for _, tt := range serveMuxTests {
379 r := &Request{
380 Method: tt.method,
381 Host: tt.host,
382 URL: &url.URL{
383 Path: tt.path,
384 },
385 }
386 h, pattern := mux.Handler(r)
387 rr := httptest.NewRecorder()
388 h.ServeHTTP(rr, r)
389 if pattern != tt.pattern || rr.Code != tt.code {
390 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
391 }
392 }
393 }
394
395
396 func TestServeMuxHandlerTrailingSlash(t *testing.T) {
397 setParallel(t)
398 mux := NewServeMux()
399 const original = "/{x}/"
400 mux.Handle(original, NotFoundHandler())
401 r, _ := NewRequest("POST", "/foo", nil)
402 _, p := mux.Handler(r)
403 if p != original {
404 t.Errorf("got %q, want %q", p, original)
405 }
406 }
407
408
409 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
410 setParallel(t)
411 defer func() {
412 if err := recover(); err == nil {
413 t.Error("expected call to mux.HandleFunc to panic")
414 }
415 }()
416 mux := NewServeMux()
417 mux.HandleFunc("/", nil)
418 }
419
420 var serveMuxTests2 = []struct {
421 method string
422 host string
423 url string
424 code int
425 redirOk bool
426 }{
427 {"GET", "google.com", "/", 404, false},
428 {"GET", "example.com", "/test/?example.com/test/", 200, false},
429 {"GET", "example.com", "test/?example.com/test/", 200, true},
430 }
431
432
433
434 func TestServeMuxHandlerRedirects(t *testing.T) {
435 setParallel(t)
436 mux := NewServeMux()
437 for _, e := range serveMuxRegister {
438 mux.Handle(e.pattern, e.h)
439 }
440
441 for _, tt := range serveMuxTests2 {
442 tries := 1
443 turl := tt.url
444 for {
445 u, e := url.Parse(turl)
446 if e != nil {
447 t.Fatal(e)
448 }
449 r := &Request{
450 Method: tt.method,
451 Host: tt.host,
452 URL: u,
453 }
454 h, _ := mux.Handler(r)
455 rr := httptest.NewRecorder()
456 h.ServeHTTP(rr, r)
457 if rr.Code != 307 {
458 if rr.Code != tt.code {
459 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
460 }
461 break
462 }
463 if !tt.redirOk {
464 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
465 break
466 }
467 turl = rr.HeaderMap.Get("Location")
468 tries--
469 }
470 if tries < 0 {
471 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
472 }
473 }
474 }
475
476 func TestServeMuxHandlerRedirectPost(t *testing.T) {
477 setParallel(t)
478 mux := NewServeMux()
479 mux.HandleFunc("POST /test/", func(w ResponseWriter, r *Request) {
480 w.WriteHeader(200)
481 })
482
483 var code, retries int
484 startURL := "http://example.com/test"
485 reqURL := startURL
486 for retries = 0; retries <= 1; retries++ {
487 r := httptest.NewRequest("POST", reqURL, strings.NewReader("hello world"))
488 h, _ := mux.Handler(r)
489 rr := httptest.NewRecorder()
490 h.ServeHTTP(rr, r)
491 code = rr.Code
492 switch rr.Code {
493 case 307:
494 reqURL = rr.Result().Header.Get("Location")
495 continue
496 case 200:
497
498 default:
499 t.Errorf("unhandled response code: %v", rr.Code)
500 }
501 }
502 if code != 200 {
503 t.Errorf("POST %s = %d after %d retries, want = 200", startURL, code, retries)
504 }
505 }
506
507
508 func TestMuxRedirectLeadingSlashes(t *testing.T) {
509 setParallel(t)
510 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
511 for _, path := range paths {
512 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
513 if err != nil {
514 t.Errorf("%s", err)
515 }
516 mux := NewServeMux()
517 resp := httptest.NewRecorder()
518
519 mux.ServeHTTP(resp, req)
520
521 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
522 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
523 return
524 }
525
526 if code, expected := resp.Code, StatusTemporaryRedirect; code != expected {
527 t.Errorf("Expected response code of StatusPermanentRedirect; got %d", code)
528 return
529 }
530 }
531 }
532
533
534
535
536
537 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
538 run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
539 }
540 func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
541 writeBackQuery := func(w ResponseWriter, r *Request) {
542 fmt.Fprintf(w, "%s", r.URL.RawQuery)
543 }
544
545 mux := NewServeMux()
546 mux.HandleFunc("/testOne", writeBackQuery)
547 mux.HandleFunc("/testTwo/", writeBackQuery)
548 mux.HandleFunc("/testThree", writeBackQuery)
549 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
550 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
551 })
552
553 ts := newClientServerTest(t, mode, mux).ts
554
555 tests := [...]struct {
556 path string
557 method string
558 want string
559 statusOk bool
560 }{
561 0: {"/testOne?this=that", "GET", "this=that", true},
562 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
563 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
564 3: {"/testTwo?", "GET", "", true},
565 4: {"/testThree?foo", "GET", "foo", true},
566 5: {"/testThree/?foo", "GET", "foo:bar", true},
567 6: {"/testThree?foo", "CONNECT", "foo", true},
568 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
569
570
571 8: {"/testOne/foo/..?foo", "GET", "foo", true},
572 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
573 }
574
575 for i, tt := range tests {
576 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
577 res, err := ts.Client().Do(req)
578 if err != nil {
579 continue
580 }
581 slurp, _ := io.ReadAll(res.Body)
582 res.Body.Close()
583 if !tt.statusOk {
584 if got, want := res.StatusCode, 404; got != want {
585 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
586 }
587 }
588 if got, want := string(slurp), tt.want; got != want {
589 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
590 }
591 }
592 }
593
594 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
595 setParallel(t)
596
597 mux := NewServeMux()
598 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
599 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
600 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
601 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
602 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
603 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
604
605 tests := []struct {
606 method string
607 url string
608 code int
609 loc string
610 want string
611 }{
612 {"GET", "http://example.com/", 404, "", ""},
613 {"GET", "http://example.com/pkg/foo", 307, "/pkg/foo/", ""},
614 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
615 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
616 {"GET", "http://example.com/pkg/baz", 307, "/pkg/baz/", ""},
617 {"GET", "http://example.com:3000/pkg/foo", 307, "/pkg/foo/", ""},
618 {"CONNECT", "http://example.com/", 404, "", ""},
619 {"CONNECT", "http://example.com:3000/", 404, "", ""},
620 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
621 {"CONNECT", "http://example.com/pkg/foo", 307, "/pkg/foo/", ""},
622 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
623 {"CONNECT", "http://example.com:3000/pkg/baz", 307, "/pkg/baz/", ""},
624 {"CONNECT", "http://example.com:3000/pkg/connect", 307, "/pkg/connect/", ""},
625 }
626
627 for i, tt := range tests {
628 req, _ := NewRequest(tt.method, tt.url, nil)
629 w := httptest.NewRecorder()
630 mux.ServeHTTP(w, req)
631
632 if got, want := w.Code, tt.code; got != want {
633 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
634 }
635
636 if tt.code == 301 {
637 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
638 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
639 }
640 } else {
641 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
642 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
643 }
644 }
645 }
646 }
647
648
649
650
651 func TestMuxNoSlashRedirectWithTrailingSlash(t *testing.T) {
652 mux := NewServeMux()
653 mux.HandleFunc("/{x}/", func(w ResponseWriter, r *Request) {
654 fmt.Fprintln(w, "ok")
655 })
656 w := httptest.NewRecorder()
657 req, _ := NewRequest("GET", "/", nil)
658 mux.ServeHTTP(w, req)
659 if g, w := w.Code, 404; g != w {
660 t.Errorf("got %d, want %d", g, w)
661 }
662 }
663
664
665
666
667 func TestMuxNoSlash405WithTrailingSlash(t *testing.T) {
668 mux := NewServeMux()
669 mux.HandleFunc("GET /{x}/", func(w ResponseWriter, r *Request) {
670 fmt.Fprintln(w, "ok")
671 })
672 w := httptest.NewRecorder()
673 req, _ := NewRequest("GET", "/", nil)
674 mux.ServeHTTP(w, req)
675 if g, w := w.Code, 404; g != w {
676 t.Errorf("got %d, want %d", g, w)
677 }
678 }
679
680 func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
681 func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
682 mux := NewServeMux()
683 newClientServerTest(t, mode, mux)
684 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
685 }
686
687 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
688 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
689 func benchmarkServeMux(b *testing.B, runHandler bool) {
690 type test struct {
691 path string
692 code int
693 req *Request
694 }
695
696
697 var tests []test
698 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
699 for _, e := range endpoints {
700 for i := 200; i < 230; i++ {
701 p := fmt.Sprintf("/%s/%d/", e, i)
702 tests = append(tests, test{
703 path: p,
704 code: i,
705 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
706 })
707 }
708 }
709 mux := NewServeMux()
710 for _, tt := range tests {
711 mux.Handle(tt.path, serve(tt.code))
712 }
713
714 rw := httptest.NewRecorder()
715 b.ReportAllocs()
716 b.ResetTimer()
717 for i := 0; i < b.N; i++ {
718 for _, tt := range tests {
719 *rw = httptest.ResponseRecorder{}
720 h, pattern := mux.Handler(tt.req)
721 if runHandler {
722 h.ServeHTTP(rw, tt.req)
723 if pattern != tt.path || rw.Code != tt.code {
724 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
725 }
726 }
727 }
728 }
729 }
730
731 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
732 func testServerTimeouts(t *testing.T, mode testMode) {
733 runTimeSensitiveTest(t, []time.Duration{
734 10 * time.Millisecond,
735 50 * time.Millisecond,
736 100 * time.Millisecond,
737 500 * time.Millisecond,
738 1 * time.Second,
739 }, func(t *testing.T, timeout time.Duration) error {
740 return testServerTimeoutsWithTimeout(t, timeout, mode)
741 })
742 }
743
744 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
745 var reqNum atomic.Int32
746 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
747 fmt.Fprintf(res, "req=%d", reqNum.Add(1))
748 }), func(ts *httptest.Server) {
749 ts.Config.ReadTimeout = timeout
750 ts.Config.WriteTimeout = timeout
751 })
752 defer cst.close()
753 ts := cst.ts
754
755
756 c := ts.Client()
757 r, err := c.Get(ts.URL)
758 if err != nil {
759 return fmt.Errorf("http Get #1: %v", err)
760 }
761 got, err := io.ReadAll(r.Body)
762 expected := "req=1"
763 if string(got) != expected || err != nil {
764 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
765 string(got), err, expected)
766 }
767
768
769 t1 := time.Now()
770 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
771 if err != nil {
772 return fmt.Errorf("Dial: %v", err)
773 }
774 buf := make([]byte, 1)
775 n, err := conn.Read(buf)
776 conn.Close()
777 latency := time.Since(t1)
778 if n != 0 || err != io.EOF {
779 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
780 }
781 minLatency := timeout / 5 * 4
782 if latency < minLatency {
783 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
784 }
785
786
787
788
789 r, err = c.Get(ts.URL)
790 if err != nil {
791 return fmt.Errorf("http Get #2: %v", err)
792 }
793 got, err = io.ReadAll(r.Body)
794 r.Body.Close()
795 expected = "req=2"
796 if string(got) != expected || err != nil {
797 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
798 }
799
800 if !testing.Short() {
801 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
802 if err != nil {
803 return fmt.Errorf("long Dial: %v", err)
804 }
805 defer conn.Close()
806 go io.Copy(io.Discard, conn)
807 for i := 0; i < 5; i++ {
808 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
809 if err != nil {
810 return fmt.Errorf("on write %d: %v", i, err)
811 }
812 time.Sleep(timeout / 2)
813 }
814 }
815 return nil
816 }
817
818 func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout, http3SkippedMode) }
819 func testServerReadTimeout(t *testing.T, mode testMode) {
820 respBody := "response body"
821 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
822 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
823 _, err := io.Copy(io.Discard, req.Body)
824 if !errors.Is(err, os.ErrDeadlineExceeded) {
825 t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
826 }
827 res.Write([]byte(respBody))
828 }), func(ts *httptest.Server) {
829 ts.Config.ReadHeaderTimeout = -1
830 ts.Config.ReadTimeout = timeout
831 t.Logf("Server.Config.ReadTimeout = %v", timeout)
832 })
833
834 var retries atomic.Int32
835 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
836 if retries.Add(1) != 1 {
837 return nil, errors.New("too many retries")
838 }
839 return nil, nil
840 }
841
842 pr, pw := io.Pipe()
843 res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
844 if err != nil {
845 t.Logf("Get error, retrying: %v", err)
846 cst.close()
847 continue
848 }
849 defer res.Body.Close()
850 got, err := io.ReadAll(res.Body)
851 if string(got) != respBody || err != nil {
852 t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
853 }
854 pw.Close()
855 break
856 }
857 }
858
859 func TestServerNoReadTimeout(t *testing.T) {
860
861 run(t, testServerNoReadTimeout, http3SkippedMode)
862 }
863 func testServerNoReadTimeout(t *testing.T, mode testMode) {
864 reqBody := "Hello, Gophers!"
865 resBody := "Hi, Gophers!"
866 for _, timeout := range []time.Duration{0, -1} {
867 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
868 ctl := NewResponseController(res)
869 ctl.EnableFullDuplex()
870 res.WriteHeader(StatusOK)
871
872
873 if err := ctl.Flush(); err != nil {
874 t.Errorf("server flush response: %v", err)
875 return
876 }
877 got, err := io.ReadAll(req.Body)
878 if string(got) != reqBody || err != nil {
879 t.Errorf("server read request body: %v; got %q, want %q", err, got, reqBody)
880 }
881 res.Write([]byte(resBody))
882 }), func(ts *httptest.Server) {
883 ts.Config.ReadTimeout = timeout
884 t.Logf("Server.Config.ReadTimeout = %d", timeout)
885 })
886
887 pr, pw := io.Pipe()
888 res, err := cst.c.Post(cst.ts.URL, "text/plain", pr)
889 if err != nil {
890 t.Fatal(err)
891 }
892 defer res.Body.Close()
893
894
895 time.Sleep(10 * time.Millisecond)
896 pw.Write([]byte(reqBody))
897 pw.Close()
898
899 got, err := io.ReadAll(res.Body)
900 if string(got) != resBody || err != nil {
901 t.Errorf("client read response body: %v; got %v, want %q", err, got, resBody)
902 }
903 }
904 }
905
906 func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout, http3SkippedMode) }
907 func testServerWriteTimeout(t *testing.T, mode testMode) {
908 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
909 errc := make(chan error, 2)
910 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
911 errc <- nil
912 _, err := io.Copy(res, neverEnding('a'))
913 errc <- err
914 }), func(ts *httptest.Server) {
915 ts.Config.WriteTimeout = timeout
916 t.Logf("Server.Config.WriteTimeout = %v", timeout)
917 })
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936 var retries atomic.Int32
937 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
938 if retries.Add(1) != 1 {
939 return nil, errors.New("too many retries")
940 }
941 return nil, nil
942 }
943
944 res, err := cst.c.Get(cst.ts.URL)
945 if err != nil {
946
947 t.Logf("Get error, retrying: %v", err)
948 cst.close()
949 continue
950 }
951 defer res.Body.Close()
952 _, err = io.Copy(io.Discard, res.Body)
953 if err == nil {
954 t.Errorf("client reading from truncated request body: got nil error, want non-nil")
955 }
956 select {
957 case <-errc:
958 err = <-errc
959 if !errors.Is(err, os.ErrDeadlineExceeded) {
960 t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
961 }
962 return
963 default:
964
965 t.Logf("handler didn't run, retrying")
966 cst.close()
967 }
968 }
969 }
970
971 func TestServerNoWriteTimeout(t *testing.T) { run(t, testServerNoWriteTimeout) }
972 func testServerNoWriteTimeout(t *testing.T, mode testMode) {
973 for _, timeout := range []time.Duration{0, -1} {
974 handlerDone := make(chan struct{})
975 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
976 defer close(handlerDone)
977 _, err := io.Copy(res, neverEnding('a'))
978 t.Logf("server write response: %v", err)
979 }), func(ts *httptest.Server) {
980 ts.Config.WriteTimeout = timeout
981 t.Logf("Server.Config.WriteTimeout = %d", timeout)
982 })
983
984 res, err := cst.c.Get(cst.ts.URL)
985 if err != nil {
986 t.Fatal(err)
987 }
988 n, err := io.CopyN(io.Discard, res.Body, 1<<20)
989 if n != 1<<20 || err != nil {
990 t.Errorf("client read response body: %d, %v", n, err)
991 }
992 res.Body.Close()
993
994 cst.ts.Config.Shutdown(context.Background())
995 <-handlerDone
996 }
997 }
998
999
1000 func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
1001 run(t, testWriteDeadlineExtendedOnNewRequest)
1002 }
1003 func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
1004 if testing.Short() {
1005 t.Skip("skipping in short mode")
1006 }
1007 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
1008 func(ts *httptest.Server) {
1009 ts.Config.WriteTimeout = 250 * time.Millisecond
1010 },
1011 ).ts
1012
1013 c := ts.Client()
1014
1015 for i := 1; i <= 3; i++ {
1016 req, err := NewRequest("GET", ts.URL, nil)
1017 if err != nil {
1018 t.Fatal(err)
1019 }
1020
1021 r, err := c.Do(req)
1022 if err != nil {
1023 t.Fatalf("http2 Get #%d: %v", i, err)
1024 }
1025 r.Body.Close()
1026 time.Sleep(ts.Config.WriteTimeout / 2)
1027 }
1028 }
1029
1030
1031
1032 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
1033 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
1034 for i, timeout := range tries {
1035 err := testFunc(timeout)
1036 if err == nil {
1037 return
1038 }
1039 t.Logf("failed at %v: %v", timeout, err)
1040 if i != len(tries)-1 {
1041 t.Logf("retrying at %v ...", tries[i+1])
1042 }
1043 }
1044 t.Fatal("all attempts failed")
1045 }
1046
1047
1048 func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
1049 if testing.Short() {
1050 t.Skip("skipping in short mode")
1051 }
1052 setParallel(t)
1053 run(t, func(t *testing.T, mode testMode) {
1054 tryTimeouts(t, func(timeout time.Duration) error {
1055 return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
1056 })
1057 }, http3SkippedMode)
1058 }
1059
1060 func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
1061 firstRequest := make(chan bool, 1)
1062 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1063 select {
1064 case firstRequest <- true:
1065
1066 default:
1067
1068 time.Sleep(timeout)
1069 }
1070 }), func(ts *httptest.Server) {
1071 ts.Config.WriteTimeout = timeout / 2
1072 })
1073 defer cst.close()
1074 ts := cst.ts
1075
1076 c := ts.Client()
1077
1078 req, err := NewRequest("GET", ts.URL, nil)
1079 if err != nil {
1080 return fmt.Errorf("NewRequest: %v", err)
1081 }
1082 r, err := c.Do(req)
1083 if err != nil {
1084 return fmt.Errorf("Get #1: %v", err)
1085 }
1086 r.Body.Close()
1087
1088 req, err = NewRequest("GET", ts.URL, nil)
1089 if err != nil {
1090 return fmt.Errorf("NewRequest: %v", err)
1091 }
1092 r, err = c.Do(req)
1093 if err == nil {
1094 r.Body.Close()
1095 return fmt.Errorf("Get #2 expected error, got nil")
1096 }
1097 if mode == http2Mode {
1098 expected := "stream ID 3; INTERNAL_ERROR"
1099 if !strings.Contains(err.Error(), expected) {
1100 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
1101 }
1102 }
1103 return nil
1104 }
1105
1106
1107 func TestNoWriteDeadline(t *testing.T) {
1108 if testing.Short() {
1109 t.Skip("skipping in short mode")
1110 }
1111 setParallel(t)
1112 defer afterTest(t)
1113 run(t, func(t *testing.T, mode testMode) {
1114 tryTimeouts(t, func(timeout time.Duration) error {
1115 return testNoWriteDeadline(t, mode, timeout)
1116 })
1117 })
1118 }
1119
1120 func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
1121 firstRequest := make(chan bool, 1)
1122 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1123 select {
1124 case firstRequest <- true:
1125
1126 default:
1127
1128 time.Sleep(timeout)
1129 }
1130 }))
1131 defer cst.close()
1132 ts := cst.ts
1133
1134 c := ts.Client()
1135
1136 for i := 0; i < 2; i++ {
1137 req, err := NewRequest("GET", ts.URL, nil)
1138 if err != nil {
1139 return fmt.Errorf("NewRequest: %v", err)
1140 }
1141 r, err := c.Do(req)
1142 if err != nil {
1143 return fmt.Errorf("Get #%d: %v", i, err)
1144 }
1145 r.Body.Close()
1146 }
1147 return nil
1148 }
1149
1150
1151
1152
1153 func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
1154 func testOnlyWriteTimeout(t *testing.T, mode testMode) {
1155 var (
1156 mu sync.RWMutex
1157 conn net.Conn
1158 )
1159 var afterTimeoutErrc = make(chan error, 1)
1160 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
1161 buf := make([]byte, 512<<10)
1162 _, err := w.Write(buf)
1163 if err != nil {
1164 t.Errorf("handler Write error: %v", err)
1165 return
1166 }
1167 mu.RLock()
1168 defer mu.RUnlock()
1169 if conn == nil {
1170 t.Error("no established connection found")
1171 return
1172 }
1173 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
1174 _, err = w.Write(buf)
1175 afterTimeoutErrc <- err
1176 }), func(ts *httptest.Server) {
1177 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
1178 }).ts
1179
1180 c := ts.Client()
1181
1182 err := func() error {
1183 res, err := c.Get(ts.URL)
1184 if err != nil {
1185 return err
1186 }
1187 _, err = io.Copy(io.Discard, res.Body)
1188 res.Body.Close()
1189 return err
1190 }()
1191 if err == nil {
1192 t.Errorf("expected an error copying body from Get request")
1193 }
1194
1195 if err := <-afterTimeoutErrc; err == nil {
1196 t.Error("expected write error after timeout")
1197 }
1198 }
1199
1200
1201 type trackLastConnListener struct {
1202 net.Listener
1203
1204 mu *sync.RWMutex
1205 last *net.Conn
1206 }
1207
1208 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
1209 c, err = l.Listener.Accept()
1210 if err == nil {
1211 l.mu.Lock()
1212 *l.last = c
1213 l.mu.Unlock()
1214 }
1215 return
1216 }
1217
1218
1219 func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
1220 func testIdentityResponse(t *testing.T, mode testMode) {
1221 if mode == http2Mode {
1222 t.Skip("https://go.dev/issue/56019")
1223 }
1224
1225 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
1226 rw.Header().Set("Content-Length", "3")
1227 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1228 switch {
1229 case req.FormValue("overwrite") == "1":
1230 _, err := rw.Write([]byte("foo TOO LONG"))
1231 if err != ErrContentLength {
1232 t.Errorf("expected ErrContentLength; got %v", err)
1233 }
1234 case req.FormValue("underwrite") == "1":
1235 rw.Header().Set("Content-Length", "500")
1236 rw.Write([]byte("too short"))
1237 default:
1238 rw.Write([]byte("foo"))
1239 }
1240 })
1241
1242 ts := newClientServerTest(t, mode, handler).ts
1243 c := ts.Client()
1244
1245
1246
1247
1248
1249 for _, te := range []string{"", "identity"} {
1250 url := ts.URL + "/?te=" + te
1251 res, err := c.Get(url)
1252 if err != nil {
1253 t.Fatalf("error with Get of %s: %v", url, err)
1254 }
1255 if cl, expected := res.ContentLength, int64(3); cl != expected {
1256 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1257 }
1258 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1259 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1260 }
1261 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1262 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1263 url, expected, tl, res.TransferEncoding)
1264 }
1265 res.Body.Close()
1266 }
1267
1268
1269 url := ts.URL + "/?overwrite=1"
1270 res, err := c.Get(url)
1271 if err != nil {
1272 t.Fatalf("error with Get of %s: %v", url, err)
1273 }
1274 res.Body.Close()
1275
1276 if mode != http1Mode {
1277 return
1278 }
1279
1280
1281
1282 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1283 if err != nil {
1284 t.Fatalf("error dialing: %v", err)
1285 }
1286 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1287 if err != nil {
1288 t.Fatalf("error writing: %v", err)
1289 }
1290
1291
1292 got, _ := io.ReadAll(conn)
1293 expectedSuffix := "\r\n\r\ntoo short"
1294 if !strings.HasSuffix(string(got), expectedSuffix) {
1295 t.Errorf("Expected output to end with %q; got response body %q",
1296 expectedSuffix, string(got))
1297 }
1298 }
1299
1300 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1301 setParallel(t)
1302 s := newClientServerTest(t, http1Mode, h).ts
1303
1304 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1305 if err != nil {
1306 t.Fatal("dial error:", err)
1307 }
1308 defer conn.Close()
1309
1310 _, err = fmt.Fprint(conn, req)
1311 if err != nil {
1312 t.Fatal("print error:", err)
1313 }
1314
1315 r := bufio.NewReader(conn)
1316 res, err := ReadResponse(r, &Request{Method: "GET"})
1317 if err != nil {
1318 t.Fatal("ReadResponse error:", err)
1319 }
1320
1321 _, err = io.ReadAll(r)
1322 if err != nil {
1323 t.Fatal("read error:", err)
1324 }
1325
1326 if !res.Close {
1327 t.Errorf("Response.Close = false; want true")
1328 }
1329 }
1330
1331 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1332 setParallel(t)
1333 ts := newClientServerTest(t, http1Mode, handler).ts
1334 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1335 if err != nil {
1336 t.Fatal(err)
1337 }
1338 defer conn.Close()
1339 br := bufio.NewReader(conn)
1340 for i := 0; i < 2; i++ {
1341 if _, err := io.WriteString(conn, req); err != nil {
1342 t.Fatal(err)
1343 }
1344 res, err := ReadResponse(br, nil)
1345 if err != nil {
1346 t.Fatalf("res %d: %v", i+1, err)
1347 }
1348 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1349 t.Fatalf("res %d body copy: %v", i+1, err)
1350 }
1351 res.Body.Close()
1352 }
1353 }
1354
1355
1356 func TestServeHTTP10Close(t *testing.T) {
1357 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1358 ServeFile(w, r, "testdata/file")
1359 }))
1360 }
1361
1362
1363 func TestClientCanClose(t *testing.T) {
1364 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1365
1366 }))
1367 }
1368
1369
1370
1371 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1372 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1373 w.Header().Set("Connection", "close")
1374 }))
1375 }
1376
1377 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1378 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1379 w.Header().Set("Connection", "close")
1380 }))
1381 }
1382
1383 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1384 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1385
1386
1387 }))
1388 }
1389
1390 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1391 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1392
1393
1394 func TestHTTP10KeepAlive204Response(t *testing.T) {
1395 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1396 }
1397
1398 func TestHTTP11KeepAlive204Response(t *testing.T) {
1399 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1400 }
1401
1402 func TestHTTP10KeepAlive304Response(t *testing.T) {
1403 testTCPConnectionStaysOpen(t,
1404 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1405 HandlerFunc(send304))
1406 }
1407
1408
1409 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
1410 func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
1411 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1412 w.(Flusher).Flush()
1413 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1414 }))
1415 type data struct {
1416 Addr string
1417 }
1418 var addrs [2]data
1419 for i := range addrs {
1420 res, err := cst.c.Get(cst.ts.URL)
1421 if err != nil {
1422 t.Fatal(err)
1423 }
1424 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1425 t.Fatal(err)
1426 }
1427 if addrs[i].Addr == "" {
1428 t.Fatal("no address")
1429 }
1430 res.Body.Close()
1431 }
1432 if addrs[0] != addrs[1] {
1433 t.Fatalf("connection not reused")
1434 }
1435 }
1436
1437 func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
1438 func testSetsRemoteAddr(t *testing.T, mode testMode) {
1439 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1440 fmt.Fprintf(w, "%s", r.RemoteAddr)
1441 }))
1442
1443 res, err := cst.c.Get(cst.ts.URL)
1444 if err != nil {
1445 t.Fatalf("Get error: %v", err)
1446 }
1447 body, err := io.ReadAll(res.Body)
1448 if err != nil {
1449 t.Fatalf("ReadAll error: %v", err)
1450 }
1451 ip := string(body)
1452 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1453 t.Fatalf("Expected local addr; got %q", ip)
1454 }
1455 }
1456
1457 type blockingRemoteAddrListener struct {
1458 net.Listener
1459 conns chan<- net.Conn
1460 }
1461
1462 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1463 c, err := l.Listener.Accept()
1464 if err != nil {
1465 return nil, err
1466 }
1467 brac := &blockingRemoteAddrConn{
1468 Conn: c,
1469 addrs: make(chan net.Addr, 1),
1470 }
1471 l.conns <- brac
1472 return brac, nil
1473 }
1474
1475 type blockingRemoteAddrConn struct {
1476 net.Conn
1477 addrs chan net.Addr
1478 }
1479
1480 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1481 return <-c.addrs
1482 }
1483
1484
1485 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1486 run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
1487 }
1488 func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
1489 conns := make(chan net.Conn)
1490 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1491 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1492 }), func(ts *httptest.Server) {
1493 ts.Listener = &blockingRemoteAddrListener{
1494 Listener: ts.Listener,
1495 conns: conns,
1496 }
1497 }).ts
1498
1499 c := ts.Client()
1500
1501 c.Transport.(*Transport).DisableKeepAlives = true
1502
1503 fetch := func(num int, response chan<- string) {
1504 resp, err := c.Get(ts.URL)
1505 if err != nil {
1506 t.Errorf("Request %d: %v", num, err)
1507 response <- ""
1508 return
1509 }
1510 defer resp.Body.Close()
1511 body, err := io.ReadAll(resp.Body)
1512 if err != nil {
1513 t.Errorf("Request %d: %v", num, err)
1514 response <- ""
1515 return
1516 }
1517 response <- string(body)
1518 }
1519
1520
1521 response1c := make(chan string, 1)
1522 go fetch(1, response1c)
1523
1524
1525 conn1 := <-conns
1526
1527
1528 response2c := make(chan string, 1)
1529 go fetch(2, response2c)
1530 conn2 := <-conns
1531
1532
1533 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1534 IP: net.ParseIP("12.12.12.12"), Port: 12}
1535
1536
1537 response2 := <-response2c
1538 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1539 t.Fatalf("response 2 addr = %q; want %q", g, e)
1540 }
1541
1542
1543 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1544 IP: net.ParseIP("21.21.21.21"), Port: 21}
1545
1546
1547 response1 := <-response1c
1548 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1549 t.Fatalf("response 1 addr = %q; want %q", g, e)
1550 }
1551 }
1552
1553
1554
1555 func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
1556 func testHeadResponses(t *testing.T, mode testMode) {
1557 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1558 _, err := w.Write([]byte("<html>"))
1559 if err != nil {
1560 t.Errorf("ResponseWriter.Write: %v", err)
1561 }
1562
1563
1564 _, err = io.Copy(w, struct{ io.Reader }{strings.NewReader("789a")})
1565 if err != nil {
1566 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1567 }
1568 }))
1569 res, err := cst.c.Head(cst.ts.URL)
1570 if err != nil {
1571 t.Error(err)
1572 }
1573 if len(res.TransferEncoding) > 0 {
1574 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1575 }
1576 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1577 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1578 }
1579
1580 if v := res.ContentLength; v != 10 && mode != http3Mode {
1581 t.Errorf("Content-Length: %d; want 10", v)
1582 }
1583 body, err := io.ReadAll(res.Body)
1584 if err != nil {
1585 t.Error(err)
1586 }
1587 if len(body) > 0 {
1588 t.Errorf("got unexpected body %q", string(body))
1589 }
1590 }
1591
1592
1593
1594 func TestHeadReaderFrom(t *testing.T) { run(t, testHeadReaderFrom, []testMode{http1Mode}) }
1595 func testHeadReaderFrom(t *testing.T, mode testMode) {
1596
1597 wantBody := strings.Repeat("a", 4096)
1598 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1599 w.(io.ReaderFrom).ReadFrom(strings.NewReader(wantBody))
1600 }))
1601 res, err := cst.c.Head(cst.ts.URL)
1602 if err != nil {
1603 t.Fatal(err)
1604 }
1605 res.Body.Close()
1606 res, err = cst.c.Get(cst.ts.URL)
1607 if err != nil {
1608 t.Fatal(err)
1609 }
1610 gotBody, err := io.ReadAll(res.Body)
1611 res.Body.Close()
1612 if err != nil {
1613 t.Fatal(err)
1614 }
1615 if string(gotBody) != wantBody {
1616 t.Errorf("got unexpected body len=%v, want %v", len(gotBody), len(wantBody))
1617 }
1618 }
1619
1620
1621
1622 func TestReaderFromTooLong(t *testing.T) { run(t, testReaderFromTooLong, []testMode{http1Mode}) }
1623 func testReaderFromTooLong(t *testing.T, mode testMode) {
1624 contentLen := 600
1625 tests := []struct {
1626 name string
1627 reader io.Reader
1628 wantHandlerErr error
1629 }{
1630 {
1631 name: "reader of correct length",
1632 reader: strings.NewReader(strings.Repeat("a", contentLen)),
1633 },
1634 {
1635 name: "wrapped reader of correct outer length",
1636 reader: io.LimitReader(strings.NewReader(strings.Repeat("a", 2*contentLen)), int64(contentLen)),
1637 },
1638 {
1639 name: "wrapped reader of correct inner length",
1640 reader: io.LimitReader(io.LimitReader(strings.NewReader(strings.Repeat("a", 2*contentLen)), int64(contentLen)), int64(2*contentLen)),
1641 },
1642 {
1643 name: "reader that is too long",
1644 reader: strings.NewReader(strings.Repeat("a", 2*contentLen)),
1645 wantHandlerErr: ErrContentLength,
1646 },
1647 {
1648 name: "wrapped reader that is too long",
1649 reader: io.LimitReader(strings.NewReader(strings.Repeat("a", 2*contentLen)), int64(2*contentLen)),
1650 wantHandlerErr: ErrContentLength,
1651 },
1652 }
1653
1654 for _, tc := range tests {
1655 t.Run(tc.name, func(t *testing.T) {
1656 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1657 w.Header().Set("Content-Length", strconv.Itoa(contentLen))
1658 n, err := w.(io.ReaderFrom).ReadFrom(tc.reader)
1659 if int(n) != contentLen || !errors.Is(err, tc.wantHandlerErr) {
1660 t.Errorf("got %v, %v from w.ReadFrom; want %v, %v", n, err, contentLen, tc.wantHandlerErr)
1661 }
1662 }))
1663 res, err := cst.c.Get(cst.ts.URL)
1664 if err != nil {
1665 t.Fatal(err)
1666 }
1667 defer res.Body.Close()
1668 gotBody, err := io.ReadAll(res.Body)
1669 if err != nil {
1670 t.Fatal(err)
1671 }
1672 if len(gotBody) != contentLen {
1673 t.Errorf("got unexpected body len=%v, want %v", len(gotBody), contentLen)
1674 }
1675 })
1676 }
1677 }
1678
1679 func TestTLSHandshakeTimeout(t *testing.T) {
1680 run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
1681 }
1682 func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
1683 errLog := new(strings.Builder)
1684 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
1685 func(ts *httptest.Server) {
1686 ts.Config.ReadTimeout = 250 * time.Millisecond
1687 ts.Config.ErrorLog = log.New(errLog, "", 0)
1688 },
1689 )
1690 ts := cst.ts
1691
1692 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1693 if err != nil {
1694 t.Fatalf("Dial: %v", err)
1695 }
1696 var buf [1]byte
1697 n, err := conn.Read(buf[:])
1698 if err == nil || n != 0 {
1699 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1700 }
1701 conn.Close()
1702
1703 cst.close()
1704 if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1705 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1706 }
1707 }
1708
1709 func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
1710 func testTLSServer(t *testing.T, mode testMode) {
1711 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1712 if r.TLS != nil {
1713 w.Header().Set("X-TLS-Set", "true")
1714 if r.TLS.HandshakeComplete {
1715 w.Header().Set("X-TLS-HandshakeComplete", "true")
1716 }
1717 }
1718 }), func(ts *httptest.Server) {
1719 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1720 }).ts
1721
1722
1723
1724
1725
1726
1727 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1728 if err != nil {
1729 t.Fatalf("Dial: %v", err)
1730 }
1731 defer idleConn.Close()
1732
1733 if !strings.HasPrefix(ts.URL, "https://") {
1734 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1735 return
1736 }
1737 client := ts.Client()
1738 res, err := client.Get(ts.URL)
1739 if err != nil {
1740 t.Error(err)
1741 return
1742 }
1743 if res == nil {
1744 t.Errorf("got nil Response")
1745 return
1746 }
1747 defer res.Body.Close()
1748 if res.Header.Get("X-TLS-Set") != "true" {
1749 t.Errorf("expected X-TLS-Set response header")
1750 return
1751 }
1752 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1753 t.Errorf("expected X-TLS-HandshakeComplete header")
1754 }
1755 }
1756
1757 type fakeConnectionStateConn struct {
1758 net.Conn
1759 }
1760
1761 func (fcsc *fakeConnectionStateConn) ConnectionState() tls.ConnectionState {
1762 return tls.ConnectionState{
1763 ServerName: "example.com",
1764 }
1765 }
1766
1767 func TestTLSServerWithoutTLSConn(t *testing.T) {
1768
1769 pr, pw := net.Pipe()
1770 c := make(chan int)
1771 listener := &oneConnListener{&fakeConnectionStateConn{pr}}
1772 server := &Server{
1773 Handler: HandlerFunc(func(writer ResponseWriter, request *Request) {
1774 if request.TLS == nil {
1775 t.Fatal("request.TLS is nil, expected not nil")
1776 }
1777 if request.TLS.ServerName != "example.com" {
1778 t.Fatalf("request.TLS.ServerName is %s, expected %s", request.TLS.ServerName, "example.com")
1779 }
1780 writer.Header().Set("X-TLS-ServerName", "example.com")
1781 }),
1782 }
1783
1784
1785 go func() {
1786 req, _ := NewRequest(MethodGet, "https://example.com", nil)
1787 req.Write(pw)
1788
1789 resp, _ := ReadResponse(bufio.NewReader(pw), req)
1790 if hdr := resp.Header.Get("X-TLS-ServerName"); hdr != "example.com" {
1791 t.Errorf("response header X-TLS-ServerName is %s, expected %s", hdr, "example.com")
1792 }
1793 close(c)
1794 pw.Close()
1795 }()
1796
1797 server.Serve(listener)
1798
1799
1800 <-c
1801 pr.Close()
1802 }
1803
1804 func TestServeTLS(t *testing.T) {
1805 CondSkipHTTP2(t)
1806
1807 defer afterTest(t)
1808 defer SetTestHookServerServe(nil)
1809
1810 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1811 if err != nil {
1812 t.Fatal(err)
1813 }
1814 tlsConf := &tls.Config{
1815 Certificates: []tls.Certificate{cert},
1816 }
1817
1818 ln := newLocalListener(t)
1819 defer ln.Close()
1820 addr := ln.Addr().String()
1821
1822 serving := make(chan bool, 1)
1823 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1824 serving <- true
1825 })
1826 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1827 s := &Server{
1828 Addr: addr,
1829 TLSConfig: tlsConf,
1830 Handler: handler,
1831 }
1832 errc := make(chan error, 1)
1833 go func() { errc <- s.ServeTLS(ln, "", "") }()
1834 select {
1835 case err := <-errc:
1836 t.Fatalf("ServeTLS: %v", err)
1837 case <-serving:
1838 }
1839
1840 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1841 InsecureSkipVerify: true,
1842 NextProtos: []string{"h2", "http/1.1"},
1843 })
1844 if err != nil {
1845 t.Fatal(err)
1846 }
1847 defer c.Close()
1848 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1849 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1850 }
1851 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1852 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1853 }
1854 }
1855
1856
1857 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1858 run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
1859 }
1860 func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
1861 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1862 t.Error("unexpected HTTPS request")
1863 }), func(ts *httptest.Server) {
1864 var errBuf bytes.Buffer
1865 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1866 }).ts
1867 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1868 if err != nil {
1869 t.Fatal(err)
1870 }
1871 defer conn.Close()
1872 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1873 slurp, err := io.ReadAll(conn)
1874 if err != nil {
1875 t.Fatal(err)
1876 }
1877 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1878 if !strings.HasPrefix(string(slurp), wantPrefix) {
1879 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1880 }
1881 }
1882
1883
1884 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1885 testAutomaticHTTP2_Serve(t, nil, true)
1886 }
1887
1888 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1889 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1890 }
1891
1892 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1893 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1894 }
1895
1896 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1897 setParallel(t)
1898 defer afterTest(t)
1899 ln := newLocalListener(t)
1900 ln.Close()
1901 var s Server
1902 s.TLSConfig = tlsConf
1903 if err := s.Serve(ln); err == nil {
1904 t.Fatal("expected an error")
1905 }
1906 gotH2 := s.TLSNextProto["h2"] != nil
1907 if gotH2 != wantH2 {
1908 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1909 }
1910 }
1911
1912 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1913 setParallel(t)
1914 defer afterTest(t)
1915 ln := newLocalListener(t)
1916 ln.Close()
1917 var s Server
1918
1919
1920 s.TLSConfig = &tls.Config{
1921 NextProtos: []string{"h2"},
1922 }
1923 if err := s.Serve(ln); err == nil {
1924 t.Fatal("expected an error")
1925 }
1926 on := s.TLSNextProto["h2"] != nil
1927 if !on {
1928 t.Errorf("http2 wasn't automatically enabled")
1929 }
1930 }
1931
1932 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1933 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1934 if err != nil {
1935 t.Fatal(err)
1936 }
1937 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1938 Certificates: []tls.Certificate{cert},
1939 })
1940 }
1941
1942 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1943 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1944 if err != nil {
1945 t.Fatal(err)
1946 }
1947 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1948 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1949 return &cert, nil
1950 },
1951 })
1952 }
1953
1954 func TestAutomaticHTTP2_ListenAndServe_GetConfigForClient(t *testing.T) {
1955 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1956 if err != nil {
1957 t.Fatal(err)
1958 }
1959 conf := &tls.Config{
1960
1961
1962 NextProtos: []string{"h2"},
1963 Certificates: []tls.Certificate{cert},
1964 }
1965 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1966 GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
1967 return conf, nil
1968 },
1969 })
1970 }
1971
1972 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1973 CondSkipHTTP2(t)
1974
1975 defer afterTest(t)
1976 defer SetTestHookServerServe(nil)
1977 var ok bool
1978 var s *Server
1979 const maxTries = 5
1980 var ln net.Listener
1981 Try:
1982 for try := 0; try < maxTries; try++ {
1983 ln = newLocalListener(t)
1984 addr := ln.Addr().String()
1985 ln.Close()
1986 t.Logf("Got %v", addr)
1987 lnc := make(chan net.Listener, 1)
1988 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1989 lnc <- ln
1990 })
1991 s = &Server{
1992 Addr: addr,
1993 TLSConfig: tlsConf,
1994 }
1995 errc := make(chan error, 1)
1996 go func() { errc <- s.ListenAndServeTLS("", "") }()
1997 select {
1998 case err := <-errc:
1999 t.Logf("On try #%v: %v", try+1, err)
2000 continue
2001 case ln = <-lnc:
2002 ok = true
2003 t.Logf("Listening on %v", ln.Addr().String())
2004 break Try
2005 }
2006 }
2007 if !ok {
2008 t.Fatalf("Failed to start up after %d tries", maxTries)
2009 }
2010 defer ln.Close()
2011 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
2012 InsecureSkipVerify: true,
2013 NextProtos: []string{"h2", "http/1.1"},
2014 })
2015 if err != nil {
2016 t.Fatal(err)
2017 }
2018 defer c.Close()
2019 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
2020 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
2021 }
2022 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
2023 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
2024 }
2025 }
2026
2027 type serverExpectTest struct {
2028 contentLength int
2029 chunked bool
2030 expectation string
2031 readBody bool
2032 expectedResponse string
2033 }
2034
2035 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
2036 return serverExpectTest{
2037 contentLength: contentLength,
2038 expectation: expectation,
2039 readBody: readBody,
2040 expectedResponse: expectedResponse,
2041 }
2042 }
2043
2044 var serverExpectTests = []serverExpectTest{
2045
2046 expectTest(100, "100-continue", true, "100 Continue"),
2047 expectTest(100, "100-cOntInUE", true, "100 Continue"),
2048
2049
2050 expectTest(100, "", true, "200 OK"),
2051
2052
2053
2054 expectTest(100, "100-continue", false, "401 Unauthorized"),
2055
2056 expectTest(100, "", false, "401 Unauthorized"),
2057
2058
2059 expectTest(0, "a-pony", false, "417 Expectation Failed"),
2060
2061
2062 expectTest(0, "100-continue", true, "200 OK"),
2063
2064 expectTest(0, "100-continue", false, "401 Unauthorized"),
2065
2066 {
2067 expectation: "100-continue",
2068 readBody: true,
2069 chunked: true,
2070 expectedResponse: "100 Continue",
2071 },
2072 }
2073
2074
2075
2076 func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
2077 func testServerExpect(t *testing.T, mode testMode) {
2078 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2079
2080
2081
2082 if strings.Contains(r.URL.RawQuery, "readbody=true") {
2083 io.ReadAll(r.Body)
2084 w.Write([]byte("Hi"))
2085 } else {
2086 w.WriteHeader(StatusUnauthorized)
2087 }
2088 })).ts
2089
2090 runTest := func(test serverExpectTest) {
2091 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
2092 if err != nil {
2093 t.Fatalf("Dial: %v", err)
2094 }
2095 defer conn.Close()
2096
2097
2098
2099 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
2100
2101 wg := sync.WaitGroup{}
2102 wg.Add(1)
2103 defer wg.Wait()
2104
2105 go func() {
2106 defer wg.Done()
2107
2108 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
2109 if test.chunked {
2110 contentLen = "Transfer-Encoding: chunked"
2111 }
2112 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
2113 "Connection: close\r\n"+
2114 "%s\r\n"+
2115 "Expect: %s\r\nHost: foo\r\n\r\n",
2116 test.readBody, contentLen, test.expectation)
2117 if err != nil {
2118 t.Errorf("On test %#v, error writing request headers: %v", test, err)
2119 return
2120 }
2121 if writeBody {
2122 var targ io.WriteCloser = struct {
2123 io.Writer
2124 io.Closer
2125 }{
2126 conn,
2127 io.NopCloser(nil),
2128 }
2129 if test.chunked {
2130 targ = httputil.NewChunkedWriter(conn)
2131 }
2132 body := strings.Repeat("A", test.contentLength)
2133 _, err = fmt.Fprint(targ, body)
2134 if err == nil {
2135 err = targ.Close()
2136 }
2137 if err != nil {
2138 if !test.readBody {
2139
2140
2141 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
2142 return
2143 }
2144 t.Errorf("On test %#v, error writing request body: %v", test, err)
2145 }
2146 }
2147 }()
2148 bufr := bufio.NewReader(conn)
2149 line, err := bufr.ReadString('\n')
2150 if err != nil {
2151 if writeBody && !test.readBody {
2152
2153
2154
2155
2156
2157 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
2158 return
2159 }
2160 t.Fatalf("On test %#v, ReadString: %v", test, err)
2161 }
2162 if !strings.Contains(line, test.expectedResponse) {
2163 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
2164 }
2165 }
2166
2167 for _, test := range serverExpectTests {
2168 runTest(test)
2169 }
2170 }
2171
2172
2173
2174 func TestServerUnreadRequestBodyLittle(t *testing.T) {
2175 setParallel(t)
2176 defer afterTest(t)
2177 conn := new(testConn)
2178 body := strings.Repeat("x", 100<<10)
2179 conn.readBuf.Write([]byte(fmt.Sprintf(
2180 "POST / HTTP/1.1\r\n"+
2181 "Host: test\r\n"+
2182 "Content-Length: %d\r\n"+
2183 "\r\n", len(body))))
2184 conn.readBuf.Write([]byte(body))
2185
2186 done := make(chan bool)
2187
2188 readBufLen := func() int {
2189 conn.readMu.Lock()
2190 defer conn.readMu.Unlock()
2191 return conn.readBuf.Len()
2192 }
2193
2194 ls := &oneConnListener{conn}
2195 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2196 defer close(done)
2197 if bufLen := readBufLen(); bufLen < len(body)/2 {
2198 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
2199 }
2200 rw.WriteHeader(200)
2201 rw.(Flusher).Flush()
2202 if g, e := readBufLen(), 0; g != e {
2203 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
2204 }
2205 if c := rw.Header().Get("Connection"); c != "" {
2206 t.Errorf(`Connection header = %q; want ""`, c)
2207 }
2208 }))
2209 <-done
2210 }
2211
2212
2213
2214
2215 func TestServerUnreadRequestBodyLarge(t *testing.T) {
2216 setParallel(t)
2217 if testing.Short() && testenv.Builder() == "" {
2218 t.Log("skipping in short mode")
2219 }
2220 conn := new(testConn)
2221 body := strings.Repeat("x", 1<<20)
2222 conn.readBuf.Write([]byte(fmt.Sprintf(
2223 "POST / HTTP/1.1\r\n"+
2224 "Host: test\r\n"+
2225 "Content-Length: %d\r\n"+
2226 "\r\n", len(body))))
2227 conn.readBuf.Write([]byte(body))
2228 conn.closec = make(chan bool, 1)
2229
2230 ls := &oneConnListener{conn}
2231 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2232 if conn.readBuf.Len() < len(body)/2 {
2233 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2234 }
2235 rw.WriteHeader(200)
2236 rw.(Flusher).Flush()
2237 if conn.readBuf.Len() < len(body)/2 {
2238 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2239 }
2240 }))
2241 <-conn.closec
2242
2243 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
2244 t.Errorf("Expected a Connection: close header; got response: %s", res)
2245 }
2246 }
2247
2248 type handlerBodyCloseTest struct {
2249 bodySize int
2250 bodyChunked bool
2251 reqConnClose bool
2252
2253 wantEOFSearch bool
2254 wantNextReq bool
2255 }
2256
2257 func (t handlerBodyCloseTest) connectionHeader() string {
2258 if t.reqConnClose {
2259 return "Connection: close\r\n"
2260 }
2261 return ""
2262 }
2263
2264 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
2265
2266
2267 0: {
2268 bodySize: 20 << 10,
2269 bodyChunked: false,
2270 reqConnClose: false,
2271 wantEOFSearch: true,
2272 wantNextReq: true,
2273 },
2274
2275
2276
2277 1: {
2278 bodySize: 20 << 10,
2279 bodyChunked: true,
2280 reqConnClose: false,
2281 wantEOFSearch: true,
2282 wantNextReq: true,
2283 },
2284
2285
2286
2287
2288 2: {
2289 bodySize: 20 << 10,
2290 bodyChunked: false,
2291 reqConnClose: true,
2292 wantEOFSearch: false,
2293 wantNextReq: false,
2294 },
2295
2296
2297
2298
2299
2300
2301 3: {
2302 bodySize: 20 << 10,
2303 bodyChunked: true,
2304 reqConnClose: true,
2305 wantEOFSearch: true,
2306 wantNextReq: false,
2307 },
2308
2309
2310 4: {
2311 bodySize: 1 << 20,
2312 bodyChunked: false,
2313 reqConnClose: false,
2314 wantEOFSearch: false,
2315 wantNextReq: false,
2316 },
2317
2318
2319 5: {
2320 bodySize: 1 << 20,
2321 bodyChunked: true,
2322 reqConnClose: false,
2323 wantEOFSearch: true,
2324 wantNextReq: false,
2325 },
2326
2327
2328
2329
2330 6: {
2331 bodySize: 1 << 20,
2332 bodyChunked: true,
2333 reqConnClose: true,
2334 wantEOFSearch: true,
2335 wantNextReq: false,
2336 },
2337
2338
2339
2340 7: {
2341 bodySize: 1 << 20,
2342 bodyChunked: false,
2343 reqConnClose: true,
2344 wantEOFSearch: false,
2345 wantNextReq: false,
2346 },
2347 }
2348
2349 func TestHandlerBodyClose(t *testing.T) {
2350 setParallel(t)
2351 if testing.Short() && testenv.Builder() == "" {
2352 t.Skip("skipping in -short mode")
2353 }
2354 for i, tt := range handlerBodyCloseTests {
2355 testHandlerBodyClose(t, i, tt)
2356 }
2357 }
2358
2359 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2360 conn := new(testConn)
2361 body := strings.Repeat("x", tt.bodySize)
2362 if tt.bodyChunked {
2363 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2364 "Host: test\r\n" +
2365 tt.connectionHeader() +
2366 "Transfer-Encoding: chunked\r\n" +
2367 "\r\n")
2368 cw := internal.NewChunkedWriter(&conn.readBuf)
2369 io.WriteString(cw, body)
2370 cw.Close()
2371 conn.readBuf.WriteString("\r\n")
2372 } else {
2373 conn.readBuf.Write([]byte(fmt.Sprintf(
2374 "POST / HTTP/1.1\r\n"+
2375 "Host: test\r\n"+
2376 tt.connectionHeader()+
2377 "Content-Length: %d\r\n"+
2378 "\r\n", len(body))))
2379 conn.readBuf.Write([]byte(body))
2380 }
2381 if !tt.reqConnClose {
2382 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2383 }
2384 conn.closec = make(chan bool, 1)
2385
2386 readBufLen := func() int {
2387 conn.readMu.Lock()
2388 defer conn.readMu.Unlock()
2389 return conn.readBuf.Len()
2390 }
2391
2392 ls := &oneConnListener{conn}
2393 var numReqs int
2394 var size0, size1 int
2395 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2396 numReqs++
2397 if numReqs == 1 {
2398 size0 = readBufLen()
2399 req.Body.Close()
2400 size1 = readBufLen()
2401 }
2402 }))
2403 <-conn.closec
2404 if numReqs < 1 || numReqs > 2 {
2405 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2406 }
2407 didSearch := size0 != size1
2408 if didSearch != tt.wantEOFSearch {
2409 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2410 }
2411 if tt.wantNextReq && numReqs != 2 {
2412 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2413 }
2414 }
2415
2416
2417
2418 type testHandlerBodyConsumer struct {
2419 name string
2420 f func(io.ReadCloser)
2421 }
2422
2423 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2424 {"nil", func(io.ReadCloser) {}},
2425 {"close", func(r io.ReadCloser) { r.Close() }},
2426 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2427 }
2428
2429 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2430 setParallel(t)
2431 defer afterTest(t)
2432 for _, handler := range testHandlerBodyConsumers {
2433 conn := new(testConn)
2434 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2435 "Host: test\r\n" +
2436 "Transfer-Encoding: chunked\r\n" +
2437 "\r\n" +
2438 "hax\r\n" +
2439 "GET /secret HTTP/1.1\r\n" +
2440 "Host: test\r\n" +
2441 "\r\n")
2442
2443 conn.closec = make(chan bool, 1)
2444 ls := &oneConnListener{conn}
2445 var numReqs int
2446 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2447 numReqs++
2448 if strings.Contains(req.URL.Path, "secret") {
2449 t.Error("Request for /secret encountered, should not have happened.")
2450 }
2451 handler.f(req.Body)
2452 }))
2453 <-conn.closec
2454 if numReqs != 1 {
2455 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2456 }
2457 }
2458 }
2459
2460 func TestInvalidTrailerClosesConnection(t *testing.T) {
2461 setParallel(t)
2462 defer afterTest(t)
2463 for _, handler := range testHandlerBodyConsumers {
2464 conn := new(testConn)
2465 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2466 "Host: test\r\n" +
2467 "Trailer: hack\r\n" +
2468 "Transfer-Encoding: chunked\r\n" +
2469 "\r\n" +
2470 "3\r\n" +
2471 "hax\r\n" +
2472 "0\r\n" +
2473 "I'm not a valid trailer\r\n" +
2474 "GET /secret HTTP/1.1\r\n" +
2475 "Host: test\r\n" +
2476 "\r\n")
2477
2478 conn.closec = make(chan bool, 1)
2479 ln := &oneConnListener{conn}
2480 var numReqs int
2481 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2482 numReqs++
2483 if strings.Contains(req.URL.Path, "secret") {
2484 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2485 }
2486 handler.f(req.Body)
2487 }))
2488 <-conn.closec
2489 if numReqs != 1 {
2490 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2491 }
2492 }
2493 }
2494
2495
2496
2497
2498 type slowTestConn struct {
2499
2500 script []any
2501 closec chan bool
2502
2503 mu sync.Mutex
2504 rd, wd time.Time
2505 noopConn
2506 }
2507
2508 func (c *slowTestConn) SetDeadline(t time.Time) error {
2509 c.SetReadDeadline(t)
2510 c.SetWriteDeadline(t)
2511 return nil
2512 }
2513
2514 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2515 c.mu.Lock()
2516 defer c.mu.Unlock()
2517 c.rd = t
2518 return nil
2519 }
2520
2521 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2522 c.mu.Lock()
2523 defer c.mu.Unlock()
2524 c.wd = t
2525 return nil
2526 }
2527
2528 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2529 c.mu.Lock()
2530 defer c.mu.Unlock()
2531 restart:
2532 if !c.rd.IsZero() && time.Now().After(c.rd) {
2533 return 0, syscall.ETIMEDOUT
2534 }
2535 if len(c.script) == 0 {
2536 return 0, io.EOF
2537 }
2538
2539 switch cue := c.script[0].(type) {
2540 case time.Duration:
2541 if !c.rd.IsZero() {
2542
2543
2544 if remaining := time.Until(c.rd); remaining < cue {
2545 c.script[0] = cue - remaining
2546 time.Sleep(remaining)
2547 return 0, syscall.ETIMEDOUT
2548 }
2549 }
2550 c.script = c.script[1:]
2551 time.Sleep(cue)
2552 goto restart
2553
2554 case string:
2555 n = copy(b, cue)
2556
2557 if len(cue) > n {
2558 c.script[0] = cue[n:]
2559 } else {
2560 c.script = c.script[1:]
2561 }
2562
2563 default:
2564 panic("unknown cue in slowTestConn script")
2565 }
2566
2567 return
2568 }
2569
2570 func (c *slowTestConn) Close() error {
2571 select {
2572 case c.closec <- true:
2573 default:
2574 }
2575 return nil
2576 }
2577
2578 func (c *slowTestConn) Write(b []byte) (int, error) {
2579 if !c.wd.IsZero() && time.Now().After(c.wd) {
2580 return 0, syscall.ETIMEDOUT
2581 }
2582 return len(b), nil
2583 }
2584
2585 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2586 if testing.Short() {
2587 t.Skip("skipping in -short mode")
2588 }
2589 defer afterTest(t)
2590 for _, handler := range testHandlerBodyConsumers {
2591 conn := &slowTestConn{
2592 script: []any{
2593 "POST /public HTTP/1.1\r\n" +
2594 "Host: test\r\n" +
2595 "Content-Length: 10000\r\n" +
2596 "\r\n",
2597 "foo bar baz",
2598 600 * time.Millisecond,
2599 "GET /secret HTTP/1.1\r\n" +
2600 "Host: test\r\n" +
2601 "\r\n",
2602 },
2603 closec: make(chan bool, 1),
2604 }
2605 ls := &oneConnListener{conn}
2606
2607 var numReqs int
2608 s := Server{
2609 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2610 numReqs++
2611 if strings.Contains(req.URL.Path, "secret") {
2612 t.Error("Request for /secret encountered, should not have happened.")
2613 }
2614 handler.f(req.Body)
2615 }),
2616 ReadTimeout: 400 * time.Millisecond,
2617 }
2618 go s.Serve(ls)
2619 <-conn.closec
2620
2621 if numReqs != 1 {
2622 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2623 }
2624 }
2625 }
2626
2627
2628 type cancelableTimeoutContext struct {
2629 context.Context
2630 }
2631
2632 func (c cancelableTimeoutContext) Err() error {
2633 if c.Context.Err() != nil {
2634 return context.DeadlineExceeded
2635 }
2636 return nil
2637 }
2638
2639 func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
2640 func testTimeoutHandler(t *testing.T, mode testMode) {
2641 sendHi := make(chan bool, 1)
2642 writeErrors := make(chan error, 1)
2643 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2644 <-sendHi
2645 _, werr := w.Write([]byte("hi"))
2646 writeErrors <- werr
2647 })
2648 ctx, cancel := context.WithCancel(context.Background())
2649 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2650 cst := newClientServerTest(t, mode, h)
2651
2652
2653 sendHi <- true
2654 res, err := cst.c.Get(cst.ts.URL)
2655 if err != nil {
2656 t.Error(err)
2657 }
2658 if g, e := res.StatusCode, StatusOK; g != e {
2659 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2660 }
2661 body, _ := io.ReadAll(res.Body)
2662 if g, e := string(body), "hi"; g != e {
2663 t.Errorf("got body %q; expected %q", g, e)
2664 }
2665 if g := <-writeErrors; g != nil {
2666 t.Errorf("got unexpected Write error on first request: %v", g)
2667 }
2668
2669
2670 cancel()
2671
2672 res, err = cst.c.Get(cst.ts.URL)
2673 if err != nil {
2674 t.Error(err)
2675 }
2676 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2677 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2678 }
2679 body, _ = io.ReadAll(res.Body)
2680 if !strings.Contains(string(body), "<title>Timeout</title>") {
2681 t.Errorf("expected timeout body; got %q", string(body))
2682 }
2683 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2684 t.Errorf("response content-type = %q; want %q", g, w)
2685 }
2686
2687
2688
2689 sendHi <- true
2690 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2691 t.Errorf("expected Write error of %v; got %v", e, g)
2692 }
2693 }
2694
2695
2696 func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
2697 func testTimeoutHandlerRace(t *testing.T, mode testMode) {
2698 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2699 ms, _ := strconv.Atoi(r.URL.Path[1:])
2700 if ms == 0 {
2701 ms = 1
2702 }
2703 for i := 0; i < ms; i++ {
2704 w.Write([]byte("hi"))
2705 time.Sleep(time.Millisecond)
2706 }
2707 })
2708
2709 ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
2710
2711 c := ts.Client()
2712
2713 var wg sync.WaitGroup
2714 gate := make(chan bool, 10)
2715 n := 50
2716 if testing.Short() {
2717 n = 10
2718 gate = make(chan bool, 3)
2719 }
2720 for i := 0; i < n; i++ {
2721 gate <- true
2722 wg.Add(1)
2723 go func() {
2724 defer wg.Done()
2725 defer func() { <-gate }()
2726 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2727 if err == nil {
2728 io.Copy(io.Discard, res.Body)
2729 res.Body.Close()
2730 }
2731 }()
2732 }
2733 wg.Wait()
2734 }
2735
2736
2737
2738 func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
2739 func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
2740 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2741 w.WriteHeader(204)
2742 })
2743
2744 ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
2745
2746 var wg sync.WaitGroup
2747 gate := make(chan bool, 50)
2748 n := 500
2749 if testing.Short() {
2750 n = 10
2751 }
2752
2753 c := ts.Client()
2754 for i := 0; i < n; i++ {
2755 gate <- true
2756 wg.Add(1)
2757 go func() {
2758 defer wg.Done()
2759 defer func() { <-gate }()
2760 res, err := c.Get(ts.URL)
2761 if err != nil {
2762
2763
2764 t.Log(err)
2765 return
2766 }
2767 defer res.Body.Close()
2768 io.Copy(io.Discard, res.Body)
2769 }()
2770 }
2771 wg.Wait()
2772 }
2773
2774
2775 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
2776 func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
2777 sendHi := make(chan bool, 1)
2778 writeErrors := make(chan error, 1)
2779 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2780 w.Header().Set("Content-Type", "text/plain")
2781 <-sendHi
2782 _, werr := w.Write([]byte("hi"))
2783 writeErrors <- werr
2784 })
2785 ctx, cancel := context.WithCancel(context.Background())
2786 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2787 cst := newClientServerTest(t, mode, h)
2788
2789
2790 sendHi <- true
2791 res, err := cst.c.Get(cst.ts.URL)
2792 if err != nil {
2793 t.Error(err)
2794 }
2795 if g, e := res.StatusCode, StatusOK; g != e {
2796 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2797 }
2798 body, _ := io.ReadAll(res.Body)
2799 if g, e := string(body), "hi"; g != e {
2800 t.Errorf("got body %q; expected %q", g, e)
2801 }
2802 if g := <-writeErrors; g != nil {
2803 t.Errorf("got unexpected Write error on first request: %v", g)
2804 }
2805
2806
2807 cancel()
2808
2809 res, err = cst.c.Get(cst.ts.URL)
2810 if err != nil {
2811 t.Error(err)
2812 }
2813 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2814 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2815 }
2816 body, _ = io.ReadAll(res.Body)
2817 if !strings.Contains(string(body), "<title>Timeout</title>") {
2818 t.Errorf("expected timeout body; got %q", string(body))
2819 }
2820
2821
2822
2823 sendHi <- true
2824 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2825 t.Errorf("expected Write error of %v; got %v", e, g)
2826 }
2827 }
2828
2829
2830 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2831 run(t, testTimeoutHandlerStartTimerWhenServing)
2832 }
2833 func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
2834 if testing.Short() {
2835 t.Skip("skipping sleeping test in -short mode")
2836 }
2837 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2838 w.WriteHeader(StatusNoContent)
2839 }
2840 timeout := 300 * time.Millisecond
2841 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2842 defer ts.Close()
2843
2844 c := ts.Client()
2845
2846
2847
2848
2849 time.Sleep(2 * timeout)
2850 res, err := c.Get(ts.URL)
2851 if err != nil {
2852 t.Fatal(err)
2853 }
2854 defer res.Body.Close()
2855 if res.StatusCode != StatusNoContent {
2856 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2857 }
2858 }
2859
2860 func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
2861 func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
2862 writeErrors := make(chan error, 1)
2863 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2864 w.Header().Set("Content-Type", "text/plain")
2865 var err error
2866
2867
2868
2869 for i := 0; i < 100; i++ {
2870 _, err = w.Write([]byte("a"))
2871 if err != nil {
2872 break
2873 }
2874 time.Sleep(1 * time.Millisecond)
2875 }
2876 writeErrors <- err
2877 })
2878 ctx, cancel := context.WithCancel(context.Background())
2879 cancel()
2880 h := NewTestTimeoutHandler(sayHi, ctx)
2881 cst := newClientServerTest(t, mode, h)
2882 defer cst.close()
2883
2884 res, err := cst.c.Get(cst.ts.URL)
2885 if err != nil {
2886 t.Error(err)
2887 }
2888 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2889 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2890 }
2891 body, _ := io.ReadAll(res.Body)
2892 if g, e := string(body), ""; g != e {
2893 t.Errorf("got body %q; expected %q", g, e)
2894 }
2895 if g, e := <-writeErrors, context.Canceled; g != e {
2896 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2897 }
2898 }
2899
2900
2901 func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
2902 func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
2903 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2904
2905 }
2906 timeout := 300 * time.Millisecond
2907 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2908
2909 c := ts.Client()
2910
2911 res, err := c.Get(ts.URL)
2912 if err != nil {
2913 t.Fatal(err)
2914 }
2915 defer res.Body.Close()
2916 if res.StatusCode != StatusOK {
2917 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2918 }
2919 }
2920
2921
2922 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2923 wrapper := func(h Handler) Handler {
2924 return TimeoutHandler(h, time.Second, "")
2925 }
2926 run(t, func(t *testing.T, mode testMode) {
2927 testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
2928 }, testNotParallel, http3SkippedMode)
2929 }
2930
2931 func TestRedirectBadPath(t *testing.T) {
2932
2933
2934 rr := httptest.NewRecorder()
2935 req := &Request{
2936 Method: "GET",
2937 URL: &url.URL{
2938 Scheme: "http",
2939 Path: "not-empty-but-no-leading-slash",
2940 },
2941 }
2942 Redirect(rr, req, "", 304)
2943 if rr.Code != 304 {
2944 t.Errorf("Code = %d; want 304", rr.Code)
2945 }
2946 }
2947
2948 func TestRedirectEscapedPath(t *testing.T) {
2949 baseURL, redirectURL := "http://example.com/foo%2Fbar/", "qux%2Fbaz"
2950 req := httptest.NewRequest("GET", baseURL, NoBody)
2951
2952 rr := httptest.NewRecorder()
2953 Redirect(rr, req, redirectURL, StatusMovedPermanently)
2954
2955 wantURL := "/foo%2Fbar/qux%2Fbaz"
2956 if got := rr.Result().Header.Get("Location"); got != wantURL {
2957 t.Errorf("Redirect(%s, %s) = %s, want = %s", baseURL, redirectURL, got, wantURL)
2958 }
2959 }
2960
2961
2962 func TestRedirect(t *testing.T) {
2963 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2964
2965 var tests = []struct {
2966 in string
2967 want string
2968 }{
2969
2970 {"http://foobar.com/baz", "http://foobar.com/baz"},
2971
2972 {"https://foobar.com/baz", "https://foobar.com/baz"},
2973
2974 {"test://foobar.com/baz", "test://foobar.com/baz"},
2975
2976 {"//foobar.com/baz", "//foobar.com/baz"},
2977
2978 {"/foobar.com/baz", "/foobar.com/baz"},
2979
2980 {"foobar.com/baz", "/qux/foobar.com/baz"},
2981
2982 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2983
2984 {"///foobar.com/baz", "/foobar.com/baz"},
2985
2986
2987 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2988 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2989 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2990
2991 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2992 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2993 }
2994
2995 for _, tt := range tests {
2996 rec := httptest.NewRecorder()
2997 Redirect(rec, req, tt.in, 302)
2998 if got, want := rec.Code, 302; got != want {
2999 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
3000 }
3001 if got := rec.Header().Get("Location"); got != tt.want {
3002 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
3003 }
3004 }
3005 }
3006
3007
3008
3009 func TestRedirectContentTypeAndBody(t *testing.T) {
3010 type ctHeader struct {
3011 Values []string
3012 }
3013
3014 var tests = []struct {
3015 method string
3016 ct *ctHeader
3017 wantCT string
3018 wantBody string
3019 }{
3020 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
3021 {MethodHead, nil, "text/html; charset=utf-8", ""},
3022 {MethodPost, nil, "", ""},
3023 {MethodDelete, nil, "", ""},
3024 {"foo", nil, "", ""},
3025 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
3026 {MethodGet, &ctHeader{[]string{}}, "", ""},
3027 {MethodGet, &ctHeader{nil}, "", ""},
3028 }
3029 for _, tt := range tests {
3030 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
3031 rec := httptest.NewRecorder()
3032 if tt.ct != nil {
3033 rec.Header()["Content-Type"] = tt.ct.Values
3034 }
3035 Redirect(rec, req, "/foo", 302)
3036 if got, want := rec.Code, 302; got != want {
3037 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
3038 }
3039 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
3040 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
3041 }
3042 resp := rec.Result()
3043 body, err := io.ReadAll(resp.Body)
3044 if err != nil {
3045 t.Fatal(err)
3046 }
3047 if got, want := string(body), tt.wantBody; got != want {
3048 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
3049 }
3050 }
3051 }
3052
3053
3054
3055
3056
3057
3058
3059 func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
3060
3061 func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
3062 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
3063 all, err := io.ReadAll(r.Body)
3064 if err != nil {
3065 t.Fatalf("handler ReadAll: %v", err)
3066 }
3067 if len(all) != 0 {
3068 t.Errorf("handler got %d bytes; expected 0", len(all))
3069 }
3070 rw.Header().Set("Content-Length", "0")
3071 }))
3072
3073 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
3074 if err != nil {
3075 t.Fatal(err)
3076 }
3077 req.ContentLength = 0
3078
3079 var resp [5]*Response
3080 for i := range resp {
3081 resp[i], err = cst.c.Do(req)
3082 if err != nil {
3083 t.Fatalf("client post #%d: %v", i, err)
3084 }
3085 }
3086
3087 for i := range resp {
3088 all, err := io.ReadAll(resp[i].Body)
3089 if err != nil {
3090 t.Fatalf("req #%d: client ReadAll: %v", i, err)
3091 }
3092 if len(all) != 0 {
3093 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
3094 }
3095 }
3096 }
3097
3098 func TestHandlerPanicNil(t *testing.T) {
3099 run(t, func(t *testing.T, mode testMode) {
3100 testHandlerPanic(t, false, mode, nil, nil)
3101 }, testNotParallel, http3SkippedMode)
3102 }
3103
3104 func TestHandlerPanic(t *testing.T) {
3105 run(t, func(t *testing.T, mode testMode) {
3106 testHandlerPanic(t, false, mode, nil, "intentional death for testing")
3107 }, testNotParallel, http3SkippedMode)
3108 }
3109
3110 func TestHandlerPanicWithHijack(t *testing.T) {
3111
3112 run(t, func(t *testing.T, mode testMode) {
3113 testHandlerPanic(t, true, mode, nil, "intentional death for testing")
3114 }, []testMode{http1Mode})
3115 }
3116
3117 func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
3118
3119
3120
3121
3122
3123
3124
3125
3126 pr, pw := io.Pipe()
3127 defer pw.Close()
3128
3129 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
3130 if withHijack {
3131 rwc, _, err := w.(Hijacker).Hijack()
3132 if err != nil {
3133 t.Logf("unexpected error: %v", err)
3134 }
3135 defer rwc.Close()
3136 }
3137 panic(panicValue)
3138 })
3139 if wrapper != nil {
3140 handler = wrapper(handler)
3141 }
3142 cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
3143 ts.Config.ErrorLog = log.New(pw, "", 0)
3144 })
3145
3146
3147 done := make(chan bool, 1)
3148 go func() {
3149 buf := make([]byte, 4<<10)
3150 _, err := pr.Read(buf)
3151 pr.Close()
3152 if err != nil && err != io.EOF {
3153 t.Error(err)
3154 }
3155 done <- true
3156 }()
3157
3158 _, err := cst.c.Get(cst.ts.URL)
3159 if err == nil {
3160 t.Logf("expected an error")
3161 }
3162
3163 if panicValue == nil {
3164 return
3165 }
3166
3167 <-done
3168 }
3169
3170 type terrorWriter struct{ t *testing.T }
3171
3172 func (w terrorWriter) Write(p []byte) (int, error) {
3173 w.t.Errorf("%s", p)
3174 return len(p), nil
3175 }
3176
3177
3178
3179 func TestServerWriteHijackZeroBytes(t *testing.T) {
3180 run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
3181 }
3182 func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
3183 done := make(chan struct{})
3184 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3185 defer close(done)
3186 w.(Flusher).Flush()
3187 conn, _, err := w.(Hijacker).Hijack()
3188 if err != nil {
3189 t.Errorf("Hijack: %v", err)
3190 return
3191 }
3192 defer conn.Close()
3193 _, err = w.Write(nil)
3194 if err != ErrHijacked {
3195 t.Errorf("Write error = %v; want ErrHijacked", err)
3196 }
3197 }), func(ts *httptest.Server) {
3198 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
3199 }).ts
3200
3201 c := ts.Client()
3202 res, err := c.Get(ts.URL)
3203 if err != nil {
3204 t.Fatal(err)
3205 }
3206 res.Body.Close()
3207 <-done
3208 }
3209
3210 func TestServerNoDate(t *testing.T) {
3211 run(t, func(t *testing.T, mode testMode) {
3212 testServerNoHeader(t, mode, "Date")
3213 })
3214 }
3215
3216 func TestServerContentType(t *testing.T) {
3217 run(t, func(t *testing.T, mode testMode) {
3218 testServerNoHeader(t, mode, "Content-Type")
3219 })
3220 }
3221
3222 func testServerNoHeader(t *testing.T, mode testMode, header string) {
3223 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3224 w.Header()[header] = nil
3225 io.WriteString(w, "<html>foo</html>")
3226 }))
3227 res, err := cst.c.Get(cst.ts.URL)
3228 if err != nil {
3229 t.Fatal(err)
3230 }
3231 res.Body.Close()
3232 if got, ok := res.Header[header]; ok {
3233 t.Fatalf("Expected no %s header; got %q", header, got)
3234 }
3235 }
3236
3237 func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
3238 func testStripPrefix(t *testing.T, mode testMode) {
3239 h := HandlerFunc(func(w ResponseWriter, r *Request) {
3240 w.Header().Set("X-Path", r.URL.Path)
3241 w.Header().Set("X-RawPath", r.URL.RawPath)
3242 })
3243 ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
3244
3245 c := ts.Client()
3246
3247 cases := []struct {
3248 reqPath string
3249 path string
3250 rawPath string
3251 }{
3252 {"/foo/bar/qux", "/qux", ""},
3253 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
3254 {"/foo%2Fbar/qux", "", ""},
3255 {"/bar", "", ""},
3256 }
3257 for _, tc := range cases {
3258 t.Run(tc.reqPath, func(t *testing.T) {
3259 res, err := c.Get(ts.URL + tc.reqPath)
3260 if err != nil {
3261 t.Fatal(err)
3262 }
3263 res.Body.Close()
3264 if tc.path == "" {
3265 if res.StatusCode != StatusNotFound {
3266 t.Errorf("got %q, want 404 Not Found", res.Status)
3267 }
3268 return
3269 }
3270 if res.StatusCode != StatusOK {
3271 t.Fatalf("got %q, want 200 OK", res.Status)
3272 }
3273 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
3274 t.Errorf("got Path %q, want %q", g, w)
3275 }
3276 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
3277 t.Errorf("got RawPath %q, want %q", g, w)
3278 }
3279 })
3280 }
3281 }
3282
3283
3284 func TestStripPrefixNotModifyRequest(t *testing.T) {
3285 h := StripPrefix("/foo", NotFoundHandler())
3286 req := httptest.NewRequest("GET", "/foo/bar", nil)
3287 h.ServeHTTP(httptest.NewRecorder(), req)
3288 if req.URL.Path != "/foo/bar" {
3289 t.Errorf("StripPrefix should not modify the provided Request, but it did")
3290 }
3291 }
3292
3293 func TestRequestLimit(t *testing.T) { run(t, testRequestLimit, http3SkippedMode) }
3294 func testRequestLimit(t *testing.T, mode testMode) {
3295 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3296 t.Fatalf("didn't expect to get request in Handler")
3297 }), optQuietLog)
3298 req, _ := NewRequest("GET", cst.ts.URL, nil)
3299 var bytesPerHeader = len("header12345: val12345\r\n")
3300 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
3301 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
3302 }
3303 res, err := cst.c.Do(req)
3304 if res != nil {
3305 defer res.Body.Close()
3306 }
3307 if mode == http2Mode {
3308
3309
3310
3311
3312 if err == nil && res.StatusCode != 431 {
3313 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3314 }
3315 } else {
3316
3317
3318
3319
3320 if err != nil {
3321 t.Fatalf("Do: %v", err)
3322 }
3323 if res.StatusCode != 431 {
3324 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3325 }
3326 }
3327 }
3328
3329 type neverEnding byte
3330
3331 func (b neverEnding) Read(p []byte) (n int, err error) {
3332 for i := range p {
3333 p[i] = byte(b)
3334 }
3335 return len(p), nil
3336 }
3337
3338 type bodyLimitReader struct {
3339 mu sync.Mutex
3340 count int
3341 limit int
3342 closed chan struct{}
3343 }
3344
3345 func (r *bodyLimitReader) Read(p []byte) (int, error) {
3346 r.mu.Lock()
3347 defer r.mu.Unlock()
3348 select {
3349 case <-r.closed:
3350 return 0, errors.New("closed")
3351 default:
3352 }
3353 if r.count > r.limit {
3354 return 0, errors.New("at limit")
3355 }
3356 r.count += len(p)
3357 for i := range p {
3358 p[i] = 'a'
3359 }
3360 return len(p), nil
3361 }
3362
3363 func (r *bodyLimitReader) Close() error {
3364 r.mu.Lock()
3365 defer r.mu.Unlock()
3366 close(r.closed)
3367 return nil
3368 }
3369
3370 func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
3371 func testRequestBodyLimit(t *testing.T, mode testMode) {
3372 const limit = 1 << 20
3373 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3374 r.Body = MaxBytesReader(w, r.Body, limit)
3375 n, err := io.Copy(io.Discard, r.Body)
3376 if err == nil {
3377 t.Errorf("expected error from io.Copy")
3378 }
3379 if n != limit {
3380 t.Errorf("io.Copy = %d, want %d", n, limit)
3381 }
3382 mbErr, ok := err.(*MaxBytesError)
3383 if !ok {
3384 t.Errorf("expected MaxBytesError, got %T", err)
3385 }
3386 if mbErr.Limit != limit {
3387 t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
3388 }
3389 }))
3390
3391 body := &bodyLimitReader{
3392 closed: make(chan struct{}),
3393 limit: limit * 200,
3394 }
3395 req, _ := NewRequest("POST", cst.ts.URL, body)
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3406 resp, err := cst.c.Do(req)
3407 if err == nil {
3408 resp.Body.Close()
3409 }
3410
3411
3412 <-body.closed
3413
3414 if body.count > limit*100 {
3415 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3416 limit, body.count)
3417 }
3418 }
3419
3420
3421
3422 func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown, http3SkippedMode) }
3423 func testClientWriteShutdown(t *testing.T, mode testMode) {
3424 if runtime.GOOS == "plan9" {
3425 t.Skip("skipping test; see https://golang.org/issue/17906")
3426 }
3427 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
3428 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3429 if err != nil {
3430 t.Fatalf("Dial: %v", err)
3431 }
3432 err = conn.(*net.TCPConn).CloseWrite()
3433 if err != nil {
3434 t.Fatalf("CloseWrite: %v", err)
3435 }
3436
3437 bs, err := io.ReadAll(conn)
3438 if err != nil {
3439 t.Errorf("ReadAll: %v", err)
3440 }
3441 got := string(bs)
3442 if got != "" {
3443 t.Errorf("read %q from server; want nothing", got)
3444 }
3445 }
3446
3447
3448
3449 func TestServerBufferedChunking(t *testing.T) {
3450 conn := new(testConn)
3451 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3452 conn.closec = make(chan bool, 1)
3453 ls := &oneConnListener{conn}
3454 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3455 rw.(Flusher).Flush()
3456 rw.Write([]byte{'x'})
3457 rw.Write([]byte{'y'})
3458 rw.Write([]byte{'z'})
3459 }))
3460 <-conn.closec
3461 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3462 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3463 conn.writeBuf.Bytes())
3464 }
3465 }
3466
3467
3468
3469
3470
3471 func TestServerGracefulClose(t *testing.T) {
3472
3473 run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
3474 }
3475 func testServerGracefulClose(t *testing.T, mode testMode) {
3476 runTimeSensitiveTest(t, []time.Duration{
3477 1 * time.Millisecond,
3478 5 * time.Millisecond,
3479 10 * time.Millisecond,
3480 50 * time.Millisecond,
3481 100 * time.Millisecond,
3482 500 * time.Millisecond,
3483 time.Second,
3484 5 * time.Second,
3485 }, func(t *testing.T, timeout time.Duration) error {
3486 SetRSTAvoidanceDelay(t, timeout)
3487 t.Logf("set RST avoidance delay to %v", timeout)
3488
3489 const bodySize = 5 << 20
3490 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3491 for i := 0; i < bodySize; i++ {
3492 req = append(req, 'x')
3493 }
3494
3495 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3496 Error(w, "bye", StatusUnauthorized)
3497 }))
3498
3499
3500 defer cst.close()
3501 ts := cst.ts
3502
3503 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3504 if err != nil {
3505 return err
3506 }
3507 writeErr := make(chan error)
3508 go func() {
3509 _, err := conn.Write(req)
3510 writeErr <- err
3511 }()
3512 defer func() {
3513 conn.Close()
3514
3515
3516
3517 <-writeErr
3518 }()
3519
3520 br := bufio.NewReader(conn)
3521 lineNum := 0
3522 for {
3523 line, err := br.ReadString('\n')
3524 if err == io.EOF {
3525 break
3526 }
3527 if err != nil {
3528 return fmt.Errorf("ReadLine: %v", err)
3529 }
3530 lineNum++
3531 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3532 t.Errorf("Response line = %q; want a 401", line)
3533 }
3534 }
3535 return nil
3536 })
3537 }
3538
3539 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
3540 func testCaseSensitiveMethod(t *testing.T, mode testMode) {
3541 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3542 if r.Method != "get" {
3543 t.Errorf(`Got method %q; want "get"`, r.Method)
3544 }
3545 }))
3546 defer cst.close()
3547 req, _ := NewRequest("get", cst.ts.URL, nil)
3548 res, err := cst.c.Do(req)
3549 if err != nil {
3550 t.Error(err)
3551 return
3552 }
3553
3554 res.Body.Close()
3555 }
3556
3557
3558
3559
3560
3561 func TestContentLengthZero(t *testing.T) {
3562 run(t, testContentLengthZero, []testMode{http1Mode})
3563 }
3564 func testContentLengthZero(t *testing.T, mode testMode) {
3565 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
3566
3567 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3568 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3569 if err != nil {
3570 t.Fatalf("error dialing: %v", err)
3571 }
3572 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3573 if err != nil {
3574 t.Fatalf("error writing: %v", err)
3575 }
3576 req, _ := NewRequest("GET", "/", nil)
3577 res, err := ReadResponse(bufio.NewReader(conn), req)
3578 if err != nil {
3579 t.Fatalf("error reading response: %v", err)
3580 }
3581 if te := res.TransferEncoding; len(te) > 0 {
3582 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3583 }
3584 if cl := res.ContentLength; cl != 0 {
3585 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3586 }
3587 conn.Close()
3588 }
3589 }
3590
3591 func TestCloseNotifier(t *testing.T) {
3592 run(t, testCloseNotifier, []testMode{http1Mode})
3593 }
3594 func testCloseNotifier(t *testing.T, mode testMode) {
3595 gotReq := make(chan bool, 1)
3596 sawClose := make(chan bool, 1)
3597 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3598 gotReq <- true
3599 cc := rw.(CloseNotifier).CloseNotify()
3600 <-cc
3601 sawClose <- true
3602 })).ts
3603 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3604 if err != nil {
3605 t.Fatalf("error dialing: %v", err)
3606 }
3607 diec := make(chan bool)
3608 go func() {
3609 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3610 if err != nil {
3611 t.Error(err)
3612 return
3613 }
3614 <-diec
3615 conn.Close()
3616 }()
3617 For:
3618 for {
3619 select {
3620 case <-gotReq:
3621 diec <- true
3622 case <-sawClose:
3623 break For
3624 }
3625 }
3626 ts.Close()
3627 }
3628
3629
3630
3631
3632
3633 func TestCloseNotifierPipelined(t *testing.T) {
3634 run(t, testCloseNotifierPipelined, []testMode{http1Mode})
3635 }
3636 func testCloseNotifierPipelined(t *testing.T, mode testMode) {
3637 gotReq := make(chan bool, 2)
3638 sawClose := make(chan bool, 2)
3639 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3640 gotReq <- true
3641 cc := rw.(CloseNotifier).CloseNotify()
3642 select {
3643 case <-cc:
3644 t.Error("unexpected CloseNotify")
3645 case <-time.After(100 * time.Millisecond):
3646 }
3647 sawClose <- true
3648 })).ts
3649 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3650 if err != nil {
3651 t.Fatalf("error dialing: %v", err)
3652 }
3653 diec := make(chan bool, 1)
3654 defer close(diec)
3655 go func() {
3656 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3657 _, err = io.WriteString(conn, req+req)
3658 if err != nil {
3659 t.Error(err)
3660 return
3661 }
3662 <-diec
3663 conn.Close()
3664 }()
3665 reqs := 0
3666 closes := 0
3667 for {
3668 select {
3669 case <-gotReq:
3670 reqs++
3671 if reqs > 2 {
3672 t.Fatal("too many requests")
3673 }
3674 case <-sawClose:
3675 closes++
3676 if closes > 1 {
3677 return
3678 }
3679 }
3680 }
3681 }
3682
3683 func TestCloseNotifierChanLeak(t *testing.T) {
3684 defer afterTest(t)
3685 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3686 for i := 0; i < 20; i++ {
3687 var output bytes.Buffer
3688 conn := &rwTestConn{
3689 Reader: bytes.NewReader(req),
3690 Writer: &output,
3691 closec: make(chan bool, 1),
3692 }
3693 ln := &oneConnListener{conn: conn}
3694 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3695
3696
3697
3698 _ = rw.(CloseNotifier).CloseNotify()
3699 })
3700 go Serve(ln, handler)
3701 <-conn.closec
3702 }
3703 }
3704
3705
3706
3707
3708
3709
3710
3711
3712
3713
3714 func TestHijackAfterCloseNotifier(t *testing.T) {
3715 run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
3716 }
3717 func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
3718 script := make(chan string, 2)
3719 script <- "closenotify"
3720 script <- "hijack"
3721 close(script)
3722 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3723 plan := <-script
3724 switch plan {
3725 default:
3726 panic("bogus plan; too many requests")
3727 case "closenotify":
3728 w.(CloseNotifier).CloseNotify()
3729 w.Header().Set("X-Addr", r.RemoteAddr)
3730 case "hijack":
3731 c, _, err := w.(Hijacker).Hijack()
3732 if err != nil {
3733 t.Errorf("Hijack in Handler: %v", err)
3734 return
3735 }
3736 if _, ok := c.(*net.TCPConn); !ok {
3737
3738
3739 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3740 }
3741 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3742 c.Close()
3743 return
3744 }
3745 })).ts
3746 res1, err := ts.Client().Get(ts.URL)
3747 if err != nil {
3748 log.Fatal(err)
3749 }
3750 res2, err := ts.Client().Get(ts.URL)
3751 if err != nil {
3752 log.Fatal(err)
3753 }
3754 addr1 := res1.Header.Get("X-Addr")
3755 addr2 := res2.Header.Get("X-Addr")
3756 if addr1 == "" || addr1 != addr2 {
3757 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3758 }
3759 }
3760
3761 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3762 run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
3763 }
3764 func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
3765 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3766 bodyOkay := make(chan bool, 1)
3767 gotCloseNotify := make(chan bool, 1)
3768 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3769 defer close(bodyOkay)
3770
3771 reqBody := r.Body
3772 r.Body = nil
3773
3774 gone := w.(CloseNotifier).CloseNotify()
3775 slurp, err := io.ReadAll(reqBody)
3776 if err != nil {
3777 t.Errorf("Body read: %v", err)
3778 return
3779 }
3780 if len(slurp) != len(requestBody) {
3781 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3782 return
3783 }
3784 if !bytes.Equal(slurp, requestBody) {
3785 t.Error("Backend read wrong request body.")
3786 return
3787 }
3788 bodyOkay <- true
3789 <-gone
3790 gotCloseNotify <- true
3791 })).ts
3792
3793 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3794 if err != nil {
3795 t.Fatal(err)
3796 }
3797 defer conn.Close()
3798
3799 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3800 len(requestBody), requestBody)
3801 if !<-bodyOkay {
3802
3803 return
3804 }
3805 conn.Close()
3806 <-gotCloseNotify
3807 }
3808
3809 func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
3810 func testOptions(t *testing.T, mode testMode) {
3811 uric := make(chan string, 2)
3812 mux := NewServeMux()
3813 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3814 uric <- r.RequestURI
3815 })
3816 ts := newClientServerTest(t, mode, mux).ts
3817
3818 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3819 if err != nil {
3820 t.Fatal(err)
3821 }
3822 defer conn.Close()
3823
3824
3825 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3826 if err != nil {
3827 t.Fatal(err)
3828 }
3829 br := bufio.NewReader(conn)
3830 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3831 if err != nil {
3832 t.Fatal(err)
3833 }
3834 if res.StatusCode != 200 {
3835 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3836 }
3837
3838
3839 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3840 if err != nil {
3841 t.Fatal(err)
3842 }
3843 res, err = ReadResponse(br, &Request{Method: "GET"})
3844 if err != nil {
3845 t.Fatal(err)
3846 }
3847 if res.StatusCode != 400 {
3848 t.Errorf("Got non-400 response to GET *: %#v", res)
3849 }
3850
3851 res, err = Get(ts.URL + "/second")
3852 if err != nil {
3853 t.Fatal(err)
3854 }
3855 res.Body.Close()
3856 if got := <-uric; got != "/second" {
3857 t.Errorf("Handler saw request for %q; want /second", got)
3858 }
3859 }
3860
3861 func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
3862 func testOptionsHandler(t *testing.T, mode testMode) {
3863 rc := make(chan *Request, 1)
3864
3865 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3866 rc <- r
3867 }), func(ts *httptest.Server) {
3868 ts.Config.DisableGeneralOptionsHandler = true
3869 }).ts
3870
3871 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3872 if err != nil {
3873 t.Fatal(err)
3874 }
3875 defer conn.Close()
3876
3877 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3878 if err != nil {
3879 t.Fatal(err)
3880 }
3881
3882 if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
3883 t.Errorf("Expected OPTIONS * request, got %v", got)
3884 }
3885 }
3886
3887
3888
3889
3890
3891
3892
3893
3894
3895
3896 func TestHeaderToWire(t *testing.T) {
3897 tests := []struct {
3898 name string
3899 handler func(ResponseWriter, *Request)
3900 check func(got, logs string) error
3901 }{
3902 {
3903 name: "write without Header",
3904 handler: func(rw ResponseWriter, r *Request) {
3905 rw.Write([]byte("hello world"))
3906 },
3907 check: func(got, logs string) error {
3908 if !strings.Contains(got, "Content-Length:") {
3909 return errors.New("no content-length")
3910 }
3911 if !strings.Contains(got, "Content-Type: text/plain") {
3912 return errors.New("no content-type")
3913 }
3914 return nil
3915 },
3916 },
3917 {
3918 name: "Header mutation before write",
3919 handler: func(rw ResponseWriter, r *Request) {
3920 h := rw.Header()
3921 h.Set("Content-Type", "some/type")
3922 rw.Write([]byte("hello world"))
3923 h.Set("Too-Late", "bogus")
3924 },
3925 check: func(got, logs string) error {
3926 if !strings.Contains(got, "Content-Length:") {
3927 return errors.New("no content-length")
3928 }
3929 if !strings.Contains(got, "Content-Type: some/type") {
3930 return errors.New("wrong content-type")
3931 }
3932 if strings.Contains(got, "Too-Late") {
3933 return errors.New("don't want too-late header")
3934 }
3935 return nil
3936 },
3937 },
3938 {
3939 name: "write then useless Header mutation",
3940 handler: func(rw ResponseWriter, r *Request) {
3941 rw.Write([]byte("hello world"))
3942 rw.Header().Set("Too-Late", "Write already wrote headers")
3943 },
3944 check: func(got, logs string) error {
3945 if strings.Contains(got, "Too-Late") {
3946 return errors.New("header appeared from after WriteHeader")
3947 }
3948 return nil
3949 },
3950 },
3951 {
3952 name: "flush then write",
3953 handler: func(rw ResponseWriter, r *Request) {
3954 rw.(Flusher).Flush()
3955 rw.Write([]byte("post-flush"))
3956 rw.Header().Set("Too-Late", "Write already wrote headers")
3957 },
3958 check: func(got, logs string) error {
3959 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3960 return errors.New("not chunked")
3961 }
3962 if strings.Contains(got, "Too-Late") {
3963 return errors.New("header appeared from after WriteHeader")
3964 }
3965 return nil
3966 },
3967 },
3968 {
3969 name: "header then flush",
3970 handler: func(rw ResponseWriter, r *Request) {
3971 rw.Header().Set("Content-Type", "some/type")
3972 rw.(Flusher).Flush()
3973 rw.Write([]byte("post-flush"))
3974 rw.Header().Set("Too-Late", "Write already wrote headers")
3975 },
3976 check: func(got, logs string) error {
3977 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3978 return errors.New("not chunked")
3979 }
3980 if strings.Contains(got, "Too-Late") {
3981 return errors.New("header appeared from after WriteHeader")
3982 }
3983 if !strings.Contains(got, "Content-Type: some/type") {
3984 return errors.New("wrong content-type")
3985 }
3986 return nil
3987 },
3988 },
3989 {
3990 name: "sniff-on-first-write content-type",
3991 handler: func(rw ResponseWriter, r *Request) {
3992 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3993 rw.Header().Set("Content-Type", "x/wrong")
3994 },
3995 check: func(got, logs string) error {
3996 if !strings.Contains(got, "Content-Type: text/html") {
3997 return errors.New("wrong content-type; want html")
3998 }
3999 return nil
4000 },
4001 },
4002 {
4003 name: "explicit content-type wins",
4004 handler: func(rw ResponseWriter, r *Request) {
4005 rw.Header().Set("Content-Type", "some/type")
4006 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
4007 },
4008 check: func(got, logs string) error {
4009 if !strings.Contains(got, "Content-Type: some/type") {
4010 return errors.New("wrong content-type; want html")
4011 }
4012 return nil
4013 },
4014 },
4015 {
4016 name: "empty handler",
4017 handler: func(rw ResponseWriter, r *Request) {
4018 },
4019 check: func(got, logs string) error {
4020 if !strings.Contains(got, "Content-Length: 0") {
4021 return errors.New("want 0 content-length")
4022 }
4023 return nil
4024 },
4025 },
4026 {
4027 name: "only Header, no write",
4028 handler: func(rw ResponseWriter, r *Request) {
4029 rw.Header().Set("Some-Header", "some-value")
4030 },
4031 check: func(got, logs string) error {
4032 if !strings.Contains(got, "Some-Header") {
4033 return errors.New("didn't get header")
4034 }
4035 return nil
4036 },
4037 },
4038 {
4039 name: "WriteHeader call",
4040 handler: func(rw ResponseWriter, r *Request) {
4041 rw.WriteHeader(404)
4042 rw.Header().Set("Too-Late", "some-value")
4043 },
4044 check: func(got, logs string) error {
4045 if !strings.Contains(got, "404") {
4046 return errors.New("wrong status")
4047 }
4048 if strings.Contains(got, "Too-Late") {
4049 return errors.New("shouldn't have seen Too-Late")
4050 }
4051 return nil
4052 },
4053 },
4054 }
4055 for _, tc := range tests {
4056 ht := newHandlerTest(HandlerFunc(tc.handler))
4057 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
4058 logs := ht.logbuf.String()
4059 if err := tc.check(got, logs); err != nil {
4060 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
4061 }
4062 }
4063 }
4064
4065 type errorListener struct {
4066 errs []error
4067 }
4068
4069 func (l *errorListener) Accept() (c net.Conn, err error) {
4070 if len(l.errs) == 0 {
4071 return nil, io.EOF
4072 }
4073 err = l.errs[0]
4074 l.errs = l.errs[1:]
4075 return
4076 }
4077
4078 func (l *errorListener) Close() error {
4079 return nil
4080 }
4081
4082 func (l *errorListener) Addr() net.Addr {
4083 return dummyAddr("test-address")
4084 }
4085
4086 func TestAcceptMaxFds(t *testing.T) {
4087 setParallel(t)
4088
4089 ln := &errorListener{[]error{
4090 &net.OpError{
4091 Op: "accept",
4092 Err: syscall.EMFILE,
4093 }}}
4094 server := &Server{
4095 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
4096 ErrorLog: log.New(io.Discard, "", 0),
4097 }
4098 err := server.Serve(ln)
4099 if err != io.EOF {
4100 t.Errorf("got error %v, want EOF", err)
4101 }
4102 }
4103
4104 func TestWriteAfterHijack(t *testing.T) {
4105 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
4106 var buf strings.Builder
4107 wrotec := make(chan bool, 1)
4108 conn := &rwTestConn{
4109 Reader: bytes.NewReader(req),
4110 Writer: &buf,
4111 closec: make(chan bool, 1),
4112 }
4113 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
4114 conn, bufrw, err := rw.(Hijacker).Hijack()
4115 if err != nil {
4116 t.Error(err)
4117 return
4118 }
4119 go func() {
4120 bufrw.Write([]byte("[hijack-to-bufw]"))
4121 bufrw.Flush()
4122 conn.Write([]byte("[hijack-to-conn]"))
4123 conn.Close()
4124 wrotec <- true
4125 }()
4126 })
4127 ln := &oneConnListener{conn: conn}
4128 go Serve(ln, handler)
4129 <-conn.closec
4130 <-wrotec
4131 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
4132 t.Errorf("wrote %q; want %q", g, w)
4133 }
4134 }
4135
4136 func TestDoubleHijack(t *testing.T) {
4137 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
4138 var buf bytes.Buffer
4139 conn := &rwTestConn{
4140 Reader: bytes.NewReader(req),
4141 Writer: &buf,
4142 closec: make(chan bool, 1),
4143 }
4144 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
4145 conn, _, err := rw.(Hijacker).Hijack()
4146 if err != nil {
4147 t.Error(err)
4148 return
4149 }
4150 _, _, err = rw.(Hijacker).Hijack()
4151 if err == nil {
4152 t.Errorf("got err = nil; want err != nil")
4153 }
4154 conn.Close()
4155 })
4156 ln := &oneConnListener{conn: conn}
4157 go Serve(ln, handler)
4158 <-conn.closec
4159 }
4160
4161
4162
4163
4164
4165
4166
4167 func TestHTTP10ConnectionHeader(t *testing.T) {
4168 run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
4169 }
4170 func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
4171 mux := NewServeMux()
4172 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
4173 ts := newClientServerTest(t, mode, mux).ts
4174
4175
4176 tests := []struct {
4177 req string
4178 expect []string
4179 }{
4180 {
4181 req: "GET / HTTP/1.0\r\n\r\n",
4182 expect: nil,
4183 },
4184 {
4185 req: "OPTIONS * HTTP/1.0\r\n\r\n",
4186 expect: nil,
4187 },
4188 {
4189 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
4190 expect: []string{"keep-alive"},
4191 },
4192 }
4193
4194 for _, tt := range tests {
4195 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
4196 if err != nil {
4197 t.Fatal("dial err:", err)
4198 }
4199
4200 _, err = fmt.Fprint(conn, tt.req)
4201 if err != nil {
4202 t.Fatal("conn write err:", err)
4203 }
4204
4205 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
4206 if err != nil {
4207 t.Fatal("ReadResponse err:", err)
4208 }
4209 conn.Close()
4210 resp.Body.Close()
4211
4212 got := resp.Header["Connection"]
4213 if !slices.Equal(got, tt.expect) {
4214 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
4215 }
4216 }
4217 }
4218
4219
4220 func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
4221 func testServerReaderFromOrder(t *testing.T, mode testMode) {
4222 pr, pw := io.Pipe()
4223 const size = 3 << 20
4224 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4225 rw.Header().Set("Content-Type", "text/plain")
4226 done := make(chan bool)
4227 go func() {
4228 io.Copy(rw, pr)
4229 close(done)
4230 }()
4231 time.Sleep(25 * time.Millisecond)
4232 n, err := io.Copy(io.Discard, req.Body)
4233 if err != nil {
4234 t.Errorf("handler Copy: %v", err)
4235 return
4236 }
4237 if n != size {
4238 t.Errorf("handler Copy = %d; want %d", n, size)
4239 }
4240 pw.Write([]byte("hi"))
4241 pw.Close()
4242 <-done
4243 }))
4244
4245 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
4246 if err != nil {
4247 t.Fatal(err)
4248 }
4249 res, err := cst.c.Do(req)
4250 if err != nil {
4251 t.Fatal(err)
4252 }
4253 all, err := io.ReadAll(res.Body)
4254 if err != nil {
4255 t.Fatal(err)
4256 }
4257 res.Body.Close()
4258 if string(all) != "hi" {
4259 t.Errorf("Body = %q; want hi", all)
4260 }
4261 }
4262
4263
4264 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
4265 for _, code := range []int{StatusNotModified, StatusNoContent} {
4266 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4267 if r.URL.Path == "/header" {
4268 w.Header().Set("Content-Length", "123")
4269 }
4270 w.WriteHeader(code)
4271 if r.URL.Path == "/more" {
4272 w.Write([]byte("stuff"))
4273 }
4274 }))
4275 for _, req := range []string{
4276 "GET / HTTP/1.0",
4277 "GET /header HTTP/1.0",
4278 "GET /more HTTP/1.0",
4279 "GET / HTTP/1.1\nHost: foo",
4280 "GET /header HTTP/1.1\nHost: foo",
4281 "GET /more HTTP/1.1\nHost: foo",
4282 } {
4283 got := ht.rawResponse(req)
4284 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
4285 if !strings.Contains(got, wantStatus) {
4286 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
4287 } else if strings.Contains(got, "Content-Length") {
4288 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
4289 } else if strings.Contains(got, "stuff") {
4290 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
4291 }
4292 }
4293 }
4294 }
4295
4296 func TestContentTypeOkayOn204(t *testing.T) {
4297 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4298 w.Header().Set("Content-Length", "123")
4299 w.Header().Set("Content-Type", "foo/bar")
4300 w.WriteHeader(204)
4301 }))
4302 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4303 if !strings.Contains(got, "Content-Type: foo/bar") {
4304 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
4305 }
4306 if strings.Contains(got, "Content-Length: 123") {
4307 t.Errorf("Response = %q; don't want a Content-Length", got)
4308 }
4309 }
4310
4311
4312
4313
4314
4315
4316
4317 func TestTransportAndServerSharedBodyRace(t *testing.T) {
4318 run(t, testTransportAndServerSharedBodyRace, testNotParallel, http3SkippedMode)
4319 }
4320 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
4321
4322
4323
4324
4325 runTimeSensitiveTest(t, []time.Duration{
4326 1 * time.Millisecond,
4327 5 * time.Millisecond,
4328 10 * time.Millisecond,
4329 50 * time.Millisecond,
4330 100 * time.Millisecond,
4331 500 * time.Millisecond,
4332 time.Second,
4333 5 * time.Second,
4334 }, func(t *testing.T, timeout time.Duration) error {
4335 SetRSTAvoidanceDelay(t, timeout)
4336 t.Logf("set RST avoidance delay to %v", timeout)
4337
4338 const bodySize = 1 << 20
4339
4340 var wg sync.WaitGroup
4341 backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4342
4343
4344
4345
4346
4347
4348
4349
4350 wg.Add(1)
4351 defer wg.Done()
4352
4353 n, err := io.CopyN(rw, req.Body, bodySize)
4354 t.Logf("backend CopyN: %v, %v", n, err)
4355 <-req.Context().Done()
4356 }))
4357
4358
4359 defer func() {
4360 wg.Wait()
4361 backend.close()
4362 }()
4363
4364 var proxy *clientServerTest
4365 proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4366 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
4367 req2.ContentLength = bodySize
4368 cancel := make(chan struct{})
4369 req2.Cancel = cancel
4370
4371 bresp, err := proxy.c.Do(req2)
4372 if err != nil {
4373 t.Errorf("Proxy outbound request: %v", err)
4374 return
4375 }
4376 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
4377 if err != nil {
4378 t.Errorf("Proxy copy error: %v", err)
4379 return
4380 }
4381 t.Cleanup(func() { bresp.Body.Close() })
4382
4383
4384
4385
4386
4387
4388 if mode == http2Mode {
4389 close(cancel)
4390 } else {
4391 proxy.c.Transport.(*Transport).CancelRequest(req2)
4392 }
4393 rw.Write([]byte("OK"))
4394 }))
4395 defer proxy.close()
4396
4397 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
4398 res, err := proxy.c.Do(req)
4399 if err != nil {
4400 return fmt.Errorf("original request: %v", err)
4401 }
4402 res.Body.Close()
4403 return nil
4404 })
4405 }
4406
4407
4408
4409
4410 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
4411 run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
4412 }
4413 func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
4414 if testing.Short() {
4415 t.Skip("skipping in -short mode")
4416 }
4417
4418 readErrCh := make(chan error, 1)
4419 errCh := make(chan error, 2)
4420
4421 server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4422 go func(body io.Reader) {
4423 _, err := body.Read(make([]byte, 100))
4424 readErrCh <- err
4425 }(req.Body)
4426 time.Sleep(500 * time.Millisecond)
4427 })).ts
4428
4429 closeConn := make(chan bool)
4430 defer close(closeConn)
4431 go func() {
4432 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4433 if err != nil {
4434 errCh <- err
4435 return
4436 }
4437 defer conn.Close()
4438 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4439 if err != nil {
4440 errCh <- err
4441 return
4442 }
4443
4444
4445 <-closeConn
4446 }()
4447 select {
4448 case err := <-readErrCh:
4449 if err == nil {
4450 t.Error("Read was nil. Expected error.")
4451 }
4452 case err := <-errCh:
4453 t.Error(err)
4454 }
4455 }
4456
4457
4458 func TestResponseWriterWriteString(t *testing.T) {
4459 okc := make(chan bool, 1)
4460 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4461 _, ok := w.(io.StringWriter)
4462 okc <- ok
4463 }))
4464 ht.rawResponse("GET / HTTP/1.0")
4465 select {
4466 case ok := <-okc:
4467 if !ok {
4468 t.Error("ResponseWriter did not implement io.StringWriter")
4469 }
4470 default:
4471 t.Error("handler was never called")
4472 }
4473 }
4474
4475 func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
4476 func testServerConnState(t *testing.T, mode testMode) {
4477 handler := map[string]func(w ResponseWriter, r *Request){
4478 "/": func(w ResponseWriter, r *Request) {
4479 fmt.Fprintf(w, "Hello.")
4480 },
4481 "/close": func(w ResponseWriter, r *Request) {
4482 w.Header().Set("Connection", "close")
4483 fmt.Fprintf(w, "Hello.")
4484 },
4485 "/hijack": func(w ResponseWriter, r *Request) {
4486 c, _, _ := w.(Hijacker).Hijack()
4487 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4488 c.Close()
4489 },
4490 "/hijack-panic": func(w ResponseWriter, r *Request) {
4491 c, _, _ := w.(Hijacker).Hijack()
4492 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4493 c.Close()
4494 panic("intentional panic")
4495 },
4496 }
4497
4498
4499 type stateLog struct {
4500 active net.Conn
4501 got []ConnState
4502 want []ConnState
4503 complete chan<- struct{}
4504 }
4505 activeLog := make(chan *stateLog, 1)
4506
4507
4508
4509
4510 wantLog := func(doRequests func(), want ...ConnState) {
4511 t.Helper()
4512 complete := make(chan struct{})
4513 activeLog <- &stateLog{want: want, complete: complete}
4514
4515 doRequests()
4516
4517 <-complete
4518 sl := <-activeLog
4519 if !slices.Equal(sl.got, sl.want) {
4520 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4521 }
4522
4523
4524
4525 }
4526
4527 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4528 handler[r.URL.Path](w, r)
4529 }), func(ts *httptest.Server) {
4530 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4531 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4532 if c == nil {
4533 t.Errorf("nil conn seen in state %s", state)
4534 return
4535 }
4536 sl := <-activeLog
4537 if sl.active == nil && state == StateNew {
4538 sl.active = c
4539 } else if sl.active != c {
4540 t.Errorf("unexpected conn in state %s", state)
4541 activeLog <- sl
4542 return
4543 }
4544 sl.got = append(sl.got, state)
4545 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !slices.Equal(sl.got, sl.want[:len(sl.got)])) {
4546 close(sl.complete)
4547 sl.complete = nil
4548 }
4549 activeLog <- sl
4550 }
4551 }).ts
4552 defer func() {
4553 activeLog <- &stateLog{}
4554 ts.Close()
4555 }()
4556
4557 c := ts.Client()
4558
4559 mustGet := func(url string, headers ...string) {
4560 t.Helper()
4561 req, err := NewRequest("GET", url, nil)
4562 if err != nil {
4563 t.Fatal(err)
4564 }
4565 for len(headers) > 0 {
4566 req.Header.Add(headers[0], headers[1])
4567 headers = headers[2:]
4568 }
4569 res, err := c.Do(req)
4570 if err != nil {
4571 t.Errorf("Error fetching %s: %v", url, err)
4572 return
4573 }
4574 _, err = io.ReadAll(res.Body)
4575 defer res.Body.Close()
4576 if err != nil {
4577 t.Errorf("Error reading %s: %v", url, err)
4578 }
4579 }
4580
4581 wantLog(func() {
4582 mustGet(ts.URL + "/")
4583 mustGet(ts.URL + "/close")
4584 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4585
4586 wantLog(func() {
4587 mustGet(ts.URL + "/")
4588 mustGet(ts.URL+"/", "Connection", "close")
4589 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4590
4591 wantLog(func() {
4592 mustGet(ts.URL + "/hijack")
4593 }, StateNew, StateActive, StateHijacked)
4594
4595 wantLog(func() {
4596 mustGet(ts.URL + "/hijack-panic")
4597 }, StateNew, StateActive, StateHijacked)
4598
4599 wantLog(func() {
4600 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4601 if err != nil {
4602 t.Fatal(err)
4603 }
4604 c.Close()
4605 }, StateNew, StateClosed)
4606
4607 wantLog(func() {
4608 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4609 if err != nil {
4610 t.Fatal(err)
4611 }
4612 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4613 t.Fatal(err)
4614 }
4615 c.Read(make([]byte, 1))
4616 c.Close()
4617 }, StateNew, StateActive, StateClosed)
4618
4619 wantLog(func() {
4620 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4621 if err != nil {
4622 t.Fatal(err)
4623 }
4624 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4625 t.Fatal(err)
4626 }
4627 res, err := ReadResponse(bufio.NewReader(c), nil)
4628 if err != nil {
4629 t.Fatal(err)
4630 }
4631 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4632 t.Fatal(err)
4633 }
4634 c.Close()
4635 }, StateNew, StateActive, StateIdle, StateClosed)
4636 }
4637
4638 func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
4639 run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
4640 }
4641 func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
4642 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4643 }), func(ts *httptest.Server) {
4644 ts.Config.SetKeepAlivesEnabled(false)
4645 }).ts
4646 res, err := ts.Client().Get(ts.URL)
4647 if err != nil {
4648 t.Fatal(err)
4649 }
4650 defer res.Body.Close()
4651 if !res.Close {
4652 t.Errorf("Body.Close == false; want true")
4653 }
4654 }
4655
4656
4657 func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
4658 func testServerEmptyBodyRace(t *testing.T, mode testMode) {
4659 var n int32
4660 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4661 atomic.AddInt32(&n, 1)
4662 }), optQuietLog)
4663 var wg sync.WaitGroup
4664 const reqs = 20
4665 for i := 0; i < reqs; i++ {
4666 wg.Add(1)
4667 go func() {
4668 defer wg.Done()
4669 res, err := cst.c.Get(cst.ts.URL)
4670 if err != nil {
4671
4672
4673 time.Sleep(10 * time.Millisecond)
4674 res, err = cst.c.Get(cst.ts.URL)
4675 if err != nil {
4676 t.Error(err)
4677 return
4678 }
4679 }
4680 defer res.Body.Close()
4681 _, err = io.Copy(io.Discard, res.Body)
4682 if err != nil {
4683 t.Error(err)
4684 return
4685 }
4686 }()
4687 }
4688 wg.Wait()
4689 if got := atomic.LoadInt32(&n); got != reqs {
4690 t.Errorf("handler ran %d times; want %d", got, reqs)
4691 }
4692 }
4693
4694 func TestServerConnStateNew(t *testing.T) {
4695 sawNew := false
4696 srv := &Server{
4697 ConnState: func(c net.Conn, state ConnState) {
4698 if state == StateNew {
4699 sawNew = true
4700 }
4701 },
4702 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4703 }
4704 srv.Serve(&oneConnListener{
4705 conn: &rwTestConn{
4706 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4707 Writer: io.Discard,
4708 },
4709 })
4710 if !sawNew {
4711 t.Error("StateNew not seen")
4712 }
4713 }
4714
4715 type closeWriteTestConn struct {
4716 rwTestConn
4717 didCloseWrite bool
4718 }
4719
4720 func (c *closeWriteTestConn) CloseWrite() error {
4721 c.didCloseWrite = true
4722 return nil
4723 }
4724
4725 func TestCloseWrite(t *testing.T) {
4726 SetRSTAvoidanceDelay(t, 1*time.Millisecond)
4727
4728 var srv Server
4729 var testConn closeWriteTestConn
4730 c := ExportServerNewConn(&srv, &testConn)
4731 ExportCloseWriteAndWait(c)
4732 if !testConn.didCloseWrite {
4733 t.Error("didn't see CloseWrite call")
4734 }
4735 }
4736
4737
4738
4739
4740
4741
4742
4743
4744 func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
4745 func testServerFlushAndHijack(t *testing.T, mode testMode) {
4746 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4747 io.WriteString(w, "Hello, ")
4748 w.(Flusher).Flush()
4749 conn, buf, _ := w.(Hijacker).Hijack()
4750 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4751 if err := buf.Flush(); err != nil {
4752 t.Error(err)
4753 }
4754 if err := conn.Close(); err != nil {
4755 t.Error(err)
4756 }
4757 })).ts
4758 res, err := Get(ts.URL)
4759 if err != nil {
4760 t.Fatal(err)
4761 }
4762 defer res.Body.Close()
4763 all, err := io.ReadAll(res.Body)
4764 if err != nil {
4765 t.Fatal(err)
4766 }
4767 if want := "Hello, world!"; string(all) != want {
4768 t.Errorf("Got %q; want %q", all, want)
4769 }
4770 }
4771
4772
4773
4774
4775
4776
4777
4778 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4779 run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
4780 }
4781 func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
4782 if testing.Short() {
4783 t.Skip("skipping in -short mode")
4784 }
4785 const numReq = 3
4786 addrc := make(chan string, numReq)
4787 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4788 addrc <- r.RemoteAddr
4789 time.Sleep(500 * time.Millisecond)
4790 w.(Flusher).Flush()
4791 }), func(ts *httptest.Server) {
4792 ts.Config.WriteTimeout = 250 * time.Millisecond
4793 }).ts
4794
4795 errc := make(chan error, numReq)
4796 go func() {
4797 defer close(errc)
4798 for i := 0; i < numReq; i++ {
4799 res, err := Get(ts.URL)
4800 if res != nil {
4801 res.Body.Close()
4802 }
4803 errc <- err
4804 }
4805 }()
4806
4807 addrSeen := map[string]bool{}
4808 numOkay := 0
4809 for {
4810 select {
4811 case v := <-addrc:
4812 addrSeen[v] = true
4813 case err, ok := <-errc:
4814 if !ok {
4815 if len(addrSeen) != numReq {
4816 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4817 }
4818 if numOkay != 0 {
4819 t.Errorf("got %d successful client requests; want 0", numOkay)
4820 }
4821 return
4822 }
4823 if err == nil {
4824 numOkay++
4825 }
4826 }
4827 }
4828 }
4829
4830
4831
4832 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4833 run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
4834 }
4835 func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
4836 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4837 w.Header().Set("Transfer-Encoding", "foo")
4838 io.WriteString(w, "<html>")
4839 })).ts
4840 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4841 if err != nil {
4842 t.Fatalf("Dial: %v", err)
4843 }
4844 defer c.Close()
4845 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4846 t.Fatal(err)
4847 }
4848 bs := bufio.NewScanner(c)
4849 var got strings.Builder
4850 for bs.Scan() {
4851 if strings.TrimSpace(bs.Text()) == "" {
4852 break
4853 }
4854 got.WriteString(bs.Text())
4855 got.WriteByte('\n')
4856 }
4857 if err := bs.Err(); err != nil {
4858 t.Fatal(err)
4859 }
4860 if strings.Contains(got.String(), "Content-Length") {
4861 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4862 }
4863 if strings.Contains(got.String(), "Content-Type") {
4864 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4865 }
4866 }
4867
4868
4869
4870 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4871 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4872 "\r\n\r\n" +
4873 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4874 var buf bytes.Buffer
4875 conn := &rwTestConn{
4876 Reader: bytes.NewReader(req),
4877 Writer: &buf,
4878 closec: make(chan bool, 1),
4879 }
4880 ln := &oneConnListener{conn: conn}
4881 numReq := 0
4882 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4883 numReq++
4884 }))
4885 <-conn.closec
4886 if numReq != 2 {
4887 t.Errorf("num requests = %d; want 2", numReq)
4888 t.Logf("Res: %s", buf.Bytes())
4889 }
4890 }
4891
4892 func TestIssue13893_Expect100(t *testing.T) {
4893
4894 req := reqBytes(`PUT /readbody HTTP/1.1
4895 User-Agent: PycURL/7.22.0
4896 Host: 127.0.0.1:9000
4897 Accept: */*
4898 Expect: 100-continue
4899 Content-Length: 10
4900
4901 HelloWorld
4902
4903 `)
4904 var buf bytes.Buffer
4905 conn := &rwTestConn{
4906 Reader: bytes.NewReader(req),
4907 Writer: &buf,
4908 closec: make(chan bool, 1),
4909 }
4910 ln := &oneConnListener{conn: conn}
4911 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4912 if _, ok := r.Header["Expect"]; !ok {
4913 t.Error("Expect header should not be filtered out")
4914 }
4915 }))
4916 <-conn.closec
4917 }
4918
4919 func TestIssue11549_Expect100(t *testing.T) {
4920 req := reqBytes(`PUT /readbody HTTP/1.1
4921 User-Agent: PycURL/7.22.0
4922 Host: 127.0.0.1:9000
4923 Accept: */*
4924 Expect: 100-continue
4925 Content-Length: 10
4926
4927 HelloWorldPUT /noreadbody HTTP/1.1
4928 User-Agent: PycURL/7.22.0
4929 Host: 127.0.0.1:9000
4930 Accept: */*
4931 Expect: 100-continue
4932 Content-Length: 10
4933
4934 GET /should-be-ignored HTTP/1.1
4935 Host: foo
4936
4937 `)
4938 var buf strings.Builder
4939 conn := &rwTestConn{
4940 Reader: bytes.NewReader(req),
4941 Writer: &buf,
4942 closec: make(chan bool, 1),
4943 }
4944 ln := &oneConnListener{conn: conn}
4945 numReq := 0
4946 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4947 numReq++
4948 if r.URL.Path == "/readbody" {
4949 io.ReadAll(r.Body)
4950 }
4951 io.WriteString(w, "Hello world!")
4952 }))
4953 <-conn.closec
4954 if numReq != 2 {
4955 t.Errorf("num requests = %d; want 2", numReq)
4956 }
4957 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4958 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4959 }
4960 }
4961
4962
4963
4964 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4965 setParallel(t)
4966 conn := newTestConn()
4967 conn.readBuf.WriteString(
4968 "POST / HTTP/1.1\r\n" +
4969 "Host: test\r\n" +
4970 "Content-Length: 9999999999\r\n" +
4971 "\r\n" + strings.Repeat("a", 1<<20))
4972
4973 ls := &oneConnListener{conn}
4974 var inHandlerLen int
4975 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4976 inHandlerLen = conn.readBuf.Len()
4977 rw.WriteHeader(404)
4978 }))
4979 <-conn.closec
4980 afterHandlerLen := conn.readBuf.Len()
4981
4982 if afterHandlerLen != inHandlerLen {
4983 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4984 }
4985 }
4986
4987 func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
4988 func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
4989 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4990 r.Body = nil
4991 fmt.Fprintf(w, "%v", r.RemoteAddr)
4992 }))
4993 get := func() string {
4994 res, err := cst.c.Get(cst.ts.URL)
4995 if err != nil {
4996 t.Fatal(err)
4997 }
4998 defer res.Body.Close()
4999 slurp, err := io.ReadAll(res.Body)
5000 if err != nil {
5001 t.Fatal(err)
5002 }
5003 return string(slurp)
5004 }
5005 a, b := get(), get()
5006 if a != b {
5007 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
5008 }
5009 }
5010
5011
5012
5013 func TestServerValidatesHostHeader(t *testing.T) {
5014 tests := []struct {
5015 proto string
5016 host string
5017 want int
5018 }{
5019 {"HTTP/0.9", "", 505},
5020
5021 {"HTTP/1.1", "", 400},
5022 {"HTTP/1.1", "Host: \r\n", 200},
5023 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
5024 {"HTTP/1.1", "Host: foo.com\r\n", 200},
5025 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
5026 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
5027 {"HTTP/1.1", "Host: ::1\r\n", 200},
5028 {"HTTP/1.1", "Host: [::1]\r\n", 200},
5029 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
5030 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
5031 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
5032 {"HTTP/1.1", "Host: \x06\r\n", 400},
5033 {"HTTP/1.1", "Host: \xff\r\n", 400},
5034 {"HTTP/1.1", "Host: {\r\n", 400},
5035 {"HTTP/1.1", "Host: }\r\n", 400},
5036 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
5037
5038
5039
5040 {"HTTP/1.0", "", 200},
5041 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
5042 {"HTTP/1.0", "Host: \xff\r\n", 400},
5043
5044
5045 {"PRI * HTTP/2.0", "", 200},
5046
5047
5048 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
5049
5050
5051 {"PRI / HTTP/2.0", "", 505},
5052 {"GET / HTTP/2.0", "", 505},
5053 {"GET / HTTP/3.0", "", 505},
5054 }
5055 for _, tt := range tests {
5056 conn := newTestConn()
5057 methodTarget := "GET / "
5058 if !strings.HasPrefix(tt.proto, "HTTP/") {
5059 methodTarget = ""
5060 }
5061 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
5062
5063 ln := &oneConnListener{conn}
5064 srv := Server{
5065 ErrorLog: quietLog,
5066 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5067 }
5068 go srv.Serve(ln)
5069 <-conn.closec
5070 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5071 if err != nil {
5072 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
5073 continue
5074 }
5075 if res.StatusCode != tt.want {
5076 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
5077 }
5078 }
5079 }
5080
5081 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
5082 run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
5083 }
5084 func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
5085 const upgradeResponse = "upgrade here"
5086 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5087 conn, br, err := w.(Hijacker).Hijack()
5088 if err != nil {
5089 t.Error(err)
5090 return
5091 }
5092 defer conn.Close()
5093 if r.Method != "PRI" || r.RequestURI != "*" {
5094 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
5095 return
5096 }
5097 if !r.Close {
5098 t.Errorf("Request.Close = true; want false")
5099 }
5100 const want = "SM\r\n\r\n"
5101 buf := make([]byte, len(want))
5102 n, err := io.ReadFull(br, buf)
5103 if err != nil || string(buf[:n]) != want {
5104 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
5105 return
5106 }
5107 io.WriteString(conn, upgradeResponse)
5108 })).ts
5109
5110 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5111 if err != nil {
5112 t.Fatalf("Dial: %v", err)
5113 }
5114 defer c.Close()
5115 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
5116 slurp, err := io.ReadAll(c)
5117 if err != nil {
5118 t.Fatal(err)
5119 }
5120 if string(slurp) != upgradeResponse {
5121 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
5122 }
5123 }
5124
5125
5126
5127 func TestServerValidatesHeaders(t *testing.T) {
5128 setParallel(t)
5129 tests := []struct {
5130 header string
5131 want int
5132 }{
5133 {"", 200},
5134 {"Foo: bar\r\n", 200},
5135 {"X-Foo: bar\r\n", 200},
5136 {"Foo: a space\r\n", 200},
5137
5138 {"A space: foo\r\n", 400},
5139 {"foo\xffbar: foo\r\n", 400},
5140 {"foo\x00bar: foo\r\n", 400},
5141 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
5142
5143
5144 {"Foo : bar\r\n", 400},
5145 {"Foo\t: bar\r\n", 400},
5146
5147
5148
5149 {": empty key\r\n", 400},
5150
5151
5152
5153
5154 {"Content-Length: notdigits\r\n", 400},
5155 {"Content-Length: notdigits\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n", 400},
5156
5157 {"foo: foo foo\r\n", 200},
5158 {"foo: foo\tfoo\r\n", 200},
5159 {"foo: foo\x00foo\r\n", 400},
5160 {"foo: foo\x7ffoo\r\n", 400},
5161 {"foo: foo\xfffoo\r\n", 200},
5162 }
5163 for _, tt := range tests {
5164 conn := newTestConn()
5165 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
5166
5167 ln := &oneConnListener{conn}
5168 srv := Server{
5169 ErrorLog: quietLog,
5170 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5171 }
5172 go srv.Serve(ln)
5173 <-conn.closec
5174 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5175 if err != nil {
5176 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
5177 continue
5178 }
5179 if res.StatusCode != tt.want {
5180 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
5181 }
5182 }
5183 }
5184
5185 func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
5186 run(t, testServerRequestContextCancel_ServeHTTPDone, http3SkippedMode)
5187 }
5188 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
5189 ctxc := make(chan context.Context, 1)
5190 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5191 ctx := r.Context()
5192 select {
5193 case <-ctx.Done():
5194 t.Error("should not be Done in ServeHTTP")
5195 default:
5196 }
5197 ctxc <- ctx
5198 }))
5199 res, err := cst.c.Get(cst.ts.URL)
5200 if err != nil {
5201 t.Fatal(err)
5202 }
5203 res.Body.Close()
5204 ctx := <-ctxc
5205 select {
5206 case <-ctx.Done():
5207 default:
5208 t.Error("context should be done after ServeHTTP completes")
5209 }
5210 }
5211
5212
5213
5214
5215
5216 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
5217 run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
5218 }
5219 func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
5220 inHandler := make(chan struct{})
5221 handlerDone := make(chan struct{})
5222 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5223 close(inHandler)
5224 <-r.Context().Done()
5225 close(handlerDone)
5226 })).ts
5227 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5228 if err != nil {
5229 t.Fatal(err)
5230 }
5231 defer c.Close()
5232 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
5233 <-inHandler
5234 c.Close()
5235 <-handlerDone
5236 }
5237
5238 func TestServerContext_ServerContextKey(t *testing.T) {
5239 run(t, testServerContext_ServerContextKey, http3SkippedMode)
5240 }
5241 func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
5242 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5243 ctx := r.Context()
5244 got := ctx.Value(ServerContextKey)
5245 if _, ok := got.(*Server); !ok {
5246 t.Errorf("context value = %T; want *http.Server", got)
5247 }
5248 }))
5249 res, err := cst.c.Get(cst.ts.URL)
5250 if err != nil {
5251 t.Fatal(err)
5252 }
5253 res.Body.Close()
5254 }
5255
5256 func TestServerContext_LocalAddrContextKey(t *testing.T) {
5257 run(t, testServerContext_LocalAddrContextKey, http3SkippedMode)
5258 }
5259 func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
5260 ch := make(chan any, 1)
5261 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5262 ch <- r.Context().Value(LocalAddrContextKey)
5263 }))
5264 if _, err := cst.c.Head(cst.ts.URL); err != nil {
5265 t.Fatal(err)
5266 }
5267
5268 host := cst.ts.Listener.Addr().String()
5269 got := <-ch
5270 if addr, ok := got.(net.Addr); !ok {
5271 t.Errorf("local addr value = %T; want net.Addr", got)
5272 } else if fmt.Sprint(addr) != host {
5273 t.Errorf("local addr = %v; want %v", addr, host)
5274 }
5275 }
5276
5277
5278 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
5279 setParallel(t)
5280 defer afterTest(t)
5281 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5282 w.Header().Set("Transfer-Encoding", "chunked")
5283 w.Write([]byte("hello"))
5284 }))
5285 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5286 const hdr = "Transfer-Encoding: chunked"
5287 if n := strings.Count(resp, hdr); n != 1 {
5288 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5289 }
5290 }
5291
5292
5293 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
5294 setParallel(t)
5295 defer afterTest(t)
5296 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5297 w.Header().Set("Transfer-Encoding", "gzip")
5298 gz := gzip.NewWriter(w)
5299 gz.Write([]byte("hello"))
5300 gz.Close()
5301 }))
5302 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5303 for _, v := range []string{"gzip", "chunked"} {
5304 hdr := "Transfer-Encoding: " + v
5305 if n := strings.Count(resp, hdr); n != 1 {
5306 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5307 }
5308 }
5309 }
5310
5311 func BenchmarkClientServer(b *testing.B) {
5312 run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
5313 }
5314 func benchmarkClientServer(b *testing.B, mode testMode) {
5315 b.ReportAllocs()
5316 b.StopTimer()
5317 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5318 fmt.Fprintf(rw, "Hello world.\n")
5319 })).ts
5320 b.StartTimer()
5321
5322 c := ts.Client()
5323 for i := 0; i < b.N; i++ {
5324 res, err := c.Get(ts.URL)
5325 if err != nil {
5326 b.Fatal("Get:", err)
5327 }
5328 all, err := io.ReadAll(res.Body)
5329 res.Body.Close()
5330 if err != nil {
5331 b.Fatal("ReadAll:", err)
5332 }
5333 body := string(all)
5334 if body != "Hello world.\n" {
5335 b.Fatal("Got body:", body)
5336 }
5337 }
5338
5339 b.StopTimer()
5340 }
5341
5342 func BenchmarkClientServerParallel(b *testing.B) {
5343 for _, parallelism := range []int{4, 64} {
5344 b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
5345 run(b, func(b *testing.B, mode testMode) {
5346 benchmarkClientServerParallel(b, parallelism, mode)
5347 }, []testMode{http1Mode, https1Mode, http2Mode})
5348 })
5349 }
5350 }
5351
5352 func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
5353 b.ReportAllocs()
5354 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5355 fmt.Fprintf(rw, "Hello world.\n")
5356 })).ts
5357 b.ResetTimer()
5358 b.SetParallelism(parallelism)
5359 b.RunParallel(func(pb *testing.PB) {
5360 c := ts.Client()
5361 for pb.Next() {
5362 res, err := c.Get(ts.URL)
5363 if err != nil {
5364 b.Logf("Get: %v", err)
5365 continue
5366 }
5367 all, err := io.ReadAll(res.Body)
5368 res.Body.Close()
5369 if err != nil {
5370 b.Logf("ReadAll: %v", err)
5371 continue
5372 }
5373 body := string(all)
5374 if body != "Hello world.\n" {
5375 panic("Got body: " + body)
5376 }
5377 }
5378 })
5379 }
5380
5381
5382
5383
5384
5385
5386
5387
5388
5389
5390 func BenchmarkServer(b *testing.B) {
5391 b.ReportAllocs()
5392
5393 if url := os.Getenv("GO_TEST_BENCH_SERVER_URL"); url != "" {
5394 n, err := strconv.Atoi(os.Getenv("GO_TEST_BENCH_CLIENT_N"))
5395 if err != nil {
5396 panic(err)
5397 }
5398 for i := 0; i < n; i++ {
5399 res, err := Get(url)
5400 if err != nil {
5401 log.Panicf("Get: %v", err)
5402 }
5403 all, err := io.ReadAll(res.Body)
5404 res.Body.Close()
5405 if err != nil {
5406 log.Panicf("ReadAll: %v", err)
5407 }
5408 body := string(all)
5409 if body != "Hello world.\n" {
5410 log.Panicf("Got body: %q", body)
5411 }
5412 }
5413 os.Exit(0)
5414 return
5415 }
5416
5417 var res = []byte("Hello world.\n")
5418 b.StopTimer()
5419 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5420 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5421 rw.Write(res)
5422 }))
5423 defer ts.Close()
5424 b.StartTimer()
5425
5426 cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$")
5427 cmd.Env = append([]string{
5428 fmt.Sprintf("GO_TEST_BENCH_CLIENT_N=%d", b.N),
5429 fmt.Sprintf("GO_TEST_BENCH_SERVER_URL=%s", ts.URL),
5430 }, os.Environ()...)
5431 out, err := cmd.CombinedOutput()
5432 if err != nil {
5433 b.Errorf("Test failure: %v, with output: %s", err, out)
5434 }
5435 }
5436
5437
5438 func getNoBody(urlStr string) (*Response, error) {
5439 res, err := Get(urlStr)
5440 if err != nil {
5441 return nil, err
5442 }
5443 res.Body.Close()
5444 return res, nil
5445 }
5446
5447
5448
5449 func BenchmarkClient(b *testing.B) {
5450 var data = []byte("Hello world.\n")
5451
5452 url := startClientBenchmarkServer(b, HandlerFunc(func(w ResponseWriter, _ *Request) {
5453 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5454 w.Write(data)
5455 }))
5456
5457
5458 b.StartTimer()
5459 for i := 0; i < b.N; i++ {
5460 res, err := Get(url)
5461 if err != nil {
5462 b.Fatalf("Get: %v", err)
5463 }
5464 body, err := io.ReadAll(res.Body)
5465 res.Body.Close()
5466 if err != nil {
5467 b.Fatalf("ReadAll: %v", err)
5468 }
5469 if !bytes.Equal(body, data) {
5470 b.Fatalf("Got body: %q", body)
5471 }
5472 }
5473 b.StopTimer()
5474 }
5475
5476 func startClientBenchmarkServer(b *testing.B, handler Handler) string {
5477 b.ReportAllocs()
5478 b.StopTimer()
5479
5480 if server := os.Getenv("GO_TEST_BENCH_SERVER"); server != "" {
5481
5482 port := os.Getenv("GO_TEST_BENCH_SERVER_PORT")
5483 if port == "" {
5484 port = "0"
5485 }
5486 ln, err := net.Listen("tcp", "localhost:"+port)
5487 if err != nil {
5488 log.Fatal(err)
5489 }
5490 fmt.Println(ln.Addr().String())
5491
5492 HandleFunc("/", func(w ResponseWriter, r *Request) {
5493 r.ParseForm()
5494 if r.Form.Get("stop") != "" {
5495 os.Exit(0)
5496 }
5497 handler.ServeHTTP(w, r)
5498 })
5499 var srv Server
5500 log.Fatal(srv.Serve(ln))
5501 }
5502
5503
5504 ctx, cancel := context.WithCancel(context.Background())
5505 cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^"+b.Name()+"$")
5506 cmd.Env = append(cmd.Environ(), "GO_TEST_BENCH_SERVER=yes")
5507 cmd.Stderr = os.Stderr
5508 stdout, err := cmd.StdoutPipe()
5509 if err != nil {
5510 b.Fatal(err)
5511 }
5512 if err := cmd.Start(); err != nil {
5513 b.Fatalf("subprocess failed to start: %v", err)
5514 }
5515
5516 done := make(chan error, 1)
5517 go func() {
5518 done <- cmd.Wait()
5519 close(done)
5520 }()
5521
5522
5523
5524 bs := bufio.NewScanner(stdout)
5525 if !bs.Scan() {
5526 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5527 }
5528 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5529 if _, err := getNoBody(url); err != nil {
5530 b.Fatalf("initial probe of child process failed: %v", err)
5531 }
5532
5533
5534 b.Cleanup(func() {
5535 getNoBody(url + "?stop=yes")
5536 if err := <-done; err != nil {
5537 b.Fatalf("subprocess failed: %v", err)
5538 }
5539
5540 cancel()
5541 <-done
5542
5543 afterTest(b)
5544 })
5545
5546 return url
5547 }
5548
5549 func BenchmarkClientGzip(b *testing.B) {
5550 const responseSize = 1024 * 1024
5551
5552 var buf bytes.Buffer
5553 gz := gzip.NewWriter(&buf)
5554 if _, err := io.CopyN(gz, crand.Reader, responseSize); err != nil {
5555 b.Fatal(err)
5556 }
5557 gz.Close()
5558
5559 data := buf.Bytes()
5560
5561 url := startClientBenchmarkServer(b, HandlerFunc(func(w ResponseWriter, _ *Request) {
5562 w.Header().Set("Content-Encoding", "gzip")
5563 w.Write(data)
5564 }))
5565
5566
5567 b.StartTimer()
5568 for i := 0; i < b.N; i++ {
5569 res, err := Get(url)
5570 if err != nil {
5571 b.Fatalf("Get: %v", err)
5572 }
5573 n, err := io.Copy(io.Discard, res.Body)
5574 res.Body.Close()
5575 if err != nil {
5576 b.Fatalf("ReadAll: %v", err)
5577 }
5578 if n != responseSize {
5579 b.Fatalf("ReadAll: expected %d bytes, got %d", responseSize, n)
5580 }
5581 }
5582 b.StopTimer()
5583 }
5584
5585 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5586 b.ReportAllocs()
5587 req := reqBytes(`GET / HTTP/1.0
5588 Host: golang.org
5589 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5590 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5591 Accept-Encoding: gzip,deflate,sdch
5592 Accept-Language: en-US,en;q=0.8
5593 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5594 `)
5595 res := []byte("Hello world!\n")
5596
5597 conn := newTestConn()
5598 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5599 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5600 rw.Write(res)
5601 })
5602 ln := new(oneConnListener)
5603 for i := 0; i < b.N; i++ {
5604 conn.readBuf.Reset()
5605 conn.writeBuf.Reset()
5606 conn.readBuf.Write(req)
5607 ln.conn = conn
5608 Serve(ln, handler)
5609 <-conn.closec
5610 }
5611 }
5612
5613
5614 type repeatReader struct {
5615 content []byte
5616 count int
5617 off int
5618 }
5619
5620 func (r *repeatReader) Read(p []byte) (n int, err error) {
5621 if r.count <= 0 {
5622 return 0, io.EOF
5623 }
5624 n = copy(p, r.content[r.off:])
5625 r.off += n
5626 if r.off == len(r.content) {
5627 r.count--
5628 r.off = 0
5629 }
5630 return
5631 }
5632
5633 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5634 b.ReportAllocs()
5635
5636 req := reqBytes(`GET / HTTP/1.1
5637 Host: golang.org
5638 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5639 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5640 Accept-Encoding: gzip,deflate,sdch
5641 Accept-Language: en-US,en;q=0.8
5642 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5643 `)
5644 res := []byte("Hello world!\n")
5645
5646 conn := &rwTestConn{
5647 Reader: &repeatReader{content: req, count: b.N},
5648 Writer: io.Discard,
5649 closec: make(chan bool, 1),
5650 }
5651 handled := 0
5652 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5653 handled++
5654 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5655 rw.Write(res)
5656 })
5657 ln := &oneConnListener{conn: conn}
5658 go Serve(ln, handler)
5659 <-conn.closec
5660 if b.N != handled {
5661 b.Errorf("b.N=%d but handled %d", b.N, handled)
5662 }
5663 }
5664
5665
5666
5667 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5668 b.ReportAllocs()
5669
5670 req := reqBytes(`GET / HTTP/1.1
5671 Host: golang.org
5672 `)
5673 res := []byte("Hello world!\n")
5674
5675 conn := &rwTestConn{
5676 Reader: &repeatReader{content: req, count: b.N},
5677 Writer: io.Discard,
5678 closec: make(chan bool, 1),
5679 }
5680 handled := 0
5681 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5682 handled++
5683 rw.Write(res)
5684 })
5685 ln := &oneConnListener{conn: conn}
5686 go Serve(ln, handler)
5687 <-conn.closec
5688 if b.N != handled {
5689 b.Errorf("b.N=%d but handled %d", b.N, handled)
5690 }
5691 }
5692
5693 const someResponse = "<html>some response</html>"
5694
5695
5696 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5697
5698
5699 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5700 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5701 w.Header().Set("Content-Type", "text/html")
5702 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5703 w.Write(response)
5704 }))
5705 }
5706
5707
5708 func BenchmarkServerHandlerNoLen(b *testing.B) {
5709 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5710 w.Header().Set("Content-Type", "text/html")
5711 w.Write(response)
5712 }))
5713 }
5714
5715
5716 func BenchmarkServerHandlerNoType(b *testing.B) {
5717 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5718 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5719 w.Write(response)
5720 }))
5721 }
5722
5723
5724 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5725 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5726 w.Write(response)
5727 }))
5728 }
5729
5730 func benchmarkHandler(b *testing.B, h Handler) {
5731 b.ReportAllocs()
5732 req := reqBytes(`GET / HTTP/1.1
5733 Host: golang.org
5734 `)
5735 conn := &rwTestConn{
5736 Reader: &repeatReader{content: req, count: b.N},
5737 Writer: io.Discard,
5738 closec: make(chan bool, 1),
5739 }
5740 handled := 0
5741 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5742 handled++
5743 h.ServeHTTP(rw, r)
5744 })
5745 ln := &oneConnListener{conn: conn}
5746 go Serve(ln, handler)
5747 <-conn.closec
5748 if b.N != handled {
5749 b.Errorf("b.N=%d but handled %d", b.N, handled)
5750 }
5751 }
5752
5753 func BenchmarkServerHijack(b *testing.B) {
5754 b.ReportAllocs()
5755 req := reqBytes(`GET / HTTP/1.1
5756 Host: golang.org
5757 `)
5758 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5759 conn, _, err := w.(Hijacker).Hijack()
5760 if err != nil {
5761 panic(err)
5762 }
5763 conn.Close()
5764 })
5765 conn := &rwTestConn{
5766 Writer: io.Discard,
5767 closec: make(chan bool, 1),
5768 }
5769 ln := &oneConnListener{conn: conn}
5770 for i := 0; i < b.N; i++ {
5771 conn.Reader = bytes.NewReader(req)
5772 ln.conn = conn
5773 Serve(ln, h)
5774 <-conn.closec
5775 }
5776 }
5777
5778 func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
5779 func benchmarkCloseNotifier(b *testing.B, mode testMode) {
5780 b.ReportAllocs()
5781 b.StopTimer()
5782 sawClose := make(chan bool)
5783 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
5784 <-rw.(CloseNotifier).CloseNotify()
5785 sawClose <- true
5786 })).ts
5787 b.StartTimer()
5788 for i := 0; i < b.N; i++ {
5789 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5790 if err != nil {
5791 b.Fatalf("error dialing: %v", err)
5792 }
5793 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5794 if err != nil {
5795 b.Fatal(err)
5796 }
5797 conn.Close()
5798 <-sawClose
5799 }
5800 b.StopTimer()
5801 }
5802
5803
5804 func TestConcurrentServerServe(t *testing.T) {
5805 setParallel(t)
5806 for i := 0; i < 100; i++ {
5807 ln1 := &oneConnListener{conn: nil}
5808 ln2 := &oneConnListener{conn: nil}
5809 srv := Server{}
5810 go func() { srv.Serve(ln1) }()
5811 go func() { srv.Serve(ln2) }()
5812 }
5813 }
5814
5815 func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
5816 func testServerIdleTimeout(t *testing.T, mode testMode) {
5817 if testing.Short() {
5818 t.Skip("skipping in short mode")
5819 }
5820 runTimeSensitiveTest(t, []time.Duration{
5821 10 * time.Millisecond,
5822 100 * time.Millisecond,
5823 1 * time.Second,
5824 10 * time.Second,
5825 }, func(t *testing.T, readHeaderTimeout time.Duration) error {
5826 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5827 io.Copy(io.Discard, r.Body)
5828 io.WriteString(w, r.RemoteAddr)
5829 }), func(ts *httptest.Server) {
5830 ts.Config.ReadHeaderTimeout = readHeaderTimeout
5831 ts.Config.IdleTimeout = 2 * readHeaderTimeout
5832 })
5833 defer cst.close()
5834 ts := cst.ts
5835 t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
5836 t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
5837 c := ts.Client()
5838
5839 get := func() (string, error) {
5840 res, err := c.Get(ts.URL)
5841 if err != nil {
5842 return "", err
5843 }
5844 defer res.Body.Close()
5845 slurp, err := io.ReadAll(res.Body)
5846 if err != nil {
5847
5848
5849
5850 t.Fatal(err)
5851 }
5852 return string(slurp), nil
5853 }
5854
5855 a1, err := get()
5856 if err != nil {
5857 return err
5858 }
5859 a2, err := get()
5860 if err != nil {
5861 return err
5862 }
5863 if a1 != a2 {
5864 return fmt.Errorf("did requests on different connections")
5865 }
5866 time.Sleep(ts.Config.IdleTimeout * 3 / 2)
5867 a3, err := get()
5868 if err != nil {
5869 return err
5870 }
5871 if a2 == a3 {
5872 return fmt.Errorf("request three unexpectedly on same connection")
5873 }
5874
5875
5876 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5877 if err != nil {
5878 return err
5879 }
5880 defer conn.Close()
5881 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5882 time.Sleep(ts.Config.ReadHeaderTimeout * 2)
5883 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5884 return fmt.Errorf("copy byte succeeded; want err")
5885 }
5886
5887 return nil
5888 })
5889 }
5890
5891 func get(t *testing.T, c *Client, url string) string {
5892 res, err := c.Get(url)
5893 if err != nil {
5894 t.Fatal(err)
5895 }
5896 defer res.Body.Close()
5897 slurp, err := io.ReadAll(res.Body)
5898 if err != nil {
5899 t.Fatal(err)
5900 }
5901 return string(slurp)
5902 }
5903
5904
5905
5906 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5907 run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
5908 }
5909 func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
5910 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5911 io.WriteString(w, r.RemoteAddr)
5912 })).ts
5913
5914 c := ts.Client()
5915 tr := c.Transport.(*Transport)
5916
5917 get := func() string { return get(t, c, ts.URL) }
5918
5919 a1, a2 := get(), get()
5920 if a1 == a2 {
5921 t.Logf("made two requests from a single conn %q (as expected)", a1)
5922 } else {
5923 t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
5924 }
5925
5926
5927
5928
5929
5930 if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
5931 t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
5932 }
5933
5934
5935 ts.Config.SetKeepAlivesEnabled(false)
5936
5937 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5938 if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
5939 if d > 0 {
5940 t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
5941 }
5942 return false
5943 }
5944 return true
5945 })
5946
5947
5948
5949
5950 }
5951
5952 func TestServerShutdown(t *testing.T) { run(t, testServerShutdown, http3SkippedMode) }
5953 func testServerShutdown(t *testing.T, mode testMode) {
5954 var cst *clientServerTest
5955
5956 var once sync.Once
5957 statesRes := make(chan map[ConnState]int, 1)
5958 shutdownRes := make(chan error, 1)
5959 gotOnShutdown := make(chan struct{})
5960 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5961 first := false
5962 once.Do(func() {
5963 statesRes <- cst.ts.Config.ExportAllConnsByState()
5964 go func() {
5965 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5966 }()
5967 first = true
5968 })
5969
5970 if first {
5971
5972
5973
5974 <-gotOnShutdown
5975
5976
5977 for !t.Failed() {
5978 res, err := cst.c.Get(cst.ts.URL)
5979 if err != nil {
5980 break
5981 }
5982 out, _ := io.ReadAll(res.Body)
5983 res.Body.Close()
5984 if mode == http2Mode {
5985 t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5986 t.Logf("Retrying to work around https://go.dev/issue/59038.")
5987 continue
5988 }
5989 t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5990 }
5991 }
5992
5993 io.WriteString(w, r.RemoteAddr)
5994 })
5995
5996 cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
5997 srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
5998 })
5999
6000 out := get(t, cst.c, cst.ts.URL)
6001 t.Logf("%v: %q", cst.ts.URL, out)
6002
6003 if err := <-shutdownRes; err != nil {
6004 t.Fatalf("Shutdown: %v", err)
6005 }
6006 <-gotOnShutdown
6007
6008 if states := <-statesRes; states[StateActive] != 1 {
6009 t.Errorf("connection in wrong state, %v", states)
6010 }
6011 }
6012
6013 func TestServerShutdownStateNew(t *testing.T) {
6014 runSynctest(t, testServerShutdownStateNew, http3SkippedMode)
6015 }
6016 func testServerShutdownStateNew(t *testing.T, mode testMode) {
6017 if testing.Short() {
6018 t.Skip("test takes 5-6 seconds; skipping in short mode")
6019 }
6020
6021 listener := fakeNetListen()
6022 defer listener.Close()
6023
6024 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6025
6026 }), func(ts *httptest.Server) {
6027 ts.Listener.Close()
6028 ts.Listener = listener
6029
6030 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
6031 }).ts
6032
6033
6034 c := listener.connect()
6035 defer c.Close()
6036 synctest.Wait()
6037
6038 shutdownRes := runAsync(func() (struct{}, error) {
6039 return struct{}{}, ts.Config.Shutdown(context.Background())
6040 })
6041
6042
6043
6044
6045 const expectTimeout = 5 * time.Second
6046
6047
6048 time.Sleep(expectTimeout - 1)
6049 synctest.Wait()
6050 if shutdownRes.done() {
6051 t.Fatal("shutdown too soon")
6052 }
6053 if c.IsClosedByPeer() {
6054 t.Fatal("connection was closed by server too soon")
6055 }
6056
6057
6058
6059
6060
6061 time.Sleep(2 * time.Second)
6062 synctest.Wait()
6063 if _, err := shutdownRes.result(); err != nil {
6064 t.Fatalf("Shutdown() = %v, want complete", err)
6065 }
6066 if !c.IsClosedByPeer() {
6067 t.Fatalf("connection was not closed by server after shutdown")
6068 }
6069 }
6070
6071
6072 func TestServerCloseDeadlock(t *testing.T) {
6073 var s Server
6074 s.Close()
6075 s.Close()
6076 }
6077
6078
6079
6080 func TestServerKeepAlivesEnabled(t *testing.T) {
6081 runSynctest(t, testServerKeepAlivesEnabled, http3SkippedMode)
6082 }
6083 func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
6084 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optFakeNet)
6085 defer cst.close()
6086 srv := cst.ts.Config
6087 srv.SetKeepAlivesEnabled(false)
6088 for try := range 2 {
6089 synctest.Wait()
6090 if !srv.ExportAllConnsIdle() {
6091 t.Fatalf("test server still has active conns before request %v", try)
6092 }
6093 conns := 0
6094 var info httptrace.GotConnInfo
6095 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6096 GotConn: func(v httptrace.GotConnInfo) {
6097 conns++
6098 info = v
6099 },
6100 })
6101 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6102 if err != nil {
6103 t.Fatal(err)
6104 }
6105 res, err := cst.c.Do(req)
6106 if err != nil {
6107 t.Fatal(err)
6108 }
6109 res.Body.Close()
6110 if conns != 1 {
6111 t.Fatalf("request %v: got %v conns, want 1", try, conns)
6112 }
6113 if info.Reused || info.WasIdle {
6114 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
6115 }
6116 }
6117 }
6118
6119
6120
6121
6122 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
6123 func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
6124 runTimeSensitiveTest(t, []time.Duration{
6125 10 * time.Millisecond,
6126 50 * time.Millisecond,
6127 250 * time.Millisecond,
6128 time.Second,
6129 2 * time.Second,
6130 }, func(t *testing.T, timeout time.Duration) error {
6131 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6132 select {
6133 case <-time.After(2 * timeout):
6134 fmt.Fprint(w, "ok")
6135 case <-r.Context().Done():
6136 fmt.Fprint(w, r.Context().Err())
6137 }
6138 }), func(ts *httptest.Server) {
6139 ts.Config.ReadTimeout = timeout
6140 t.Logf("Server.Config.ReadTimeout = %v", timeout)
6141 })
6142 defer cst.close()
6143 ts := cst.ts
6144
6145 var retries atomic.Int32
6146 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
6147 if retries.Add(1) != 1 {
6148 return nil, errors.New("too many retries")
6149 }
6150 return nil, nil
6151 }
6152
6153 c := ts.Client()
6154
6155 res, err := c.Get(ts.URL)
6156 if err != nil {
6157 return fmt.Errorf("Get: %v", err)
6158 }
6159 slurp, err := io.ReadAll(res.Body)
6160 res.Body.Close()
6161 if err != nil {
6162 return fmt.Errorf("Body ReadAll: %v", err)
6163 }
6164 if string(slurp) != "ok" {
6165 return fmt.Errorf("got: %q, want ok", slurp)
6166 }
6167 return nil
6168 })
6169 }
6170
6171
6172
6173
6174 func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
6175 run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
6176 }
6177 func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
6178 runTimeSensitiveTest(t, []time.Duration{
6179 10 * time.Millisecond,
6180 50 * time.Millisecond,
6181 250 * time.Millisecond,
6182 time.Second,
6183 2 * time.Second,
6184 }, func(t *testing.T, timeout time.Duration) error {
6185 cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
6186 ts.Config.ReadHeaderTimeout = timeout
6187 ts.Config.IdleTimeout = 0
6188 })
6189 defer cst.close()
6190 ts := cst.ts
6191
6192
6193
6194 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6195 if err != nil {
6196 t.Fatalf("dial failed: %v", err)
6197 }
6198 br := bufio.NewReader(conn)
6199 defer conn.Close()
6200
6201 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6202 return fmt.Errorf("writing first request failed: %v", err)
6203 }
6204
6205 if _, err := ReadResponse(br, nil); err != nil {
6206 return fmt.Errorf("first response (before timeout) failed: %v", err)
6207 }
6208
6209
6210
6211 time.Sleep(timeout * 3 / 2)
6212
6213 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6214 return fmt.Errorf("writing second request failed: %v", err)
6215 }
6216
6217 if _, err := ReadResponse(br, nil); err != nil {
6218 return fmt.Errorf("second response (after timeout) failed: %v", err)
6219 }
6220
6221 return nil
6222 })
6223 }
6224
6225
6226
6227 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
6228 for i, d := range durations {
6229 err := test(t, d)
6230 if err == nil {
6231 return
6232 }
6233 if i == len(durations)-1 || t.Failed() {
6234 t.Fatalf("failed with duration %v: %v", d, err)
6235 }
6236 t.Logf("retrying after error with duration %v: %v", d, err)
6237 }
6238 }
6239
6240
6241
6242 func TestServerDuplicateBackgroundRead(t *testing.T) {
6243 run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
6244 }
6245 func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
6246 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
6247 testenv.SkipFlaky(t, 24826)
6248 }
6249
6250 goroutines := 5
6251 requests := 2000
6252 if testing.Short() {
6253 goroutines = 3
6254 requests = 100
6255 }
6256
6257 hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
6258
6259 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6260
6261 var wg sync.WaitGroup
6262 for i := 0; i < goroutines; i++ {
6263 wg.Add(1)
6264 go func() {
6265 defer wg.Done()
6266 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
6267 if err != nil {
6268 t.Error(err)
6269 return
6270 }
6271 defer cn.Close()
6272
6273 wg.Add(1)
6274 go func() {
6275 defer wg.Done()
6276 io.Copy(io.Discard, cn)
6277 }()
6278
6279 for j := 0; j < requests; j++ {
6280 if t.Failed() {
6281 return
6282 }
6283 _, err := cn.Write(reqBytes)
6284 if err != nil {
6285 t.Error(err)
6286 return
6287 }
6288 }
6289 }()
6290 }
6291 wg.Wait()
6292 }
6293
6294
6295
6296
6297
6298
6299 func TestServerHijackGetsBackgroundByte(t *testing.T) {
6300 run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
6301 }
6302 func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
6303 if runtime.GOOS == "plan9" {
6304 t.Skip("skipping test; see https://golang.org/issue/18657")
6305 }
6306 done := make(chan struct{})
6307 inHandler := make(chan bool, 1)
6308 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6309 defer close(done)
6310
6311
6312 inHandler <- true
6313
6314 conn, buf, err := w.(Hijacker).Hijack()
6315 if err != nil {
6316 t.Error(err)
6317 return
6318 }
6319 defer conn.Close()
6320
6321 peek, err := buf.Reader.Peek(3)
6322 if string(peek) != "foo" || err != nil {
6323 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
6324 }
6325
6326 select {
6327 case <-r.Context().Done():
6328 t.Error("context unexpectedly canceled")
6329 default:
6330 }
6331 })).ts
6332
6333 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6334 if err != nil {
6335 t.Fatal(err)
6336 }
6337 defer cn.Close()
6338 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6339 t.Fatal(err)
6340 }
6341 <-inHandler
6342 if _, err := cn.Write([]byte("foo")); err != nil {
6343 t.Fatal(err)
6344 }
6345
6346 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6347 t.Fatal(err)
6348 }
6349 <-done
6350 }
6351
6352
6353 func TestServerHijackGetsFullBody(t *testing.T) {
6354 run(t, testServerHijackGetsFullBody, []testMode{http1Mode})
6355 }
6356 func testServerHijackGetsFullBody(t *testing.T, mode testMode) {
6357 if runtime.GOOS == "plan9" {
6358 t.Skip("skipping test; see https://golang.org/issue/18657")
6359 }
6360 done := make(chan struct{})
6361 needle := strings.Repeat("x", 100*1024)
6362 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6363 defer close(done)
6364
6365 conn, buf, err := w.(Hijacker).Hijack()
6366 if err != nil {
6367 t.Error(err)
6368 return
6369 }
6370 defer conn.Close()
6371
6372 got := make([]byte, len(needle))
6373 n, err := io.ReadFull(buf.Reader, got)
6374 if n != len(needle) || string(got) != needle || err != nil {
6375 t.Errorf("Peek = %q, %v; want 'x'*4096, nil", got, err)
6376 }
6377 })).ts
6378
6379 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6380 if err != nil {
6381 t.Fatal(err)
6382 }
6383 defer cn.Close()
6384 buf := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6385 buf = append(buf, []byte(needle)...)
6386 if _, err := cn.Write(buf); err != nil {
6387 t.Fatal(err)
6388 }
6389
6390 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6391 t.Fatal(err)
6392 }
6393 <-done
6394 }
6395
6396
6397
6398
6399 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
6400 run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
6401 }
6402 func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
6403 if runtime.GOOS == "plan9" {
6404 t.Skip("skipping test; see https://golang.org/issue/18657")
6405 }
6406 done := make(chan struct{})
6407 const size = 8 << 10
6408 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6409 defer close(done)
6410
6411 conn, buf, err := w.(Hijacker).Hijack()
6412 if err != nil {
6413 t.Error(err)
6414 return
6415 }
6416 defer conn.Close()
6417 slurp, err := io.ReadAll(buf.Reader)
6418 if err != nil {
6419 t.Errorf("Copy: %v", err)
6420 }
6421 allX := true
6422 for _, v := range slurp {
6423 if v != 'x' {
6424 allX = false
6425 }
6426 }
6427 if len(slurp) != size {
6428 t.Errorf("read %d; want %d", len(slurp), size)
6429 } else if !allX {
6430 t.Errorf("read %q; want %d 'x'", slurp, size)
6431 }
6432 })).ts
6433
6434 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6435 if err != nil {
6436 t.Fatal(err)
6437 }
6438 defer cn.Close()
6439 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
6440 strings.Repeat("x", size)); err != nil {
6441 t.Fatal(err)
6442 }
6443 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6444 t.Fatal(err)
6445 }
6446
6447 <-done
6448 }
6449
6450
6451 func TestServerValidatesMethod(t *testing.T) {
6452 tests := []struct {
6453 method string
6454 want int
6455 }{
6456 {"GET", 200},
6457 {"GE(T", 400},
6458 }
6459 for _, tt := range tests {
6460 conn := newTestConn()
6461 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
6462
6463 ln := &oneConnListener{conn}
6464 go Serve(ln, serve(200))
6465 <-conn.closec
6466 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
6467 if err != nil {
6468 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
6469 continue
6470 }
6471 if res.StatusCode != tt.want {
6472 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
6473 }
6474 }
6475 }
6476
6477
6478 type eofListenerNotComparable []int
6479
6480 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
6481 func (eofListenerNotComparable) Addr() net.Addr { return nil }
6482 func (eofListenerNotComparable) Close() error { return nil }
6483
6484
6485 func TestServerListenNotComparableListener(t *testing.T) {
6486 var s Server
6487 s.Serve(make(eofListenerNotComparable, 1))
6488 }
6489
6490
6491 type countCloseListener struct {
6492 net.Listener
6493 closes int32
6494 }
6495
6496 func (p *countCloseListener) Close() error {
6497 var err error
6498 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
6499 err = p.Listener.Close()
6500 }
6501 return err
6502 }
6503
6504
6505 func TestServerCloseListenerOnce(t *testing.T) {
6506 setParallel(t)
6507 defer afterTest(t)
6508
6509 ln := newLocalListener(t)
6510 defer ln.Close()
6511
6512 cl := &countCloseListener{Listener: ln}
6513 server := &Server{}
6514 sdone := make(chan bool, 1)
6515
6516 go func() {
6517 server.Serve(cl)
6518 sdone <- true
6519 }()
6520 time.Sleep(10 * time.Millisecond)
6521 server.Shutdown(context.Background())
6522 ln.Close()
6523 <-sdone
6524
6525 nclose := atomic.LoadInt32(&cl.closes)
6526 if nclose != 1 {
6527 t.Errorf("Close calls = %v; want 1", nclose)
6528 }
6529 }
6530
6531
6532 func TestServerShutdownThenServe(t *testing.T) {
6533 var srv Server
6534 cl := &countCloseListener{Listener: nil}
6535 srv.Shutdown(context.Background())
6536 got := srv.Serve(cl)
6537 if got != ErrServerClosed {
6538 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6539 }
6540 nclose := atomic.LoadInt32(&cl.closes)
6541 if nclose != 1 {
6542 t.Errorf("Close calls = %v; want 1", nclose)
6543 }
6544 }
6545
6546
6547 func TestStripPortFromHost(t *testing.T) {
6548 mux := NewServeMux()
6549
6550 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6551 fmt.Fprintf(w, "OK")
6552 })
6553 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6554 fmt.Fprintf(w, "uh-oh!")
6555 })
6556
6557 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6558 rw := httptest.NewRecorder()
6559
6560 mux.ServeHTTP(rw, req)
6561
6562 response := rw.Body.String()
6563 if response != "OK" {
6564 t.Errorf("Response gotten was %q", response)
6565 }
6566 }
6567
6568 func TestServerContexts(t *testing.T) { run(t, testServerContexts, http3SkippedMode) }
6569 func testServerContexts(t *testing.T, mode testMode) {
6570 type baseKey struct{}
6571 type connKey struct{}
6572 ch := make(chan context.Context, 1)
6573 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6574 ch <- r.Context()
6575 }), func(ts *httptest.Server) {
6576 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6577 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6578 t.Errorf("unexpected onceClose listener type %T", ln)
6579 }
6580 return context.WithValue(context.Background(), baseKey{}, "base")
6581 }
6582 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6583 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6584 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6585 }
6586 return context.WithValue(ctx, connKey{}, "conn")
6587 }
6588 }).ts
6589 res, err := ts.Client().Get(ts.URL)
6590 if err != nil {
6591 t.Fatal(err)
6592 }
6593 res.Body.Close()
6594 ctx := <-ch
6595 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6596 t.Errorf("base context key = %#v; want %q", got, want)
6597 }
6598 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6599 t.Errorf("conn context key = %#v; want %q", got, want)
6600 }
6601 }
6602
6603
6604 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6605 run(t, testConnContextNotModifyingAllContexts)
6606 }
6607 func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
6608 type connKey struct{}
6609 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6610 rw.Header().Set("Connection", "close")
6611 }), func(ts *httptest.Server) {
6612 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6613 if got := ctx.Value(connKey{}); got != nil {
6614 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6615 }
6616 return context.WithValue(ctx, connKey{}, "conn")
6617 }
6618 }).ts
6619
6620 var res *Response
6621 var err error
6622
6623 res, err = ts.Client().Get(ts.URL)
6624 if err != nil {
6625 t.Fatal(err)
6626 }
6627 res.Body.Close()
6628
6629 res, err = ts.Client().Get(ts.URL)
6630 if err != nil {
6631 t.Fatal(err)
6632 }
6633 res.Body.Close()
6634 }
6635
6636
6637
6638 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6639 run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
6640 }
6641 func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
6642 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6643 w.Write([]byte("Hello, World!"))
6644 })).ts
6645
6646 serverURL, err := url.Parse(cst.URL)
6647 if err != nil {
6648 t.Fatalf("Failed to parse server URL: %v", err)
6649 }
6650
6651 unsupportedTEs := []string{
6652 "fugazi",
6653 "foo-bar",
6654 "unknown",
6655 `" chunked"`,
6656 }
6657
6658 for _, badTE := range unsupportedTEs {
6659 http1ReqBody := fmt.Sprintf(""+
6660 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6661 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6662
6663 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6664 if err != nil {
6665 t.Errorf("%q. unexpected error: %v", badTE, err)
6666 continue
6667 }
6668
6669 wantBody := fmt.Sprintf("" +
6670 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6671 "Connection: close\r\n\r\nUnsupported transfer encoding")
6672
6673 if string(gotBody) != wantBody {
6674 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6675 }
6676 }
6677 }
6678
6679
6680 func TestContentEncodingNoSniffing(t *testing.T) {
6681 run(t, testContentEncodingNoSniffing, http3SkippedMode)
6682 }
6683 func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
6684 type setting struct {
6685 name string
6686 body []byte
6687
6688
6689
6690
6691 contentEncoding any
6692 wantContentType string
6693 }
6694
6695 settings := []*setting{
6696 {
6697 name: "gzip content-encoding, gzipped",
6698 contentEncoding: "application/gzip",
6699 wantContentType: "",
6700 body: func() []byte {
6701 buf := new(bytes.Buffer)
6702 gzw := gzip.NewWriter(buf)
6703 gzw.Write([]byte("doctype html><p>Hello</p>"))
6704 gzw.Close()
6705 return buf.Bytes()
6706 }(),
6707 },
6708 {
6709 name: "zlib content-encoding, zlibbed",
6710 contentEncoding: "application/zlib",
6711 wantContentType: "",
6712 body: func() []byte {
6713 buf := new(bytes.Buffer)
6714 zw := zlib.NewWriter(buf)
6715 zw.Write([]byte("doctype html><p>Hello</p>"))
6716 zw.Close()
6717 return buf.Bytes()
6718 }(),
6719 },
6720 {
6721 name: "no content-encoding",
6722 wantContentType: "application/x-gzip",
6723 body: func() []byte {
6724 buf := new(bytes.Buffer)
6725 gzw := gzip.NewWriter(buf)
6726 gzw.Write([]byte("doctype html><p>Hello</p>"))
6727 gzw.Close()
6728 return buf.Bytes()
6729 }(),
6730 },
6731 {
6732 name: "phony content-encoding",
6733 contentEncoding: "foo/bar",
6734 body: []byte("doctype html><p>Hello</p>"),
6735 },
6736 {
6737 name: "empty but set content-encoding",
6738 contentEncoding: "",
6739 wantContentType: "audio/mpeg",
6740 body: []byte("ID3"),
6741 },
6742 }
6743
6744 for _, tt := range settings {
6745 t.Run(tt.name, func(t *testing.T) {
6746 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6747 if tt.contentEncoding != nil {
6748 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6749 }
6750 rw.Write(tt.body)
6751 }))
6752
6753 res, err := cst.c.Get(cst.ts.URL)
6754 if err != nil {
6755 t.Fatalf("Failed to fetch URL: %v", err)
6756 }
6757 defer res.Body.Close()
6758
6759 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6760 if w != nil {
6761 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6762 } else if g != "" {
6763 t.Errorf("Unexpected Content-Encoding %q", g)
6764 }
6765 }
6766
6767 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6768 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6769 }
6770 })
6771 }
6772 }
6773
6774
6775
6776 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6777 run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
6778 }
6779 func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
6780 if testing.Short() {
6781 t.Skip("skipping in short mode")
6782 }
6783
6784 pc, curFile, _, _ := runtime.Caller(0)
6785 curFileBaseName := filepath.Base(curFile)
6786 testFuncName := runtime.FuncForPC(pc).Name()
6787
6788 timeoutMsg := "timed out here!"
6789
6790 tests := []struct {
6791 name string
6792 mustTimeout bool
6793 wantResp string
6794 }{
6795 {
6796 name: "return before timeout",
6797 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6798 },
6799 {
6800 name: "return after timeout",
6801 mustTimeout: true,
6802 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6803 len(timeoutMsg), timeoutMsg),
6804 },
6805 }
6806
6807 for _, tt := range tests {
6808 t.Run(tt.name, func(t *testing.T) {
6809 exitHandler := make(chan bool, 1)
6810 defer close(exitHandler)
6811 lastLine := make(chan int, 1)
6812
6813 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6814 w.WriteHeader(404)
6815 w.WriteHeader(404)
6816 w.WriteHeader(404)
6817 w.WriteHeader(404)
6818 _, _, line, _ := runtime.Caller(0)
6819 lastLine <- line
6820 <-exitHandler
6821 })
6822
6823 if !tt.mustTimeout {
6824 exitHandler <- true
6825 }
6826
6827 logBuf := new(strings.Builder)
6828 srvLog := log.New(logBuf, "", 0)
6829
6830 dur := 20 * time.Millisecond
6831 if !tt.mustTimeout {
6832
6833 dur = 10 * time.Second
6834 }
6835 th := TimeoutHandler(sh, dur, timeoutMsg)
6836 cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
6837 defer cst.close()
6838
6839 res, err := cst.c.Get(cst.ts.URL)
6840 if err != nil {
6841 t.Fatalf("Unexpected error: %v", err)
6842 }
6843
6844
6845
6846 res.Header.Del("Date")
6847 res.Header.Del("Content-Type")
6848
6849
6850 blob, _ := httputil.DumpResponse(res, true)
6851 if g, w := string(blob), tt.wantResp; g != w {
6852 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6853 }
6854
6855
6856
6857 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6858 if g, w := len(logEntries), 3; g != w {
6859 blob, _ := json.MarshalIndent(logEntries, "", " ")
6860 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6861 }
6862
6863 lastSpuriousLine := <-lastLine
6864 firstSpuriousLine := lastSpuriousLine - 3
6865
6866
6867 for i, logEntry := range logEntries {
6868 wantLine := firstSpuriousLine + i
6869 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6870 testFuncName, curFileBaseName, wantLine)
6871 re := regexp.MustCompile(pat)
6872 if !re.MatchString(logEntry) {
6873 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6874 }
6875 }
6876 })
6877 }
6878 }
6879
6880
6881
6882
6883 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6884 conn, err := net.Dial("tcp", host)
6885 if err != nil {
6886 return nil, err
6887 }
6888 defer conn.Close()
6889
6890 if _, err := conn.Write(http1ReqBody); err != nil {
6891 return nil, err
6892 }
6893 return io.ReadAll(conn)
6894 }
6895
6896 func BenchmarkResponseStatusLine(b *testing.B) {
6897 b.ReportAllocs()
6898 b.RunParallel(func(pb *testing.PB) {
6899 bw := bufio.NewWriter(io.Discard)
6900 var buf3 [3]byte
6901 for pb.Next() {
6902 Export_writeStatusLine(bw, true, 200, buf3[:])
6903 }
6904 })
6905 }
6906
6907 func TestDisableKeepAliveUpgrade(t *testing.T) {
6908 run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
6909 }
6910 func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
6911 if testing.Short() {
6912 t.Skip("skipping in short mode")
6913 }
6914
6915 s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6916 w.Header().Set("Connection", "Upgrade")
6917 w.Header().Set("Upgrade", "someProto")
6918 w.WriteHeader(StatusSwitchingProtocols)
6919 c, buf, err := w.(Hijacker).Hijack()
6920 if err != nil {
6921 return
6922 }
6923 defer c.Close()
6924
6925
6926
6927 io.Copy(c, buf)
6928 }), func(ts *httptest.Server) {
6929 ts.Config.SetKeepAlivesEnabled(false)
6930 }).ts
6931
6932 cl := s.Client()
6933 cl.Transport.(*Transport).DisableKeepAlives = true
6934
6935 resp, err := cl.Get(s.URL)
6936 if err != nil {
6937 t.Fatalf("failed to perform request: %v", err)
6938 }
6939 defer resp.Body.Close()
6940
6941 if resp.StatusCode != StatusSwitchingProtocols {
6942 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6943 }
6944
6945 rwc, ok := resp.Body.(io.ReadWriteCloser)
6946 if !ok {
6947 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6948 }
6949
6950 _, err = rwc.Write([]byte("hello"))
6951 if err != nil {
6952 t.Fatalf("failed to write to body: %v", err)
6953 }
6954
6955 b := make([]byte, 5)
6956 _, err = io.ReadFull(rwc, b)
6957 if err != nil {
6958 t.Fatalf("failed to read from body: %v", err)
6959 }
6960
6961 if string(b) != "hello" {
6962 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6963 }
6964 }
6965
6966 type tlogWriter struct{ t *testing.T }
6967
6968 func (w tlogWriter) Write(p []byte) (int, error) {
6969 w.t.Log(string(p))
6970 return len(p), nil
6971 }
6972
6973 func TestWriteHeaderSwitchingProtocols(t *testing.T) {
6974 run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
6975 }
6976 func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
6977 const wantBody = "want"
6978 const wantUpgrade = "someProto"
6979 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6980 w.Header().Set("Connection", "Upgrade")
6981 w.Header().Set("Upgrade", wantUpgrade)
6982 w.WriteHeader(StatusSwitchingProtocols)
6983 NewResponseController(w).Flush()
6984
6985
6986 w.WriteHeader(200)
6987 if _, err := w.Write([]byte("x")); err == nil {
6988 t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
6989 }
6990
6991 c, _, err := NewResponseController(w).Hijack()
6992 if err != nil {
6993 t.Errorf("Hijack: %v", err)
6994 return
6995 }
6996 defer c.Close()
6997 if _, err := c.Write([]byte(wantBody)); err != nil {
6998 t.Errorf("Write to hijacked body: %v", err)
6999 }
7000 }), func(ts *httptest.Server) {
7001
7002 ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
7003 }).ts
7004
7005 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
7006 if err != nil {
7007 t.Fatalf("net.Dial: %v", err)
7008 }
7009 _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
7010 if err != nil {
7011 t.Fatalf("conn.Write: %v", err)
7012 }
7013 defer conn.Close()
7014
7015 r := bufio.NewReader(conn)
7016 res, err := ReadResponse(r, &Request{Method: "GET"})
7017 if err != nil {
7018 t.Fatal("ReadResponse error:", err)
7019 }
7020 if res.StatusCode != StatusSwitchingProtocols {
7021 t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
7022 }
7023 if got := res.Header.Get("Upgrade"); got != wantUpgrade {
7024 t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
7025 }
7026 body, err := io.ReadAll(r)
7027 if err != nil {
7028 t.Error(err)
7029 }
7030 if string(body) != wantBody {
7031 t.Errorf("Response body = %q, want %q", string(body), wantBody)
7032 }
7033 }
7034
7035 func TestMuxRedirectRelative(t *testing.T) {
7036 setParallel(t)
7037 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
7038 if err != nil {
7039 t.Errorf("%s", err)
7040 }
7041 mux := NewServeMux()
7042 resp := httptest.NewRecorder()
7043 mux.ServeHTTP(resp, req)
7044 if got, want := resp.Header().Get("Location"), "/"; got != want {
7045 t.Errorf("Location header expected %q; got %q", want, got)
7046 }
7047 if got, want := resp.Code, StatusTemporaryRedirect; got != want {
7048 t.Errorf("Expected response code %d; got %d", want, got)
7049 }
7050 }
7051
7052
7053 func TestQuerySemicolon(t *testing.T) {
7054 t.Cleanup(func() { afterTest(t) })
7055
7056 tests := []struct {
7057 query string
7058 xNoSemicolons string
7059 xWithSemicolons string
7060 expectParseFormErr bool
7061 }{
7062 {"?a=1;x=bad&x=good", "good", "bad", true},
7063 {"?a=1;b=bad&x=good", "good", "good", true},
7064 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
7065 {"?a=1;x=good;x=bad", "", "good", true},
7066 }
7067
7068 run(t, func(t *testing.T, mode testMode) {
7069 for _, tt := range tests {
7070 t.Run(tt.query+"/allow=false", func(t *testing.T) {
7071 allowSemicolons := false
7072 testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
7073 })
7074 t.Run(tt.query+"/allow=true", func(t *testing.T) {
7075 allowSemicolons, expectParseFormErr := true, false
7076 testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
7077 })
7078 }
7079 })
7080 }
7081
7082 func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
7083 writeBackX := func(w ResponseWriter, r *Request) {
7084 x := r.URL.Query().Get("x")
7085 if expectParseFormErr {
7086 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
7087 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
7088 }
7089 } else {
7090 if err := r.ParseForm(); err != nil {
7091 t.Errorf("expected no error from ParseForm, got %v", err)
7092 }
7093 }
7094 if got := r.FormValue("x"); x != got {
7095 t.Errorf("got %q from FormValue, want %q", got, x)
7096 }
7097 fmt.Fprintf(w, "%s", x)
7098 }
7099
7100 h := Handler(HandlerFunc(writeBackX))
7101 if allowSemicolons {
7102 h = AllowQuerySemicolons(h)
7103 }
7104
7105 logBuf := &strings.Builder{}
7106 ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
7107 ts.Config.ErrorLog = log.New(logBuf, "", 0)
7108 }).ts
7109
7110 req, _ := NewRequest("GET", ts.URL+query, nil)
7111 res, err := ts.Client().Do(req)
7112 if err != nil {
7113 t.Fatal(err)
7114 }
7115 slurp, _ := io.ReadAll(res.Body)
7116 res.Body.Close()
7117 if got, want := res.StatusCode, 200; got != want {
7118 t.Errorf("Status = %d; want = %d", got, want)
7119 }
7120 if got, want := string(slurp), wantX; got != want {
7121 t.Errorf("Body = %q; want = %q", got, want)
7122 }
7123 }
7124
7125 func TestMaxBytesHandler(t *testing.T) {
7126
7127 defer afterTest(t)
7128
7129 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
7130 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
7131 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
7132 func(t *testing.T) {
7133 run(t, func(t *testing.T, mode testMode) {
7134 testMaxBytesHandler(t, mode, maxSize, requestSize)
7135 }, testNotParallel)
7136 })
7137 }
7138 }
7139 }
7140
7141 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
7142 runTimeSensitiveTest(t, []time.Duration{
7143 1 * time.Millisecond,
7144 5 * time.Millisecond,
7145 10 * time.Millisecond,
7146 50 * time.Millisecond,
7147 100 * time.Millisecond,
7148 500 * time.Millisecond,
7149 time.Second,
7150 5 * time.Second,
7151 }, func(t *testing.T, timeout time.Duration) error {
7152 SetRSTAvoidanceDelay(t, timeout)
7153 t.Logf("set RST avoidance delay to %v", timeout)
7154
7155 var (
7156 mu sync.Mutex
7157 handlerN int64
7158 handlerErr error
7159 )
7160 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
7161 mu.Lock()
7162 defer mu.Unlock()
7163 var buf bytes.Buffer
7164 handlerN, handlerErr = io.Copy(&buf, r.Body)
7165 io.Copy(w, &buf)
7166 })
7167
7168 cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
7169
7170
7171 defer cst.close()
7172 ts := cst.ts
7173 c := ts.Client()
7174
7175 body := strings.Repeat("a", int(requestSize))
7176 var wg sync.WaitGroup
7177 defer wg.Wait()
7178 getBody := func() (io.ReadCloser, error) {
7179 wg.Add(1)
7180 body := &wgReadCloser{
7181 Reader: strings.NewReader(body),
7182 wg: &wg,
7183 }
7184 return body, nil
7185 }
7186 reqBody, _ := getBody()
7187 req, err := NewRequest("POST", ts.URL, reqBody)
7188 if err != nil {
7189 reqBody.Close()
7190 t.Fatal(err)
7191 }
7192 req.ContentLength = int64(len(body))
7193 req.GetBody = getBody
7194 req.Header.Set("Content-Type", "text/plain")
7195
7196 var buf strings.Builder
7197 res, err := c.Do(req)
7198 if err != nil {
7199 return fmt.Errorf("unexpected connection error: %v", err)
7200 } else {
7201 _, err = io.Copy(&buf, res.Body)
7202 res.Body.Close()
7203 if err != nil {
7204 return fmt.Errorf("unexpected read error: %v", err)
7205 }
7206 }
7207
7208
7209
7210
7211 mu.Lock()
7212 defer mu.Unlock()
7213 if handlerN > maxSize {
7214 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
7215 }
7216 if requestSize > maxSize && handlerErr == nil {
7217 t.Error("expected error on handler side; got nil")
7218 }
7219 if requestSize <= maxSize {
7220 if handlerErr != nil {
7221 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
7222 }
7223 if handlerN != requestSize {
7224 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
7225 }
7226 }
7227 if buf.Len() != int(handlerN) {
7228 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
7229 }
7230
7231 return nil
7232 })
7233 }
7234
7235 func TestEarlyHints(t *testing.T) {
7236 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7237 h := w.Header()
7238 h.Add("Link", "</style.css>; rel=preload; as=style")
7239 h.Add("Link", "</script.js>; rel=preload; as=script")
7240 w.WriteHeader(StatusEarlyHints)
7241
7242 h.Add("Link", "</foo.js>; rel=preload; as=script")
7243 w.WriteHeader(StatusEarlyHints)
7244
7245 w.Write([]byte("stuff"))
7246 }))
7247
7248 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7249 expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: "
7250 if !strings.Contains(got, expected) {
7251 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7252 }
7253 }
7254 func TestProcessing(t *testing.T) {
7255 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7256 w.WriteHeader(StatusProcessing)
7257 w.Write([]byte("stuff"))
7258 }))
7259
7260 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7261 expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: "
7262 if !strings.Contains(got, expected) {
7263 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7264 }
7265 }
7266
7267 func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup, http3SkippedMode) }
7268 func testParseFormCleanup(t *testing.T, mode testMode) {
7269 if mode == http2Mode {
7270 t.Skip("https://go.dev/issue/20253")
7271 }
7272
7273 const maxMemory = 1024
7274 const key = "file"
7275
7276 if runtime.GOOS == "windows" {
7277
7278 t.Skip("https://go.dev/issue/25965")
7279 }
7280
7281 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7282 r.ParseMultipartForm(maxMemory)
7283 f, _, err := r.FormFile(key)
7284 if err != nil {
7285 t.Errorf("r.FormFile(%q) = %v", key, err)
7286 return
7287 }
7288 of, ok := f.(*os.File)
7289 if !ok {
7290 t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
7291 return
7292 }
7293 w.Write([]byte(of.Name()))
7294 }))
7295
7296 fBuf := new(bytes.Buffer)
7297 mw := multipart.NewWriter(fBuf)
7298 mf, err := mw.CreateFormFile(key, "myfile.txt")
7299 if err != nil {
7300 t.Fatal(err)
7301 }
7302 if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
7303 t.Fatal(err)
7304 }
7305 if err := mw.Close(); err != nil {
7306 t.Fatal(err)
7307 }
7308 req, err := NewRequest("POST", cst.ts.URL, fBuf)
7309 if err != nil {
7310 t.Fatal(err)
7311 }
7312 req.Header.Set("Content-Type", mw.FormDataContentType())
7313 res, err := cst.c.Do(req)
7314 if err != nil {
7315 t.Fatal(err)
7316 }
7317 defer res.Body.Close()
7318 fname, err := io.ReadAll(res.Body)
7319 if err != nil {
7320 t.Fatal(err)
7321 }
7322 cst.close()
7323 if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
7324 t.Errorf("file %q exists after HTTP handler returned", string(fname))
7325 }
7326 }
7327
7328 func TestHeadBody(t *testing.T) {
7329 const identityMode = false
7330 const chunkedMode = true
7331 run(t, func(t *testing.T, mode testMode) {
7332 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
7333 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
7334 })
7335 }
7336
7337 func TestGetBody(t *testing.T) {
7338 const identityMode = false
7339 const chunkedMode = true
7340 run(t, func(t *testing.T, mode testMode) {
7341 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
7342 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
7343 })
7344 }
7345
7346 func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
7347 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7348 b, err := io.ReadAll(r.Body)
7349 if err != nil {
7350 t.Errorf("server reading body: %v", err)
7351 return
7352 }
7353 w.Header().Set("X-Request-Body", string(b))
7354 w.Header().Set("Content-Length", "0")
7355 }))
7356 defer cst.close()
7357 for _, reqBody := range []string{
7358 "",
7359 "",
7360 "request_body",
7361 "",
7362 } {
7363 var bodyReader io.Reader
7364 if reqBody != "" {
7365 bodyReader = strings.NewReader(reqBody)
7366 if chunked {
7367 bodyReader = bufio.NewReader(bodyReader)
7368 }
7369 }
7370 req, err := NewRequest(method, cst.ts.URL, bodyReader)
7371 if err != nil {
7372 t.Fatal(err)
7373 }
7374 res, err := cst.c.Do(req)
7375 if err != nil {
7376 t.Fatal(err)
7377 }
7378 res.Body.Close()
7379 if got, want := res.StatusCode, 200; got != want {
7380 t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
7381 }
7382 if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
7383 t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
7384 }
7385 }
7386 }
7387
7388
7389
7390 func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
7391 func testDisableContentLength(t *testing.T, mode testMode) {
7392 noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7393 w.Header()["Content-Length"] = nil
7394 fmt.Fprintf(w, "OK")
7395 }))
7396
7397 res, err := noCL.c.Get(noCL.ts.URL)
7398 if err != nil {
7399 t.Fatal(err)
7400 }
7401 if got, haveCL := res.Header["Content-Length"]; haveCL {
7402 t.Errorf("Unexpected Content-Length: %q", got)
7403 }
7404 if err := res.Body.Close(); err != nil {
7405 t.Fatal(err)
7406 }
7407
7408 withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7409 fmt.Fprintf(w, "OK")
7410 }))
7411
7412 res, err = withCL.c.Get(withCL.ts.URL)
7413 if err != nil {
7414 t.Fatal(err)
7415 }
7416
7417 if got := res.Header.Get("Content-Length"); got != "2" && mode != http3Mode {
7418 t.Errorf("Content-Length: %q; want 2", got)
7419 }
7420 if err := res.Body.Close(); err != nil {
7421 t.Fatal(err)
7422 }
7423 }
7424
7425 func TestErrorContentLength(t *testing.T) { run(t, testErrorContentLength) }
7426 func testErrorContentLength(t *testing.T, mode testMode) {
7427 const errorBody = "an error occurred"
7428 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7429 w.Header().Set("Content-Length", "1000")
7430 Error(w, errorBody, 400)
7431 }))
7432 res, err := cst.c.Get(cst.ts.URL)
7433 if err != nil {
7434 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7435 }
7436 defer res.Body.Close()
7437 body, err := io.ReadAll(res.Body)
7438 if err != nil {
7439 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7440 }
7441 if string(body) != errorBody+"\n" {
7442 t.Fatalf("read body: %q, want %q", string(body), errorBody)
7443 }
7444 }
7445
7446 func TestError(t *testing.T) {
7447 w := httptest.NewRecorder()
7448 w.Header().Set("Content-Length", "1")
7449 w.Header().Set("X-Content-Type-Options", "scratch and sniff")
7450 w.Header().Set("Other", "foo")
7451 Error(w, "oops", 432)
7452
7453 h := w.Header()
7454 for _, hdr := range []string{"Content-Length"} {
7455 if v, ok := h[hdr]; ok {
7456 t.Errorf("%s: %q, want not present", hdr, v)
7457 }
7458 }
7459 if v := h.Get("Content-Type"); v != "text/plain; charset=utf-8" {
7460 t.Errorf("Content-Type: %q, want %q", v, "text/plain; charset=utf-8")
7461 }
7462 if v := h.Get("X-Content-Type-Options"); v != "nosniff" {
7463 t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff")
7464 }
7465 }
7466
7467 func TestServerReadAfterWriteHeader100Continue(t *testing.T) {
7468 run(t, testServerReadAfterWriteHeader100Continue)
7469 }
7470 func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) {
7471 t.Skip("https://go.dev/issue/67555")
7472 body := []byte("body")
7473 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7474 w.WriteHeader(200)
7475 NewResponseController(w).Flush()
7476 io.ReadAll(r.Body)
7477 w.Write(body)
7478 }), func(tr *Transport) {
7479 tr.ExpectContinueTimeout = 24 * time.Hour
7480 })
7481
7482 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7483 req.Header.Set("Expect", "100-continue")
7484 res, err := cst.c.Do(req)
7485 if err != nil {
7486 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7487 }
7488 defer res.Body.Close()
7489 got, err := io.ReadAll(res.Body)
7490 if err != nil {
7491 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7492 }
7493 if !bytes.Equal(got, body) {
7494 t.Fatalf("response body = %q, want %q", got, body)
7495 }
7496 }
7497
7498 func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
7499 run(t, testServerReadAfterHandlerDone100Continue)
7500 }
7501 func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
7502 t.Skip("https://go.dev/issue/67555")
7503 readyc := make(chan struct{})
7504 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7505 go func() {
7506 <-readyc
7507 io.ReadAll(r.Body)
7508 <-readyc
7509 }()
7510 }), func(tr *Transport) {
7511 tr.ExpectContinueTimeout = 24 * time.Hour
7512 })
7513
7514 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7515 req.Header.Set("Expect", "100-continue")
7516 res, err := cst.c.Do(req)
7517 if err != nil {
7518 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7519 }
7520 res.Body.Close()
7521 readyc <- struct{}{}
7522 readyc <- struct{}{}
7523 }
7524
7525 func TestServerReadAfterHandlerAbort100Continue(t *testing.T) {
7526 run(t, testServerReadAfterHandlerAbort100Continue)
7527 }
7528 func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) {
7529 t.Skip("https://go.dev/issue/67555")
7530 readyc := make(chan struct{})
7531 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7532 go func() {
7533 <-readyc
7534 io.ReadAll(r.Body)
7535 <-readyc
7536 }()
7537 panic(ErrAbortHandler)
7538 }), func(tr *Transport) {
7539 tr.ExpectContinueTimeout = 24 * time.Hour
7540 })
7541
7542 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7543 req.Header.Set("Expect", "100-continue")
7544 res, err := cst.c.Do(req)
7545 if err == nil {
7546 res.Body.Close()
7547 }
7548 readyc <- struct{}{}
7549 readyc <- struct{}{}
7550 }
7551
7552 func TestInvalidChunkedBodies(t *testing.T) {
7553 for _, test := range []struct {
7554 name string
7555 b string
7556 }{{
7557 name: "bare LF in chunk size",
7558 b: "1\na\r\n0\r\n\r\n",
7559 }, {
7560 name: "bare LF at body end",
7561 b: "1\r\na\r\n0\r\n\n",
7562 }} {
7563 t.Run(test.name, func(t *testing.T) {
7564 reqc := make(chan error)
7565 ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7566 got, err := io.ReadAll(r.Body)
7567 if err == nil {
7568 t.Logf("read body: %q", got)
7569 }
7570 reqc <- err
7571 })).ts
7572
7573 serverURL, err := url.Parse(ts.URL)
7574 if err != nil {
7575 t.Fatal(err)
7576 }
7577
7578 conn, err := net.Dial("tcp", serverURL.Host)
7579 if err != nil {
7580 t.Fatal(err)
7581 }
7582
7583 if _, err := conn.Write([]byte(
7584 "POST / HTTP/1.1\r\n" +
7585 "Host: localhost\r\n" +
7586 "Transfer-Encoding: chunked\r\n" +
7587 "Connection: close\r\n" +
7588 "\r\n" +
7589 test.b)); err != nil {
7590 t.Fatal(err)
7591 }
7592 conn.(*net.TCPConn).CloseWrite()
7593
7594 if err := <-reqc; err == nil {
7595 t.Errorf("server handler: io.ReadAll(r.Body) succeeded, want error")
7596 }
7597 })
7598 }
7599 }
7600
7601
7602 func TestServerTLSNextProtos(t *testing.T) {
7603 run(t, testServerTLSNextProtos, []testMode{https1Mode, http2Mode})
7604 }
7605 func testServerTLSNextProtos(t *testing.T, mode testMode) {
7606 CondSkipHTTP2(t)
7607
7608 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7609 if err != nil {
7610 t.Fatal(err)
7611 }
7612 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7613 if err != nil {
7614 t.Fatal(err)
7615 }
7616 certpool := x509.NewCertPool()
7617 certpool.AddCert(leafCert)
7618
7619 protos := new(Protocols)
7620 switch mode {
7621 case https1Mode:
7622 protos.SetHTTP1(true)
7623 case http2Mode:
7624 protos.SetHTTP2(true)
7625 }
7626
7627 wantNextProtos := []string{"http/1.1", "h2", "other"}
7628 nextProtos := slices.Clone(wantNextProtos)
7629
7630
7631 srv := &Server{
7632 TLSConfig: &tls.Config{
7633 Certificates: []tls.Certificate{cert},
7634 NextProtos: nextProtos,
7635 },
7636 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {}),
7637 Protocols: protos,
7638 }
7639 tr := &Transport{
7640 TLSClientConfig: &tls.Config{
7641 RootCAs: certpool,
7642 NextProtos: nextProtos,
7643 },
7644 Protocols: protos,
7645 }
7646
7647 listener := newLocalListener(t)
7648 srvc := make(chan error, 1)
7649 go func() {
7650 srvc <- srv.ServeTLS(listener, "", "")
7651 }()
7652 t.Cleanup(func() {
7653 srv.Close()
7654 <-srvc
7655 })
7656
7657 client := &Client{Transport: tr}
7658 resp, err := client.Get("https://" + listener.Addr().String())
7659 if err != nil {
7660 t.Fatal(err)
7661 }
7662 resp.Body.Close()
7663
7664 if !slices.Equal(nextProtos, wantNextProtos) {
7665 t.Fatalf("after running test: original NextProtos slice = %v, want %v", nextProtos, wantNextProtos)
7666 }
7667 }
7668
7669
7670
7671 func TestServerHTTP2Disabled(t *testing.T) {
7672 synctest.Test(t, func(t *testing.T) {
7673 li := fakeNetListen()
7674 srv := &Server{}
7675 srv.Protocols = new(Protocols)
7676 srv.Protocols.SetHTTP1(true)
7677 go srv.ServeTLS(li, "", "")
7678 synctest.Wait()
7679 srv.Shutdown(t.Context())
7680 })
7681 }
7682
View as plain text