Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 crand "crypto/rand"
16 "crypto/tls"
17 "crypto/x509"
18 "encoding/json"
19 "errors"
20 "fmt"
21 "internal/testenv"
22 "io"
23 "log"
24 "math/rand"
25 "mime/multipart"
26 "net"
27 . "net/http"
28 "net/http/httptest"
29 "net/http/httptrace"
30 "net/http/httputil"
31 "net/http/internal"
32 "net/http/internal/testcert"
33 "net/url"
34 "os"
35 "path/filepath"
36 "reflect"
37 "regexp"
38 "runtime"
39 "slices"
40 "strconv"
41 "strings"
42 "sync"
43 "sync/atomic"
44 "syscall"
45 "testing"
46 "testing/synctest"
47 "time"
48 )
49
50 type dummyAddr string
51 type oneConnListener struct {
52 conn net.Conn
53 }
54
55 func (l *oneConnListener) Accept() (c net.Conn, err error) {
56 c = l.conn
57 if c == nil {
58 err = io.EOF
59 return
60 }
61 err = nil
62 l.conn = nil
63 return
64 }
65
66 func (l *oneConnListener) Close() error {
67 return nil
68 }
69
70 func (l *oneConnListener) Addr() net.Addr {
71 return dummyAddr("test-address")
72 }
73
74 func (a dummyAddr) Network() string {
75 return string(a)
76 }
77
78 func (a dummyAddr) String() string {
79 return string(a)
80 }
81
82 type noopConn struct{}
83
84 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
85 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
86 func (noopConn) SetDeadline(t time.Time) error { return nil }
87 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
88 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
89
90 type rwTestConn struct {
91 io.Reader
92 io.Writer
93 noopConn
94
95 closeFunc func() error
96 closec chan bool
97 }
98
99 func (c *rwTestConn) Close() error {
100 if c.closeFunc != nil {
101 return c.closeFunc()
102 }
103 select {
104 case c.closec <- true:
105 default:
106 }
107 return nil
108 }
109
110 type testConn struct {
111 readMu sync.Mutex
112 readBuf bytes.Buffer
113 writeBuf bytes.Buffer
114 closec chan bool
115 noopConn
116 }
117
118 func newTestConn() *testConn {
119 return &testConn{closec: make(chan bool, 1)}
120 }
121
122 func (c *testConn) Read(b []byte) (int, error) {
123 c.readMu.Lock()
124 defer c.readMu.Unlock()
125 return c.readBuf.Read(b)
126 }
127
128 func (c *testConn) Write(b []byte) (int, error) {
129 return c.writeBuf.Write(b)
130 }
131
132 func (c *testConn) Close() error {
133 select {
134 case c.closec <- true:
135 default:
136 }
137 return nil
138 }
139
140
141
142 func reqBytes(req string) []byte {
143 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
144 }
145
146 type handlerTest struct {
147 logbuf bytes.Buffer
148 handler Handler
149 }
150
151 func newHandlerTest(h Handler) handlerTest {
152 return handlerTest{handler: h}
153 }
154
155 func (ht *handlerTest) rawResponse(req string) string {
156 reqb := reqBytes(req)
157 var output strings.Builder
158 conn := &rwTestConn{
159 Reader: bytes.NewReader(reqb),
160 Writer: &output,
161 closec: make(chan bool, 1),
162 }
163 ln := &oneConnListener{conn: conn}
164 srv := &Server{
165 ErrorLog: log.New(&ht.logbuf, "", 0),
166 Handler: ht.handler,
167 }
168 go srv.Serve(ln)
169 <-conn.closec
170 return output.String()
171 }
172
173 func TestConsumingBodyOnNextConn(t *testing.T) {
174 t.Parallel()
175 defer afterTest(t)
176 conn := new(testConn)
177 for i := 0; i < 2; i++ {
178 conn.readBuf.Write([]byte(
179 "POST / HTTP/1.1\r\n" +
180 "Host: test\r\n" +
181 "Content-Length: 11\r\n" +
182 "\r\n" +
183 "foo=1&bar=1"))
184 }
185
186 reqNum := 0
187 ch := make(chan *Request)
188 servech := make(chan error)
189 listener := &oneConnListener{conn}
190 handler := func(res ResponseWriter, req *Request) {
191 reqNum++
192 ch <- req
193 }
194
195 go func() {
196 servech <- Serve(listener, HandlerFunc(handler))
197 }()
198
199 var req *Request
200 req = <-ch
201 if req == nil {
202 t.Fatal("Got nil first request.")
203 }
204 if req.Method != "POST" {
205 t.Errorf("For request #1's method, got %q; expected %q",
206 req.Method, "POST")
207 }
208
209 req = <-ch
210 if req == nil {
211 t.Fatal("Got nil first request.")
212 }
213 if req.Method != "POST" {
214 t.Errorf("For request #2's method, got %q; expected %q",
215 req.Method, "POST")
216 }
217
218 if serveerr := <-servech; serveerr != io.EOF {
219 t.Errorf("Serve returned %q; expected EOF", serveerr)
220 }
221 }
222
223 type stringHandler string
224
225 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
226 w.Header().Set("Result", string(s))
227 }
228
229 var handlers = []struct {
230 pattern string
231 msg string
232 }{
233 {"/", "Default"},
234 {"/someDir/", "someDir"},
235 {"/#/", "hash"},
236 {"someHost.com/someDir/", "someHost.com/someDir"},
237 }
238
239 var vtests = []struct {
240 url string
241 expected string
242 }{
243 {"http://localhost/someDir/apage", "someDir"},
244 {"http://localhost/%23/apage", "hash"},
245 {"http://localhost/otherDir/apage", "Default"},
246 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
247 {"http://otherHost.com/someDir/apage", "someDir"},
248 {"http://otherHost.com/aDir/apage", "Default"},
249
250 {"http://localhost/someDir", "/someDir/"},
251 {"http://localhost/%23", "/%23/"},
252 {"http://someHost.com/someDir", "/someDir/"},
253 }
254
255 func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
256 func testHostHandlers(t *testing.T, mode testMode) {
257 mux := NewServeMux()
258 for _, h := range handlers {
259 mux.Handle(h.pattern, stringHandler(h.msg))
260 }
261 ts := newClientServerTest(t, mode, mux).ts
262
263 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
264 if err != nil {
265 t.Fatal(err)
266 }
267 defer conn.Close()
268 cc := httputil.NewClientConn(conn, nil)
269 for _, vt := range vtests {
270 var r *Response
271 var req Request
272 if req.URL, err = url.Parse(vt.url); err != nil {
273 t.Errorf("cannot parse url: %v", err)
274 continue
275 }
276 if err := cc.Write(&req); err != nil {
277 t.Errorf("writing request: %v", err)
278 continue
279 }
280 r, err := cc.Read(&req)
281 if err != nil {
282 t.Errorf("reading response: %v", err)
283 continue
284 }
285 switch r.StatusCode {
286 case StatusOK:
287 s := r.Header.Get("Result")
288 if s != vt.expected {
289 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
290 }
291 case StatusTemporaryRedirect:
292 s := r.Header.Get("Location")
293 if s != vt.expected {
294 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
295 }
296 default:
297 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
298 }
299 }
300 }
301
302 var serveMuxRegister = []struct {
303 pattern string
304 h Handler
305 }{
306 {"/dir/", serve(200)},
307 {"/search", serve(201)},
308 {"codesearch.google.com/search", serve(202)},
309 {"codesearch.google.com/", serve(203)},
310 {"example.com/", HandlerFunc(checkQueryStringHandler)},
311 }
312
313
314 func serve(code int) HandlerFunc {
315 return func(w ResponseWriter, r *Request) {
316 w.WriteHeader(code)
317 }
318 }
319
320
321
322
323 func checkQueryStringHandler(w ResponseWriter, r *Request) {
324 u := *r.URL
325 u.Scheme = "http"
326 u.Host = r.Host
327 u.RawQuery = ""
328 if "http://"+r.URL.RawQuery == u.String() {
329 w.WriteHeader(200)
330 } else {
331 w.WriteHeader(500)
332 }
333 }
334
335 var serveMuxTests = []struct {
336 method string
337 host string
338 path string
339 code int
340 pattern string
341 }{
342 {"GET", "google.com", "/", 404, ""},
343 {"GET", "google.com", "/dir", 307, "/dir/"},
344 {"GET", "google.com", "/dir/", 200, "/dir/"},
345 {"GET", "google.com", "/dir/file", 200, "/dir/"},
346 {"GET", "google.com", "/search", 201, "/search"},
347 {"GET", "google.com", "/search/", 404, ""},
348 {"GET", "google.com", "/search/foo", 404, ""},
349 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
350 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
351 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
352 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
353 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
354 {"GET", "images.google.com", "/search", 201, "/search"},
355 {"GET", "images.google.com", "/search/", 404, ""},
356 {"GET", "images.google.com", "/search/foo", 404, ""},
357 {"GET", "google.com", "/../search", 307, "/search"},
358 {"GET", "google.com", "/dir/..", 307, ""},
359 {"GET", "google.com", "/dir/..", 307, ""},
360 {"GET", "google.com", "/dir/./file", 307, "/dir/"},
361
362
363
364 {"CONNECT", "google.com", "/dir", 307, "/dir/"},
365 {"CONNECT", "google.com", "/../search", 404, ""},
366 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
367 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
368 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
369 }
370
371 func TestServeMuxHandler(t *testing.T) {
372 setParallel(t)
373 mux := NewServeMux()
374 for _, e := range serveMuxRegister {
375 mux.Handle(e.pattern, e.h)
376 }
377
378 for _, tt := range serveMuxTests {
379 r := &Request{
380 Method: tt.method,
381 Host: tt.host,
382 URL: &url.URL{
383 Path: tt.path,
384 },
385 }
386 h, pattern := mux.Handler(r)
387 rr := httptest.NewRecorder()
388 h.ServeHTTP(rr, r)
389 if pattern != tt.pattern || rr.Code != tt.code {
390 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
391 }
392 }
393 }
394
395
396 func TestServeMuxHandlerTrailingSlash(t *testing.T) {
397 setParallel(t)
398 mux := NewServeMux()
399 const original = "/{x}/"
400 mux.Handle(original, NotFoundHandler())
401 r, _ := NewRequest("POST", "/foo", nil)
402 _, p := mux.Handler(r)
403 if p != original {
404 t.Errorf("got %q, want %q", p, original)
405 }
406 }
407
408
409 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
410 setParallel(t)
411 defer func() {
412 if err := recover(); err == nil {
413 t.Error("expected call to mux.HandleFunc to panic")
414 }
415 }()
416 mux := NewServeMux()
417 mux.HandleFunc("/", nil)
418 }
419
420 var serveMuxTests2 = []struct {
421 method string
422 host string
423 url string
424 code int
425 redirOk bool
426 }{
427 {"GET", "google.com", "/", 404, false},
428 {"GET", "example.com", "/test/?example.com/test/", 200, false},
429 {"GET", "example.com", "test/?example.com/test/", 200, true},
430 }
431
432
433
434 func TestServeMuxHandlerRedirects(t *testing.T) {
435 setParallel(t)
436 mux := NewServeMux()
437 for _, e := range serveMuxRegister {
438 mux.Handle(e.pattern, e.h)
439 }
440
441 for _, tt := range serveMuxTests2 {
442 tries := 1
443 turl := tt.url
444 for {
445 u, e := url.Parse(turl)
446 if e != nil {
447 t.Fatal(e)
448 }
449 r := &Request{
450 Method: tt.method,
451 Host: tt.host,
452 URL: u,
453 }
454 h, _ := mux.Handler(r)
455 rr := httptest.NewRecorder()
456 h.ServeHTTP(rr, r)
457 if rr.Code != 307 {
458 if rr.Code != tt.code {
459 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
460 }
461 break
462 }
463 if !tt.redirOk {
464 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
465 break
466 }
467 turl = rr.HeaderMap.Get("Location")
468 tries--
469 }
470 if tries < 0 {
471 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
472 }
473 }
474 }
475
476 func TestServeMuxHandlerRedirectPost(t *testing.T) {
477 setParallel(t)
478 mux := NewServeMux()
479 mux.HandleFunc("POST /test/", func(w ResponseWriter, r *Request) {
480 w.WriteHeader(200)
481 })
482
483 var code, retries int
484 startURL := "http://example.com/test"
485 reqURL := startURL
486 for retries = 0; retries <= 1; retries++ {
487 r := httptest.NewRequest("POST", reqURL, strings.NewReader("hello world"))
488 h, _ := mux.Handler(r)
489 rr := httptest.NewRecorder()
490 h.ServeHTTP(rr, r)
491 code = rr.Code
492 switch rr.Code {
493 case 307:
494 reqURL = rr.Result().Header.Get("Location")
495 continue
496 case 200:
497
498 default:
499 t.Errorf("unhandled response code: %v", rr.Code)
500 }
501 }
502 if code != 200 {
503 t.Errorf("POST %s = %d after %d retries, want = 200", startURL, code, retries)
504 }
505 }
506
507
508 func TestMuxRedirectLeadingSlashes(t *testing.T) {
509 setParallel(t)
510 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
511 for _, path := range paths {
512 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
513 if err != nil {
514 t.Errorf("%s", err)
515 }
516 mux := NewServeMux()
517 resp := httptest.NewRecorder()
518
519 mux.ServeHTTP(resp, req)
520
521 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
522 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
523 return
524 }
525
526 if code, expected := resp.Code, StatusTemporaryRedirect; code != expected {
527 t.Errorf("Expected response code of StatusPermanentRedirect; got %d", code)
528 return
529 }
530 }
531 }
532
533
534
535
536
537 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
538 run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
539 }
540 func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
541 writeBackQuery := func(w ResponseWriter, r *Request) {
542 fmt.Fprintf(w, "%s", r.URL.RawQuery)
543 }
544
545 mux := NewServeMux()
546 mux.HandleFunc("/testOne", writeBackQuery)
547 mux.HandleFunc("/testTwo/", writeBackQuery)
548 mux.HandleFunc("/testThree", writeBackQuery)
549 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
550 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
551 })
552
553 ts := newClientServerTest(t, mode, mux).ts
554
555 tests := [...]struct {
556 path string
557 method string
558 want string
559 statusOk bool
560 }{
561 0: {"/testOne?this=that", "GET", "this=that", true},
562 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
563 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
564 3: {"/testTwo?", "GET", "", true},
565 4: {"/testThree?foo", "GET", "foo", true},
566 5: {"/testThree/?foo", "GET", "foo:bar", true},
567 6: {"/testThree?foo", "CONNECT", "foo", true},
568 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
569
570
571 8: {"/testOne/foo/..?foo", "GET", "foo", true},
572 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
573 }
574
575 for i, tt := range tests {
576 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
577 res, err := ts.Client().Do(req)
578 if err != nil {
579 continue
580 }
581 slurp, _ := io.ReadAll(res.Body)
582 res.Body.Close()
583 if !tt.statusOk {
584 if got, want := res.StatusCode, 404; got != want {
585 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
586 }
587 }
588 if got, want := string(slurp), tt.want; got != want {
589 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
590 }
591 }
592 }
593
594 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
595 setParallel(t)
596
597 mux := NewServeMux()
598 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
599 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
600 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
601 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
602 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
603 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
604
605 tests := []struct {
606 method string
607 url string
608 code int
609 loc string
610 want string
611 }{
612 {"GET", "http://example.com/", 404, "", ""},
613 {"GET", "http://example.com/pkg/foo", 307, "/pkg/foo/", ""},
614 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
615 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
616 {"GET", "http://example.com/pkg/baz", 307, "/pkg/baz/", ""},
617 {"GET", "http://example.com:3000/pkg/foo", 307, "/pkg/foo/", ""},
618 {"CONNECT", "http://example.com/", 404, "", ""},
619 {"CONNECT", "http://example.com:3000/", 404, "", ""},
620 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
621 {"CONNECT", "http://example.com/pkg/foo", 307, "/pkg/foo/", ""},
622 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
623 {"CONNECT", "http://example.com:3000/pkg/baz", 307, "/pkg/baz/", ""},
624 {"CONNECT", "http://example.com:3000/pkg/connect", 307, "/pkg/connect/", ""},
625 }
626
627 for i, tt := range tests {
628 req, _ := NewRequest(tt.method, tt.url, nil)
629 w := httptest.NewRecorder()
630 mux.ServeHTTP(w, req)
631
632 if got, want := w.Code, tt.code; got != want {
633 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
634 }
635
636 if tt.code == 301 {
637 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
638 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
639 }
640 } else {
641 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
642 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
643 }
644 }
645 }
646 }
647
648
649
650
651 func TestMuxNoSlashRedirectWithTrailingSlash(t *testing.T) {
652 mux := NewServeMux()
653 mux.HandleFunc("/{x}/", func(w ResponseWriter, r *Request) {
654 fmt.Fprintln(w, "ok")
655 })
656 w := httptest.NewRecorder()
657 req, _ := NewRequest("GET", "/", nil)
658 mux.ServeHTTP(w, req)
659 if g, w := w.Code, 404; g != w {
660 t.Errorf("got %d, want %d", g, w)
661 }
662 }
663
664
665
666
667 func TestMuxNoSlash405WithTrailingSlash(t *testing.T) {
668 mux := NewServeMux()
669 mux.HandleFunc("GET /{x}/", func(w ResponseWriter, r *Request) {
670 fmt.Fprintln(w, "ok")
671 })
672 w := httptest.NewRecorder()
673 req, _ := NewRequest("GET", "/", nil)
674 mux.ServeHTTP(w, req)
675 if g, w := w.Code, 404; g != w {
676 t.Errorf("got %d, want %d", g, w)
677 }
678 }
679
680 func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
681 func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
682 mux := NewServeMux()
683 newClientServerTest(t, mode, mux)
684 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
685 }
686
687 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
688 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
689 func benchmarkServeMux(b *testing.B, runHandler bool) {
690 type test struct {
691 path string
692 code int
693 req *Request
694 }
695
696
697 var tests []test
698 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
699 for _, e := range endpoints {
700 for i := 200; i < 230; i++ {
701 p := fmt.Sprintf("/%s/%d/", e, i)
702 tests = append(tests, test{
703 path: p,
704 code: i,
705 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
706 })
707 }
708 }
709 mux := NewServeMux()
710 for _, tt := range tests {
711 mux.Handle(tt.path, serve(tt.code))
712 }
713
714 rw := httptest.NewRecorder()
715 b.ReportAllocs()
716 b.ResetTimer()
717 for i := 0; i < b.N; i++ {
718 for _, tt := range tests {
719 *rw = httptest.ResponseRecorder{}
720 h, pattern := mux.Handler(tt.req)
721 if runHandler {
722 h.ServeHTTP(rw, tt.req)
723 if pattern != tt.path || rw.Code != tt.code {
724 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
725 }
726 }
727 }
728 }
729 }
730
731 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
732 func testServerTimeouts(t *testing.T, mode testMode) {
733 runTimeSensitiveTest(t, []time.Duration{
734 10 * time.Millisecond,
735 50 * time.Millisecond,
736 100 * time.Millisecond,
737 500 * time.Millisecond,
738 1 * time.Second,
739 }, func(t *testing.T, timeout time.Duration) error {
740 return testServerTimeoutsWithTimeout(t, timeout, mode)
741 })
742 }
743
744 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
745 var reqNum atomic.Int32
746 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
747 fmt.Fprintf(res, "req=%d", reqNum.Add(1))
748 }), func(ts *httptest.Server) {
749 ts.Config.ReadTimeout = timeout
750 ts.Config.WriteTimeout = timeout
751 })
752 defer cst.close()
753 ts := cst.ts
754
755
756 c := ts.Client()
757 r, err := c.Get(ts.URL)
758 if err != nil {
759 return fmt.Errorf("http Get #1: %v", err)
760 }
761 got, err := io.ReadAll(r.Body)
762 expected := "req=1"
763 if string(got) != expected || err != nil {
764 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
765 string(got), err, expected)
766 }
767
768
769 t1 := time.Now()
770 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
771 if err != nil {
772 return fmt.Errorf("Dial: %v", err)
773 }
774 buf := make([]byte, 1)
775 n, err := conn.Read(buf)
776 conn.Close()
777 latency := time.Since(t1)
778 if n != 0 || err != io.EOF {
779 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
780 }
781 minLatency := timeout / 5 * 4
782 if latency < minLatency {
783 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
784 }
785
786
787
788
789 r, err = c.Get(ts.URL)
790 if err != nil {
791 return fmt.Errorf("http Get #2: %v", err)
792 }
793 got, err = io.ReadAll(r.Body)
794 r.Body.Close()
795 expected = "req=2"
796 if string(got) != expected || err != nil {
797 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
798 }
799
800 if !testing.Short() {
801 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
802 if err != nil {
803 return fmt.Errorf("long Dial: %v", err)
804 }
805 defer conn.Close()
806 go io.Copy(io.Discard, conn)
807 for i := 0; i < 5; i++ {
808 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
809 if err != nil {
810 return fmt.Errorf("on write %d: %v", i, err)
811 }
812 time.Sleep(timeout / 2)
813 }
814 }
815 return nil
816 }
817
818 func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout) }
819 func testServerReadTimeout(t *testing.T, mode testMode) {
820 respBody := "response body"
821 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
822 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
823 _, err := io.Copy(io.Discard, req.Body)
824 if !errors.Is(err, os.ErrDeadlineExceeded) {
825 t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
826 }
827 res.Write([]byte(respBody))
828 }), func(ts *httptest.Server) {
829 ts.Config.ReadHeaderTimeout = -1
830 ts.Config.ReadTimeout = timeout
831 t.Logf("Server.Config.ReadTimeout = %v", timeout)
832 })
833
834 var retries atomic.Int32
835 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
836 if retries.Add(1) != 1 {
837 return nil, errors.New("too many retries")
838 }
839 return nil, nil
840 }
841
842 pr, pw := io.Pipe()
843 res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
844 if err != nil {
845 t.Logf("Get error, retrying: %v", err)
846 cst.close()
847 continue
848 }
849 defer res.Body.Close()
850 got, err := io.ReadAll(res.Body)
851 if string(got) != respBody || err != nil {
852 t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
853 }
854 pw.Close()
855 break
856 }
857 }
858
859 func TestServerNoReadTimeout(t *testing.T) { run(t, testServerNoReadTimeout) }
860 func testServerNoReadTimeout(t *testing.T, mode testMode) {
861 reqBody := "Hello, Gophers!"
862 resBody := "Hi, Gophers!"
863 for _, timeout := range []time.Duration{0, -1} {
864 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
865 ctl := NewResponseController(res)
866 ctl.EnableFullDuplex()
867 res.WriteHeader(StatusOK)
868
869
870 if err := ctl.Flush(); err != nil {
871 t.Errorf("server flush response: %v", err)
872 return
873 }
874 got, err := io.ReadAll(req.Body)
875 if string(got) != reqBody || err != nil {
876 t.Errorf("server read request body: %v; got %q, want %q", err, got, reqBody)
877 }
878 res.Write([]byte(resBody))
879 }), func(ts *httptest.Server) {
880 ts.Config.ReadTimeout = timeout
881 t.Logf("Server.Config.ReadTimeout = %d", timeout)
882 })
883
884 pr, pw := io.Pipe()
885 res, err := cst.c.Post(cst.ts.URL, "text/plain", pr)
886 if err != nil {
887 t.Fatal(err)
888 }
889 defer res.Body.Close()
890
891
892 time.Sleep(10 * time.Millisecond)
893 pw.Write([]byte(reqBody))
894 pw.Close()
895
896 got, err := io.ReadAll(res.Body)
897 if string(got) != resBody || err != nil {
898 t.Errorf("client read response body: %v; got %v, want %q", err, got, resBody)
899 }
900 }
901 }
902
903 func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout) }
904 func testServerWriteTimeout(t *testing.T, mode testMode) {
905 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
906 errc := make(chan error, 2)
907 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
908 errc <- nil
909 _, err := io.Copy(res, neverEnding('a'))
910 errc <- err
911 }), func(ts *httptest.Server) {
912 ts.Config.WriteTimeout = timeout
913 t.Logf("Server.Config.WriteTimeout = %v", timeout)
914 })
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933 var retries atomic.Int32
934 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
935 if retries.Add(1) != 1 {
936 return nil, errors.New("too many retries")
937 }
938 return nil, nil
939 }
940
941 res, err := cst.c.Get(cst.ts.URL)
942 if err != nil {
943
944 t.Logf("Get error, retrying: %v", err)
945 cst.close()
946 continue
947 }
948 defer res.Body.Close()
949 _, err = io.Copy(io.Discard, res.Body)
950 if err == nil {
951 t.Errorf("client reading from truncated request body: got nil error, want non-nil")
952 }
953 select {
954 case <-errc:
955 err = <-errc
956 if !errors.Is(err, os.ErrDeadlineExceeded) {
957 t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
958 }
959 return
960 default:
961
962 t.Logf("handler didn't run, retrying")
963 cst.close()
964 }
965 }
966 }
967
968 func TestServerNoWriteTimeout(t *testing.T) { run(t, testServerNoWriteTimeout) }
969 func testServerNoWriteTimeout(t *testing.T, mode testMode) {
970 for _, timeout := range []time.Duration{0, -1} {
971 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
972 _, err := io.Copy(res, neverEnding('a'))
973 t.Logf("server write response: %v", err)
974 }), func(ts *httptest.Server) {
975 ts.Config.WriteTimeout = timeout
976 t.Logf("Server.Config.WriteTimeout = %d", timeout)
977 })
978
979 res, err := cst.c.Get(cst.ts.URL)
980 if err != nil {
981 t.Fatal(err)
982 }
983 defer res.Body.Close()
984 n, err := io.CopyN(io.Discard, res.Body, 1<<20)
985 if n != 1<<20 || err != nil {
986 t.Errorf("client read response body: %d, %v", n, err)
987 }
988
989
990 res.Body.Close()
991 cst.ts.Config.Shutdown(context.Background())
992 }
993 }
994
995
996 func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
997 run(t, testWriteDeadlineExtendedOnNewRequest)
998 }
999 func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
1000 if testing.Short() {
1001 t.Skip("skipping in short mode")
1002 }
1003 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
1004 func(ts *httptest.Server) {
1005 ts.Config.WriteTimeout = 250 * time.Millisecond
1006 },
1007 ).ts
1008
1009 c := ts.Client()
1010
1011 for i := 1; i <= 3; i++ {
1012 req, err := NewRequest("GET", ts.URL, nil)
1013 if err != nil {
1014 t.Fatal(err)
1015 }
1016
1017 r, err := c.Do(req)
1018 if err != nil {
1019 t.Fatalf("http2 Get #%d: %v", i, err)
1020 }
1021 r.Body.Close()
1022 time.Sleep(ts.Config.WriteTimeout / 2)
1023 }
1024 }
1025
1026
1027
1028 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
1029 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
1030 for i, timeout := range tries {
1031 err := testFunc(timeout)
1032 if err == nil {
1033 return
1034 }
1035 t.Logf("failed at %v: %v", timeout, err)
1036 if i != len(tries)-1 {
1037 t.Logf("retrying at %v ...", tries[i+1])
1038 }
1039 }
1040 t.Fatal("all attempts failed")
1041 }
1042
1043
1044 func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
1045 if testing.Short() {
1046 t.Skip("skipping in short mode")
1047 }
1048 setParallel(t)
1049 run(t, func(t *testing.T, mode testMode) {
1050 tryTimeouts(t, func(timeout time.Duration) error {
1051 return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
1052 })
1053 })
1054 }
1055
1056 func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
1057 firstRequest := make(chan bool, 1)
1058 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1059 select {
1060 case firstRequest <- true:
1061
1062 default:
1063
1064 time.Sleep(timeout)
1065 }
1066 }), func(ts *httptest.Server) {
1067 ts.Config.WriteTimeout = timeout / 2
1068 })
1069 defer cst.close()
1070 ts := cst.ts
1071
1072 c := ts.Client()
1073
1074 req, err := NewRequest("GET", ts.URL, nil)
1075 if err != nil {
1076 return fmt.Errorf("NewRequest: %v", err)
1077 }
1078 r, err := c.Do(req)
1079 if err != nil {
1080 return fmt.Errorf("Get #1: %v", err)
1081 }
1082 r.Body.Close()
1083
1084 req, err = NewRequest("GET", ts.URL, nil)
1085 if err != nil {
1086 return fmt.Errorf("NewRequest: %v", err)
1087 }
1088 r, err = c.Do(req)
1089 if err == nil {
1090 r.Body.Close()
1091 return fmt.Errorf("Get #2 expected error, got nil")
1092 }
1093 if mode == http2Mode {
1094 expected := "stream ID 3; INTERNAL_ERROR"
1095 if !strings.Contains(err.Error(), expected) {
1096 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
1097 }
1098 }
1099 return nil
1100 }
1101
1102
1103 func TestNoWriteDeadline(t *testing.T) {
1104 if testing.Short() {
1105 t.Skip("skipping in short mode")
1106 }
1107 setParallel(t)
1108 defer afterTest(t)
1109 run(t, func(t *testing.T, mode testMode) {
1110 tryTimeouts(t, func(timeout time.Duration) error {
1111 return testNoWriteDeadline(t, mode, timeout)
1112 })
1113 })
1114 }
1115
1116 func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
1117 firstRequest := make(chan bool, 1)
1118 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1119 select {
1120 case firstRequest <- true:
1121
1122 default:
1123
1124 time.Sleep(timeout)
1125 }
1126 }))
1127 defer cst.close()
1128 ts := cst.ts
1129
1130 c := ts.Client()
1131
1132 for i := 0; i < 2; i++ {
1133 req, err := NewRequest("GET", ts.URL, nil)
1134 if err != nil {
1135 return fmt.Errorf("NewRequest: %v", err)
1136 }
1137 r, err := c.Do(req)
1138 if err != nil {
1139 return fmt.Errorf("Get #%d: %v", i, err)
1140 }
1141 r.Body.Close()
1142 }
1143 return nil
1144 }
1145
1146
1147
1148
1149 func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
1150 func testOnlyWriteTimeout(t *testing.T, mode testMode) {
1151 var (
1152 mu sync.RWMutex
1153 conn net.Conn
1154 )
1155 var afterTimeoutErrc = make(chan error, 1)
1156 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
1157 buf := make([]byte, 512<<10)
1158 _, err := w.Write(buf)
1159 if err != nil {
1160 t.Errorf("handler Write error: %v", err)
1161 return
1162 }
1163 mu.RLock()
1164 defer mu.RUnlock()
1165 if conn == nil {
1166 t.Error("no established connection found")
1167 return
1168 }
1169 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
1170 _, err = w.Write(buf)
1171 afterTimeoutErrc <- err
1172 }), func(ts *httptest.Server) {
1173 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
1174 }).ts
1175
1176 c := ts.Client()
1177
1178 err := func() error {
1179 res, err := c.Get(ts.URL)
1180 if err != nil {
1181 return err
1182 }
1183 _, err = io.Copy(io.Discard, res.Body)
1184 res.Body.Close()
1185 return err
1186 }()
1187 if err == nil {
1188 t.Errorf("expected an error copying body from Get request")
1189 }
1190
1191 if err := <-afterTimeoutErrc; err == nil {
1192 t.Error("expected write error after timeout")
1193 }
1194 }
1195
1196
1197 type trackLastConnListener struct {
1198 net.Listener
1199
1200 mu *sync.RWMutex
1201 last *net.Conn
1202 }
1203
1204 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
1205 c, err = l.Listener.Accept()
1206 if err == nil {
1207 l.mu.Lock()
1208 *l.last = c
1209 l.mu.Unlock()
1210 }
1211 return
1212 }
1213
1214
1215 func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
1216 func testIdentityResponse(t *testing.T, mode testMode) {
1217 if mode == http2Mode {
1218 t.Skip("https://go.dev/issue/56019")
1219 }
1220
1221 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
1222 rw.Header().Set("Content-Length", "3")
1223 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1224 switch {
1225 case req.FormValue("overwrite") == "1":
1226 _, err := rw.Write([]byte("foo TOO LONG"))
1227 if err != ErrContentLength {
1228 t.Errorf("expected ErrContentLength; got %v", err)
1229 }
1230 case req.FormValue("underwrite") == "1":
1231 rw.Header().Set("Content-Length", "500")
1232 rw.Write([]byte("too short"))
1233 default:
1234 rw.Write([]byte("foo"))
1235 }
1236 })
1237
1238 ts := newClientServerTest(t, mode, handler).ts
1239 c := ts.Client()
1240
1241
1242
1243
1244
1245 for _, te := range []string{"", "identity"} {
1246 url := ts.URL + "/?te=" + te
1247 res, err := c.Get(url)
1248 if err != nil {
1249 t.Fatalf("error with Get of %s: %v", url, err)
1250 }
1251 if cl, expected := res.ContentLength, int64(3); cl != expected {
1252 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1253 }
1254 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1255 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1256 }
1257 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1258 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1259 url, expected, tl, res.TransferEncoding)
1260 }
1261 res.Body.Close()
1262 }
1263
1264
1265 url := ts.URL + "/?overwrite=1"
1266 res, err := c.Get(url)
1267 if err != nil {
1268 t.Fatalf("error with Get of %s: %v", url, err)
1269 }
1270 res.Body.Close()
1271
1272 if mode != http1Mode {
1273 return
1274 }
1275
1276
1277
1278 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1279 if err != nil {
1280 t.Fatalf("error dialing: %v", err)
1281 }
1282 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1283 if err != nil {
1284 t.Fatalf("error writing: %v", err)
1285 }
1286
1287
1288 got, _ := io.ReadAll(conn)
1289 expectedSuffix := "\r\n\r\ntoo short"
1290 if !strings.HasSuffix(string(got), expectedSuffix) {
1291 t.Errorf("Expected output to end with %q; got response body %q",
1292 expectedSuffix, string(got))
1293 }
1294 }
1295
1296 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1297 setParallel(t)
1298 s := newClientServerTest(t, http1Mode, h).ts
1299
1300 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1301 if err != nil {
1302 t.Fatal("dial error:", err)
1303 }
1304 defer conn.Close()
1305
1306 _, err = fmt.Fprint(conn, req)
1307 if err != nil {
1308 t.Fatal("print error:", err)
1309 }
1310
1311 r := bufio.NewReader(conn)
1312 res, err := ReadResponse(r, &Request{Method: "GET"})
1313 if err != nil {
1314 t.Fatal("ReadResponse error:", err)
1315 }
1316
1317 _, err = io.ReadAll(r)
1318 if err != nil {
1319 t.Fatal("read error:", err)
1320 }
1321
1322 if !res.Close {
1323 t.Errorf("Response.Close = false; want true")
1324 }
1325 }
1326
1327 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1328 setParallel(t)
1329 ts := newClientServerTest(t, http1Mode, handler).ts
1330 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1331 if err != nil {
1332 t.Fatal(err)
1333 }
1334 defer conn.Close()
1335 br := bufio.NewReader(conn)
1336 for i := 0; i < 2; i++ {
1337 if _, err := io.WriteString(conn, req); err != nil {
1338 t.Fatal(err)
1339 }
1340 res, err := ReadResponse(br, nil)
1341 if err != nil {
1342 t.Fatalf("res %d: %v", i+1, err)
1343 }
1344 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1345 t.Fatalf("res %d body copy: %v", i+1, err)
1346 }
1347 res.Body.Close()
1348 }
1349 }
1350
1351
1352 func TestServeHTTP10Close(t *testing.T) {
1353 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1354 ServeFile(w, r, "testdata/file")
1355 }))
1356 }
1357
1358
1359 func TestClientCanClose(t *testing.T) {
1360 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1361
1362 }))
1363 }
1364
1365
1366
1367 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1368 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1369 w.Header().Set("Connection", "close")
1370 }))
1371 }
1372
1373 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1374 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1375 w.Header().Set("Connection", "close")
1376 }))
1377 }
1378
1379 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1380 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1381
1382
1383 }))
1384 }
1385
1386 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1387 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1388
1389
1390 func TestHTTP10KeepAlive204Response(t *testing.T) {
1391 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1392 }
1393
1394 func TestHTTP11KeepAlive204Response(t *testing.T) {
1395 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1396 }
1397
1398 func TestHTTP10KeepAlive304Response(t *testing.T) {
1399 testTCPConnectionStaysOpen(t,
1400 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1401 HandlerFunc(send304))
1402 }
1403
1404
1405 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
1406 func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
1407 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1408 w.(Flusher).Flush()
1409 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1410 }))
1411 type data struct {
1412 Addr string
1413 }
1414 var addrs [2]data
1415 for i := range addrs {
1416 res, err := cst.c.Get(cst.ts.URL)
1417 if err != nil {
1418 t.Fatal(err)
1419 }
1420 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1421 t.Fatal(err)
1422 }
1423 if addrs[i].Addr == "" {
1424 t.Fatal("no address")
1425 }
1426 res.Body.Close()
1427 }
1428 if addrs[0] != addrs[1] {
1429 t.Fatalf("connection not reused")
1430 }
1431 }
1432
1433 func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
1434 func testSetsRemoteAddr(t *testing.T, mode testMode) {
1435 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1436 fmt.Fprintf(w, "%s", r.RemoteAddr)
1437 }))
1438
1439 res, err := cst.c.Get(cst.ts.URL)
1440 if err != nil {
1441 t.Fatalf("Get error: %v", err)
1442 }
1443 body, err := io.ReadAll(res.Body)
1444 if err != nil {
1445 t.Fatalf("ReadAll error: %v", err)
1446 }
1447 ip := string(body)
1448 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1449 t.Fatalf("Expected local addr; got %q", ip)
1450 }
1451 }
1452
1453 type blockingRemoteAddrListener struct {
1454 net.Listener
1455 conns chan<- net.Conn
1456 }
1457
1458 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1459 c, err := l.Listener.Accept()
1460 if err != nil {
1461 return nil, err
1462 }
1463 brac := &blockingRemoteAddrConn{
1464 Conn: c,
1465 addrs: make(chan net.Addr, 1),
1466 }
1467 l.conns <- brac
1468 return brac, nil
1469 }
1470
1471 type blockingRemoteAddrConn struct {
1472 net.Conn
1473 addrs chan net.Addr
1474 }
1475
1476 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1477 return <-c.addrs
1478 }
1479
1480
1481 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1482 run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
1483 }
1484 func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
1485 conns := make(chan net.Conn)
1486 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1487 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1488 }), func(ts *httptest.Server) {
1489 ts.Listener = &blockingRemoteAddrListener{
1490 Listener: ts.Listener,
1491 conns: conns,
1492 }
1493 }).ts
1494
1495 c := ts.Client()
1496
1497 c.Transport.(*Transport).DisableKeepAlives = true
1498
1499 fetch := func(num int, response chan<- string) {
1500 resp, err := c.Get(ts.URL)
1501 if err != nil {
1502 t.Errorf("Request %d: %v", num, err)
1503 response <- ""
1504 return
1505 }
1506 defer resp.Body.Close()
1507 body, err := io.ReadAll(resp.Body)
1508 if err != nil {
1509 t.Errorf("Request %d: %v", num, err)
1510 response <- ""
1511 return
1512 }
1513 response <- string(body)
1514 }
1515
1516
1517 response1c := make(chan string, 1)
1518 go fetch(1, response1c)
1519
1520
1521 conn1 := <-conns
1522
1523
1524 response2c := make(chan string, 1)
1525 go fetch(2, response2c)
1526 conn2 := <-conns
1527
1528
1529 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1530 IP: net.ParseIP("12.12.12.12"), Port: 12}
1531
1532
1533 response2 := <-response2c
1534 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1535 t.Fatalf("response 2 addr = %q; want %q", g, e)
1536 }
1537
1538
1539 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1540 IP: net.ParseIP("21.21.21.21"), Port: 21}
1541
1542
1543 response1 := <-response1c
1544 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1545 t.Fatalf("response 1 addr = %q; want %q", g, e)
1546 }
1547 }
1548
1549
1550
1551 func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
1552 func testHeadResponses(t *testing.T, mode testMode) {
1553 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1554 _, err := w.Write([]byte("<html>"))
1555 if err != nil {
1556 t.Errorf("ResponseWriter.Write: %v", err)
1557 }
1558
1559
1560 _, err = io.Copy(w, struct{ io.Reader }{strings.NewReader("789a")})
1561 if err != nil {
1562 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1563 }
1564 }))
1565 res, err := cst.c.Head(cst.ts.URL)
1566 if err != nil {
1567 t.Error(err)
1568 }
1569 if len(res.TransferEncoding) > 0 {
1570 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1571 }
1572 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1573 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1574 }
1575 if v := res.ContentLength; v != 10 {
1576 t.Errorf("Content-Length: %d; want 10", v)
1577 }
1578 body, err := io.ReadAll(res.Body)
1579 if err != nil {
1580 t.Error(err)
1581 }
1582 if len(body) > 0 {
1583 t.Errorf("got unexpected body %q", string(body))
1584 }
1585 }
1586
1587
1588
1589 func TestHeadReaderFrom(t *testing.T) { run(t, testHeadReaderFrom, []testMode{http1Mode}) }
1590 func testHeadReaderFrom(t *testing.T, mode testMode) {
1591
1592 wantBody := strings.Repeat("a", 4096)
1593 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1594 w.(io.ReaderFrom).ReadFrom(strings.NewReader(wantBody))
1595 }))
1596 res, err := cst.c.Head(cst.ts.URL)
1597 if err != nil {
1598 t.Fatal(err)
1599 }
1600 res.Body.Close()
1601 res, err = cst.c.Get(cst.ts.URL)
1602 if err != nil {
1603 t.Fatal(err)
1604 }
1605 gotBody, err := io.ReadAll(res.Body)
1606 res.Body.Close()
1607 if err != nil {
1608 t.Fatal(err)
1609 }
1610 if string(gotBody) != wantBody {
1611 t.Errorf("got unexpected body len=%v, want %v", len(gotBody), len(wantBody))
1612 }
1613 }
1614
1615
1616
1617 func TestReaderFromTooLong(t *testing.T) { run(t, testReaderFromTooLong, []testMode{http1Mode}) }
1618 func testReaderFromTooLong(t *testing.T, mode testMode) {
1619 contentLen := 600
1620 tests := []struct {
1621 name string
1622 reader io.Reader
1623 wantHandlerErr error
1624 }{
1625 {
1626 name: "reader of correct length",
1627 reader: strings.NewReader(strings.Repeat("a", contentLen)),
1628 },
1629 {
1630 name: "wrapped reader of correct outer length",
1631 reader: io.LimitReader(strings.NewReader(strings.Repeat("a", 2*contentLen)), int64(contentLen)),
1632 },
1633 {
1634 name: "wrapped reader of correct inner length",
1635 reader: io.LimitReader(io.LimitReader(strings.NewReader(strings.Repeat("a", 2*contentLen)), int64(contentLen)), int64(2*contentLen)),
1636 },
1637 {
1638 name: "reader that is too long",
1639 reader: strings.NewReader(strings.Repeat("a", 2*contentLen)),
1640 wantHandlerErr: ErrContentLength,
1641 },
1642 {
1643 name: "wrapped reader that is too long",
1644 reader: io.LimitReader(strings.NewReader(strings.Repeat("a", 2*contentLen)), int64(2*contentLen)),
1645 wantHandlerErr: ErrContentLength,
1646 },
1647 }
1648
1649 for _, tc := range tests {
1650 t.Run(tc.name, func(t *testing.T) {
1651 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1652 w.Header().Set("Content-Length", strconv.Itoa(contentLen))
1653 n, err := w.(io.ReaderFrom).ReadFrom(tc.reader)
1654 if int(n) != contentLen || !errors.Is(err, tc.wantHandlerErr) {
1655 t.Errorf("got %v, %v from w.ReadFrom; want %v, %v", n, err, contentLen, tc.wantHandlerErr)
1656 }
1657 }))
1658 res, err := cst.c.Get(cst.ts.URL)
1659 if err != nil {
1660 t.Fatal(err)
1661 }
1662 defer res.Body.Close()
1663 gotBody, err := io.ReadAll(res.Body)
1664 if err != nil {
1665 t.Fatal(err)
1666 }
1667 if len(gotBody) != contentLen {
1668 t.Errorf("got unexpected body len=%v, want %v", len(gotBody), contentLen)
1669 }
1670 })
1671 }
1672 }
1673
1674 func TestTLSHandshakeTimeout(t *testing.T) {
1675 run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
1676 }
1677 func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
1678 errLog := new(strings.Builder)
1679 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
1680 func(ts *httptest.Server) {
1681 ts.Config.ReadTimeout = 250 * time.Millisecond
1682 ts.Config.ErrorLog = log.New(errLog, "", 0)
1683 },
1684 )
1685 ts := cst.ts
1686
1687 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1688 if err != nil {
1689 t.Fatalf("Dial: %v", err)
1690 }
1691 var buf [1]byte
1692 n, err := conn.Read(buf[:])
1693 if err == nil || n != 0 {
1694 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1695 }
1696 conn.Close()
1697
1698 cst.close()
1699 if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1700 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1701 }
1702 }
1703
1704 func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
1705 func testTLSServer(t *testing.T, mode testMode) {
1706 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1707 if r.TLS != nil {
1708 w.Header().Set("X-TLS-Set", "true")
1709 if r.TLS.HandshakeComplete {
1710 w.Header().Set("X-TLS-HandshakeComplete", "true")
1711 }
1712 }
1713 }), func(ts *httptest.Server) {
1714 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1715 }).ts
1716
1717
1718
1719
1720
1721
1722 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1723 if err != nil {
1724 t.Fatalf("Dial: %v", err)
1725 }
1726 defer idleConn.Close()
1727
1728 if !strings.HasPrefix(ts.URL, "https://") {
1729 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1730 return
1731 }
1732 client := ts.Client()
1733 res, err := client.Get(ts.URL)
1734 if err != nil {
1735 t.Error(err)
1736 return
1737 }
1738 if res == nil {
1739 t.Errorf("got nil Response")
1740 return
1741 }
1742 defer res.Body.Close()
1743 if res.Header.Get("X-TLS-Set") != "true" {
1744 t.Errorf("expected X-TLS-Set response header")
1745 return
1746 }
1747 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1748 t.Errorf("expected X-TLS-HandshakeComplete header")
1749 }
1750 }
1751
1752 type fakeConnectionStateConn struct {
1753 net.Conn
1754 }
1755
1756 func (fcsc *fakeConnectionStateConn) ConnectionState() tls.ConnectionState {
1757 return tls.ConnectionState{
1758 ServerName: "example.com",
1759 }
1760 }
1761
1762 func TestTLSServerWithoutTLSConn(t *testing.T) {
1763
1764 pr, pw := net.Pipe()
1765 c := make(chan int)
1766 listener := &oneConnListener{&fakeConnectionStateConn{pr}}
1767 server := &Server{
1768 Handler: HandlerFunc(func(writer ResponseWriter, request *Request) {
1769 if request.TLS == nil {
1770 t.Fatal("request.TLS is nil, expected not nil")
1771 }
1772 if request.TLS.ServerName != "example.com" {
1773 t.Fatalf("request.TLS.ServerName is %s, expected %s", request.TLS.ServerName, "example.com")
1774 }
1775 writer.Header().Set("X-TLS-ServerName", "example.com")
1776 }),
1777 }
1778
1779
1780 go func() {
1781 req, _ := NewRequest(MethodGet, "https://example.com", nil)
1782 req.Write(pw)
1783
1784 resp, _ := ReadResponse(bufio.NewReader(pw), req)
1785 if hdr := resp.Header.Get("X-TLS-ServerName"); hdr != "example.com" {
1786 t.Errorf("response header X-TLS-ServerName is %s, expected %s", hdr, "example.com")
1787 }
1788 close(c)
1789 pw.Close()
1790 }()
1791
1792 server.Serve(listener)
1793
1794
1795 <-c
1796 pr.Close()
1797 }
1798
1799 func TestServeTLS(t *testing.T) {
1800 CondSkipHTTP2(t)
1801
1802 defer afterTest(t)
1803 defer SetTestHookServerServe(nil)
1804
1805 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1806 if err != nil {
1807 t.Fatal(err)
1808 }
1809 tlsConf := &tls.Config{
1810 Certificates: []tls.Certificate{cert},
1811 }
1812
1813 ln := newLocalListener(t)
1814 defer ln.Close()
1815 addr := ln.Addr().String()
1816
1817 serving := make(chan bool, 1)
1818 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1819 serving <- true
1820 })
1821 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1822 s := &Server{
1823 Addr: addr,
1824 TLSConfig: tlsConf,
1825 Handler: handler,
1826 }
1827 errc := make(chan error, 1)
1828 go func() { errc <- s.ServeTLS(ln, "", "") }()
1829 select {
1830 case err := <-errc:
1831 t.Fatalf("ServeTLS: %v", err)
1832 case <-serving:
1833 }
1834
1835 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1836 InsecureSkipVerify: true,
1837 NextProtos: []string{"h2", "http/1.1"},
1838 })
1839 if err != nil {
1840 t.Fatal(err)
1841 }
1842 defer c.Close()
1843 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1844 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1845 }
1846 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1847 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1848 }
1849 }
1850
1851
1852 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1853 run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
1854 }
1855 func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
1856 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1857 t.Error("unexpected HTTPS request")
1858 }), func(ts *httptest.Server) {
1859 var errBuf bytes.Buffer
1860 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1861 }).ts
1862 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1863 if err != nil {
1864 t.Fatal(err)
1865 }
1866 defer conn.Close()
1867 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1868 slurp, err := io.ReadAll(conn)
1869 if err != nil {
1870 t.Fatal(err)
1871 }
1872 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1873 if !strings.HasPrefix(string(slurp), wantPrefix) {
1874 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1875 }
1876 }
1877
1878
1879 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1880 testAutomaticHTTP2_Serve(t, nil, true)
1881 }
1882
1883 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1884 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1885 }
1886
1887 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1888 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1889 }
1890
1891 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1892 setParallel(t)
1893 defer afterTest(t)
1894 ln := newLocalListener(t)
1895 ln.Close()
1896 var s Server
1897 s.TLSConfig = tlsConf
1898 if err := s.Serve(ln); err == nil {
1899 t.Fatal("expected an error")
1900 }
1901 gotH2 := s.TLSNextProto["h2"] != nil
1902 if gotH2 != wantH2 {
1903 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1904 }
1905 }
1906
1907 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1908 setParallel(t)
1909 defer afterTest(t)
1910 ln := newLocalListener(t)
1911 ln.Close()
1912 var s Server
1913
1914
1915 s.TLSConfig = &tls.Config{
1916 NextProtos: []string{"h2"},
1917 }
1918 if err := s.Serve(ln); err == nil {
1919 t.Fatal("expected an error")
1920 }
1921 on := s.TLSNextProto["h2"] != nil
1922 if !on {
1923 t.Errorf("http2 wasn't automatically enabled")
1924 }
1925 }
1926
1927 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1928 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1929 if err != nil {
1930 t.Fatal(err)
1931 }
1932 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1933 Certificates: []tls.Certificate{cert},
1934 })
1935 }
1936
1937 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1938 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1939 if err != nil {
1940 t.Fatal(err)
1941 }
1942 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1943 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1944 return &cert, nil
1945 },
1946 })
1947 }
1948
1949 func TestAutomaticHTTP2_ListenAndServe_GetConfigForClient(t *testing.T) {
1950 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1951 if err != nil {
1952 t.Fatal(err)
1953 }
1954 conf := &tls.Config{
1955
1956
1957 NextProtos: []string{"h2"},
1958 Certificates: []tls.Certificate{cert},
1959 }
1960 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1961 GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
1962 return conf, nil
1963 },
1964 })
1965 }
1966
1967 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1968 CondSkipHTTP2(t)
1969
1970 defer afterTest(t)
1971 defer SetTestHookServerServe(nil)
1972 var ok bool
1973 var s *Server
1974 const maxTries = 5
1975 var ln net.Listener
1976 Try:
1977 for try := 0; try < maxTries; try++ {
1978 ln = newLocalListener(t)
1979 addr := ln.Addr().String()
1980 ln.Close()
1981 t.Logf("Got %v", addr)
1982 lnc := make(chan net.Listener, 1)
1983 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1984 lnc <- ln
1985 })
1986 s = &Server{
1987 Addr: addr,
1988 TLSConfig: tlsConf,
1989 }
1990 errc := make(chan error, 1)
1991 go func() { errc <- s.ListenAndServeTLS("", "") }()
1992 select {
1993 case err := <-errc:
1994 t.Logf("On try #%v: %v", try+1, err)
1995 continue
1996 case ln = <-lnc:
1997 ok = true
1998 t.Logf("Listening on %v", ln.Addr().String())
1999 break Try
2000 }
2001 }
2002 if !ok {
2003 t.Fatalf("Failed to start up after %d tries", maxTries)
2004 }
2005 defer ln.Close()
2006 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
2007 InsecureSkipVerify: true,
2008 NextProtos: []string{"h2", "http/1.1"},
2009 })
2010 if err != nil {
2011 t.Fatal(err)
2012 }
2013 defer c.Close()
2014 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
2015 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
2016 }
2017 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
2018 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
2019 }
2020 }
2021
2022 type serverExpectTest struct {
2023 contentLength int
2024 chunked bool
2025 expectation string
2026 readBody bool
2027 expectedResponse string
2028 }
2029
2030 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
2031 return serverExpectTest{
2032 contentLength: contentLength,
2033 expectation: expectation,
2034 readBody: readBody,
2035 expectedResponse: expectedResponse,
2036 }
2037 }
2038
2039 var serverExpectTests = []serverExpectTest{
2040
2041 expectTest(100, "100-continue", true, "100 Continue"),
2042 expectTest(100, "100-cOntInUE", true, "100 Continue"),
2043
2044
2045 expectTest(100, "", true, "200 OK"),
2046
2047
2048
2049 expectTest(100, "100-continue", false, "401 Unauthorized"),
2050
2051 expectTest(100, "", false, "401 Unauthorized"),
2052
2053
2054 expectTest(0, "a-pony", false, "417 Expectation Failed"),
2055
2056
2057 expectTest(0, "100-continue", true, "200 OK"),
2058
2059 expectTest(0, "100-continue", false, "401 Unauthorized"),
2060
2061 {
2062 expectation: "100-continue",
2063 readBody: true,
2064 chunked: true,
2065 expectedResponse: "100 Continue",
2066 },
2067 }
2068
2069
2070
2071 func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
2072 func testServerExpect(t *testing.T, mode testMode) {
2073 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2074
2075
2076
2077 if strings.Contains(r.URL.RawQuery, "readbody=true") {
2078 io.ReadAll(r.Body)
2079 w.Write([]byte("Hi"))
2080 } else {
2081 w.WriteHeader(StatusUnauthorized)
2082 }
2083 })).ts
2084
2085 runTest := func(test serverExpectTest) {
2086 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
2087 if err != nil {
2088 t.Fatalf("Dial: %v", err)
2089 }
2090 defer conn.Close()
2091
2092
2093
2094 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
2095
2096 wg := sync.WaitGroup{}
2097 wg.Add(1)
2098 defer wg.Wait()
2099
2100 go func() {
2101 defer wg.Done()
2102
2103 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
2104 if test.chunked {
2105 contentLen = "Transfer-Encoding: chunked"
2106 }
2107 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
2108 "Connection: close\r\n"+
2109 "%s\r\n"+
2110 "Expect: %s\r\nHost: foo\r\n\r\n",
2111 test.readBody, contentLen, test.expectation)
2112 if err != nil {
2113 t.Errorf("On test %#v, error writing request headers: %v", test, err)
2114 return
2115 }
2116 if writeBody {
2117 var targ io.WriteCloser = struct {
2118 io.Writer
2119 io.Closer
2120 }{
2121 conn,
2122 io.NopCloser(nil),
2123 }
2124 if test.chunked {
2125 targ = httputil.NewChunkedWriter(conn)
2126 }
2127 body := strings.Repeat("A", test.contentLength)
2128 _, err = fmt.Fprint(targ, body)
2129 if err == nil {
2130 err = targ.Close()
2131 }
2132 if err != nil {
2133 if !test.readBody {
2134
2135
2136 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
2137 return
2138 }
2139 t.Errorf("On test %#v, error writing request body: %v", test, err)
2140 }
2141 }
2142 }()
2143 bufr := bufio.NewReader(conn)
2144 line, err := bufr.ReadString('\n')
2145 if err != nil {
2146 if writeBody && !test.readBody {
2147
2148
2149
2150
2151
2152 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
2153 return
2154 }
2155 t.Fatalf("On test %#v, ReadString: %v", test, err)
2156 }
2157 if !strings.Contains(line, test.expectedResponse) {
2158 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
2159 }
2160 }
2161
2162 for _, test := range serverExpectTests {
2163 runTest(test)
2164 }
2165 }
2166
2167
2168
2169 func TestServerUnreadRequestBodyLittle(t *testing.T) {
2170 setParallel(t)
2171 defer afterTest(t)
2172 conn := new(testConn)
2173 body := strings.Repeat("x", 100<<10)
2174 conn.readBuf.Write([]byte(fmt.Sprintf(
2175 "POST / HTTP/1.1\r\n"+
2176 "Host: test\r\n"+
2177 "Content-Length: %d\r\n"+
2178 "\r\n", len(body))))
2179 conn.readBuf.Write([]byte(body))
2180
2181 done := make(chan bool)
2182
2183 readBufLen := func() int {
2184 conn.readMu.Lock()
2185 defer conn.readMu.Unlock()
2186 return conn.readBuf.Len()
2187 }
2188
2189 ls := &oneConnListener{conn}
2190 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2191 defer close(done)
2192 if bufLen := readBufLen(); bufLen < len(body)/2 {
2193 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
2194 }
2195 rw.WriteHeader(200)
2196 rw.(Flusher).Flush()
2197 if g, e := readBufLen(), 0; g != e {
2198 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
2199 }
2200 if c := rw.Header().Get("Connection"); c != "" {
2201 t.Errorf(`Connection header = %q; want ""`, c)
2202 }
2203 }))
2204 <-done
2205 }
2206
2207
2208
2209
2210 func TestServerUnreadRequestBodyLarge(t *testing.T) {
2211 setParallel(t)
2212 if testing.Short() && testenv.Builder() == "" {
2213 t.Log("skipping in short mode")
2214 }
2215 conn := new(testConn)
2216 body := strings.Repeat("x", 1<<20)
2217 conn.readBuf.Write([]byte(fmt.Sprintf(
2218 "POST / HTTP/1.1\r\n"+
2219 "Host: test\r\n"+
2220 "Content-Length: %d\r\n"+
2221 "\r\n", len(body))))
2222 conn.readBuf.Write([]byte(body))
2223 conn.closec = make(chan bool, 1)
2224
2225 ls := &oneConnListener{conn}
2226 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2227 if conn.readBuf.Len() < len(body)/2 {
2228 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2229 }
2230 rw.WriteHeader(200)
2231 rw.(Flusher).Flush()
2232 if conn.readBuf.Len() < len(body)/2 {
2233 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2234 }
2235 }))
2236 <-conn.closec
2237
2238 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
2239 t.Errorf("Expected a Connection: close header; got response: %s", res)
2240 }
2241 }
2242
2243 type handlerBodyCloseTest struct {
2244 bodySize int
2245 bodyChunked bool
2246 reqConnClose bool
2247
2248 wantEOFSearch bool
2249 wantNextReq bool
2250 }
2251
2252 func (t handlerBodyCloseTest) connectionHeader() string {
2253 if t.reqConnClose {
2254 return "Connection: close\r\n"
2255 }
2256 return ""
2257 }
2258
2259 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
2260
2261
2262 0: {
2263 bodySize: 20 << 10,
2264 bodyChunked: false,
2265 reqConnClose: false,
2266 wantEOFSearch: true,
2267 wantNextReq: true,
2268 },
2269
2270
2271
2272 1: {
2273 bodySize: 20 << 10,
2274 bodyChunked: true,
2275 reqConnClose: false,
2276 wantEOFSearch: true,
2277 wantNextReq: true,
2278 },
2279
2280
2281
2282
2283 2: {
2284 bodySize: 20 << 10,
2285 bodyChunked: false,
2286 reqConnClose: true,
2287 wantEOFSearch: false,
2288 wantNextReq: false,
2289 },
2290
2291
2292
2293
2294
2295
2296 3: {
2297 bodySize: 20 << 10,
2298 bodyChunked: true,
2299 reqConnClose: true,
2300 wantEOFSearch: true,
2301 wantNextReq: false,
2302 },
2303
2304
2305 4: {
2306 bodySize: 1 << 20,
2307 bodyChunked: false,
2308 reqConnClose: false,
2309 wantEOFSearch: false,
2310 wantNextReq: false,
2311 },
2312
2313
2314 5: {
2315 bodySize: 1 << 20,
2316 bodyChunked: true,
2317 reqConnClose: false,
2318 wantEOFSearch: true,
2319 wantNextReq: false,
2320 },
2321
2322
2323
2324
2325 6: {
2326 bodySize: 1 << 20,
2327 bodyChunked: true,
2328 reqConnClose: true,
2329 wantEOFSearch: true,
2330 wantNextReq: false,
2331 },
2332
2333
2334
2335 7: {
2336 bodySize: 1 << 20,
2337 bodyChunked: false,
2338 reqConnClose: true,
2339 wantEOFSearch: false,
2340 wantNextReq: false,
2341 },
2342 }
2343
2344 func TestHandlerBodyClose(t *testing.T) {
2345 setParallel(t)
2346 if testing.Short() && testenv.Builder() == "" {
2347 t.Skip("skipping in -short mode")
2348 }
2349 for i, tt := range handlerBodyCloseTests {
2350 testHandlerBodyClose(t, i, tt)
2351 }
2352 }
2353
2354 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2355 conn := new(testConn)
2356 body := strings.Repeat("x", tt.bodySize)
2357 if tt.bodyChunked {
2358 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2359 "Host: test\r\n" +
2360 tt.connectionHeader() +
2361 "Transfer-Encoding: chunked\r\n" +
2362 "\r\n")
2363 cw := internal.NewChunkedWriter(&conn.readBuf)
2364 io.WriteString(cw, body)
2365 cw.Close()
2366 conn.readBuf.WriteString("\r\n")
2367 } else {
2368 conn.readBuf.Write([]byte(fmt.Sprintf(
2369 "POST / HTTP/1.1\r\n"+
2370 "Host: test\r\n"+
2371 tt.connectionHeader()+
2372 "Content-Length: %d\r\n"+
2373 "\r\n", len(body))))
2374 conn.readBuf.Write([]byte(body))
2375 }
2376 if !tt.reqConnClose {
2377 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2378 }
2379 conn.closec = make(chan bool, 1)
2380
2381 readBufLen := func() int {
2382 conn.readMu.Lock()
2383 defer conn.readMu.Unlock()
2384 return conn.readBuf.Len()
2385 }
2386
2387 ls := &oneConnListener{conn}
2388 var numReqs int
2389 var size0, size1 int
2390 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2391 numReqs++
2392 if numReqs == 1 {
2393 size0 = readBufLen()
2394 req.Body.Close()
2395 size1 = readBufLen()
2396 }
2397 }))
2398 <-conn.closec
2399 if numReqs < 1 || numReqs > 2 {
2400 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2401 }
2402 didSearch := size0 != size1
2403 if didSearch != tt.wantEOFSearch {
2404 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2405 }
2406 if tt.wantNextReq && numReqs != 2 {
2407 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2408 }
2409 }
2410
2411
2412
2413 type testHandlerBodyConsumer struct {
2414 name string
2415 f func(io.ReadCloser)
2416 }
2417
2418 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2419 {"nil", func(io.ReadCloser) {}},
2420 {"close", func(r io.ReadCloser) { r.Close() }},
2421 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2422 }
2423
2424 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2425 setParallel(t)
2426 defer afterTest(t)
2427 for _, handler := range testHandlerBodyConsumers {
2428 conn := new(testConn)
2429 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2430 "Host: test\r\n" +
2431 "Transfer-Encoding: chunked\r\n" +
2432 "\r\n" +
2433 "hax\r\n" +
2434 "GET /secret HTTP/1.1\r\n" +
2435 "Host: test\r\n" +
2436 "\r\n")
2437
2438 conn.closec = make(chan bool, 1)
2439 ls := &oneConnListener{conn}
2440 var numReqs int
2441 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2442 numReqs++
2443 if strings.Contains(req.URL.Path, "secret") {
2444 t.Error("Request for /secret encountered, should not have happened.")
2445 }
2446 handler.f(req.Body)
2447 }))
2448 <-conn.closec
2449 if numReqs != 1 {
2450 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2451 }
2452 }
2453 }
2454
2455 func TestInvalidTrailerClosesConnection(t *testing.T) {
2456 setParallel(t)
2457 defer afterTest(t)
2458 for _, handler := range testHandlerBodyConsumers {
2459 conn := new(testConn)
2460 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2461 "Host: test\r\n" +
2462 "Trailer: hack\r\n" +
2463 "Transfer-Encoding: chunked\r\n" +
2464 "\r\n" +
2465 "3\r\n" +
2466 "hax\r\n" +
2467 "0\r\n" +
2468 "I'm not a valid trailer\r\n" +
2469 "GET /secret HTTP/1.1\r\n" +
2470 "Host: test\r\n" +
2471 "\r\n")
2472
2473 conn.closec = make(chan bool, 1)
2474 ln := &oneConnListener{conn}
2475 var numReqs int
2476 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2477 numReqs++
2478 if strings.Contains(req.URL.Path, "secret") {
2479 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2480 }
2481 handler.f(req.Body)
2482 }))
2483 <-conn.closec
2484 if numReqs != 1 {
2485 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2486 }
2487 }
2488 }
2489
2490
2491
2492
2493 type slowTestConn struct {
2494
2495 script []any
2496 closec chan bool
2497
2498 mu sync.Mutex
2499 rd, wd time.Time
2500 noopConn
2501 }
2502
2503 func (c *slowTestConn) SetDeadline(t time.Time) error {
2504 c.SetReadDeadline(t)
2505 c.SetWriteDeadline(t)
2506 return nil
2507 }
2508
2509 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2510 c.mu.Lock()
2511 defer c.mu.Unlock()
2512 c.rd = t
2513 return nil
2514 }
2515
2516 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2517 c.mu.Lock()
2518 defer c.mu.Unlock()
2519 c.wd = t
2520 return nil
2521 }
2522
2523 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2524 c.mu.Lock()
2525 defer c.mu.Unlock()
2526 restart:
2527 if !c.rd.IsZero() && time.Now().After(c.rd) {
2528 return 0, syscall.ETIMEDOUT
2529 }
2530 if len(c.script) == 0 {
2531 return 0, io.EOF
2532 }
2533
2534 switch cue := c.script[0].(type) {
2535 case time.Duration:
2536 if !c.rd.IsZero() {
2537
2538
2539 if remaining := time.Until(c.rd); remaining < cue {
2540 c.script[0] = cue - remaining
2541 time.Sleep(remaining)
2542 return 0, syscall.ETIMEDOUT
2543 }
2544 }
2545 c.script = c.script[1:]
2546 time.Sleep(cue)
2547 goto restart
2548
2549 case string:
2550 n = copy(b, cue)
2551
2552 if len(cue) > n {
2553 c.script[0] = cue[n:]
2554 } else {
2555 c.script = c.script[1:]
2556 }
2557
2558 default:
2559 panic("unknown cue in slowTestConn script")
2560 }
2561
2562 return
2563 }
2564
2565 func (c *slowTestConn) Close() error {
2566 select {
2567 case c.closec <- true:
2568 default:
2569 }
2570 return nil
2571 }
2572
2573 func (c *slowTestConn) Write(b []byte) (int, error) {
2574 if !c.wd.IsZero() && time.Now().After(c.wd) {
2575 return 0, syscall.ETIMEDOUT
2576 }
2577 return len(b), nil
2578 }
2579
2580 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2581 if testing.Short() {
2582 t.Skip("skipping in -short mode")
2583 }
2584 defer afterTest(t)
2585 for _, handler := range testHandlerBodyConsumers {
2586 conn := &slowTestConn{
2587 script: []any{
2588 "POST /public HTTP/1.1\r\n" +
2589 "Host: test\r\n" +
2590 "Content-Length: 10000\r\n" +
2591 "\r\n",
2592 "foo bar baz",
2593 600 * time.Millisecond,
2594 "GET /secret HTTP/1.1\r\n" +
2595 "Host: test\r\n" +
2596 "\r\n",
2597 },
2598 closec: make(chan bool, 1),
2599 }
2600 ls := &oneConnListener{conn}
2601
2602 var numReqs int
2603 s := Server{
2604 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2605 numReqs++
2606 if strings.Contains(req.URL.Path, "secret") {
2607 t.Error("Request for /secret encountered, should not have happened.")
2608 }
2609 handler.f(req.Body)
2610 }),
2611 ReadTimeout: 400 * time.Millisecond,
2612 }
2613 go s.Serve(ls)
2614 <-conn.closec
2615
2616 if numReqs != 1 {
2617 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2618 }
2619 }
2620 }
2621
2622
2623 type cancelableTimeoutContext struct {
2624 context.Context
2625 }
2626
2627 func (c cancelableTimeoutContext) Err() error {
2628 if c.Context.Err() != nil {
2629 return context.DeadlineExceeded
2630 }
2631 return nil
2632 }
2633
2634 func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
2635 func testTimeoutHandler(t *testing.T, mode testMode) {
2636 sendHi := make(chan bool, 1)
2637 writeErrors := make(chan error, 1)
2638 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2639 <-sendHi
2640 _, werr := w.Write([]byte("hi"))
2641 writeErrors <- werr
2642 })
2643 ctx, cancel := context.WithCancel(context.Background())
2644 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2645 cst := newClientServerTest(t, mode, h)
2646
2647
2648 sendHi <- true
2649 res, err := cst.c.Get(cst.ts.URL)
2650 if err != nil {
2651 t.Error(err)
2652 }
2653 if g, e := res.StatusCode, StatusOK; g != e {
2654 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2655 }
2656 body, _ := io.ReadAll(res.Body)
2657 if g, e := string(body), "hi"; g != e {
2658 t.Errorf("got body %q; expected %q", g, e)
2659 }
2660 if g := <-writeErrors; g != nil {
2661 t.Errorf("got unexpected Write error on first request: %v", g)
2662 }
2663
2664
2665 cancel()
2666
2667 res, err = cst.c.Get(cst.ts.URL)
2668 if err != nil {
2669 t.Error(err)
2670 }
2671 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2672 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2673 }
2674 body, _ = io.ReadAll(res.Body)
2675 if !strings.Contains(string(body), "<title>Timeout</title>") {
2676 t.Errorf("expected timeout body; got %q", string(body))
2677 }
2678 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2679 t.Errorf("response content-type = %q; want %q", g, w)
2680 }
2681
2682
2683
2684 sendHi <- true
2685 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2686 t.Errorf("expected Write error of %v; got %v", e, g)
2687 }
2688 }
2689
2690
2691 func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
2692 func testTimeoutHandlerRace(t *testing.T, mode testMode) {
2693 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2694 ms, _ := strconv.Atoi(r.URL.Path[1:])
2695 if ms == 0 {
2696 ms = 1
2697 }
2698 for i := 0; i < ms; i++ {
2699 w.Write([]byte("hi"))
2700 time.Sleep(time.Millisecond)
2701 }
2702 })
2703
2704 ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
2705
2706 c := ts.Client()
2707
2708 var wg sync.WaitGroup
2709 gate := make(chan bool, 10)
2710 n := 50
2711 if testing.Short() {
2712 n = 10
2713 gate = make(chan bool, 3)
2714 }
2715 for i := 0; i < n; i++ {
2716 gate <- true
2717 wg.Add(1)
2718 go func() {
2719 defer wg.Done()
2720 defer func() { <-gate }()
2721 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2722 if err == nil {
2723 io.Copy(io.Discard, res.Body)
2724 res.Body.Close()
2725 }
2726 }()
2727 }
2728 wg.Wait()
2729 }
2730
2731
2732
2733 func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
2734 func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
2735 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2736 w.WriteHeader(204)
2737 })
2738
2739 ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
2740
2741 var wg sync.WaitGroup
2742 gate := make(chan bool, 50)
2743 n := 500
2744 if testing.Short() {
2745 n = 10
2746 }
2747
2748 c := ts.Client()
2749 for i := 0; i < n; i++ {
2750 gate <- true
2751 wg.Add(1)
2752 go func() {
2753 defer wg.Done()
2754 defer func() { <-gate }()
2755 res, err := c.Get(ts.URL)
2756 if err != nil {
2757
2758
2759 t.Log(err)
2760 return
2761 }
2762 defer res.Body.Close()
2763 io.Copy(io.Discard, res.Body)
2764 }()
2765 }
2766 wg.Wait()
2767 }
2768
2769
2770 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
2771 func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
2772 sendHi := make(chan bool, 1)
2773 writeErrors := make(chan error, 1)
2774 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2775 w.Header().Set("Content-Type", "text/plain")
2776 <-sendHi
2777 _, werr := w.Write([]byte("hi"))
2778 writeErrors <- werr
2779 })
2780 ctx, cancel := context.WithCancel(context.Background())
2781 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2782 cst := newClientServerTest(t, mode, h)
2783
2784
2785 sendHi <- true
2786 res, err := cst.c.Get(cst.ts.URL)
2787 if err != nil {
2788 t.Error(err)
2789 }
2790 if g, e := res.StatusCode, StatusOK; g != e {
2791 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2792 }
2793 body, _ := io.ReadAll(res.Body)
2794 if g, e := string(body), "hi"; g != e {
2795 t.Errorf("got body %q; expected %q", g, e)
2796 }
2797 if g := <-writeErrors; g != nil {
2798 t.Errorf("got unexpected Write error on first request: %v", g)
2799 }
2800
2801
2802 cancel()
2803
2804 res, err = cst.c.Get(cst.ts.URL)
2805 if err != nil {
2806 t.Error(err)
2807 }
2808 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2809 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2810 }
2811 body, _ = io.ReadAll(res.Body)
2812 if !strings.Contains(string(body), "<title>Timeout</title>") {
2813 t.Errorf("expected timeout body; got %q", string(body))
2814 }
2815
2816
2817
2818 sendHi <- true
2819 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2820 t.Errorf("expected Write error of %v; got %v", e, g)
2821 }
2822 }
2823
2824
2825 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2826 run(t, testTimeoutHandlerStartTimerWhenServing)
2827 }
2828 func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
2829 if testing.Short() {
2830 t.Skip("skipping sleeping test in -short mode")
2831 }
2832 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2833 w.WriteHeader(StatusNoContent)
2834 }
2835 timeout := 300 * time.Millisecond
2836 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2837 defer ts.Close()
2838
2839 c := ts.Client()
2840
2841
2842
2843
2844 time.Sleep(2 * timeout)
2845 res, err := c.Get(ts.URL)
2846 if err != nil {
2847 t.Fatal(err)
2848 }
2849 defer res.Body.Close()
2850 if res.StatusCode != StatusNoContent {
2851 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2852 }
2853 }
2854
2855 func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
2856 func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
2857 writeErrors := make(chan error, 1)
2858 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2859 w.Header().Set("Content-Type", "text/plain")
2860 var err error
2861
2862
2863
2864 for i := 0; i < 100; i++ {
2865 _, err = w.Write([]byte("a"))
2866 if err != nil {
2867 break
2868 }
2869 time.Sleep(1 * time.Millisecond)
2870 }
2871 writeErrors <- err
2872 })
2873 ctx, cancel := context.WithCancel(context.Background())
2874 cancel()
2875 h := NewTestTimeoutHandler(sayHi, ctx)
2876 cst := newClientServerTest(t, mode, h)
2877 defer cst.close()
2878
2879 res, err := cst.c.Get(cst.ts.URL)
2880 if err != nil {
2881 t.Error(err)
2882 }
2883 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2884 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2885 }
2886 body, _ := io.ReadAll(res.Body)
2887 if g, e := string(body), ""; g != e {
2888 t.Errorf("got body %q; expected %q", g, e)
2889 }
2890 if g, e := <-writeErrors, context.Canceled; g != e {
2891 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2892 }
2893 }
2894
2895
2896 func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
2897 func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
2898 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2899
2900 }
2901 timeout := 300 * time.Millisecond
2902 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2903
2904 c := ts.Client()
2905
2906 res, err := c.Get(ts.URL)
2907 if err != nil {
2908 t.Fatal(err)
2909 }
2910 defer res.Body.Close()
2911 if res.StatusCode != StatusOK {
2912 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2913 }
2914 }
2915
2916
2917 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2918 wrapper := func(h Handler) Handler {
2919 return TimeoutHandler(h, time.Second, "")
2920 }
2921 run(t, func(t *testing.T, mode testMode) {
2922 testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
2923 }, testNotParallel)
2924 }
2925
2926 func TestRedirectBadPath(t *testing.T) {
2927
2928
2929 rr := httptest.NewRecorder()
2930 req := &Request{
2931 Method: "GET",
2932 URL: &url.URL{
2933 Scheme: "http",
2934 Path: "not-empty-but-no-leading-slash",
2935 },
2936 }
2937 Redirect(rr, req, "", 304)
2938 if rr.Code != 304 {
2939 t.Errorf("Code = %d; want 304", rr.Code)
2940 }
2941 }
2942
2943 func TestRedirectEscapedPath(t *testing.T) {
2944 baseURL, redirectURL := "http://example.com/foo%2Fbar/", "qux%2Fbaz"
2945 req := httptest.NewRequest("GET", baseURL, NoBody)
2946
2947 rr := httptest.NewRecorder()
2948 Redirect(rr, req, redirectURL, StatusMovedPermanently)
2949
2950 wantURL := "/foo%2Fbar/qux%2Fbaz"
2951 if got := rr.Result().Header.Get("Location"); got != wantURL {
2952 t.Errorf("Redirect(%s, %s) = %s, want = %s", baseURL, redirectURL, got, wantURL)
2953 }
2954 }
2955
2956
2957 func TestRedirect(t *testing.T) {
2958 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2959
2960 var tests = []struct {
2961 in string
2962 want string
2963 }{
2964
2965 {"http://foobar.com/baz", "http://foobar.com/baz"},
2966
2967 {"https://foobar.com/baz", "https://foobar.com/baz"},
2968
2969 {"test://foobar.com/baz", "test://foobar.com/baz"},
2970
2971 {"//foobar.com/baz", "//foobar.com/baz"},
2972
2973 {"/foobar.com/baz", "/foobar.com/baz"},
2974
2975 {"foobar.com/baz", "/qux/foobar.com/baz"},
2976
2977 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2978
2979 {"///foobar.com/baz", "/foobar.com/baz"},
2980
2981
2982 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2983 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2984 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2985
2986 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2987 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2988 }
2989
2990 for _, tt := range tests {
2991 rec := httptest.NewRecorder()
2992 Redirect(rec, req, tt.in, 302)
2993 if got, want := rec.Code, 302; got != want {
2994 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
2995 }
2996 if got := rec.Header().Get("Location"); got != tt.want {
2997 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
2998 }
2999 }
3000 }
3001
3002
3003
3004 func TestRedirectContentTypeAndBody(t *testing.T) {
3005 type ctHeader struct {
3006 Values []string
3007 }
3008
3009 var tests = []struct {
3010 method string
3011 ct *ctHeader
3012 wantCT string
3013 wantBody string
3014 }{
3015 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
3016 {MethodHead, nil, "text/html; charset=utf-8", ""},
3017 {MethodPost, nil, "", ""},
3018 {MethodDelete, nil, "", ""},
3019 {"foo", nil, "", ""},
3020 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
3021 {MethodGet, &ctHeader{[]string{}}, "", ""},
3022 {MethodGet, &ctHeader{nil}, "", ""},
3023 }
3024 for _, tt := range tests {
3025 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
3026 rec := httptest.NewRecorder()
3027 if tt.ct != nil {
3028 rec.Header()["Content-Type"] = tt.ct.Values
3029 }
3030 Redirect(rec, req, "/foo", 302)
3031 if got, want := rec.Code, 302; got != want {
3032 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
3033 }
3034 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
3035 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
3036 }
3037 resp := rec.Result()
3038 body, err := io.ReadAll(resp.Body)
3039 if err != nil {
3040 t.Fatal(err)
3041 }
3042 if got, want := string(body), tt.wantBody; got != want {
3043 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
3044 }
3045 }
3046 }
3047
3048
3049
3050
3051
3052
3053
3054 func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
3055
3056 func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
3057 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
3058 all, err := io.ReadAll(r.Body)
3059 if err != nil {
3060 t.Fatalf("handler ReadAll: %v", err)
3061 }
3062 if len(all) != 0 {
3063 t.Errorf("handler got %d bytes; expected 0", len(all))
3064 }
3065 rw.Header().Set("Content-Length", "0")
3066 }))
3067
3068 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
3069 if err != nil {
3070 t.Fatal(err)
3071 }
3072 req.ContentLength = 0
3073
3074 var resp [5]*Response
3075 for i := range resp {
3076 resp[i], err = cst.c.Do(req)
3077 if err != nil {
3078 t.Fatalf("client post #%d: %v", i, err)
3079 }
3080 }
3081
3082 for i := range resp {
3083 all, err := io.ReadAll(resp[i].Body)
3084 if err != nil {
3085 t.Fatalf("req #%d: client ReadAll: %v", i, err)
3086 }
3087 if len(all) != 0 {
3088 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
3089 }
3090 }
3091 }
3092
3093 func TestHandlerPanicNil(t *testing.T) {
3094 run(t, func(t *testing.T, mode testMode) {
3095 testHandlerPanic(t, false, mode, nil, nil)
3096 }, testNotParallel)
3097 }
3098
3099 func TestHandlerPanic(t *testing.T) {
3100 run(t, func(t *testing.T, mode testMode) {
3101 testHandlerPanic(t, false, mode, nil, "intentional death for testing")
3102 }, testNotParallel)
3103 }
3104
3105 func TestHandlerPanicWithHijack(t *testing.T) {
3106
3107 run(t, func(t *testing.T, mode testMode) {
3108 testHandlerPanic(t, true, mode, nil, "intentional death for testing")
3109 }, []testMode{http1Mode})
3110 }
3111
3112 func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
3113
3114
3115
3116
3117
3118
3119
3120
3121 pr, pw := io.Pipe()
3122 defer pw.Close()
3123
3124 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
3125 if withHijack {
3126 rwc, _, err := w.(Hijacker).Hijack()
3127 if err != nil {
3128 t.Logf("unexpected error: %v", err)
3129 }
3130 defer rwc.Close()
3131 }
3132 panic(panicValue)
3133 })
3134 if wrapper != nil {
3135 handler = wrapper(handler)
3136 }
3137 cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
3138 ts.Config.ErrorLog = log.New(pw, "", 0)
3139 })
3140
3141
3142 done := make(chan bool, 1)
3143 go func() {
3144 buf := make([]byte, 4<<10)
3145 _, err := pr.Read(buf)
3146 pr.Close()
3147 if err != nil && err != io.EOF {
3148 t.Error(err)
3149 }
3150 done <- true
3151 }()
3152
3153 _, err := cst.c.Get(cst.ts.URL)
3154 if err == nil {
3155 t.Logf("expected an error")
3156 }
3157
3158 if panicValue == nil {
3159 return
3160 }
3161
3162 <-done
3163 }
3164
3165 type terrorWriter struct{ t *testing.T }
3166
3167 func (w terrorWriter) Write(p []byte) (int, error) {
3168 w.t.Errorf("%s", p)
3169 return len(p), nil
3170 }
3171
3172
3173
3174 func TestServerWriteHijackZeroBytes(t *testing.T) {
3175 run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
3176 }
3177 func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
3178 done := make(chan struct{})
3179 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3180 defer close(done)
3181 w.(Flusher).Flush()
3182 conn, _, err := w.(Hijacker).Hijack()
3183 if err != nil {
3184 t.Errorf("Hijack: %v", err)
3185 return
3186 }
3187 defer conn.Close()
3188 _, err = w.Write(nil)
3189 if err != ErrHijacked {
3190 t.Errorf("Write error = %v; want ErrHijacked", err)
3191 }
3192 }), func(ts *httptest.Server) {
3193 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
3194 }).ts
3195
3196 c := ts.Client()
3197 res, err := c.Get(ts.URL)
3198 if err != nil {
3199 t.Fatal(err)
3200 }
3201 res.Body.Close()
3202 <-done
3203 }
3204
3205 func TestServerNoDate(t *testing.T) {
3206 run(t, func(t *testing.T, mode testMode) {
3207 testServerNoHeader(t, mode, "Date")
3208 })
3209 }
3210
3211 func TestServerContentType(t *testing.T) {
3212 run(t, func(t *testing.T, mode testMode) {
3213 testServerNoHeader(t, mode, "Content-Type")
3214 })
3215 }
3216
3217 func testServerNoHeader(t *testing.T, mode testMode, header string) {
3218 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3219 w.Header()[header] = nil
3220 io.WriteString(w, "<html>foo</html>")
3221 }))
3222 res, err := cst.c.Get(cst.ts.URL)
3223 if err != nil {
3224 t.Fatal(err)
3225 }
3226 res.Body.Close()
3227 if got, ok := res.Header[header]; ok {
3228 t.Fatalf("Expected no %s header; got %q", header, got)
3229 }
3230 }
3231
3232 func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
3233 func testStripPrefix(t *testing.T, mode testMode) {
3234 h := HandlerFunc(func(w ResponseWriter, r *Request) {
3235 w.Header().Set("X-Path", r.URL.Path)
3236 w.Header().Set("X-RawPath", r.URL.RawPath)
3237 })
3238 ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
3239
3240 c := ts.Client()
3241
3242 cases := []struct {
3243 reqPath string
3244 path string
3245 rawPath string
3246 }{
3247 {"/foo/bar/qux", "/qux", ""},
3248 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
3249 {"/foo%2Fbar/qux", "", ""},
3250 {"/bar", "", ""},
3251 }
3252 for _, tc := range cases {
3253 t.Run(tc.reqPath, func(t *testing.T) {
3254 res, err := c.Get(ts.URL + tc.reqPath)
3255 if err != nil {
3256 t.Fatal(err)
3257 }
3258 res.Body.Close()
3259 if tc.path == "" {
3260 if res.StatusCode != StatusNotFound {
3261 t.Errorf("got %q, want 404 Not Found", res.Status)
3262 }
3263 return
3264 }
3265 if res.StatusCode != StatusOK {
3266 t.Fatalf("got %q, want 200 OK", res.Status)
3267 }
3268 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
3269 t.Errorf("got Path %q, want %q", g, w)
3270 }
3271 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
3272 t.Errorf("got RawPath %q, want %q", g, w)
3273 }
3274 })
3275 }
3276 }
3277
3278
3279 func TestStripPrefixNotModifyRequest(t *testing.T) {
3280 h := StripPrefix("/foo", NotFoundHandler())
3281 req := httptest.NewRequest("GET", "/foo/bar", nil)
3282 h.ServeHTTP(httptest.NewRecorder(), req)
3283 if req.URL.Path != "/foo/bar" {
3284 t.Errorf("StripPrefix should not modify the provided Request, but it did")
3285 }
3286 }
3287
3288 func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) }
3289 func testRequestLimit(t *testing.T, mode testMode) {
3290 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3291 t.Fatalf("didn't expect to get request in Handler")
3292 }), optQuietLog)
3293 req, _ := NewRequest("GET", cst.ts.URL, nil)
3294 var bytesPerHeader = len("header12345: val12345\r\n")
3295 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
3296 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
3297 }
3298 res, err := cst.c.Do(req)
3299 if res != nil {
3300 defer res.Body.Close()
3301 }
3302 if mode == http2Mode {
3303
3304
3305
3306
3307 if err == nil && res.StatusCode != 431 {
3308 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3309 }
3310 } else {
3311
3312
3313
3314
3315 if err != nil {
3316 t.Fatalf("Do: %v", err)
3317 }
3318 if res.StatusCode != 431 {
3319 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3320 }
3321 }
3322 }
3323
3324 type neverEnding byte
3325
3326 func (b neverEnding) Read(p []byte) (n int, err error) {
3327 for i := range p {
3328 p[i] = byte(b)
3329 }
3330 return len(p), nil
3331 }
3332
3333 type bodyLimitReader struct {
3334 mu sync.Mutex
3335 count int
3336 limit int
3337 closed chan struct{}
3338 }
3339
3340 func (r *bodyLimitReader) Read(p []byte) (int, error) {
3341 r.mu.Lock()
3342 defer r.mu.Unlock()
3343 select {
3344 case <-r.closed:
3345 return 0, errors.New("closed")
3346 default:
3347 }
3348 if r.count > r.limit {
3349 return 0, errors.New("at limit")
3350 }
3351 r.count += len(p)
3352 for i := range p {
3353 p[i] = 'a'
3354 }
3355 return len(p), nil
3356 }
3357
3358 func (r *bodyLimitReader) Close() error {
3359 r.mu.Lock()
3360 defer r.mu.Unlock()
3361 close(r.closed)
3362 return nil
3363 }
3364
3365 func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
3366 func testRequestBodyLimit(t *testing.T, mode testMode) {
3367 const limit = 1 << 20
3368 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3369 r.Body = MaxBytesReader(w, r.Body, limit)
3370 n, err := io.Copy(io.Discard, r.Body)
3371 if err == nil {
3372 t.Errorf("expected error from io.Copy")
3373 }
3374 if n != limit {
3375 t.Errorf("io.Copy = %d, want %d", n, limit)
3376 }
3377 mbErr, ok := err.(*MaxBytesError)
3378 if !ok {
3379 t.Errorf("expected MaxBytesError, got %T", err)
3380 }
3381 if mbErr.Limit != limit {
3382 t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
3383 }
3384 }))
3385
3386 body := &bodyLimitReader{
3387 closed: make(chan struct{}),
3388 limit: limit * 200,
3389 }
3390 req, _ := NewRequest("POST", cst.ts.URL, body)
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
3401 resp, err := cst.c.Do(req)
3402 if err == nil {
3403 resp.Body.Close()
3404 }
3405
3406
3407 <-body.closed
3408
3409 if body.count > limit*100 {
3410 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3411 limit, body.count)
3412 }
3413 }
3414
3415
3416
3417 func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) }
3418 func testClientWriteShutdown(t *testing.T, mode testMode) {
3419 if runtime.GOOS == "plan9" {
3420 t.Skip("skipping test; see https://golang.org/issue/17906")
3421 }
3422 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
3423 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3424 if err != nil {
3425 t.Fatalf("Dial: %v", err)
3426 }
3427 err = conn.(*net.TCPConn).CloseWrite()
3428 if err != nil {
3429 t.Fatalf("CloseWrite: %v", err)
3430 }
3431
3432 bs, err := io.ReadAll(conn)
3433 if err != nil {
3434 t.Errorf("ReadAll: %v", err)
3435 }
3436 got := string(bs)
3437 if got != "" {
3438 t.Errorf("read %q from server; want nothing", got)
3439 }
3440 }
3441
3442
3443
3444 func TestServerBufferedChunking(t *testing.T) {
3445 conn := new(testConn)
3446 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3447 conn.closec = make(chan bool, 1)
3448 ls := &oneConnListener{conn}
3449 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3450 rw.(Flusher).Flush()
3451 rw.Write([]byte{'x'})
3452 rw.Write([]byte{'y'})
3453 rw.Write([]byte{'z'})
3454 }))
3455 <-conn.closec
3456 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3457 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3458 conn.writeBuf.Bytes())
3459 }
3460 }
3461
3462
3463
3464
3465
3466 func TestServerGracefulClose(t *testing.T) {
3467
3468 run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
3469 }
3470 func testServerGracefulClose(t *testing.T, mode testMode) {
3471 runTimeSensitiveTest(t, []time.Duration{
3472 1 * time.Millisecond,
3473 5 * time.Millisecond,
3474 10 * time.Millisecond,
3475 50 * time.Millisecond,
3476 100 * time.Millisecond,
3477 500 * time.Millisecond,
3478 time.Second,
3479 5 * time.Second,
3480 }, func(t *testing.T, timeout time.Duration) error {
3481 SetRSTAvoidanceDelay(t, timeout)
3482 t.Logf("set RST avoidance delay to %v", timeout)
3483
3484 const bodySize = 5 << 20
3485 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3486 for i := 0; i < bodySize; i++ {
3487 req = append(req, 'x')
3488 }
3489
3490 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3491 Error(w, "bye", StatusUnauthorized)
3492 }))
3493
3494
3495 defer cst.close()
3496 ts := cst.ts
3497
3498 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3499 if err != nil {
3500 return err
3501 }
3502 writeErr := make(chan error)
3503 go func() {
3504 _, err := conn.Write(req)
3505 writeErr <- err
3506 }()
3507 defer func() {
3508 conn.Close()
3509
3510
3511
3512 <-writeErr
3513 }()
3514
3515 br := bufio.NewReader(conn)
3516 lineNum := 0
3517 for {
3518 line, err := br.ReadString('\n')
3519 if err == io.EOF {
3520 break
3521 }
3522 if err != nil {
3523 return fmt.Errorf("ReadLine: %v", err)
3524 }
3525 lineNum++
3526 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3527 t.Errorf("Response line = %q; want a 401", line)
3528 }
3529 }
3530 return nil
3531 })
3532 }
3533
3534 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
3535 func testCaseSensitiveMethod(t *testing.T, mode testMode) {
3536 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3537 if r.Method != "get" {
3538 t.Errorf(`Got method %q; want "get"`, r.Method)
3539 }
3540 }))
3541 defer cst.close()
3542 req, _ := NewRequest("get", cst.ts.URL, nil)
3543 res, err := cst.c.Do(req)
3544 if err != nil {
3545 t.Error(err)
3546 return
3547 }
3548
3549 res.Body.Close()
3550 }
3551
3552
3553
3554
3555
3556 func TestContentLengthZero(t *testing.T) {
3557 run(t, testContentLengthZero, []testMode{http1Mode})
3558 }
3559 func testContentLengthZero(t *testing.T, mode testMode) {
3560 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
3561
3562 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3563 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3564 if err != nil {
3565 t.Fatalf("error dialing: %v", err)
3566 }
3567 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3568 if err != nil {
3569 t.Fatalf("error writing: %v", err)
3570 }
3571 req, _ := NewRequest("GET", "/", nil)
3572 res, err := ReadResponse(bufio.NewReader(conn), req)
3573 if err != nil {
3574 t.Fatalf("error reading response: %v", err)
3575 }
3576 if te := res.TransferEncoding; len(te) > 0 {
3577 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3578 }
3579 if cl := res.ContentLength; cl != 0 {
3580 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3581 }
3582 conn.Close()
3583 }
3584 }
3585
3586 func TestCloseNotifier(t *testing.T) {
3587 run(t, testCloseNotifier, []testMode{http1Mode})
3588 }
3589 func testCloseNotifier(t *testing.T, mode testMode) {
3590 gotReq := make(chan bool, 1)
3591 sawClose := make(chan bool, 1)
3592 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3593 gotReq <- true
3594 cc := rw.(CloseNotifier).CloseNotify()
3595 <-cc
3596 sawClose <- true
3597 })).ts
3598 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3599 if err != nil {
3600 t.Fatalf("error dialing: %v", err)
3601 }
3602 diec := make(chan bool)
3603 go func() {
3604 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3605 if err != nil {
3606 t.Error(err)
3607 return
3608 }
3609 <-diec
3610 conn.Close()
3611 }()
3612 For:
3613 for {
3614 select {
3615 case <-gotReq:
3616 diec <- true
3617 case <-sawClose:
3618 break For
3619 }
3620 }
3621 ts.Close()
3622 }
3623
3624
3625
3626
3627
3628 func TestCloseNotifierPipelined(t *testing.T) {
3629 run(t, testCloseNotifierPipelined, []testMode{http1Mode})
3630 }
3631 func testCloseNotifierPipelined(t *testing.T, mode testMode) {
3632 gotReq := make(chan bool, 2)
3633 sawClose := make(chan bool, 2)
3634 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3635 gotReq <- true
3636 cc := rw.(CloseNotifier).CloseNotify()
3637 select {
3638 case <-cc:
3639 t.Error("unexpected CloseNotify")
3640 case <-time.After(100 * time.Millisecond):
3641 }
3642 sawClose <- true
3643 })).ts
3644 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3645 if err != nil {
3646 t.Fatalf("error dialing: %v", err)
3647 }
3648 diec := make(chan bool, 1)
3649 defer close(diec)
3650 go func() {
3651 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3652 _, err = io.WriteString(conn, req+req)
3653 if err != nil {
3654 t.Error(err)
3655 return
3656 }
3657 <-diec
3658 conn.Close()
3659 }()
3660 reqs := 0
3661 closes := 0
3662 for {
3663 select {
3664 case <-gotReq:
3665 reqs++
3666 if reqs > 2 {
3667 t.Fatal("too many requests")
3668 }
3669 case <-sawClose:
3670 closes++
3671 if closes > 1 {
3672 return
3673 }
3674 }
3675 }
3676 }
3677
3678 func TestCloseNotifierChanLeak(t *testing.T) {
3679 defer afterTest(t)
3680 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3681 for i := 0; i < 20; i++ {
3682 var output bytes.Buffer
3683 conn := &rwTestConn{
3684 Reader: bytes.NewReader(req),
3685 Writer: &output,
3686 closec: make(chan bool, 1),
3687 }
3688 ln := &oneConnListener{conn: conn}
3689 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3690
3691
3692
3693 _ = rw.(CloseNotifier).CloseNotify()
3694 })
3695 go Serve(ln, handler)
3696 <-conn.closec
3697 }
3698 }
3699
3700
3701
3702
3703
3704
3705
3706
3707
3708
3709 func TestHijackAfterCloseNotifier(t *testing.T) {
3710 run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
3711 }
3712 func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
3713 script := make(chan string, 2)
3714 script <- "closenotify"
3715 script <- "hijack"
3716 close(script)
3717 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3718 plan := <-script
3719 switch plan {
3720 default:
3721 panic("bogus plan; too many requests")
3722 case "closenotify":
3723 w.(CloseNotifier).CloseNotify()
3724 w.Header().Set("X-Addr", r.RemoteAddr)
3725 case "hijack":
3726 c, _, err := w.(Hijacker).Hijack()
3727 if err != nil {
3728 t.Errorf("Hijack in Handler: %v", err)
3729 return
3730 }
3731 if _, ok := c.(*net.TCPConn); !ok {
3732
3733
3734 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3735 }
3736 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3737 c.Close()
3738 return
3739 }
3740 })).ts
3741 res1, err := ts.Client().Get(ts.URL)
3742 if err != nil {
3743 log.Fatal(err)
3744 }
3745 res2, err := ts.Client().Get(ts.URL)
3746 if err != nil {
3747 log.Fatal(err)
3748 }
3749 addr1 := res1.Header.Get("X-Addr")
3750 addr2 := res2.Header.Get("X-Addr")
3751 if addr1 == "" || addr1 != addr2 {
3752 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3753 }
3754 }
3755
3756 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3757 run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
3758 }
3759 func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
3760 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3761 bodyOkay := make(chan bool, 1)
3762 gotCloseNotify := make(chan bool, 1)
3763 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3764 defer close(bodyOkay)
3765
3766 reqBody := r.Body
3767 r.Body = nil
3768
3769 gone := w.(CloseNotifier).CloseNotify()
3770 slurp, err := io.ReadAll(reqBody)
3771 if err != nil {
3772 t.Errorf("Body read: %v", err)
3773 return
3774 }
3775 if len(slurp) != len(requestBody) {
3776 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3777 return
3778 }
3779 if !bytes.Equal(slurp, requestBody) {
3780 t.Error("Backend read wrong request body.")
3781 return
3782 }
3783 bodyOkay <- true
3784 <-gone
3785 gotCloseNotify <- true
3786 })).ts
3787
3788 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3789 if err != nil {
3790 t.Fatal(err)
3791 }
3792 defer conn.Close()
3793
3794 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3795 len(requestBody), requestBody)
3796 if !<-bodyOkay {
3797
3798 return
3799 }
3800 conn.Close()
3801 <-gotCloseNotify
3802 }
3803
3804 func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
3805 func testOptions(t *testing.T, mode testMode) {
3806 uric := make(chan string, 2)
3807 mux := NewServeMux()
3808 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3809 uric <- r.RequestURI
3810 })
3811 ts := newClientServerTest(t, mode, mux).ts
3812
3813 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3814 if err != nil {
3815 t.Fatal(err)
3816 }
3817 defer conn.Close()
3818
3819
3820 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3821 if err != nil {
3822 t.Fatal(err)
3823 }
3824 br := bufio.NewReader(conn)
3825 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3826 if err != nil {
3827 t.Fatal(err)
3828 }
3829 if res.StatusCode != 200 {
3830 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3831 }
3832
3833
3834 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3835 if err != nil {
3836 t.Fatal(err)
3837 }
3838 res, err = ReadResponse(br, &Request{Method: "GET"})
3839 if err != nil {
3840 t.Fatal(err)
3841 }
3842 if res.StatusCode != 400 {
3843 t.Errorf("Got non-400 response to GET *: %#v", res)
3844 }
3845
3846 res, err = Get(ts.URL + "/second")
3847 if err != nil {
3848 t.Fatal(err)
3849 }
3850 res.Body.Close()
3851 if got := <-uric; got != "/second" {
3852 t.Errorf("Handler saw request for %q; want /second", got)
3853 }
3854 }
3855
3856 func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
3857 func testOptionsHandler(t *testing.T, mode testMode) {
3858 rc := make(chan *Request, 1)
3859
3860 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3861 rc <- r
3862 }), func(ts *httptest.Server) {
3863 ts.Config.DisableGeneralOptionsHandler = true
3864 }).ts
3865
3866 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3867 if err != nil {
3868 t.Fatal(err)
3869 }
3870 defer conn.Close()
3871
3872 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3873 if err != nil {
3874 t.Fatal(err)
3875 }
3876
3877 if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
3878 t.Errorf("Expected OPTIONS * request, got %v", got)
3879 }
3880 }
3881
3882
3883
3884
3885
3886
3887
3888
3889
3890
3891 func TestHeaderToWire(t *testing.T) {
3892 tests := []struct {
3893 name string
3894 handler func(ResponseWriter, *Request)
3895 check func(got, logs string) error
3896 }{
3897 {
3898 name: "write without Header",
3899 handler: func(rw ResponseWriter, r *Request) {
3900 rw.Write([]byte("hello world"))
3901 },
3902 check: func(got, logs string) error {
3903 if !strings.Contains(got, "Content-Length:") {
3904 return errors.New("no content-length")
3905 }
3906 if !strings.Contains(got, "Content-Type: text/plain") {
3907 return errors.New("no content-type")
3908 }
3909 return nil
3910 },
3911 },
3912 {
3913 name: "Header mutation before write",
3914 handler: func(rw ResponseWriter, r *Request) {
3915 h := rw.Header()
3916 h.Set("Content-Type", "some/type")
3917 rw.Write([]byte("hello world"))
3918 h.Set("Too-Late", "bogus")
3919 },
3920 check: func(got, logs string) error {
3921 if !strings.Contains(got, "Content-Length:") {
3922 return errors.New("no content-length")
3923 }
3924 if !strings.Contains(got, "Content-Type: some/type") {
3925 return errors.New("wrong content-type")
3926 }
3927 if strings.Contains(got, "Too-Late") {
3928 return errors.New("don't want too-late header")
3929 }
3930 return nil
3931 },
3932 },
3933 {
3934 name: "write then useless Header mutation",
3935 handler: func(rw ResponseWriter, r *Request) {
3936 rw.Write([]byte("hello world"))
3937 rw.Header().Set("Too-Late", "Write already wrote headers")
3938 },
3939 check: func(got, logs string) error {
3940 if strings.Contains(got, "Too-Late") {
3941 return errors.New("header appeared from after WriteHeader")
3942 }
3943 return nil
3944 },
3945 },
3946 {
3947 name: "flush then write",
3948 handler: func(rw ResponseWriter, r *Request) {
3949 rw.(Flusher).Flush()
3950 rw.Write([]byte("post-flush"))
3951 rw.Header().Set("Too-Late", "Write already wrote headers")
3952 },
3953 check: func(got, logs string) error {
3954 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3955 return errors.New("not chunked")
3956 }
3957 if strings.Contains(got, "Too-Late") {
3958 return errors.New("header appeared from after WriteHeader")
3959 }
3960 return nil
3961 },
3962 },
3963 {
3964 name: "header then flush",
3965 handler: func(rw ResponseWriter, r *Request) {
3966 rw.Header().Set("Content-Type", "some/type")
3967 rw.(Flusher).Flush()
3968 rw.Write([]byte("post-flush"))
3969 rw.Header().Set("Too-Late", "Write already wrote headers")
3970 },
3971 check: func(got, logs string) error {
3972 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3973 return errors.New("not chunked")
3974 }
3975 if strings.Contains(got, "Too-Late") {
3976 return errors.New("header appeared from after WriteHeader")
3977 }
3978 if !strings.Contains(got, "Content-Type: some/type") {
3979 return errors.New("wrong content-type")
3980 }
3981 return nil
3982 },
3983 },
3984 {
3985 name: "sniff-on-first-write content-type",
3986 handler: func(rw ResponseWriter, r *Request) {
3987 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3988 rw.Header().Set("Content-Type", "x/wrong")
3989 },
3990 check: func(got, logs string) error {
3991 if !strings.Contains(got, "Content-Type: text/html") {
3992 return errors.New("wrong content-type; want html")
3993 }
3994 return nil
3995 },
3996 },
3997 {
3998 name: "explicit content-type wins",
3999 handler: func(rw ResponseWriter, r *Request) {
4000 rw.Header().Set("Content-Type", "some/type")
4001 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
4002 },
4003 check: func(got, logs string) error {
4004 if !strings.Contains(got, "Content-Type: some/type") {
4005 return errors.New("wrong content-type; want html")
4006 }
4007 return nil
4008 },
4009 },
4010 {
4011 name: "empty handler",
4012 handler: func(rw ResponseWriter, r *Request) {
4013 },
4014 check: func(got, logs string) error {
4015 if !strings.Contains(got, "Content-Length: 0") {
4016 return errors.New("want 0 content-length")
4017 }
4018 return nil
4019 },
4020 },
4021 {
4022 name: "only Header, no write",
4023 handler: func(rw ResponseWriter, r *Request) {
4024 rw.Header().Set("Some-Header", "some-value")
4025 },
4026 check: func(got, logs string) error {
4027 if !strings.Contains(got, "Some-Header") {
4028 return errors.New("didn't get header")
4029 }
4030 return nil
4031 },
4032 },
4033 {
4034 name: "WriteHeader call",
4035 handler: func(rw ResponseWriter, r *Request) {
4036 rw.WriteHeader(404)
4037 rw.Header().Set("Too-Late", "some-value")
4038 },
4039 check: func(got, logs string) error {
4040 if !strings.Contains(got, "404") {
4041 return errors.New("wrong status")
4042 }
4043 if strings.Contains(got, "Too-Late") {
4044 return errors.New("shouldn't have seen Too-Late")
4045 }
4046 return nil
4047 },
4048 },
4049 }
4050 for _, tc := range tests {
4051 ht := newHandlerTest(HandlerFunc(tc.handler))
4052 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
4053 logs := ht.logbuf.String()
4054 if err := tc.check(got, logs); err != nil {
4055 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
4056 }
4057 }
4058 }
4059
4060 type errorListener struct {
4061 errs []error
4062 }
4063
4064 func (l *errorListener) Accept() (c net.Conn, err error) {
4065 if len(l.errs) == 0 {
4066 return nil, io.EOF
4067 }
4068 err = l.errs[0]
4069 l.errs = l.errs[1:]
4070 return
4071 }
4072
4073 func (l *errorListener) Close() error {
4074 return nil
4075 }
4076
4077 func (l *errorListener) Addr() net.Addr {
4078 return dummyAddr("test-address")
4079 }
4080
4081 func TestAcceptMaxFds(t *testing.T) {
4082 setParallel(t)
4083
4084 ln := &errorListener{[]error{
4085 &net.OpError{
4086 Op: "accept",
4087 Err: syscall.EMFILE,
4088 }}}
4089 server := &Server{
4090 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
4091 ErrorLog: log.New(io.Discard, "", 0),
4092 }
4093 err := server.Serve(ln)
4094 if err != io.EOF {
4095 t.Errorf("got error %v, want EOF", err)
4096 }
4097 }
4098
4099 func TestWriteAfterHijack(t *testing.T) {
4100 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
4101 var buf strings.Builder
4102 wrotec := make(chan bool, 1)
4103 conn := &rwTestConn{
4104 Reader: bytes.NewReader(req),
4105 Writer: &buf,
4106 closec: make(chan bool, 1),
4107 }
4108 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
4109 conn, bufrw, err := rw.(Hijacker).Hijack()
4110 if err != nil {
4111 t.Error(err)
4112 return
4113 }
4114 go func() {
4115 bufrw.Write([]byte("[hijack-to-bufw]"))
4116 bufrw.Flush()
4117 conn.Write([]byte("[hijack-to-conn]"))
4118 conn.Close()
4119 wrotec <- true
4120 }()
4121 })
4122 ln := &oneConnListener{conn: conn}
4123 go Serve(ln, handler)
4124 <-conn.closec
4125 <-wrotec
4126 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
4127 t.Errorf("wrote %q; want %q", g, w)
4128 }
4129 }
4130
4131 func TestDoubleHijack(t *testing.T) {
4132 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
4133 var buf bytes.Buffer
4134 conn := &rwTestConn{
4135 Reader: bytes.NewReader(req),
4136 Writer: &buf,
4137 closec: make(chan bool, 1),
4138 }
4139 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
4140 conn, _, err := rw.(Hijacker).Hijack()
4141 if err != nil {
4142 t.Error(err)
4143 return
4144 }
4145 _, _, err = rw.(Hijacker).Hijack()
4146 if err == nil {
4147 t.Errorf("got err = nil; want err != nil")
4148 }
4149 conn.Close()
4150 })
4151 ln := &oneConnListener{conn: conn}
4152 go Serve(ln, handler)
4153 <-conn.closec
4154 }
4155
4156
4157
4158
4159
4160
4161
4162 func TestHTTP10ConnectionHeader(t *testing.T) {
4163 run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
4164 }
4165 func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
4166 mux := NewServeMux()
4167 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
4168 ts := newClientServerTest(t, mode, mux).ts
4169
4170
4171 tests := []struct {
4172 req string
4173 expect []string
4174 }{
4175 {
4176 req: "GET / HTTP/1.0\r\n\r\n",
4177 expect: nil,
4178 },
4179 {
4180 req: "OPTIONS * HTTP/1.0\r\n\r\n",
4181 expect: nil,
4182 },
4183 {
4184 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
4185 expect: []string{"keep-alive"},
4186 },
4187 }
4188
4189 for _, tt := range tests {
4190 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
4191 if err != nil {
4192 t.Fatal("dial err:", err)
4193 }
4194
4195 _, err = fmt.Fprint(conn, tt.req)
4196 if err != nil {
4197 t.Fatal("conn write err:", err)
4198 }
4199
4200 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
4201 if err != nil {
4202 t.Fatal("ReadResponse err:", err)
4203 }
4204 conn.Close()
4205 resp.Body.Close()
4206
4207 got := resp.Header["Connection"]
4208 if !slices.Equal(got, tt.expect) {
4209 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
4210 }
4211 }
4212 }
4213
4214
4215 func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
4216 func testServerReaderFromOrder(t *testing.T, mode testMode) {
4217 pr, pw := io.Pipe()
4218 const size = 3 << 20
4219 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4220 rw.Header().Set("Content-Type", "text/plain")
4221 done := make(chan bool)
4222 go func() {
4223 io.Copy(rw, pr)
4224 close(done)
4225 }()
4226 time.Sleep(25 * time.Millisecond)
4227 n, err := io.Copy(io.Discard, req.Body)
4228 if err != nil {
4229 t.Errorf("handler Copy: %v", err)
4230 return
4231 }
4232 if n != size {
4233 t.Errorf("handler Copy = %d; want %d", n, size)
4234 }
4235 pw.Write([]byte("hi"))
4236 pw.Close()
4237 <-done
4238 }))
4239
4240 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
4241 if err != nil {
4242 t.Fatal(err)
4243 }
4244 res, err := cst.c.Do(req)
4245 if err != nil {
4246 t.Fatal(err)
4247 }
4248 all, err := io.ReadAll(res.Body)
4249 if err != nil {
4250 t.Fatal(err)
4251 }
4252 res.Body.Close()
4253 if string(all) != "hi" {
4254 t.Errorf("Body = %q; want hi", all)
4255 }
4256 }
4257
4258
4259 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
4260 for _, code := range []int{StatusNotModified, StatusNoContent} {
4261 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4262 if r.URL.Path == "/header" {
4263 w.Header().Set("Content-Length", "123")
4264 }
4265 w.WriteHeader(code)
4266 if r.URL.Path == "/more" {
4267 w.Write([]byte("stuff"))
4268 }
4269 }))
4270 for _, req := range []string{
4271 "GET / HTTP/1.0",
4272 "GET /header HTTP/1.0",
4273 "GET /more HTTP/1.0",
4274 "GET / HTTP/1.1\nHost: foo",
4275 "GET /header HTTP/1.1\nHost: foo",
4276 "GET /more HTTP/1.1\nHost: foo",
4277 } {
4278 got := ht.rawResponse(req)
4279 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
4280 if !strings.Contains(got, wantStatus) {
4281 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
4282 } else if strings.Contains(got, "Content-Length") {
4283 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
4284 } else if strings.Contains(got, "stuff") {
4285 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
4286 }
4287 }
4288 }
4289 }
4290
4291 func TestContentTypeOkayOn204(t *testing.T) {
4292 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4293 w.Header().Set("Content-Length", "123")
4294 w.Header().Set("Content-Type", "foo/bar")
4295 w.WriteHeader(204)
4296 }))
4297 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4298 if !strings.Contains(got, "Content-Type: foo/bar") {
4299 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
4300 }
4301 if strings.Contains(got, "Content-Length: 123") {
4302 t.Errorf("Response = %q; don't want a Content-Length", got)
4303 }
4304 }
4305
4306
4307
4308
4309
4310
4311
4312 func TestTransportAndServerSharedBodyRace(t *testing.T) {
4313 run(t, testTransportAndServerSharedBodyRace, testNotParallel)
4314 }
4315 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
4316
4317
4318
4319
4320 runTimeSensitiveTest(t, []time.Duration{
4321 1 * time.Millisecond,
4322 5 * time.Millisecond,
4323 10 * time.Millisecond,
4324 50 * time.Millisecond,
4325 100 * time.Millisecond,
4326 500 * time.Millisecond,
4327 time.Second,
4328 5 * time.Second,
4329 }, func(t *testing.T, timeout time.Duration) error {
4330 SetRSTAvoidanceDelay(t, timeout)
4331 t.Logf("set RST avoidance delay to %v", timeout)
4332
4333 const bodySize = 1 << 20
4334
4335 var wg sync.WaitGroup
4336 backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4337
4338
4339
4340
4341
4342
4343
4344
4345 wg.Add(1)
4346 defer wg.Done()
4347
4348 n, err := io.CopyN(rw, req.Body, bodySize)
4349 t.Logf("backend CopyN: %v, %v", n, err)
4350 <-req.Context().Done()
4351 }))
4352
4353
4354 defer func() {
4355 wg.Wait()
4356 backend.close()
4357 }()
4358
4359 var proxy *clientServerTest
4360 proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4361 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
4362 req2.ContentLength = bodySize
4363 cancel := make(chan struct{})
4364 req2.Cancel = cancel
4365
4366 bresp, err := proxy.c.Do(req2)
4367 if err != nil {
4368 t.Errorf("Proxy outbound request: %v", err)
4369 return
4370 }
4371 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
4372 if err != nil {
4373 t.Errorf("Proxy copy error: %v", err)
4374 return
4375 }
4376 t.Cleanup(func() { bresp.Body.Close() })
4377
4378
4379
4380
4381
4382
4383 if mode == http2Mode {
4384 close(cancel)
4385 } else {
4386 proxy.c.Transport.(*Transport).CancelRequest(req2)
4387 }
4388 rw.Write([]byte("OK"))
4389 }))
4390 defer proxy.close()
4391
4392 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
4393 res, err := proxy.c.Do(req)
4394 if err != nil {
4395 return fmt.Errorf("original request: %v", err)
4396 }
4397 res.Body.Close()
4398 return nil
4399 })
4400 }
4401
4402
4403
4404
4405 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
4406 run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
4407 }
4408 func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
4409 if testing.Short() {
4410 t.Skip("skipping in -short mode")
4411 }
4412
4413 readErrCh := make(chan error, 1)
4414 errCh := make(chan error, 2)
4415
4416 server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4417 go func(body io.Reader) {
4418 _, err := body.Read(make([]byte, 100))
4419 readErrCh <- err
4420 }(req.Body)
4421 time.Sleep(500 * time.Millisecond)
4422 })).ts
4423
4424 closeConn := make(chan bool)
4425 defer close(closeConn)
4426 go func() {
4427 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4428 if err != nil {
4429 errCh <- err
4430 return
4431 }
4432 defer conn.Close()
4433 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4434 if err != nil {
4435 errCh <- err
4436 return
4437 }
4438
4439
4440 <-closeConn
4441 }()
4442 select {
4443 case err := <-readErrCh:
4444 if err == nil {
4445 t.Error("Read was nil. Expected error.")
4446 }
4447 case err := <-errCh:
4448 t.Error(err)
4449 }
4450 }
4451
4452
4453 func TestResponseWriterWriteString(t *testing.T) {
4454 okc := make(chan bool, 1)
4455 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4456 _, ok := w.(io.StringWriter)
4457 okc <- ok
4458 }))
4459 ht.rawResponse("GET / HTTP/1.0")
4460 select {
4461 case ok := <-okc:
4462 if !ok {
4463 t.Error("ResponseWriter did not implement io.StringWriter")
4464 }
4465 default:
4466 t.Error("handler was never called")
4467 }
4468 }
4469
4470 func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
4471 func testServerConnState(t *testing.T, mode testMode) {
4472 handler := map[string]func(w ResponseWriter, r *Request){
4473 "/": func(w ResponseWriter, r *Request) {
4474 fmt.Fprintf(w, "Hello.")
4475 },
4476 "/close": func(w ResponseWriter, r *Request) {
4477 w.Header().Set("Connection", "close")
4478 fmt.Fprintf(w, "Hello.")
4479 },
4480 "/hijack": func(w ResponseWriter, r *Request) {
4481 c, _, _ := w.(Hijacker).Hijack()
4482 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4483 c.Close()
4484 },
4485 "/hijack-panic": func(w ResponseWriter, r *Request) {
4486 c, _, _ := w.(Hijacker).Hijack()
4487 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4488 c.Close()
4489 panic("intentional panic")
4490 },
4491 }
4492
4493
4494 type stateLog struct {
4495 active net.Conn
4496 got []ConnState
4497 want []ConnState
4498 complete chan<- struct{}
4499 }
4500 activeLog := make(chan *stateLog, 1)
4501
4502
4503
4504
4505 wantLog := func(doRequests func(), want ...ConnState) {
4506 t.Helper()
4507 complete := make(chan struct{})
4508 activeLog <- &stateLog{want: want, complete: complete}
4509
4510 doRequests()
4511
4512 <-complete
4513 sl := <-activeLog
4514 if !slices.Equal(sl.got, sl.want) {
4515 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4516 }
4517
4518
4519
4520 }
4521
4522 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4523 handler[r.URL.Path](w, r)
4524 }), func(ts *httptest.Server) {
4525 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4526 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4527 if c == nil {
4528 t.Errorf("nil conn seen in state %s", state)
4529 return
4530 }
4531 sl := <-activeLog
4532 if sl.active == nil && state == StateNew {
4533 sl.active = c
4534 } else if sl.active != c {
4535 t.Errorf("unexpected conn in state %s", state)
4536 activeLog <- sl
4537 return
4538 }
4539 sl.got = append(sl.got, state)
4540 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !slices.Equal(sl.got, sl.want[:len(sl.got)])) {
4541 close(sl.complete)
4542 sl.complete = nil
4543 }
4544 activeLog <- sl
4545 }
4546 }).ts
4547 defer func() {
4548 activeLog <- &stateLog{}
4549 ts.Close()
4550 }()
4551
4552 c := ts.Client()
4553
4554 mustGet := func(url string, headers ...string) {
4555 t.Helper()
4556 req, err := NewRequest("GET", url, nil)
4557 if err != nil {
4558 t.Fatal(err)
4559 }
4560 for len(headers) > 0 {
4561 req.Header.Add(headers[0], headers[1])
4562 headers = headers[2:]
4563 }
4564 res, err := c.Do(req)
4565 if err != nil {
4566 t.Errorf("Error fetching %s: %v", url, err)
4567 return
4568 }
4569 _, err = io.ReadAll(res.Body)
4570 defer res.Body.Close()
4571 if err != nil {
4572 t.Errorf("Error reading %s: %v", url, err)
4573 }
4574 }
4575
4576 wantLog(func() {
4577 mustGet(ts.URL + "/")
4578 mustGet(ts.URL + "/close")
4579 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4580
4581 wantLog(func() {
4582 mustGet(ts.URL + "/")
4583 mustGet(ts.URL+"/", "Connection", "close")
4584 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4585
4586 wantLog(func() {
4587 mustGet(ts.URL + "/hijack")
4588 }, StateNew, StateActive, StateHijacked)
4589
4590 wantLog(func() {
4591 mustGet(ts.URL + "/hijack-panic")
4592 }, StateNew, StateActive, StateHijacked)
4593
4594 wantLog(func() {
4595 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4596 if err != nil {
4597 t.Fatal(err)
4598 }
4599 c.Close()
4600 }, StateNew, StateClosed)
4601
4602 wantLog(func() {
4603 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4604 if err != nil {
4605 t.Fatal(err)
4606 }
4607 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4608 t.Fatal(err)
4609 }
4610 c.Read(make([]byte, 1))
4611 c.Close()
4612 }, StateNew, StateActive, StateClosed)
4613
4614 wantLog(func() {
4615 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4616 if err != nil {
4617 t.Fatal(err)
4618 }
4619 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4620 t.Fatal(err)
4621 }
4622 res, err := ReadResponse(bufio.NewReader(c), nil)
4623 if err != nil {
4624 t.Fatal(err)
4625 }
4626 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4627 t.Fatal(err)
4628 }
4629 c.Close()
4630 }, StateNew, StateActive, StateIdle, StateClosed)
4631 }
4632
4633 func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
4634 run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
4635 }
4636 func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
4637 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4638 }), func(ts *httptest.Server) {
4639 ts.Config.SetKeepAlivesEnabled(false)
4640 }).ts
4641 res, err := ts.Client().Get(ts.URL)
4642 if err != nil {
4643 t.Fatal(err)
4644 }
4645 defer res.Body.Close()
4646 if !res.Close {
4647 t.Errorf("Body.Close == false; want true")
4648 }
4649 }
4650
4651
4652 func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
4653 func testServerEmptyBodyRace(t *testing.T, mode testMode) {
4654 var n int32
4655 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4656 atomic.AddInt32(&n, 1)
4657 }), optQuietLog)
4658 var wg sync.WaitGroup
4659 const reqs = 20
4660 for i := 0; i < reqs; i++ {
4661 wg.Add(1)
4662 go func() {
4663 defer wg.Done()
4664 res, err := cst.c.Get(cst.ts.URL)
4665 if err != nil {
4666
4667
4668 time.Sleep(10 * time.Millisecond)
4669 res, err = cst.c.Get(cst.ts.URL)
4670 if err != nil {
4671 t.Error(err)
4672 return
4673 }
4674 }
4675 defer res.Body.Close()
4676 _, err = io.Copy(io.Discard, res.Body)
4677 if err != nil {
4678 t.Error(err)
4679 return
4680 }
4681 }()
4682 }
4683 wg.Wait()
4684 if got := atomic.LoadInt32(&n); got != reqs {
4685 t.Errorf("handler ran %d times; want %d", got, reqs)
4686 }
4687 }
4688
4689 func TestServerConnStateNew(t *testing.T) {
4690 sawNew := false
4691 srv := &Server{
4692 ConnState: func(c net.Conn, state ConnState) {
4693 if state == StateNew {
4694 sawNew = true
4695 }
4696 },
4697 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4698 }
4699 srv.Serve(&oneConnListener{
4700 conn: &rwTestConn{
4701 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4702 Writer: io.Discard,
4703 },
4704 })
4705 if !sawNew {
4706 t.Error("StateNew not seen")
4707 }
4708 }
4709
4710 type closeWriteTestConn struct {
4711 rwTestConn
4712 didCloseWrite bool
4713 }
4714
4715 func (c *closeWriteTestConn) CloseWrite() error {
4716 c.didCloseWrite = true
4717 return nil
4718 }
4719
4720 func TestCloseWrite(t *testing.T) {
4721 SetRSTAvoidanceDelay(t, 1*time.Millisecond)
4722
4723 var srv Server
4724 var testConn closeWriteTestConn
4725 c := ExportServerNewConn(&srv, &testConn)
4726 ExportCloseWriteAndWait(c)
4727 if !testConn.didCloseWrite {
4728 t.Error("didn't see CloseWrite call")
4729 }
4730 }
4731
4732
4733
4734
4735
4736
4737
4738
4739 func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
4740 func testServerFlushAndHijack(t *testing.T, mode testMode) {
4741 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4742 io.WriteString(w, "Hello, ")
4743 w.(Flusher).Flush()
4744 conn, buf, _ := w.(Hijacker).Hijack()
4745 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4746 if err := buf.Flush(); err != nil {
4747 t.Error(err)
4748 }
4749 if err := conn.Close(); err != nil {
4750 t.Error(err)
4751 }
4752 })).ts
4753 res, err := Get(ts.URL)
4754 if err != nil {
4755 t.Fatal(err)
4756 }
4757 defer res.Body.Close()
4758 all, err := io.ReadAll(res.Body)
4759 if err != nil {
4760 t.Fatal(err)
4761 }
4762 if want := "Hello, world!"; string(all) != want {
4763 t.Errorf("Got %q; want %q", all, want)
4764 }
4765 }
4766
4767
4768
4769
4770
4771
4772
4773 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4774 run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
4775 }
4776 func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
4777 if testing.Short() {
4778 t.Skip("skipping in -short mode")
4779 }
4780 const numReq = 3
4781 addrc := make(chan string, numReq)
4782 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4783 addrc <- r.RemoteAddr
4784 time.Sleep(500 * time.Millisecond)
4785 w.(Flusher).Flush()
4786 }), func(ts *httptest.Server) {
4787 ts.Config.WriteTimeout = 250 * time.Millisecond
4788 }).ts
4789
4790 errc := make(chan error, numReq)
4791 go func() {
4792 defer close(errc)
4793 for i := 0; i < numReq; i++ {
4794 res, err := Get(ts.URL)
4795 if res != nil {
4796 res.Body.Close()
4797 }
4798 errc <- err
4799 }
4800 }()
4801
4802 addrSeen := map[string]bool{}
4803 numOkay := 0
4804 for {
4805 select {
4806 case v := <-addrc:
4807 addrSeen[v] = true
4808 case err, ok := <-errc:
4809 if !ok {
4810 if len(addrSeen) != numReq {
4811 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4812 }
4813 if numOkay != 0 {
4814 t.Errorf("got %d successful client requests; want 0", numOkay)
4815 }
4816 return
4817 }
4818 if err == nil {
4819 numOkay++
4820 }
4821 }
4822 }
4823 }
4824
4825
4826
4827 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4828 run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
4829 }
4830 func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
4831 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4832 w.Header().Set("Transfer-Encoding", "foo")
4833 io.WriteString(w, "<html>")
4834 })).ts
4835 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4836 if err != nil {
4837 t.Fatalf("Dial: %v", err)
4838 }
4839 defer c.Close()
4840 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4841 t.Fatal(err)
4842 }
4843 bs := bufio.NewScanner(c)
4844 var got strings.Builder
4845 for bs.Scan() {
4846 if strings.TrimSpace(bs.Text()) == "" {
4847 break
4848 }
4849 got.WriteString(bs.Text())
4850 got.WriteByte('\n')
4851 }
4852 if err := bs.Err(); err != nil {
4853 t.Fatal(err)
4854 }
4855 if strings.Contains(got.String(), "Content-Length") {
4856 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4857 }
4858 if strings.Contains(got.String(), "Content-Type") {
4859 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4860 }
4861 }
4862
4863
4864
4865 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4866 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4867 "\r\n\r\n" +
4868 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4869 var buf bytes.Buffer
4870 conn := &rwTestConn{
4871 Reader: bytes.NewReader(req),
4872 Writer: &buf,
4873 closec: make(chan bool, 1),
4874 }
4875 ln := &oneConnListener{conn: conn}
4876 numReq := 0
4877 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4878 numReq++
4879 }))
4880 <-conn.closec
4881 if numReq != 2 {
4882 t.Errorf("num requests = %d; want 2", numReq)
4883 t.Logf("Res: %s", buf.Bytes())
4884 }
4885 }
4886
4887 func TestIssue13893_Expect100(t *testing.T) {
4888
4889 req := reqBytes(`PUT /readbody HTTP/1.1
4890 User-Agent: PycURL/7.22.0
4891 Host: 127.0.0.1:9000
4892 Accept: */*
4893 Expect: 100-continue
4894 Content-Length: 10
4895
4896 HelloWorld
4897
4898 `)
4899 var buf bytes.Buffer
4900 conn := &rwTestConn{
4901 Reader: bytes.NewReader(req),
4902 Writer: &buf,
4903 closec: make(chan bool, 1),
4904 }
4905 ln := &oneConnListener{conn: conn}
4906 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4907 if _, ok := r.Header["Expect"]; !ok {
4908 t.Error("Expect header should not be filtered out")
4909 }
4910 }))
4911 <-conn.closec
4912 }
4913
4914 func TestIssue11549_Expect100(t *testing.T) {
4915 req := reqBytes(`PUT /readbody HTTP/1.1
4916 User-Agent: PycURL/7.22.0
4917 Host: 127.0.0.1:9000
4918 Accept: */*
4919 Expect: 100-continue
4920 Content-Length: 10
4921
4922 HelloWorldPUT /noreadbody HTTP/1.1
4923 User-Agent: PycURL/7.22.0
4924 Host: 127.0.0.1:9000
4925 Accept: */*
4926 Expect: 100-continue
4927 Content-Length: 10
4928
4929 GET /should-be-ignored HTTP/1.1
4930 Host: foo
4931
4932 `)
4933 var buf strings.Builder
4934 conn := &rwTestConn{
4935 Reader: bytes.NewReader(req),
4936 Writer: &buf,
4937 closec: make(chan bool, 1),
4938 }
4939 ln := &oneConnListener{conn: conn}
4940 numReq := 0
4941 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4942 numReq++
4943 if r.URL.Path == "/readbody" {
4944 io.ReadAll(r.Body)
4945 }
4946 io.WriteString(w, "Hello world!")
4947 }))
4948 <-conn.closec
4949 if numReq != 2 {
4950 t.Errorf("num requests = %d; want 2", numReq)
4951 }
4952 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4953 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4954 }
4955 }
4956
4957
4958
4959 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4960 setParallel(t)
4961 conn := newTestConn()
4962 conn.readBuf.WriteString(
4963 "POST / HTTP/1.1\r\n" +
4964 "Host: test\r\n" +
4965 "Content-Length: 9999999999\r\n" +
4966 "\r\n" + strings.Repeat("a", 1<<20))
4967
4968 ls := &oneConnListener{conn}
4969 var inHandlerLen int
4970 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4971 inHandlerLen = conn.readBuf.Len()
4972 rw.WriteHeader(404)
4973 }))
4974 <-conn.closec
4975 afterHandlerLen := conn.readBuf.Len()
4976
4977 if afterHandlerLen != inHandlerLen {
4978 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4979 }
4980 }
4981
4982 func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
4983 func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
4984 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4985 r.Body = nil
4986 fmt.Fprintf(w, "%v", r.RemoteAddr)
4987 }))
4988 get := func() string {
4989 res, err := cst.c.Get(cst.ts.URL)
4990 if err != nil {
4991 t.Fatal(err)
4992 }
4993 defer res.Body.Close()
4994 slurp, err := io.ReadAll(res.Body)
4995 if err != nil {
4996 t.Fatal(err)
4997 }
4998 return string(slurp)
4999 }
5000 a, b := get(), get()
5001 if a != b {
5002 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
5003 }
5004 }
5005
5006
5007
5008 func TestServerValidatesHostHeader(t *testing.T) {
5009 tests := []struct {
5010 proto string
5011 host string
5012 want int
5013 }{
5014 {"HTTP/0.9", "", 505},
5015
5016 {"HTTP/1.1", "", 400},
5017 {"HTTP/1.1", "Host: \r\n", 200},
5018 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
5019 {"HTTP/1.1", "Host: foo.com\r\n", 200},
5020 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
5021 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
5022 {"HTTP/1.1", "Host: ::1\r\n", 200},
5023 {"HTTP/1.1", "Host: [::1]\r\n", 200},
5024 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
5025 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
5026 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
5027 {"HTTP/1.1", "Host: \x06\r\n", 400},
5028 {"HTTP/1.1", "Host: \xff\r\n", 400},
5029 {"HTTP/1.1", "Host: {\r\n", 400},
5030 {"HTTP/1.1", "Host: }\r\n", 400},
5031 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
5032
5033
5034
5035 {"HTTP/1.0", "", 200},
5036 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
5037 {"HTTP/1.0", "Host: \xff\r\n", 400},
5038
5039
5040 {"PRI * HTTP/2.0", "", 200},
5041
5042
5043 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
5044
5045
5046 {"PRI / HTTP/2.0", "", 505},
5047 {"GET / HTTP/2.0", "", 505},
5048 {"GET / HTTP/3.0", "", 505},
5049 }
5050 for _, tt := range tests {
5051 conn := newTestConn()
5052 methodTarget := "GET / "
5053 if !strings.HasPrefix(tt.proto, "HTTP/") {
5054 methodTarget = ""
5055 }
5056 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
5057
5058 ln := &oneConnListener{conn}
5059 srv := Server{
5060 ErrorLog: quietLog,
5061 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5062 }
5063 go srv.Serve(ln)
5064 <-conn.closec
5065 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5066 if err != nil {
5067 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
5068 continue
5069 }
5070 if res.StatusCode != tt.want {
5071 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
5072 }
5073 }
5074 }
5075
5076 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
5077 run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
5078 }
5079 func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
5080 const upgradeResponse = "upgrade here"
5081 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5082 conn, br, err := w.(Hijacker).Hijack()
5083 if err != nil {
5084 t.Error(err)
5085 return
5086 }
5087 defer conn.Close()
5088 if r.Method != "PRI" || r.RequestURI != "*" {
5089 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
5090 return
5091 }
5092 if !r.Close {
5093 t.Errorf("Request.Close = true; want false")
5094 }
5095 const want = "SM\r\n\r\n"
5096 buf := make([]byte, len(want))
5097 n, err := io.ReadFull(br, buf)
5098 if err != nil || string(buf[:n]) != want {
5099 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
5100 return
5101 }
5102 io.WriteString(conn, upgradeResponse)
5103 })).ts
5104
5105 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5106 if err != nil {
5107 t.Fatalf("Dial: %v", err)
5108 }
5109 defer c.Close()
5110 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
5111 slurp, err := io.ReadAll(c)
5112 if err != nil {
5113 t.Fatal(err)
5114 }
5115 if string(slurp) != upgradeResponse {
5116 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
5117 }
5118 }
5119
5120
5121
5122 func TestServerValidatesHeaders(t *testing.T) {
5123 setParallel(t)
5124 tests := []struct {
5125 header string
5126 want int
5127 }{
5128 {"", 200},
5129 {"Foo: bar\r\n", 200},
5130 {"X-Foo: bar\r\n", 200},
5131 {"Foo: a space\r\n", 200},
5132
5133 {"A space: foo\r\n", 400},
5134 {"foo\xffbar: foo\r\n", 400},
5135 {"foo\x00bar: foo\r\n", 400},
5136 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
5137
5138
5139 {"Foo : bar\r\n", 400},
5140 {"Foo\t: bar\r\n", 400},
5141
5142
5143
5144 {": empty key\r\n", 400},
5145
5146
5147
5148
5149 {"Content-Length: notdigits\r\n", 400},
5150 {"Content-Length: notdigits\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n", 400},
5151
5152 {"foo: foo foo\r\n", 200},
5153 {"foo: foo\tfoo\r\n", 200},
5154 {"foo: foo\x00foo\r\n", 400},
5155 {"foo: foo\x7ffoo\r\n", 400},
5156 {"foo: foo\xfffoo\r\n", 200},
5157 }
5158 for _, tt := range tests {
5159 conn := newTestConn()
5160 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
5161
5162 ln := &oneConnListener{conn}
5163 srv := Server{
5164 ErrorLog: quietLog,
5165 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5166 }
5167 go srv.Serve(ln)
5168 <-conn.closec
5169 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5170 if err != nil {
5171 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
5172 continue
5173 }
5174 if res.StatusCode != tt.want {
5175 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
5176 }
5177 }
5178 }
5179
5180 func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
5181 run(t, testServerRequestContextCancel_ServeHTTPDone)
5182 }
5183 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
5184 ctxc := make(chan context.Context, 1)
5185 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5186 ctx := r.Context()
5187 select {
5188 case <-ctx.Done():
5189 t.Error("should not be Done in ServeHTTP")
5190 default:
5191 }
5192 ctxc <- ctx
5193 }))
5194 res, err := cst.c.Get(cst.ts.URL)
5195 if err != nil {
5196 t.Fatal(err)
5197 }
5198 res.Body.Close()
5199 ctx := <-ctxc
5200 select {
5201 case <-ctx.Done():
5202 default:
5203 t.Error("context should be done after ServeHTTP completes")
5204 }
5205 }
5206
5207
5208
5209
5210
5211 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
5212 run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
5213 }
5214 func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
5215 inHandler := make(chan struct{})
5216 handlerDone := make(chan struct{})
5217 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5218 close(inHandler)
5219 <-r.Context().Done()
5220 close(handlerDone)
5221 })).ts
5222 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5223 if err != nil {
5224 t.Fatal(err)
5225 }
5226 defer c.Close()
5227 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
5228 <-inHandler
5229 c.Close()
5230 <-handlerDone
5231 }
5232
5233 func TestServerContext_ServerContextKey(t *testing.T) {
5234 run(t, testServerContext_ServerContextKey)
5235 }
5236 func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
5237 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5238 ctx := r.Context()
5239 got := ctx.Value(ServerContextKey)
5240 if _, ok := got.(*Server); !ok {
5241 t.Errorf("context value = %T; want *http.Server", got)
5242 }
5243 }))
5244 res, err := cst.c.Get(cst.ts.URL)
5245 if err != nil {
5246 t.Fatal(err)
5247 }
5248 res.Body.Close()
5249 }
5250
5251 func TestServerContext_LocalAddrContextKey(t *testing.T) {
5252 run(t, testServerContext_LocalAddrContextKey)
5253 }
5254 func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
5255 ch := make(chan any, 1)
5256 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5257 ch <- r.Context().Value(LocalAddrContextKey)
5258 }))
5259 if _, err := cst.c.Head(cst.ts.URL); err != nil {
5260 t.Fatal(err)
5261 }
5262
5263 host := cst.ts.Listener.Addr().String()
5264 got := <-ch
5265 if addr, ok := got.(net.Addr); !ok {
5266 t.Errorf("local addr value = %T; want net.Addr", got)
5267 } else if fmt.Sprint(addr) != host {
5268 t.Errorf("local addr = %v; want %v", addr, host)
5269 }
5270 }
5271
5272
5273 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
5274 setParallel(t)
5275 defer afterTest(t)
5276 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5277 w.Header().Set("Transfer-Encoding", "chunked")
5278 w.Write([]byte("hello"))
5279 }))
5280 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5281 const hdr = "Transfer-Encoding: chunked"
5282 if n := strings.Count(resp, hdr); n != 1 {
5283 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5284 }
5285 }
5286
5287
5288 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
5289 setParallel(t)
5290 defer afterTest(t)
5291 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5292 w.Header().Set("Transfer-Encoding", "gzip")
5293 gz := gzip.NewWriter(w)
5294 gz.Write([]byte("hello"))
5295 gz.Close()
5296 }))
5297 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5298 for _, v := range []string{"gzip", "chunked"} {
5299 hdr := "Transfer-Encoding: " + v
5300 if n := strings.Count(resp, hdr); n != 1 {
5301 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5302 }
5303 }
5304 }
5305
5306 func BenchmarkClientServer(b *testing.B) {
5307 run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
5308 }
5309 func benchmarkClientServer(b *testing.B, mode testMode) {
5310 b.ReportAllocs()
5311 b.StopTimer()
5312 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5313 fmt.Fprintf(rw, "Hello world.\n")
5314 })).ts
5315 b.StartTimer()
5316
5317 c := ts.Client()
5318 for i := 0; i < b.N; i++ {
5319 res, err := c.Get(ts.URL)
5320 if err != nil {
5321 b.Fatal("Get:", err)
5322 }
5323 all, err := io.ReadAll(res.Body)
5324 res.Body.Close()
5325 if err != nil {
5326 b.Fatal("ReadAll:", err)
5327 }
5328 body := string(all)
5329 if body != "Hello world.\n" {
5330 b.Fatal("Got body:", body)
5331 }
5332 }
5333
5334 b.StopTimer()
5335 }
5336
5337 func BenchmarkClientServerParallel(b *testing.B) {
5338 for _, parallelism := range []int{4, 64} {
5339 b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
5340 run(b, func(b *testing.B, mode testMode) {
5341 benchmarkClientServerParallel(b, parallelism, mode)
5342 }, []testMode{http1Mode, https1Mode, http2Mode})
5343 })
5344 }
5345 }
5346
5347 func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
5348 b.ReportAllocs()
5349 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5350 fmt.Fprintf(rw, "Hello world.\n")
5351 })).ts
5352 b.ResetTimer()
5353 b.SetParallelism(parallelism)
5354 b.RunParallel(func(pb *testing.PB) {
5355 c := ts.Client()
5356 for pb.Next() {
5357 res, err := c.Get(ts.URL)
5358 if err != nil {
5359 b.Logf("Get: %v", err)
5360 continue
5361 }
5362 all, err := io.ReadAll(res.Body)
5363 res.Body.Close()
5364 if err != nil {
5365 b.Logf("ReadAll: %v", err)
5366 continue
5367 }
5368 body := string(all)
5369 if body != "Hello world.\n" {
5370 panic("Got body: " + body)
5371 }
5372 }
5373 })
5374 }
5375
5376
5377
5378
5379
5380
5381
5382
5383
5384
5385 func BenchmarkServer(b *testing.B) {
5386 b.ReportAllocs()
5387
5388 if url := os.Getenv("GO_TEST_BENCH_SERVER_URL"); url != "" {
5389 n, err := strconv.Atoi(os.Getenv("GO_TEST_BENCH_CLIENT_N"))
5390 if err != nil {
5391 panic(err)
5392 }
5393 for i := 0; i < n; i++ {
5394 res, err := Get(url)
5395 if err != nil {
5396 log.Panicf("Get: %v", err)
5397 }
5398 all, err := io.ReadAll(res.Body)
5399 res.Body.Close()
5400 if err != nil {
5401 log.Panicf("ReadAll: %v", err)
5402 }
5403 body := string(all)
5404 if body != "Hello world.\n" {
5405 log.Panicf("Got body: %q", body)
5406 }
5407 }
5408 os.Exit(0)
5409 return
5410 }
5411
5412 var res = []byte("Hello world.\n")
5413 b.StopTimer()
5414 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5415 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5416 rw.Write(res)
5417 }))
5418 defer ts.Close()
5419 b.StartTimer()
5420
5421 cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$")
5422 cmd.Env = append([]string{
5423 fmt.Sprintf("GO_TEST_BENCH_CLIENT_N=%d", b.N),
5424 fmt.Sprintf("GO_TEST_BENCH_SERVER_URL=%s", ts.URL),
5425 }, os.Environ()...)
5426 out, err := cmd.CombinedOutput()
5427 if err != nil {
5428 b.Errorf("Test failure: %v, with output: %s", err, out)
5429 }
5430 }
5431
5432
5433 func getNoBody(urlStr string) (*Response, error) {
5434 res, err := Get(urlStr)
5435 if err != nil {
5436 return nil, err
5437 }
5438 res.Body.Close()
5439 return res, nil
5440 }
5441
5442
5443
5444 func BenchmarkClient(b *testing.B) {
5445 var data = []byte("Hello world.\n")
5446
5447 url := startClientBenchmarkServer(b, HandlerFunc(func(w ResponseWriter, _ *Request) {
5448 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5449 w.Write(data)
5450 }))
5451
5452
5453 b.StartTimer()
5454 for i := 0; i < b.N; i++ {
5455 res, err := Get(url)
5456 if err != nil {
5457 b.Fatalf("Get: %v", err)
5458 }
5459 body, err := io.ReadAll(res.Body)
5460 res.Body.Close()
5461 if err != nil {
5462 b.Fatalf("ReadAll: %v", err)
5463 }
5464 if !bytes.Equal(body, data) {
5465 b.Fatalf("Got body: %q", body)
5466 }
5467 }
5468 b.StopTimer()
5469 }
5470
5471 func startClientBenchmarkServer(b *testing.B, handler Handler) string {
5472 b.ReportAllocs()
5473 b.StopTimer()
5474
5475 if server := os.Getenv("GO_TEST_BENCH_SERVER"); server != "" {
5476
5477 port := os.Getenv("GO_TEST_BENCH_SERVER_PORT")
5478 if port == "" {
5479 port = "0"
5480 }
5481 ln, err := net.Listen("tcp", "localhost:"+port)
5482 if err != nil {
5483 log.Fatal(err)
5484 }
5485 fmt.Println(ln.Addr().String())
5486
5487 HandleFunc("/", func(w ResponseWriter, r *Request) {
5488 r.ParseForm()
5489 if r.Form.Get("stop") != "" {
5490 os.Exit(0)
5491 }
5492 handler.ServeHTTP(w, r)
5493 })
5494 var srv Server
5495 log.Fatal(srv.Serve(ln))
5496 }
5497
5498
5499 ctx, cancel := context.WithCancel(context.Background())
5500 cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^"+b.Name()+"$")
5501 cmd.Env = append(cmd.Environ(), "GO_TEST_BENCH_SERVER=yes")
5502 cmd.Stderr = os.Stderr
5503 stdout, err := cmd.StdoutPipe()
5504 if err != nil {
5505 b.Fatal(err)
5506 }
5507 if err := cmd.Start(); err != nil {
5508 b.Fatalf("subprocess failed to start: %v", err)
5509 }
5510
5511 done := make(chan error, 1)
5512 go func() {
5513 done <- cmd.Wait()
5514 close(done)
5515 }()
5516
5517
5518
5519 bs := bufio.NewScanner(stdout)
5520 if !bs.Scan() {
5521 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5522 }
5523 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5524 if _, err := getNoBody(url); err != nil {
5525 b.Fatalf("initial probe of child process failed: %v", err)
5526 }
5527
5528
5529 b.Cleanup(func() {
5530 getNoBody(url + "?stop=yes")
5531 if err := <-done; err != nil {
5532 b.Fatalf("subprocess failed: %v", err)
5533 }
5534
5535 cancel()
5536 <-done
5537
5538 afterTest(b)
5539 })
5540
5541 return url
5542 }
5543
5544 func BenchmarkClientGzip(b *testing.B) {
5545 const responseSize = 1024 * 1024
5546
5547 var buf bytes.Buffer
5548 gz := gzip.NewWriter(&buf)
5549 if _, err := io.CopyN(gz, crand.Reader, responseSize); err != nil {
5550 b.Fatal(err)
5551 }
5552 gz.Close()
5553
5554 data := buf.Bytes()
5555
5556 url := startClientBenchmarkServer(b, HandlerFunc(func(w ResponseWriter, _ *Request) {
5557 w.Header().Set("Content-Encoding", "gzip")
5558 w.Write(data)
5559 }))
5560
5561
5562 b.StartTimer()
5563 for i := 0; i < b.N; i++ {
5564 res, err := Get(url)
5565 if err != nil {
5566 b.Fatalf("Get: %v", err)
5567 }
5568 n, err := io.Copy(io.Discard, res.Body)
5569 res.Body.Close()
5570 if err != nil {
5571 b.Fatalf("ReadAll: %v", err)
5572 }
5573 if n != responseSize {
5574 b.Fatalf("ReadAll: expected %d bytes, got %d", responseSize, n)
5575 }
5576 }
5577 b.StopTimer()
5578 }
5579
5580 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5581 b.ReportAllocs()
5582 req := reqBytes(`GET / HTTP/1.0
5583 Host: golang.org
5584 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5585 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5586 Accept-Encoding: gzip,deflate,sdch
5587 Accept-Language: en-US,en;q=0.8
5588 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5589 `)
5590 res := []byte("Hello world!\n")
5591
5592 conn := newTestConn()
5593 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5594 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5595 rw.Write(res)
5596 })
5597 ln := new(oneConnListener)
5598 for i := 0; i < b.N; i++ {
5599 conn.readBuf.Reset()
5600 conn.writeBuf.Reset()
5601 conn.readBuf.Write(req)
5602 ln.conn = conn
5603 Serve(ln, handler)
5604 <-conn.closec
5605 }
5606 }
5607
5608
5609 type repeatReader struct {
5610 content []byte
5611 count int
5612 off int
5613 }
5614
5615 func (r *repeatReader) Read(p []byte) (n int, err error) {
5616 if r.count <= 0 {
5617 return 0, io.EOF
5618 }
5619 n = copy(p, r.content[r.off:])
5620 r.off += n
5621 if r.off == len(r.content) {
5622 r.count--
5623 r.off = 0
5624 }
5625 return
5626 }
5627
5628 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5629 b.ReportAllocs()
5630
5631 req := reqBytes(`GET / HTTP/1.1
5632 Host: golang.org
5633 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5634 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5635 Accept-Encoding: gzip,deflate,sdch
5636 Accept-Language: en-US,en;q=0.8
5637 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5638 `)
5639 res := []byte("Hello world!\n")
5640
5641 conn := &rwTestConn{
5642 Reader: &repeatReader{content: req, count: b.N},
5643 Writer: io.Discard,
5644 closec: make(chan bool, 1),
5645 }
5646 handled := 0
5647 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5648 handled++
5649 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5650 rw.Write(res)
5651 })
5652 ln := &oneConnListener{conn: conn}
5653 go Serve(ln, handler)
5654 <-conn.closec
5655 if b.N != handled {
5656 b.Errorf("b.N=%d but handled %d", b.N, handled)
5657 }
5658 }
5659
5660
5661
5662 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5663 b.ReportAllocs()
5664
5665 req := reqBytes(`GET / HTTP/1.1
5666 Host: golang.org
5667 `)
5668 res := []byte("Hello world!\n")
5669
5670 conn := &rwTestConn{
5671 Reader: &repeatReader{content: req, count: b.N},
5672 Writer: io.Discard,
5673 closec: make(chan bool, 1),
5674 }
5675 handled := 0
5676 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5677 handled++
5678 rw.Write(res)
5679 })
5680 ln := &oneConnListener{conn: conn}
5681 go Serve(ln, handler)
5682 <-conn.closec
5683 if b.N != handled {
5684 b.Errorf("b.N=%d but handled %d", b.N, handled)
5685 }
5686 }
5687
5688 const someResponse = "<html>some response</html>"
5689
5690
5691 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5692
5693
5694 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5695 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5696 w.Header().Set("Content-Type", "text/html")
5697 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5698 w.Write(response)
5699 }))
5700 }
5701
5702
5703 func BenchmarkServerHandlerNoLen(b *testing.B) {
5704 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5705 w.Header().Set("Content-Type", "text/html")
5706 w.Write(response)
5707 }))
5708 }
5709
5710
5711 func BenchmarkServerHandlerNoType(b *testing.B) {
5712 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5713 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5714 w.Write(response)
5715 }))
5716 }
5717
5718
5719 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5720 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5721 w.Write(response)
5722 }))
5723 }
5724
5725 func benchmarkHandler(b *testing.B, h Handler) {
5726 b.ReportAllocs()
5727 req := reqBytes(`GET / HTTP/1.1
5728 Host: golang.org
5729 `)
5730 conn := &rwTestConn{
5731 Reader: &repeatReader{content: req, count: b.N},
5732 Writer: io.Discard,
5733 closec: make(chan bool, 1),
5734 }
5735 handled := 0
5736 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5737 handled++
5738 h.ServeHTTP(rw, r)
5739 })
5740 ln := &oneConnListener{conn: conn}
5741 go Serve(ln, handler)
5742 <-conn.closec
5743 if b.N != handled {
5744 b.Errorf("b.N=%d but handled %d", b.N, handled)
5745 }
5746 }
5747
5748 func BenchmarkServerHijack(b *testing.B) {
5749 b.ReportAllocs()
5750 req := reqBytes(`GET / HTTP/1.1
5751 Host: golang.org
5752 `)
5753 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5754 conn, _, err := w.(Hijacker).Hijack()
5755 if err != nil {
5756 panic(err)
5757 }
5758 conn.Close()
5759 })
5760 conn := &rwTestConn{
5761 Writer: io.Discard,
5762 closec: make(chan bool, 1),
5763 }
5764 ln := &oneConnListener{conn: conn}
5765 for i := 0; i < b.N; i++ {
5766 conn.Reader = bytes.NewReader(req)
5767 ln.conn = conn
5768 Serve(ln, h)
5769 <-conn.closec
5770 }
5771 }
5772
5773 func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
5774 func benchmarkCloseNotifier(b *testing.B, mode testMode) {
5775 b.ReportAllocs()
5776 b.StopTimer()
5777 sawClose := make(chan bool)
5778 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
5779 <-rw.(CloseNotifier).CloseNotify()
5780 sawClose <- true
5781 })).ts
5782 b.StartTimer()
5783 for i := 0; i < b.N; i++ {
5784 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5785 if err != nil {
5786 b.Fatalf("error dialing: %v", err)
5787 }
5788 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5789 if err != nil {
5790 b.Fatal(err)
5791 }
5792 conn.Close()
5793 <-sawClose
5794 }
5795 b.StopTimer()
5796 }
5797
5798
5799 func TestConcurrentServerServe(t *testing.T) {
5800 setParallel(t)
5801 for i := 0; i < 100; i++ {
5802 ln1 := &oneConnListener{conn: nil}
5803 ln2 := &oneConnListener{conn: nil}
5804 srv := Server{}
5805 go func() { srv.Serve(ln1) }()
5806 go func() { srv.Serve(ln2) }()
5807 }
5808 }
5809
5810 func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
5811 func testServerIdleTimeout(t *testing.T, mode testMode) {
5812 if testing.Short() {
5813 t.Skip("skipping in short mode")
5814 }
5815 runTimeSensitiveTest(t, []time.Duration{
5816 10 * time.Millisecond,
5817 100 * time.Millisecond,
5818 1 * time.Second,
5819 10 * time.Second,
5820 }, func(t *testing.T, readHeaderTimeout time.Duration) error {
5821 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5822 io.Copy(io.Discard, r.Body)
5823 io.WriteString(w, r.RemoteAddr)
5824 }), func(ts *httptest.Server) {
5825 ts.Config.ReadHeaderTimeout = readHeaderTimeout
5826 ts.Config.IdleTimeout = 2 * readHeaderTimeout
5827 })
5828 defer cst.close()
5829 ts := cst.ts
5830 t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
5831 t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
5832 c := ts.Client()
5833
5834 get := func() (string, error) {
5835 res, err := c.Get(ts.URL)
5836 if err != nil {
5837 return "", err
5838 }
5839 defer res.Body.Close()
5840 slurp, err := io.ReadAll(res.Body)
5841 if err != nil {
5842
5843
5844
5845 t.Fatal(err)
5846 }
5847 return string(slurp), nil
5848 }
5849
5850 a1, err := get()
5851 if err != nil {
5852 return err
5853 }
5854 a2, err := get()
5855 if err != nil {
5856 return err
5857 }
5858 if a1 != a2 {
5859 return fmt.Errorf("did requests on different connections")
5860 }
5861 time.Sleep(ts.Config.IdleTimeout * 3 / 2)
5862 a3, err := get()
5863 if err != nil {
5864 return err
5865 }
5866 if a2 == a3 {
5867 return fmt.Errorf("request three unexpectedly on same connection")
5868 }
5869
5870
5871 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5872 if err != nil {
5873 return err
5874 }
5875 defer conn.Close()
5876 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5877 time.Sleep(ts.Config.ReadHeaderTimeout * 2)
5878 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5879 return fmt.Errorf("copy byte succeeded; want err")
5880 }
5881
5882 return nil
5883 })
5884 }
5885
5886 func get(t *testing.T, c *Client, url string) string {
5887 res, err := c.Get(url)
5888 if err != nil {
5889 t.Fatal(err)
5890 }
5891 defer res.Body.Close()
5892 slurp, err := io.ReadAll(res.Body)
5893 if err != nil {
5894 t.Fatal(err)
5895 }
5896 return string(slurp)
5897 }
5898
5899
5900
5901 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5902 run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
5903 }
5904 func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
5905 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5906 io.WriteString(w, r.RemoteAddr)
5907 })).ts
5908
5909 c := ts.Client()
5910 tr := c.Transport.(*Transport)
5911
5912 get := func() string { return get(t, c, ts.URL) }
5913
5914 a1, a2 := get(), get()
5915 if a1 == a2 {
5916 t.Logf("made two requests from a single conn %q (as expected)", a1)
5917 } else {
5918 t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
5919 }
5920
5921
5922
5923
5924
5925 if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
5926 t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
5927 }
5928
5929
5930 ts.Config.SetKeepAlivesEnabled(false)
5931
5932 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5933 if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
5934 if d > 0 {
5935 t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
5936 }
5937 return false
5938 }
5939 return true
5940 })
5941
5942
5943
5944
5945 }
5946
5947 func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) }
5948 func testServerShutdown(t *testing.T, mode testMode) {
5949 var cst *clientServerTest
5950
5951 var once sync.Once
5952 statesRes := make(chan map[ConnState]int, 1)
5953 shutdownRes := make(chan error, 1)
5954 gotOnShutdown := make(chan struct{})
5955 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5956 first := false
5957 once.Do(func() {
5958 statesRes <- cst.ts.Config.ExportAllConnsByState()
5959 go func() {
5960 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5961 }()
5962 first = true
5963 })
5964
5965 if first {
5966
5967
5968
5969 <-gotOnShutdown
5970
5971
5972 for !t.Failed() {
5973 res, err := cst.c.Get(cst.ts.URL)
5974 if err != nil {
5975 break
5976 }
5977 out, _ := io.ReadAll(res.Body)
5978 res.Body.Close()
5979 if mode == http2Mode {
5980 t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5981 t.Logf("Retrying to work around https://go.dev/issue/59038.")
5982 continue
5983 }
5984 t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5985 }
5986 }
5987
5988 io.WriteString(w, r.RemoteAddr)
5989 })
5990
5991 cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
5992 srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
5993 })
5994
5995 out := get(t, cst.c, cst.ts.URL)
5996 t.Logf("%v: %q", cst.ts.URL, out)
5997
5998 if err := <-shutdownRes; err != nil {
5999 t.Fatalf("Shutdown: %v", err)
6000 }
6001 <-gotOnShutdown
6002
6003 if states := <-statesRes; states[StateActive] != 1 {
6004 t.Errorf("connection in wrong state, %v", states)
6005 }
6006 }
6007
6008 func TestServerShutdownStateNew(t *testing.T) { runSynctest(t, testServerShutdownStateNew) }
6009 func testServerShutdownStateNew(t *testing.T, mode testMode) {
6010 if testing.Short() {
6011 t.Skip("test takes 5-6 seconds; skipping in short mode")
6012 }
6013
6014 listener := fakeNetListen()
6015 defer listener.Close()
6016
6017 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6018
6019 }), func(ts *httptest.Server) {
6020 ts.Listener.Close()
6021 ts.Listener = listener
6022
6023 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
6024 }).ts
6025
6026
6027 c := listener.connect()
6028 defer c.Close()
6029 synctest.Wait()
6030
6031 shutdownRes := runAsync(func() (struct{}, error) {
6032 return struct{}{}, ts.Config.Shutdown(context.Background())
6033 })
6034
6035
6036
6037
6038 const expectTimeout = 5 * time.Second
6039
6040
6041 time.Sleep(expectTimeout - 1)
6042 synctest.Wait()
6043 if shutdownRes.done() {
6044 t.Fatal("shutdown too soon")
6045 }
6046 if c.IsClosedByPeer() {
6047 t.Fatal("connection was closed by server too soon")
6048 }
6049
6050
6051
6052
6053
6054 time.Sleep(2 * time.Second)
6055 synctest.Wait()
6056 if _, err := shutdownRes.result(); err != nil {
6057 t.Fatalf("Shutdown() = %v, want complete", err)
6058 }
6059 if !c.IsClosedByPeer() {
6060 t.Fatalf("connection was not closed by server after shutdown")
6061 }
6062 }
6063
6064
6065 func TestServerCloseDeadlock(t *testing.T) {
6066 var s Server
6067 s.Close()
6068 s.Close()
6069 }
6070
6071
6072
6073 func TestServerKeepAlivesEnabled(t *testing.T) { runSynctest(t, testServerKeepAlivesEnabled) }
6074 func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
6075 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optFakeNet)
6076 defer cst.close()
6077 srv := cst.ts.Config
6078 srv.SetKeepAlivesEnabled(false)
6079 for try := range 2 {
6080 synctest.Wait()
6081 if !srv.ExportAllConnsIdle() {
6082 t.Fatalf("test server still has active conns before request %v", try)
6083 }
6084 conns := 0
6085 var info httptrace.GotConnInfo
6086 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6087 GotConn: func(v httptrace.GotConnInfo) {
6088 conns++
6089 info = v
6090 },
6091 })
6092 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6093 if err != nil {
6094 t.Fatal(err)
6095 }
6096 res, err := cst.c.Do(req)
6097 if err != nil {
6098 t.Fatal(err)
6099 }
6100 res.Body.Close()
6101 if conns != 1 {
6102 t.Fatalf("request %v: got %v conns, want 1", try, conns)
6103 }
6104 if info.Reused || info.WasIdle {
6105 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
6106 }
6107 }
6108 }
6109
6110
6111
6112
6113 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
6114 func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
6115 runTimeSensitiveTest(t, []time.Duration{
6116 10 * time.Millisecond,
6117 50 * time.Millisecond,
6118 250 * time.Millisecond,
6119 time.Second,
6120 2 * time.Second,
6121 }, func(t *testing.T, timeout time.Duration) error {
6122 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6123 select {
6124 case <-time.After(2 * timeout):
6125 fmt.Fprint(w, "ok")
6126 case <-r.Context().Done():
6127 fmt.Fprint(w, r.Context().Err())
6128 }
6129 }), func(ts *httptest.Server) {
6130 ts.Config.ReadTimeout = timeout
6131 t.Logf("Server.Config.ReadTimeout = %v", timeout)
6132 })
6133 defer cst.close()
6134 ts := cst.ts
6135
6136 var retries atomic.Int32
6137 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
6138 if retries.Add(1) != 1 {
6139 return nil, errors.New("too many retries")
6140 }
6141 return nil, nil
6142 }
6143
6144 c := ts.Client()
6145
6146 res, err := c.Get(ts.URL)
6147 if err != nil {
6148 return fmt.Errorf("Get: %v", err)
6149 }
6150 slurp, err := io.ReadAll(res.Body)
6151 res.Body.Close()
6152 if err != nil {
6153 return fmt.Errorf("Body ReadAll: %v", err)
6154 }
6155 if string(slurp) != "ok" {
6156 return fmt.Errorf("got: %q, want ok", slurp)
6157 }
6158 return nil
6159 })
6160 }
6161
6162
6163
6164
6165 func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
6166 run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
6167 }
6168 func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
6169 runTimeSensitiveTest(t, []time.Duration{
6170 10 * time.Millisecond,
6171 50 * time.Millisecond,
6172 250 * time.Millisecond,
6173 time.Second,
6174 2 * time.Second,
6175 }, func(t *testing.T, timeout time.Duration) error {
6176 cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
6177 ts.Config.ReadHeaderTimeout = timeout
6178 ts.Config.IdleTimeout = 0
6179 })
6180 defer cst.close()
6181 ts := cst.ts
6182
6183
6184
6185 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6186 if err != nil {
6187 t.Fatalf("dial failed: %v", err)
6188 }
6189 br := bufio.NewReader(conn)
6190 defer conn.Close()
6191
6192 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6193 return fmt.Errorf("writing first request failed: %v", err)
6194 }
6195
6196 if _, err := ReadResponse(br, nil); err != nil {
6197 return fmt.Errorf("first response (before timeout) failed: %v", err)
6198 }
6199
6200
6201
6202 time.Sleep(timeout * 3 / 2)
6203
6204 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6205 return fmt.Errorf("writing second request failed: %v", err)
6206 }
6207
6208 if _, err := ReadResponse(br, nil); err != nil {
6209 return fmt.Errorf("second response (after timeout) failed: %v", err)
6210 }
6211
6212 return nil
6213 })
6214 }
6215
6216
6217
6218 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
6219 for i, d := range durations {
6220 err := test(t, d)
6221 if err == nil {
6222 return
6223 }
6224 if i == len(durations)-1 || t.Failed() {
6225 t.Fatalf("failed with duration %v: %v", d, err)
6226 }
6227 t.Logf("retrying after error with duration %v: %v", d, err)
6228 }
6229 }
6230
6231
6232
6233 func TestServerDuplicateBackgroundRead(t *testing.T) {
6234 run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
6235 }
6236 func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
6237 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
6238 testenv.SkipFlaky(t, 24826)
6239 }
6240
6241 goroutines := 5
6242 requests := 2000
6243 if testing.Short() {
6244 goroutines = 3
6245 requests = 100
6246 }
6247
6248 hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
6249
6250 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6251
6252 var wg sync.WaitGroup
6253 for i := 0; i < goroutines; i++ {
6254 wg.Add(1)
6255 go func() {
6256 defer wg.Done()
6257 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
6258 if err != nil {
6259 t.Error(err)
6260 return
6261 }
6262 defer cn.Close()
6263
6264 wg.Add(1)
6265 go func() {
6266 defer wg.Done()
6267 io.Copy(io.Discard, cn)
6268 }()
6269
6270 for j := 0; j < requests; j++ {
6271 if t.Failed() {
6272 return
6273 }
6274 _, err := cn.Write(reqBytes)
6275 if err != nil {
6276 t.Error(err)
6277 return
6278 }
6279 }
6280 }()
6281 }
6282 wg.Wait()
6283 }
6284
6285
6286
6287
6288
6289
6290 func TestServerHijackGetsBackgroundByte(t *testing.T) {
6291 run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
6292 }
6293 func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
6294 if runtime.GOOS == "plan9" {
6295 t.Skip("skipping test; see https://golang.org/issue/18657")
6296 }
6297 done := make(chan struct{})
6298 inHandler := make(chan bool, 1)
6299 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6300 defer close(done)
6301
6302
6303 inHandler <- true
6304
6305 conn, buf, err := w.(Hijacker).Hijack()
6306 if err != nil {
6307 t.Error(err)
6308 return
6309 }
6310 defer conn.Close()
6311
6312 peek, err := buf.Reader.Peek(3)
6313 if string(peek) != "foo" || err != nil {
6314 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
6315 }
6316
6317 select {
6318 case <-r.Context().Done():
6319 t.Error("context unexpectedly canceled")
6320 default:
6321 }
6322 })).ts
6323
6324 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6325 if err != nil {
6326 t.Fatal(err)
6327 }
6328 defer cn.Close()
6329 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6330 t.Fatal(err)
6331 }
6332 <-inHandler
6333 if _, err := cn.Write([]byte("foo")); err != nil {
6334 t.Fatal(err)
6335 }
6336
6337 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6338 t.Fatal(err)
6339 }
6340 <-done
6341 }
6342
6343
6344 func TestServerHijackGetsFullBody(t *testing.T) {
6345 run(t, testServerHijackGetsFullBody, []testMode{http1Mode})
6346 }
6347 func testServerHijackGetsFullBody(t *testing.T, mode testMode) {
6348 if runtime.GOOS == "plan9" {
6349 t.Skip("skipping test; see https://golang.org/issue/18657")
6350 }
6351 done := make(chan struct{})
6352 needle := strings.Repeat("x", 100*1024)
6353 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6354 defer close(done)
6355
6356 conn, buf, err := w.(Hijacker).Hijack()
6357 if err != nil {
6358 t.Error(err)
6359 return
6360 }
6361 defer conn.Close()
6362
6363 got := make([]byte, len(needle))
6364 n, err := io.ReadFull(buf.Reader, got)
6365 if n != len(needle) || string(got) != needle || err != nil {
6366 t.Errorf("Peek = %q, %v; want 'x'*4096, nil", got, err)
6367 }
6368 })).ts
6369
6370 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6371 if err != nil {
6372 t.Fatal(err)
6373 }
6374 defer cn.Close()
6375 buf := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6376 buf = append(buf, []byte(needle)...)
6377 if _, err := cn.Write(buf); err != nil {
6378 t.Fatal(err)
6379 }
6380
6381 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6382 t.Fatal(err)
6383 }
6384 <-done
6385 }
6386
6387
6388
6389
6390 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
6391 run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
6392 }
6393 func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
6394 if runtime.GOOS == "plan9" {
6395 t.Skip("skipping test; see https://golang.org/issue/18657")
6396 }
6397 done := make(chan struct{})
6398 const size = 8 << 10
6399 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6400 defer close(done)
6401
6402 conn, buf, err := w.(Hijacker).Hijack()
6403 if err != nil {
6404 t.Error(err)
6405 return
6406 }
6407 defer conn.Close()
6408 slurp, err := io.ReadAll(buf.Reader)
6409 if err != nil {
6410 t.Errorf("Copy: %v", err)
6411 }
6412 allX := true
6413 for _, v := range slurp {
6414 if v != 'x' {
6415 allX = false
6416 }
6417 }
6418 if len(slurp) != size {
6419 t.Errorf("read %d; want %d", len(slurp), size)
6420 } else if !allX {
6421 t.Errorf("read %q; want %d 'x'", slurp, size)
6422 }
6423 })).ts
6424
6425 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6426 if err != nil {
6427 t.Fatal(err)
6428 }
6429 defer cn.Close()
6430 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
6431 strings.Repeat("x", size)); err != nil {
6432 t.Fatal(err)
6433 }
6434 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6435 t.Fatal(err)
6436 }
6437
6438 <-done
6439 }
6440
6441
6442 func TestServerValidatesMethod(t *testing.T) {
6443 tests := []struct {
6444 method string
6445 want int
6446 }{
6447 {"GET", 200},
6448 {"GE(T", 400},
6449 }
6450 for _, tt := range tests {
6451 conn := newTestConn()
6452 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
6453
6454 ln := &oneConnListener{conn}
6455 go Serve(ln, serve(200))
6456 <-conn.closec
6457 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
6458 if err != nil {
6459 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
6460 continue
6461 }
6462 if res.StatusCode != tt.want {
6463 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
6464 }
6465 }
6466 }
6467
6468
6469 type eofListenerNotComparable []int
6470
6471 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
6472 func (eofListenerNotComparable) Addr() net.Addr { return nil }
6473 func (eofListenerNotComparable) Close() error { return nil }
6474
6475
6476 func TestServerListenNotComparableListener(t *testing.T) {
6477 var s Server
6478 s.Serve(make(eofListenerNotComparable, 1))
6479 }
6480
6481
6482 type countCloseListener struct {
6483 net.Listener
6484 closes int32
6485 }
6486
6487 func (p *countCloseListener) Close() error {
6488 var err error
6489 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
6490 err = p.Listener.Close()
6491 }
6492 return err
6493 }
6494
6495
6496 func TestServerCloseListenerOnce(t *testing.T) {
6497 setParallel(t)
6498 defer afterTest(t)
6499
6500 ln := newLocalListener(t)
6501 defer ln.Close()
6502
6503 cl := &countCloseListener{Listener: ln}
6504 server := &Server{}
6505 sdone := make(chan bool, 1)
6506
6507 go func() {
6508 server.Serve(cl)
6509 sdone <- true
6510 }()
6511 time.Sleep(10 * time.Millisecond)
6512 server.Shutdown(context.Background())
6513 ln.Close()
6514 <-sdone
6515
6516 nclose := atomic.LoadInt32(&cl.closes)
6517 if nclose != 1 {
6518 t.Errorf("Close calls = %v; want 1", nclose)
6519 }
6520 }
6521
6522
6523 func TestServerShutdownThenServe(t *testing.T) {
6524 var srv Server
6525 cl := &countCloseListener{Listener: nil}
6526 srv.Shutdown(context.Background())
6527 got := srv.Serve(cl)
6528 if got != ErrServerClosed {
6529 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6530 }
6531 nclose := atomic.LoadInt32(&cl.closes)
6532 if nclose != 1 {
6533 t.Errorf("Close calls = %v; want 1", nclose)
6534 }
6535 }
6536
6537
6538 func TestStripPortFromHost(t *testing.T) {
6539 mux := NewServeMux()
6540
6541 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6542 fmt.Fprintf(w, "OK")
6543 })
6544 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6545 fmt.Fprintf(w, "uh-oh!")
6546 })
6547
6548 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6549 rw := httptest.NewRecorder()
6550
6551 mux.ServeHTTP(rw, req)
6552
6553 response := rw.Body.String()
6554 if response != "OK" {
6555 t.Errorf("Response gotten was %q", response)
6556 }
6557 }
6558
6559 func TestServerContexts(t *testing.T) { run(t, testServerContexts) }
6560 func testServerContexts(t *testing.T, mode testMode) {
6561 type baseKey struct{}
6562 type connKey struct{}
6563 ch := make(chan context.Context, 1)
6564 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6565 ch <- r.Context()
6566 }), func(ts *httptest.Server) {
6567 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6568 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6569 t.Errorf("unexpected onceClose listener type %T", ln)
6570 }
6571 return context.WithValue(context.Background(), baseKey{}, "base")
6572 }
6573 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6574 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6575 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6576 }
6577 return context.WithValue(ctx, connKey{}, "conn")
6578 }
6579 }).ts
6580 res, err := ts.Client().Get(ts.URL)
6581 if err != nil {
6582 t.Fatal(err)
6583 }
6584 res.Body.Close()
6585 ctx := <-ch
6586 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6587 t.Errorf("base context key = %#v; want %q", got, want)
6588 }
6589 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6590 t.Errorf("conn context key = %#v; want %q", got, want)
6591 }
6592 }
6593
6594
6595 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6596 run(t, testConnContextNotModifyingAllContexts)
6597 }
6598 func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
6599 type connKey struct{}
6600 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6601 rw.Header().Set("Connection", "close")
6602 }), func(ts *httptest.Server) {
6603 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6604 if got := ctx.Value(connKey{}); got != nil {
6605 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6606 }
6607 return context.WithValue(ctx, connKey{}, "conn")
6608 }
6609 }).ts
6610
6611 var res *Response
6612 var err error
6613
6614 res, err = ts.Client().Get(ts.URL)
6615 if err != nil {
6616 t.Fatal(err)
6617 }
6618 res.Body.Close()
6619
6620 res, err = ts.Client().Get(ts.URL)
6621 if err != nil {
6622 t.Fatal(err)
6623 }
6624 res.Body.Close()
6625 }
6626
6627
6628
6629 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6630 run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
6631 }
6632 func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
6633 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6634 w.Write([]byte("Hello, World!"))
6635 })).ts
6636
6637 serverURL, err := url.Parse(cst.URL)
6638 if err != nil {
6639 t.Fatalf("Failed to parse server URL: %v", err)
6640 }
6641
6642 unsupportedTEs := []string{
6643 "fugazi",
6644 "foo-bar",
6645 "unknown",
6646 `" chunked"`,
6647 }
6648
6649 for _, badTE := range unsupportedTEs {
6650 http1ReqBody := fmt.Sprintf(""+
6651 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6652 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6653
6654 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6655 if err != nil {
6656 t.Errorf("%q. unexpected error: %v", badTE, err)
6657 continue
6658 }
6659
6660 wantBody := fmt.Sprintf("" +
6661 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6662 "Connection: close\r\n\r\nUnsupported transfer encoding")
6663
6664 if string(gotBody) != wantBody {
6665 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6666 }
6667 }
6668 }
6669
6670
6671 func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) }
6672 func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
6673 type setting struct {
6674 name string
6675 body []byte
6676
6677
6678
6679
6680 contentEncoding any
6681 wantContentType string
6682 }
6683
6684 settings := []*setting{
6685 {
6686 name: "gzip content-encoding, gzipped",
6687 contentEncoding: "application/gzip",
6688 wantContentType: "",
6689 body: func() []byte {
6690 buf := new(bytes.Buffer)
6691 gzw := gzip.NewWriter(buf)
6692 gzw.Write([]byte("doctype html><p>Hello</p>"))
6693 gzw.Close()
6694 return buf.Bytes()
6695 }(),
6696 },
6697 {
6698 name: "zlib content-encoding, zlibbed",
6699 contentEncoding: "application/zlib",
6700 wantContentType: "",
6701 body: func() []byte {
6702 buf := new(bytes.Buffer)
6703 zw := zlib.NewWriter(buf)
6704 zw.Write([]byte("doctype html><p>Hello</p>"))
6705 zw.Close()
6706 return buf.Bytes()
6707 }(),
6708 },
6709 {
6710 name: "no content-encoding",
6711 wantContentType: "application/x-gzip",
6712 body: func() []byte {
6713 buf := new(bytes.Buffer)
6714 gzw := gzip.NewWriter(buf)
6715 gzw.Write([]byte("doctype html><p>Hello</p>"))
6716 gzw.Close()
6717 return buf.Bytes()
6718 }(),
6719 },
6720 {
6721 name: "phony content-encoding",
6722 contentEncoding: "foo/bar",
6723 body: []byte("doctype html><p>Hello</p>"),
6724 },
6725 {
6726 name: "empty but set content-encoding",
6727 contentEncoding: "",
6728 wantContentType: "audio/mpeg",
6729 body: []byte("ID3"),
6730 },
6731 }
6732
6733 for _, tt := range settings {
6734 t.Run(tt.name, func(t *testing.T) {
6735 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6736 if tt.contentEncoding != nil {
6737 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6738 }
6739 rw.Write(tt.body)
6740 }))
6741
6742 res, err := cst.c.Get(cst.ts.URL)
6743 if err != nil {
6744 t.Fatalf("Failed to fetch URL: %v", err)
6745 }
6746 defer res.Body.Close()
6747
6748 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6749 if w != nil {
6750 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6751 } else if g != "" {
6752 t.Errorf("Unexpected Content-Encoding %q", g)
6753 }
6754 }
6755
6756 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6757 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6758 }
6759 })
6760 }
6761 }
6762
6763
6764
6765 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6766 run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
6767 }
6768 func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
6769 if testing.Short() {
6770 t.Skip("skipping in short mode")
6771 }
6772
6773 pc, curFile, _, _ := runtime.Caller(0)
6774 curFileBaseName := filepath.Base(curFile)
6775 testFuncName := runtime.FuncForPC(pc).Name()
6776
6777 timeoutMsg := "timed out here!"
6778
6779 tests := []struct {
6780 name string
6781 mustTimeout bool
6782 wantResp string
6783 }{
6784 {
6785 name: "return before timeout",
6786 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6787 },
6788 {
6789 name: "return after timeout",
6790 mustTimeout: true,
6791 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6792 len(timeoutMsg), timeoutMsg),
6793 },
6794 }
6795
6796 for _, tt := range tests {
6797 t.Run(tt.name, func(t *testing.T) {
6798 exitHandler := make(chan bool, 1)
6799 defer close(exitHandler)
6800 lastLine := make(chan int, 1)
6801
6802 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6803 w.WriteHeader(404)
6804 w.WriteHeader(404)
6805 w.WriteHeader(404)
6806 w.WriteHeader(404)
6807 _, _, line, _ := runtime.Caller(0)
6808 lastLine <- line
6809 <-exitHandler
6810 })
6811
6812 if !tt.mustTimeout {
6813 exitHandler <- true
6814 }
6815
6816 logBuf := new(strings.Builder)
6817 srvLog := log.New(logBuf, "", 0)
6818
6819 dur := 20 * time.Millisecond
6820 if !tt.mustTimeout {
6821
6822 dur = 10 * time.Second
6823 }
6824 th := TimeoutHandler(sh, dur, timeoutMsg)
6825 cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
6826 defer cst.close()
6827
6828 res, err := cst.c.Get(cst.ts.URL)
6829 if err != nil {
6830 t.Fatalf("Unexpected error: %v", err)
6831 }
6832
6833
6834
6835 res.Header.Del("Date")
6836 res.Header.Del("Content-Type")
6837
6838
6839 blob, _ := httputil.DumpResponse(res, true)
6840 if g, w := string(blob), tt.wantResp; g != w {
6841 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6842 }
6843
6844
6845
6846 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6847 if g, w := len(logEntries), 3; g != w {
6848 blob, _ := json.MarshalIndent(logEntries, "", " ")
6849 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6850 }
6851
6852 lastSpuriousLine := <-lastLine
6853 firstSpuriousLine := lastSpuriousLine - 3
6854
6855
6856 for i, logEntry := range logEntries {
6857 wantLine := firstSpuriousLine + i
6858 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6859 testFuncName, curFileBaseName, wantLine)
6860 re := regexp.MustCompile(pat)
6861 if !re.MatchString(logEntry) {
6862 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6863 }
6864 }
6865 })
6866 }
6867 }
6868
6869
6870
6871
6872 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6873 conn, err := net.Dial("tcp", host)
6874 if err != nil {
6875 return nil, err
6876 }
6877 defer conn.Close()
6878
6879 if _, err := conn.Write(http1ReqBody); err != nil {
6880 return nil, err
6881 }
6882 return io.ReadAll(conn)
6883 }
6884
6885 func BenchmarkResponseStatusLine(b *testing.B) {
6886 b.ReportAllocs()
6887 b.RunParallel(func(pb *testing.PB) {
6888 bw := bufio.NewWriter(io.Discard)
6889 var buf3 [3]byte
6890 for pb.Next() {
6891 Export_writeStatusLine(bw, true, 200, buf3[:])
6892 }
6893 })
6894 }
6895
6896 func TestDisableKeepAliveUpgrade(t *testing.T) {
6897 run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
6898 }
6899 func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
6900 if testing.Short() {
6901 t.Skip("skipping in short mode")
6902 }
6903
6904 s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6905 w.Header().Set("Connection", "Upgrade")
6906 w.Header().Set("Upgrade", "someProto")
6907 w.WriteHeader(StatusSwitchingProtocols)
6908 c, buf, err := w.(Hijacker).Hijack()
6909 if err != nil {
6910 return
6911 }
6912 defer c.Close()
6913
6914
6915
6916 io.Copy(c, buf)
6917 }), func(ts *httptest.Server) {
6918 ts.Config.SetKeepAlivesEnabled(false)
6919 }).ts
6920
6921 cl := s.Client()
6922 cl.Transport.(*Transport).DisableKeepAlives = true
6923
6924 resp, err := cl.Get(s.URL)
6925 if err != nil {
6926 t.Fatalf("failed to perform request: %v", err)
6927 }
6928 defer resp.Body.Close()
6929
6930 if resp.StatusCode != StatusSwitchingProtocols {
6931 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6932 }
6933
6934 rwc, ok := resp.Body.(io.ReadWriteCloser)
6935 if !ok {
6936 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6937 }
6938
6939 _, err = rwc.Write([]byte("hello"))
6940 if err != nil {
6941 t.Fatalf("failed to write to body: %v", err)
6942 }
6943
6944 b := make([]byte, 5)
6945 _, err = io.ReadFull(rwc, b)
6946 if err != nil {
6947 t.Fatalf("failed to read from body: %v", err)
6948 }
6949
6950 if string(b) != "hello" {
6951 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6952 }
6953 }
6954
6955 type tlogWriter struct{ t *testing.T }
6956
6957 func (w tlogWriter) Write(p []byte) (int, error) {
6958 w.t.Log(string(p))
6959 return len(p), nil
6960 }
6961
6962 func TestWriteHeaderSwitchingProtocols(t *testing.T) {
6963 run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
6964 }
6965 func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
6966 const wantBody = "want"
6967 const wantUpgrade = "someProto"
6968 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6969 w.Header().Set("Connection", "Upgrade")
6970 w.Header().Set("Upgrade", wantUpgrade)
6971 w.WriteHeader(StatusSwitchingProtocols)
6972 NewResponseController(w).Flush()
6973
6974
6975 w.WriteHeader(200)
6976 if _, err := w.Write([]byte("x")); err == nil {
6977 t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
6978 }
6979
6980 c, _, err := NewResponseController(w).Hijack()
6981 if err != nil {
6982 t.Errorf("Hijack: %v", err)
6983 return
6984 }
6985 defer c.Close()
6986 if _, err := c.Write([]byte(wantBody)); err != nil {
6987 t.Errorf("Write to hijacked body: %v", err)
6988 }
6989 }), func(ts *httptest.Server) {
6990
6991 ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
6992 }).ts
6993
6994 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6995 if err != nil {
6996 t.Fatalf("net.Dial: %v", err)
6997 }
6998 _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
6999 if err != nil {
7000 t.Fatalf("conn.Write: %v", err)
7001 }
7002 defer conn.Close()
7003
7004 r := bufio.NewReader(conn)
7005 res, err := ReadResponse(r, &Request{Method: "GET"})
7006 if err != nil {
7007 t.Fatal("ReadResponse error:", err)
7008 }
7009 if res.StatusCode != StatusSwitchingProtocols {
7010 t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
7011 }
7012 if got := res.Header.Get("Upgrade"); got != wantUpgrade {
7013 t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
7014 }
7015 body, err := io.ReadAll(r)
7016 if err != nil {
7017 t.Error(err)
7018 }
7019 if string(body) != wantBody {
7020 t.Errorf("Response body = %q, want %q", string(body), wantBody)
7021 }
7022 }
7023
7024 func TestMuxRedirectRelative(t *testing.T) {
7025 setParallel(t)
7026 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
7027 if err != nil {
7028 t.Errorf("%s", err)
7029 }
7030 mux := NewServeMux()
7031 resp := httptest.NewRecorder()
7032 mux.ServeHTTP(resp, req)
7033 if got, want := resp.Header().Get("Location"), "/"; got != want {
7034 t.Errorf("Location header expected %q; got %q", want, got)
7035 }
7036 if got, want := resp.Code, StatusTemporaryRedirect; got != want {
7037 t.Errorf("Expected response code %d; got %d", want, got)
7038 }
7039 }
7040
7041
7042 func TestQuerySemicolon(t *testing.T) {
7043 t.Cleanup(func() { afterTest(t) })
7044
7045 tests := []struct {
7046 query string
7047 xNoSemicolons string
7048 xWithSemicolons string
7049 expectParseFormErr bool
7050 }{
7051 {"?a=1;x=bad&x=good", "good", "bad", true},
7052 {"?a=1;b=bad&x=good", "good", "good", true},
7053 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
7054 {"?a=1;x=good;x=bad", "", "good", true},
7055 }
7056
7057 run(t, func(t *testing.T, mode testMode) {
7058 for _, tt := range tests {
7059 t.Run(tt.query+"/allow=false", func(t *testing.T) {
7060 allowSemicolons := false
7061 testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
7062 })
7063 t.Run(tt.query+"/allow=true", func(t *testing.T) {
7064 allowSemicolons, expectParseFormErr := true, false
7065 testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
7066 })
7067 }
7068 })
7069 }
7070
7071 func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
7072 writeBackX := func(w ResponseWriter, r *Request) {
7073 x := r.URL.Query().Get("x")
7074 if expectParseFormErr {
7075 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
7076 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
7077 }
7078 } else {
7079 if err := r.ParseForm(); err != nil {
7080 t.Errorf("expected no error from ParseForm, got %v", err)
7081 }
7082 }
7083 if got := r.FormValue("x"); x != got {
7084 t.Errorf("got %q from FormValue, want %q", got, x)
7085 }
7086 fmt.Fprintf(w, "%s", x)
7087 }
7088
7089 h := Handler(HandlerFunc(writeBackX))
7090 if allowSemicolons {
7091 h = AllowQuerySemicolons(h)
7092 }
7093
7094 logBuf := &strings.Builder{}
7095 ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
7096 ts.Config.ErrorLog = log.New(logBuf, "", 0)
7097 }).ts
7098
7099 req, _ := NewRequest("GET", ts.URL+query, nil)
7100 res, err := ts.Client().Do(req)
7101 if err != nil {
7102 t.Fatal(err)
7103 }
7104 slurp, _ := io.ReadAll(res.Body)
7105 res.Body.Close()
7106 if got, want := res.StatusCode, 200; got != want {
7107 t.Errorf("Status = %d; want = %d", got, want)
7108 }
7109 if got, want := string(slurp), wantX; got != want {
7110 t.Errorf("Body = %q; want = %q", got, want)
7111 }
7112 }
7113
7114 func TestMaxBytesHandler(t *testing.T) {
7115
7116 defer afterTest(t)
7117
7118 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
7119 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
7120 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
7121 func(t *testing.T) {
7122 run(t, func(t *testing.T, mode testMode) {
7123 testMaxBytesHandler(t, mode, maxSize, requestSize)
7124 }, testNotParallel)
7125 })
7126 }
7127 }
7128 }
7129
7130 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
7131 runTimeSensitiveTest(t, []time.Duration{
7132 1 * time.Millisecond,
7133 5 * time.Millisecond,
7134 10 * time.Millisecond,
7135 50 * time.Millisecond,
7136 100 * time.Millisecond,
7137 500 * time.Millisecond,
7138 time.Second,
7139 5 * time.Second,
7140 }, func(t *testing.T, timeout time.Duration) error {
7141 SetRSTAvoidanceDelay(t, timeout)
7142 t.Logf("set RST avoidance delay to %v", timeout)
7143
7144 var (
7145 handlerN int64
7146 handlerErr error
7147 )
7148 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
7149 var buf bytes.Buffer
7150 handlerN, handlerErr = io.Copy(&buf, r.Body)
7151 io.Copy(w, &buf)
7152 })
7153
7154 cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
7155
7156
7157 defer cst.close()
7158 ts := cst.ts
7159 c := ts.Client()
7160
7161 body := strings.Repeat("a", int(requestSize))
7162 var wg sync.WaitGroup
7163 defer wg.Wait()
7164 getBody := func() (io.ReadCloser, error) {
7165 wg.Add(1)
7166 body := &wgReadCloser{
7167 Reader: strings.NewReader(body),
7168 wg: &wg,
7169 }
7170 return body, nil
7171 }
7172 reqBody, _ := getBody()
7173 req, err := NewRequest("POST", ts.URL, reqBody)
7174 if err != nil {
7175 reqBody.Close()
7176 t.Fatal(err)
7177 }
7178 req.ContentLength = int64(len(body))
7179 req.GetBody = getBody
7180 req.Header.Set("Content-Type", "text/plain")
7181
7182 var buf strings.Builder
7183 res, err := c.Do(req)
7184 if err != nil {
7185 return fmt.Errorf("unexpected connection error: %v", err)
7186 } else {
7187 _, err = io.Copy(&buf, res.Body)
7188 res.Body.Close()
7189 if err != nil {
7190 return fmt.Errorf("unexpected read error: %v", err)
7191 }
7192 }
7193
7194
7195
7196
7197 if handlerN > maxSize {
7198 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
7199 }
7200 if requestSize > maxSize && handlerErr == nil {
7201 t.Error("expected error on handler side; got nil")
7202 }
7203 if requestSize <= maxSize {
7204 if handlerErr != nil {
7205 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
7206 }
7207 if handlerN != requestSize {
7208 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
7209 }
7210 }
7211 if buf.Len() != int(handlerN) {
7212 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
7213 }
7214
7215 return nil
7216 })
7217 }
7218
7219 func TestEarlyHints(t *testing.T) {
7220 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7221 h := w.Header()
7222 h.Add("Link", "</style.css>; rel=preload; as=style")
7223 h.Add("Link", "</script.js>; rel=preload; as=script")
7224 w.WriteHeader(StatusEarlyHints)
7225
7226 h.Add("Link", "</foo.js>; rel=preload; as=script")
7227 w.WriteHeader(StatusEarlyHints)
7228
7229 w.Write([]byte("stuff"))
7230 }))
7231
7232 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7233 expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: "
7234 if !strings.Contains(got, expected) {
7235 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7236 }
7237 }
7238 func TestProcessing(t *testing.T) {
7239 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7240 w.WriteHeader(StatusProcessing)
7241 w.Write([]byte("stuff"))
7242 }))
7243
7244 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7245 expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: "
7246 if !strings.Contains(got, expected) {
7247 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7248 }
7249 }
7250
7251 func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) }
7252 func testParseFormCleanup(t *testing.T, mode testMode) {
7253 if mode == http2Mode {
7254 t.Skip("https://go.dev/issue/20253")
7255 }
7256
7257 const maxMemory = 1024
7258 const key = "file"
7259
7260 if runtime.GOOS == "windows" {
7261
7262 t.Skip("https://go.dev/issue/25965")
7263 }
7264
7265 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7266 r.ParseMultipartForm(maxMemory)
7267 f, _, err := r.FormFile(key)
7268 if err != nil {
7269 t.Errorf("r.FormFile(%q) = %v", key, err)
7270 return
7271 }
7272 of, ok := f.(*os.File)
7273 if !ok {
7274 t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
7275 return
7276 }
7277 w.Write([]byte(of.Name()))
7278 }))
7279
7280 fBuf := new(bytes.Buffer)
7281 mw := multipart.NewWriter(fBuf)
7282 mf, err := mw.CreateFormFile(key, "myfile.txt")
7283 if err != nil {
7284 t.Fatal(err)
7285 }
7286 if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
7287 t.Fatal(err)
7288 }
7289 if err := mw.Close(); err != nil {
7290 t.Fatal(err)
7291 }
7292 req, err := NewRequest("POST", cst.ts.URL, fBuf)
7293 if err != nil {
7294 t.Fatal(err)
7295 }
7296 req.Header.Set("Content-Type", mw.FormDataContentType())
7297 res, err := cst.c.Do(req)
7298 if err != nil {
7299 t.Fatal(err)
7300 }
7301 defer res.Body.Close()
7302 fname, err := io.ReadAll(res.Body)
7303 if err != nil {
7304 t.Fatal(err)
7305 }
7306 cst.close()
7307 if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
7308 t.Errorf("file %q exists after HTTP handler returned", string(fname))
7309 }
7310 }
7311
7312 func TestHeadBody(t *testing.T) {
7313 const identityMode = false
7314 const chunkedMode = true
7315 run(t, func(t *testing.T, mode testMode) {
7316 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
7317 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
7318 })
7319 }
7320
7321 func TestGetBody(t *testing.T) {
7322 const identityMode = false
7323 const chunkedMode = true
7324 run(t, func(t *testing.T, mode testMode) {
7325 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
7326 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
7327 })
7328 }
7329
7330 func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
7331 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7332 b, err := io.ReadAll(r.Body)
7333 if err != nil {
7334 t.Errorf("server reading body: %v", err)
7335 return
7336 }
7337 w.Header().Set("X-Request-Body", string(b))
7338 w.Header().Set("Content-Length", "0")
7339 }))
7340 defer cst.close()
7341 for _, reqBody := range []string{
7342 "",
7343 "",
7344 "request_body",
7345 "",
7346 } {
7347 var bodyReader io.Reader
7348 if reqBody != "" {
7349 bodyReader = strings.NewReader(reqBody)
7350 if chunked {
7351 bodyReader = bufio.NewReader(bodyReader)
7352 }
7353 }
7354 req, err := NewRequest(method, cst.ts.URL, bodyReader)
7355 if err != nil {
7356 t.Fatal(err)
7357 }
7358 res, err := cst.c.Do(req)
7359 if err != nil {
7360 t.Fatal(err)
7361 }
7362 res.Body.Close()
7363 if got, want := res.StatusCode, 200; got != want {
7364 t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
7365 }
7366 if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
7367 t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
7368 }
7369 }
7370 }
7371
7372
7373
7374 func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
7375 func testDisableContentLength(t *testing.T, mode testMode) {
7376 noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7377 w.Header()["Content-Length"] = nil
7378 fmt.Fprintf(w, "OK")
7379 }))
7380
7381 res, err := noCL.c.Get(noCL.ts.URL)
7382 if err != nil {
7383 t.Fatal(err)
7384 }
7385 if got, haveCL := res.Header["Content-Length"]; haveCL {
7386 t.Errorf("Unexpected Content-Length: %q", got)
7387 }
7388 if err := res.Body.Close(); err != nil {
7389 t.Fatal(err)
7390 }
7391
7392 withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7393 fmt.Fprintf(w, "OK")
7394 }))
7395
7396 res, err = withCL.c.Get(withCL.ts.URL)
7397 if err != nil {
7398 t.Fatal(err)
7399 }
7400 if got := res.Header.Get("Content-Length"); got != "2" {
7401 t.Errorf("Content-Length: %q; want 2", got)
7402 }
7403 if err := res.Body.Close(); err != nil {
7404 t.Fatal(err)
7405 }
7406 }
7407
7408 func TestErrorContentLength(t *testing.T) { run(t, testErrorContentLength) }
7409 func testErrorContentLength(t *testing.T, mode testMode) {
7410 const errorBody = "an error occurred"
7411 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7412 w.Header().Set("Content-Length", "1000")
7413 Error(w, errorBody, 400)
7414 }))
7415 res, err := cst.c.Get(cst.ts.URL)
7416 if err != nil {
7417 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7418 }
7419 defer res.Body.Close()
7420 body, err := io.ReadAll(res.Body)
7421 if err != nil {
7422 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7423 }
7424 if string(body) != errorBody+"\n" {
7425 t.Fatalf("read body: %q, want %q", string(body), errorBody)
7426 }
7427 }
7428
7429 func TestError(t *testing.T) {
7430 w := httptest.NewRecorder()
7431 w.Header().Set("Content-Length", "1")
7432 w.Header().Set("X-Content-Type-Options", "scratch and sniff")
7433 w.Header().Set("Other", "foo")
7434 Error(w, "oops", 432)
7435
7436 h := w.Header()
7437 for _, hdr := range []string{"Content-Length"} {
7438 if v, ok := h[hdr]; ok {
7439 t.Errorf("%s: %q, want not present", hdr, v)
7440 }
7441 }
7442 if v := h.Get("Content-Type"); v != "text/plain; charset=utf-8" {
7443 t.Errorf("Content-Type: %q, want %q", v, "text/plain; charset=utf-8")
7444 }
7445 if v := h.Get("X-Content-Type-Options"); v != "nosniff" {
7446 t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff")
7447 }
7448 }
7449
7450 func TestServerReadAfterWriteHeader100Continue(t *testing.T) {
7451 run(t, testServerReadAfterWriteHeader100Continue)
7452 }
7453 func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) {
7454 t.Skip("https://go.dev/issue/67555")
7455 body := []byte("body")
7456 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7457 w.WriteHeader(200)
7458 NewResponseController(w).Flush()
7459 io.ReadAll(r.Body)
7460 w.Write(body)
7461 }), func(tr *Transport) {
7462 tr.ExpectContinueTimeout = 24 * time.Hour
7463 })
7464
7465 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7466 req.Header.Set("Expect", "100-continue")
7467 res, err := cst.c.Do(req)
7468 if err != nil {
7469 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7470 }
7471 defer res.Body.Close()
7472 got, err := io.ReadAll(res.Body)
7473 if err != nil {
7474 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7475 }
7476 if !bytes.Equal(got, body) {
7477 t.Fatalf("response body = %q, want %q", got, body)
7478 }
7479 }
7480
7481 func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
7482 run(t, testServerReadAfterHandlerDone100Continue)
7483 }
7484 func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
7485 t.Skip("https://go.dev/issue/67555")
7486 readyc := make(chan struct{})
7487 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7488 go func() {
7489 <-readyc
7490 io.ReadAll(r.Body)
7491 <-readyc
7492 }()
7493 }), func(tr *Transport) {
7494 tr.ExpectContinueTimeout = 24 * time.Hour
7495 })
7496
7497 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7498 req.Header.Set("Expect", "100-continue")
7499 res, err := cst.c.Do(req)
7500 if err != nil {
7501 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7502 }
7503 res.Body.Close()
7504 readyc <- struct{}{}
7505 readyc <- struct{}{}
7506 }
7507
7508 func TestServerReadAfterHandlerAbort100Continue(t *testing.T) {
7509 run(t, testServerReadAfterHandlerAbort100Continue)
7510 }
7511 func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) {
7512 t.Skip("https://go.dev/issue/67555")
7513 readyc := make(chan struct{})
7514 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7515 go func() {
7516 <-readyc
7517 io.ReadAll(r.Body)
7518 <-readyc
7519 }()
7520 panic(ErrAbortHandler)
7521 }), func(tr *Transport) {
7522 tr.ExpectContinueTimeout = 24 * time.Hour
7523 })
7524
7525 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7526 req.Header.Set("Expect", "100-continue")
7527 res, err := cst.c.Do(req)
7528 if err == nil {
7529 res.Body.Close()
7530 }
7531 readyc <- struct{}{}
7532 readyc <- struct{}{}
7533 }
7534
7535 func TestInvalidChunkedBodies(t *testing.T) {
7536 for _, test := range []struct {
7537 name string
7538 b string
7539 }{{
7540 name: "bare LF in chunk size",
7541 b: "1\na\r\n0\r\n\r\n",
7542 }, {
7543 name: "bare LF at body end",
7544 b: "1\r\na\r\n0\r\n\n",
7545 }} {
7546 t.Run(test.name, func(t *testing.T) {
7547 reqc := make(chan error)
7548 ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7549 got, err := io.ReadAll(r.Body)
7550 if err == nil {
7551 t.Logf("read body: %q", got)
7552 }
7553 reqc <- err
7554 })).ts
7555
7556 serverURL, err := url.Parse(ts.URL)
7557 if err != nil {
7558 t.Fatal(err)
7559 }
7560
7561 conn, err := net.Dial("tcp", serverURL.Host)
7562 if err != nil {
7563 t.Fatal(err)
7564 }
7565
7566 if _, err := conn.Write([]byte(
7567 "POST / HTTP/1.1\r\n" +
7568 "Host: localhost\r\n" +
7569 "Transfer-Encoding: chunked\r\n" +
7570 "Connection: close\r\n" +
7571 "\r\n" +
7572 test.b)); err != nil {
7573 t.Fatal(err)
7574 }
7575 conn.(*net.TCPConn).CloseWrite()
7576
7577 if err := <-reqc; err == nil {
7578 t.Errorf("server handler: io.ReadAll(r.Body) succeeded, want error")
7579 }
7580 })
7581 }
7582 }
7583
7584
7585 func TestServerTLSNextProtos(t *testing.T) {
7586 run(t, testServerTLSNextProtos, []testMode{https1Mode, http2Mode})
7587 }
7588 func testServerTLSNextProtos(t *testing.T, mode testMode) {
7589 CondSkipHTTP2(t)
7590
7591 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7592 if err != nil {
7593 t.Fatal(err)
7594 }
7595 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7596 if err != nil {
7597 t.Fatal(err)
7598 }
7599 certpool := x509.NewCertPool()
7600 certpool.AddCert(leafCert)
7601
7602 protos := new(Protocols)
7603 switch mode {
7604 case https1Mode:
7605 protos.SetHTTP1(true)
7606 case http2Mode:
7607 protos.SetHTTP2(true)
7608 }
7609
7610 wantNextProtos := []string{"http/1.1", "h2", "other"}
7611 nextProtos := slices.Clone(wantNextProtos)
7612
7613
7614 srv := &Server{
7615 TLSConfig: &tls.Config{
7616 Certificates: []tls.Certificate{cert},
7617 NextProtos: nextProtos,
7618 },
7619 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {}),
7620 Protocols: protos,
7621 }
7622 tr := &Transport{
7623 TLSClientConfig: &tls.Config{
7624 RootCAs: certpool,
7625 NextProtos: nextProtos,
7626 },
7627 Protocols: protos,
7628 }
7629
7630 listener := newLocalListener(t)
7631 srvc := make(chan error, 1)
7632 go func() {
7633 srvc <- srv.ServeTLS(listener, "", "")
7634 }()
7635 t.Cleanup(func() {
7636 srv.Close()
7637 <-srvc
7638 })
7639
7640 client := &Client{Transport: tr}
7641 resp, err := client.Get("https://" + listener.Addr().String())
7642 if err != nil {
7643 t.Fatal(err)
7644 }
7645 resp.Body.Close()
7646
7647 if !slices.Equal(nextProtos, wantNextProtos) {
7648 t.Fatalf("after running test: original NextProtos slice = %v, want %v", nextProtos, wantNextProtos)
7649 }
7650 }
7651
7652
7653
7654 func TestServerHTTP2Disabled(t *testing.T) {
7655 synctest.Test(t, func(t *testing.T) {
7656 li := fakeNetListen()
7657 srv := &Server{}
7658 srv.Protocols = new(Protocols)
7659 srv.Protocols.SetHTTP1(true)
7660 go srv.ServeTLS(li, "", "")
7661 synctest.Wait()
7662 srv.Shutdown(t.Context())
7663 })
7664 }
7665
View as plain text