Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 crand "crypto/rand"
16 "crypto/tls"
17 "crypto/x509"
18 "encoding/json"
19 "errors"
20 "fmt"
21 "internal/synctest"
22 "internal/testenv"
23 "io"
24 "log"
25 "math/rand"
26 "mime/multipart"
27 "net"
28 . "net/http"
29 "net/http/httptest"
30 "net/http/httptrace"
31 "net/http/httputil"
32 "net/http/internal"
33 "net/http/internal/testcert"
34 "net/url"
35 "os"
36 "path/filepath"
37 "reflect"
38 "regexp"
39 "runtime"
40 "slices"
41 "strconv"
42 "strings"
43 "sync"
44 "sync/atomic"
45 "syscall"
46 "testing"
47 "time"
48 )
49
50 type dummyAddr string
51 type oneConnListener struct {
52 conn net.Conn
53 }
54
55 func (l *oneConnListener) Accept() (c net.Conn, err error) {
56 c = l.conn
57 if c == nil {
58 err = io.EOF
59 return
60 }
61 err = nil
62 l.conn = nil
63 return
64 }
65
66 func (l *oneConnListener) Close() error {
67 return nil
68 }
69
70 func (l *oneConnListener) Addr() net.Addr {
71 return dummyAddr("test-address")
72 }
73
74 func (a dummyAddr) Network() string {
75 return string(a)
76 }
77
78 func (a dummyAddr) String() string {
79 return string(a)
80 }
81
82 type noopConn struct{}
83
84 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
85 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
86 func (noopConn) SetDeadline(t time.Time) error { return nil }
87 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
88 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
89
90 type rwTestConn struct {
91 io.Reader
92 io.Writer
93 noopConn
94
95 closeFunc func() error
96 closec chan bool
97 }
98
99 func (c *rwTestConn) Close() error {
100 if c.closeFunc != nil {
101 return c.closeFunc()
102 }
103 select {
104 case c.closec <- true:
105 default:
106 }
107 return nil
108 }
109
110 type testConn struct {
111 readMu sync.Mutex
112 readBuf bytes.Buffer
113 writeBuf bytes.Buffer
114 closec chan bool
115 noopConn
116 }
117
118 func newTestConn() *testConn {
119 return &testConn{closec: make(chan bool, 1)}
120 }
121
122 func (c *testConn) Read(b []byte) (int, error) {
123 c.readMu.Lock()
124 defer c.readMu.Unlock()
125 return c.readBuf.Read(b)
126 }
127
128 func (c *testConn) Write(b []byte) (int, error) {
129 return c.writeBuf.Write(b)
130 }
131
132 func (c *testConn) Close() error {
133 select {
134 case c.closec <- true:
135 default:
136 }
137 return nil
138 }
139
140
141
142 func reqBytes(req string) []byte {
143 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
144 }
145
146 type handlerTest struct {
147 logbuf bytes.Buffer
148 handler Handler
149 }
150
151 func newHandlerTest(h Handler) handlerTest {
152 return handlerTest{handler: h}
153 }
154
155 func (ht *handlerTest) rawResponse(req string) string {
156 reqb := reqBytes(req)
157 var output strings.Builder
158 conn := &rwTestConn{
159 Reader: bytes.NewReader(reqb),
160 Writer: &output,
161 closec: make(chan bool, 1),
162 }
163 ln := &oneConnListener{conn: conn}
164 srv := &Server{
165 ErrorLog: log.New(&ht.logbuf, "", 0),
166 Handler: ht.handler,
167 }
168 go srv.Serve(ln)
169 <-conn.closec
170 return output.String()
171 }
172
173 func TestConsumingBodyOnNextConn(t *testing.T) {
174 t.Parallel()
175 defer afterTest(t)
176 conn := new(testConn)
177 for i := 0; i < 2; i++ {
178 conn.readBuf.Write([]byte(
179 "POST / HTTP/1.1\r\n" +
180 "Host: test\r\n" +
181 "Content-Length: 11\r\n" +
182 "\r\n" +
183 "foo=1&bar=1"))
184 }
185
186 reqNum := 0
187 ch := make(chan *Request)
188 servech := make(chan error)
189 listener := &oneConnListener{conn}
190 handler := func(res ResponseWriter, req *Request) {
191 reqNum++
192 ch <- req
193 }
194
195 go func() {
196 servech <- Serve(listener, HandlerFunc(handler))
197 }()
198
199 var req *Request
200 req = <-ch
201 if req == nil {
202 t.Fatal("Got nil first request.")
203 }
204 if req.Method != "POST" {
205 t.Errorf("For request #1's method, got %q; expected %q",
206 req.Method, "POST")
207 }
208
209 req = <-ch
210 if req == nil {
211 t.Fatal("Got nil first request.")
212 }
213 if req.Method != "POST" {
214 t.Errorf("For request #2's method, got %q; expected %q",
215 req.Method, "POST")
216 }
217
218 if serveerr := <-servech; serveerr != io.EOF {
219 t.Errorf("Serve returned %q; expected EOF", serveerr)
220 }
221 }
222
223 type stringHandler string
224
225 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
226 w.Header().Set("Result", string(s))
227 }
228
229 var handlers = []struct {
230 pattern string
231 msg string
232 }{
233 {"/", "Default"},
234 {"/someDir/", "someDir"},
235 {"/#/", "hash"},
236 {"someHost.com/someDir/", "someHost.com/someDir"},
237 }
238
239 var vtests = []struct {
240 url string
241 expected string
242 }{
243 {"http://localhost/someDir/apage", "someDir"},
244 {"http://localhost/%23/apage", "hash"},
245 {"http://localhost/otherDir/apage", "Default"},
246 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
247 {"http://otherHost.com/someDir/apage", "someDir"},
248 {"http://otherHost.com/aDir/apage", "Default"},
249
250 {"http://localhost/someDir", "/someDir/"},
251 {"http://localhost/%23", "/%23/"},
252 {"http://someHost.com/someDir", "/someDir/"},
253 }
254
255 func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
256 func testHostHandlers(t *testing.T, mode testMode) {
257 mux := NewServeMux()
258 for _, h := range handlers {
259 mux.Handle(h.pattern, stringHandler(h.msg))
260 }
261 ts := newClientServerTest(t, mode, mux).ts
262
263 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
264 if err != nil {
265 t.Fatal(err)
266 }
267 defer conn.Close()
268 cc := httputil.NewClientConn(conn, nil)
269 for _, vt := range vtests {
270 var r *Response
271 var req Request
272 if req.URL, err = url.Parse(vt.url); err != nil {
273 t.Errorf("cannot parse url: %v", err)
274 continue
275 }
276 if err := cc.Write(&req); err != nil {
277 t.Errorf("writing request: %v", err)
278 continue
279 }
280 r, err := cc.Read(&req)
281 if err != nil {
282 t.Errorf("reading response: %v", err)
283 continue
284 }
285 switch r.StatusCode {
286 case StatusOK:
287 s := r.Header.Get("Result")
288 if s != vt.expected {
289 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
290 }
291 case StatusTemporaryRedirect:
292 s := r.Header.Get("Location")
293 if s != vt.expected {
294 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
295 }
296 default:
297 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
298 }
299 }
300 }
301
302 var serveMuxRegister = []struct {
303 pattern string
304 h Handler
305 }{
306 {"/dir/", serve(200)},
307 {"/search", serve(201)},
308 {"codesearch.google.com/search", serve(202)},
309 {"codesearch.google.com/", serve(203)},
310 {"example.com/", HandlerFunc(checkQueryStringHandler)},
311 }
312
313
314 func serve(code int) HandlerFunc {
315 return func(w ResponseWriter, r *Request) {
316 w.WriteHeader(code)
317 }
318 }
319
320
321
322
323 func checkQueryStringHandler(w ResponseWriter, r *Request) {
324 u := *r.URL
325 u.Scheme = "http"
326 u.Host = r.Host
327 u.RawQuery = ""
328 if "http://"+r.URL.RawQuery == u.String() {
329 w.WriteHeader(200)
330 } else {
331 w.WriteHeader(500)
332 }
333 }
334
335 var serveMuxTests = []struct {
336 method string
337 host string
338 path string
339 code int
340 pattern string
341 }{
342 {"GET", "google.com", "/", 404, ""},
343 {"GET", "google.com", "/dir", 307, "/dir/"},
344 {"GET", "google.com", "/dir/", 200, "/dir/"},
345 {"GET", "google.com", "/dir/file", 200, "/dir/"},
346 {"GET", "google.com", "/search", 201, "/search"},
347 {"GET", "google.com", "/search/", 404, ""},
348 {"GET", "google.com", "/search/foo", 404, ""},
349 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
350 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
351 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
352 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
353 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
354 {"GET", "images.google.com", "/search", 201, "/search"},
355 {"GET", "images.google.com", "/search/", 404, ""},
356 {"GET", "images.google.com", "/search/foo", 404, ""},
357 {"GET", "google.com", "/../search", 307, "/search"},
358 {"GET", "google.com", "/dir/..", 307, ""},
359 {"GET", "google.com", "/dir/..", 307, ""},
360 {"GET", "google.com", "/dir/./file", 307, "/dir/"},
361
362
363
364 {"CONNECT", "google.com", "/dir", 307, "/dir/"},
365 {"CONNECT", "google.com", "/../search", 404, ""},
366 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
367 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
368 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
369 }
370
371 func TestServeMuxHandler(t *testing.T) {
372 setParallel(t)
373 mux := NewServeMux()
374 for _, e := range serveMuxRegister {
375 mux.Handle(e.pattern, e.h)
376 }
377
378 for _, tt := range serveMuxTests {
379 r := &Request{
380 Method: tt.method,
381 Host: tt.host,
382 URL: &url.URL{
383 Path: tt.path,
384 },
385 }
386 h, pattern := mux.Handler(r)
387 rr := httptest.NewRecorder()
388 h.ServeHTTP(rr, r)
389 if pattern != tt.pattern || rr.Code != tt.code {
390 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
391 }
392 }
393 }
394
395
396 func TestServeMuxHandlerTrailingSlash(t *testing.T) {
397 setParallel(t)
398 mux := NewServeMux()
399 const original = "/{x}/"
400 mux.Handle(original, NotFoundHandler())
401 r, _ := NewRequest("POST", "/foo", nil)
402 _, p := mux.Handler(r)
403 if p != original {
404 t.Errorf("got %q, want %q", p, original)
405 }
406 }
407
408
409 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
410 setParallel(t)
411 defer func() {
412 if err := recover(); err == nil {
413 t.Error("expected call to mux.HandleFunc to panic")
414 }
415 }()
416 mux := NewServeMux()
417 mux.HandleFunc("/", nil)
418 }
419
420 var serveMuxTests2 = []struct {
421 method string
422 host string
423 url string
424 code int
425 redirOk bool
426 }{
427 {"GET", "google.com", "/", 404, false},
428 {"GET", "example.com", "/test/?example.com/test/", 200, false},
429 {"GET", "example.com", "test/?example.com/test/", 200, true},
430 }
431
432
433
434 func TestServeMuxHandlerRedirects(t *testing.T) {
435 setParallel(t)
436 mux := NewServeMux()
437 for _, e := range serveMuxRegister {
438 mux.Handle(e.pattern, e.h)
439 }
440
441 for _, tt := range serveMuxTests2 {
442 tries := 1
443 turl := tt.url
444 for {
445 u, e := url.Parse(turl)
446 if e != nil {
447 t.Fatal(e)
448 }
449 r := &Request{
450 Method: tt.method,
451 Host: tt.host,
452 URL: u,
453 }
454 h, _ := mux.Handler(r)
455 rr := httptest.NewRecorder()
456 h.ServeHTTP(rr, r)
457 if rr.Code != 307 {
458 if rr.Code != tt.code {
459 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
460 }
461 break
462 }
463 if !tt.redirOk {
464 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
465 break
466 }
467 turl = rr.HeaderMap.Get("Location")
468 tries--
469 }
470 if tries < 0 {
471 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
472 }
473 }
474 }
475
476 func TestServeMuxHandlerRedirectPost(t *testing.T) {
477 setParallel(t)
478 mux := NewServeMux()
479 mux.HandleFunc("POST /test/", func(w ResponseWriter, r *Request) {
480 w.WriteHeader(200)
481 })
482
483 var code, retries int
484 startURL := "http://example.com/test"
485 reqURL := startURL
486 for retries = 0; retries <= 1; retries++ {
487 r := httptest.NewRequest("POST", reqURL, strings.NewReader("hello world"))
488 h, _ := mux.Handler(r)
489 rr := httptest.NewRecorder()
490 h.ServeHTTP(rr, r)
491 code = rr.Code
492 switch rr.Code {
493 case 307:
494 reqURL = rr.Result().Header.Get("Location")
495 continue
496 case 200:
497
498 default:
499 t.Errorf("unhandled response code: %v", rr.Code)
500 }
501 }
502 if code != 200 {
503 t.Errorf("POST %s = %d after %d retries, want = 200", startURL, code, retries)
504 }
505 }
506
507
508 func TestMuxRedirectLeadingSlashes(t *testing.T) {
509 setParallel(t)
510 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
511 for _, path := range paths {
512 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
513 if err != nil {
514 t.Errorf("%s", err)
515 }
516 mux := NewServeMux()
517 resp := httptest.NewRecorder()
518
519 mux.ServeHTTP(resp, req)
520
521 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
522 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
523 return
524 }
525
526 if code, expected := resp.Code, StatusTemporaryRedirect; code != expected {
527 t.Errorf("Expected response code of StatusPermanentRedirect; got %d", code)
528 return
529 }
530 }
531 }
532
533
534
535
536
537 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
538 run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
539 }
540 func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
541 writeBackQuery := func(w ResponseWriter, r *Request) {
542 fmt.Fprintf(w, "%s", r.URL.RawQuery)
543 }
544
545 mux := NewServeMux()
546 mux.HandleFunc("/testOne", writeBackQuery)
547 mux.HandleFunc("/testTwo/", writeBackQuery)
548 mux.HandleFunc("/testThree", writeBackQuery)
549 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
550 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
551 })
552
553 ts := newClientServerTest(t, mode, mux).ts
554
555 tests := [...]struct {
556 path string
557 method string
558 want string
559 statusOk bool
560 }{
561 0: {"/testOne?this=that", "GET", "this=that", true},
562 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
563 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
564 3: {"/testTwo?", "GET", "", true},
565 4: {"/testThree?foo", "GET", "foo", true},
566 5: {"/testThree/?foo", "GET", "foo:bar", true},
567 6: {"/testThree?foo", "CONNECT", "foo", true},
568 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
569
570
571 8: {"/testOne/foo/..?foo", "GET", "foo", true},
572 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
573 }
574
575 for i, tt := range tests {
576 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
577 res, err := ts.Client().Do(req)
578 if err != nil {
579 continue
580 }
581 slurp, _ := io.ReadAll(res.Body)
582 res.Body.Close()
583 if !tt.statusOk {
584 if got, want := res.StatusCode, 404; got != want {
585 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
586 }
587 }
588 if got, want := string(slurp), tt.want; got != want {
589 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
590 }
591 }
592 }
593
594 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
595 setParallel(t)
596
597 mux := NewServeMux()
598 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
599 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
600 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
601 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
602 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
603 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
604
605 tests := []struct {
606 method string
607 url string
608 code int
609 loc string
610 want string
611 }{
612 {"GET", "http://example.com/", 404, "", ""},
613 {"GET", "http://example.com/pkg/foo", 307, "/pkg/foo/", ""},
614 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
615 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
616 {"GET", "http://example.com/pkg/baz", 307, "/pkg/baz/", ""},
617 {"GET", "http://example.com:3000/pkg/foo", 307, "/pkg/foo/", ""},
618 {"CONNECT", "http://example.com/", 404, "", ""},
619 {"CONNECT", "http://example.com:3000/", 404, "", ""},
620 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
621 {"CONNECT", "http://example.com/pkg/foo", 307, "/pkg/foo/", ""},
622 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
623 {"CONNECT", "http://example.com:3000/pkg/baz", 307, "/pkg/baz/", ""},
624 {"CONNECT", "http://example.com:3000/pkg/connect", 307, "/pkg/connect/", ""},
625 }
626
627 for i, tt := range tests {
628 req, _ := NewRequest(tt.method, tt.url, nil)
629 w := httptest.NewRecorder()
630 mux.ServeHTTP(w, req)
631
632 if got, want := w.Code, tt.code; got != want {
633 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
634 }
635
636 if tt.code == 301 {
637 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
638 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
639 }
640 } else {
641 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
642 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
643 }
644 }
645 }
646 }
647
648
649
650
651 func TestMuxNoSlashRedirectWithTrailingSlash(t *testing.T) {
652 mux := NewServeMux()
653 mux.HandleFunc("/{x}/", func(w ResponseWriter, r *Request) {
654 fmt.Fprintln(w, "ok")
655 })
656 w := httptest.NewRecorder()
657 req, _ := NewRequest("GET", "/", nil)
658 mux.ServeHTTP(w, req)
659 if g, w := w.Code, 404; g != w {
660 t.Errorf("got %d, want %d", g, w)
661 }
662 }
663
664
665
666
667 func TestMuxNoSlash405WithTrailingSlash(t *testing.T) {
668 mux := NewServeMux()
669 mux.HandleFunc("GET /{x}/", func(w ResponseWriter, r *Request) {
670 fmt.Fprintln(w, "ok")
671 })
672 w := httptest.NewRecorder()
673 req, _ := NewRequest("GET", "/", nil)
674 mux.ServeHTTP(w, req)
675 if g, w := w.Code, 404; g != w {
676 t.Errorf("got %d, want %d", g, w)
677 }
678 }
679
680 func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
681 func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
682 mux := NewServeMux()
683 newClientServerTest(t, mode, mux)
684 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
685 }
686
687 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
688 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
689 func benchmarkServeMux(b *testing.B, runHandler bool) {
690 type test struct {
691 path string
692 code int
693 req *Request
694 }
695
696
697 var tests []test
698 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
699 for _, e := range endpoints {
700 for i := 200; i < 230; i++ {
701 p := fmt.Sprintf("/%s/%d/", e, i)
702 tests = append(tests, test{
703 path: p,
704 code: i,
705 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
706 })
707 }
708 }
709 mux := NewServeMux()
710 for _, tt := range tests {
711 mux.Handle(tt.path, serve(tt.code))
712 }
713
714 rw := httptest.NewRecorder()
715 b.ReportAllocs()
716 b.ResetTimer()
717 for i := 0; i < b.N; i++ {
718 for _, tt := range tests {
719 *rw = httptest.ResponseRecorder{}
720 h, pattern := mux.Handler(tt.req)
721 if runHandler {
722 h.ServeHTTP(rw, tt.req)
723 if pattern != tt.path || rw.Code != tt.code {
724 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
725 }
726 }
727 }
728 }
729 }
730
731 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
732 func testServerTimeouts(t *testing.T, mode testMode) {
733 runTimeSensitiveTest(t, []time.Duration{
734 10 * time.Millisecond,
735 50 * time.Millisecond,
736 100 * time.Millisecond,
737 500 * time.Millisecond,
738 1 * time.Second,
739 }, func(t *testing.T, timeout time.Duration) error {
740 return testServerTimeoutsWithTimeout(t, timeout, mode)
741 })
742 }
743
744 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
745 var reqNum atomic.Int32
746 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
747 fmt.Fprintf(res, "req=%d", reqNum.Add(1))
748 }), func(ts *httptest.Server) {
749 ts.Config.ReadTimeout = timeout
750 ts.Config.WriteTimeout = timeout
751 })
752 defer cst.close()
753 ts := cst.ts
754
755
756 c := ts.Client()
757 r, err := c.Get(ts.URL)
758 if err != nil {
759 return fmt.Errorf("http Get #1: %v", err)
760 }
761 got, err := io.ReadAll(r.Body)
762 expected := "req=1"
763 if string(got) != expected || err != nil {
764 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
765 string(got), err, expected)
766 }
767
768
769 t1 := time.Now()
770 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
771 if err != nil {
772 return fmt.Errorf("Dial: %v", err)
773 }
774 buf := make([]byte, 1)
775 n, err := conn.Read(buf)
776 conn.Close()
777 latency := time.Since(t1)
778 if n != 0 || err != io.EOF {
779 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
780 }
781 minLatency := timeout / 5 * 4
782 if latency < minLatency {
783 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
784 }
785
786
787
788
789 r, err = c.Get(ts.URL)
790 if err != nil {
791 return fmt.Errorf("http Get #2: %v", err)
792 }
793 got, err = io.ReadAll(r.Body)
794 r.Body.Close()
795 expected = "req=2"
796 if string(got) != expected || err != nil {
797 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
798 }
799
800 if !testing.Short() {
801 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
802 if err != nil {
803 return fmt.Errorf("long Dial: %v", err)
804 }
805 defer conn.Close()
806 go io.Copy(io.Discard, conn)
807 for i := 0; i < 5; i++ {
808 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
809 if err != nil {
810 return fmt.Errorf("on write %d: %v", i, err)
811 }
812 time.Sleep(timeout / 2)
813 }
814 }
815 return nil
816 }
817
818 func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout) }
819 func testServerReadTimeout(t *testing.T, mode testMode) {
820 respBody := "response body"
821 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
822 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
823 _, err := io.Copy(io.Discard, req.Body)
824 if !errors.Is(err, os.ErrDeadlineExceeded) {
825 t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
826 }
827 res.Write([]byte(respBody))
828 }), func(ts *httptest.Server) {
829 ts.Config.ReadHeaderTimeout = -1
830 ts.Config.ReadTimeout = timeout
831 t.Logf("Server.Config.ReadTimeout = %v", timeout)
832 })
833
834 var retries atomic.Int32
835 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
836 if retries.Add(1) != 1 {
837 return nil, errors.New("too many retries")
838 }
839 return nil, nil
840 }
841
842 pr, pw := io.Pipe()
843 res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
844 if err != nil {
845 t.Logf("Get error, retrying: %v", err)
846 cst.close()
847 continue
848 }
849 defer res.Body.Close()
850 got, err := io.ReadAll(res.Body)
851 if string(got) != respBody || err != nil {
852 t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
853 }
854 pw.Close()
855 break
856 }
857 }
858
859 func TestServerNoReadTimeout(t *testing.T) { run(t, testServerNoReadTimeout) }
860 func testServerNoReadTimeout(t *testing.T, mode testMode) {
861 reqBody := "Hello, Gophers!"
862 resBody := "Hi, Gophers!"
863 for _, timeout := range []time.Duration{0, -1} {
864 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
865 ctl := NewResponseController(res)
866 ctl.EnableFullDuplex()
867 res.WriteHeader(StatusOK)
868
869
870 if err := ctl.Flush(); err != nil {
871 t.Errorf("server flush response: %v", err)
872 return
873 }
874 got, err := io.ReadAll(req.Body)
875 if string(got) != reqBody || err != nil {
876 t.Errorf("server read request body: %v; got %q, want %q", err, got, reqBody)
877 }
878 res.Write([]byte(resBody))
879 }), func(ts *httptest.Server) {
880 ts.Config.ReadTimeout = timeout
881 t.Logf("Server.Config.ReadTimeout = %d", timeout)
882 })
883
884 pr, pw := io.Pipe()
885 res, err := cst.c.Post(cst.ts.URL, "text/plain", pr)
886 if err != nil {
887 t.Fatal(err)
888 }
889 defer res.Body.Close()
890
891
892 time.Sleep(10 * time.Millisecond)
893 pw.Write([]byte(reqBody))
894 pw.Close()
895
896 got, err := io.ReadAll(res.Body)
897 if string(got) != resBody || err != nil {
898 t.Errorf("client read response body: %v; got %v, want %q", err, got, resBody)
899 }
900 }
901 }
902
903 func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout) }
904 func testServerWriteTimeout(t *testing.T, mode testMode) {
905 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
906 errc := make(chan error, 2)
907 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
908 errc <- nil
909 _, err := io.Copy(res, neverEnding('a'))
910 errc <- err
911 }), func(ts *httptest.Server) {
912 ts.Config.WriteTimeout = timeout
913 t.Logf("Server.Config.WriteTimeout = %v", timeout)
914 })
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933 var retries atomic.Int32
934 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
935 if retries.Add(1) != 1 {
936 return nil, errors.New("too many retries")
937 }
938 return nil, nil
939 }
940
941 res, err := cst.c.Get(cst.ts.URL)
942 if err != nil {
943
944 t.Logf("Get error, retrying: %v", err)
945 cst.close()
946 continue
947 }
948 defer res.Body.Close()
949 _, err = io.Copy(io.Discard, res.Body)
950 if err == nil {
951 t.Errorf("client reading from truncated request body: got nil error, want non-nil")
952 }
953 select {
954 case <-errc:
955 err = <-errc
956 if !errors.Is(err, os.ErrDeadlineExceeded) {
957 t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
958 }
959 return
960 default:
961
962 t.Logf("handler didn't run, retrying")
963 cst.close()
964 }
965 }
966 }
967
968 func TestServerNoWriteTimeout(t *testing.T) { run(t, testServerNoWriteTimeout) }
969 func testServerNoWriteTimeout(t *testing.T, mode testMode) {
970 for _, timeout := range []time.Duration{0, -1} {
971 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
972 _, err := io.Copy(res, neverEnding('a'))
973 t.Logf("server write response: %v", err)
974 }), func(ts *httptest.Server) {
975 ts.Config.WriteTimeout = timeout
976 t.Logf("Server.Config.WriteTimeout = %d", timeout)
977 })
978
979 res, err := cst.c.Get(cst.ts.URL)
980 if err != nil {
981 t.Fatal(err)
982 }
983 defer res.Body.Close()
984 n, err := io.CopyN(io.Discard, res.Body, 1<<20)
985 if n != 1<<20 || err != nil {
986 t.Errorf("client read response body: %d, %v", n, err)
987 }
988
989
990 res.Body.Close()
991 cst.ts.Config.Shutdown(context.Background())
992 }
993 }
994
995
996 func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
997 run(t, testWriteDeadlineExtendedOnNewRequest)
998 }
999 func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
1000 if testing.Short() {
1001 t.Skip("skipping in short mode")
1002 }
1003 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
1004 func(ts *httptest.Server) {
1005 ts.Config.WriteTimeout = 250 * time.Millisecond
1006 },
1007 ).ts
1008
1009 c := ts.Client()
1010
1011 for i := 1; i <= 3; i++ {
1012 req, err := NewRequest("GET", ts.URL, nil)
1013 if err != nil {
1014 t.Fatal(err)
1015 }
1016
1017 r, err := c.Do(req)
1018 if err != nil {
1019 t.Fatalf("http2 Get #%d: %v", i, err)
1020 }
1021 r.Body.Close()
1022 time.Sleep(ts.Config.WriteTimeout / 2)
1023 }
1024 }
1025
1026
1027
1028 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
1029 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
1030 for i, timeout := range tries {
1031 err := testFunc(timeout)
1032 if err == nil {
1033 return
1034 }
1035 t.Logf("failed at %v: %v", timeout, err)
1036 if i != len(tries)-1 {
1037 t.Logf("retrying at %v ...", tries[i+1])
1038 }
1039 }
1040 t.Fatal("all attempts failed")
1041 }
1042
1043
1044 func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
1045 if testing.Short() {
1046 t.Skip("skipping in short mode")
1047 }
1048 setParallel(t)
1049 run(t, func(t *testing.T, mode testMode) {
1050 tryTimeouts(t, func(timeout time.Duration) error {
1051 return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
1052 })
1053 })
1054 }
1055
1056 func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
1057 firstRequest := make(chan bool, 1)
1058 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1059 select {
1060 case firstRequest <- true:
1061
1062 default:
1063
1064 time.Sleep(timeout)
1065 }
1066 }), func(ts *httptest.Server) {
1067 ts.Config.WriteTimeout = timeout / 2
1068 })
1069 defer cst.close()
1070 ts := cst.ts
1071
1072 c := ts.Client()
1073
1074 req, err := NewRequest("GET", ts.URL, nil)
1075 if err != nil {
1076 return fmt.Errorf("NewRequest: %v", err)
1077 }
1078 r, err := c.Do(req)
1079 if err != nil {
1080 return fmt.Errorf("Get #1: %v", err)
1081 }
1082 r.Body.Close()
1083
1084 req, err = NewRequest("GET", ts.URL, nil)
1085 if err != nil {
1086 return fmt.Errorf("NewRequest: %v", err)
1087 }
1088 r, err = c.Do(req)
1089 if err == nil {
1090 r.Body.Close()
1091 return fmt.Errorf("Get #2 expected error, got nil")
1092 }
1093 if mode == http2Mode {
1094 expected := "stream ID 3; INTERNAL_ERROR"
1095 if !strings.Contains(err.Error(), expected) {
1096 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
1097 }
1098 }
1099 return nil
1100 }
1101
1102
1103 func TestNoWriteDeadline(t *testing.T) {
1104 if testing.Short() {
1105 t.Skip("skipping in short mode")
1106 }
1107 setParallel(t)
1108 defer afterTest(t)
1109 run(t, func(t *testing.T, mode testMode) {
1110 tryTimeouts(t, func(timeout time.Duration) error {
1111 return testNoWriteDeadline(t, mode, timeout)
1112 })
1113 })
1114 }
1115
1116 func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
1117 firstRequest := make(chan bool, 1)
1118 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1119 select {
1120 case firstRequest <- true:
1121
1122 default:
1123
1124 time.Sleep(timeout)
1125 }
1126 }))
1127 defer cst.close()
1128 ts := cst.ts
1129
1130 c := ts.Client()
1131
1132 for i := 0; i < 2; i++ {
1133 req, err := NewRequest("GET", ts.URL, nil)
1134 if err != nil {
1135 return fmt.Errorf("NewRequest: %v", err)
1136 }
1137 r, err := c.Do(req)
1138 if err != nil {
1139 return fmt.Errorf("Get #%d: %v", i, err)
1140 }
1141 r.Body.Close()
1142 }
1143 return nil
1144 }
1145
1146
1147
1148
1149 func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
1150 func testOnlyWriteTimeout(t *testing.T, mode testMode) {
1151 var (
1152 mu sync.RWMutex
1153 conn net.Conn
1154 )
1155 var afterTimeoutErrc = make(chan error, 1)
1156 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
1157 buf := make([]byte, 512<<10)
1158 _, err := w.Write(buf)
1159 if err != nil {
1160 t.Errorf("handler Write error: %v", err)
1161 return
1162 }
1163 mu.RLock()
1164 defer mu.RUnlock()
1165 if conn == nil {
1166 t.Error("no established connection found")
1167 return
1168 }
1169 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
1170 _, err = w.Write(buf)
1171 afterTimeoutErrc <- err
1172 }), func(ts *httptest.Server) {
1173 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
1174 }).ts
1175
1176 c := ts.Client()
1177
1178 err := func() error {
1179 res, err := c.Get(ts.URL)
1180 if err != nil {
1181 return err
1182 }
1183 _, err = io.Copy(io.Discard, res.Body)
1184 res.Body.Close()
1185 return err
1186 }()
1187 if err == nil {
1188 t.Errorf("expected an error copying body from Get request")
1189 }
1190
1191 if err := <-afterTimeoutErrc; err == nil {
1192 t.Error("expected write error after timeout")
1193 }
1194 }
1195
1196
1197 type trackLastConnListener struct {
1198 net.Listener
1199
1200 mu *sync.RWMutex
1201 last *net.Conn
1202 }
1203
1204 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
1205 c, err = l.Listener.Accept()
1206 if err == nil {
1207 l.mu.Lock()
1208 *l.last = c
1209 l.mu.Unlock()
1210 }
1211 return
1212 }
1213
1214
1215 func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
1216 func testIdentityResponse(t *testing.T, mode testMode) {
1217 if mode == http2Mode {
1218 t.Skip("https://go.dev/issue/56019")
1219 }
1220
1221 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
1222 rw.Header().Set("Content-Length", "3")
1223 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1224 switch {
1225 case req.FormValue("overwrite") == "1":
1226 _, err := rw.Write([]byte("foo TOO LONG"))
1227 if err != ErrContentLength {
1228 t.Errorf("expected ErrContentLength; got %v", err)
1229 }
1230 case req.FormValue("underwrite") == "1":
1231 rw.Header().Set("Content-Length", "500")
1232 rw.Write([]byte("too short"))
1233 default:
1234 rw.Write([]byte("foo"))
1235 }
1236 })
1237
1238 ts := newClientServerTest(t, mode, handler).ts
1239 c := ts.Client()
1240
1241
1242
1243
1244
1245 for _, te := range []string{"", "identity"} {
1246 url := ts.URL + "/?te=" + te
1247 res, err := c.Get(url)
1248 if err != nil {
1249 t.Fatalf("error with Get of %s: %v", url, err)
1250 }
1251 if cl, expected := res.ContentLength, int64(3); cl != expected {
1252 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1253 }
1254 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1255 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1256 }
1257 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1258 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1259 url, expected, tl, res.TransferEncoding)
1260 }
1261 res.Body.Close()
1262 }
1263
1264
1265 url := ts.URL + "/?overwrite=1"
1266 res, err := c.Get(url)
1267 if err != nil {
1268 t.Fatalf("error with Get of %s: %v", url, err)
1269 }
1270 res.Body.Close()
1271
1272 if mode != http1Mode {
1273 return
1274 }
1275
1276
1277
1278 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1279 if err != nil {
1280 t.Fatalf("error dialing: %v", err)
1281 }
1282 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1283 if err != nil {
1284 t.Fatalf("error writing: %v", err)
1285 }
1286
1287
1288 got, _ := io.ReadAll(conn)
1289 expectedSuffix := "\r\n\r\ntoo short"
1290 if !strings.HasSuffix(string(got), expectedSuffix) {
1291 t.Errorf("Expected output to end with %q; got response body %q",
1292 expectedSuffix, string(got))
1293 }
1294 }
1295
1296 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1297 setParallel(t)
1298 s := newClientServerTest(t, http1Mode, h).ts
1299
1300 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1301 if err != nil {
1302 t.Fatal("dial error:", err)
1303 }
1304 defer conn.Close()
1305
1306 _, err = fmt.Fprint(conn, req)
1307 if err != nil {
1308 t.Fatal("print error:", err)
1309 }
1310
1311 r := bufio.NewReader(conn)
1312 res, err := ReadResponse(r, &Request{Method: "GET"})
1313 if err != nil {
1314 t.Fatal("ReadResponse error:", err)
1315 }
1316
1317 _, err = io.ReadAll(r)
1318 if err != nil {
1319 t.Fatal("read error:", err)
1320 }
1321
1322 if !res.Close {
1323 t.Errorf("Response.Close = false; want true")
1324 }
1325 }
1326
1327 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1328 setParallel(t)
1329 ts := newClientServerTest(t, http1Mode, handler).ts
1330 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1331 if err != nil {
1332 t.Fatal(err)
1333 }
1334 defer conn.Close()
1335 br := bufio.NewReader(conn)
1336 for i := 0; i < 2; i++ {
1337 if _, err := io.WriteString(conn, req); err != nil {
1338 t.Fatal(err)
1339 }
1340 res, err := ReadResponse(br, nil)
1341 if err != nil {
1342 t.Fatalf("res %d: %v", i+1, err)
1343 }
1344 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1345 t.Fatalf("res %d body copy: %v", i+1, err)
1346 }
1347 res.Body.Close()
1348 }
1349 }
1350
1351
1352 func TestServeHTTP10Close(t *testing.T) {
1353 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1354 ServeFile(w, r, "testdata/file")
1355 }))
1356 }
1357
1358
1359 func TestClientCanClose(t *testing.T) {
1360 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1361
1362 }))
1363 }
1364
1365
1366
1367 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1368 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1369 w.Header().Set("Connection", "close")
1370 }))
1371 }
1372
1373 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1374 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1375 w.Header().Set("Connection", "close")
1376 }))
1377 }
1378
1379 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1380 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1381
1382
1383 }))
1384 }
1385
1386 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1387 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1388
1389
1390 func TestHTTP10KeepAlive204Response(t *testing.T) {
1391 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1392 }
1393
1394 func TestHTTP11KeepAlive204Response(t *testing.T) {
1395 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1396 }
1397
1398 func TestHTTP10KeepAlive304Response(t *testing.T) {
1399 testTCPConnectionStaysOpen(t,
1400 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1401 HandlerFunc(send304))
1402 }
1403
1404
1405 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
1406 func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
1407 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1408 w.(Flusher).Flush()
1409 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1410 }))
1411 type data struct {
1412 Addr string
1413 }
1414 var addrs [2]data
1415 for i := range addrs {
1416 res, err := cst.c.Get(cst.ts.URL)
1417 if err != nil {
1418 t.Fatal(err)
1419 }
1420 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1421 t.Fatal(err)
1422 }
1423 if addrs[i].Addr == "" {
1424 t.Fatal("no address")
1425 }
1426 res.Body.Close()
1427 }
1428 if addrs[0] != addrs[1] {
1429 t.Fatalf("connection not reused")
1430 }
1431 }
1432
1433 func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
1434 func testSetsRemoteAddr(t *testing.T, mode testMode) {
1435 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1436 fmt.Fprintf(w, "%s", r.RemoteAddr)
1437 }))
1438
1439 res, err := cst.c.Get(cst.ts.URL)
1440 if err != nil {
1441 t.Fatalf("Get error: %v", err)
1442 }
1443 body, err := io.ReadAll(res.Body)
1444 if err != nil {
1445 t.Fatalf("ReadAll error: %v", err)
1446 }
1447 ip := string(body)
1448 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1449 t.Fatalf("Expected local addr; got %q", ip)
1450 }
1451 }
1452
1453 type blockingRemoteAddrListener struct {
1454 net.Listener
1455 conns chan<- net.Conn
1456 }
1457
1458 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1459 c, err := l.Listener.Accept()
1460 if err != nil {
1461 return nil, err
1462 }
1463 brac := &blockingRemoteAddrConn{
1464 Conn: c,
1465 addrs: make(chan net.Addr, 1),
1466 }
1467 l.conns <- brac
1468 return brac, nil
1469 }
1470
1471 type blockingRemoteAddrConn struct {
1472 net.Conn
1473 addrs chan net.Addr
1474 }
1475
1476 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1477 return <-c.addrs
1478 }
1479
1480
1481 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1482 run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
1483 }
1484 func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
1485 conns := make(chan net.Conn)
1486 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1487 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1488 }), func(ts *httptest.Server) {
1489 ts.Listener = &blockingRemoteAddrListener{
1490 Listener: ts.Listener,
1491 conns: conns,
1492 }
1493 }).ts
1494
1495 c := ts.Client()
1496
1497 c.Transport.(*Transport).DisableKeepAlives = true
1498
1499 fetch := func(num int, response chan<- string) {
1500 resp, err := c.Get(ts.URL)
1501 if err != nil {
1502 t.Errorf("Request %d: %v", num, err)
1503 response <- ""
1504 return
1505 }
1506 defer resp.Body.Close()
1507 body, err := io.ReadAll(resp.Body)
1508 if err != nil {
1509 t.Errorf("Request %d: %v", num, err)
1510 response <- ""
1511 return
1512 }
1513 response <- string(body)
1514 }
1515
1516
1517 response1c := make(chan string, 1)
1518 go fetch(1, response1c)
1519
1520
1521 conn1 := <-conns
1522
1523
1524 response2c := make(chan string, 1)
1525 go fetch(2, response2c)
1526 conn2 := <-conns
1527
1528
1529 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1530 IP: net.ParseIP("12.12.12.12"), Port: 12}
1531
1532
1533 response2 := <-response2c
1534 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1535 t.Fatalf("response 2 addr = %q; want %q", g, e)
1536 }
1537
1538
1539 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1540 IP: net.ParseIP("21.21.21.21"), Port: 21}
1541
1542
1543 response1 := <-response1c
1544 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1545 t.Fatalf("response 1 addr = %q; want %q", g, e)
1546 }
1547 }
1548
1549
1550
1551 func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
1552 func testHeadResponses(t *testing.T, mode testMode) {
1553 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1554 _, err := w.Write([]byte("<html>"))
1555 if err != nil {
1556 t.Errorf("ResponseWriter.Write: %v", err)
1557 }
1558
1559
1560 _, err = io.Copy(w, struct{ io.Reader }{strings.NewReader("789a")})
1561 if err != nil {
1562 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1563 }
1564 }))
1565 res, err := cst.c.Head(cst.ts.URL)
1566 if err != nil {
1567 t.Error(err)
1568 }
1569 if len(res.TransferEncoding) > 0 {
1570 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1571 }
1572 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1573 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1574 }
1575 if v := res.ContentLength; v != 10 {
1576 t.Errorf("Content-Length: %d; want 10", v)
1577 }
1578 body, err := io.ReadAll(res.Body)
1579 if err != nil {
1580 t.Error(err)
1581 }
1582 if len(body) > 0 {
1583 t.Errorf("got unexpected body %q", string(body))
1584 }
1585 }
1586
1587
1588
1589 func TestHeadReaderFrom(t *testing.T) { run(t, testHeadReaderFrom, []testMode{http1Mode}) }
1590 func testHeadReaderFrom(t *testing.T, mode testMode) {
1591
1592 wantBody := strings.Repeat("a", 4096)
1593 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1594 w.(io.ReaderFrom).ReadFrom(strings.NewReader(wantBody))
1595 }))
1596 res, err := cst.c.Head(cst.ts.URL)
1597 if err != nil {
1598 t.Fatal(err)
1599 }
1600 res.Body.Close()
1601 res, err = cst.c.Get(cst.ts.URL)
1602 if err != nil {
1603 t.Fatal(err)
1604 }
1605 gotBody, err := io.ReadAll(res.Body)
1606 res.Body.Close()
1607 if err != nil {
1608 t.Fatal(err)
1609 }
1610 if string(gotBody) != wantBody {
1611 t.Errorf("got unexpected body len=%v, want %v", len(gotBody), len(wantBody))
1612 }
1613 }
1614
1615 func TestTLSHandshakeTimeout(t *testing.T) {
1616 run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
1617 }
1618 func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
1619 errLog := new(strings.Builder)
1620 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
1621 func(ts *httptest.Server) {
1622 ts.Config.ReadTimeout = 250 * time.Millisecond
1623 ts.Config.ErrorLog = log.New(errLog, "", 0)
1624 },
1625 )
1626 ts := cst.ts
1627
1628 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1629 if err != nil {
1630 t.Fatalf("Dial: %v", err)
1631 }
1632 var buf [1]byte
1633 n, err := conn.Read(buf[:])
1634 if err == nil || n != 0 {
1635 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1636 }
1637 conn.Close()
1638
1639 cst.close()
1640 if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1641 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1642 }
1643 }
1644
1645 func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
1646 func testTLSServer(t *testing.T, mode testMode) {
1647 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1648 if r.TLS != nil {
1649 w.Header().Set("X-TLS-Set", "true")
1650 if r.TLS.HandshakeComplete {
1651 w.Header().Set("X-TLS-HandshakeComplete", "true")
1652 }
1653 }
1654 }), func(ts *httptest.Server) {
1655 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1656 }).ts
1657
1658
1659
1660
1661
1662
1663 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1664 if err != nil {
1665 t.Fatalf("Dial: %v", err)
1666 }
1667 defer idleConn.Close()
1668
1669 if !strings.HasPrefix(ts.URL, "https://") {
1670 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1671 return
1672 }
1673 client := ts.Client()
1674 res, err := client.Get(ts.URL)
1675 if err != nil {
1676 t.Error(err)
1677 return
1678 }
1679 if res == nil {
1680 t.Errorf("got nil Response")
1681 return
1682 }
1683 defer res.Body.Close()
1684 if res.Header.Get("X-TLS-Set") != "true" {
1685 t.Errorf("expected X-TLS-Set response header")
1686 return
1687 }
1688 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1689 t.Errorf("expected X-TLS-HandshakeComplete header")
1690 }
1691 }
1692
1693 type fakeConnectionStateConn struct {
1694 net.Conn
1695 }
1696
1697 func (fcsc *fakeConnectionStateConn) ConnectionState() tls.ConnectionState {
1698 return tls.ConnectionState{
1699 ServerName: "example.com",
1700 }
1701 }
1702
1703 func TestTLSServerWithoutTLSConn(t *testing.T) {
1704
1705 pr, pw := net.Pipe()
1706 c := make(chan int)
1707 listener := &oneConnListener{&fakeConnectionStateConn{pr}}
1708 server := &Server{
1709 Handler: HandlerFunc(func(writer ResponseWriter, request *Request) {
1710 if request.TLS == nil {
1711 t.Fatal("request.TLS is nil, expected not nil")
1712 }
1713 if request.TLS.ServerName != "example.com" {
1714 t.Fatalf("request.TLS.ServerName is %s, expected %s", request.TLS.ServerName, "example.com")
1715 }
1716 writer.Header().Set("X-TLS-ServerName", "example.com")
1717 }),
1718 }
1719
1720
1721 go func() {
1722 req, _ := NewRequest(MethodGet, "https://example.com", nil)
1723 req.Write(pw)
1724
1725 resp, _ := ReadResponse(bufio.NewReader(pw), req)
1726 if hdr := resp.Header.Get("X-TLS-ServerName"); hdr != "example.com" {
1727 t.Errorf("response header X-TLS-ServerName is %s, expected %s", hdr, "example.com")
1728 }
1729 close(c)
1730 pw.Close()
1731 }()
1732
1733 server.Serve(listener)
1734
1735
1736 <-c
1737 pr.Close()
1738 }
1739
1740 func TestServeTLS(t *testing.T) {
1741 CondSkipHTTP2(t)
1742
1743 defer afterTest(t)
1744 defer SetTestHookServerServe(nil)
1745
1746 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1747 if err != nil {
1748 t.Fatal(err)
1749 }
1750 tlsConf := &tls.Config{
1751 Certificates: []tls.Certificate{cert},
1752 }
1753
1754 ln := newLocalListener(t)
1755 defer ln.Close()
1756 addr := ln.Addr().String()
1757
1758 serving := make(chan bool, 1)
1759 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1760 serving <- true
1761 })
1762 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1763 s := &Server{
1764 Addr: addr,
1765 TLSConfig: tlsConf,
1766 Handler: handler,
1767 }
1768 errc := make(chan error, 1)
1769 go func() { errc <- s.ServeTLS(ln, "", "") }()
1770 select {
1771 case err := <-errc:
1772 t.Fatalf("ServeTLS: %v", err)
1773 case <-serving:
1774 }
1775
1776 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1777 InsecureSkipVerify: true,
1778 NextProtos: []string{"h2", "http/1.1"},
1779 })
1780 if err != nil {
1781 t.Fatal(err)
1782 }
1783 defer c.Close()
1784 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1785 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1786 }
1787 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1788 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1789 }
1790 }
1791
1792
1793 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1794 run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
1795 }
1796 func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
1797 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1798 t.Error("unexpected HTTPS request")
1799 }), func(ts *httptest.Server) {
1800 var errBuf bytes.Buffer
1801 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1802 }).ts
1803 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1804 if err != nil {
1805 t.Fatal(err)
1806 }
1807 defer conn.Close()
1808 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1809 slurp, err := io.ReadAll(conn)
1810 if err != nil {
1811 t.Fatal(err)
1812 }
1813 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1814 if !strings.HasPrefix(string(slurp), wantPrefix) {
1815 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1816 }
1817 }
1818
1819
1820 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1821 testAutomaticHTTP2_Serve(t, nil, true)
1822 }
1823
1824 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1825 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1826 }
1827
1828 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1829 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1830 }
1831
1832 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1833 setParallel(t)
1834 defer afterTest(t)
1835 ln := newLocalListener(t)
1836 ln.Close()
1837 var s Server
1838 s.TLSConfig = tlsConf
1839 if err := s.Serve(ln); err == nil {
1840 t.Fatal("expected an error")
1841 }
1842 gotH2 := s.TLSNextProto["h2"] != nil
1843 if gotH2 != wantH2 {
1844 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1845 }
1846 }
1847
1848 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1849 setParallel(t)
1850 defer afterTest(t)
1851 ln := newLocalListener(t)
1852 ln.Close()
1853 var s Server
1854
1855
1856 s.TLSConfig = &tls.Config{
1857 NextProtos: []string{"h2"},
1858 }
1859 if err := s.Serve(ln); err == nil {
1860 t.Fatal("expected an error")
1861 }
1862 on := s.TLSNextProto["h2"] != nil
1863 if !on {
1864 t.Errorf("http2 wasn't automatically enabled")
1865 }
1866 }
1867
1868 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1869 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1870 if err != nil {
1871 t.Fatal(err)
1872 }
1873 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1874 Certificates: []tls.Certificate{cert},
1875 })
1876 }
1877
1878 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1879 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1880 if err != nil {
1881 t.Fatal(err)
1882 }
1883 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1884 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1885 return &cert, nil
1886 },
1887 })
1888 }
1889
1890 func TestAutomaticHTTP2_ListenAndServe_GetConfigForClient(t *testing.T) {
1891 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1892 if err != nil {
1893 t.Fatal(err)
1894 }
1895 conf := &tls.Config{
1896
1897
1898 NextProtos: []string{"h2"},
1899 Certificates: []tls.Certificate{cert},
1900 }
1901 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1902 GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
1903 return conf, nil
1904 },
1905 })
1906 }
1907
1908 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1909 CondSkipHTTP2(t)
1910
1911 defer afterTest(t)
1912 defer SetTestHookServerServe(nil)
1913 var ok bool
1914 var s *Server
1915 const maxTries = 5
1916 var ln net.Listener
1917 Try:
1918 for try := 0; try < maxTries; try++ {
1919 ln = newLocalListener(t)
1920 addr := ln.Addr().String()
1921 ln.Close()
1922 t.Logf("Got %v", addr)
1923 lnc := make(chan net.Listener, 1)
1924 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1925 lnc <- ln
1926 })
1927 s = &Server{
1928 Addr: addr,
1929 TLSConfig: tlsConf,
1930 }
1931 errc := make(chan error, 1)
1932 go func() { errc <- s.ListenAndServeTLS("", "") }()
1933 select {
1934 case err := <-errc:
1935 t.Logf("On try #%v: %v", try+1, err)
1936 continue
1937 case ln = <-lnc:
1938 ok = true
1939 t.Logf("Listening on %v", ln.Addr().String())
1940 break Try
1941 }
1942 }
1943 if !ok {
1944 t.Fatalf("Failed to start up after %d tries", maxTries)
1945 }
1946 defer ln.Close()
1947 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1948 InsecureSkipVerify: true,
1949 NextProtos: []string{"h2", "http/1.1"},
1950 })
1951 if err != nil {
1952 t.Fatal(err)
1953 }
1954 defer c.Close()
1955 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1956 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1957 }
1958 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1959 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1960 }
1961 }
1962
1963 type serverExpectTest struct {
1964 contentLength int
1965 chunked bool
1966 expectation string
1967 readBody bool
1968 expectedResponse string
1969 }
1970
1971 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
1972 return serverExpectTest{
1973 contentLength: contentLength,
1974 expectation: expectation,
1975 readBody: readBody,
1976 expectedResponse: expectedResponse,
1977 }
1978 }
1979
1980 var serverExpectTests = []serverExpectTest{
1981
1982 expectTest(100, "100-continue", true, "100 Continue"),
1983 expectTest(100, "100-cOntInUE", true, "100 Continue"),
1984
1985
1986 expectTest(100, "", true, "200 OK"),
1987
1988
1989
1990 expectTest(100, "100-continue", false, "401 Unauthorized"),
1991
1992 expectTest(100, "", false, "401 Unauthorized"),
1993
1994
1995 expectTest(0, "a-pony", false, "417 Expectation Failed"),
1996
1997
1998 expectTest(0, "100-continue", true, "200 OK"),
1999
2000 expectTest(0, "100-continue", false, "401 Unauthorized"),
2001
2002 {
2003 expectation: "100-continue",
2004 readBody: true,
2005 chunked: true,
2006 expectedResponse: "100 Continue",
2007 },
2008 }
2009
2010
2011
2012 func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
2013 func testServerExpect(t *testing.T, mode testMode) {
2014 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2015
2016
2017
2018 if strings.Contains(r.URL.RawQuery, "readbody=true") {
2019 io.ReadAll(r.Body)
2020 w.Write([]byte("Hi"))
2021 } else {
2022 w.WriteHeader(StatusUnauthorized)
2023 }
2024 })).ts
2025
2026 runTest := func(test serverExpectTest) {
2027 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
2028 if err != nil {
2029 t.Fatalf("Dial: %v", err)
2030 }
2031 defer conn.Close()
2032
2033
2034
2035 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
2036
2037 wg := sync.WaitGroup{}
2038 wg.Add(1)
2039 defer wg.Wait()
2040
2041 go func() {
2042 defer wg.Done()
2043
2044 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
2045 if test.chunked {
2046 contentLen = "Transfer-Encoding: chunked"
2047 }
2048 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
2049 "Connection: close\r\n"+
2050 "%s\r\n"+
2051 "Expect: %s\r\nHost: foo\r\n\r\n",
2052 test.readBody, contentLen, test.expectation)
2053 if err != nil {
2054 t.Errorf("On test %#v, error writing request headers: %v", test, err)
2055 return
2056 }
2057 if writeBody {
2058 var targ io.WriteCloser = struct {
2059 io.Writer
2060 io.Closer
2061 }{
2062 conn,
2063 io.NopCloser(nil),
2064 }
2065 if test.chunked {
2066 targ = httputil.NewChunkedWriter(conn)
2067 }
2068 body := strings.Repeat("A", test.contentLength)
2069 _, err = fmt.Fprint(targ, body)
2070 if err == nil {
2071 err = targ.Close()
2072 }
2073 if err != nil {
2074 if !test.readBody {
2075
2076
2077 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
2078 return
2079 }
2080 t.Errorf("On test %#v, error writing request body: %v", test, err)
2081 }
2082 }
2083 }()
2084 bufr := bufio.NewReader(conn)
2085 line, err := bufr.ReadString('\n')
2086 if err != nil {
2087 if writeBody && !test.readBody {
2088
2089
2090
2091
2092
2093 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
2094 return
2095 }
2096 t.Fatalf("On test %#v, ReadString: %v", test, err)
2097 }
2098 if !strings.Contains(line, test.expectedResponse) {
2099 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
2100 }
2101 }
2102
2103 for _, test := range serverExpectTests {
2104 runTest(test)
2105 }
2106 }
2107
2108
2109
2110 func TestServerUnreadRequestBodyLittle(t *testing.T) {
2111 setParallel(t)
2112 defer afterTest(t)
2113 conn := new(testConn)
2114 body := strings.Repeat("x", 100<<10)
2115 conn.readBuf.Write([]byte(fmt.Sprintf(
2116 "POST / HTTP/1.1\r\n"+
2117 "Host: test\r\n"+
2118 "Content-Length: %d\r\n"+
2119 "\r\n", len(body))))
2120 conn.readBuf.Write([]byte(body))
2121
2122 done := make(chan bool)
2123
2124 readBufLen := func() int {
2125 conn.readMu.Lock()
2126 defer conn.readMu.Unlock()
2127 return conn.readBuf.Len()
2128 }
2129
2130 ls := &oneConnListener{conn}
2131 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2132 defer close(done)
2133 if bufLen := readBufLen(); bufLen < len(body)/2 {
2134 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
2135 }
2136 rw.WriteHeader(200)
2137 rw.(Flusher).Flush()
2138 if g, e := readBufLen(), 0; g != e {
2139 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
2140 }
2141 if c := rw.Header().Get("Connection"); c != "" {
2142 t.Errorf(`Connection header = %q; want ""`, c)
2143 }
2144 }))
2145 <-done
2146 }
2147
2148
2149
2150
2151 func TestServerUnreadRequestBodyLarge(t *testing.T) {
2152 setParallel(t)
2153 if testing.Short() && testenv.Builder() == "" {
2154 t.Log("skipping in short mode")
2155 }
2156 conn := new(testConn)
2157 body := strings.Repeat("x", 1<<20)
2158 conn.readBuf.Write([]byte(fmt.Sprintf(
2159 "POST / HTTP/1.1\r\n"+
2160 "Host: test\r\n"+
2161 "Content-Length: %d\r\n"+
2162 "\r\n", len(body))))
2163 conn.readBuf.Write([]byte(body))
2164 conn.closec = make(chan bool, 1)
2165
2166 ls := &oneConnListener{conn}
2167 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2168 if conn.readBuf.Len() < len(body)/2 {
2169 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2170 }
2171 rw.WriteHeader(200)
2172 rw.(Flusher).Flush()
2173 if conn.readBuf.Len() < len(body)/2 {
2174 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2175 }
2176 }))
2177 <-conn.closec
2178
2179 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
2180 t.Errorf("Expected a Connection: close header; got response: %s", res)
2181 }
2182 }
2183
2184 type handlerBodyCloseTest struct {
2185 bodySize int
2186 bodyChunked bool
2187 reqConnClose bool
2188
2189 wantEOFSearch bool
2190 wantNextReq bool
2191 }
2192
2193 func (t handlerBodyCloseTest) connectionHeader() string {
2194 if t.reqConnClose {
2195 return "Connection: close\r\n"
2196 }
2197 return ""
2198 }
2199
2200 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
2201
2202
2203 0: {
2204 bodySize: 20 << 10,
2205 bodyChunked: false,
2206 reqConnClose: false,
2207 wantEOFSearch: true,
2208 wantNextReq: true,
2209 },
2210
2211
2212
2213 1: {
2214 bodySize: 20 << 10,
2215 bodyChunked: true,
2216 reqConnClose: false,
2217 wantEOFSearch: true,
2218 wantNextReq: true,
2219 },
2220
2221
2222
2223
2224 2: {
2225 bodySize: 20 << 10,
2226 bodyChunked: false,
2227 reqConnClose: true,
2228 wantEOFSearch: false,
2229 wantNextReq: false,
2230 },
2231
2232
2233
2234
2235
2236
2237 3: {
2238 bodySize: 20 << 10,
2239 bodyChunked: true,
2240 reqConnClose: true,
2241 wantEOFSearch: true,
2242 wantNextReq: false,
2243 },
2244
2245
2246 4: {
2247 bodySize: 1 << 20,
2248 bodyChunked: false,
2249 reqConnClose: false,
2250 wantEOFSearch: false,
2251 wantNextReq: false,
2252 },
2253
2254
2255 5: {
2256 bodySize: 1 << 20,
2257 bodyChunked: true,
2258 reqConnClose: false,
2259 wantEOFSearch: true,
2260 wantNextReq: false,
2261 },
2262
2263
2264
2265
2266 6: {
2267 bodySize: 1 << 20,
2268 bodyChunked: true,
2269 reqConnClose: true,
2270 wantEOFSearch: true,
2271 wantNextReq: false,
2272 },
2273
2274
2275
2276 7: {
2277 bodySize: 1 << 20,
2278 bodyChunked: false,
2279 reqConnClose: true,
2280 wantEOFSearch: false,
2281 wantNextReq: false,
2282 },
2283 }
2284
2285 func TestHandlerBodyClose(t *testing.T) {
2286 setParallel(t)
2287 if testing.Short() && testenv.Builder() == "" {
2288 t.Skip("skipping in -short mode")
2289 }
2290 for i, tt := range handlerBodyCloseTests {
2291 testHandlerBodyClose(t, i, tt)
2292 }
2293 }
2294
2295 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2296 conn := new(testConn)
2297 body := strings.Repeat("x", tt.bodySize)
2298 if tt.bodyChunked {
2299 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2300 "Host: test\r\n" +
2301 tt.connectionHeader() +
2302 "Transfer-Encoding: chunked\r\n" +
2303 "\r\n")
2304 cw := internal.NewChunkedWriter(&conn.readBuf)
2305 io.WriteString(cw, body)
2306 cw.Close()
2307 conn.readBuf.WriteString("\r\n")
2308 } else {
2309 conn.readBuf.Write([]byte(fmt.Sprintf(
2310 "POST / HTTP/1.1\r\n"+
2311 "Host: test\r\n"+
2312 tt.connectionHeader()+
2313 "Content-Length: %d\r\n"+
2314 "\r\n", len(body))))
2315 conn.readBuf.Write([]byte(body))
2316 }
2317 if !tt.reqConnClose {
2318 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2319 }
2320 conn.closec = make(chan bool, 1)
2321
2322 readBufLen := func() int {
2323 conn.readMu.Lock()
2324 defer conn.readMu.Unlock()
2325 return conn.readBuf.Len()
2326 }
2327
2328 ls := &oneConnListener{conn}
2329 var numReqs int
2330 var size0, size1 int
2331 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2332 numReqs++
2333 if numReqs == 1 {
2334 size0 = readBufLen()
2335 req.Body.Close()
2336 size1 = readBufLen()
2337 }
2338 }))
2339 <-conn.closec
2340 if numReqs < 1 || numReqs > 2 {
2341 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2342 }
2343 didSearch := size0 != size1
2344 if didSearch != tt.wantEOFSearch {
2345 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2346 }
2347 if tt.wantNextReq && numReqs != 2 {
2348 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2349 }
2350 }
2351
2352
2353
2354 type testHandlerBodyConsumer struct {
2355 name string
2356 f func(io.ReadCloser)
2357 }
2358
2359 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2360 {"nil", func(io.ReadCloser) {}},
2361 {"close", func(r io.ReadCloser) { r.Close() }},
2362 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2363 }
2364
2365 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2366 setParallel(t)
2367 defer afterTest(t)
2368 for _, handler := range testHandlerBodyConsumers {
2369 conn := new(testConn)
2370 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2371 "Host: test\r\n" +
2372 "Transfer-Encoding: chunked\r\n" +
2373 "\r\n" +
2374 "hax\r\n" +
2375 "GET /secret HTTP/1.1\r\n" +
2376 "Host: test\r\n" +
2377 "\r\n")
2378
2379 conn.closec = make(chan bool, 1)
2380 ls := &oneConnListener{conn}
2381 var numReqs int
2382 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2383 numReqs++
2384 if strings.Contains(req.URL.Path, "secret") {
2385 t.Error("Request for /secret encountered, should not have happened.")
2386 }
2387 handler.f(req.Body)
2388 }))
2389 <-conn.closec
2390 if numReqs != 1 {
2391 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2392 }
2393 }
2394 }
2395
2396 func TestInvalidTrailerClosesConnection(t *testing.T) {
2397 setParallel(t)
2398 defer afterTest(t)
2399 for _, handler := range testHandlerBodyConsumers {
2400 conn := new(testConn)
2401 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2402 "Host: test\r\n" +
2403 "Trailer: hack\r\n" +
2404 "Transfer-Encoding: chunked\r\n" +
2405 "\r\n" +
2406 "3\r\n" +
2407 "hax\r\n" +
2408 "0\r\n" +
2409 "I'm not a valid trailer\r\n" +
2410 "GET /secret HTTP/1.1\r\n" +
2411 "Host: test\r\n" +
2412 "\r\n")
2413
2414 conn.closec = make(chan bool, 1)
2415 ln := &oneConnListener{conn}
2416 var numReqs int
2417 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2418 numReqs++
2419 if strings.Contains(req.URL.Path, "secret") {
2420 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2421 }
2422 handler.f(req.Body)
2423 }))
2424 <-conn.closec
2425 if numReqs != 1 {
2426 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2427 }
2428 }
2429 }
2430
2431
2432
2433
2434 type slowTestConn struct {
2435
2436 script []any
2437 closec chan bool
2438
2439 mu sync.Mutex
2440 rd, wd time.Time
2441 noopConn
2442 }
2443
2444 func (c *slowTestConn) SetDeadline(t time.Time) error {
2445 c.SetReadDeadline(t)
2446 c.SetWriteDeadline(t)
2447 return nil
2448 }
2449
2450 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2451 c.mu.Lock()
2452 defer c.mu.Unlock()
2453 c.rd = t
2454 return nil
2455 }
2456
2457 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2458 c.mu.Lock()
2459 defer c.mu.Unlock()
2460 c.wd = t
2461 return nil
2462 }
2463
2464 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2465 c.mu.Lock()
2466 defer c.mu.Unlock()
2467 restart:
2468 if !c.rd.IsZero() && time.Now().After(c.rd) {
2469 return 0, syscall.ETIMEDOUT
2470 }
2471 if len(c.script) == 0 {
2472 return 0, io.EOF
2473 }
2474
2475 switch cue := c.script[0].(type) {
2476 case time.Duration:
2477 if !c.rd.IsZero() {
2478
2479
2480 if remaining := time.Until(c.rd); remaining < cue {
2481 c.script[0] = cue - remaining
2482 time.Sleep(remaining)
2483 return 0, syscall.ETIMEDOUT
2484 }
2485 }
2486 c.script = c.script[1:]
2487 time.Sleep(cue)
2488 goto restart
2489
2490 case string:
2491 n = copy(b, cue)
2492
2493 if len(cue) > n {
2494 c.script[0] = cue[n:]
2495 } else {
2496 c.script = c.script[1:]
2497 }
2498
2499 default:
2500 panic("unknown cue in slowTestConn script")
2501 }
2502
2503 return
2504 }
2505
2506 func (c *slowTestConn) Close() error {
2507 select {
2508 case c.closec <- true:
2509 default:
2510 }
2511 return nil
2512 }
2513
2514 func (c *slowTestConn) Write(b []byte) (int, error) {
2515 if !c.wd.IsZero() && time.Now().After(c.wd) {
2516 return 0, syscall.ETIMEDOUT
2517 }
2518 return len(b), nil
2519 }
2520
2521 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2522 if testing.Short() {
2523 t.Skip("skipping in -short mode")
2524 }
2525 defer afterTest(t)
2526 for _, handler := range testHandlerBodyConsumers {
2527 conn := &slowTestConn{
2528 script: []any{
2529 "POST /public HTTP/1.1\r\n" +
2530 "Host: test\r\n" +
2531 "Content-Length: 10000\r\n" +
2532 "\r\n",
2533 "foo bar baz",
2534 600 * time.Millisecond,
2535 "GET /secret HTTP/1.1\r\n" +
2536 "Host: test\r\n" +
2537 "\r\n",
2538 },
2539 closec: make(chan bool, 1),
2540 }
2541 ls := &oneConnListener{conn}
2542
2543 var numReqs int
2544 s := Server{
2545 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2546 numReqs++
2547 if strings.Contains(req.URL.Path, "secret") {
2548 t.Error("Request for /secret encountered, should not have happened.")
2549 }
2550 handler.f(req.Body)
2551 }),
2552 ReadTimeout: 400 * time.Millisecond,
2553 }
2554 go s.Serve(ls)
2555 <-conn.closec
2556
2557 if numReqs != 1 {
2558 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2559 }
2560 }
2561 }
2562
2563
2564 type cancelableTimeoutContext struct {
2565 context.Context
2566 }
2567
2568 func (c cancelableTimeoutContext) Err() error {
2569 if c.Context.Err() != nil {
2570 return context.DeadlineExceeded
2571 }
2572 return nil
2573 }
2574
2575 func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
2576 func testTimeoutHandler(t *testing.T, mode testMode) {
2577 sendHi := make(chan bool, 1)
2578 writeErrors := make(chan error, 1)
2579 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2580 <-sendHi
2581 _, werr := w.Write([]byte("hi"))
2582 writeErrors <- werr
2583 })
2584 ctx, cancel := context.WithCancel(context.Background())
2585 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2586 cst := newClientServerTest(t, mode, h)
2587
2588
2589 sendHi <- true
2590 res, err := cst.c.Get(cst.ts.URL)
2591 if err != nil {
2592 t.Error(err)
2593 }
2594 if g, e := res.StatusCode, StatusOK; g != e {
2595 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2596 }
2597 body, _ := io.ReadAll(res.Body)
2598 if g, e := string(body), "hi"; g != e {
2599 t.Errorf("got body %q; expected %q", g, e)
2600 }
2601 if g := <-writeErrors; g != nil {
2602 t.Errorf("got unexpected Write error on first request: %v", g)
2603 }
2604
2605
2606 cancel()
2607
2608 res, err = cst.c.Get(cst.ts.URL)
2609 if err != nil {
2610 t.Error(err)
2611 }
2612 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2613 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2614 }
2615 body, _ = io.ReadAll(res.Body)
2616 if !strings.Contains(string(body), "<title>Timeout</title>") {
2617 t.Errorf("expected timeout body; got %q", string(body))
2618 }
2619 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2620 t.Errorf("response content-type = %q; want %q", g, w)
2621 }
2622
2623
2624
2625 sendHi <- true
2626 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2627 t.Errorf("expected Write error of %v; got %v", e, g)
2628 }
2629 }
2630
2631
2632 func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
2633 func testTimeoutHandlerRace(t *testing.T, mode testMode) {
2634 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2635 ms, _ := strconv.Atoi(r.URL.Path[1:])
2636 if ms == 0 {
2637 ms = 1
2638 }
2639 for i := 0; i < ms; i++ {
2640 w.Write([]byte("hi"))
2641 time.Sleep(time.Millisecond)
2642 }
2643 })
2644
2645 ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
2646
2647 c := ts.Client()
2648
2649 var wg sync.WaitGroup
2650 gate := make(chan bool, 10)
2651 n := 50
2652 if testing.Short() {
2653 n = 10
2654 gate = make(chan bool, 3)
2655 }
2656 for i := 0; i < n; i++ {
2657 gate <- true
2658 wg.Add(1)
2659 go func() {
2660 defer wg.Done()
2661 defer func() { <-gate }()
2662 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2663 if err == nil {
2664 io.Copy(io.Discard, res.Body)
2665 res.Body.Close()
2666 }
2667 }()
2668 }
2669 wg.Wait()
2670 }
2671
2672
2673
2674 func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
2675 func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
2676 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2677 w.WriteHeader(204)
2678 })
2679
2680 ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
2681
2682 var wg sync.WaitGroup
2683 gate := make(chan bool, 50)
2684 n := 500
2685 if testing.Short() {
2686 n = 10
2687 }
2688
2689 c := ts.Client()
2690 for i := 0; i < n; i++ {
2691 gate <- true
2692 wg.Add(1)
2693 go func() {
2694 defer wg.Done()
2695 defer func() { <-gate }()
2696 res, err := c.Get(ts.URL)
2697 if err != nil {
2698
2699
2700 t.Log(err)
2701 return
2702 }
2703 defer res.Body.Close()
2704 io.Copy(io.Discard, res.Body)
2705 }()
2706 }
2707 wg.Wait()
2708 }
2709
2710
2711 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
2712 func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
2713 sendHi := make(chan bool, 1)
2714 writeErrors := make(chan error, 1)
2715 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2716 w.Header().Set("Content-Type", "text/plain")
2717 <-sendHi
2718 _, werr := w.Write([]byte("hi"))
2719 writeErrors <- werr
2720 })
2721 ctx, cancel := context.WithCancel(context.Background())
2722 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2723 cst := newClientServerTest(t, mode, h)
2724
2725
2726 sendHi <- true
2727 res, err := cst.c.Get(cst.ts.URL)
2728 if err != nil {
2729 t.Error(err)
2730 }
2731 if g, e := res.StatusCode, StatusOK; g != e {
2732 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2733 }
2734 body, _ := io.ReadAll(res.Body)
2735 if g, e := string(body), "hi"; g != e {
2736 t.Errorf("got body %q; expected %q", g, e)
2737 }
2738 if g := <-writeErrors; g != nil {
2739 t.Errorf("got unexpected Write error on first request: %v", g)
2740 }
2741
2742
2743 cancel()
2744
2745 res, err = cst.c.Get(cst.ts.URL)
2746 if err != nil {
2747 t.Error(err)
2748 }
2749 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2750 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2751 }
2752 body, _ = io.ReadAll(res.Body)
2753 if !strings.Contains(string(body), "<title>Timeout</title>") {
2754 t.Errorf("expected timeout body; got %q", string(body))
2755 }
2756
2757
2758
2759 sendHi <- true
2760 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2761 t.Errorf("expected Write error of %v; got %v", e, g)
2762 }
2763 }
2764
2765
2766 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2767 run(t, testTimeoutHandlerStartTimerWhenServing)
2768 }
2769 func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
2770 if testing.Short() {
2771 t.Skip("skipping sleeping test in -short mode")
2772 }
2773 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2774 w.WriteHeader(StatusNoContent)
2775 }
2776 timeout := 300 * time.Millisecond
2777 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2778 defer ts.Close()
2779
2780 c := ts.Client()
2781
2782
2783
2784
2785 time.Sleep(2 * timeout)
2786 res, err := c.Get(ts.URL)
2787 if err != nil {
2788 t.Fatal(err)
2789 }
2790 defer res.Body.Close()
2791 if res.StatusCode != StatusNoContent {
2792 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2793 }
2794 }
2795
2796 func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
2797 func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
2798 writeErrors := make(chan error, 1)
2799 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2800 w.Header().Set("Content-Type", "text/plain")
2801 var err error
2802
2803
2804
2805 for i := 0; i < 100; i++ {
2806 _, err = w.Write([]byte("a"))
2807 if err != nil {
2808 break
2809 }
2810 time.Sleep(1 * time.Millisecond)
2811 }
2812 writeErrors <- err
2813 })
2814 ctx, cancel := context.WithCancel(context.Background())
2815 cancel()
2816 h := NewTestTimeoutHandler(sayHi, ctx)
2817 cst := newClientServerTest(t, mode, h)
2818 defer cst.close()
2819
2820 res, err := cst.c.Get(cst.ts.URL)
2821 if err != nil {
2822 t.Error(err)
2823 }
2824 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2825 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2826 }
2827 body, _ := io.ReadAll(res.Body)
2828 if g, e := string(body), ""; g != e {
2829 t.Errorf("got body %q; expected %q", g, e)
2830 }
2831 if g, e := <-writeErrors, context.Canceled; g != e {
2832 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2833 }
2834 }
2835
2836
2837 func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
2838 func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
2839 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2840
2841 }
2842 timeout := 300 * time.Millisecond
2843 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2844
2845 c := ts.Client()
2846
2847 res, err := c.Get(ts.URL)
2848 if err != nil {
2849 t.Fatal(err)
2850 }
2851 defer res.Body.Close()
2852 if res.StatusCode != StatusOK {
2853 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2854 }
2855 }
2856
2857
2858 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2859 wrapper := func(h Handler) Handler {
2860 return TimeoutHandler(h, time.Second, "")
2861 }
2862 run(t, func(t *testing.T, mode testMode) {
2863 testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
2864 }, testNotParallel)
2865 }
2866
2867 func TestRedirectBadPath(t *testing.T) {
2868
2869
2870 rr := httptest.NewRecorder()
2871 req := &Request{
2872 Method: "GET",
2873 URL: &url.URL{
2874 Scheme: "http",
2875 Path: "not-empty-but-no-leading-slash",
2876 },
2877 }
2878 Redirect(rr, req, "", 304)
2879 if rr.Code != 304 {
2880 t.Errorf("Code = %d; want 304", rr.Code)
2881 }
2882 }
2883
2884 func TestRedirectEscapedPath(t *testing.T) {
2885 baseURL, redirectURL := "http://example.com/foo%2Fbar/", "qux%2Fbaz"
2886 req := httptest.NewRequest("GET", baseURL, NoBody)
2887
2888 rr := httptest.NewRecorder()
2889 Redirect(rr, req, redirectURL, StatusMovedPermanently)
2890
2891 wantURL := "/foo%2Fbar/qux%2Fbaz"
2892 if got := rr.Result().Header.Get("Location"); got != wantURL {
2893 t.Errorf("Redirect(%s, %s) = %s, want = %s", baseURL, redirectURL, got, wantURL)
2894 }
2895 }
2896
2897
2898 func TestRedirect(t *testing.T) {
2899 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2900
2901 var tests = []struct {
2902 in string
2903 want string
2904 }{
2905
2906 {"http://foobar.com/baz", "http://foobar.com/baz"},
2907
2908 {"https://foobar.com/baz", "https://foobar.com/baz"},
2909
2910 {"test://foobar.com/baz", "test://foobar.com/baz"},
2911
2912 {"//foobar.com/baz", "//foobar.com/baz"},
2913
2914 {"/foobar.com/baz", "/foobar.com/baz"},
2915
2916 {"foobar.com/baz", "/qux/foobar.com/baz"},
2917
2918 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2919
2920 {"///foobar.com/baz", "/foobar.com/baz"},
2921
2922
2923 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2924 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2925 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2926
2927 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2928 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2929 }
2930
2931 for _, tt := range tests {
2932 rec := httptest.NewRecorder()
2933 Redirect(rec, req, tt.in, 302)
2934 if got, want := rec.Code, 302; got != want {
2935 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
2936 }
2937 if got := rec.Header().Get("Location"); got != tt.want {
2938 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
2939 }
2940 }
2941 }
2942
2943
2944
2945 func TestRedirectContentTypeAndBody(t *testing.T) {
2946 type ctHeader struct {
2947 Values []string
2948 }
2949
2950 var tests = []struct {
2951 method string
2952 ct *ctHeader
2953 wantCT string
2954 wantBody string
2955 }{
2956 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
2957 {MethodHead, nil, "text/html; charset=utf-8", ""},
2958 {MethodPost, nil, "", ""},
2959 {MethodDelete, nil, "", ""},
2960 {"foo", nil, "", ""},
2961 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
2962 {MethodGet, &ctHeader{[]string{}}, "", ""},
2963 {MethodGet, &ctHeader{nil}, "", ""},
2964 }
2965 for _, tt := range tests {
2966 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
2967 rec := httptest.NewRecorder()
2968 if tt.ct != nil {
2969 rec.Header()["Content-Type"] = tt.ct.Values
2970 }
2971 Redirect(rec, req, "/foo", 302)
2972 if got, want := rec.Code, 302; got != want {
2973 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
2974 }
2975 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
2976 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
2977 }
2978 resp := rec.Result()
2979 body, err := io.ReadAll(resp.Body)
2980 if err != nil {
2981 t.Fatal(err)
2982 }
2983 if got, want := string(body), tt.wantBody; got != want {
2984 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
2985 }
2986 }
2987 }
2988
2989
2990
2991
2992
2993
2994
2995 func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
2996
2997 func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
2998 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
2999 all, err := io.ReadAll(r.Body)
3000 if err != nil {
3001 t.Fatalf("handler ReadAll: %v", err)
3002 }
3003 if len(all) != 0 {
3004 t.Errorf("handler got %d bytes; expected 0", len(all))
3005 }
3006 rw.Header().Set("Content-Length", "0")
3007 }))
3008
3009 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
3010 if err != nil {
3011 t.Fatal(err)
3012 }
3013 req.ContentLength = 0
3014
3015 var resp [5]*Response
3016 for i := range resp {
3017 resp[i], err = cst.c.Do(req)
3018 if err != nil {
3019 t.Fatalf("client post #%d: %v", i, err)
3020 }
3021 }
3022
3023 for i := range resp {
3024 all, err := io.ReadAll(resp[i].Body)
3025 if err != nil {
3026 t.Fatalf("req #%d: client ReadAll: %v", i, err)
3027 }
3028 if len(all) != 0 {
3029 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
3030 }
3031 }
3032 }
3033
3034 func TestHandlerPanicNil(t *testing.T) {
3035 run(t, func(t *testing.T, mode testMode) {
3036 testHandlerPanic(t, false, mode, nil, nil)
3037 }, testNotParallel)
3038 }
3039
3040 func TestHandlerPanic(t *testing.T) {
3041 run(t, func(t *testing.T, mode testMode) {
3042 testHandlerPanic(t, false, mode, nil, "intentional death for testing")
3043 }, testNotParallel)
3044 }
3045
3046 func TestHandlerPanicWithHijack(t *testing.T) {
3047
3048 run(t, func(t *testing.T, mode testMode) {
3049 testHandlerPanic(t, true, mode, nil, "intentional death for testing")
3050 }, []testMode{http1Mode})
3051 }
3052
3053 func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
3054
3055
3056
3057
3058
3059
3060
3061
3062 pr, pw := io.Pipe()
3063 defer pw.Close()
3064
3065 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
3066 if withHijack {
3067 rwc, _, err := w.(Hijacker).Hijack()
3068 if err != nil {
3069 t.Logf("unexpected error: %v", err)
3070 }
3071 defer rwc.Close()
3072 }
3073 panic(panicValue)
3074 })
3075 if wrapper != nil {
3076 handler = wrapper(handler)
3077 }
3078 cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
3079 ts.Config.ErrorLog = log.New(pw, "", 0)
3080 })
3081
3082
3083 done := make(chan bool, 1)
3084 go func() {
3085 buf := make([]byte, 4<<10)
3086 _, err := pr.Read(buf)
3087 pr.Close()
3088 if err != nil && err != io.EOF {
3089 t.Error(err)
3090 }
3091 done <- true
3092 }()
3093
3094 _, err := cst.c.Get(cst.ts.URL)
3095 if err == nil {
3096 t.Logf("expected an error")
3097 }
3098
3099 if panicValue == nil {
3100 return
3101 }
3102
3103 <-done
3104 }
3105
3106 type terrorWriter struct{ t *testing.T }
3107
3108 func (w terrorWriter) Write(p []byte) (int, error) {
3109 w.t.Errorf("%s", p)
3110 return len(p), nil
3111 }
3112
3113
3114
3115 func TestServerWriteHijackZeroBytes(t *testing.T) {
3116 run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
3117 }
3118 func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
3119 done := make(chan struct{})
3120 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3121 defer close(done)
3122 w.(Flusher).Flush()
3123 conn, _, err := w.(Hijacker).Hijack()
3124 if err != nil {
3125 t.Errorf("Hijack: %v", err)
3126 return
3127 }
3128 defer conn.Close()
3129 _, err = w.Write(nil)
3130 if err != ErrHijacked {
3131 t.Errorf("Write error = %v; want ErrHijacked", err)
3132 }
3133 }), func(ts *httptest.Server) {
3134 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
3135 }).ts
3136
3137 c := ts.Client()
3138 res, err := c.Get(ts.URL)
3139 if err != nil {
3140 t.Fatal(err)
3141 }
3142 res.Body.Close()
3143 <-done
3144 }
3145
3146 func TestServerNoDate(t *testing.T) {
3147 run(t, func(t *testing.T, mode testMode) {
3148 testServerNoHeader(t, mode, "Date")
3149 })
3150 }
3151
3152 func TestServerContentType(t *testing.T) {
3153 run(t, func(t *testing.T, mode testMode) {
3154 testServerNoHeader(t, mode, "Content-Type")
3155 })
3156 }
3157
3158 func testServerNoHeader(t *testing.T, mode testMode, header string) {
3159 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3160 w.Header()[header] = nil
3161 io.WriteString(w, "<html>foo</html>")
3162 }))
3163 res, err := cst.c.Get(cst.ts.URL)
3164 if err != nil {
3165 t.Fatal(err)
3166 }
3167 res.Body.Close()
3168 if got, ok := res.Header[header]; ok {
3169 t.Fatalf("Expected no %s header; got %q", header, got)
3170 }
3171 }
3172
3173 func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
3174 func testStripPrefix(t *testing.T, mode testMode) {
3175 h := HandlerFunc(func(w ResponseWriter, r *Request) {
3176 w.Header().Set("X-Path", r.URL.Path)
3177 w.Header().Set("X-RawPath", r.URL.RawPath)
3178 })
3179 ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
3180
3181 c := ts.Client()
3182
3183 cases := []struct {
3184 reqPath string
3185 path string
3186 rawPath string
3187 }{
3188 {"/foo/bar/qux", "/qux", ""},
3189 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
3190 {"/foo%2Fbar/qux", "", ""},
3191 {"/bar", "", ""},
3192 }
3193 for _, tc := range cases {
3194 t.Run(tc.reqPath, func(t *testing.T) {
3195 res, err := c.Get(ts.URL + tc.reqPath)
3196 if err != nil {
3197 t.Fatal(err)
3198 }
3199 res.Body.Close()
3200 if tc.path == "" {
3201 if res.StatusCode != StatusNotFound {
3202 t.Errorf("got %q, want 404 Not Found", res.Status)
3203 }
3204 return
3205 }
3206 if res.StatusCode != StatusOK {
3207 t.Fatalf("got %q, want 200 OK", res.Status)
3208 }
3209 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
3210 t.Errorf("got Path %q, want %q", g, w)
3211 }
3212 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
3213 t.Errorf("got RawPath %q, want %q", g, w)
3214 }
3215 })
3216 }
3217 }
3218
3219
3220 func TestStripPrefixNotModifyRequest(t *testing.T) {
3221 h := StripPrefix("/foo", NotFoundHandler())
3222 req := httptest.NewRequest("GET", "/foo/bar", nil)
3223 h.ServeHTTP(httptest.NewRecorder(), req)
3224 if req.URL.Path != "/foo/bar" {
3225 t.Errorf("StripPrefix should not modify the provided Request, but it did")
3226 }
3227 }
3228
3229 func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) }
3230 func testRequestLimit(t *testing.T, mode testMode) {
3231 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3232 t.Fatalf("didn't expect to get request in Handler")
3233 }), optQuietLog)
3234 req, _ := NewRequest("GET", cst.ts.URL, nil)
3235 var bytesPerHeader = len("header12345: val12345\r\n")
3236 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
3237 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
3238 }
3239 res, err := cst.c.Do(req)
3240 if res != nil {
3241 defer res.Body.Close()
3242 }
3243 if mode == http2Mode {
3244
3245
3246
3247
3248 if err == nil && res.StatusCode != 431 {
3249 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3250 }
3251 } else {
3252
3253
3254
3255
3256 if err != nil {
3257 t.Fatalf("Do: %v", err)
3258 }
3259 if res.StatusCode != 431 {
3260 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3261 }
3262 }
3263 }
3264
3265 type neverEnding byte
3266
3267 func (b neverEnding) Read(p []byte) (n int, err error) {
3268 for i := range p {
3269 p[i] = byte(b)
3270 }
3271 return len(p), nil
3272 }
3273
3274 type bodyLimitReader struct {
3275 mu sync.Mutex
3276 count int
3277 limit int
3278 closed chan struct{}
3279 }
3280
3281 func (r *bodyLimitReader) Read(p []byte) (int, error) {
3282 r.mu.Lock()
3283 defer r.mu.Unlock()
3284 select {
3285 case <-r.closed:
3286 return 0, errors.New("closed")
3287 default:
3288 }
3289 if r.count > r.limit {
3290 return 0, errors.New("at limit")
3291 }
3292 r.count += len(p)
3293 for i := range p {
3294 p[i] = 'a'
3295 }
3296 return len(p), nil
3297 }
3298
3299 func (r *bodyLimitReader) Close() error {
3300 r.mu.Lock()
3301 defer r.mu.Unlock()
3302 close(r.closed)
3303 return nil
3304 }
3305
3306 func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
3307 func testRequestBodyLimit(t *testing.T, mode testMode) {
3308 const limit = 1 << 20
3309 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3310 r.Body = MaxBytesReader(w, r.Body, limit)
3311 n, err := io.Copy(io.Discard, r.Body)
3312 if err == nil {
3313 t.Errorf("expected error from io.Copy")
3314 }
3315 if n != limit {
3316 t.Errorf("io.Copy = %d, want %d", n, limit)
3317 }
3318 mbErr, ok := err.(*MaxBytesError)
3319 if !ok {
3320 t.Errorf("expected MaxBytesError, got %T", err)
3321 }
3322 if mbErr.Limit != limit {
3323 t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
3324 }
3325 }))
3326
3327 body := &bodyLimitReader{
3328 closed: make(chan struct{}),
3329 limit: limit * 200,
3330 }
3331 req, _ := NewRequest("POST", cst.ts.URL, body)
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342 resp, err := cst.c.Do(req)
3343 if err == nil {
3344 resp.Body.Close()
3345 }
3346
3347
3348 <-body.closed
3349
3350 if body.count > limit*100 {
3351 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3352 limit, body.count)
3353 }
3354 }
3355
3356
3357
3358 func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) }
3359 func testClientWriteShutdown(t *testing.T, mode testMode) {
3360 if runtime.GOOS == "plan9" {
3361 t.Skip("skipping test; see https://golang.org/issue/17906")
3362 }
3363 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
3364 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3365 if err != nil {
3366 t.Fatalf("Dial: %v", err)
3367 }
3368 err = conn.(*net.TCPConn).CloseWrite()
3369 if err != nil {
3370 t.Fatalf("CloseWrite: %v", err)
3371 }
3372
3373 bs, err := io.ReadAll(conn)
3374 if err != nil {
3375 t.Errorf("ReadAll: %v", err)
3376 }
3377 got := string(bs)
3378 if got != "" {
3379 t.Errorf("read %q from server; want nothing", got)
3380 }
3381 }
3382
3383
3384
3385 func TestServerBufferedChunking(t *testing.T) {
3386 conn := new(testConn)
3387 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3388 conn.closec = make(chan bool, 1)
3389 ls := &oneConnListener{conn}
3390 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3391 rw.(Flusher).Flush()
3392 rw.Write([]byte{'x'})
3393 rw.Write([]byte{'y'})
3394 rw.Write([]byte{'z'})
3395 }))
3396 <-conn.closec
3397 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3398 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3399 conn.writeBuf.Bytes())
3400 }
3401 }
3402
3403
3404
3405
3406
3407 func TestServerGracefulClose(t *testing.T) {
3408
3409 run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
3410 }
3411 func testServerGracefulClose(t *testing.T, mode testMode) {
3412 runTimeSensitiveTest(t, []time.Duration{
3413 1 * time.Millisecond,
3414 5 * time.Millisecond,
3415 10 * time.Millisecond,
3416 50 * time.Millisecond,
3417 100 * time.Millisecond,
3418 500 * time.Millisecond,
3419 time.Second,
3420 5 * time.Second,
3421 }, func(t *testing.T, timeout time.Duration) error {
3422 SetRSTAvoidanceDelay(t, timeout)
3423 t.Logf("set RST avoidance delay to %v", timeout)
3424
3425 const bodySize = 5 << 20
3426 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3427 for i := 0; i < bodySize; i++ {
3428 req = append(req, 'x')
3429 }
3430
3431 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3432 Error(w, "bye", StatusUnauthorized)
3433 }))
3434
3435
3436 defer cst.close()
3437 ts := cst.ts
3438
3439 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3440 if err != nil {
3441 return err
3442 }
3443 writeErr := make(chan error)
3444 go func() {
3445 _, err := conn.Write(req)
3446 writeErr <- err
3447 }()
3448 defer func() {
3449 conn.Close()
3450
3451
3452
3453 <-writeErr
3454 }()
3455
3456 br := bufio.NewReader(conn)
3457 lineNum := 0
3458 for {
3459 line, err := br.ReadString('\n')
3460 if err == io.EOF {
3461 break
3462 }
3463 if err != nil {
3464 return fmt.Errorf("ReadLine: %v", err)
3465 }
3466 lineNum++
3467 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3468 t.Errorf("Response line = %q; want a 401", line)
3469 }
3470 }
3471 return nil
3472 })
3473 }
3474
3475 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
3476 func testCaseSensitiveMethod(t *testing.T, mode testMode) {
3477 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3478 if r.Method != "get" {
3479 t.Errorf(`Got method %q; want "get"`, r.Method)
3480 }
3481 }))
3482 defer cst.close()
3483 req, _ := NewRequest("get", cst.ts.URL, nil)
3484 res, err := cst.c.Do(req)
3485 if err != nil {
3486 t.Error(err)
3487 return
3488 }
3489
3490 res.Body.Close()
3491 }
3492
3493
3494
3495
3496
3497 func TestContentLengthZero(t *testing.T) {
3498 run(t, testContentLengthZero, []testMode{http1Mode})
3499 }
3500 func testContentLengthZero(t *testing.T, mode testMode) {
3501 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
3502
3503 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3504 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3505 if err != nil {
3506 t.Fatalf("error dialing: %v", err)
3507 }
3508 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3509 if err != nil {
3510 t.Fatalf("error writing: %v", err)
3511 }
3512 req, _ := NewRequest("GET", "/", nil)
3513 res, err := ReadResponse(bufio.NewReader(conn), req)
3514 if err != nil {
3515 t.Fatalf("error reading response: %v", err)
3516 }
3517 if te := res.TransferEncoding; len(te) > 0 {
3518 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3519 }
3520 if cl := res.ContentLength; cl != 0 {
3521 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3522 }
3523 conn.Close()
3524 }
3525 }
3526
3527 func TestCloseNotifier(t *testing.T) {
3528 run(t, testCloseNotifier, []testMode{http1Mode})
3529 }
3530 func testCloseNotifier(t *testing.T, mode testMode) {
3531 gotReq := make(chan bool, 1)
3532 sawClose := make(chan bool, 1)
3533 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3534 gotReq <- true
3535 cc := rw.(CloseNotifier).CloseNotify()
3536 <-cc
3537 sawClose <- true
3538 })).ts
3539 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3540 if err != nil {
3541 t.Fatalf("error dialing: %v", err)
3542 }
3543 diec := make(chan bool)
3544 go func() {
3545 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3546 if err != nil {
3547 t.Error(err)
3548 return
3549 }
3550 <-diec
3551 conn.Close()
3552 }()
3553 For:
3554 for {
3555 select {
3556 case <-gotReq:
3557 diec <- true
3558 case <-sawClose:
3559 break For
3560 }
3561 }
3562 ts.Close()
3563 }
3564
3565
3566
3567
3568
3569 func TestCloseNotifierPipelined(t *testing.T) {
3570 run(t, testCloseNotifierPipelined, []testMode{http1Mode})
3571 }
3572 func testCloseNotifierPipelined(t *testing.T, mode testMode) {
3573 gotReq := make(chan bool, 2)
3574 sawClose := make(chan bool, 2)
3575 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3576 gotReq <- true
3577 cc := rw.(CloseNotifier).CloseNotify()
3578 select {
3579 case <-cc:
3580 t.Error("unexpected CloseNotify")
3581 case <-time.After(100 * time.Millisecond):
3582 }
3583 sawClose <- true
3584 })).ts
3585 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3586 if err != nil {
3587 t.Fatalf("error dialing: %v", err)
3588 }
3589 diec := make(chan bool, 1)
3590 defer close(diec)
3591 go func() {
3592 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3593 _, err = io.WriteString(conn, req+req)
3594 if err != nil {
3595 t.Error(err)
3596 return
3597 }
3598 <-diec
3599 conn.Close()
3600 }()
3601 reqs := 0
3602 closes := 0
3603 for {
3604 select {
3605 case <-gotReq:
3606 reqs++
3607 if reqs > 2 {
3608 t.Fatal("too many requests")
3609 }
3610 case <-sawClose:
3611 closes++
3612 if closes > 1 {
3613 return
3614 }
3615 }
3616 }
3617 }
3618
3619 func TestCloseNotifierChanLeak(t *testing.T) {
3620 defer afterTest(t)
3621 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3622 for i := 0; i < 20; i++ {
3623 var output bytes.Buffer
3624 conn := &rwTestConn{
3625 Reader: bytes.NewReader(req),
3626 Writer: &output,
3627 closec: make(chan bool, 1),
3628 }
3629 ln := &oneConnListener{conn: conn}
3630 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3631
3632
3633
3634 _ = rw.(CloseNotifier).CloseNotify()
3635 })
3636 go Serve(ln, handler)
3637 <-conn.closec
3638 }
3639 }
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649
3650 func TestHijackAfterCloseNotifier(t *testing.T) {
3651 run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
3652 }
3653 func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
3654 script := make(chan string, 2)
3655 script <- "closenotify"
3656 script <- "hijack"
3657 close(script)
3658 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3659 plan := <-script
3660 switch plan {
3661 default:
3662 panic("bogus plan; too many requests")
3663 case "closenotify":
3664 w.(CloseNotifier).CloseNotify()
3665 w.Header().Set("X-Addr", r.RemoteAddr)
3666 case "hijack":
3667 c, _, err := w.(Hijacker).Hijack()
3668 if err != nil {
3669 t.Errorf("Hijack in Handler: %v", err)
3670 return
3671 }
3672 if _, ok := c.(*net.TCPConn); !ok {
3673
3674
3675 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3676 }
3677 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3678 c.Close()
3679 return
3680 }
3681 })).ts
3682 res1, err := ts.Client().Get(ts.URL)
3683 if err != nil {
3684 log.Fatal(err)
3685 }
3686 res2, err := ts.Client().Get(ts.URL)
3687 if err != nil {
3688 log.Fatal(err)
3689 }
3690 addr1 := res1.Header.Get("X-Addr")
3691 addr2 := res2.Header.Get("X-Addr")
3692 if addr1 == "" || addr1 != addr2 {
3693 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3694 }
3695 }
3696
3697 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3698 run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
3699 }
3700 func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
3701 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3702 bodyOkay := make(chan bool, 1)
3703 gotCloseNotify := make(chan bool, 1)
3704 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3705 defer close(bodyOkay)
3706
3707 reqBody := r.Body
3708 r.Body = nil
3709
3710 gone := w.(CloseNotifier).CloseNotify()
3711 slurp, err := io.ReadAll(reqBody)
3712 if err != nil {
3713 t.Errorf("Body read: %v", err)
3714 return
3715 }
3716 if len(slurp) != len(requestBody) {
3717 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3718 return
3719 }
3720 if !bytes.Equal(slurp, requestBody) {
3721 t.Error("Backend read wrong request body.")
3722 return
3723 }
3724 bodyOkay <- true
3725 <-gone
3726 gotCloseNotify <- true
3727 })).ts
3728
3729 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3730 if err != nil {
3731 t.Fatal(err)
3732 }
3733 defer conn.Close()
3734
3735 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3736 len(requestBody), requestBody)
3737 if !<-bodyOkay {
3738
3739 return
3740 }
3741 conn.Close()
3742 <-gotCloseNotify
3743 }
3744
3745 func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
3746 func testOptions(t *testing.T, mode testMode) {
3747 uric := make(chan string, 2)
3748 mux := NewServeMux()
3749 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3750 uric <- r.RequestURI
3751 })
3752 ts := newClientServerTest(t, mode, mux).ts
3753
3754 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3755 if err != nil {
3756 t.Fatal(err)
3757 }
3758 defer conn.Close()
3759
3760
3761 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3762 if err != nil {
3763 t.Fatal(err)
3764 }
3765 br := bufio.NewReader(conn)
3766 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3767 if err != nil {
3768 t.Fatal(err)
3769 }
3770 if res.StatusCode != 200 {
3771 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3772 }
3773
3774
3775 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3776 if err != nil {
3777 t.Fatal(err)
3778 }
3779 res, err = ReadResponse(br, &Request{Method: "GET"})
3780 if err != nil {
3781 t.Fatal(err)
3782 }
3783 if res.StatusCode != 400 {
3784 t.Errorf("Got non-400 response to GET *: %#v", res)
3785 }
3786
3787 res, err = Get(ts.URL + "/second")
3788 if err != nil {
3789 t.Fatal(err)
3790 }
3791 res.Body.Close()
3792 if got := <-uric; got != "/second" {
3793 t.Errorf("Handler saw request for %q; want /second", got)
3794 }
3795 }
3796
3797 func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
3798 func testOptionsHandler(t *testing.T, mode testMode) {
3799 rc := make(chan *Request, 1)
3800
3801 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3802 rc <- r
3803 }), func(ts *httptest.Server) {
3804 ts.Config.DisableGeneralOptionsHandler = true
3805 }).ts
3806
3807 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3808 if err != nil {
3809 t.Fatal(err)
3810 }
3811 defer conn.Close()
3812
3813 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3814 if err != nil {
3815 t.Fatal(err)
3816 }
3817
3818 if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
3819 t.Errorf("Expected OPTIONS * request, got %v", got)
3820 }
3821 }
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832 func TestHeaderToWire(t *testing.T) {
3833 tests := []struct {
3834 name string
3835 handler func(ResponseWriter, *Request)
3836 check func(got, logs string) error
3837 }{
3838 {
3839 name: "write without Header",
3840 handler: func(rw ResponseWriter, r *Request) {
3841 rw.Write([]byte("hello world"))
3842 },
3843 check: func(got, logs string) error {
3844 if !strings.Contains(got, "Content-Length:") {
3845 return errors.New("no content-length")
3846 }
3847 if !strings.Contains(got, "Content-Type: text/plain") {
3848 return errors.New("no content-type")
3849 }
3850 return nil
3851 },
3852 },
3853 {
3854 name: "Header mutation before write",
3855 handler: func(rw ResponseWriter, r *Request) {
3856 h := rw.Header()
3857 h.Set("Content-Type", "some/type")
3858 rw.Write([]byte("hello world"))
3859 h.Set("Too-Late", "bogus")
3860 },
3861 check: func(got, logs string) error {
3862 if !strings.Contains(got, "Content-Length:") {
3863 return errors.New("no content-length")
3864 }
3865 if !strings.Contains(got, "Content-Type: some/type") {
3866 return errors.New("wrong content-type")
3867 }
3868 if strings.Contains(got, "Too-Late") {
3869 return errors.New("don't want too-late header")
3870 }
3871 return nil
3872 },
3873 },
3874 {
3875 name: "write then useless Header mutation",
3876 handler: func(rw ResponseWriter, r *Request) {
3877 rw.Write([]byte("hello world"))
3878 rw.Header().Set("Too-Late", "Write already wrote headers")
3879 },
3880 check: func(got, logs string) error {
3881 if strings.Contains(got, "Too-Late") {
3882 return errors.New("header appeared from after WriteHeader")
3883 }
3884 return nil
3885 },
3886 },
3887 {
3888 name: "flush then write",
3889 handler: func(rw ResponseWriter, r *Request) {
3890 rw.(Flusher).Flush()
3891 rw.Write([]byte("post-flush"))
3892 rw.Header().Set("Too-Late", "Write already wrote headers")
3893 },
3894 check: func(got, logs string) error {
3895 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3896 return errors.New("not chunked")
3897 }
3898 if strings.Contains(got, "Too-Late") {
3899 return errors.New("header appeared from after WriteHeader")
3900 }
3901 return nil
3902 },
3903 },
3904 {
3905 name: "header then flush",
3906 handler: func(rw ResponseWriter, r *Request) {
3907 rw.Header().Set("Content-Type", "some/type")
3908 rw.(Flusher).Flush()
3909 rw.Write([]byte("post-flush"))
3910 rw.Header().Set("Too-Late", "Write already wrote headers")
3911 },
3912 check: func(got, logs string) error {
3913 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3914 return errors.New("not chunked")
3915 }
3916 if strings.Contains(got, "Too-Late") {
3917 return errors.New("header appeared from after WriteHeader")
3918 }
3919 if !strings.Contains(got, "Content-Type: some/type") {
3920 return errors.New("wrong content-type")
3921 }
3922 return nil
3923 },
3924 },
3925 {
3926 name: "sniff-on-first-write content-type",
3927 handler: func(rw ResponseWriter, r *Request) {
3928 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3929 rw.Header().Set("Content-Type", "x/wrong")
3930 },
3931 check: func(got, logs string) error {
3932 if !strings.Contains(got, "Content-Type: text/html") {
3933 return errors.New("wrong content-type; want html")
3934 }
3935 return nil
3936 },
3937 },
3938 {
3939 name: "explicit content-type wins",
3940 handler: func(rw ResponseWriter, r *Request) {
3941 rw.Header().Set("Content-Type", "some/type")
3942 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3943 },
3944 check: func(got, logs string) error {
3945 if !strings.Contains(got, "Content-Type: some/type") {
3946 return errors.New("wrong content-type; want html")
3947 }
3948 return nil
3949 },
3950 },
3951 {
3952 name: "empty handler",
3953 handler: func(rw ResponseWriter, r *Request) {
3954 },
3955 check: func(got, logs string) error {
3956 if !strings.Contains(got, "Content-Length: 0") {
3957 return errors.New("want 0 content-length")
3958 }
3959 return nil
3960 },
3961 },
3962 {
3963 name: "only Header, no write",
3964 handler: func(rw ResponseWriter, r *Request) {
3965 rw.Header().Set("Some-Header", "some-value")
3966 },
3967 check: func(got, logs string) error {
3968 if !strings.Contains(got, "Some-Header") {
3969 return errors.New("didn't get header")
3970 }
3971 return nil
3972 },
3973 },
3974 {
3975 name: "WriteHeader call",
3976 handler: func(rw ResponseWriter, r *Request) {
3977 rw.WriteHeader(404)
3978 rw.Header().Set("Too-Late", "some-value")
3979 },
3980 check: func(got, logs string) error {
3981 if !strings.Contains(got, "404") {
3982 return errors.New("wrong status")
3983 }
3984 if strings.Contains(got, "Too-Late") {
3985 return errors.New("shouldn't have seen Too-Late")
3986 }
3987 return nil
3988 },
3989 },
3990 }
3991 for _, tc := range tests {
3992 ht := newHandlerTest(HandlerFunc(tc.handler))
3993 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
3994 logs := ht.logbuf.String()
3995 if err := tc.check(got, logs); err != nil {
3996 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
3997 }
3998 }
3999 }
4000
4001 type errorListener struct {
4002 errs []error
4003 }
4004
4005 func (l *errorListener) Accept() (c net.Conn, err error) {
4006 if len(l.errs) == 0 {
4007 return nil, io.EOF
4008 }
4009 err = l.errs[0]
4010 l.errs = l.errs[1:]
4011 return
4012 }
4013
4014 func (l *errorListener) Close() error {
4015 return nil
4016 }
4017
4018 func (l *errorListener) Addr() net.Addr {
4019 return dummyAddr("test-address")
4020 }
4021
4022 func TestAcceptMaxFds(t *testing.T) {
4023 setParallel(t)
4024
4025 ln := &errorListener{[]error{
4026 &net.OpError{
4027 Op: "accept",
4028 Err: syscall.EMFILE,
4029 }}}
4030 server := &Server{
4031 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
4032 ErrorLog: log.New(io.Discard, "", 0),
4033 }
4034 err := server.Serve(ln)
4035 if err != io.EOF {
4036 t.Errorf("got error %v, want EOF", err)
4037 }
4038 }
4039
4040 func TestWriteAfterHijack(t *testing.T) {
4041 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
4042 var buf strings.Builder
4043 wrotec := make(chan bool, 1)
4044 conn := &rwTestConn{
4045 Reader: bytes.NewReader(req),
4046 Writer: &buf,
4047 closec: make(chan bool, 1),
4048 }
4049 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
4050 conn, bufrw, err := rw.(Hijacker).Hijack()
4051 if err != nil {
4052 t.Error(err)
4053 return
4054 }
4055 go func() {
4056 bufrw.Write([]byte("[hijack-to-bufw]"))
4057 bufrw.Flush()
4058 conn.Write([]byte("[hijack-to-conn]"))
4059 conn.Close()
4060 wrotec <- true
4061 }()
4062 })
4063 ln := &oneConnListener{conn: conn}
4064 go Serve(ln, handler)
4065 <-conn.closec
4066 <-wrotec
4067 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
4068 t.Errorf("wrote %q; want %q", g, w)
4069 }
4070 }
4071
4072 func TestDoubleHijack(t *testing.T) {
4073 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
4074 var buf bytes.Buffer
4075 conn := &rwTestConn{
4076 Reader: bytes.NewReader(req),
4077 Writer: &buf,
4078 closec: make(chan bool, 1),
4079 }
4080 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
4081 conn, _, err := rw.(Hijacker).Hijack()
4082 if err != nil {
4083 t.Error(err)
4084 return
4085 }
4086 _, _, err = rw.(Hijacker).Hijack()
4087 if err == nil {
4088 t.Errorf("got err = nil; want err != nil")
4089 }
4090 conn.Close()
4091 })
4092 ln := &oneConnListener{conn: conn}
4093 go Serve(ln, handler)
4094 <-conn.closec
4095 }
4096
4097
4098
4099
4100
4101
4102
4103 func TestHTTP10ConnectionHeader(t *testing.T) {
4104 run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
4105 }
4106 func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
4107 mux := NewServeMux()
4108 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
4109 ts := newClientServerTest(t, mode, mux).ts
4110
4111
4112 tests := []struct {
4113 req string
4114 expect []string
4115 }{
4116 {
4117 req: "GET / HTTP/1.0\r\n\r\n",
4118 expect: nil,
4119 },
4120 {
4121 req: "OPTIONS * HTTP/1.0\r\n\r\n",
4122 expect: nil,
4123 },
4124 {
4125 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
4126 expect: []string{"keep-alive"},
4127 },
4128 }
4129
4130 for _, tt := range tests {
4131 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
4132 if err != nil {
4133 t.Fatal("dial err:", err)
4134 }
4135
4136 _, err = fmt.Fprint(conn, tt.req)
4137 if err != nil {
4138 t.Fatal("conn write err:", err)
4139 }
4140
4141 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
4142 if err != nil {
4143 t.Fatal("ReadResponse err:", err)
4144 }
4145 conn.Close()
4146 resp.Body.Close()
4147
4148 got := resp.Header["Connection"]
4149 if !slices.Equal(got, tt.expect) {
4150 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
4151 }
4152 }
4153 }
4154
4155
4156 func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
4157 func testServerReaderFromOrder(t *testing.T, mode testMode) {
4158 pr, pw := io.Pipe()
4159 const size = 3 << 20
4160 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4161 rw.Header().Set("Content-Type", "text/plain")
4162 done := make(chan bool)
4163 go func() {
4164 io.Copy(rw, pr)
4165 close(done)
4166 }()
4167 time.Sleep(25 * time.Millisecond)
4168 n, err := io.Copy(io.Discard, req.Body)
4169 if err != nil {
4170 t.Errorf("handler Copy: %v", err)
4171 return
4172 }
4173 if n != size {
4174 t.Errorf("handler Copy = %d; want %d", n, size)
4175 }
4176 pw.Write([]byte("hi"))
4177 pw.Close()
4178 <-done
4179 }))
4180
4181 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
4182 if err != nil {
4183 t.Fatal(err)
4184 }
4185 res, err := cst.c.Do(req)
4186 if err != nil {
4187 t.Fatal(err)
4188 }
4189 all, err := io.ReadAll(res.Body)
4190 if err != nil {
4191 t.Fatal(err)
4192 }
4193 res.Body.Close()
4194 if string(all) != "hi" {
4195 t.Errorf("Body = %q; want hi", all)
4196 }
4197 }
4198
4199
4200 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
4201 for _, code := range []int{StatusNotModified, StatusNoContent} {
4202 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4203 if r.URL.Path == "/header" {
4204 w.Header().Set("Content-Length", "123")
4205 }
4206 w.WriteHeader(code)
4207 if r.URL.Path == "/more" {
4208 w.Write([]byte("stuff"))
4209 }
4210 }))
4211 for _, req := range []string{
4212 "GET / HTTP/1.0",
4213 "GET /header HTTP/1.0",
4214 "GET /more HTTP/1.0",
4215 "GET / HTTP/1.1\nHost: foo",
4216 "GET /header HTTP/1.1\nHost: foo",
4217 "GET /more HTTP/1.1\nHost: foo",
4218 } {
4219 got := ht.rawResponse(req)
4220 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
4221 if !strings.Contains(got, wantStatus) {
4222 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
4223 } else if strings.Contains(got, "Content-Length") {
4224 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
4225 } else if strings.Contains(got, "stuff") {
4226 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
4227 }
4228 }
4229 }
4230 }
4231
4232 func TestContentTypeOkayOn204(t *testing.T) {
4233 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4234 w.Header().Set("Content-Length", "123")
4235 w.Header().Set("Content-Type", "foo/bar")
4236 w.WriteHeader(204)
4237 }))
4238 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4239 if !strings.Contains(got, "Content-Type: foo/bar") {
4240 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
4241 }
4242 if strings.Contains(got, "Content-Length: 123") {
4243 t.Errorf("Response = %q; don't want a Content-Length", got)
4244 }
4245 }
4246
4247
4248
4249
4250
4251
4252
4253 func TestTransportAndServerSharedBodyRace(t *testing.T) {
4254 run(t, testTransportAndServerSharedBodyRace, testNotParallel)
4255 }
4256 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
4257
4258
4259
4260
4261 runTimeSensitiveTest(t, []time.Duration{
4262 1 * time.Millisecond,
4263 5 * time.Millisecond,
4264 10 * time.Millisecond,
4265 50 * time.Millisecond,
4266 100 * time.Millisecond,
4267 500 * time.Millisecond,
4268 time.Second,
4269 5 * time.Second,
4270 }, func(t *testing.T, timeout time.Duration) error {
4271 SetRSTAvoidanceDelay(t, timeout)
4272 t.Logf("set RST avoidance delay to %v", timeout)
4273
4274 const bodySize = 1 << 20
4275
4276 var wg sync.WaitGroup
4277 backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4278
4279
4280
4281
4282
4283
4284
4285
4286 wg.Add(1)
4287 defer wg.Done()
4288
4289 n, err := io.CopyN(rw, req.Body, bodySize)
4290 t.Logf("backend CopyN: %v, %v", n, err)
4291 <-req.Context().Done()
4292 }))
4293
4294
4295 defer func() {
4296 wg.Wait()
4297 backend.close()
4298 }()
4299
4300 var proxy *clientServerTest
4301 proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4302 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
4303 req2.ContentLength = bodySize
4304 cancel := make(chan struct{})
4305 req2.Cancel = cancel
4306
4307 bresp, err := proxy.c.Do(req2)
4308 if err != nil {
4309 t.Errorf("Proxy outbound request: %v", err)
4310 return
4311 }
4312 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
4313 if err != nil {
4314 t.Errorf("Proxy copy error: %v", err)
4315 return
4316 }
4317 t.Cleanup(func() { bresp.Body.Close() })
4318
4319
4320
4321
4322
4323
4324 if mode == http2Mode {
4325 close(cancel)
4326 } else {
4327 proxy.c.Transport.(*Transport).CancelRequest(req2)
4328 }
4329 rw.Write([]byte("OK"))
4330 }))
4331 defer proxy.close()
4332
4333 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
4334 res, err := proxy.c.Do(req)
4335 if err != nil {
4336 return fmt.Errorf("original request: %v", err)
4337 }
4338 res.Body.Close()
4339 return nil
4340 })
4341 }
4342
4343
4344
4345
4346 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
4347 run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
4348 }
4349 func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
4350 if testing.Short() {
4351 t.Skip("skipping in -short mode")
4352 }
4353
4354 readErrCh := make(chan error, 1)
4355 errCh := make(chan error, 2)
4356
4357 server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4358 go func(body io.Reader) {
4359 _, err := body.Read(make([]byte, 100))
4360 readErrCh <- err
4361 }(req.Body)
4362 time.Sleep(500 * time.Millisecond)
4363 })).ts
4364
4365 closeConn := make(chan bool)
4366 defer close(closeConn)
4367 go func() {
4368 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4369 if err != nil {
4370 errCh <- err
4371 return
4372 }
4373 defer conn.Close()
4374 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4375 if err != nil {
4376 errCh <- err
4377 return
4378 }
4379
4380
4381 <-closeConn
4382 }()
4383 select {
4384 case err := <-readErrCh:
4385 if err == nil {
4386 t.Error("Read was nil. Expected error.")
4387 }
4388 case err := <-errCh:
4389 t.Error(err)
4390 }
4391 }
4392
4393
4394 func TestResponseWriterWriteString(t *testing.T) {
4395 okc := make(chan bool, 1)
4396 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4397 _, ok := w.(io.StringWriter)
4398 okc <- ok
4399 }))
4400 ht.rawResponse("GET / HTTP/1.0")
4401 select {
4402 case ok := <-okc:
4403 if !ok {
4404 t.Error("ResponseWriter did not implement io.StringWriter")
4405 }
4406 default:
4407 t.Error("handler was never called")
4408 }
4409 }
4410
4411 func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
4412 func testServerConnState(t *testing.T, mode testMode) {
4413 handler := map[string]func(w ResponseWriter, r *Request){
4414 "/": func(w ResponseWriter, r *Request) {
4415 fmt.Fprintf(w, "Hello.")
4416 },
4417 "/close": func(w ResponseWriter, r *Request) {
4418 w.Header().Set("Connection", "close")
4419 fmt.Fprintf(w, "Hello.")
4420 },
4421 "/hijack": func(w ResponseWriter, r *Request) {
4422 c, _, _ := w.(Hijacker).Hijack()
4423 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4424 c.Close()
4425 },
4426 "/hijack-panic": func(w ResponseWriter, r *Request) {
4427 c, _, _ := w.(Hijacker).Hijack()
4428 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4429 c.Close()
4430 panic("intentional panic")
4431 },
4432 }
4433
4434
4435 type stateLog struct {
4436 active net.Conn
4437 got []ConnState
4438 want []ConnState
4439 complete chan<- struct{}
4440 }
4441 activeLog := make(chan *stateLog, 1)
4442
4443
4444
4445
4446 wantLog := func(doRequests func(), want ...ConnState) {
4447 t.Helper()
4448 complete := make(chan struct{})
4449 activeLog <- &stateLog{want: want, complete: complete}
4450
4451 doRequests()
4452
4453 <-complete
4454 sl := <-activeLog
4455 if !slices.Equal(sl.got, sl.want) {
4456 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4457 }
4458
4459
4460
4461 }
4462
4463 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4464 handler[r.URL.Path](w, r)
4465 }), func(ts *httptest.Server) {
4466 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4467 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4468 if c == nil {
4469 t.Errorf("nil conn seen in state %s", state)
4470 return
4471 }
4472 sl := <-activeLog
4473 if sl.active == nil && state == StateNew {
4474 sl.active = c
4475 } else if sl.active != c {
4476 t.Errorf("unexpected conn in state %s", state)
4477 activeLog <- sl
4478 return
4479 }
4480 sl.got = append(sl.got, state)
4481 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !slices.Equal(sl.got, sl.want[:len(sl.got)])) {
4482 close(sl.complete)
4483 sl.complete = nil
4484 }
4485 activeLog <- sl
4486 }
4487 }).ts
4488 defer func() {
4489 activeLog <- &stateLog{}
4490 ts.Close()
4491 }()
4492
4493 c := ts.Client()
4494
4495 mustGet := func(url string, headers ...string) {
4496 t.Helper()
4497 req, err := NewRequest("GET", url, nil)
4498 if err != nil {
4499 t.Fatal(err)
4500 }
4501 for len(headers) > 0 {
4502 req.Header.Add(headers[0], headers[1])
4503 headers = headers[2:]
4504 }
4505 res, err := c.Do(req)
4506 if err != nil {
4507 t.Errorf("Error fetching %s: %v", url, err)
4508 return
4509 }
4510 _, err = io.ReadAll(res.Body)
4511 defer res.Body.Close()
4512 if err != nil {
4513 t.Errorf("Error reading %s: %v", url, err)
4514 }
4515 }
4516
4517 wantLog(func() {
4518 mustGet(ts.URL + "/")
4519 mustGet(ts.URL + "/close")
4520 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4521
4522 wantLog(func() {
4523 mustGet(ts.URL + "/")
4524 mustGet(ts.URL+"/", "Connection", "close")
4525 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4526
4527 wantLog(func() {
4528 mustGet(ts.URL + "/hijack")
4529 }, StateNew, StateActive, StateHijacked)
4530
4531 wantLog(func() {
4532 mustGet(ts.URL + "/hijack-panic")
4533 }, StateNew, StateActive, StateHijacked)
4534
4535 wantLog(func() {
4536 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4537 if err != nil {
4538 t.Fatal(err)
4539 }
4540 c.Close()
4541 }, StateNew, StateClosed)
4542
4543 wantLog(func() {
4544 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4545 if err != nil {
4546 t.Fatal(err)
4547 }
4548 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4549 t.Fatal(err)
4550 }
4551 c.Read(make([]byte, 1))
4552 c.Close()
4553 }, StateNew, StateActive, StateClosed)
4554
4555 wantLog(func() {
4556 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4557 if err != nil {
4558 t.Fatal(err)
4559 }
4560 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4561 t.Fatal(err)
4562 }
4563 res, err := ReadResponse(bufio.NewReader(c), nil)
4564 if err != nil {
4565 t.Fatal(err)
4566 }
4567 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4568 t.Fatal(err)
4569 }
4570 c.Close()
4571 }, StateNew, StateActive, StateIdle, StateClosed)
4572 }
4573
4574 func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
4575 run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
4576 }
4577 func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
4578 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4579 }), func(ts *httptest.Server) {
4580 ts.Config.SetKeepAlivesEnabled(false)
4581 }).ts
4582 res, err := ts.Client().Get(ts.URL)
4583 if err != nil {
4584 t.Fatal(err)
4585 }
4586 defer res.Body.Close()
4587 if !res.Close {
4588 t.Errorf("Body.Close == false; want true")
4589 }
4590 }
4591
4592
4593 func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
4594 func testServerEmptyBodyRace(t *testing.T, mode testMode) {
4595 var n int32
4596 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4597 atomic.AddInt32(&n, 1)
4598 }), optQuietLog)
4599 var wg sync.WaitGroup
4600 const reqs = 20
4601 for i := 0; i < reqs; i++ {
4602 wg.Add(1)
4603 go func() {
4604 defer wg.Done()
4605 res, err := cst.c.Get(cst.ts.URL)
4606 if err != nil {
4607
4608
4609 time.Sleep(10 * time.Millisecond)
4610 res, err = cst.c.Get(cst.ts.URL)
4611 if err != nil {
4612 t.Error(err)
4613 return
4614 }
4615 }
4616 defer res.Body.Close()
4617 _, err = io.Copy(io.Discard, res.Body)
4618 if err != nil {
4619 t.Error(err)
4620 return
4621 }
4622 }()
4623 }
4624 wg.Wait()
4625 if got := atomic.LoadInt32(&n); got != reqs {
4626 t.Errorf("handler ran %d times; want %d", got, reqs)
4627 }
4628 }
4629
4630 func TestServerConnStateNew(t *testing.T) {
4631 sawNew := false
4632 srv := &Server{
4633 ConnState: func(c net.Conn, state ConnState) {
4634 if state == StateNew {
4635 sawNew = true
4636 }
4637 },
4638 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4639 }
4640 srv.Serve(&oneConnListener{
4641 conn: &rwTestConn{
4642 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4643 Writer: io.Discard,
4644 },
4645 })
4646 if !sawNew {
4647 t.Error("StateNew not seen")
4648 }
4649 }
4650
4651 type closeWriteTestConn struct {
4652 rwTestConn
4653 didCloseWrite bool
4654 }
4655
4656 func (c *closeWriteTestConn) CloseWrite() error {
4657 c.didCloseWrite = true
4658 return nil
4659 }
4660
4661 func TestCloseWrite(t *testing.T) {
4662 SetRSTAvoidanceDelay(t, 1*time.Millisecond)
4663
4664 var srv Server
4665 var testConn closeWriteTestConn
4666 c := ExportServerNewConn(&srv, &testConn)
4667 ExportCloseWriteAndWait(c)
4668 if !testConn.didCloseWrite {
4669 t.Error("didn't see CloseWrite call")
4670 }
4671 }
4672
4673
4674
4675
4676
4677
4678
4679
4680 func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
4681 func testServerFlushAndHijack(t *testing.T, mode testMode) {
4682 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4683 io.WriteString(w, "Hello, ")
4684 w.(Flusher).Flush()
4685 conn, buf, _ := w.(Hijacker).Hijack()
4686 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4687 if err := buf.Flush(); err != nil {
4688 t.Error(err)
4689 }
4690 if err := conn.Close(); err != nil {
4691 t.Error(err)
4692 }
4693 })).ts
4694 res, err := Get(ts.URL)
4695 if err != nil {
4696 t.Fatal(err)
4697 }
4698 defer res.Body.Close()
4699 all, err := io.ReadAll(res.Body)
4700 if err != nil {
4701 t.Fatal(err)
4702 }
4703 if want := "Hello, world!"; string(all) != want {
4704 t.Errorf("Got %q; want %q", all, want)
4705 }
4706 }
4707
4708
4709
4710
4711
4712
4713
4714 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4715 run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
4716 }
4717 func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
4718 if testing.Short() {
4719 t.Skip("skipping in -short mode")
4720 }
4721 const numReq = 3
4722 addrc := make(chan string, numReq)
4723 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4724 addrc <- r.RemoteAddr
4725 time.Sleep(500 * time.Millisecond)
4726 w.(Flusher).Flush()
4727 }), func(ts *httptest.Server) {
4728 ts.Config.WriteTimeout = 250 * time.Millisecond
4729 }).ts
4730
4731 errc := make(chan error, numReq)
4732 go func() {
4733 defer close(errc)
4734 for i := 0; i < numReq; i++ {
4735 res, err := Get(ts.URL)
4736 if res != nil {
4737 res.Body.Close()
4738 }
4739 errc <- err
4740 }
4741 }()
4742
4743 addrSeen := map[string]bool{}
4744 numOkay := 0
4745 for {
4746 select {
4747 case v := <-addrc:
4748 addrSeen[v] = true
4749 case err, ok := <-errc:
4750 if !ok {
4751 if len(addrSeen) != numReq {
4752 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4753 }
4754 if numOkay != 0 {
4755 t.Errorf("got %d successful client requests; want 0", numOkay)
4756 }
4757 return
4758 }
4759 if err == nil {
4760 numOkay++
4761 }
4762 }
4763 }
4764 }
4765
4766
4767
4768 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4769 run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
4770 }
4771 func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
4772 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4773 w.Header().Set("Transfer-Encoding", "foo")
4774 io.WriteString(w, "<html>")
4775 })).ts
4776 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4777 if err != nil {
4778 t.Fatalf("Dial: %v", err)
4779 }
4780 defer c.Close()
4781 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4782 t.Fatal(err)
4783 }
4784 bs := bufio.NewScanner(c)
4785 var got strings.Builder
4786 for bs.Scan() {
4787 if strings.TrimSpace(bs.Text()) == "" {
4788 break
4789 }
4790 got.WriteString(bs.Text())
4791 got.WriteByte('\n')
4792 }
4793 if err := bs.Err(); err != nil {
4794 t.Fatal(err)
4795 }
4796 if strings.Contains(got.String(), "Content-Length") {
4797 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4798 }
4799 if strings.Contains(got.String(), "Content-Type") {
4800 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4801 }
4802 }
4803
4804
4805
4806 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4807 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4808 "\r\n\r\n" +
4809 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4810 var buf bytes.Buffer
4811 conn := &rwTestConn{
4812 Reader: bytes.NewReader(req),
4813 Writer: &buf,
4814 closec: make(chan bool, 1),
4815 }
4816 ln := &oneConnListener{conn: conn}
4817 numReq := 0
4818 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4819 numReq++
4820 }))
4821 <-conn.closec
4822 if numReq != 2 {
4823 t.Errorf("num requests = %d; want 2", numReq)
4824 t.Logf("Res: %s", buf.Bytes())
4825 }
4826 }
4827
4828 func TestIssue13893_Expect100(t *testing.T) {
4829
4830 req := reqBytes(`PUT /readbody HTTP/1.1
4831 User-Agent: PycURL/7.22.0
4832 Host: 127.0.0.1:9000
4833 Accept: */*
4834 Expect: 100-continue
4835 Content-Length: 10
4836
4837 HelloWorld
4838
4839 `)
4840 var buf bytes.Buffer
4841 conn := &rwTestConn{
4842 Reader: bytes.NewReader(req),
4843 Writer: &buf,
4844 closec: make(chan bool, 1),
4845 }
4846 ln := &oneConnListener{conn: conn}
4847 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4848 if _, ok := r.Header["Expect"]; !ok {
4849 t.Error("Expect header should not be filtered out")
4850 }
4851 }))
4852 <-conn.closec
4853 }
4854
4855 func TestIssue11549_Expect100(t *testing.T) {
4856 req := reqBytes(`PUT /readbody HTTP/1.1
4857 User-Agent: PycURL/7.22.0
4858 Host: 127.0.0.1:9000
4859 Accept: */*
4860 Expect: 100-continue
4861 Content-Length: 10
4862
4863 HelloWorldPUT /noreadbody HTTP/1.1
4864 User-Agent: PycURL/7.22.0
4865 Host: 127.0.0.1:9000
4866 Accept: */*
4867 Expect: 100-continue
4868 Content-Length: 10
4869
4870 GET /should-be-ignored HTTP/1.1
4871 Host: foo
4872
4873 `)
4874 var buf strings.Builder
4875 conn := &rwTestConn{
4876 Reader: bytes.NewReader(req),
4877 Writer: &buf,
4878 closec: make(chan bool, 1),
4879 }
4880 ln := &oneConnListener{conn: conn}
4881 numReq := 0
4882 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4883 numReq++
4884 if r.URL.Path == "/readbody" {
4885 io.ReadAll(r.Body)
4886 }
4887 io.WriteString(w, "Hello world!")
4888 }))
4889 <-conn.closec
4890 if numReq != 2 {
4891 t.Errorf("num requests = %d; want 2", numReq)
4892 }
4893 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4894 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4895 }
4896 }
4897
4898
4899
4900 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4901 setParallel(t)
4902 conn := newTestConn()
4903 conn.readBuf.WriteString(
4904 "POST / HTTP/1.1\r\n" +
4905 "Host: test\r\n" +
4906 "Content-Length: 9999999999\r\n" +
4907 "\r\n" + strings.Repeat("a", 1<<20))
4908
4909 ls := &oneConnListener{conn}
4910 var inHandlerLen int
4911 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4912 inHandlerLen = conn.readBuf.Len()
4913 rw.WriteHeader(404)
4914 }))
4915 <-conn.closec
4916 afterHandlerLen := conn.readBuf.Len()
4917
4918 if afterHandlerLen != inHandlerLen {
4919 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4920 }
4921 }
4922
4923 func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
4924 func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
4925 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4926 r.Body = nil
4927 fmt.Fprintf(w, "%v", r.RemoteAddr)
4928 }))
4929 get := func() string {
4930 res, err := cst.c.Get(cst.ts.URL)
4931 if err != nil {
4932 t.Fatal(err)
4933 }
4934 defer res.Body.Close()
4935 slurp, err := io.ReadAll(res.Body)
4936 if err != nil {
4937 t.Fatal(err)
4938 }
4939 return string(slurp)
4940 }
4941 a, b := get(), get()
4942 if a != b {
4943 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
4944 }
4945 }
4946
4947
4948
4949 func TestServerValidatesHostHeader(t *testing.T) {
4950 tests := []struct {
4951 proto string
4952 host string
4953 want int
4954 }{
4955 {"HTTP/0.9", "", 505},
4956
4957 {"HTTP/1.1", "", 400},
4958 {"HTTP/1.1", "Host: \r\n", 200},
4959 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4960 {"HTTP/1.1", "Host: foo.com\r\n", 200},
4961 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
4962 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
4963 {"HTTP/1.1", "Host: ::1\r\n", 200},
4964 {"HTTP/1.1", "Host: [::1]\r\n", 200},
4965 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
4966 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
4967 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4968 {"HTTP/1.1", "Host: \x06\r\n", 400},
4969 {"HTTP/1.1", "Host: \xff\r\n", 400},
4970 {"HTTP/1.1", "Host: {\r\n", 400},
4971 {"HTTP/1.1", "Host: }\r\n", 400},
4972 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
4973
4974
4975
4976 {"HTTP/1.0", "", 200},
4977 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
4978 {"HTTP/1.0", "Host: \xff\r\n", 400},
4979
4980
4981 {"PRI * HTTP/2.0", "", 200},
4982
4983
4984 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
4985
4986
4987 {"PRI / HTTP/2.0", "", 505},
4988 {"GET / HTTP/2.0", "", 505},
4989 {"GET / HTTP/3.0", "", 505},
4990 }
4991 for _, tt := range tests {
4992 conn := newTestConn()
4993 methodTarget := "GET / "
4994 if !strings.HasPrefix(tt.proto, "HTTP/") {
4995 methodTarget = ""
4996 }
4997 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
4998
4999 ln := &oneConnListener{conn}
5000 srv := Server{
5001 ErrorLog: quietLog,
5002 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5003 }
5004 go srv.Serve(ln)
5005 <-conn.closec
5006 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5007 if err != nil {
5008 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
5009 continue
5010 }
5011 if res.StatusCode != tt.want {
5012 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
5013 }
5014 }
5015 }
5016
5017 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
5018 run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
5019 }
5020 func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
5021 const upgradeResponse = "upgrade here"
5022 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5023 conn, br, err := w.(Hijacker).Hijack()
5024 if err != nil {
5025 t.Error(err)
5026 return
5027 }
5028 defer conn.Close()
5029 if r.Method != "PRI" || r.RequestURI != "*" {
5030 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
5031 return
5032 }
5033 if !r.Close {
5034 t.Errorf("Request.Close = true; want false")
5035 }
5036 const want = "SM\r\n\r\n"
5037 buf := make([]byte, len(want))
5038 n, err := io.ReadFull(br, buf)
5039 if err != nil || string(buf[:n]) != want {
5040 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
5041 return
5042 }
5043 io.WriteString(conn, upgradeResponse)
5044 })).ts
5045
5046 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5047 if err != nil {
5048 t.Fatalf("Dial: %v", err)
5049 }
5050 defer c.Close()
5051 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
5052 slurp, err := io.ReadAll(c)
5053 if err != nil {
5054 t.Fatal(err)
5055 }
5056 if string(slurp) != upgradeResponse {
5057 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
5058 }
5059 }
5060
5061
5062
5063 func TestServerValidatesHeaders(t *testing.T) {
5064 setParallel(t)
5065 tests := []struct {
5066 header string
5067 want int
5068 }{
5069 {"", 200},
5070 {"Foo: bar\r\n", 200},
5071 {"X-Foo: bar\r\n", 200},
5072 {"Foo: a space\r\n", 200},
5073
5074 {"A space: foo\r\n", 400},
5075 {"foo\xffbar: foo\r\n", 400},
5076 {"foo\x00bar: foo\r\n", 400},
5077 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
5078
5079
5080 {"Foo : bar\r\n", 400},
5081 {"Foo\t: bar\r\n", 400},
5082
5083
5084
5085 {": empty key\r\n", 400},
5086
5087
5088
5089
5090 {"Content-Length: notdigits\r\n", 400},
5091 {"Content-Length: notdigits\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n", 400},
5092
5093 {"foo: foo foo\r\n", 200},
5094 {"foo: foo\tfoo\r\n", 200},
5095 {"foo: foo\x00foo\r\n", 400},
5096 {"foo: foo\x7ffoo\r\n", 400},
5097 {"foo: foo\xfffoo\r\n", 200},
5098 }
5099 for _, tt := range tests {
5100 conn := newTestConn()
5101 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
5102
5103 ln := &oneConnListener{conn}
5104 srv := Server{
5105 ErrorLog: quietLog,
5106 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5107 }
5108 go srv.Serve(ln)
5109 <-conn.closec
5110 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5111 if err != nil {
5112 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
5113 continue
5114 }
5115 if res.StatusCode != tt.want {
5116 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
5117 }
5118 }
5119 }
5120
5121 func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
5122 run(t, testServerRequestContextCancel_ServeHTTPDone)
5123 }
5124 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
5125 ctxc := make(chan context.Context, 1)
5126 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5127 ctx := r.Context()
5128 select {
5129 case <-ctx.Done():
5130 t.Error("should not be Done in ServeHTTP")
5131 default:
5132 }
5133 ctxc <- ctx
5134 }))
5135 res, err := cst.c.Get(cst.ts.URL)
5136 if err != nil {
5137 t.Fatal(err)
5138 }
5139 res.Body.Close()
5140 ctx := <-ctxc
5141 select {
5142 case <-ctx.Done():
5143 default:
5144 t.Error("context should be done after ServeHTTP completes")
5145 }
5146 }
5147
5148
5149
5150
5151
5152 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
5153 run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
5154 }
5155 func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
5156 inHandler := make(chan struct{})
5157 handlerDone := make(chan struct{})
5158 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5159 close(inHandler)
5160 <-r.Context().Done()
5161 close(handlerDone)
5162 })).ts
5163 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5164 if err != nil {
5165 t.Fatal(err)
5166 }
5167 defer c.Close()
5168 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
5169 <-inHandler
5170 c.Close()
5171 <-handlerDone
5172 }
5173
5174 func TestServerContext_ServerContextKey(t *testing.T) {
5175 run(t, testServerContext_ServerContextKey)
5176 }
5177 func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
5178 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5179 ctx := r.Context()
5180 got := ctx.Value(ServerContextKey)
5181 if _, ok := got.(*Server); !ok {
5182 t.Errorf("context value = %T; want *http.Server", got)
5183 }
5184 }))
5185 res, err := cst.c.Get(cst.ts.URL)
5186 if err != nil {
5187 t.Fatal(err)
5188 }
5189 res.Body.Close()
5190 }
5191
5192 func TestServerContext_LocalAddrContextKey(t *testing.T) {
5193 run(t, testServerContext_LocalAddrContextKey)
5194 }
5195 func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
5196 ch := make(chan any, 1)
5197 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5198 ch <- r.Context().Value(LocalAddrContextKey)
5199 }))
5200 if _, err := cst.c.Head(cst.ts.URL); err != nil {
5201 t.Fatal(err)
5202 }
5203
5204 host := cst.ts.Listener.Addr().String()
5205 got := <-ch
5206 if addr, ok := got.(net.Addr); !ok {
5207 t.Errorf("local addr value = %T; want net.Addr", got)
5208 } else if fmt.Sprint(addr) != host {
5209 t.Errorf("local addr = %v; want %v", addr, host)
5210 }
5211 }
5212
5213
5214 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
5215 setParallel(t)
5216 defer afterTest(t)
5217 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5218 w.Header().Set("Transfer-Encoding", "chunked")
5219 w.Write([]byte("hello"))
5220 }))
5221 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5222 const hdr = "Transfer-Encoding: chunked"
5223 if n := strings.Count(resp, hdr); n != 1 {
5224 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5225 }
5226 }
5227
5228
5229 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
5230 setParallel(t)
5231 defer afterTest(t)
5232 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5233 w.Header().Set("Transfer-Encoding", "gzip")
5234 gz := gzip.NewWriter(w)
5235 gz.Write([]byte("hello"))
5236 gz.Close()
5237 }))
5238 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5239 for _, v := range []string{"gzip", "chunked"} {
5240 hdr := "Transfer-Encoding: " + v
5241 if n := strings.Count(resp, hdr); n != 1 {
5242 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5243 }
5244 }
5245 }
5246
5247 func BenchmarkClientServer(b *testing.B) {
5248 run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
5249 }
5250 func benchmarkClientServer(b *testing.B, mode testMode) {
5251 b.ReportAllocs()
5252 b.StopTimer()
5253 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5254 fmt.Fprintf(rw, "Hello world.\n")
5255 })).ts
5256 b.StartTimer()
5257
5258 c := ts.Client()
5259 for i := 0; i < b.N; i++ {
5260 res, err := c.Get(ts.URL)
5261 if err != nil {
5262 b.Fatal("Get:", err)
5263 }
5264 all, err := io.ReadAll(res.Body)
5265 res.Body.Close()
5266 if err != nil {
5267 b.Fatal("ReadAll:", err)
5268 }
5269 body := string(all)
5270 if body != "Hello world.\n" {
5271 b.Fatal("Got body:", body)
5272 }
5273 }
5274
5275 b.StopTimer()
5276 }
5277
5278 func BenchmarkClientServerParallel(b *testing.B) {
5279 for _, parallelism := range []int{4, 64} {
5280 b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
5281 run(b, func(b *testing.B, mode testMode) {
5282 benchmarkClientServerParallel(b, parallelism, mode)
5283 }, []testMode{http1Mode, https1Mode, http2Mode})
5284 })
5285 }
5286 }
5287
5288 func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
5289 b.ReportAllocs()
5290 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5291 fmt.Fprintf(rw, "Hello world.\n")
5292 })).ts
5293 b.ResetTimer()
5294 b.SetParallelism(parallelism)
5295 b.RunParallel(func(pb *testing.PB) {
5296 c := ts.Client()
5297 for pb.Next() {
5298 res, err := c.Get(ts.URL)
5299 if err != nil {
5300 b.Logf("Get: %v", err)
5301 continue
5302 }
5303 all, err := io.ReadAll(res.Body)
5304 res.Body.Close()
5305 if err != nil {
5306 b.Logf("ReadAll: %v", err)
5307 continue
5308 }
5309 body := string(all)
5310 if body != "Hello world.\n" {
5311 panic("Got body: " + body)
5312 }
5313 }
5314 })
5315 }
5316
5317
5318
5319
5320
5321
5322
5323
5324
5325
5326 func BenchmarkServer(b *testing.B) {
5327 b.ReportAllocs()
5328
5329 if url := os.Getenv("GO_TEST_BENCH_SERVER_URL"); url != "" {
5330 n, err := strconv.Atoi(os.Getenv("GO_TEST_BENCH_CLIENT_N"))
5331 if err != nil {
5332 panic(err)
5333 }
5334 for i := 0; i < n; i++ {
5335 res, err := Get(url)
5336 if err != nil {
5337 log.Panicf("Get: %v", err)
5338 }
5339 all, err := io.ReadAll(res.Body)
5340 res.Body.Close()
5341 if err != nil {
5342 log.Panicf("ReadAll: %v", err)
5343 }
5344 body := string(all)
5345 if body != "Hello world.\n" {
5346 log.Panicf("Got body: %q", body)
5347 }
5348 }
5349 os.Exit(0)
5350 return
5351 }
5352
5353 var res = []byte("Hello world.\n")
5354 b.StopTimer()
5355 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5356 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5357 rw.Write(res)
5358 }))
5359 defer ts.Close()
5360 b.StartTimer()
5361
5362 cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$")
5363 cmd.Env = append([]string{
5364 fmt.Sprintf("GO_TEST_BENCH_CLIENT_N=%d", b.N),
5365 fmt.Sprintf("GO_TEST_BENCH_SERVER_URL=%s", ts.URL),
5366 }, os.Environ()...)
5367 out, err := cmd.CombinedOutput()
5368 if err != nil {
5369 b.Errorf("Test failure: %v, with output: %s", err, out)
5370 }
5371 }
5372
5373
5374 func getNoBody(urlStr string) (*Response, error) {
5375 res, err := Get(urlStr)
5376 if err != nil {
5377 return nil, err
5378 }
5379 res.Body.Close()
5380 return res, nil
5381 }
5382
5383
5384
5385 func BenchmarkClient(b *testing.B) {
5386 var data = []byte("Hello world.\n")
5387
5388 url := startClientBenchmarkServer(b, HandlerFunc(func(w ResponseWriter, _ *Request) {
5389 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5390 w.Write(data)
5391 }))
5392
5393
5394 b.StartTimer()
5395 for i := 0; i < b.N; i++ {
5396 res, err := Get(url)
5397 if err != nil {
5398 b.Fatalf("Get: %v", err)
5399 }
5400 body, err := io.ReadAll(res.Body)
5401 res.Body.Close()
5402 if err != nil {
5403 b.Fatalf("ReadAll: %v", err)
5404 }
5405 if !bytes.Equal(body, data) {
5406 b.Fatalf("Got body: %q", body)
5407 }
5408 }
5409 b.StopTimer()
5410 }
5411
5412 func startClientBenchmarkServer(b *testing.B, handler Handler) string {
5413 b.ReportAllocs()
5414 b.StopTimer()
5415
5416 if server := os.Getenv("GO_TEST_BENCH_SERVER"); server != "" {
5417
5418 port := os.Getenv("GO_TEST_BENCH_SERVER_PORT")
5419 if port == "" {
5420 port = "0"
5421 }
5422 ln, err := net.Listen("tcp", "localhost:"+port)
5423 if err != nil {
5424 log.Fatal(err)
5425 }
5426 fmt.Println(ln.Addr().String())
5427
5428 HandleFunc("/", func(w ResponseWriter, r *Request) {
5429 r.ParseForm()
5430 if r.Form.Get("stop") != "" {
5431 os.Exit(0)
5432 }
5433 handler.ServeHTTP(w, r)
5434 })
5435 var srv Server
5436 log.Fatal(srv.Serve(ln))
5437 }
5438
5439
5440 ctx, cancel := context.WithCancel(context.Background())
5441 cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^"+b.Name()+"$")
5442 cmd.Env = append(cmd.Environ(), "GO_TEST_BENCH_SERVER=yes")
5443 cmd.Stderr = os.Stderr
5444 stdout, err := cmd.StdoutPipe()
5445 if err != nil {
5446 b.Fatal(err)
5447 }
5448 if err := cmd.Start(); err != nil {
5449 b.Fatalf("subprocess failed to start: %v", err)
5450 }
5451
5452 done := make(chan error, 1)
5453 go func() {
5454 done <- cmd.Wait()
5455 close(done)
5456 }()
5457
5458
5459
5460 bs := bufio.NewScanner(stdout)
5461 if !bs.Scan() {
5462 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5463 }
5464 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5465 if _, err := getNoBody(url); err != nil {
5466 b.Fatalf("initial probe of child process failed: %v", err)
5467 }
5468
5469
5470 b.Cleanup(func() {
5471 getNoBody(url + "?stop=yes")
5472 if err := <-done; err != nil {
5473 b.Fatalf("subprocess failed: %v", err)
5474 }
5475
5476 cancel()
5477 <-done
5478
5479 afterTest(b)
5480 })
5481
5482 return url
5483 }
5484
5485 func BenchmarkClientGzip(b *testing.B) {
5486 const responseSize = 1024 * 1024
5487
5488 var buf bytes.Buffer
5489 gz := gzip.NewWriter(&buf)
5490 if _, err := io.CopyN(gz, crand.Reader, responseSize); err != nil {
5491 b.Fatal(err)
5492 }
5493 gz.Close()
5494
5495 data := buf.Bytes()
5496
5497 url := startClientBenchmarkServer(b, HandlerFunc(func(w ResponseWriter, _ *Request) {
5498 w.Header().Set("Content-Encoding", "gzip")
5499 w.Write(data)
5500 }))
5501
5502
5503 b.StartTimer()
5504 for i := 0; i < b.N; i++ {
5505 res, err := Get(url)
5506 if err != nil {
5507 b.Fatalf("Get: %v", err)
5508 }
5509 n, err := io.Copy(io.Discard, res.Body)
5510 res.Body.Close()
5511 if err != nil {
5512 b.Fatalf("ReadAll: %v", err)
5513 }
5514 if n != responseSize {
5515 b.Fatalf("ReadAll: expected %d bytes, got %d", responseSize, n)
5516 }
5517 }
5518 b.StopTimer()
5519 }
5520
5521 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5522 b.ReportAllocs()
5523 req := reqBytes(`GET / HTTP/1.0
5524 Host: golang.org
5525 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5526 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5527 Accept-Encoding: gzip,deflate,sdch
5528 Accept-Language: en-US,en;q=0.8
5529 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5530 `)
5531 res := []byte("Hello world!\n")
5532
5533 conn := newTestConn()
5534 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5535 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5536 rw.Write(res)
5537 })
5538 ln := new(oneConnListener)
5539 for i := 0; i < b.N; i++ {
5540 conn.readBuf.Reset()
5541 conn.writeBuf.Reset()
5542 conn.readBuf.Write(req)
5543 ln.conn = conn
5544 Serve(ln, handler)
5545 <-conn.closec
5546 }
5547 }
5548
5549
5550 type repeatReader struct {
5551 content []byte
5552 count int
5553 off int
5554 }
5555
5556 func (r *repeatReader) Read(p []byte) (n int, err error) {
5557 if r.count <= 0 {
5558 return 0, io.EOF
5559 }
5560 n = copy(p, r.content[r.off:])
5561 r.off += n
5562 if r.off == len(r.content) {
5563 r.count--
5564 r.off = 0
5565 }
5566 return
5567 }
5568
5569 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5570 b.ReportAllocs()
5571
5572 req := reqBytes(`GET / HTTP/1.1
5573 Host: golang.org
5574 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5575 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5576 Accept-Encoding: gzip,deflate,sdch
5577 Accept-Language: en-US,en;q=0.8
5578 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5579 `)
5580 res := []byte("Hello world!\n")
5581
5582 conn := &rwTestConn{
5583 Reader: &repeatReader{content: req, count: b.N},
5584 Writer: io.Discard,
5585 closec: make(chan bool, 1),
5586 }
5587 handled := 0
5588 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5589 handled++
5590 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5591 rw.Write(res)
5592 })
5593 ln := &oneConnListener{conn: conn}
5594 go Serve(ln, handler)
5595 <-conn.closec
5596 if b.N != handled {
5597 b.Errorf("b.N=%d but handled %d", b.N, handled)
5598 }
5599 }
5600
5601
5602
5603 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5604 b.ReportAllocs()
5605
5606 req := reqBytes(`GET / HTTP/1.1
5607 Host: golang.org
5608 `)
5609 res := []byte("Hello world!\n")
5610
5611 conn := &rwTestConn{
5612 Reader: &repeatReader{content: req, count: b.N},
5613 Writer: io.Discard,
5614 closec: make(chan bool, 1),
5615 }
5616 handled := 0
5617 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5618 handled++
5619 rw.Write(res)
5620 })
5621 ln := &oneConnListener{conn: conn}
5622 go Serve(ln, handler)
5623 <-conn.closec
5624 if b.N != handled {
5625 b.Errorf("b.N=%d but handled %d", b.N, handled)
5626 }
5627 }
5628
5629 const someResponse = "<html>some response</html>"
5630
5631
5632 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5633
5634
5635 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5636 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5637 w.Header().Set("Content-Type", "text/html")
5638 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5639 w.Write(response)
5640 }))
5641 }
5642
5643
5644 func BenchmarkServerHandlerNoLen(b *testing.B) {
5645 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5646 w.Header().Set("Content-Type", "text/html")
5647 w.Write(response)
5648 }))
5649 }
5650
5651
5652 func BenchmarkServerHandlerNoType(b *testing.B) {
5653 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5654 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5655 w.Write(response)
5656 }))
5657 }
5658
5659
5660 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5661 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5662 w.Write(response)
5663 }))
5664 }
5665
5666 func benchmarkHandler(b *testing.B, h Handler) {
5667 b.ReportAllocs()
5668 req := reqBytes(`GET / HTTP/1.1
5669 Host: golang.org
5670 `)
5671 conn := &rwTestConn{
5672 Reader: &repeatReader{content: req, count: b.N},
5673 Writer: io.Discard,
5674 closec: make(chan bool, 1),
5675 }
5676 handled := 0
5677 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5678 handled++
5679 h.ServeHTTP(rw, r)
5680 })
5681 ln := &oneConnListener{conn: conn}
5682 go Serve(ln, handler)
5683 <-conn.closec
5684 if b.N != handled {
5685 b.Errorf("b.N=%d but handled %d", b.N, handled)
5686 }
5687 }
5688
5689 func BenchmarkServerHijack(b *testing.B) {
5690 b.ReportAllocs()
5691 req := reqBytes(`GET / HTTP/1.1
5692 Host: golang.org
5693 `)
5694 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5695 conn, _, err := w.(Hijacker).Hijack()
5696 if err != nil {
5697 panic(err)
5698 }
5699 conn.Close()
5700 })
5701 conn := &rwTestConn{
5702 Writer: io.Discard,
5703 closec: make(chan bool, 1),
5704 }
5705 ln := &oneConnListener{conn: conn}
5706 for i := 0; i < b.N; i++ {
5707 conn.Reader = bytes.NewReader(req)
5708 ln.conn = conn
5709 Serve(ln, h)
5710 <-conn.closec
5711 }
5712 }
5713
5714 func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
5715 func benchmarkCloseNotifier(b *testing.B, mode testMode) {
5716 b.ReportAllocs()
5717 b.StopTimer()
5718 sawClose := make(chan bool)
5719 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
5720 <-rw.(CloseNotifier).CloseNotify()
5721 sawClose <- true
5722 })).ts
5723 b.StartTimer()
5724 for i := 0; i < b.N; i++ {
5725 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5726 if err != nil {
5727 b.Fatalf("error dialing: %v", err)
5728 }
5729 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5730 if err != nil {
5731 b.Fatal(err)
5732 }
5733 conn.Close()
5734 <-sawClose
5735 }
5736 b.StopTimer()
5737 }
5738
5739
5740 func TestConcurrentServerServe(t *testing.T) {
5741 setParallel(t)
5742 for i := 0; i < 100; i++ {
5743 ln1 := &oneConnListener{conn: nil}
5744 ln2 := &oneConnListener{conn: nil}
5745 srv := Server{}
5746 go func() { srv.Serve(ln1) }()
5747 go func() { srv.Serve(ln2) }()
5748 }
5749 }
5750
5751 func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
5752 func testServerIdleTimeout(t *testing.T, mode testMode) {
5753 if testing.Short() {
5754 t.Skip("skipping in short mode")
5755 }
5756 runTimeSensitiveTest(t, []time.Duration{
5757 10 * time.Millisecond,
5758 100 * time.Millisecond,
5759 1 * time.Second,
5760 10 * time.Second,
5761 }, func(t *testing.T, readHeaderTimeout time.Duration) error {
5762 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5763 io.Copy(io.Discard, r.Body)
5764 io.WriteString(w, r.RemoteAddr)
5765 }), func(ts *httptest.Server) {
5766 ts.Config.ReadHeaderTimeout = readHeaderTimeout
5767 ts.Config.IdleTimeout = 2 * readHeaderTimeout
5768 })
5769 defer cst.close()
5770 ts := cst.ts
5771 t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
5772 t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
5773 c := ts.Client()
5774
5775 get := func() (string, error) {
5776 res, err := c.Get(ts.URL)
5777 if err != nil {
5778 return "", err
5779 }
5780 defer res.Body.Close()
5781 slurp, err := io.ReadAll(res.Body)
5782 if err != nil {
5783
5784
5785
5786 t.Fatal(err)
5787 }
5788 return string(slurp), nil
5789 }
5790
5791 a1, err := get()
5792 if err != nil {
5793 return err
5794 }
5795 a2, err := get()
5796 if err != nil {
5797 return err
5798 }
5799 if a1 != a2 {
5800 return fmt.Errorf("did requests on different connections")
5801 }
5802 time.Sleep(ts.Config.IdleTimeout * 3 / 2)
5803 a3, err := get()
5804 if err != nil {
5805 return err
5806 }
5807 if a2 == a3 {
5808 return fmt.Errorf("request three unexpectedly on same connection")
5809 }
5810
5811
5812 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5813 if err != nil {
5814 return err
5815 }
5816 defer conn.Close()
5817 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5818 time.Sleep(ts.Config.ReadHeaderTimeout * 2)
5819 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5820 return fmt.Errorf("copy byte succeeded; want err")
5821 }
5822
5823 return nil
5824 })
5825 }
5826
5827 func get(t *testing.T, c *Client, url string) string {
5828 res, err := c.Get(url)
5829 if err != nil {
5830 t.Fatal(err)
5831 }
5832 defer res.Body.Close()
5833 slurp, err := io.ReadAll(res.Body)
5834 if err != nil {
5835 t.Fatal(err)
5836 }
5837 return string(slurp)
5838 }
5839
5840
5841
5842 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5843 run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
5844 }
5845 func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
5846 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5847 io.WriteString(w, r.RemoteAddr)
5848 })).ts
5849
5850 c := ts.Client()
5851 tr := c.Transport.(*Transport)
5852
5853 get := func() string { return get(t, c, ts.URL) }
5854
5855 a1, a2 := get(), get()
5856 if a1 == a2 {
5857 t.Logf("made two requests from a single conn %q (as expected)", a1)
5858 } else {
5859 t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
5860 }
5861
5862
5863
5864
5865
5866 if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
5867 t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
5868 }
5869
5870
5871 ts.Config.SetKeepAlivesEnabled(false)
5872
5873 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5874 if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
5875 if d > 0 {
5876 t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
5877 }
5878 return false
5879 }
5880 return true
5881 })
5882
5883
5884
5885
5886 }
5887
5888 func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) }
5889 func testServerShutdown(t *testing.T, mode testMode) {
5890 var cst *clientServerTest
5891
5892 var once sync.Once
5893 statesRes := make(chan map[ConnState]int, 1)
5894 shutdownRes := make(chan error, 1)
5895 gotOnShutdown := make(chan struct{})
5896 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5897 first := false
5898 once.Do(func() {
5899 statesRes <- cst.ts.Config.ExportAllConnsByState()
5900 go func() {
5901 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5902 }()
5903 first = true
5904 })
5905
5906 if first {
5907
5908
5909
5910 <-gotOnShutdown
5911
5912
5913 for !t.Failed() {
5914 res, err := cst.c.Get(cst.ts.URL)
5915 if err != nil {
5916 break
5917 }
5918 out, _ := io.ReadAll(res.Body)
5919 res.Body.Close()
5920 if mode == http2Mode {
5921 t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5922 t.Logf("Retrying to work around https://go.dev/issue/59038.")
5923 continue
5924 }
5925 t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5926 }
5927 }
5928
5929 io.WriteString(w, r.RemoteAddr)
5930 })
5931
5932 cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
5933 srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
5934 })
5935
5936 out := get(t, cst.c, cst.ts.URL)
5937 t.Logf("%v: %q", cst.ts.URL, out)
5938
5939 if err := <-shutdownRes; err != nil {
5940 t.Fatalf("Shutdown: %v", err)
5941 }
5942 <-gotOnShutdown
5943
5944 if states := <-statesRes; states[StateActive] != 1 {
5945 t.Errorf("connection in wrong state, %v", states)
5946 }
5947 }
5948
5949 func TestServerShutdownStateNew(t *testing.T) { runSynctest(t, testServerShutdownStateNew) }
5950 func testServerShutdownStateNew(t *testing.T, mode testMode) {
5951 if testing.Short() {
5952 t.Skip("test takes 5-6 seconds; skipping in short mode")
5953 }
5954
5955 listener := fakeNetListen()
5956 defer listener.Close()
5957
5958 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5959
5960 }), func(ts *httptest.Server) {
5961 ts.Listener.Close()
5962 ts.Listener = listener
5963
5964 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
5965 }).ts
5966
5967
5968 c := listener.connect()
5969 defer c.Close()
5970 synctest.Wait()
5971
5972 shutdownRes := runAsync(func() (struct{}, error) {
5973 return struct{}{}, ts.Config.Shutdown(context.Background())
5974 })
5975
5976
5977
5978
5979 const expectTimeout = 5 * time.Second
5980
5981
5982 time.Sleep(expectTimeout - 1)
5983 synctest.Wait()
5984 if shutdownRes.done() {
5985 t.Fatal("shutdown too soon")
5986 }
5987 if c.IsClosedByPeer() {
5988 t.Fatal("connection was closed by server too soon")
5989 }
5990
5991
5992
5993
5994
5995 time.Sleep(2 * time.Second)
5996 synctest.Wait()
5997 if _, err := shutdownRes.result(); err != nil {
5998 t.Fatalf("Shutdown() = %v, want complete", err)
5999 }
6000 if !c.IsClosedByPeer() {
6001 t.Fatalf("connection was not closed by server after shutdown")
6002 }
6003 }
6004
6005
6006 func TestServerCloseDeadlock(t *testing.T) {
6007 var s Server
6008 s.Close()
6009 s.Close()
6010 }
6011
6012
6013
6014 func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) }
6015 func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
6016 if mode == http2Mode {
6017 restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
6018 defer restore()
6019 }
6020
6021 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
6022 defer cst.close()
6023 srv := cst.ts.Config
6024 srv.SetKeepAlivesEnabled(false)
6025 for try := 0; try < 2; try++ {
6026 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
6027 if !srv.ExportAllConnsIdle() {
6028 if d > 0 {
6029 t.Logf("test server still has active conns after %v", d)
6030 }
6031 return false
6032 }
6033 return true
6034 })
6035 conns := 0
6036 var info httptrace.GotConnInfo
6037 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6038 GotConn: func(v httptrace.GotConnInfo) {
6039 conns++
6040 info = v
6041 },
6042 })
6043 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6044 if err != nil {
6045 t.Fatal(err)
6046 }
6047 res, err := cst.c.Do(req)
6048 if err != nil {
6049 t.Fatal(err)
6050 }
6051 res.Body.Close()
6052 if conns != 1 {
6053 t.Fatalf("request %v: got %v conns, want 1", try, conns)
6054 }
6055 if info.Reused || info.WasIdle {
6056 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
6057 }
6058 }
6059 }
6060
6061
6062
6063
6064 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
6065 func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
6066 runTimeSensitiveTest(t, []time.Duration{
6067 10 * time.Millisecond,
6068 50 * time.Millisecond,
6069 250 * time.Millisecond,
6070 time.Second,
6071 2 * time.Second,
6072 }, func(t *testing.T, timeout time.Duration) error {
6073 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6074 select {
6075 case <-time.After(2 * timeout):
6076 fmt.Fprint(w, "ok")
6077 case <-r.Context().Done():
6078 fmt.Fprint(w, r.Context().Err())
6079 }
6080 }), func(ts *httptest.Server) {
6081 ts.Config.ReadTimeout = timeout
6082 t.Logf("Server.Config.ReadTimeout = %v", timeout)
6083 })
6084 defer cst.close()
6085 ts := cst.ts
6086
6087 var retries atomic.Int32
6088 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
6089 if retries.Add(1) != 1 {
6090 return nil, errors.New("too many retries")
6091 }
6092 return nil, nil
6093 }
6094
6095 c := ts.Client()
6096
6097 res, err := c.Get(ts.URL)
6098 if err != nil {
6099 return fmt.Errorf("Get: %v", err)
6100 }
6101 slurp, err := io.ReadAll(res.Body)
6102 res.Body.Close()
6103 if err != nil {
6104 return fmt.Errorf("Body ReadAll: %v", err)
6105 }
6106 if string(slurp) != "ok" {
6107 return fmt.Errorf("got: %q, want ok", slurp)
6108 }
6109 return nil
6110 })
6111 }
6112
6113
6114
6115
6116 func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
6117 run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
6118 }
6119 func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
6120 runTimeSensitiveTest(t, []time.Duration{
6121 10 * time.Millisecond,
6122 50 * time.Millisecond,
6123 250 * time.Millisecond,
6124 time.Second,
6125 2 * time.Second,
6126 }, func(t *testing.T, timeout time.Duration) error {
6127 cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
6128 ts.Config.ReadHeaderTimeout = timeout
6129 ts.Config.IdleTimeout = 0
6130 })
6131 defer cst.close()
6132 ts := cst.ts
6133
6134
6135
6136 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6137 if err != nil {
6138 t.Fatalf("dial failed: %v", err)
6139 }
6140 br := bufio.NewReader(conn)
6141 defer conn.Close()
6142
6143 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6144 return fmt.Errorf("writing first request failed: %v", err)
6145 }
6146
6147 if _, err := ReadResponse(br, nil); err != nil {
6148 return fmt.Errorf("first response (before timeout) failed: %v", err)
6149 }
6150
6151
6152
6153 time.Sleep(timeout * 3 / 2)
6154
6155 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6156 return fmt.Errorf("writing second request failed: %v", err)
6157 }
6158
6159 if _, err := ReadResponse(br, nil); err != nil {
6160 return fmt.Errorf("second response (after timeout) failed: %v", err)
6161 }
6162
6163 return nil
6164 })
6165 }
6166
6167
6168
6169 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
6170 for i, d := range durations {
6171 err := test(t, d)
6172 if err == nil {
6173 return
6174 }
6175 if i == len(durations)-1 || t.Failed() {
6176 t.Fatalf("failed with duration %v: %v", d, err)
6177 }
6178 t.Logf("retrying after error with duration %v: %v", d, err)
6179 }
6180 }
6181
6182
6183
6184 func TestServerDuplicateBackgroundRead(t *testing.T) {
6185 run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
6186 }
6187 func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
6188 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
6189 testenv.SkipFlaky(t, 24826)
6190 }
6191
6192 goroutines := 5
6193 requests := 2000
6194 if testing.Short() {
6195 goroutines = 3
6196 requests = 100
6197 }
6198
6199 hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
6200
6201 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6202
6203 var wg sync.WaitGroup
6204 for i := 0; i < goroutines; i++ {
6205 wg.Add(1)
6206 go func() {
6207 defer wg.Done()
6208 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
6209 if err != nil {
6210 t.Error(err)
6211 return
6212 }
6213 defer cn.Close()
6214
6215 wg.Add(1)
6216 go func() {
6217 defer wg.Done()
6218 io.Copy(io.Discard, cn)
6219 }()
6220
6221 for j := 0; j < requests; j++ {
6222 if t.Failed() {
6223 return
6224 }
6225 _, err := cn.Write(reqBytes)
6226 if err != nil {
6227 t.Error(err)
6228 return
6229 }
6230 }
6231 }()
6232 }
6233 wg.Wait()
6234 }
6235
6236
6237
6238
6239
6240
6241 func TestServerHijackGetsBackgroundByte(t *testing.T) {
6242 run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
6243 }
6244 func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
6245 if runtime.GOOS == "plan9" {
6246 t.Skip("skipping test; see https://golang.org/issue/18657")
6247 }
6248 done := make(chan struct{})
6249 inHandler := make(chan bool, 1)
6250 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6251 defer close(done)
6252
6253
6254 inHandler <- true
6255
6256 conn, buf, err := w.(Hijacker).Hijack()
6257 if err != nil {
6258 t.Error(err)
6259 return
6260 }
6261 defer conn.Close()
6262
6263 peek, err := buf.Reader.Peek(3)
6264 if string(peek) != "foo" || err != nil {
6265 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
6266 }
6267
6268 select {
6269 case <-r.Context().Done():
6270 t.Error("context unexpectedly canceled")
6271 default:
6272 }
6273 })).ts
6274
6275 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6276 if err != nil {
6277 t.Fatal(err)
6278 }
6279 defer cn.Close()
6280 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6281 t.Fatal(err)
6282 }
6283 <-inHandler
6284 if _, err := cn.Write([]byte("foo")); err != nil {
6285 t.Fatal(err)
6286 }
6287
6288 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6289 t.Fatal(err)
6290 }
6291 <-done
6292 }
6293
6294
6295 func TestServerHijackGetsFullBody(t *testing.T) {
6296 run(t, testServerHijackGetsFullBody, []testMode{http1Mode})
6297 }
6298 func testServerHijackGetsFullBody(t *testing.T, mode testMode) {
6299 if runtime.GOOS == "plan9" {
6300 t.Skip("skipping test; see https://golang.org/issue/18657")
6301 }
6302 done := make(chan struct{})
6303 needle := strings.Repeat("x", 100*1024)
6304 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6305 defer close(done)
6306
6307 conn, buf, err := w.(Hijacker).Hijack()
6308 if err != nil {
6309 t.Error(err)
6310 return
6311 }
6312 defer conn.Close()
6313
6314 got := make([]byte, len(needle))
6315 n, err := io.ReadFull(buf.Reader, got)
6316 if n != len(needle) || string(got) != needle || err != nil {
6317 t.Errorf("Peek = %q, %v; want 'x'*4096, nil", got, err)
6318 }
6319 })).ts
6320
6321 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6322 if err != nil {
6323 t.Fatal(err)
6324 }
6325 defer cn.Close()
6326 buf := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6327 buf = append(buf, []byte(needle)...)
6328 if _, err := cn.Write(buf); err != nil {
6329 t.Fatal(err)
6330 }
6331
6332 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6333 t.Fatal(err)
6334 }
6335 <-done
6336 }
6337
6338
6339
6340
6341 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
6342 run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
6343 }
6344 func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
6345 if runtime.GOOS == "plan9" {
6346 t.Skip("skipping test; see https://golang.org/issue/18657")
6347 }
6348 done := make(chan struct{})
6349 const size = 8 << 10
6350 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6351 defer close(done)
6352
6353 conn, buf, err := w.(Hijacker).Hijack()
6354 if err != nil {
6355 t.Error(err)
6356 return
6357 }
6358 defer conn.Close()
6359 slurp, err := io.ReadAll(buf.Reader)
6360 if err != nil {
6361 t.Errorf("Copy: %v", err)
6362 }
6363 allX := true
6364 for _, v := range slurp {
6365 if v != 'x' {
6366 allX = false
6367 }
6368 }
6369 if len(slurp) != size {
6370 t.Errorf("read %d; want %d", len(slurp), size)
6371 } else if !allX {
6372 t.Errorf("read %q; want %d 'x'", slurp, size)
6373 }
6374 })).ts
6375
6376 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6377 if err != nil {
6378 t.Fatal(err)
6379 }
6380 defer cn.Close()
6381 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
6382 strings.Repeat("x", size)); err != nil {
6383 t.Fatal(err)
6384 }
6385 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6386 t.Fatal(err)
6387 }
6388
6389 <-done
6390 }
6391
6392
6393 func TestServerValidatesMethod(t *testing.T) {
6394 tests := []struct {
6395 method string
6396 want int
6397 }{
6398 {"GET", 200},
6399 {"GE(T", 400},
6400 }
6401 for _, tt := range tests {
6402 conn := newTestConn()
6403 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
6404
6405 ln := &oneConnListener{conn}
6406 go Serve(ln, serve(200))
6407 <-conn.closec
6408 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
6409 if err != nil {
6410 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
6411 continue
6412 }
6413 if res.StatusCode != tt.want {
6414 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
6415 }
6416 }
6417 }
6418
6419
6420 type eofListenerNotComparable []int
6421
6422 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
6423 func (eofListenerNotComparable) Addr() net.Addr { return nil }
6424 func (eofListenerNotComparable) Close() error { return nil }
6425
6426
6427 func TestServerListenNotComparableListener(t *testing.T) {
6428 var s Server
6429 s.Serve(make(eofListenerNotComparable, 1))
6430 }
6431
6432
6433 type countCloseListener struct {
6434 net.Listener
6435 closes int32
6436 }
6437
6438 func (p *countCloseListener) Close() error {
6439 var err error
6440 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
6441 err = p.Listener.Close()
6442 }
6443 return err
6444 }
6445
6446
6447 func TestServerCloseListenerOnce(t *testing.T) {
6448 setParallel(t)
6449 defer afterTest(t)
6450
6451 ln := newLocalListener(t)
6452 defer ln.Close()
6453
6454 cl := &countCloseListener{Listener: ln}
6455 server := &Server{}
6456 sdone := make(chan bool, 1)
6457
6458 go func() {
6459 server.Serve(cl)
6460 sdone <- true
6461 }()
6462 time.Sleep(10 * time.Millisecond)
6463 server.Shutdown(context.Background())
6464 ln.Close()
6465 <-sdone
6466
6467 nclose := atomic.LoadInt32(&cl.closes)
6468 if nclose != 1 {
6469 t.Errorf("Close calls = %v; want 1", nclose)
6470 }
6471 }
6472
6473
6474 func TestServerShutdownThenServe(t *testing.T) {
6475 var srv Server
6476 cl := &countCloseListener{Listener: nil}
6477 srv.Shutdown(context.Background())
6478 got := srv.Serve(cl)
6479 if got != ErrServerClosed {
6480 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6481 }
6482 nclose := atomic.LoadInt32(&cl.closes)
6483 if nclose != 1 {
6484 t.Errorf("Close calls = %v; want 1", nclose)
6485 }
6486 }
6487
6488
6489 func TestStripPortFromHost(t *testing.T) {
6490 mux := NewServeMux()
6491
6492 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6493 fmt.Fprintf(w, "OK")
6494 })
6495 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6496 fmt.Fprintf(w, "uh-oh!")
6497 })
6498
6499 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6500 rw := httptest.NewRecorder()
6501
6502 mux.ServeHTTP(rw, req)
6503
6504 response := rw.Body.String()
6505 if response != "OK" {
6506 t.Errorf("Response gotten was %q", response)
6507 }
6508 }
6509
6510 func TestServerContexts(t *testing.T) { run(t, testServerContexts) }
6511 func testServerContexts(t *testing.T, mode testMode) {
6512 type baseKey struct{}
6513 type connKey struct{}
6514 ch := make(chan context.Context, 1)
6515 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6516 ch <- r.Context()
6517 }), func(ts *httptest.Server) {
6518 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6519 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6520 t.Errorf("unexpected onceClose listener type %T", ln)
6521 }
6522 return context.WithValue(context.Background(), baseKey{}, "base")
6523 }
6524 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6525 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6526 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6527 }
6528 return context.WithValue(ctx, connKey{}, "conn")
6529 }
6530 }).ts
6531 res, err := ts.Client().Get(ts.URL)
6532 if err != nil {
6533 t.Fatal(err)
6534 }
6535 res.Body.Close()
6536 ctx := <-ch
6537 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6538 t.Errorf("base context key = %#v; want %q", got, want)
6539 }
6540 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6541 t.Errorf("conn context key = %#v; want %q", got, want)
6542 }
6543 }
6544
6545
6546 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6547 run(t, testConnContextNotModifyingAllContexts)
6548 }
6549 func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
6550 type connKey struct{}
6551 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6552 rw.Header().Set("Connection", "close")
6553 }), func(ts *httptest.Server) {
6554 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6555 if got := ctx.Value(connKey{}); got != nil {
6556 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6557 }
6558 return context.WithValue(ctx, connKey{}, "conn")
6559 }
6560 }).ts
6561
6562 var res *Response
6563 var err error
6564
6565 res, err = ts.Client().Get(ts.URL)
6566 if err != nil {
6567 t.Fatal(err)
6568 }
6569 res.Body.Close()
6570
6571 res, err = ts.Client().Get(ts.URL)
6572 if err != nil {
6573 t.Fatal(err)
6574 }
6575 res.Body.Close()
6576 }
6577
6578
6579
6580 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6581 run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
6582 }
6583 func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
6584 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6585 w.Write([]byte("Hello, World!"))
6586 })).ts
6587
6588 serverURL, err := url.Parse(cst.URL)
6589 if err != nil {
6590 t.Fatalf("Failed to parse server URL: %v", err)
6591 }
6592
6593 unsupportedTEs := []string{
6594 "fugazi",
6595 "foo-bar",
6596 "unknown",
6597 `" chunked"`,
6598 }
6599
6600 for _, badTE := range unsupportedTEs {
6601 http1ReqBody := fmt.Sprintf(""+
6602 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6603 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6604
6605 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6606 if err != nil {
6607 t.Errorf("%q. unexpected error: %v", badTE, err)
6608 continue
6609 }
6610
6611 wantBody := fmt.Sprintf("" +
6612 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6613 "Connection: close\r\n\r\nUnsupported transfer encoding")
6614
6615 if string(gotBody) != wantBody {
6616 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6617 }
6618 }
6619 }
6620
6621
6622 func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) }
6623 func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
6624 type setting struct {
6625 name string
6626 body []byte
6627
6628
6629
6630
6631 contentEncoding any
6632 wantContentType string
6633 }
6634
6635 settings := []*setting{
6636 {
6637 name: "gzip content-encoding, gzipped",
6638 contentEncoding: "application/gzip",
6639 wantContentType: "",
6640 body: func() []byte {
6641 buf := new(bytes.Buffer)
6642 gzw := gzip.NewWriter(buf)
6643 gzw.Write([]byte("doctype html><p>Hello</p>"))
6644 gzw.Close()
6645 return buf.Bytes()
6646 }(),
6647 },
6648 {
6649 name: "zlib content-encoding, zlibbed",
6650 contentEncoding: "application/zlib",
6651 wantContentType: "",
6652 body: func() []byte {
6653 buf := new(bytes.Buffer)
6654 zw := zlib.NewWriter(buf)
6655 zw.Write([]byte("doctype html><p>Hello</p>"))
6656 zw.Close()
6657 return buf.Bytes()
6658 }(),
6659 },
6660 {
6661 name: "no content-encoding",
6662 wantContentType: "application/x-gzip",
6663 body: func() []byte {
6664 buf := new(bytes.Buffer)
6665 gzw := gzip.NewWriter(buf)
6666 gzw.Write([]byte("doctype html><p>Hello</p>"))
6667 gzw.Close()
6668 return buf.Bytes()
6669 }(),
6670 },
6671 {
6672 name: "phony content-encoding",
6673 contentEncoding: "foo/bar",
6674 body: []byte("doctype html><p>Hello</p>"),
6675 },
6676 {
6677 name: "empty but set content-encoding",
6678 contentEncoding: "",
6679 wantContentType: "audio/mpeg",
6680 body: []byte("ID3"),
6681 },
6682 }
6683
6684 for _, tt := range settings {
6685 t.Run(tt.name, func(t *testing.T) {
6686 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6687 if tt.contentEncoding != nil {
6688 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6689 }
6690 rw.Write(tt.body)
6691 }))
6692
6693 res, err := cst.c.Get(cst.ts.URL)
6694 if err != nil {
6695 t.Fatalf("Failed to fetch URL: %v", err)
6696 }
6697 defer res.Body.Close()
6698
6699 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6700 if w != nil {
6701 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6702 } else if g != "" {
6703 t.Errorf("Unexpected Content-Encoding %q", g)
6704 }
6705 }
6706
6707 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6708 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6709 }
6710 })
6711 }
6712 }
6713
6714
6715
6716 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6717 run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
6718 }
6719 func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
6720 if testing.Short() {
6721 t.Skip("skipping in short mode")
6722 }
6723
6724 pc, curFile, _, _ := runtime.Caller(0)
6725 curFileBaseName := filepath.Base(curFile)
6726 testFuncName := runtime.FuncForPC(pc).Name()
6727
6728 timeoutMsg := "timed out here!"
6729
6730 tests := []struct {
6731 name string
6732 mustTimeout bool
6733 wantResp string
6734 }{
6735 {
6736 name: "return before timeout",
6737 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6738 },
6739 {
6740 name: "return after timeout",
6741 mustTimeout: true,
6742 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6743 len(timeoutMsg), timeoutMsg),
6744 },
6745 }
6746
6747 for _, tt := range tests {
6748 t.Run(tt.name, func(t *testing.T) {
6749 exitHandler := make(chan bool, 1)
6750 defer close(exitHandler)
6751 lastLine := make(chan int, 1)
6752
6753 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6754 w.WriteHeader(404)
6755 w.WriteHeader(404)
6756 w.WriteHeader(404)
6757 w.WriteHeader(404)
6758 _, _, line, _ := runtime.Caller(0)
6759 lastLine <- line
6760 <-exitHandler
6761 })
6762
6763 if !tt.mustTimeout {
6764 exitHandler <- true
6765 }
6766
6767 logBuf := new(strings.Builder)
6768 srvLog := log.New(logBuf, "", 0)
6769
6770 dur := 20 * time.Millisecond
6771 if !tt.mustTimeout {
6772
6773 dur = 10 * time.Second
6774 }
6775 th := TimeoutHandler(sh, dur, timeoutMsg)
6776 cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
6777 defer cst.close()
6778
6779 res, err := cst.c.Get(cst.ts.URL)
6780 if err != nil {
6781 t.Fatalf("Unexpected error: %v", err)
6782 }
6783
6784
6785
6786 res.Header.Del("Date")
6787 res.Header.Del("Content-Type")
6788
6789
6790 blob, _ := httputil.DumpResponse(res, true)
6791 if g, w := string(blob), tt.wantResp; g != w {
6792 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6793 }
6794
6795
6796
6797 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6798 if g, w := len(logEntries), 3; g != w {
6799 blob, _ := json.MarshalIndent(logEntries, "", " ")
6800 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6801 }
6802
6803 lastSpuriousLine := <-lastLine
6804 firstSpuriousLine := lastSpuriousLine - 3
6805
6806
6807 for i, logEntry := range logEntries {
6808 wantLine := firstSpuriousLine + i
6809 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6810 testFuncName, curFileBaseName, wantLine)
6811 re := regexp.MustCompile(pat)
6812 if !re.MatchString(logEntry) {
6813 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6814 }
6815 }
6816 })
6817 }
6818 }
6819
6820
6821
6822
6823 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6824 conn, err := net.Dial("tcp", host)
6825 if err != nil {
6826 return nil, err
6827 }
6828 defer conn.Close()
6829
6830 if _, err := conn.Write(http1ReqBody); err != nil {
6831 return nil, err
6832 }
6833 return io.ReadAll(conn)
6834 }
6835
6836 func BenchmarkResponseStatusLine(b *testing.B) {
6837 b.ReportAllocs()
6838 b.RunParallel(func(pb *testing.PB) {
6839 bw := bufio.NewWriter(io.Discard)
6840 var buf3 [3]byte
6841 for pb.Next() {
6842 Export_writeStatusLine(bw, true, 200, buf3[:])
6843 }
6844 })
6845 }
6846
6847 func TestDisableKeepAliveUpgrade(t *testing.T) {
6848 run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
6849 }
6850 func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
6851 if testing.Short() {
6852 t.Skip("skipping in short mode")
6853 }
6854
6855 s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6856 w.Header().Set("Connection", "Upgrade")
6857 w.Header().Set("Upgrade", "someProto")
6858 w.WriteHeader(StatusSwitchingProtocols)
6859 c, buf, err := w.(Hijacker).Hijack()
6860 if err != nil {
6861 return
6862 }
6863 defer c.Close()
6864
6865
6866
6867 io.Copy(c, buf)
6868 }), func(ts *httptest.Server) {
6869 ts.Config.SetKeepAlivesEnabled(false)
6870 }).ts
6871
6872 cl := s.Client()
6873 cl.Transport.(*Transport).DisableKeepAlives = true
6874
6875 resp, err := cl.Get(s.URL)
6876 if err != nil {
6877 t.Fatalf("failed to perform request: %v", err)
6878 }
6879 defer resp.Body.Close()
6880
6881 if resp.StatusCode != StatusSwitchingProtocols {
6882 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6883 }
6884
6885 rwc, ok := resp.Body.(io.ReadWriteCloser)
6886 if !ok {
6887 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6888 }
6889
6890 _, err = rwc.Write([]byte("hello"))
6891 if err != nil {
6892 t.Fatalf("failed to write to body: %v", err)
6893 }
6894
6895 b := make([]byte, 5)
6896 _, err = io.ReadFull(rwc, b)
6897 if err != nil {
6898 t.Fatalf("failed to read from body: %v", err)
6899 }
6900
6901 if string(b) != "hello" {
6902 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6903 }
6904 }
6905
6906 type tlogWriter struct{ t *testing.T }
6907
6908 func (w tlogWriter) Write(p []byte) (int, error) {
6909 w.t.Log(string(p))
6910 return len(p), nil
6911 }
6912
6913 func TestWriteHeaderSwitchingProtocols(t *testing.T) {
6914 run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
6915 }
6916 func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
6917 const wantBody = "want"
6918 const wantUpgrade = "someProto"
6919 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6920 w.Header().Set("Connection", "Upgrade")
6921 w.Header().Set("Upgrade", wantUpgrade)
6922 w.WriteHeader(StatusSwitchingProtocols)
6923 NewResponseController(w).Flush()
6924
6925
6926 w.WriteHeader(200)
6927 if _, err := w.Write([]byte("x")); err == nil {
6928 t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
6929 }
6930
6931 c, _, err := NewResponseController(w).Hijack()
6932 if err != nil {
6933 t.Errorf("Hijack: %v", err)
6934 return
6935 }
6936 defer c.Close()
6937 if _, err := c.Write([]byte(wantBody)); err != nil {
6938 t.Errorf("Write to hijacked body: %v", err)
6939 }
6940 }), func(ts *httptest.Server) {
6941
6942 ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
6943 }).ts
6944
6945 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6946 if err != nil {
6947 t.Fatalf("net.Dial: %v", err)
6948 }
6949 _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
6950 if err != nil {
6951 t.Fatalf("conn.Write: %v", err)
6952 }
6953 defer conn.Close()
6954
6955 r := bufio.NewReader(conn)
6956 res, err := ReadResponse(r, &Request{Method: "GET"})
6957 if err != nil {
6958 t.Fatal("ReadResponse error:", err)
6959 }
6960 if res.StatusCode != StatusSwitchingProtocols {
6961 t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
6962 }
6963 if got := res.Header.Get("Upgrade"); got != wantUpgrade {
6964 t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
6965 }
6966 body, err := io.ReadAll(r)
6967 if err != nil {
6968 t.Error(err)
6969 }
6970 if string(body) != wantBody {
6971 t.Errorf("Response body = %q, want %q", string(body), wantBody)
6972 }
6973 }
6974
6975 func TestMuxRedirectRelative(t *testing.T) {
6976 setParallel(t)
6977 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
6978 if err != nil {
6979 t.Errorf("%s", err)
6980 }
6981 mux := NewServeMux()
6982 resp := httptest.NewRecorder()
6983 mux.ServeHTTP(resp, req)
6984 if got, want := resp.Header().Get("Location"), "/"; got != want {
6985 t.Errorf("Location header expected %q; got %q", want, got)
6986 }
6987 if got, want := resp.Code, StatusTemporaryRedirect; got != want {
6988 t.Errorf("Expected response code %d; got %d", want, got)
6989 }
6990 }
6991
6992
6993 func TestQuerySemicolon(t *testing.T) {
6994 t.Cleanup(func() { afterTest(t) })
6995
6996 tests := []struct {
6997 query string
6998 xNoSemicolons string
6999 xWithSemicolons string
7000 expectParseFormErr bool
7001 }{
7002 {"?a=1;x=bad&x=good", "good", "bad", true},
7003 {"?a=1;b=bad&x=good", "good", "good", true},
7004 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
7005 {"?a=1;x=good;x=bad", "", "good", true},
7006 }
7007
7008 run(t, func(t *testing.T, mode testMode) {
7009 for _, tt := range tests {
7010 t.Run(tt.query+"/allow=false", func(t *testing.T) {
7011 allowSemicolons := false
7012 testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
7013 })
7014 t.Run(tt.query+"/allow=true", func(t *testing.T) {
7015 allowSemicolons, expectParseFormErr := true, false
7016 testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
7017 })
7018 }
7019 })
7020 }
7021
7022 func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
7023 writeBackX := func(w ResponseWriter, r *Request) {
7024 x := r.URL.Query().Get("x")
7025 if expectParseFormErr {
7026 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
7027 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
7028 }
7029 } else {
7030 if err := r.ParseForm(); err != nil {
7031 t.Errorf("expected no error from ParseForm, got %v", err)
7032 }
7033 }
7034 if got := r.FormValue("x"); x != got {
7035 t.Errorf("got %q from FormValue, want %q", got, x)
7036 }
7037 fmt.Fprintf(w, "%s", x)
7038 }
7039
7040 h := Handler(HandlerFunc(writeBackX))
7041 if allowSemicolons {
7042 h = AllowQuerySemicolons(h)
7043 }
7044
7045 logBuf := &strings.Builder{}
7046 ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
7047 ts.Config.ErrorLog = log.New(logBuf, "", 0)
7048 }).ts
7049
7050 req, _ := NewRequest("GET", ts.URL+query, nil)
7051 res, err := ts.Client().Do(req)
7052 if err != nil {
7053 t.Fatal(err)
7054 }
7055 slurp, _ := io.ReadAll(res.Body)
7056 res.Body.Close()
7057 if got, want := res.StatusCode, 200; got != want {
7058 t.Errorf("Status = %d; want = %d", got, want)
7059 }
7060 if got, want := string(slurp), wantX; got != want {
7061 t.Errorf("Body = %q; want = %q", got, want)
7062 }
7063 }
7064
7065 func TestMaxBytesHandler(t *testing.T) {
7066
7067 defer afterTest(t)
7068
7069 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
7070 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
7071 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
7072 func(t *testing.T) {
7073 run(t, func(t *testing.T, mode testMode) {
7074 testMaxBytesHandler(t, mode, maxSize, requestSize)
7075 }, testNotParallel)
7076 })
7077 }
7078 }
7079 }
7080
7081 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
7082 runTimeSensitiveTest(t, []time.Duration{
7083 1 * time.Millisecond,
7084 5 * time.Millisecond,
7085 10 * time.Millisecond,
7086 50 * time.Millisecond,
7087 100 * time.Millisecond,
7088 500 * time.Millisecond,
7089 time.Second,
7090 5 * time.Second,
7091 }, func(t *testing.T, timeout time.Duration) error {
7092 SetRSTAvoidanceDelay(t, timeout)
7093 t.Logf("set RST avoidance delay to %v", timeout)
7094
7095 var (
7096 handlerN int64
7097 handlerErr error
7098 )
7099 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
7100 var buf bytes.Buffer
7101 handlerN, handlerErr = io.Copy(&buf, r.Body)
7102 io.Copy(w, &buf)
7103 })
7104
7105 cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
7106
7107
7108 defer cst.close()
7109 ts := cst.ts
7110 c := ts.Client()
7111
7112 body := strings.Repeat("a", int(requestSize))
7113 var wg sync.WaitGroup
7114 defer wg.Wait()
7115 getBody := func() (io.ReadCloser, error) {
7116 wg.Add(1)
7117 body := &wgReadCloser{
7118 Reader: strings.NewReader(body),
7119 wg: &wg,
7120 }
7121 return body, nil
7122 }
7123 reqBody, _ := getBody()
7124 req, err := NewRequest("POST", ts.URL, reqBody)
7125 if err != nil {
7126 reqBody.Close()
7127 t.Fatal(err)
7128 }
7129 req.ContentLength = int64(len(body))
7130 req.GetBody = getBody
7131 req.Header.Set("Content-Type", "text/plain")
7132
7133 var buf strings.Builder
7134 res, err := c.Do(req)
7135 if err != nil {
7136 return fmt.Errorf("unexpected connection error: %v", err)
7137 } else {
7138 _, err = io.Copy(&buf, res.Body)
7139 res.Body.Close()
7140 if err != nil {
7141 return fmt.Errorf("unexpected read error: %v", err)
7142 }
7143 }
7144
7145
7146
7147
7148 if handlerN > maxSize {
7149 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
7150 }
7151 if requestSize > maxSize && handlerErr == nil {
7152 t.Error("expected error on handler side; got nil")
7153 }
7154 if requestSize <= maxSize {
7155 if handlerErr != nil {
7156 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
7157 }
7158 if handlerN != requestSize {
7159 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
7160 }
7161 }
7162 if buf.Len() != int(handlerN) {
7163 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
7164 }
7165
7166 return nil
7167 })
7168 }
7169
7170 func TestEarlyHints(t *testing.T) {
7171 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7172 h := w.Header()
7173 h.Add("Link", "</style.css>; rel=preload; as=style")
7174 h.Add("Link", "</script.js>; rel=preload; as=script")
7175 w.WriteHeader(StatusEarlyHints)
7176
7177 h.Add("Link", "</foo.js>; rel=preload; as=script")
7178 w.WriteHeader(StatusEarlyHints)
7179
7180 w.Write([]byte("stuff"))
7181 }))
7182
7183 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7184 expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: "
7185 if !strings.Contains(got, expected) {
7186 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7187 }
7188 }
7189 func TestProcessing(t *testing.T) {
7190 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7191 w.WriteHeader(StatusProcessing)
7192 w.Write([]byte("stuff"))
7193 }))
7194
7195 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7196 expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: "
7197 if !strings.Contains(got, expected) {
7198 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7199 }
7200 }
7201
7202 func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) }
7203 func testParseFormCleanup(t *testing.T, mode testMode) {
7204 if mode == http2Mode {
7205 t.Skip("https://go.dev/issue/20253")
7206 }
7207
7208 const maxMemory = 1024
7209 const key = "file"
7210
7211 if runtime.GOOS == "windows" {
7212
7213 t.Skip("https://go.dev/issue/25965")
7214 }
7215
7216 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7217 r.ParseMultipartForm(maxMemory)
7218 f, _, err := r.FormFile(key)
7219 if err != nil {
7220 t.Errorf("r.FormFile(%q) = %v", key, err)
7221 return
7222 }
7223 of, ok := f.(*os.File)
7224 if !ok {
7225 t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
7226 return
7227 }
7228 w.Write([]byte(of.Name()))
7229 }))
7230
7231 fBuf := new(bytes.Buffer)
7232 mw := multipart.NewWriter(fBuf)
7233 mf, err := mw.CreateFormFile(key, "myfile.txt")
7234 if err != nil {
7235 t.Fatal(err)
7236 }
7237 if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
7238 t.Fatal(err)
7239 }
7240 if err := mw.Close(); err != nil {
7241 t.Fatal(err)
7242 }
7243 req, err := NewRequest("POST", cst.ts.URL, fBuf)
7244 if err != nil {
7245 t.Fatal(err)
7246 }
7247 req.Header.Set("Content-Type", mw.FormDataContentType())
7248 res, err := cst.c.Do(req)
7249 if err != nil {
7250 t.Fatal(err)
7251 }
7252 defer res.Body.Close()
7253 fname, err := io.ReadAll(res.Body)
7254 if err != nil {
7255 t.Fatal(err)
7256 }
7257 cst.close()
7258 if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
7259 t.Errorf("file %q exists after HTTP handler returned", string(fname))
7260 }
7261 }
7262
7263 func TestHeadBody(t *testing.T) {
7264 const identityMode = false
7265 const chunkedMode = true
7266 run(t, func(t *testing.T, mode testMode) {
7267 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
7268 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
7269 })
7270 }
7271
7272 func TestGetBody(t *testing.T) {
7273 const identityMode = false
7274 const chunkedMode = true
7275 run(t, func(t *testing.T, mode testMode) {
7276 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
7277 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
7278 })
7279 }
7280
7281 func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
7282 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7283 b, err := io.ReadAll(r.Body)
7284 if err != nil {
7285 t.Errorf("server reading body: %v", err)
7286 return
7287 }
7288 w.Header().Set("X-Request-Body", string(b))
7289 w.Header().Set("Content-Length", "0")
7290 }))
7291 defer cst.close()
7292 for _, reqBody := range []string{
7293 "",
7294 "",
7295 "request_body",
7296 "",
7297 } {
7298 var bodyReader io.Reader
7299 if reqBody != "" {
7300 bodyReader = strings.NewReader(reqBody)
7301 if chunked {
7302 bodyReader = bufio.NewReader(bodyReader)
7303 }
7304 }
7305 req, err := NewRequest(method, cst.ts.URL, bodyReader)
7306 if err != nil {
7307 t.Fatal(err)
7308 }
7309 res, err := cst.c.Do(req)
7310 if err != nil {
7311 t.Fatal(err)
7312 }
7313 res.Body.Close()
7314 if got, want := res.StatusCode, 200; got != want {
7315 t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
7316 }
7317 if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
7318 t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
7319 }
7320 }
7321 }
7322
7323
7324
7325 func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
7326 func testDisableContentLength(t *testing.T, mode testMode) {
7327 noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7328 w.Header()["Content-Length"] = nil
7329 fmt.Fprintf(w, "OK")
7330 }))
7331
7332 res, err := noCL.c.Get(noCL.ts.URL)
7333 if err != nil {
7334 t.Fatal(err)
7335 }
7336 if got, haveCL := res.Header["Content-Length"]; haveCL {
7337 t.Errorf("Unexpected Content-Length: %q", got)
7338 }
7339 if err := res.Body.Close(); err != nil {
7340 t.Fatal(err)
7341 }
7342
7343 withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7344 fmt.Fprintf(w, "OK")
7345 }))
7346
7347 res, err = withCL.c.Get(withCL.ts.URL)
7348 if err != nil {
7349 t.Fatal(err)
7350 }
7351 if got := res.Header.Get("Content-Length"); got != "2" {
7352 t.Errorf("Content-Length: %q; want 2", got)
7353 }
7354 if err := res.Body.Close(); err != nil {
7355 t.Fatal(err)
7356 }
7357 }
7358
7359 func TestErrorContentLength(t *testing.T) { run(t, testErrorContentLength) }
7360 func testErrorContentLength(t *testing.T, mode testMode) {
7361 const errorBody = "an error occurred"
7362 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7363 w.Header().Set("Content-Length", "1000")
7364 Error(w, errorBody, 400)
7365 }))
7366 res, err := cst.c.Get(cst.ts.URL)
7367 if err != nil {
7368 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7369 }
7370 defer res.Body.Close()
7371 body, err := io.ReadAll(res.Body)
7372 if err != nil {
7373 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7374 }
7375 if string(body) != errorBody+"\n" {
7376 t.Fatalf("read body: %q, want %q", string(body), errorBody)
7377 }
7378 }
7379
7380 func TestError(t *testing.T) {
7381 w := httptest.NewRecorder()
7382 w.Header().Set("Content-Length", "1")
7383 w.Header().Set("X-Content-Type-Options", "scratch and sniff")
7384 w.Header().Set("Other", "foo")
7385 Error(w, "oops", 432)
7386
7387 h := w.Header()
7388 for _, hdr := range []string{"Content-Length"} {
7389 if v, ok := h[hdr]; ok {
7390 t.Errorf("%s: %q, want not present", hdr, v)
7391 }
7392 }
7393 if v := h.Get("Content-Type"); v != "text/plain; charset=utf-8" {
7394 t.Errorf("Content-Type: %q, want %q", v, "text/plain; charset=utf-8")
7395 }
7396 if v := h.Get("X-Content-Type-Options"); v != "nosniff" {
7397 t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff")
7398 }
7399 }
7400
7401 func TestServerReadAfterWriteHeader100Continue(t *testing.T) {
7402 run(t, testServerReadAfterWriteHeader100Continue)
7403 }
7404 func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) {
7405 t.Skip("https://go.dev/issue/67555")
7406 body := []byte("body")
7407 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7408 w.WriteHeader(200)
7409 NewResponseController(w).Flush()
7410 io.ReadAll(r.Body)
7411 w.Write(body)
7412 }), func(tr *Transport) {
7413 tr.ExpectContinueTimeout = 24 * time.Hour
7414 })
7415
7416 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7417 req.Header.Set("Expect", "100-continue")
7418 res, err := cst.c.Do(req)
7419 if err != nil {
7420 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7421 }
7422 defer res.Body.Close()
7423 got, err := io.ReadAll(res.Body)
7424 if err != nil {
7425 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7426 }
7427 if !bytes.Equal(got, body) {
7428 t.Fatalf("response body = %q, want %q", got, body)
7429 }
7430 }
7431
7432 func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
7433 run(t, testServerReadAfterHandlerDone100Continue)
7434 }
7435 func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
7436 t.Skip("https://go.dev/issue/67555")
7437 readyc := make(chan struct{})
7438 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7439 go func() {
7440 <-readyc
7441 io.ReadAll(r.Body)
7442 <-readyc
7443 }()
7444 }), func(tr *Transport) {
7445 tr.ExpectContinueTimeout = 24 * time.Hour
7446 })
7447
7448 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7449 req.Header.Set("Expect", "100-continue")
7450 res, err := cst.c.Do(req)
7451 if err != nil {
7452 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7453 }
7454 res.Body.Close()
7455 readyc <- struct{}{}
7456 readyc <- struct{}{}
7457 }
7458
7459 func TestServerReadAfterHandlerAbort100Continue(t *testing.T) {
7460 run(t, testServerReadAfterHandlerAbort100Continue)
7461 }
7462 func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) {
7463 t.Skip("https://go.dev/issue/67555")
7464 readyc := make(chan struct{})
7465 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7466 go func() {
7467 <-readyc
7468 io.ReadAll(r.Body)
7469 <-readyc
7470 }()
7471 panic(ErrAbortHandler)
7472 }), func(tr *Transport) {
7473 tr.ExpectContinueTimeout = 24 * time.Hour
7474 })
7475
7476 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7477 req.Header.Set("Expect", "100-continue")
7478 res, err := cst.c.Do(req)
7479 if err == nil {
7480 res.Body.Close()
7481 }
7482 readyc <- struct{}{}
7483 readyc <- struct{}{}
7484 }
7485
7486 func TestInvalidChunkedBodies(t *testing.T) {
7487 for _, test := range []struct {
7488 name string
7489 b string
7490 }{{
7491 name: "bare LF in chunk size",
7492 b: "1\na\r\n0\r\n\r\n",
7493 }, {
7494 name: "bare LF at body end",
7495 b: "1\r\na\r\n0\r\n\n",
7496 }} {
7497 t.Run(test.name, func(t *testing.T) {
7498 reqc := make(chan error)
7499 ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7500 got, err := io.ReadAll(r.Body)
7501 if err == nil {
7502 t.Logf("read body: %q", got)
7503 }
7504 reqc <- err
7505 })).ts
7506
7507 serverURL, err := url.Parse(ts.URL)
7508 if err != nil {
7509 t.Fatal(err)
7510 }
7511
7512 conn, err := net.Dial("tcp", serverURL.Host)
7513 if err != nil {
7514 t.Fatal(err)
7515 }
7516
7517 if _, err := conn.Write([]byte(
7518 "POST / HTTP/1.1\r\n" +
7519 "Host: localhost\r\n" +
7520 "Transfer-Encoding: chunked\r\n" +
7521 "Connection: close\r\n" +
7522 "\r\n" +
7523 test.b)); err != nil {
7524 t.Fatal(err)
7525 }
7526 conn.(*net.TCPConn).CloseWrite()
7527
7528 if err := <-reqc; err == nil {
7529 t.Errorf("server handler: io.ReadAll(r.Body) succeeded, want error")
7530 }
7531 })
7532 }
7533 }
7534
7535
7536 func TestServerTLSNextProtos(t *testing.T) {
7537 run(t, testServerTLSNextProtos, []testMode{https1Mode, http2Mode})
7538 }
7539 func testServerTLSNextProtos(t *testing.T, mode testMode) {
7540 CondSkipHTTP2(t)
7541
7542 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7543 if err != nil {
7544 t.Fatal(err)
7545 }
7546 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7547 if err != nil {
7548 t.Fatal(err)
7549 }
7550 certpool := x509.NewCertPool()
7551 certpool.AddCert(leafCert)
7552
7553 protos := new(Protocols)
7554 switch mode {
7555 case https1Mode:
7556 protos.SetHTTP1(true)
7557 case http2Mode:
7558 protos.SetHTTP2(true)
7559 }
7560
7561 wantNextProtos := []string{"http/1.1", "h2", "other"}
7562 nextProtos := slices.Clone(wantNextProtos)
7563
7564
7565 srv := &Server{
7566 TLSConfig: &tls.Config{
7567 Certificates: []tls.Certificate{cert},
7568 NextProtos: nextProtos,
7569 },
7570 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {}),
7571 Protocols: protos,
7572 }
7573 tr := &Transport{
7574 TLSClientConfig: &tls.Config{
7575 RootCAs: certpool,
7576 NextProtos: nextProtos,
7577 },
7578 Protocols: protos,
7579 }
7580
7581 listener := newLocalListener(t)
7582 srvc := make(chan error, 1)
7583 go func() {
7584 srvc <- srv.ServeTLS(listener, "", "")
7585 }()
7586 t.Cleanup(func() {
7587 srv.Close()
7588 <-srvc
7589 })
7590
7591 client := &Client{Transport: tr}
7592 resp, err := client.Get("https://" + listener.Addr().String())
7593 if err != nil {
7594 t.Fatal(err)
7595 }
7596 resp.Body.Close()
7597
7598 if !slices.Equal(nextProtos, wantNextProtos) {
7599 t.Fatalf("after running test: original NextProtos slice = %v, want %v", nextProtos, wantNextProtos)
7600 }
7601 }
7602
View as plain text