Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 "crypto/tls"
16 "crypto/x509"
17 "encoding/json"
18 "errors"
19 "fmt"
20 "internal/synctest"
21 "internal/testenv"
22 "io"
23 "log"
24 "math/rand"
25 "mime/multipart"
26 "net"
27 . "net/http"
28 "net/http/httptest"
29 "net/http/httptrace"
30 "net/http/httputil"
31 "net/http/internal"
32 "net/http/internal/testcert"
33 "net/url"
34 "os"
35 "path/filepath"
36 "reflect"
37 "regexp"
38 "runtime"
39 "slices"
40 "strconv"
41 "strings"
42 "sync"
43 "sync/atomic"
44 "syscall"
45 "testing"
46 "time"
47 )
48
49 type dummyAddr string
50 type oneConnListener struct {
51 conn net.Conn
52 }
53
54 func (l *oneConnListener) Accept() (c net.Conn, err error) {
55 c = l.conn
56 if c == nil {
57 err = io.EOF
58 return
59 }
60 err = nil
61 l.conn = nil
62 return
63 }
64
65 func (l *oneConnListener) Close() error {
66 return nil
67 }
68
69 func (l *oneConnListener) Addr() net.Addr {
70 return dummyAddr("test-address")
71 }
72
73 func (a dummyAddr) Network() string {
74 return string(a)
75 }
76
77 func (a dummyAddr) String() string {
78 return string(a)
79 }
80
81 type noopConn struct{}
82
83 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
84 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
85 func (noopConn) SetDeadline(t time.Time) error { return nil }
86 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
87 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
88
89 type rwTestConn struct {
90 io.Reader
91 io.Writer
92 noopConn
93
94 closeFunc func() error
95 closec chan bool
96 }
97
98 func (c *rwTestConn) Close() error {
99 if c.closeFunc != nil {
100 return c.closeFunc()
101 }
102 select {
103 case c.closec <- true:
104 default:
105 }
106 return nil
107 }
108
109 type testConn struct {
110 readMu sync.Mutex
111 readBuf bytes.Buffer
112 writeBuf bytes.Buffer
113 closec chan bool
114 noopConn
115 }
116
117 func newTestConn() *testConn {
118 return &testConn{closec: make(chan bool, 1)}
119 }
120
121 func (c *testConn) Read(b []byte) (int, error) {
122 c.readMu.Lock()
123 defer c.readMu.Unlock()
124 return c.readBuf.Read(b)
125 }
126
127 func (c *testConn) Write(b []byte) (int, error) {
128 return c.writeBuf.Write(b)
129 }
130
131 func (c *testConn) Close() error {
132 select {
133 case c.closec <- true:
134 default:
135 }
136 return nil
137 }
138
139
140
141 func reqBytes(req string) []byte {
142 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
143 }
144
145 type handlerTest struct {
146 logbuf bytes.Buffer
147 handler Handler
148 }
149
150 func newHandlerTest(h Handler) handlerTest {
151 return handlerTest{handler: h}
152 }
153
154 func (ht *handlerTest) rawResponse(req string) string {
155 reqb := reqBytes(req)
156 var output strings.Builder
157 conn := &rwTestConn{
158 Reader: bytes.NewReader(reqb),
159 Writer: &output,
160 closec: make(chan bool, 1),
161 }
162 ln := &oneConnListener{conn: conn}
163 srv := &Server{
164 ErrorLog: log.New(&ht.logbuf, "", 0),
165 Handler: ht.handler,
166 }
167 go srv.Serve(ln)
168 <-conn.closec
169 return output.String()
170 }
171
172 func TestConsumingBodyOnNextConn(t *testing.T) {
173 t.Parallel()
174 defer afterTest(t)
175 conn := new(testConn)
176 for i := 0; i < 2; i++ {
177 conn.readBuf.Write([]byte(
178 "POST / HTTP/1.1\r\n" +
179 "Host: test\r\n" +
180 "Content-Length: 11\r\n" +
181 "\r\n" +
182 "foo=1&bar=1"))
183 }
184
185 reqNum := 0
186 ch := make(chan *Request)
187 servech := make(chan error)
188 listener := &oneConnListener{conn}
189 handler := func(res ResponseWriter, req *Request) {
190 reqNum++
191 ch <- req
192 }
193
194 go func() {
195 servech <- Serve(listener, HandlerFunc(handler))
196 }()
197
198 var req *Request
199 req = <-ch
200 if req == nil {
201 t.Fatal("Got nil first request.")
202 }
203 if req.Method != "POST" {
204 t.Errorf("For request #1's method, got %q; expected %q",
205 req.Method, "POST")
206 }
207
208 req = <-ch
209 if req == nil {
210 t.Fatal("Got nil first request.")
211 }
212 if req.Method != "POST" {
213 t.Errorf("For request #2's method, got %q; expected %q",
214 req.Method, "POST")
215 }
216
217 if serveerr := <-servech; serveerr != io.EOF {
218 t.Errorf("Serve returned %q; expected EOF", serveerr)
219 }
220 }
221
222 type stringHandler string
223
224 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
225 w.Header().Set("Result", string(s))
226 }
227
228 var handlers = []struct {
229 pattern string
230 msg string
231 }{
232 {"/", "Default"},
233 {"/someDir/", "someDir"},
234 {"/#/", "hash"},
235 {"someHost.com/someDir/", "someHost.com/someDir"},
236 }
237
238 var vtests = []struct {
239 url string
240 expected string
241 }{
242 {"http://localhost/someDir/apage", "someDir"},
243 {"http://localhost/%23/apage", "hash"},
244 {"http://localhost/otherDir/apage", "Default"},
245 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
246 {"http://otherHost.com/someDir/apage", "someDir"},
247 {"http://otherHost.com/aDir/apage", "Default"},
248
249 {"http://localhost/someDir", "/someDir/"},
250 {"http://localhost/%23", "/%23/"},
251 {"http://someHost.com/someDir", "/someDir/"},
252 }
253
254 func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
255 func testHostHandlers(t *testing.T, mode testMode) {
256 mux := NewServeMux()
257 for _, h := range handlers {
258 mux.Handle(h.pattern, stringHandler(h.msg))
259 }
260 ts := newClientServerTest(t, mode, mux).ts
261
262 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
263 if err != nil {
264 t.Fatal(err)
265 }
266 defer conn.Close()
267 cc := httputil.NewClientConn(conn, nil)
268 for _, vt := range vtests {
269 var r *Response
270 var req Request
271 if req.URL, err = url.Parse(vt.url); err != nil {
272 t.Errorf("cannot parse url: %v", err)
273 continue
274 }
275 if err := cc.Write(&req); err != nil {
276 t.Errorf("writing request: %v", err)
277 continue
278 }
279 r, err := cc.Read(&req)
280 if err != nil {
281 t.Errorf("reading response: %v", err)
282 continue
283 }
284 switch r.StatusCode {
285 case StatusOK:
286 s := r.Header.Get("Result")
287 if s != vt.expected {
288 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
289 }
290 case StatusMovedPermanently:
291 s := r.Header.Get("Location")
292 if s != vt.expected {
293 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
294 }
295 default:
296 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
297 }
298 }
299 }
300
301 var serveMuxRegister = []struct {
302 pattern string
303 h Handler
304 }{
305 {"/dir/", serve(200)},
306 {"/search", serve(201)},
307 {"codesearch.google.com/search", serve(202)},
308 {"codesearch.google.com/", serve(203)},
309 {"example.com/", HandlerFunc(checkQueryStringHandler)},
310 }
311
312
313 func serve(code int) HandlerFunc {
314 return func(w ResponseWriter, r *Request) {
315 w.WriteHeader(code)
316 }
317 }
318
319
320
321
322 func checkQueryStringHandler(w ResponseWriter, r *Request) {
323 u := *r.URL
324 u.Scheme = "http"
325 u.Host = r.Host
326 u.RawQuery = ""
327 if "http://"+r.URL.RawQuery == u.String() {
328 w.WriteHeader(200)
329 } else {
330 w.WriteHeader(500)
331 }
332 }
333
334 var serveMuxTests = []struct {
335 method string
336 host string
337 path string
338 code int
339 pattern string
340 }{
341 {"GET", "google.com", "/", 404, ""},
342 {"GET", "google.com", "/dir", 301, "/dir/"},
343 {"GET", "google.com", "/dir/", 200, "/dir/"},
344 {"GET", "google.com", "/dir/file", 200, "/dir/"},
345 {"GET", "google.com", "/search", 201, "/search"},
346 {"GET", "google.com", "/search/", 404, ""},
347 {"GET", "google.com", "/search/foo", 404, ""},
348 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
349 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
350 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
351 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
352 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
353 {"GET", "images.google.com", "/search", 201, "/search"},
354 {"GET", "images.google.com", "/search/", 404, ""},
355 {"GET", "images.google.com", "/search/foo", 404, ""},
356 {"GET", "google.com", "/../search", 301, "/search"},
357 {"GET", "google.com", "/dir/..", 301, ""},
358 {"GET", "google.com", "/dir/..", 301, ""},
359 {"GET", "google.com", "/dir/./file", 301, "/dir/"},
360
361
362
363 {"CONNECT", "google.com", "/dir", 301, "/dir/"},
364 {"CONNECT", "google.com", "/../search", 404, ""},
365 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
366 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
367 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
368 }
369
370 func TestServeMuxHandler(t *testing.T) {
371 setParallel(t)
372 mux := NewServeMux()
373 for _, e := range serveMuxRegister {
374 mux.Handle(e.pattern, e.h)
375 }
376
377 for _, tt := range serveMuxTests {
378 r := &Request{
379 Method: tt.method,
380 Host: tt.host,
381 URL: &url.URL{
382 Path: tt.path,
383 },
384 }
385 h, pattern := mux.Handler(r)
386 rr := httptest.NewRecorder()
387 h.ServeHTTP(rr, r)
388 if pattern != tt.pattern || rr.Code != tt.code {
389 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
390 }
391 }
392 }
393
394
395 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
396 setParallel(t)
397 defer func() {
398 if err := recover(); err == nil {
399 t.Error("expected call to mux.HandleFunc to panic")
400 }
401 }()
402 mux := NewServeMux()
403 mux.HandleFunc("/", nil)
404 }
405
406 var serveMuxTests2 = []struct {
407 method string
408 host string
409 url string
410 code int
411 redirOk bool
412 }{
413 {"GET", "google.com", "/", 404, false},
414 {"GET", "example.com", "/test/?example.com/test/", 200, false},
415 {"GET", "example.com", "test/?example.com/test/", 200, true},
416 }
417
418
419
420 func TestServeMuxHandlerRedirects(t *testing.T) {
421 setParallel(t)
422 mux := NewServeMux()
423 for _, e := range serveMuxRegister {
424 mux.Handle(e.pattern, e.h)
425 }
426
427 for _, tt := range serveMuxTests2 {
428 tries := 1
429 turl := tt.url
430 for {
431 u, e := url.Parse(turl)
432 if e != nil {
433 t.Fatal(e)
434 }
435 r := &Request{
436 Method: tt.method,
437 Host: tt.host,
438 URL: u,
439 }
440 h, _ := mux.Handler(r)
441 rr := httptest.NewRecorder()
442 h.ServeHTTP(rr, r)
443 if rr.Code != 301 {
444 if rr.Code != tt.code {
445 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
446 }
447 break
448 }
449 if !tt.redirOk {
450 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
451 break
452 }
453 turl = rr.HeaderMap.Get("Location")
454 tries--
455 }
456 if tries < 0 {
457 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
458 }
459 }
460 }
461
462
463 func TestMuxRedirectLeadingSlashes(t *testing.T) {
464 setParallel(t)
465 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
466 for _, path := range paths {
467 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
468 if err != nil {
469 t.Errorf("%s", err)
470 }
471 mux := NewServeMux()
472 resp := httptest.NewRecorder()
473
474 mux.ServeHTTP(resp, req)
475
476 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
477 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
478 return
479 }
480
481 if code, expected := resp.Code, StatusMovedPermanently; code != expected {
482 t.Errorf("Expected response code of StatusMovedPermanently; got %d", code)
483 return
484 }
485 }
486 }
487
488
489
490
491
492 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
493 run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
494 }
495 func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
496 writeBackQuery := func(w ResponseWriter, r *Request) {
497 fmt.Fprintf(w, "%s", r.URL.RawQuery)
498 }
499
500 mux := NewServeMux()
501 mux.HandleFunc("/testOne", writeBackQuery)
502 mux.HandleFunc("/testTwo/", writeBackQuery)
503 mux.HandleFunc("/testThree", writeBackQuery)
504 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
505 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
506 })
507
508 ts := newClientServerTest(t, mode, mux).ts
509
510 tests := [...]struct {
511 path string
512 method string
513 want string
514 statusOk bool
515 }{
516 0: {"/testOne?this=that", "GET", "this=that", true},
517 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
518 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
519 3: {"/testTwo?", "GET", "", true},
520 4: {"/testThree?foo", "GET", "foo", true},
521 5: {"/testThree/?foo", "GET", "foo:bar", true},
522 6: {"/testThree?foo", "CONNECT", "foo", true},
523 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
524
525
526 8: {"/testOne/foo/..?foo", "GET", "foo", true},
527 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
528 }
529
530 for i, tt := range tests {
531 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
532 res, err := ts.Client().Do(req)
533 if err != nil {
534 continue
535 }
536 slurp, _ := io.ReadAll(res.Body)
537 res.Body.Close()
538 if !tt.statusOk {
539 if got, want := res.StatusCode, 404; got != want {
540 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
541 }
542 }
543 if got, want := string(slurp), tt.want; got != want {
544 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
545 }
546 }
547 }
548
549 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
550 setParallel(t)
551
552 mux := NewServeMux()
553 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
554 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
555 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
556 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
557 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
558 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
559
560 tests := []struct {
561 method string
562 url string
563 code int
564 loc string
565 want string
566 }{
567 {"GET", "http://example.com/", 404, "", ""},
568 {"GET", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
569 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
570 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
571 {"GET", "http://example.com/pkg/baz", 301, "/pkg/baz/", ""},
572 {"GET", "http://example.com:3000/pkg/foo", 301, "/pkg/foo/", ""},
573 {"CONNECT", "http://example.com/", 404, "", ""},
574 {"CONNECT", "http://example.com:3000/", 404, "", ""},
575 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
576 {"CONNECT", "http://example.com/pkg/foo", 301, "/pkg/foo/", ""},
577 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
578 {"CONNECT", "http://example.com:3000/pkg/baz", 301, "/pkg/baz/", ""},
579 {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""},
580 }
581
582 for i, tt := range tests {
583 req, _ := NewRequest(tt.method, tt.url, nil)
584 w := httptest.NewRecorder()
585 mux.ServeHTTP(w, req)
586
587 if got, want := w.Code, tt.code; got != want {
588 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
589 }
590
591 if tt.code == 301 {
592 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
593 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
594 }
595 } else {
596 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
597 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
598 }
599 }
600 }
601 }
602
603
604
605
606 func TestMuxNoSlashRedirectWithTrailingSlash(t *testing.T) {
607 mux := NewServeMux()
608 mux.HandleFunc("/{x}/", func(w ResponseWriter, r *Request) {
609 fmt.Fprintln(w, "ok")
610 })
611 w := httptest.NewRecorder()
612 req, _ := NewRequest("GET", "/", nil)
613 mux.ServeHTTP(w, req)
614 if g, w := w.Code, 404; g != w {
615 t.Errorf("got %d, want %d", g, w)
616 }
617 }
618
619
620
621
622 func TestMuxNoSlash405WithTrailingSlash(t *testing.T) {
623 mux := NewServeMux()
624 mux.HandleFunc("GET /{x}/", func(w ResponseWriter, r *Request) {
625 fmt.Fprintln(w, "ok")
626 })
627 w := httptest.NewRecorder()
628 req, _ := NewRequest("GET", "/", nil)
629 mux.ServeHTTP(w, req)
630 if g, w := w.Code, 404; g != w {
631 t.Errorf("got %d, want %d", g, w)
632 }
633 }
634
635 func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
636 func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
637 mux := NewServeMux()
638 newClientServerTest(t, mode, mux)
639 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
640 }
641
642 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
643 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
644 func benchmarkServeMux(b *testing.B, runHandler bool) {
645 type test struct {
646 path string
647 code int
648 req *Request
649 }
650
651
652 var tests []test
653 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
654 for _, e := range endpoints {
655 for i := 200; i < 230; i++ {
656 p := fmt.Sprintf("/%s/%d/", e, i)
657 tests = append(tests, test{
658 path: p,
659 code: i,
660 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
661 })
662 }
663 }
664 mux := NewServeMux()
665 for _, tt := range tests {
666 mux.Handle(tt.path, serve(tt.code))
667 }
668
669 rw := httptest.NewRecorder()
670 b.ReportAllocs()
671 b.ResetTimer()
672 for i := 0; i < b.N; i++ {
673 for _, tt := range tests {
674 *rw = httptest.ResponseRecorder{}
675 h, pattern := mux.Handler(tt.req)
676 if runHandler {
677 h.ServeHTTP(rw, tt.req)
678 if pattern != tt.path || rw.Code != tt.code {
679 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
680 }
681 }
682 }
683 }
684 }
685
686 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
687 func testServerTimeouts(t *testing.T, mode testMode) {
688 runTimeSensitiveTest(t, []time.Duration{
689 10 * time.Millisecond,
690 50 * time.Millisecond,
691 100 * time.Millisecond,
692 500 * time.Millisecond,
693 1 * time.Second,
694 }, func(t *testing.T, timeout time.Duration) error {
695 return testServerTimeoutsWithTimeout(t, timeout, mode)
696 })
697 }
698
699 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
700 var reqNum atomic.Int32
701 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
702 fmt.Fprintf(res, "req=%d", reqNum.Add(1))
703 }), func(ts *httptest.Server) {
704 ts.Config.ReadTimeout = timeout
705 ts.Config.WriteTimeout = timeout
706 })
707 defer cst.close()
708 ts := cst.ts
709
710
711 c := ts.Client()
712 r, err := c.Get(ts.URL)
713 if err != nil {
714 return fmt.Errorf("http Get #1: %v", err)
715 }
716 got, err := io.ReadAll(r.Body)
717 expected := "req=1"
718 if string(got) != expected || err != nil {
719 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
720 string(got), err, expected)
721 }
722
723
724 t1 := time.Now()
725 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
726 if err != nil {
727 return fmt.Errorf("Dial: %v", err)
728 }
729 buf := make([]byte, 1)
730 n, err := conn.Read(buf)
731 conn.Close()
732 latency := time.Since(t1)
733 if n != 0 || err != io.EOF {
734 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
735 }
736 minLatency := timeout / 5 * 4
737 if latency < minLatency {
738 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
739 }
740
741
742
743
744 r, err = c.Get(ts.URL)
745 if err != nil {
746 return fmt.Errorf("http Get #2: %v", err)
747 }
748 got, err = io.ReadAll(r.Body)
749 r.Body.Close()
750 expected = "req=2"
751 if string(got) != expected || err != nil {
752 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
753 }
754
755 if !testing.Short() {
756 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
757 if err != nil {
758 return fmt.Errorf("long Dial: %v", err)
759 }
760 defer conn.Close()
761 go io.Copy(io.Discard, conn)
762 for i := 0; i < 5; i++ {
763 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
764 if err != nil {
765 return fmt.Errorf("on write %d: %v", i, err)
766 }
767 time.Sleep(timeout / 2)
768 }
769 }
770 return nil
771 }
772
773 func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout) }
774 func testServerReadTimeout(t *testing.T, mode testMode) {
775 respBody := "response body"
776 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
777 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
778 _, err := io.Copy(io.Discard, req.Body)
779 if !errors.Is(err, os.ErrDeadlineExceeded) {
780 t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
781 }
782 res.Write([]byte(respBody))
783 }), func(ts *httptest.Server) {
784 ts.Config.ReadHeaderTimeout = -1
785 ts.Config.ReadTimeout = timeout
786 t.Logf("Server.Config.ReadTimeout = %v", timeout)
787 })
788
789 var retries atomic.Int32
790 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
791 if retries.Add(1) != 1 {
792 return nil, errors.New("too many retries")
793 }
794 return nil, nil
795 }
796
797 pr, pw := io.Pipe()
798 res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
799 if err != nil {
800 t.Logf("Get error, retrying: %v", err)
801 cst.close()
802 continue
803 }
804 defer res.Body.Close()
805 got, err := io.ReadAll(res.Body)
806 if string(got) != respBody || err != nil {
807 t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
808 }
809 pw.Close()
810 break
811 }
812 }
813
814 func TestServerNoReadTimeout(t *testing.T) { run(t, testServerNoReadTimeout) }
815 func testServerNoReadTimeout(t *testing.T, mode testMode) {
816 reqBody := "Hello, Gophers!"
817 resBody := "Hi, Gophers!"
818 for _, timeout := range []time.Duration{0, -1} {
819 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
820 ctl := NewResponseController(res)
821 ctl.EnableFullDuplex()
822 res.WriteHeader(StatusOK)
823
824
825 if err := ctl.Flush(); err != nil {
826 t.Errorf("server flush response: %v", err)
827 return
828 }
829 got, err := io.ReadAll(req.Body)
830 if string(got) != reqBody || err != nil {
831 t.Errorf("server read request body: %v; got %q, want %q", err, got, reqBody)
832 }
833 res.Write([]byte(resBody))
834 }), func(ts *httptest.Server) {
835 ts.Config.ReadTimeout = timeout
836 t.Logf("Server.Config.ReadTimeout = %d", timeout)
837 })
838
839 pr, pw := io.Pipe()
840 res, err := cst.c.Post(cst.ts.URL, "text/plain", pr)
841 if err != nil {
842 t.Fatal(err)
843 }
844 defer res.Body.Close()
845
846
847 time.Sleep(10 * time.Millisecond)
848 pw.Write([]byte(reqBody))
849 pw.Close()
850
851 got, err := io.ReadAll(res.Body)
852 if string(got) != resBody || err != nil {
853 t.Errorf("client read response body: %v; got %v, want %q", err, got, resBody)
854 }
855 }
856 }
857
858 func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout) }
859 func testServerWriteTimeout(t *testing.T, mode testMode) {
860 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
861 errc := make(chan error, 2)
862 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
863 errc <- nil
864 _, err := io.Copy(res, neverEnding('a'))
865 errc <- err
866 }), func(ts *httptest.Server) {
867 ts.Config.WriteTimeout = timeout
868 t.Logf("Server.Config.WriteTimeout = %v", timeout)
869 })
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888 var retries atomic.Int32
889 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
890 if retries.Add(1) != 1 {
891 return nil, errors.New("too many retries")
892 }
893 return nil, nil
894 }
895
896 res, err := cst.c.Get(cst.ts.URL)
897 if err != nil {
898
899 t.Logf("Get error, retrying: %v", err)
900 cst.close()
901 continue
902 }
903 defer res.Body.Close()
904 _, err = io.Copy(io.Discard, res.Body)
905 if err == nil {
906 t.Errorf("client reading from truncated request body: got nil error, want non-nil")
907 }
908 select {
909 case <-errc:
910 err = <-errc
911 if !errors.Is(err, os.ErrDeadlineExceeded) {
912 t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
913 }
914 return
915 default:
916
917 t.Logf("handler didn't run, retrying")
918 cst.close()
919 }
920 }
921 }
922
923 func TestServerNoWriteTimeout(t *testing.T) { run(t, testServerNoWriteTimeout) }
924 func testServerNoWriteTimeout(t *testing.T, mode testMode) {
925 for _, timeout := range []time.Duration{0, -1} {
926 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
927 _, err := io.Copy(res, neverEnding('a'))
928 t.Logf("server write response: %v", err)
929 }), func(ts *httptest.Server) {
930 ts.Config.WriteTimeout = timeout
931 t.Logf("Server.Config.WriteTimeout = %d", timeout)
932 })
933
934 res, err := cst.c.Get(cst.ts.URL)
935 if err != nil {
936 t.Fatal(err)
937 }
938 defer res.Body.Close()
939 n, err := io.CopyN(io.Discard, res.Body, 1<<20)
940 if n != 1<<20 || err != nil {
941 t.Errorf("client read response body: %d, %v", n, err)
942 }
943
944
945 res.Body.Close()
946 cst.ts.Config.Shutdown(context.Background())
947 }
948 }
949
950
951 func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
952 run(t, testWriteDeadlineExtendedOnNewRequest)
953 }
954 func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
955 if testing.Short() {
956 t.Skip("skipping in short mode")
957 }
958 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
959 func(ts *httptest.Server) {
960 ts.Config.WriteTimeout = 250 * time.Millisecond
961 },
962 ).ts
963
964 c := ts.Client()
965
966 for i := 1; i <= 3; i++ {
967 req, err := NewRequest("GET", ts.URL, nil)
968 if err != nil {
969 t.Fatal(err)
970 }
971
972 r, err := c.Do(req)
973 if err != nil {
974 t.Fatalf("http2 Get #%d: %v", i, err)
975 }
976 r.Body.Close()
977 time.Sleep(ts.Config.WriteTimeout / 2)
978 }
979 }
980
981
982
983 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
984 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
985 for i, timeout := range tries {
986 err := testFunc(timeout)
987 if err == nil {
988 return
989 }
990 t.Logf("failed at %v: %v", timeout, err)
991 if i != len(tries)-1 {
992 t.Logf("retrying at %v ...", tries[i+1])
993 }
994 }
995 t.Fatal("all attempts failed")
996 }
997
998
999 func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
1000 if testing.Short() {
1001 t.Skip("skipping in short mode")
1002 }
1003 setParallel(t)
1004 run(t, func(t *testing.T, mode testMode) {
1005 tryTimeouts(t, func(timeout time.Duration) error {
1006 return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
1007 })
1008 })
1009 }
1010
1011 func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
1012 firstRequest := make(chan bool, 1)
1013 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1014 select {
1015 case firstRequest <- true:
1016
1017 default:
1018
1019 time.Sleep(timeout)
1020 }
1021 }), func(ts *httptest.Server) {
1022 ts.Config.WriteTimeout = timeout / 2
1023 })
1024 defer cst.close()
1025 ts := cst.ts
1026
1027 c := ts.Client()
1028
1029 req, err := NewRequest("GET", ts.URL, nil)
1030 if err != nil {
1031 return fmt.Errorf("NewRequest: %v", err)
1032 }
1033 r, err := c.Do(req)
1034 if err != nil {
1035 return fmt.Errorf("Get #1: %v", err)
1036 }
1037 r.Body.Close()
1038
1039 req, err = NewRequest("GET", ts.URL, nil)
1040 if err != nil {
1041 return fmt.Errorf("NewRequest: %v", err)
1042 }
1043 r, err = c.Do(req)
1044 if err == nil {
1045 r.Body.Close()
1046 return fmt.Errorf("Get #2 expected error, got nil")
1047 }
1048 if mode == http2Mode {
1049 expected := "stream ID 3; INTERNAL_ERROR"
1050 if !strings.Contains(err.Error(), expected) {
1051 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
1052 }
1053 }
1054 return nil
1055 }
1056
1057
1058 func TestNoWriteDeadline(t *testing.T) {
1059 if testing.Short() {
1060 t.Skip("skipping in short mode")
1061 }
1062 setParallel(t)
1063 defer afterTest(t)
1064 run(t, func(t *testing.T, mode testMode) {
1065 tryTimeouts(t, func(timeout time.Duration) error {
1066 return testNoWriteDeadline(t, mode, timeout)
1067 })
1068 })
1069 }
1070
1071 func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
1072 firstRequest := make(chan bool, 1)
1073 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1074 select {
1075 case firstRequest <- true:
1076
1077 default:
1078
1079 time.Sleep(timeout)
1080 }
1081 }))
1082 defer cst.close()
1083 ts := cst.ts
1084
1085 c := ts.Client()
1086
1087 for i := 0; i < 2; i++ {
1088 req, err := NewRequest("GET", ts.URL, nil)
1089 if err != nil {
1090 return fmt.Errorf("NewRequest: %v", err)
1091 }
1092 r, err := c.Do(req)
1093 if err != nil {
1094 return fmt.Errorf("Get #%d: %v", i, err)
1095 }
1096 r.Body.Close()
1097 }
1098 return nil
1099 }
1100
1101
1102
1103
1104 func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
1105 func testOnlyWriteTimeout(t *testing.T, mode testMode) {
1106 var (
1107 mu sync.RWMutex
1108 conn net.Conn
1109 )
1110 var afterTimeoutErrc = make(chan error, 1)
1111 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
1112 buf := make([]byte, 512<<10)
1113 _, err := w.Write(buf)
1114 if err != nil {
1115 t.Errorf("handler Write error: %v", err)
1116 return
1117 }
1118 mu.RLock()
1119 defer mu.RUnlock()
1120 if conn == nil {
1121 t.Error("no established connection found")
1122 return
1123 }
1124 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
1125 _, err = w.Write(buf)
1126 afterTimeoutErrc <- err
1127 }), func(ts *httptest.Server) {
1128 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
1129 }).ts
1130
1131 c := ts.Client()
1132
1133 err := func() error {
1134 res, err := c.Get(ts.URL)
1135 if err != nil {
1136 return err
1137 }
1138 _, err = io.Copy(io.Discard, res.Body)
1139 res.Body.Close()
1140 return err
1141 }()
1142 if err == nil {
1143 t.Errorf("expected an error copying body from Get request")
1144 }
1145
1146 if err := <-afterTimeoutErrc; err == nil {
1147 t.Error("expected write error after timeout")
1148 }
1149 }
1150
1151
1152 type trackLastConnListener struct {
1153 net.Listener
1154
1155 mu *sync.RWMutex
1156 last *net.Conn
1157 }
1158
1159 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
1160 c, err = l.Listener.Accept()
1161 if err == nil {
1162 l.mu.Lock()
1163 *l.last = c
1164 l.mu.Unlock()
1165 }
1166 return
1167 }
1168
1169
1170 func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
1171 func testIdentityResponse(t *testing.T, mode testMode) {
1172 if mode == http2Mode {
1173 t.Skip("https://go.dev/issue/56019")
1174 }
1175
1176 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
1177 rw.Header().Set("Content-Length", "3")
1178 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1179 switch {
1180 case req.FormValue("overwrite") == "1":
1181 _, err := rw.Write([]byte("foo TOO LONG"))
1182 if err != ErrContentLength {
1183 t.Errorf("expected ErrContentLength; got %v", err)
1184 }
1185 case req.FormValue("underwrite") == "1":
1186 rw.Header().Set("Content-Length", "500")
1187 rw.Write([]byte("too short"))
1188 default:
1189 rw.Write([]byte("foo"))
1190 }
1191 })
1192
1193 ts := newClientServerTest(t, mode, handler).ts
1194 c := ts.Client()
1195
1196
1197
1198
1199
1200 for _, te := range []string{"", "identity"} {
1201 url := ts.URL + "/?te=" + te
1202 res, err := c.Get(url)
1203 if err != nil {
1204 t.Fatalf("error with Get of %s: %v", url, err)
1205 }
1206 if cl, expected := res.ContentLength, int64(3); cl != expected {
1207 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1208 }
1209 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1210 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1211 }
1212 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1213 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1214 url, expected, tl, res.TransferEncoding)
1215 }
1216 res.Body.Close()
1217 }
1218
1219
1220 url := ts.URL + "/?overwrite=1"
1221 res, err := c.Get(url)
1222 if err != nil {
1223 t.Fatalf("error with Get of %s: %v", url, err)
1224 }
1225 res.Body.Close()
1226
1227 if mode != http1Mode {
1228 return
1229 }
1230
1231
1232
1233 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1234 if err != nil {
1235 t.Fatalf("error dialing: %v", err)
1236 }
1237 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1238 if err != nil {
1239 t.Fatalf("error writing: %v", err)
1240 }
1241
1242
1243 got, _ := io.ReadAll(conn)
1244 expectedSuffix := "\r\n\r\ntoo short"
1245 if !strings.HasSuffix(string(got), expectedSuffix) {
1246 t.Errorf("Expected output to end with %q; got response body %q",
1247 expectedSuffix, string(got))
1248 }
1249 }
1250
1251 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1252 setParallel(t)
1253 s := newClientServerTest(t, http1Mode, h).ts
1254
1255 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1256 if err != nil {
1257 t.Fatal("dial error:", err)
1258 }
1259 defer conn.Close()
1260
1261 _, err = fmt.Fprint(conn, req)
1262 if err != nil {
1263 t.Fatal("print error:", err)
1264 }
1265
1266 r := bufio.NewReader(conn)
1267 res, err := ReadResponse(r, &Request{Method: "GET"})
1268 if err != nil {
1269 t.Fatal("ReadResponse error:", err)
1270 }
1271
1272 _, err = io.ReadAll(r)
1273 if err != nil {
1274 t.Fatal("read error:", err)
1275 }
1276
1277 if !res.Close {
1278 t.Errorf("Response.Close = false; want true")
1279 }
1280 }
1281
1282 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1283 setParallel(t)
1284 ts := newClientServerTest(t, http1Mode, handler).ts
1285 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1286 if err != nil {
1287 t.Fatal(err)
1288 }
1289 defer conn.Close()
1290 br := bufio.NewReader(conn)
1291 for i := 0; i < 2; i++ {
1292 if _, err := io.WriteString(conn, req); err != nil {
1293 t.Fatal(err)
1294 }
1295 res, err := ReadResponse(br, nil)
1296 if err != nil {
1297 t.Fatalf("res %d: %v", i+1, err)
1298 }
1299 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1300 t.Fatalf("res %d body copy: %v", i+1, err)
1301 }
1302 res.Body.Close()
1303 }
1304 }
1305
1306
1307 func TestServeHTTP10Close(t *testing.T) {
1308 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1309 ServeFile(w, r, "testdata/file")
1310 }))
1311 }
1312
1313
1314 func TestClientCanClose(t *testing.T) {
1315 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1316
1317 }))
1318 }
1319
1320
1321
1322 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1323 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1324 w.Header().Set("Connection", "close")
1325 }))
1326 }
1327
1328 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1329 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1330 w.Header().Set("Connection", "close")
1331 }))
1332 }
1333
1334 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1335 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1336
1337
1338 }))
1339 }
1340
1341 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1342 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1343
1344
1345 func TestHTTP10KeepAlive204Response(t *testing.T) {
1346 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1347 }
1348
1349 func TestHTTP11KeepAlive204Response(t *testing.T) {
1350 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1351 }
1352
1353 func TestHTTP10KeepAlive304Response(t *testing.T) {
1354 testTCPConnectionStaysOpen(t,
1355 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1356 HandlerFunc(send304))
1357 }
1358
1359
1360 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
1361 func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
1362 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1363 w.(Flusher).Flush()
1364 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1365 }))
1366 type data struct {
1367 Addr string
1368 }
1369 var addrs [2]data
1370 for i := range addrs {
1371 res, err := cst.c.Get(cst.ts.URL)
1372 if err != nil {
1373 t.Fatal(err)
1374 }
1375 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1376 t.Fatal(err)
1377 }
1378 if addrs[i].Addr == "" {
1379 t.Fatal("no address")
1380 }
1381 res.Body.Close()
1382 }
1383 if addrs[0] != addrs[1] {
1384 t.Fatalf("connection not reused")
1385 }
1386 }
1387
1388 func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
1389 func testSetsRemoteAddr(t *testing.T, mode testMode) {
1390 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1391 fmt.Fprintf(w, "%s", r.RemoteAddr)
1392 }))
1393
1394 res, err := cst.c.Get(cst.ts.URL)
1395 if err != nil {
1396 t.Fatalf("Get error: %v", err)
1397 }
1398 body, err := io.ReadAll(res.Body)
1399 if err != nil {
1400 t.Fatalf("ReadAll error: %v", err)
1401 }
1402 ip := string(body)
1403 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1404 t.Fatalf("Expected local addr; got %q", ip)
1405 }
1406 }
1407
1408 type blockingRemoteAddrListener struct {
1409 net.Listener
1410 conns chan<- net.Conn
1411 }
1412
1413 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1414 c, err := l.Listener.Accept()
1415 if err != nil {
1416 return nil, err
1417 }
1418 brac := &blockingRemoteAddrConn{
1419 Conn: c,
1420 addrs: make(chan net.Addr, 1),
1421 }
1422 l.conns <- brac
1423 return brac, nil
1424 }
1425
1426 type blockingRemoteAddrConn struct {
1427 net.Conn
1428 addrs chan net.Addr
1429 }
1430
1431 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1432 return <-c.addrs
1433 }
1434
1435
1436 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1437 run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
1438 }
1439 func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
1440 conns := make(chan net.Conn)
1441 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1442 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1443 }), func(ts *httptest.Server) {
1444 ts.Listener = &blockingRemoteAddrListener{
1445 Listener: ts.Listener,
1446 conns: conns,
1447 }
1448 }).ts
1449
1450 c := ts.Client()
1451
1452 c.Transport.(*Transport).DisableKeepAlives = true
1453
1454 fetch := func(num int, response chan<- string) {
1455 resp, err := c.Get(ts.URL)
1456 if err != nil {
1457 t.Errorf("Request %d: %v", num, err)
1458 response <- ""
1459 return
1460 }
1461 defer resp.Body.Close()
1462 body, err := io.ReadAll(resp.Body)
1463 if err != nil {
1464 t.Errorf("Request %d: %v", num, err)
1465 response <- ""
1466 return
1467 }
1468 response <- string(body)
1469 }
1470
1471
1472 response1c := make(chan string, 1)
1473 go fetch(1, response1c)
1474
1475
1476 conn1 := <-conns
1477
1478
1479 response2c := make(chan string, 1)
1480 go fetch(2, response2c)
1481 conn2 := <-conns
1482
1483
1484 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1485 IP: net.ParseIP("12.12.12.12"), Port: 12}
1486
1487
1488 response2 := <-response2c
1489 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1490 t.Fatalf("response 2 addr = %q; want %q", g, e)
1491 }
1492
1493
1494 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1495 IP: net.ParseIP("21.21.21.21"), Port: 21}
1496
1497
1498 response1 := <-response1c
1499 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1500 t.Fatalf("response 1 addr = %q; want %q", g, e)
1501 }
1502 }
1503
1504
1505
1506 func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
1507 func testHeadResponses(t *testing.T, mode testMode) {
1508 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1509 _, err := w.Write([]byte("<html>"))
1510 if err != nil {
1511 t.Errorf("ResponseWriter.Write: %v", err)
1512 }
1513
1514
1515 _, err = io.Copy(w, struct{ io.Reader }{strings.NewReader("789a")})
1516 if err != nil {
1517 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1518 }
1519 }))
1520 res, err := cst.c.Head(cst.ts.URL)
1521 if err != nil {
1522 t.Error(err)
1523 }
1524 if len(res.TransferEncoding) > 0 {
1525 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1526 }
1527 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1528 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1529 }
1530 if v := res.ContentLength; v != 10 {
1531 t.Errorf("Content-Length: %d; want 10", v)
1532 }
1533 body, err := io.ReadAll(res.Body)
1534 if err != nil {
1535 t.Error(err)
1536 }
1537 if len(body) > 0 {
1538 t.Errorf("got unexpected body %q", string(body))
1539 }
1540 }
1541
1542
1543
1544 func TestHeadReaderFrom(t *testing.T) { run(t, testHeadReaderFrom, []testMode{http1Mode}) }
1545 func testHeadReaderFrom(t *testing.T, mode testMode) {
1546
1547 wantBody := strings.Repeat("a", 4096)
1548 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1549 w.(io.ReaderFrom).ReadFrom(strings.NewReader(wantBody))
1550 }))
1551 res, err := cst.c.Head(cst.ts.URL)
1552 if err != nil {
1553 t.Fatal(err)
1554 }
1555 res.Body.Close()
1556 res, err = cst.c.Get(cst.ts.URL)
1557 if err != nil {
1558 t.Fatal(err)
1559 }
1560 gotBody, err := io.ReadAll(res.Body)
1561 res.Body.Close()
1562 if err != nil {
1563 t.Fatal(err)
1564 }
1565 if string(gotBody) != wantBody {
1566 t.Errorf("got unexpected body len=%v, want %v", len(gotBody), len(wantBody))
1567 }
1568 }
1569
1570 func TestTLSHandshakeTimeout(t *testing.T) {
1571 run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
1572 }
1573 func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
1574 errLog := new(strings.Builder)
1575 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
1576 func(ts *httptest.Server) {
1577 ts.Config.ReadTimeout = 250 * time.Millisecond
1578 ts.Config.ErrorLog = log.New(errLog, "", 0)
1579 },
1580 )
1581 ts := cst.ts
1582
1583 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1584 if err != nil {
1585 t.Fatalf("Dial: %v", err)
1586 }
1587 var buf [1]byte
1588 n, err := conn.Read(buf[:])
1589 if err == nil || n != 0 {
1590 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1591 }
1592 conn.Close()
1593
1594 cst.close()
1595 if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1596 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1597 }
1598 }
1599
1600 func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
1601 func testTLSServer(t *testing.T, mode testMode) {
1602 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1603 if r.TLS != nil {
1604 w.Header().Set("X-TLS-Set", "true")
1605 if r.TLS.HandshakeComplete {
1606 w.Header().Set("X-TLS-HandshakeComplete", "true")
1607 }
1608 }
1609 }), func(ts *httptest.Server) {
1610 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1611 }).ts
1612
1613
1614
1615
1616
1617
1618 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1619 if err != nil {
1620 t.Fatalf("Dial: %v", err)
1621 }
1622 defer idleConn.Close()
1623
1624 if !strings.HasPrefix(ts.URL, "https://") {
1625 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1626 return
1627 }
1628 client := ts.Client()
1629 res, err := client.Get(ts.URL)
1630 if err != nil {
1631 t.Error(err)
1632 return
1633 }
1634 if res == nil {
1635 t.Errorf("got nil Response")
1636 return
1637 }
1638 defer res.Body.Close()
1639 if res.Header.Get("X-TLS-Set") != "true" {
1640 t.Errorf("expected X-TLS-Set response header")
1641 return
1642 }
1643 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1644 t.Errorf("expected X-TLS-HandshakeComplete header")
1645 }
1646 }
1647
1648 type fakeConnectionStateConn struct {
1649 net.Conn
1650 }
1651
1652 func (fcsc *fakeConnectionStateConn) ConnectionState() tls.ConnectionState {
1653 return tls.ConnectionState{
1654 ServerName: "example.com",
1655 }
1656 }
1657
1658 func TestTLSServerWithoutTLSConn(t *testing.T) {
1659
1660 pr, pw := net.Pipe()
1661 c := make(chan int)
1662 listener := &oneConnListener{&fakeConnectionStateConn{pr}}
1663 server := &Server{
1664 Handler: HandlerFunc(func(writer ResponseWriter, request *Request) {
1665 if request.TLS == nil {
1666 t.Fatal("request.TLS is nil, expected not nil")
1667 }
1668 if request.TLS.ServerName != "example.com" {
1669 t.Fatalf("request.TLS.ServerName is %s, expected %s", request.TLS.ServerName, "example.com")
1670 }
1671 writer.Header().Set("X-TLS-ServerName", "example.com")
1672 }),
1673 }
1674
1675
1676 go func() {
1677 req, _ := NewRequest(MethodGet, "https://example.com", nil)
1678 req.Write(pw)
1679
1680 resp, _ := ReadResponse(bufio.NewReader(pw), req)
1681 if hdr := resp.Header.Get("X-TLS-ServerName"); hdr != "example.com" {
1682 t.Errorf("response header X-TLS-ServerName is %s, expected %s", hdr, "example.com")
1683 }
1684 close(c)
1685 pw.Close()
1686 }()
1687
1688 server.Serve(listener)
1689
1690
1691 <-c
1692 pr.Close()
1693 }
1694
1695 func TestServeTLS(t *testing.T) {
1696 CondSkipHTTP2(t)
1697
1698 defer afterTest(t)
1699 defer SetTestHookServerServe(nil)
1700
1701 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1702 if err != nil {
1703 t.Fatal(err)
1704 }
1705 tlsConf := &tls.Config{
1706 Certificates: []tls.Certificate{cert},
1707 }
1708
1709 ln := newLocalListener(t)
1710 defer ln.Close()
1711 addr := ln.Addr().String()
1712
1713 serving := make(chan bool, 1)
1714 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1715 serving <- true
1716 })
1717 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1718 s := &Server{
1719 Addr: addr,
1720 TLSConfig: tlsConf,
1721 Handler: handler,
1722 }
1723 errc := make(chan error, 1)
1724 go func() { errc <- s.ServeTLS(ln, "", "") }()
1725 select {
1726 case err := <-errc:
1727 t.Fatalf("ServeTLS: %v", err)
1728 case <-serving:
1729 }
1730
1731 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1732 InsecureSkipVerify: true,
1733 NextProtos: []string{"h2", "http/1.1"},
1734 })
1735 if err != nil {
1736 t.Fatal(err)
1737 }
1738 defer c.Close()
1739 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1740 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1741 }
1742 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1743 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1744 }
1745 }
1746
1747
1748 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1749 run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
1750 }
1751 func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
1752 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1753 t.Error("unexpected HTTPS request")
1754 }), func(ts *httptest.Server) {
1755 var errBuf bytes.Buffer
1756 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1757 }).ts
1758 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1759 if err != nil {
1760 t.Fatal(err)
1761 }
1762 defer conn.Close()
1763 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1764 slurp, err := io.ReadAll(conn)
1765 if err != nil {
1766 t.Fatal(err)
1767 }
1768 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1769 if !strings.HasPrefix(string(slurp), wantPrefix) {
1770 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1771 }
1772 }
1773
1774
1775 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1776 testAutomaticHTTP2_Serve(t, nil, true)
1777 }
1778
1779 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1780 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1781 }
1782
1783 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1784 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1785 }
1786
1787 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1788 setParallel(t)
1789 defer afterTest(t)
1790 ln := newLocalListener(t)
1791 ln.Close()
1792 var s Server
1793 s.TLSConfig = tlsConf
1794 if err := s.Serve(ln); err == nil {
1795 t.Fatal("expected an error")
1796 }
1797 gotH2 := s.TLSNextProto["h2"] != nil
1798 if gotH2 != wantH2 {
1799 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1800 }
1801 }
1802
1803 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1804 setParallel(t)
1805 defer afterTest(t)
1806 ln := newLocalListener(t)
1807 ln.Close()
1808 var s Server
1809
1810
1811 s.TLSConfig = &tls.Config{
1812 NextProtos: []string{"h2"},
1813 }
1814 if err := s.Serve(ln); err == nil {
1815 t.Fatal("expected an error")
1816 }
1817 on := s.TLSNextProto["h2"] != nil
1818 if !on {
1819 t.Errorf("http2 wasn't automatically enabled")
1820 }
1821 }
1822
1823 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1824 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1825 if err != nil {
1826 t.Fatal(err)
1827 }
1828 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1829 Certificates: []tls.Certificate{cert},
1830 })
1831 }
1832
1833 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1834 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1835 if err != nil {
1836 t.Fatal(err)
1837 }
1838 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1839 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1840 return &cert, nil
1841 },
1842 })
1843 }
1844
1845 func TestAutomaticHTTP2_ListenAndServe_GetConfigForClient(t *testing.T) {
1846 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1847 if err != nil {
1848 t.Fatal(err)
1849 }
1850 conf := &tls.Config{
1851
1852
1853 NextProtos: []string{"h2"},
1854 Certificates: []tls.Certificate{cert},
1855 }
1856 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1857 GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
1858 return conf, nil
1859 },
1860 })
1861 }
1862
1863 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1864 CondSkipHTTP2(t)
1865
1866 defer afterTest(t)
1867 defer SetTestHookServerServe(nil)
1868 var ok bool
1869 var s *Server
1870 const maxTries = 5
1871 var ln net.Listener
1872 Try:
1873 for try := 0; try < maxTries; try++ {
1874 ln = newLocalListener(t)
1875 addr := ln.Addr().String()
1876 ln.Close()
1877 t.Logf("Got %v", addr)
1878 lnc := make(chan net.Listener, 1)
1879 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1880 lnc <- ln
1881 })
1882 s = &Server{
1883 Addr: addr,
1884 TLSConfig: tlsConf,
1885 }
1886 errc := make(chan error, 1)
1887 go func() { errc <- s.ListenAndServeTLS("", "") }()
1888 select {
1889 case err := <-errc:
1890 t.Logf("On try #%v: %v", try+1, err)
1891 continue
1892 case ln = <-lnc:
1893 ok = true
1894 t.Logf("Listening on %v", ln.Addr().String())
1895 break Try
1896 }
1897 }
1898 if !ok {
1899 t.Fatalf("Failed to start up after %d tries", maxTries)
1900 }
1901 defer ln.Close()
1902 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1903 InsecureSkipVerify: true,
1904 NextProtos: []string{"h2", "http/1.1"},
1905 })
1906 if err != nil {
1907 t.Fatal(err)
1908 }
1909 defer c.Close()
1910 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1911 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1912 }
1913 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1914 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1915 }
1916 }
1917
1918 type serverExpectTest struct {
1919 contentLength int
1920 chunked bool
1921 expectation string
1922 readBody bool
1923 expectedResponse string
1924 }
1925
1926 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
1927 return serverExpectTest{
1928 contentLength: contentLength,
1929 expectation: expectation,
1930 readBody: readBody,
1931 expectedResponse: expectedResponse,
1932 }
1933 }
1934
1935 var serverExpectTests = []serverExpectTest{
1936
1937 expectTest(100, "100-continue", true, "100 Continue"),
1938 expectTest(100, "100-cOntInUE", true, "100 Continue"),
1939
1940
1941 expectTest(100, "", true, "200 OK"),
1942
1943
1944
1945 expectTest(100, "100-continue", false, "401 Unauthorized"),
1946
1947 expectTest(100, "", false, "401 Unauthorized"),
1948
1949
1950 expectTest(0, "a-pony", false, "417 Expectation Failed"),
1951
1952
1953 expectTest(0, "100-continue", true, "200 OK"),
1954
1955 expectTest(0, "100-continue", false, "401 Unauthorized"),
1956
1957 {
1958 expectation: "100-continue",
1959 readBody: true,
1960 chunked: true,
1961 expectedResponse: "100 Continue",
1962 },
1963 }
1964
1965
1966
1967 func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
1968 func testServerExpect(t *testing.T, mode testMode) {
1969 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1970
1971
1972
1973 if strings.Contains(r.URL.RawQuery, "readbody=true") {
1974 io.ReadAll(r.Body)
1975 w.Write([]byte("Hi"))
1976 } else {
1977 w.WriteHeader(StatusUnauthorized)
1978 }
1979 })).ts
1980
1981 runTest := func(test serverExpectTest) {
1982 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1983 if err != nil {
1984 t.Fatalf("Dial: %v", err)
1985 }
1986 defer conn.Close()
1987
1988
1989
1990 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
1991
1992 wg := sync.WaitGroup{}
1993 wg.Add(1)
1994 defer wg.Wait()
1995
1996 go func() {
1997 defer wg.Done()
1998
1999 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
2000 if test.chunked {
2001 contentLen = "Transfer-Encoding: chunked"
2002 }
2003 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
2004 "Connection: close\r\n"+
2005 "%s\r\n"+
2006 "Expect: %s\r\nHost: foo\r\n\r\n",
2007 test.readBody, contentLen, test.expectation)
2008 if err != nil {
2009 t.Errorf("On test %#v, error writing request headers: %v", test, err)
2010 return
2011 }
2012 if writeBody {
2013 var targ io.WriteCloser = struct {
2014 io.Writer
2015 io.Closer
2016 }{
2017 conn,
2018 io.NopCloser(nil),
2019 }
2020 if test.chunked {
2021 targ = httputil.NewChunkedWriter(conn)
2022 }
2023 body := strings.Repeat("A", test.contentLength)
2024 _, err = fmt.Fprint(targ, body)
2025 if err == nil {
2026 err = targ.Close()
2027 }
2028 if err != nil {
2029 if !test.readBody {
2030
2031
2032 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
2033 return
2034 }
2035 t.Errorf("On test %#v, error writing request body: %v", test, err)
2036 }
2037 }
2038 }()
2039 bufr := bufio.NewReader(conn)
2040 line, err := bufr.ReadString('\n')
2041 if err != nil {
2042 if writeBody && !test.readBody {
2043
2044
2045
2046
2047
2048 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
2049 return
2050 }
2051 t.Fatalf("On test %#v, ReadString: %v", test, err)
2052 }
2053 if !strings.Contains(line, test.expectedResponse) {
2054 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
2055 }
2056 }
2057
2058 for _, test := range serverExpectTests {
2059 runTest(test)
2060 }
2061 }
2062
2063
2064
2065 func TestServerUnreadRequestBodyLittle(t *testing.T) {
2066 setParallel(t)
2067 defer afterTest(t)
2068 conn := new(testConn)
2069 body := strings.Repeat("x", 100<<10)
2070 conn.readBuf.Write([]byte(fmt.Sprintf(
2071 "POST / HTTP/1.1\r\n"+
2072 "Host: test\r\n"+
2073 "Content-Length: %d\r\n"+
2074 "\r\n", len(body))))
2075 conn.readBuf.Write([]byte(body))
2076
2077 done := make(chan bool)
2078
2079 readBufLen := func() int {
2080 conn.readMu.Lock()
2081 defer conn.readMu.Unlock()
2082 return conn.readBuf.Len()
2083 }
2084
2085 ls := &oneConnListener{conn}
2086 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2087 defer close(done)
2088 if bufLen := readBufLen(); bufLen < len(body)/2 {
2089 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
2090 }
2091 rw.WriteHeader(200)
2092 rw.(Flusher).Flush()
2093 if g, e := readBufLen(), 0; g != e {
2094 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
2095 }
2096 if c := rw.Header().Get("Connection"); c != "" {
2097 t.Errorf(`Connection header = %q; want ""`, c)
2098 }
2099 }))
2100 <-done
2101 }
2102
2103
2104
2105
2106 func TestServerUnreadRequestBodyLarge(t *testing.T) {
2107 setParallel(t)
2108 if testing.Short() && testenv.Builder() == "" {
2109 t.Log("skipping in short mode")
2110 }
2111 conn := new(testConn)
2112 body := strings.Repeat("x", 1<<20)
2113 conn.readBuf.Write([]byte(fmt.Sprintf(
2114 "POST / HTTP/1.1\r\n"+
2115 "Host: test\r\n"+
2116 "Content-Length: %d\r\n"+
2117 "\r\n", len(body))))
2118 conn.readBuf.Write([]byte(body))
2119 conn.closec = make(chan bool, 1)
2120
2121 ls := &oneConnListener{conn}
2122 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2123 if conn.readBuf.Len() < len(body)/2 {
2124 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2125 }
2126 rw.WriteHeader(200)
2127 rw.(Flusher).Flush()
2128 if conn.readBuf.Len() < len(body)/2 {
2129 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2130 }
2131 }))
2132 <-conn.closec
2133
2134 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
2135 t.Errorf("Expected a Connection: close header; got response: %s", res)
2136 }
2137 }
2138
2139 type handlerBodyCloseTest struct {
2140 bodySize int
2141 bodyChunked bool
2142 reqConnClose bool
2143
2144 wantEOFSearch bool
2145 wantNextReq bool
2146 }
2147
2148 func (t handlerBodyCloseTest) connectionHeader() string {
2149 if t.reqConnClose {
2150 return "Connection: close\r\n"
2151 }
2152 return ""
2153 }
2154
2155 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
2156
2157
2158 0: {
2159 bodySize: 20 << 10,
2160 bodyChunked: false,
2161 reqConnClose: false,
2162 wantEOFSearch: true,
2163 wantNextReq: true,
2164 },
2165
2166
2167
2168 1: {
2169 bodySize: 20 << 10,
2170 bodyChunked: true,
2171 reqConnClose: false,
2172 wantEOFSearch: true,
2173 wantNextReq: true,
2174 },
2175
2176
2177
2178
2179 2: {
2180 bodySize: 20 << 10,
2181 bodyChunked: false,
2182 reqConnClose: true,
2183 wantEOFSearch: false,
2184 wantNextReq: false,
2185 },
2186
2187
2188
2189
2190
2191
2192 3: {
2193 bodySize: 20 << 10,
2194 bodyChunked: true,
2195 reqConnClose: true,
2196 wantEOFSearch: true,
2197 wantNextReq: false,
2198 },
2199
2200
2201 4: {
2202 bodySize: 1 << 20,
2203 bodyChunked: false,
2204 reqConnClose: false,
2205 wantEOFSearch: false,
2206 wantNextReq: false,
2207 },
2208
2209
2210 5: {
2211 bodySize: 1 << 20,
2212 bodyChunked: true,
2213 reqConnClose: false,
2214 wantEOFSearch: true,
2215 wantNextReq: false,
2216 },
2217
2218
2219
2220
2221 6: {
2222 bodySize: 1 << 20,
2223 bodyChunked: true,
2224 reqConnClose: true,
2225 wantEOFSearch: true,
2226 wantNextReq: false,
2227 },
2228
2229
2230
2231 7: {
2232 bodySize: 1 << 20,
2233 bodyChunked: false,
2234 reqConnClose: true,
2235 wantEOFSearch: false,
2236 wantNextReq: false,
2237 },
2238 }
2239
2240 func TestHandlerBodyClose(t *testing.T) {
2241 setParallel(t)
2242 if testing.Short() && testenv.Builder() == "" {
2243 t.Skip("skipping in -short mode")
2244 }
2245 for i, tt := range handlerBodyCloseTests {
2246 testHandlerBodyClose(t, i, tt)
2247 }
2248 }
2249
2250 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2251 conn := new(testConn)
2252 body := strings.Repeat("x", tt.bodySize)
2253 if tt.bodyChunked {
2254 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2255 "Host: test\r\n" +
2256 tt.connectionHeader() +
2257 "Transfer-Encoding: chunked\r\n" +
2258 "\r\n")
2259 cw := internal.NewChunkedWriter(&conn.readBuf)
2260 io.WriteString(cw, body)
2261 cw.Close()
2262 conn.readBuf.WriteString("\r\n")
2263 } else {
2264 conn.readBuf.Write([]byte(fmt.Sprintf(
2265 "POST / HTTP/1.1\r\n"+
2266 "Host: test\r\n"+
2267 tt.connectionHeader()+
2268 "Content-Length: %d\r\n"+
2269 "\r\n", len(body))))
2270 conn.readBuf.Write([]byte(body))
2271 }
2272 if !tt.reqConnClose {
2273 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2274 }
2275 conn.closec = make(chan bool, 1)
2276
2277 readBufLen := func() int {
2278 conn.readMu.Lock()
2279 defer conn.readMu.Unlock()
2280 return conn.readBuf.Len()
2281 }
2282
2283 ls := &oneConnListener{conn}
2284 var numReqs int
2285 var size0, size1 int
2286 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2287 numReqs++
2288 if numReqs == 1 {
2289 size0 = readBufLen()
2290 req.Body.Close()
2291 size1 = readBufLen()
2292 }
2293 }))
2294 <-conn.closec
2295 if numReqs < 1 || numReqs > 2 {
2296 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2297 }
2298 didSearch := size0 != size1
2299 if didSearch != tt.wantEOFSearch {
2300 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2301 }
2302 if tt.wantNextReq && numReqs != 2 {
2303 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2304 }
2305 }
2306
2307
2308
2309 type testHandlerBodyConsumer struct {
2310 name string
2311 f func(io.ReadCloser)
2312 }
2313
2314 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2315 {"nil", func(io.ReadCloser) {}},
2316 {"close", func(r io.ReadCloser) { r.Close() }},
2317 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2318 }
2319
2320 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2321 setParallel(t)
2322 defer afterTest(t)
2323 for _, handler := range testHandlerBodyConsumers {
2324 conn := new(testConn)
2325 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2326 "Host: test\r\n" +
2327 "Transfer-Encoding: chunked\r\n" +
2328 "\r\n" +
2329 "hax\r\n" +
2330 "GET /secret HTTP/1.1\r\n" +
2331 "Host: test\r\n" +
2332 "\r\n")
2333
2334 conn.closec = make(chan bool, 1)
2335 ls := &oneConnListener{conn}
2336 var numReqs int
2337 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2338 numReqs++
2339 if strings.Contains(req.URL.Path, "secret") {
2340 t.Error("Request for /secret encountered, should not have happened.")
2341 }
2342 handler.f(req.Body)
2343 }))
2344 <-conn.closec
2345 if numReqs != 1 {
2346 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2347 }
2348 }
2349 }
2350
2351 func TestInvalidTrailerClosesConnection(t *testing.T) {
2352 setParallel(t)
2353 defer afterTest(t)
2354 for _, handler := range testHandlerBodyConsumers {
2355 conn := new(testConn)
2356 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2357 "Host: test\r\n" +
2358 "Trailer: hack\r\n" +
2359 "Transfer-Encoding: chunked\r\n" +
2360 "\r\n" +
2361 "3\r\n" +
2362 "hax\r\n" +
2363 "0\r\n" +
2364 "I'm not a valid trailer\r\n" +
2365 "GET /secret HTTP/1.1\r\n" +
2366 "Host: test\r\n" +
2367 "\r\n")
2368
2369 conn.closec = make(chan bool, 1)
2370 ln := &oneConnListener{conn}
2371 var numReqs int
2372 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2373 numReqs++
2374 if strings.Contains(req.URL.Path, "secret") {
2375 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2376 }
2377 handler.f(req.Body)
2378 }))
2379 <-conn.closec
2380 if numReqs != 1 {
2381 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2382 }
2383 }
2384 }
2385
2386
2387
2388
2389 type slowTestConn struct {
2390
2391 script []any
2392 closec chan bool
2393
2394 mu sync.Mutex
2395 rd, wd time.Time
2396 noopConn
2397 }
2398
2399 func (c *slowTestConn) SetDeadline(t time.Time) error {
2400 c.SetReadDeadline(t)
2401 c.SetWriteDeadline(t)
2402 return nil
2403 }
2404
2405 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2406 c.mu.Lock()
2407 defer c.mu.Unlock()
2408 c.rd = t
2409 return nil
2410 }
2411
2412 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2413 c.mu.Lock()
2414 defer c.mu.Unlock()
2415 c.wd = t
2416 return nil
2417 }
2418
2419 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2420 c.mu.Lock()
2421 defer c.mu.Unlock()
2422 restart:
2423 if !c.rd.IsZero() && time.Now().After(c.rd) {
2424 return 0, syscall.ETIMEDOUT
2425 }
2426 if len(c.script) == 0 {
2427 return 0, io.EOF
2428 }
2429
2430 switch cue := c.script[0].(type) {
2431 case time.Duration:
2432 if !c.rd.IsZero() {
2433
2434
2435 if remaining := time.Until(c.rd); remaining < cue {
2436 c.script[0] = cue - remaining
2437 time.Sleep(remaining)
2438 return 0, syscall.ETIMEDOUT
2439 }
2440 }
2441 c.script = c.script[1:]
2442 time.Sleep(cue)
2443 goto restart
2444
2445 case string:
2446 n = copy(b, cue)
2447
2448 if len(cue) > n {
2449 c.script[0] = cue[n:]
2450 } else {
2451 c.script = c.script[1:]
2452 }
2453
2454 default:
2455 panic("unknown cue in slowTestConn script")
2456 }
2457
2458 return
2459 }
2460
2461 func (c *slowTestConn) Close() error {
2462 select {
2463 case c.closec <- true:
2464 default:
2465 }
2466 return nil
2467 }
2468
2469 func (c *slowTestConn) Write(b []byte) (int, error) {
2470 if !c.wd.IsZero() && time.Now().After(c.wd) {
2471 return 0, syscall.ETIMEDOUT
2472 }
2473 return len(b), nil
2474 }
2475
2476 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2477 if testing.Short() {
2478 t.Skip("skipping in -short mode")
2479 }
2480 defer afterTest(t)
2481 for _, handler := range testHandlerBodyConsumers {
2482 conn := &slowTestConn{
2483 script: []any{
2484 "POST /public HTTP/1.1\r\n" +
2485 "Host: test\r\n" +
2486 "Content-Length: 10000\r\n" +
2487 "\r\n",
2488 "foo bar baz",
2489 600 * time.Millisecond,
2490 "GET /secret HTTP/1.1\r\n" +
2491 "Host: test\r\n" +
2492 "\r\n",
2493 },
2494 closec: make(chan bool, 1),
2495 }
2496 ls := &oneConnListener{conn}
2497
2498 var numReqs int
2499 s := Server{
2500 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2501 numReqs++
2502 if strings.Contains(req.URL.Path, "secret") {
2503 t.Error("Request for /secret encountered, should not have happened.")
2504 }
2505 handler.f(req.Body)
2506 }),
2507 ReadTimeout: 400 * time.Millisecond,
2508 }
2509 go s.Serve(ls)
2510 <-conn.closec
2511
2512 if numReqs != 1 {
2513 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2514 }
2515 }
2516 }
2517
2518
2519 type cancelableTimeoutContext struct {
2520 context.Context
2521 }
2522
2523 func (c cancelableTimeoutContext) Err() error {
2524 if c.Context.Err() != nil {
2525 return context.DeadlineExceeded
2526 }
2527 return nil
2528 }
2529
2530 func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
2531 func testTimeoutHandler(t *testing.T, mode testMode) {
2532 sendHi := make(chan bool, 1)
2533 writeErrors := make(chan error, 1)
2534 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2535 <-sendHi
2536 _, werr := w.Write([]byte("hi"))
2537 writeErrors <- werr
2538 })
2539 ctx, cancel := context.WithCancel(context.Background())
2540 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2541 cst := newClientServerTest(t, mode, h)
2542
2543
2544 sendHi <- true
2545 res, err := cst.c.Get(cst.ts.URL)
2546 if err != nil {
2547 t.Error(err)
2548 }
2549 if g, e := res.StatusCode, StatusOK; g != e {
2550 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2551 }
2552 body, _ := io.ReadAll(res.Body)
2553 if g, e := string(body), "hi"; g != e {
2554 t.Errorf("got body %q; expected %q", g, e)
2555 }
2556 if g := <-writeErrors; g != nil {
2557 t.Errorf("got unexpected Write error on first request: %v", g)
2558 }
2559
2560
2561 cancel()
2562
2563 res, err = cst.c.Get(cst.ts.URL)
2564 if err != nil {
2565 t.Error(err)
2566 }
2567 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2568 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2569 }
2570 body, _ = io.ReadAll(res.Body)
2571 if !strings.Contains(string(body), "<title>Timeout</title>") {
2572 t.Errorf("expected timeout body; got %q", string(body))
2573 }
2574 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2575 t.Errorf("response content-type = %q; want %q", g, w)
2576 }
2577
2578
2579
2580 sendHi <- true
2581 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2582 t.Errorf("expected Write error of %v; got %v", e, g)
2583 }
2584 }
2585
2586
2587 func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
2588 func testTimeoutHandlerRace(t *testing.T, mode testMode) {
2589 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2590 ms, _ := strconv.Atoi(r.URL.Path[1:])
2591 if ms == 0 {
2592 ms = 1
2593 }
2594 for i := 0; i < ms; i++ {
2595 w.Write([]byte("hi"))
2596 time.Sleep(time.Millisecond)
2597 }
2598 })
2599
2600 ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
2601
2602 c := ts.Client()
2603
2604 var wg sync.WaitGroup
2605 gate := make(chan bool, 10)
2606 n := 50
2607 if testing.Short() {
2608 n = 10
2609 gate = make(chan bool, 3)
2610 }
2611 for i := 0; i < n; i++ {
2612 gate <- true
2613 wg.Add(1)
2614 go func() {
2615 defer wg.Done()
2616 defer func() { <-gate }()
2617 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2618 if err == nil {
2619 io.Copy(io.Discard, res.Body)
2620 res.Body.Close()
2621 }
2622 }()
2623 }
2624 wg.Wait()
2625 }
2626
2627
2628
2629 func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
2630 func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
2631 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2632 w.WriteHeader(204)
2633 })
2634
2635 ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
2636
2637 var wg sync.WaitGroup
2638 gate := make(chan bool, 50)
2639 n := 500
2640 if testing.Short() {
2641 n = 10
2642 }
2643
2644 c := ts.Client()
2645 for i := 0; i < n; i++ {
2646 gate <- true
2647 wg.Add(1)
2648 go func() {
2649 defer wg.Done()
2650 defer func() { <-gate }()
2651 res, err := c.Get(ts.URL)
2652 if err != nil {
2653
2654
2655 t.Log(err)
2656 return
2657 }
2658 defer res.Body.Close()
2659 io.Copy(io.Discard, res.Body)
2660 }()
2661 }
2662 wg.Wait()
2663 }
2664
2665
2666 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
2667 func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
2668 sendHi := make(chan bool, 1)
2669 writeErrors := make(chan error, 1)
2670 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2671 w.Header().Set("Content-Type", "text/plain")
2672 <-sendHi
2673 _, werr := w.Write([]byte("hi"))
2674 writeErrors <- werr
2675 })
2676 ctx, cancel := context.WithCancel(context.Background())
2677 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2678 cst := newClientServerTest(t, mode, h)
2679
2680
2681 sendHi <- true
2682 res, err := cst.c.Get(cst.ts.URL)
2683 if err != nil {
2684 t.Error(err)
2685 }
2686 if g, e := res.StatusCode, StatusOK; g != e {
2687 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2688 }
2689 body, _ := io.ReadAll(res.Body)
2690 if g, e := string(body), "hi"; g != e {
2691 t.Errorf("got body %q; expected %q", g, e)
2692 }
2693 if g := <-writeErrors; g != nil {
2694 t.Errorf("got unexpected Write error on first request: %v", g)
2695 }
2696
2697
2698 cancel()
2699
2700 res, err = cst.c.Get(cst.ts.URL)
2701 if err != nil {
2702 t.Error(err)
2703 }
2704 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2705 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2706 }
2707 body, _ = io.ReadAll(res.Body)
2708 if !strings.Contains(string(body), "<title>Timeout</title>") {
2709 t.Errorf("expected timeout body; got %q", string(body))
2710 }
2711
2712
2713
2714 sendHi <- true
2715 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2716 t.Errorf("expected Write error of %v; got %v", e, g)
2717 }
2718 }
2719
2720
2721 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2722 run(t, testTimeoutHandlerStartTimerWhenServing)
2723 }
2724 func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
2725 if testing.Short() {
2726 t.Skip("skipping sleeping test in -short mode")
2727 }
2728 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2729 w.WriteHeader(StatusNoContent)
2730 }
2731 timeout := 300 * time.Millisecond
2732 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2733 defer ts.Close()
2734
2735 c := ts.Client()
2736
2737
2738
2739
2740 time.Sleep(2 * timeout)
2741 res, err := c.Get(ts.URL)
2742 if err != nil {
2743 t.Fatal(err)
2744 }
2745 defer res.Body.Close()
2746 if res.StatusCode != StatusNoContent {
2747 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2748 }
2749 }
2750
2751 func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
2752 func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
2753 writeErrors := make(chan error, 1)
2754 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2755 w.Header().Set("Content-Type", "text/plain")
2756 var err error
2757
2758
2759
2760 for i := 0; i < 100; i++ {
2761 _, err = w.Write([]byte("a"))
2762 if err != nil {
2763 break
2764 }
2765 time.Sleep(1 * time.Millisecond)
2766 }
2767 writeErrors <- err
2768 })
2769 ctx, cancel := context.WithCancel(context.Background())
2770 cancel()
2771 h := NewTestTimeoutHandler(sayHi, ctx)
2772 cst := newClientServerTest(t, mode, h)
2773 defer cst.close()
2774
2775 res, err := cst.c.Get(cst.ts.URL)
2776 if err != nil {
2777 t.Error(err)
2778 }
2779 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2780 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2781 }
2782 body, _ := io.ReadAll(res.Body)
2783 if g, e := string(body), ""; g != e {
2784 t.Errorf("got body %q; expected %q", g, e)
2785 }
2786 if g, e := <-writeErrors, context.Canceled; g != e {
2787 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2788 }
2789 }
2790
2791
2792 func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
2793 func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
2794 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2795
2796 }
2797 timeout := 300 * time.Millisecond
2798 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2799
2800 c := ts.Client()
2801
2802 res, err := c.Get(ts.URL)
2803 if err != nil {
2804 t.Fatal(err)
2805 }
2806 defer res.Body.Close()
2807 if res.StatusCode != StatusOK {
2808 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2809 }
2810 }
2811
2812
2813 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2814 wrapper := func(h Handler) Handler {
2815 return TimeoutHandler(h, time.Second, "")
2816 }
2817 run(t, func(t *testing.T, mode testMode) {
2818 testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
2819 }, testNotParallel)
2820 }
2821
2822 func TestRedirectBadPath(t *testing.T) {
2823
2824
2825 rr := httptest.NewRecorder()
2826 req := &Request{
2827 Method: "GET",
2828 URL: &url.URL{
2829 Scheme: "http",
2830 Path: "not-empty-but-no-leading-slash",
2831 },
2832 }
2833 Redirect(rr, req, "", 304)
2834 if rr.Code != 304 {
2835 t.Errorf("Code = %d; want 304", rr.Code)
2836 }
2837 }
2838
2839
2840 func TestRedirect(t *testing.T) {
2841 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2842
2843 var tests = []struct {
2844 in string
2845 want string
2846 }{
2847
2848 {"http://foobar.com/baz", "http://foobar.com/baz"},
2849
2850 {"https://foobar.com/baz", "https://foobar.com/baz"},
2851
2852 {"test://foobar.com/baz", "test://foobar.com/baz"},
2853
2854 {"//foobar.com/baz", "//foobar.com/baz"},
2855
2856 {"/foobar.com/baz", "/foobar.com/baz"},
2857
2858 {"foobar.com/baz", "/qux/foobar.com/baz"},
2859
2860 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2861
2862 {"///foobar.com/baz", "/foobar.com/baz"},
2863
2864
2865 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2866 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2867 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2868
2869 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2870 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2871 }
2872
2873 for _, tt := range tests {
2874 rec := httptest.NewRecorder()
2875 Redirect(rec, req, tt.in, 302)
2876 if got, want := rec.Code, 302; got != want {
2877 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
2878 }
2879 if got := rec.Header().Get("Location"); got != tt.want {
2880 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
2881 }
2882 }
2883 }
2884
2885
2886
2887 func TestRedirectContentTypeAndBody(t *testing.T) {
2888 type ctHeader struct {
2889 Values []string
2890 }
2891
2892 var tests = []struct {
2893 method string
2894 ct *ctHeader
2895 wantCT string
2896 wantBody string
2897 }{
2898 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
2899 {MethodHead, nil, "text/html; charset=utf-8", ""},
2900 {MethodPost, nil, "", ""},
2901 {MethodDelete, nil, "", ""},
2902 {"foo", nil, "", ""},
2903 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
2904 {MethodGet, &ctHeader{[]string{}}, "", ""},
2905 {MethodGet, &ctHeader{nil}, "", ""},
2906 }
2907 for _, tt := range tests {
2908 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
2909 rec := httptest.NewRecorder()
2910 if tt.ct != nil {
2911 rec.Header()["Content-Type"] = tt.ct.Values
2912 }
2913 Redirect(rec, req, "/foo", 302)
2914 if got, want := rec.Code, 302; got != want {
2915 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
2916 }
2917 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
2918 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
2919 }
2920 resp := rec.Result()
2921 body, err := io.ReadAll(resp.Body)
2922 if err != nil {
2923 t.Fatal(err)
2924 }
2925 if got, want := string(body), tt.wantBody; got != want {
2926 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
2927 }
2928 }
2929 }
2930
2931
2932
2933
2934
2935
2936
2937 func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
2938
2939 func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
2940 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
2941 all, err := io.ReadAll(r.Body)
2942 if err != nil {
2943 t.Fatalf("handler ReadAll: %v", err)
2944 }
2945 if len(all) != 0 {
2946 t.Errorf("handler got %d bytes; expected 0", len(all))
2947 }
2948 rw.Header().Set("Content-Length", "0")
2949 }))
2950
2951 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
2952 if err != nil {
2953 t.Fatal(err)
2954 }
2955 req.ContentLength = 0
2956
2957 var resp [5]*Response
2958 for i := range resp {
2959 resp[i], err = cst.c.Do(req)
2960 if err != nil {
2961 t.Fatalf("client post #%d: %v", i, err)
2962 }
2963 }
2964
2965 for i := range resp {
2966 all, err := io.ReadAll(resp[i].Body)
2967 if err != nil {
2968 t.Fatalf("req #%d: client ReadAll: %v", i, err)
2969 }
2970 if len(all) != 0 {
2971 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
2972 }
2973 }
2974 }
2975
2976 func TestHandlerPanicNil(t *testing.T) {
2977 run(t, func(t *testing.T, mode testMode) {
2978 testHandlerPanic(t, false, mode, nil, nil)
2979 }, testNotParallel)
2980 }
2981
2982 func TestHandlerPanic(t *testing.T) {
2983 run(t, func(t *testing.T, mode testMode) {
2984 testHandlerPanic(t, false, mode, nil, "intentional death for testing")
2985 }, testNotParallel)
2986 }
2987
2988 func TestHandlerPanicWithHijack(t *testing.T) {
2989
2990 run(t, func(t *testing.T, mode testMode) {
2991 testHandlerPanic(t, true, mode, nil, "intentional death for testing")
2992 }, []testMode{http1Mode})
2993 }
2994
2995 func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
2996
2997
2998
2999
3000
3001
3002
3003
3004 pr, pw := io.Pipe()
3005 defer pw.Close()
3006
3007 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
3008 if withHijack {
3009 rwc, _, err := w.(Hijacker).Hijack()
3010 if err != nil {
3011 t.Logf("unexpected error: %v", err)
3012 }
3013 defer rwc.Close()
3014 }
3015 panic(panicValue)
3016 })
3017 if wrapper != nil {
3018 handler = wrapper(handler)
3019 }
3020 cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
3021 ts.Config.ErrorLog = log.New(pw, "", 0)
3022 })
3023
3024
3025 done := make(chan bool, 1)
3026 go func() {
3027 buf := make([]byte, 4<<10)
3028 _, err := pr.Read(buf)
3029 pr.Close()
3030 if err != nil && err != io.EOF {
3031 t.Error(err)
3032 }
3033 done <- true
3034 }()
3035
3036 _, err := cst.c.Get(cst.ts.URL)
3037 if err == nil {
3038 t.Logf("expected an error")
3039 }
3040
3041 if panicValue == nil {
3042 return
3043 }
3044
3045 <-done
3046 }
3047
3048 type terrorWriter struct{ t *testing.T }
3049
3050 func (w terrorWriter) Write(p []byte) (int, error) {
3051 w.t.Errorf("%s", p)
3052 return len(p), nil
3053 }
3054
3055
3056
3057 func TestServerWriteHijackZeroBytes(t *testing.T) {
3058 run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
3059 }
3060 func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
3061 done := make(chan struct{})
3062 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3063 defer close(done)
3064 w.(Flusher).Flush()
3065 conn, _, err := w.(Hijacker).Hijack()
3066 if err != nil {
3067 t.Errorf("Hijack: %v", err)
3068 return
3069 }
3070 defer conn.Close()
3071 _, err = w.Write(nil)
3072 if err != ErrHijacked {
3073 t.Errorf("Write error = %v; want ErrHijacked", err)
3074 }
3075 }), func(ts *httptest.Server) {
3076 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
3077 }).ts
3078
3079 c := ts.Client()
3080 res, err := c.Get(ts.URL)
3081 if err != nil {
3082 t.Fatal(err)
3083 }
3084 res.Body.Close()
3085 <-done
3086 }
3087
3088 func TestServerNoDate(t *testing.T) {
3089 run(t, func(t *testing.T, mode testMode) {
3090 testServerNoHeader(t, mode, "Date")
3091 })
3092 }
3093
3094 func TestServerContentType(t *testing.T) {
3095 run(t, func(t *testing.T, mode testMode) {
3096 testServerNoHeader(t, mode, "Content-Type")
3097 })
3098 }
3099
3100 func testServerNoHeader(t *testing.T, mode testMode, header string) {
3101 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3102 w.Header()[header] = nil
3103 io.WriteString(w, "<html>foo</html>")
3104 }))
3105 res, err := cst.c.Get(cst.ts.URL)
3106 if err != nil {
3107 t.Fatal(err)
3108 }
3109 res.Body.Close()
3110 if got, ok := res.Header[header]; ok {
3111 t.Fatalf("Expected no %s header; got %q", header, got)
3112 }
3113 }
3114
3115 func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
3116 func testStripPrefix(t *testing.T, mode testMode) {
3117 h := HandlerFunc(func(w ResponseWriter, r *Request) {
3118 w.Header().Set("X-Path", r.URL.Path)
3119 w.Header().Set("X-RawPath", r.URL.RawPath)
3120 })
3121 ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
3122
3123 c := ts.Client()
3124
3125 cases := []struct {
3126 reqPath string
3127 path string
3128 rawPath string
3129 }{
3130 {"/foo/bar/qux", "/qux", ""},
3131 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
3132 {"/foo%2Fbar/qux", "", ""},
3133 {"/bar", "", ""},
3134 }
3135 for _, tc := range cases {
3136 t.Run(tc.reqPath, func(t *testing.T) {
3137 res, err := c.Get(ts.URL + tc.reqPath)
3138 if err != nil {
3139 t.Fatal(err)
3140 }
3141 res.Body.Close()
3142 if tc.path == "" {
3143 if res.StatusCode != StatusNotFound {
3144 t.Errorf("got %q, want 404 Not Found", res.Status)
3145 }
3146 return
3147 }
3148 if res.StatusCode != StatusOK {
3149 t.Fatalf("got %q, want 200 OK", res.Status)
3150 }
3151 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
3152 t.Errorf("got Path %q, want %q", g, w)
3153 }
3154 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
3155 t.Errorf("got RawPath %q, want %q", g, w)
3156 }
3157 })
3158 }
3159 }
3160
3161
3162 func TestStripPrefixNotModifyRequest(t *testing.T) {
3163 h := StripPrefix("/foo", NotFoundHandler())
3164 req := httptest.NewRequest("GET", "/foo/bar", nil)
3165 h.ServeHTTP(httptest.NewRecorder(), req)
3166 if req.URL.Path != "/foo/bar" {
3167 t.Errorf("StripPrefix should not modify the provided Request, but it did")
3168 }
3169 }
3170
3171 func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) }
3172 func testRequestLimit(t *testing.T, mode testMode) {
3173 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3174 t.Fatalf("didn't expect to get request in Handler")
3175 }), optQuietLog)
3176 req, _ := NewRequest("GET", cst.ts.URL, nil)
3177 var bytesPerHeader = len("header12345: val12345\r\n")
3178 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
3179 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
3180 }
3181 res, err := cst.c.Do(req)
3182 if res != nil {
3183 defer res.Body.Close()
3184 }
3185 if mode == http2Mode {
3186
3187
3188
3189
3190 if err == nil && res.StatusCode != 431 {
3191 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3192 }
3193 } else {
3194
3195
3196
3197
3198 if err != nil {
3199 t.Fatalf("Do: %v", err)
3200 }
3201 if res.StatusCode != 431 {
3202 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3203 }
3204 }
3205 }
3206
3207 type neverEnding byte
3208
3209 func (b neverEnding) Read(p []byte) (n int, err error) {
3210 for i := range p {
3211 p[i] = byte(b)
3212 }
3213 return len(p), nil
3214 }
3215
3216 type bodyLimitReader struct {
3217 mu sync.Mutex
3218 count int
3219 limit int
3220 closed chan struct{}
3221 }
3222
3223 func (r *bodyLimitReader) Read(p []byte) (int, error) {
3224 r.mu.Lock()
3225 defer r.mu.Unlock()
3226 select {
3227 case <-r.closed:
3228 return 0, errors.New("closed")
3229 default:
3230 }
3231 if r.count > r.limit {
3232 return 0, errors.New("at limit")
3233 }
3234 r.count += len(p)
3235 for i := range p {
3236 p[i] = 'a'
3237 }
3238 return len(p), nil
3239 }
3240
3241 func (r *bodyLimitReader) Close() error {
3242 r.mu.Lock()
3243 defer r.mu.Unlock()
3244 close(r.closed)
3245 return nil
3246 }
3247
3248 func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
3249 func testRequestBodyLimit(t *testing.T, mode testMode) {
3250 const limit = 1 << 20
3251 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3252 r.Body = MaxBytesReader(w, r.Body, limit)
3253 n, err := io.Copy(io.Discard, r.Body)
3254 if err == nil {
3255 t.Errorf("expected error from io.Copy")
3256 }
3257 if n != limit {
3258 t.Errorf("io.Copy = %d, want %d", n, limit)
3259 }
3260 mbErr, ok := err.(*MaxBytesError)
3261 if !ok {
3262 t.Errorf("expected MaxBytesError, got %T", err)
3263 }
3264 if mbErr.Limit != limit {
3265 t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
3266 }
3267 }))
3268
3269 body := &bodyLimitReader{
3270 closed: make(chan struct{}),
3271 limit: limit * 200,
3272 }
3273 req, _ := NewRequest("POST", cst.ts.URL, body)
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284 resp, err := cst.c.Do(req)
3285 if err == nil {
3286 resp.Body.Close()
3287 }
3288
3289
3290 <-body.closed
3291
3292 if body.count > limit*100 {
3293 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3294 limit, body.count)
3295 }
3296 }
3297
3298
3299
3300 func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) }
3301 func testClientWriteShutdown(t *testing.T, mode testMode) {
3302 if runtime.GOOS == "plan9" {
3303 t.Skip("skipping test; see https://golang.org/issue/17906")
3304 }
3305 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
3306 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3307 if err != nil {
3308 t.Fatalf("Dial: %v", err)
3309 }
3310 err = conn.(*net.TCPConn).CloseWrite()
3311 if err != nil {
3312 t.Fatalf("CloseWrite: %v", err)
3313 }
3314
3315 bs, err := io.ReadAll(conn)
3316 if err != nil {
3317 t.Errorf("ReadAll: %v", err)
3318 }
3319 got := string(bs)
3320 if got != "" {
3321 t.Errorf("read %q from server; want nothing", got)
3322 }
3323 }
3324
3325
3326
3327 func TestServerBufferedChunking(t *testing.T) {
3328 conn := new(testConn)
3329 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3330 conn.closec = make(chan bool, 1)
3331 ls := &oneConnListener{conn}
3332 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3333 rw.(Flusher).Flush()
3334 rw.Write([]byte{'x'})
3335 rw.Write([]byte{'y'})
3336 rw.Write([]byte{'z'})
3337 }))
3338 <-conn.closec
3339 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3340 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3341 conn.writeBuf.Bytes())
3342 }
3343 }
3344
3345
3346
3347
3348
3349 func TestServerGracefulClose(t *testing.T) {
3350
3351 run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
3352 }
3353 func testServerGracefulClose(t *testing.T, mode testMode) {
3354 runTimeSensitiveTest(t, []time.Duration{
3355 1 * time.Millisecond,
3356 5 * time.Millisecond,
3357 10 * time.Millisecond,
3358 50 * time.Millisecond,
3359 100 * time.Millisecond,
3360 500 * time.Millisecond,
3361 time.Second,
3362 5 * time.Second,
3363 }, func(t *testing.T, timeout time.Duration) error {
3364 SetRSTAvoidanceDelay(t, timeout)
3365 t.Logf("set RST avoidance delay to %v", timeout)
3366
3367 const bodySize = 5 << 20
3368 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3369 for i := 0; i < bodySize; i++ {
3370 req = append(req, 'x')
3371 }
3372
3373 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3374 Error(w, "bye", StatusUnauthorized)
3375 }))
3376
3377
3378 defer cst.close()
3379 ts := cst.ts
3380
3381 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3382 if err != nil {
3383 return err
3384 }
3385 writeErr := make(chan error)
3386 go func() {
3387 _, err := conn.Write(req)
3388 writeErr <- err
3389 }()
3390 defer func() {
3391 conn.Close()
3392
3393
3394
3395 <-writeErr
3396 }()
3397
3398 br := bufio.NewReader(conn)
3399 lineNum := 0
3400 for {
3401 line, err := br.ReadString('\n')
3402 if err == io.EOF {
3403 break
3404 }
3405 if err != nil {
3406 return fmt.Errorf("ReadLine: %v", err)
3407 }
3408 lineNum++
3409 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3410 t.Errorf("Response line = %q; want a 401", line)
3411 }
3412 }
3413 return nil
3414 })
3415 }
3416
3417 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
3418 func testCaseSensitiveMethod(t *testing.T, mode testMode) {
3419 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3420 if r.Method != "get" {
3421 t.Errorf(`Got method %q; want "get"`, r.Method)
3422 }
3423 }))
3424 defer cst.close()
3425 req, _ := NewRequest("get", cst.ts.URL, nil)
3426 res, err := cst.c.Do(req)
3427 if err != nil {
3428 t.Error(err)
3429 return
3430 }
3431
3432 res.Body.Close()
3433 }
3434
3435
3436
3437
3438
3439 func TestContentLengthZero(t *testing.T) {
3440 run(t, testContentLengthZero, []testMode{http1Mode})
3441 }
3442 func testContentLengthZero(t *testing.T, mode testMode) {
3443 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
3444
3445 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3446 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3447 if err != nil {
3448 t.Fatalf("error dialing: %v", err)
3449 }
3450 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3451 if err != nil {
3452 t.Fatalf("error writing: %v", err)
3453 }
3454 req, _ := NewRequest("GET", "/", nil)
3455 res, err := ReadResponse(bufio.NewReader(conn), req)
3456 if err != nil {
3457 t.Fatalf("error reading response: %v", err)
3458 }
3459 if te := res.TransferEncoding; len(te) > 0 {
3460 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3461 }
3462 if cl := res.ContentLength; cl != 0 {
3463 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3464 }
3465 conn.Close()
3466 }
3467 }
3468
3469 func TestCloseNotifier(t *testing.T) {
3470 run(t, testCloseNotifier, []testMode{http1Mode})
3471 }
3472 func testCloseNotifier(t *testing.T, mode testMode) {
3473 gotReq := make(chan bool, 1)
3474 sawClose := make(chan bool, 1)
3475 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3476 gotReq <- true
3477 cc := rw.(CloseNotifier).CloseNotify()
3478 <-cc
3479 sawClose <- true
3480 })).ts
3481 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3482 if err != nil {
3483 t.Fatalf("error dialing: %v", err)
3484 }
3485 diec := make(chan bool)
3486 go func() {
3487 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3488 if err != nil {
3489 t.Error(err)
3490 return
3491 }
3492 <-diec
3493 conn.Close()
3494 }()
3495 For:
3496 for {
3497 select {
3498 case <-gotReq:
3499 diec <- true
3500 case <-sawClose:
3501 break For
3502 }
3503 }
3504 ts.Close()
3505 }
3506
3507
3508
3509
3510
3511 func TestCloseNotifierPipelined(t *testing.T) {
3512 run(t, testCloseNotifierPipelined, []testMode{http1Mode})
3513 }
3514 func testCloseNotifierPipelined(t *testing.T, mode testMode) {
3515 gotReq := make(chan bool, 2)
3516 sawClose := make(chan bool, 2)
3517 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3518 gotReq <- true
3519 cc := rw.(CloseNotifier).CloseNotify()
3520 select {
3521 case <-cc:
3522 t.Error("unexpected CloseNotify")
3523 case <-time.After(100 * time.Millisecond):
3524 }
3525 sawClose <- true
3526 })).ts
3527 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3528 if err != nil {
3529 t.Fatalf("error dialing: %v", err)
3530 }
3531 diec := make(chan bool, 1)
3532 defer close(diec)
3533 go func() {
3534 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3535 _, err = io.WriteString(conn, req+req)
3536 if err != nil {
3537 t.Error(err)
3538 return
3539 }
3540 <-diec
3541 conn.Close()
3542 }()
3543 reqs := 0
3544 closes := 0
3545 for {
3546 select {
3547 case <-gotReq:
3548 reqs++
3549 if reqs > 2 {
3550 t.Fatal("too many requests")
3551 }
3552 case <-sawClose:
3553 closes++
3554 if closes > 1 {
3555 return
3556 }
3557 }
3558 }
3559 }
3560
3561 func TestCloseNotifierChanLeak(t *testing.T) {
3562 defer afterTest(t)
3563 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3564 for i := 0; i < 20; i++ {
3565 var output bytes.Buffer
3566 conn := &rwTestConn{
3567 Reader: bytes.NewReader(req),
3568 Writer: &output,
3569 closec: make(chan bool, 1),
3570 }
3571 ln := &oneConnListener{conn: conn}
3572 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3573
3574
3575
3576 _ = rw.(CloseNotifier).CloseNotify()
3577 })
3578 go Serve(ln, handler)
3579 <-conn.closec
3580 }
3581 }
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592 func TestHijackAfterCloseNotifier(t *testing.T) {
3593 run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
3594 }
3595 func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
3596 script := make(chan string, 2)
3597 script <- "closenotify"
3598 script <- "hijack"
3599 close(script)
3600 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3601 plan := <-script
3602 switch plan {
3603 default:
3604 panic("bogus plan; too many requests")
3605 case "closenotify":
3606 w.(CloseNotifier).CloseNotify()
3607 w.Header().Set("X-Addr", r.RemoteAddr)
3608 case "hijack":
3609 c, _, err := w.(Hijacker).Hijack()
3610 if err != nil {
3611 t.Errorf("Hijack in Handler: %v", err)
3612 return
3613 }
3614 if _, ok := c.(*net.TCPConn); !ok {
3615
3616
3617 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3618 }
3619 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3620 c.Close()
3621 return
3622 }
3623 })).ts
3624 res1, err := ts.Client().Get(ts.URL)
3625 if err != nil {
3626 log.Fatal(err)
3627 }
3628 res2, err := ts.Client().Get(ts.URL)
3629 if err != nil {
3630 log.Fatal(err)
3631 }
3632 addr1 := res1.Header.Get("X-Addr")
3633 addr2 := res2.Header.Get("X-Addr")
3634 if addr1 == "" || addr1 != addr2 {
3635 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3636 }
3637 }
3638
3639 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3640 run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
3641 }
3642 func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
3643 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3644 bodyOkay := make(chan bool, 1)
3645 gotCloseNotify := make(chan bool, 1)
3646 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3647 defer close(bodyOkay)
3648
3649 reqBody := r.Body
3650 r.Body = nil
3651
3652 gone := w.(CloseNotifier).CloseNotify()
3653 slurp, err := io.ReadAll(reqBody)
3654 if err != nil {
3655 t.Errorf("Body read: %v", err)
3656 return
3657 }
3658 if len(slurp) != len(requestBody) {
3659 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3660 return
3661 }
3662 if !bytes.Equal(slurp, requestBody) {
3663 t.Error("Backend read wrong request body.")
3664 return
3665 }
3666 bodyOkay <- true
3667 <-gone
3668 gotCloseNotify <- true
3669 })).ts
3670
3671 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3672 if err != nil {
3673 t.Fatal(err)
3674 }
3675 defer conn.Close()
3676
3677 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3678 len(requestBody), requestBody)
3679 if !<-bodyOkay {
3680
3681 return
3682 }
3683 conn.Close()
3684 <-gotCloseNotify
3685 }
3686
3687 func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
3688 func testOptions(t *testing.T, mode testMode) {
3689 uric := make(chan string, 2)
3690 mux := NewServeMux()
3691 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3692 uric <- r.RequestURI
3693 })
3694 ts := newClientServerTest(t, mode, mux).ts
3695
3696 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3697 if err != nil {
3698 t.Fatal(err)
3699 }
3700 defer conn.Close()
3701
3702
3703 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3704 if err != nil {
3705 t.Fatal(err)
3706 }
3707 br := bufio.NewReader(conn)
3708 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3709 if err != nil {
3710 t.Fatal(err)
3711 }
3712 if res.StatusCode != 200 {
3713 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3714 }
3715
3716
3717 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3718 if err != nil {
3719 t.Fatal(err)
3720 }
3721 res, err = ReadResponse(br, &Request{Method: "GET"})
3722 if err != nil {
3723 t.Fatal(err)
3724 }
3725 if res.StatusCode != 400 {
3726 t.Errorf("Got non-400 response to GET *: %#v", res)
3727 }
3728
3729 res, err = Get(ts.URL + "/second")
3730 if err != nil {
3731 t.Fatal(err)
3732 }
3733 res.Body.Close()
3734 if got := <-uric; got != "/second" {
3735 t.Errorf("Handler saw request for %q; want /second", got)
3736 }
3737 }
3738
3739 func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
3740 func testOptionsHandler(t *testing.T, mode testMode) {
3741 rc := make(chan *Request, 1)
3742
3743 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3744 rc <- r
3745 }), func(ts *httptest.Server) {
3746 ts.Config.DisableGeneralOptionsHandler = true
3747 }).ts
3748
3749 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3750 if err != nil {
3751 t.Fatal(err)
3752 }
3753 defer conn.Close()
3754
3755 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3756 if err != nil {
3757 t.Fatal(err)
3758 }
3759
3760 if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
3761 t.Errorf("Expected OPTIONS * request, got %v", got)
3762 }
3763 }
3764
3765
3766
3767
3768
3769
3770
3771
3772
3773
3774 func TestHeaderToWire(t *testing.T) {
3775 tests := []struct {
3776 name string
3777 handler func(ResponseWriter, *Request)
3778 check func(got, logs string) error
3779 }{
3780 {
3781 name: "write without Header",
3782 handler: func(rw ResponseWriter, r *Request) {
3783 rw.Write([]byte("hello world"))
3784 },
3785 check: func(got, logs string) error {
3786 if !strings.Contains(got, "Content-Length:") {
3787 return errors.New("no content-length")
3788 }
3789 if !strings.Contains(got, "Content-Type: text/plain") {
3790 return errors.New("no content-type")
3791 }
3792 return nil
3793 },
3794 },
3795 {
3796 name: "Header mutation before write",
3797 handler: func(rw ResponseWriter, r *Request) {
3798 h := rw.Header()
3799 h.Set("Content-Type", "some/type")
3800 rw.Write([]byte("hello world"))
3801 h.Set("Too-Late", "bogus")
3802 },
3803 check: func(got, logs string) error {
3804 if !strings.Contains(got, "Content-Length:") {
3805 return errors.New("no content-length")
3806 }
3807 if !strings.Contains(got, "Content-Type: some/type") {
3808 return errors.New("wrong content-type")
3809 }
3810 if strings.Contains(got, "Too-Late") {
3811 return errors.New("don't want too-late header")
3812 }
3813 return nil
3814 },
3815 },
3816 {
3817 name: "write then useless Header mutation",
3818 handler: func(rw ResponseWriter, r *Request) {
3819 rw.Write([]byte("hello world"))
3820 rw.Header().Set("Too-Late", "Write already wrote headers")
3821 },
3822 check: func(got, logs string) error {
3823 if strings.Contains(got, "Too-Late") {
3824 return errors.New("header appeared from after WriteHeader")
3825 }
3826 return nil
3827 },
3828 },
3829 {
3830 name: "flush then write",
3831 handler: func(rw ResponseWriter, r *Request) {
3832 rw.(Flusher).Flush()
3833 rw.Write([]byte("post-flush"))
3834 rw.Header().Set("Too-Late", "Write already wrote headers")
3835 },
3836 check: func(got, logs string) error {
3837 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3838 return errors.New("not chunked")
3839 }
3840 if strings.Contains(got, "Too-Late") {
3841 return errors.New("header appeared from after WriteHeader")
3842 }
3843 return nil
3844 },
3845 },
3846 {
3847 name: "header then flush",
3848 handler: func(rw ResponseWriter, r *Request) {
3849 rw.Header().Set("Content-Type", "some/type")
3850 rw.(Flusher).Flush()
3851 rw.Write([]byte("post-flush"))
3852 rw.Header().Set("Too-Late", "Write already wrote headers")
3853 },
3854 check: func(got, logs string) error {
3855 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3856 return errors.New("not chunked")
3857 }
3858 if strings.Contains(got, "Too-Late") {
3859 return errors.New("header appeared from after WriteHeader")
3860 }
3861 if !strings.Contains(got, "Content-Type: some/type") {
3862 return errors.New("wrong content-type")
3863 }
3864 return nil
3865 },
3866 },
3867 {
3868 name: "sniff-on-first-write content-type",
3869 handler: func(rw ResponseWriter, r *Request) {
3870 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3871 rw.Header().Set("Content-Type", "x/wrong")
3872 },
3873 check: func(got, logs string) error {
3874 if !strings.Contains(got, "Content-Type: text/html") {
3875 return errors.New("wrong content-type; want html")
3876 }
3877 return nil
3878 },
3879 },
3880 {
3881 name: "explicit content-type wins",
3882 handler: func(rw ResponseWriter, r *Request) {
3883 rw.Header().Set("Content-Type", "some/type")
3884 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3885 },
3886 check: func(got, logs string) error {
3887 if !strings.Contains(got, "Content-Type: some/type") {
3888 return errors.New("wrong content-type; want html")
3889 }
3890 return nil
3891 },
3892 },
3893 {
3894 name: "empty handler",
3895 handler: func(rw ResponseWriter, r *Request) {
3896 },
3897 check: func(got, logs string) error {
3898 if !strings.Contains(got, "Content-Length: 0") {
3899 return errors.New("want 0 content-length")
3900 }
3901 return nil
3902 },
3903 },
3904 {
3905 name: "only Header, no write",
3906 handler: func(rw ResponseWriter, r *Request) {
3907 rw.Header().Set("Some-Header", "some-value")
3908 },
3909 check: func(got, logs string) error {
3910 if !strings.Contains(got, "Some-Header") {
3911 return errors.New("didn't get header")
3912 }
3913 return nil
3914 },
3915 },
3916 {
3917 name: "WriteHeader call",
3918 handler: func(rw ResponseWriter, r *Request) {
3919 rw.WriteHeader(404)
3920 rw.Header().Set("Too-Late", "some-value")
3921 },
3922 check: func(got, logs string) error {
3923 if !strings.Contains(got, "404") {
3924 return errors.New("wrong status")
3925 }
3926 if strings.Contains(got, "Too-Late") {
3927 return errors.New("shouldn't have seen Too-Late")
3928 }
3929 return nil
3930 },
3931 },
3932 }
3933 for _, tc := range tests {
3934 ht := newHandlerTest(HandlerFunc(tc.handler))
3935 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
3936 logs := ht.logbuf.String()
3937 if err := tc.check(got, logs); err != nil {
3938 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
3939 }
3940 }
3941 }
3942
3943 type errorListener struct {
3944 errs []error
3945 }
3946
3947 func (l *errorListener) Accept() (c net.Conn, err error) {
3948 if len(l.errs) == 0 {
3949 return nil, io.EOF
3950 }
3951 err = l.errs[0]
3952 l.errs = l.errs[1:]
3953 return
3954 }
3955
3956 func (l *errorListener) Close() error {
3957 return nil
3958 }
3959
3960 func (l *errorListener) Addr() net.Addr {
3961 return dummyAddr("test-address")
3962 }
3963
3964 func TestAcceptMaxFds(t *testing.T) {
3965 setParallel(t)
3966
3967 ln := &errorListener{[]error{
3968 &net.OpError{
3969 Op: "accept",
3970 Err: syscall.EMFILE,
3971 }}}
3972 server := &Server{
3973 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
3974 ErrorLog: log.New(io.Discard, "", 0),
3975 }
3976 err := server.Serve(ln)
3977 if err != io.EOF {
3978 t.Errorf("got error %v, want EOF", err)
3979 }
3980 }
3981
3982 func TestWriteAfterHijack(t *testing.T) {
3983 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
3984 var buf strings.Builder
3985 wrotec := make(chan bool, 1)
3986 conn := &rwTestConn{
3987 Reader: bytes.NewReader(req),
3988 Writer: &buf,
3989 closec: make(chan bool, 1),
3990 }
3991 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3992 conn, bufrw, err := rw.(Hijacker).Hijack()
3993 if err != nil {
3994 t.Error(err)
3995 return
3996 }
3997 go func() {
3998 bufrw.Write([]byte("[hijack-to-bufw]"))
3999 bufrw.Flush()
4000 conn.Write([]byte("[hijack-to-conn]"))
4001 conn.Close()
4002 wrotec <- true
4003 }()
4004 })
4005 ln := &oneConnListener{conn: conn}
4006 go Serve(ln, handler)
4007 <-conn.closec
4008 <-wrotec
4009 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
4010 t.Errorf("wrote %q; want %q", g, w)
4011 }
4012 }
4013
4014 func TestDoubleHijack(t *testing.T) {
4015 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
4016 var buf bytes.Buffer
4017 conn := &rwTestConn{
4018 Reader: bytes.NewReader(req),
4019 Writer: &buf,
4020 closec: make(chan bool, 1),
4021 }
4022 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
4023 conn, _, err := rw.(Hijacker).Hijack()
4024 if err != nil {
4025 t.Error(err)
4026 return
4027 }
4028 _, _, err = rw.(Hijacker).Hijack()
4029 if err == nil {
4030 t.Errorf("got err = nil; want err != nil")
4031 }
4032 conn.Close()
4033 })
4034 ln := &oneConnListener{conn: conn}
4035 go Serve(ln, handler)
4036 <-conn.closec
4037 }
4038
4039
4040
4041
4042
4043
4044
4045 func TestHTTP10ConnectionHeader(t *testing.T) {
4046 run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
4047 }
4048 func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
4049 mux := NewServeMux()
4050 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
4051 ts := newClientServerTest(t, mode, mux).ts
4052
4053
4054 tests := []struct {
4055 req string
4056 expect []string
4057 }{
4058 {
4059 req: "GET / HTTP/1.0\r\n\r\n",
4060 expect: nil,
4061 },
4062 {
4063 req: "OPTIONS * HTTP/1.0\r\n\r\n",
4064 expect: nil,
4065 },
4066 {
4067 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
4068 expect: []string{"keep-alive"},
4069 },
4070 }
4071
4072 for _, tt := range tests {
4073 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
4074 if err != nil {
4075 t.Fatal("dial err:", err)
4076 }
4077
4078 _, err = fmt.Fprint(conn, tt.req)
4079 if err != nil {
4080 t.Fatal("conn write err:", err)
4081 }
4082
4083 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
4084 if err != nil {
4085 t.Fatal("ReadResponse err:", err)
4086 }
4087 conn.Close()
4088 resp.Body.Close()
4089
4090 got := resp.Header["Connection"]
4091 if !slices.Equal(got, tt.expect) {
4092 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
4093 }
4094 }
4095 }
4096
4097
4098 func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
4099 func testServerReaderFromOrder(t *testing.T, mode testMode) {
4100 pr, pw := io.Pipe()
4101 const size = 3 << 20
4102 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4103 rw.Header().Set("Content-Type", "text/plain")
4104 done := make(chan bool)
4105 go func() {
4106 io.Copy(rw, pr)
4107 close(done)
4108 }()
4109 time.Sleep(25 * time.Millisecond)
4110 n, err := io.Copy(io.Discard, req.Body)
4111 if err != nil {
4112 t.Errorf("handler Copy: %v", err)
4113 return
4114 }
4115 if n != size {
4116 t.Errorf("handler Copy = %d; want %d", n, size)
4117 }
4118 pw.Write([]byte("hi"))
4119 pw.Close()
4120 <-done
4121 }))
4122
4123 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
4124 if err != nil {
4125 t.Fatal(err)
4126 }
4127 res, err := cst.c.Do(req)
4128 if err != nil {
4129 t.Fatal(err)
4130 }
4131 all, err := io.ReadAll(res.Body)
4132 if err != nil {
4133 t.Fatal(err)
4134 }
4135 res.Body.Close()
4136 if string(all) != "hi" {
4137 t.Errorf("Body = %q; want hi", all)
4138 }
4139 }
4140
4141
4142 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
4143 for _, code := range []int{StatusNotModified, StatusNoContent} {
4144 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4145 if r.URL.Path == "/header" {
4146 w.Header().Set("Content-Length", "123")
4147 }
4148 w.WriteHeader(code)
4149 if r.URL.Path == "/more" {
4150 w.Write([]byte("stuff"))
4151 }
4152 }))
4153 for _, req := range []string{
4154 "GET / HTTP/1.0",
4155 "GET /header HTTP/1.0",
4156 "GET /more HTTP/1.0",
4157 "GET / HTTP/1.1\nHost: foo",
4158 "GET /header HTTP/1.1\nHost: foo",
4159 "GET /more HTTP/1.1\nHost: foo",
4160 } {
4161 got := ht.rawResponse(req)
4162 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
4163 if !strings.Contains(got, wantStatus) {
4164 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
4165 } else if strings.Contains(got, "Content-Length") {
4166 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
4167 } else if strings.Contains(got, "stuff") {
4168 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
4169 }
4170 }
4171 }
4172 }
4173
4174 func TestContentTypeOkayOn204(t *testing.T) {
4175 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4176 w.Header().Set("Content-Length", "123")
4177 w.Header().Set("Content-Type", "foo/bar")
4178 w.WriteHeader(204)
4179 }))
4180 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4181 if !strings.Contains(got, "Content-Type: foo/bar") {
4182 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
4183 }
4184 if strings.Contains(got, "Content-Length: 123") {
4185 t.Errorf("Response = %q; don't want a Content-Length", got)
4186 }
4187 }
4188
4189
4190
4191
4192
4193
4194
4195 func TestTransportAndServerSharedBodyRace(t *testing.T) {
4196 run(t, testTransportAndServerSharedBodyRace, testNotParallel)
4197 }
4198 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
4199
4200
4201
4202
4203 runTimeSensitiveTest(t, []time.Duration{
4204 1 * time.Millisecond,
4205 5 * time.Millisecond,
4206 10 * time.Millisecond,
4207 50 * time.Millisecond,
4208 100 * time.Millisecond,
4209 500 * time.Millisecond,
4210 time.Second,
4211 5 * time.Second,
4212 }, func(t *testing.T, timeout time.Duration) error {
4213 SetRSTAvoidanceDelay(t, timeout)
4214 t.Logf("set RST avoidance delay to %v", timeout)
4215
4216 const bodySize = 1 << 20
4217
4218 var wg sync.WaitGroup
4219 backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4220
4221
4222
4223
4224
4225
4226
4227
4228 wg.Add(1)
4229 defer wg.Done()
4230
4231 n, err := io.CopyN(rw, req.Body, bodySize)
4232 t.Logf("backend CopyN: %v, %v", n, err)
4233 <-req.Context().Done()
4234 }))
4235
4236
4237 defer func() {
4238 wg.Wait()
4239 backend.close()
4240 }()
4241
4242 var proxy *clientServerTest
4243 proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4244 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
4245 req2.ContentLength = bodySize
4246 cancel := make(chan struct{})
4247 req2.Cancel = cancel
4248
4249 bresp, err := proxy.c.Do(req2)
4250 if err != nil {
4251 t.Errorf("Proxy outbound request: %v", err)
4252 return
4253 }
4254 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
4255 if err != nil {
4256 t.Errorf("Proxy copy error: %v", err)
4257 return
4258 }
4259 t.Cleanup(func() { bresp.Body.Close() })
4260
4261
4262
4263
4264
4265
4266 if mode == http2Mode {
4267 close(cancel)
4268 } else {
4269 proxy.c.Transport.(*Transport).CancelRequest(req2)
4270 }
4271 rw.Write([]byte("OK"))
4272 }))
4273 defer proxy.close()
4274
4275 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
4276 res, err := proxy.c.Do(req)
4277 if err != nil {
4278 return fmt.Errorf("original request: %v", err)
4279 }
4280 res.Body.Close()
4281 return nil
4282 })
4283 }
4284
4285
4286
4287
4288 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
4289 run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
4290 }
4291 func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
4292 if testing.Short() {
4293 t.Skip("skipping in -short mode")
4294 }
4295
4296 readErrCh := make(chan error, 1)
4297 errCh := make(chan error, 2)
4298
4299 server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4300 go func(body io.Reader) {
4301 _, err := body.Read(make([]byte, 100))
4302 readErrCh <- err
4303 }(req.Body)
4304 time.Sleep(500 * time.Millisecond)
4305 })).ts
4306
4307 closeConn := make(chan bool)
4308 defer close(closeConn)
4309 go func() {
4310 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4311 if err != nil {
4312 errCh <- err
4313 return
4314 }
4315 defer conn.Close()
4316 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4317 if err != nil {
4318 errCh <- err
4319 return
4320 }
4321
4322
4323 <-closeConn
4324 }()
4325 select {
4326 case err := <-readErrCh:
4327 if err == nil {
4328 t.Error("Read was nil. Expected error.")
4329 }
4330 case err := <-errCh:
4331 t.Error(err)
4332 }
4333 }
4334
4335
4336 func TestResponseWriterWriteString(t *testing.T) {
4337 okc := make(chan bool, 1)
4338 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4339 _, ok := w.(io.StringWriter)
4340 okc <- ok
4341 }))
4342 ht.rawResponse("GET / HTTP/1.0")
4343 select {
4344 case ok := <-okc:
4345 if !ok {
4346 t.Error("ResponseWriter did not implement io.StringWriter")
4347 }
4348 default:
4349 t.Error("handler was never called")
4350 }
4351 }
4352
4353 func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
4354 func testServerConnState(t *testing.T, mode testMode) {
4355 handler := map[string]func(w ResponseWriter, r *Request){
4356 "/": func(w ResponseWriter, r *Request) {
4357 fmt.Fprintf(w, "Hello.")
4358 },
4359 "/close": func(w ResponseWriter, r *Request) {
4360 w.Header().Set("Connection", "close")
4361 fmt.Fprintf(w, "Hello.")
4362 },
4363 "/hijack": func(w ResponseWriter, r *Request) {
4364 c, _, _ := w.(Hijacker).Hijack()
4365 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4366 c.Close()
4367 },
4368 "/hijack-panic": func(w ResponseWriter, r *Request) {
4369 c, _, _ := w.(Hijacker).Hijack()
4370 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4371 c.Close()
4372 panic("intentional panic")
4373 },
4374 }
4375
4376
4377 type stateLog struct {
4378 active net.Conn
4379 got []ConnState
4380 want []ConnState
4381 complete chan<- struct{}
4382 }
4383 activeLog := make(chan *stateLog, 1)
4384
4385
4386
4387
4388 wantLog := func(doRequests func(), want ...ConnState) {
4389 t.Helper()
4390 complete := make(chan struct{})
4391 activeLog <- &stateLog{want: want, complete: complete}
4392
4393 doRequests()
4394
4395 <-complete
4396 sl := <-activeLog
4397 if !slices.Equal(sl.got, sl.want) {
4398 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4399 }
4400
4401
4402
4403 }
4404
4405 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4406 handler[r.URL.Path](w, r)
4407 }), func(ts *httptest.Server) {
4408 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4409 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4410 if c == nil {
4411 t.Errorf("nil conn seen in state %s", state)
4412 return
4413 }
4414 sl := <-activeLog
4415 if sl.active == nil && state == StateNew {
4416 sl.active = c
4417 } else if sl.active != c {
4418 t.Errorf("unexpected conn in state %s", state)
4419 activeLog <- sl
4420 return
4421 }
4422 sl.got = append(sl.got, state)
4423 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !slices.Equal(sl.got, sl.want[:len(sl.got)])) {
4424 close(sl.complete)
4425 sl.complete = nil
4426 }
4427 activeLog <- sl
4428 }
4429 }).ts
4430 defer func() {
4431 activeLog <- &stateLog{}
4432 ts.Close()
4433 }()
4434
4435 c := ts.Client()
4436
4437 mustGet := func(url string, headers ...string) {
4438 t.Helper()
4439 req, err := NewRequest("GET", url, nil)
4440 if err != nil {
4441 t.Fatal(err)
4442 }
4443 for len(headers) > 0 {
4444 req.Header.Add(headers[0], headers[1])
4445 headers = headers[2:]
4446 }
4447 res, err := c.Do(req)
4448 if err != nil {
4449 t.Errorf("Error fetching %s: %v", url, err)
4450 return
4451 }
4452 _, err = io.ReadAll(res.Body)
4453 defer res.Body.Close()
4454 if err != nil {
4455 t.Errorf("Error reading %s: %v", url, err)
4456 }
4457 }
4458
4459 wantLog(func() {
4460 mustGet(ts.URL + "/")
4461 mustGet(ts.URL + "/close")
4462 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4463
4464 wantLog(func() {
4465 mustGet(ts.URL + "/")
4466 mustGet(ts.URL+"/", "Connection", "close")
4467 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4468
4469 wantLog(func() {
4470 mustGet(ts.URL + "/hijack")
4471 }, StateNew, StateActive, StateHijacked)
4472
4473 wantLog(func() {
4474 mustGet(ts.URL + "/hijack-panic")
4475 }, StateNew, StateActive, StateHijacked)
4476
4477 wantLog(func() {
4478 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4479 if err != nil {
4480 t.Fatal(err)
4481 }
4482 c.Close()
4483 }, StateNew, StateClosed)
4484
4485 wantLog(func() {
4486 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4487 if err != nil {
4488 t.Fatal(err)
4489 }
4490 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4491 t.Fatal(err)
4492 }
4493 c.Read(make([]byte, 1))
4494 c.Close()
4495 }, StateNew, StateActive, StateClosed)
4496
4497 wantLog(func() {
4498 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4499 if err != nil {
4500 t.Fatal(err)
4501 }
4502 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4503 t.Fatal(err)
4504 }
4505 res, err := ReadResponse(bufio.NewReader(c), nil)
4506 if err != nil {
4507 t.Fatal(err)
4508 }
4509 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4510 t.Fatal(err)
4511 }
4512 c.Close()
4513 }, StateNew, StateActive, StateIdle, StateClosed)
4514 }
4515
4516 func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
4517 run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
4518 }
4519 func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
4520 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4521 }), func(ts *httptest.Server) {
4522 ts.Config.SetKeepAlivesEnabled(false)
4523 }).ts
4524 res, err := ts.Client().Get(ts.URL)
4525 if err != nil {
4526 t.Fatal(err)
4527 }
4528 defer res.Body.Close()
4529 if !res.Close {
4530 t.Errorf("Body.Close == false; want true")
4531 }
4532 }
4533
4534
4535 func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
4536 func testServerEmptyBodyRace(t *testing.T, mode testMode) {
4537 var n int32
4538 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4539 atomic.AddInt32(&n, 1)
4540 }), optQuietLog)
4541 var wg sync.WaitGroup
4542 const reqs = 20
4543 for i := 0; i < reqs; i++ {
4544 wg.Add(1)
4545 go func() {
4546 defer wg.Done()
4547 res, err := cst.c.Get(cst.ts.URL)
4548 if err != nil {
4549
4550
4551 time.Sleep(10 * time.Millisecond)
4552 res, err = cst.c.Get(cst.ts.URL)
4553 if err != nil {
4554 t.Error(err)
4555 return
4556 }
4557 }
4558 defer res.Body.Close()
4559 _, err = io.Copy(io.Discard, res.Body)
4560 if err != nil {
4561 t.Error(err)
4562 return
4563 }
4564 }()
4565 }
4566 wg.Wait()
4567 if got := atomic.LoadInt32(&n); got != reqs {
4568 t.Errorf("handler ran %d times; want %d", got, reqs)
4569 }
4570 }
4571
4572 func TestServerConnStateNew(t *testing.T) {
4573 sawNew := false
4574 srv := &Server{
4575 ConnState: func(c net.Conn, state ConnState) {
4576 if state == StateNew {
4577 sawNew = true
4578 }
4579 },
4580 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4581 }
4582 srv.Serve(&oneConnListener{
4583 conn: &rwTestConn{
4584 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4585 Writer: io.Discard,
4586 },
4587 })
4588 if !sawNew {
4589 t.Error("StateNew not seen")
4590 }
4591 }
4592
4593 type closeWriteTestConn struct {
4594 rwTestConn
4595 didCloseWrite bool
4596 }
4597
4598 func (c *closeWriteTestConn) CloseWrite() error {
4599 c.didCloseWrite = true
4600 return nil
4601 }
4602
4603 func TestCloseWrite(t *testing.T) {
4604 SetRSTAvoidanceDelay(t, 1*time.Millisecond)
4605
4606 var srv Server
4607 var testConn closeWriteTestConn
4608 c := ExportServerNewConn(&srv, &testConn)
4609 ExportCloseWriteAndWait(c)
4610 if !testConn.didCloseWrite {
4611 t.Error("didn't see CloseWrite call")
4612 }
4613 }
4614
4615
4616
4617
4618
4619
4620
4621
4622 func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
4623 func testServerFlushAndHijack(t *testing.T, mode testMode) {
4624 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4625 io.WriteString(w, "Hello, ")
4626 w.(Flusher).Flush()
4627 conn, buf, _ := w.(Hijacker).Hijack()
4628 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4629 if err := buf.Flush(); err != nil {
4630 t.Error(err)
4631 }
4632 if err := conn.Close(); err != nil {
4633 t.Error(err)
4634 }
4635 })).ts
4636 res, err := Get(ts.URL)
4637 if err != nil {
4638 t.Fatal(err)
4639 }
4640 defer res.Body.Close()
4641 all, err := io.ReadAll(res.Body)
4642 if err != nil {
4643 t.Fatal(err)
4644 }
4645 if want := "Hello, world!"; string(all) != want {
4646 t.Errorf("Got %q; want %q", all, want)
4647 }
4648 }
4649
4650
4651
4652
4653
4654
4655
4656 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4657 run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
4658 }
4659 func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
4660 if testing.Short() {
4661 t.Skip("skipping in -short mode")
4662 }
4663 const numReq = 3
4664 addrc := make(chan string, numReq)
4665 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4666 addrc <- r.RemoteAddr
4667 time.Sleep(500 * time.Millisecond)
4668 w.(Flusher).Flush()
4669 }), func(ts *httptest.Server) {
4670 ts.Config.WriteTimeout = 250 * time.Millisecond
4671 }).ts
4672
4673 errc := make(chan error, numReq)
4674 go func() {
4675 defer close(errc)
4676 for i := 0; i < numReq; i++ {
4677 res, err := Get(ts.URL)
4678 if res != nil {
4679 res.Body.Close()
4680 }
4681 errc <- err
4682 }
4683 }()
4684
4685 addrSeen := map[string]bool{}
4686 numOkay := 0
4687 for {
4688 select {
4689 case v := <-addrc:
4690 addrSeen[v] = true
4691 case err, ok := <-errc:
4692 if !ok {
4693 if len(addrSeen) != numReq {
4694 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4695 }
4696 if numOkay != 0 {
4697 t.Errorf("got %d successful client requests; want 0", numOkay)
4698 }
4699 return
4700 }
4701 if err == nil {
4702 numOkay++
4703 }
4704 }
4705 }
4706 }
4707
4708
4709
4710 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4711 run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
4712 }
4713 func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
4714 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4715 w.Header().Set("Transfer-Encoding", "foo")
4716 io.WriteString(w, "<html>")
4717 })).ts
4718 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4719 if err != nil {
4720 t.Fatalf("Dial: %v", err)
4721 }
4722 defer c.Close()
4723 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4724 t.Fatal(err)
4725 }
4726 bs := bufio.NewScanner(c)
4727 var got strings.Builder
4728 for bs.Scan() {
4729 if strings.TrimSpace(bs.Text()) == "" {
4730 break
4731 }
4732 got.WriteString(bs.Text())
4733 got.WriteByte('\n')
4734 }
4735 if err := bs.Err(); err != nil {
4736 t.Fatal(err)
4737 }
4738 if strings.Contains(got.String(), "Content-Length") {
4739 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4740 }
4741 if strings.Contains(got.String(), "Content-Type") {
4742 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4743 }
4744 }
4745
4746
4747
4748 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4749 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4750 "\r\n\r\n" +
4751 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4752 var buf bytes.Buffer
4753 conn := &rwTestConn{
4754 Reader: bytes.NewReader(req),
4755 Writer: &buf,
4756 closec: make(chan bool, 1),
4757 }
4758 ln := &oneConnListener{conn: conn}
4759 numReq := 0
4760 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4761 numReq++
4762 }))
4763 <-conn.closec
4764 if numReq != 2 {
4765 t.Errorf("num requests = %d; want 2", numReq)
4766 t.Logf("Res: %s", buf.Bytes())
4767 }
4768 }
4769
4770 func TestIssue13893_Expect100(t *testing.T) {
4771
4772 req := reqBytes(`PUT /readbody HTTP/1.1
4773 User-Agent: PycURL/7.22.0
4774 Host: 127.0.0.1:9000
4775 Accept: */*
4776 Expect: 100-continue
4777 Content-Length: 10
4778
4779 HelloWorld
4780
4781 `)
4782 var buf bytes.Buffer
4783 conn := &rwTestConn{
4784 Reader: bytes.NewReader(req),
4785 Writer: &buf,
4786 closec: make(chan bool, 1),
4787 }
4788 ln := &oneConnListener{conn: conn}
4789 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4790 if _, ok := r.Header["Expect"]; !ok {
4791 t.Error("Expect header should not be filtered out")
4792 }
4793 }))
4794 <-conn.closec
4795 }
4796
4797 func TestIssue11549_Expect100(t *testing.T) {
4798 req := reqBytes(`PUT /readbody HTTP/1.1
4799 User-Agent: PycURL/7.22.0
4800 Host: 127.0.0.1:9000
4801 Accept: */*
4802 Expect: 100-continue
4803 Content-Length: 10
4804
4805 HelloWorldPUT /noreadbody HTTP/1.1
4806 User-Agent: PycURL/7.22.0
4807 Host: 127.0.0.1:9000
4808 Accept: */*
4809 Expect: 100-continue
4810 Content-Length: 10
4811
4812 GET /should-be-ignored HTTP/1.1
4813 Host: foo
4814
4815 `)
4816 var buf strings.Builder
4817 conn := &rwTestConn{
4818 Reader: bytes.NewReader(req),
4819 Writer: &buf,
4820 closec: make(chan bool, 1),
4821 }
4822 ln := &oneConnListener{conn: conn}
4823 numReq := 0
4824 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4825 numReq++
4826 if r.URL.Path == "/readbody" {
4827 io.ReadAll(r.Body)
4828 }
4829 io.WriteString(w, "Hello world!")
4830 }))
4831 <-conn.closec
4832 if numReq != 2 {
4833 t.Errorf("num requests = %d; want 2", numReq)
4834 }
4835 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4836 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4837 }
4838 }
4839
4840
4841
4842 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4843 setParallel(t)
4844 conn := newTestConn()
4845 conn.readBuf.WriteString(
4846 "POST / HTTP/1.1\r\n" +
4847 "Host: test\r\n" +
4848 "Content-Length: 9999999999\r\n" +
4849 "\r\n" + strings.Repeat("a", 1<<20))
4850
4851 ls := &oneConnListener{conn}
4852 var inHandlerLen int
4853 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4854 inHandlerLen = conn.readBuf.Len()
4855 rw.WriteHeader(404)
4856 }))
4857 <-conn.closec
4858 afterHandlerLen := conn.readBuf.Len()
4859
4860 if afterHandlerLen != inHandlerLen {
4861 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4862 }
4863 }
4864
4865 func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
4866 func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
4867 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4868 r.Body = nil
4869 fmt.Fprintf(w, "%v", r.RemoteAddr)
4870 }))
4871 get := func() string {
4872 res, err := cst.c.Get(cst.ts.URL)
4873 if err != nil {
4874 t.Fatal(err)
4875 }
4876 defer res.Body.Close()
4877 slurp, err := io.ReadAll(res.Body)
4878 if err != nil {
4879 t.Fatal(err)
4880 }
4881 return string(slurp)
4882 }
4883 a, b := get(), get()
4884 if a != b {
4885 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
4886 }
4887 }
4888
4889
4890
4891 func TestServerValidatesHostHeader(t *testing.T) {
4892 tests := []struct {
4893 proto string
4894 host string
4895 want int
4896 }{
4897 {"HTTP/0.9", "", 505},
4898
4899 {"HTTP/1.1", "", 400},
4900 {"HTTP/1.1", "Host: \r\n", 200},
4901 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4902 {"HTTP/1.1", "Host: foo.com\r\n", 200},
4903 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
4904 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
4905 {"HTTP/1.1", "Host: ::1\r\n", 200},
4906 {"HTTP/1.1", "Host: [::1]\r\n", 200},
4907 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
4908 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
4909 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
4910 {"HTTP/1.1", "Host: \x06\r\n", 400},
4911 {"HTTP/1.1", "Host: \xff\r\n", 400},
4912 {"HTTP/1.1", "Host: {\r\n", 400},
4913 {"HTTP/1.1", "Host: }\r\n", 400},
4914 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
4915
4916
4917
4918 {"HTTP/1.0", "", 200},
4919 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
4920 {"HTTP/1.0", "Host: \xff\r\n", 400},
4921
4922
4923 {"PRI * HTTP/2.0", "", 200},
4924
4925
4926 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
4927
4928
4929 {"PRI / HTTP/2.0", "", 505},
4930 {"GET / HTTP/2.0", "", 505},
4931 {"GET / HTTP/3.0", "", 505},
4932 }
4933 for _, tt := range tests {
4934 conn := newTestConn()
4935 methodTarget := "GET / "
4936 if !strings.HasPrefix(tt.proto, "HTTP/") {
4937 methodTarget = ""
4938 }
4939 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
4940
4941 ln := &oneConnListener{conn}
4942 srv := Server{
4943 ErrorLog: quietLog,
4944 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
4945 }
4946 go srv.Serve(ln)
4947 <-conn.closec
4948 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
4949 if err != nil {
4950 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
4951 continue
4952 }
4953 if res.StatusCode != tt.want {
4954 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
4955 }
4956 }
4957 }
4958
4959 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
4960 run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
4961 }
4962 func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
4963 const upgradeResponse = "upgrade here"
4964 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4965 conn, br, err := w.(Hijacker).Hijack()
4966 if err != nil {
4967 t.Error(err)
4968 return
4969 }
4970 defer conn.Close()
4971 if r.Method != "PRI" || r.RequestURI != "*" {
4972 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
4973 return
4974 }
4975 if !r.Close {
4976 t.Errorf("Request.Close = true; want false")
4977 }
4978 const want = "SM\r\n\r\n"
4979 buf := make([]byte, len(want))
4980 n, err := io.ReadFull(br, buf)
4981 if err != nil || string(buf[:n]) != want {
4982 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
4983 return
4984 }
4985 io.WriteString(conn, upgradeResponse)
4986 })).ts
4987
4988 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4989 if err != nil {
4990 t.Fatalf("Dial: %v", err)
4991 }
4992 defer c.Close()
4993 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
4994 slurp, err := io.ReadAll(c)
4995 if err != nil {
4996 t.Fatal(err)
4997 }
4998 if string(slurp) != upgradeResponse {
4999 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
5000 }
5001 }
5002
5003
5004
5005 func TestServerValidatesHeaders(t *testing.T) {
5006 setParallel(t)
5007 tests := []struct {
5008 header string
5009 want int
5010 }{
5011 {"", 200},
5012 {"Foo: bar\r\n", 200},
5013 {"X-Foo: bar\r\n", 200},
5014 {"Foo: a space\r\n", 200},
5015
5016 {"A space: foo\r\n", 400},
5017 {"foo\xffbar: foo\r\n", 400},
5018 {"foo\x00bar: foo\r\n", 400},
5019 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
5020
5021
5022 {"Foo : bar\r\n", 400},
5023 {"Foo\t: bar\r\n", 400},
5024
5025
5026
5027 {": empty key\r\n", 400},
5028
5029
5030
5031
5032 {"Content-Length: notdigits\r\n", 400},
5033 {"Content-Length: notdigits\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n", 400},
5034
5035 {"foo: foo foo\r\n", 200},
5036 {"foo: foo\tfoo\r\n", 200},
5037 {"foo: foo\x00foo\r\n", 400},
5038 {"foo: foo\x7ffoo\r\n", 400},
5039 {"foo: foo\xfffoo\r\n", 200},
5040 }
5041 for _, tt := range tests {
5042 conn := newTestConn()
5043 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
5044
5045 ln := &oneConnListener{conn}
5046 srv := Server{
5047 ErrorLog: quietLog,
5048 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5049 }
5050 go srv.Serve(ln)
5051 <-conn.closec
5052 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5053 if err != nil {
5054 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
5055 continue
5056 }
5057 if res.StatusCode != tt.want {
5058 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
5059 }
5060 }
5061 }
5062
5063 func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
5064 run(t, testServerRequestContextCancel_ServeHTTPDone)
5065 }
5066 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
5067 ctxc := make(chan context.Context, 1)
5068 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5069 ctx := r.Context()
5070 select {
5071 case <-ctx.Done():
5072 t.Error("should not be Done in ServeHTTP")
5073 default:
5074 }
5075 ctxc <- ctx
5076 }))
5077 res, err := cst.c.Get(cst.ts.URL)
5078 if err != nil {
5079 t.Fatal(err)
5080 }
5081 res.Body.Close()
5082 ctx := <-ctxc
5083 select {
5084 case <-ctx.Done():
5085 default:
5086 t.Error("context should be done after ServeHTTP completes")
5087 }
5088 }
5089
5090
5091
5092
5093
5094 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
5095 run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
5096 }
5097 func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
5098 inHandler := make(chan struct{})
5099 handlerDone := make(chan struct{})
5100 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5101 close(inHandler)
5102 <-r.Context().Done()
5103 close(handlerDone)
5104 })).ts
5105 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5106 if err != nil {
5107 t.Fatal(err)
5108 }
5109 defer c.Close()
5110 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
5111 <-inHandler
5112 c.Close()
5113 <-handlerDone
5114 }
5115
5116 func TestServerContext_ServerContextKey(t *testing.T) {
5117 run(t, testServerContext_ServerContextKey)
5118 }
5119 func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
5120 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5121 ctx := r.Context()
5122 got := ctx.Value(ServerContextKey)
5123 if _, ok := got.(*Server); !ok {
5124 t.Errorf("context value = %T; want *http.Server", got)
5125 }
5126 }))
5127 res, err := cst.c.Get(cst.ts.URL)
5128 if err != nil {
5129 t.Fatal(err)
5130 }
5131 res.Body.Close()
5132 }
5133
5134 func TestServerContext_LocalAddrContextKey(t *testing.T) {
5135 run(t, testServerContext_LocalAddrContextKey)
5136 }
5137 func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
5138 ch := make(chan any, 1)
5139 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5140 ch <- r.Context().Value(LocalAddrContextKey)
5141 }))
5142 if _, err := cst.c.Head(cst.ts.URL); err != nil {
5143 t.Fatal(err)
5144 }
5145
5146 host := cst.ts.Listener.Addr().String()
5147 got := <-ch
5148 if addr, ok := got.(net.Addr); !ok {
5149 t.Errorf("local addr value = %T; want net.Addr", got)
5150 } else if fmt.Sprint(addr) != host {
5151 t.Errorf("local addr = %v; want %v", addr, host)
5152 }
5153 }
5154
5155
5156 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
5157 setParallel(t)
5158 defer afterTest(t)
5159 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5160 w.Header().Set("Transfer-Encoding", "chunked")
5161 w.Write([]byte("hello"))
5162 }))
5163 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5164 const hdr = "Transfer-Encoding: chunked"
5165 if n := strings.Count(resp, hdr); n != 1 {
5166 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5167 }
5168 }
5169
5170
5171 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
5172 setParallel(t)
5173 defer afterTest(t)
5174 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5175 w.Header().Set("Transfer-Encoding", "gzip")
5176 gz := gzip.NewWriter(w)
5177 gz.Write([]byte("hello"))
5178 gz.Close()
5179 }))
5180 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5181 for _, v := range []string{"gzip", "chunked"} {
5182 hdr := "Transfer-Encoding: " + v
5183 if n := strings.Count(resp, hdr); n != 1 {
5184 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5185 }
5186 }
5187 }
5188
5189 func BenchmarkClientServer(b *testing.B) {
5190 run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
5191 }
5192 func benchmarkClientServer(b *testing.B, mode testMode) {
5193 b.ReportAllocs()
5194 b.StopTimer()
5195 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5196 fmt.Fprintf(rw, "Hello world.\n")
5197 })).ts
5198 b.StartTimer()
5199
5200 c := ts.Client()
5201 for i := 0; i < b.N; i++ {
5202 res, err := c.Get(ts.URL)
5203 if err != nil {
5204 b.Fatal("Get:", err)
5205 }
5206 all, err := io.ReadAll(res.Body)
5207 res.Body.Close()
5208 if err != nil {
5209 b.Fatal("ReadAll:", err)
5210 }
5211 body := string(all)
5212 if body != "Hello world.\n" {
5213 b.Fatal("Got body:", body)
5214 }
5215 }
5216
5217 b.StopTimer()
5218 }
5219
5220 func BenchmarkClientServerParallel(b *testing.B) {
5221 for _, parallelism := range []int{4, 64} {
5222 b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
5223 run(b, func(b *testing.B, mode testMode) {
5224 benchmarkClientServerParallel(b, parallelism, mode)
5225 }, []testMode{http1Mode, https1Mode, http2Mode})
5226 })
5227 }
5228 }
5229
5230 func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
5231 b.ReportAllocs()
5232 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5233 fmt.Fprintf(rw, "Hello world.\n")
5234 })).ts
5235 b.ResetTimer()
5236 b.SetParallelism(parallelism)
5237 b.RunParallel(func(pb *testing.PB) {
5238 c := ts.Client()
5239 for pb.Next() {
5240 res, err := c.Get(ts.URL)
5241 if err != nil {
5242 b.Logf("Get: %v", err)
5243 continue
5244 }
5245 all, err := io.ReadAll(res.Body)
5246 res.Body.Close()
5247 if err != nil {
5248 b.Logf("ReadAll: %v", err)
5249 continue
5250 }
5251 body := string(all)
5252 if body != "Hello world.\n" {
5253 panic("Got body: " + body)
5254 }
5255 }
5256 })
5257 }
5258
5259
5260
5261
5262
5263
5264
5265
5266
5267
5268 func BenchmarkServer(b *testing.B) {
5269 b.ReportAllocs()
5270
5271 if url := os.Getenv("TEST_BENCH_SERVER_URL"); url != "" {
5272 n, err := strconv.Atoi(os.Getenv("TEST_BENCH_CLIENT_N"))
5273 if err != nil {
5274 panic(err)
5275 }
5276 for i := 0; i < n; i++ {
5277 res, err := Get(url)
5278 if err != nil {
5279 log.Panicf("Get: %v", err)
5280 }
5281 all, err := io.ReadAll(res.Body)
5282 res.Body.Close()
5283 if err != nil {
5284 log.Panicf("ReadAll: %v", err)
5285 }
5286 body := string(all)
5287 if body != "Hello world.\n" {
5288 log.Panicf("Got body: %q", body)
5289 }
5290 }
5291 os.Exit(0)
5292 return
5293 }
5294
5295 var res = []byte("Hello world.\n")
5296 b.StopTimer()
5297 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5298 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5299 rw.Write(res)
5300 }))
5301 defer ts.Close()
5302 b.StartTimer()
5303
5304 cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$")
5305 cmd.Env = append([]string{
5306 fmt.Sprintf("TEST_BENCH_CLIENT_N=%d", b.N),
5307 fmt.Sprintf("TEST_BENCH_SERVER_URL=%s", ts.URL),
5308 }, os.Environ()...)
5309 out, err := cmd.CombinedOutput()
5310 if err != nil {
5311 b.Errorf("Test failure: %v, with output: %s", err, out)
5312 }
5313 }
5314
5315
5316 func getNoBody(urlStr string) (*Response, error) {
5317 res, err := Get(urlStr)
5318 if err != nil {
5319 return nil, err
5320 }
5321 res.Body.Close()
5322 return res, nil
5323 }
5324
5325
5326
5327 func BenchmarkClient(b *testing.B) {
5328 b.ReportAllocs()
5329 b.StopTimer()
5330 defer afterTest(b)
5331
5332 var data = []byte("Hello world.\n")
5333 if server := os.Getenv("TEST_BENCH_SERVER"); server != "" {
5334
5335 port := os.Getenv("TEST_BENCH_SERVER_PORT")
5336 if port == "" {
5337 port = "0"
5338 }
5339 ln, err := net.Listen("tcp", "localhost:"+port)
5340 if err != nil {
5341 fmt.Fprintln(os.Stderr, err.Error())
5342 os.Exit(1)
5343 }
5344 fmt.Println(ln.Addr().String())
5345 HandleFunc("/", func(w ResponseWriter, r *Request) {
5346 r.ParseForm()
5347 if r.Form.Get("stop") != "" {
5348 os.Exit(0)
5349 }
5350 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5351 w.Write(data)
5352 })
5353 var srv Server
5354 log.Fatal(srv.Serve(ln))
5355 }
5356
5357
5358 ctx, cancel := context.WithCancel(context.Background())
5359 cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkClient$")
5360 cmd.Env = append(cmd.Environ(), "TEST_BENCH_SERVER=yes")
5361 cmd.Stderr = os.Stderr
5362 stdout, err := cmd.StdoutPipe()
5363 if err != nil {
5364 b.Fatal(err)
5365 }
5366 if err := cmd.Start(); err != nil {
5367 b.Fatalf("subprocess failed to start: %v", err)
5368 }
5369
5370 done := make(chan error, 1)
5371 go func() {
5372 done <- cmd.Wait()
5373 close(done)
5374 }()
5375 defer func() {
5376 cancel()
5377 <-done
5378 }()
5379
5380
5381
5382 bs := bufio.NewScanner(stdout)
5383 if !bs.Scan() {
5384 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5385 }
5386 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5387 if _, err := getNoBody(url); err != nil {
5388 b.Fatalf("initial probe of child process failed: %v", err)
5389 }
5390
5391
5392 b.StartTimer()
5393 for i := 0; i < b.N; i++ {
5394 res, err := Get(url)
5395 if err != nil {
5396 b.Fatalf("Get: %v", err)
5397 }
5398 body, err := io.ReadAll(res.Body)
5399 res.Body.Close()
5400 if err != nil {
5401 b.Fatalf("ReadAll: %v", err)
5402 }
5403 if !bytes.Equal(body, data) {
5404 b.Fatalf("Got body: %q", body)
5405 }
5406 }
5407 b.StopTimer()
5408
5409
5410 getNoBody(url + "?stop=yes")
5411 if err := <-done; err != nil {
5412 b.Fatalf("subprocess failed: %v", err)
5413 }
5414 }
5415
5416 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5417 b.ReportAllocs()
5418 req := reqBytes(`GET / HTTP/1.0
5419 Host: golang.org
5420 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5421 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5422 Accept-Encoding: gzip,deflate,sdch
5423 Accept-Language: en-US,en;q=0.8
5424 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5425 `)
5426 res := []byte("Hello world!\n")
5427
5428 conn := newTestConn()
5429 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5430 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5431 rw.Write(res)
5432 })
5433 ln := new(oneConnListener)
5434 for i := 0; i < b.N; i++ {
5435 conn.readBuf.Reset()
5436 conn.writeBuf.Reset()
5437 conn.readBuf.Write(req)
5438 ln.conn = conn
5439 Serve(ln, handler)
5440 <-conn.closec
5441 }
5442 }
5443
5444
5445 type repeatReader struct {
5446 content []byte
5447 count int
5448 off int
5449 }
5450
5451 func (r *repeatReader) Read(p []byte) (n int, err error) {
5452 if r.count <= 0 {
5453 return 0, io.EOF
5454 }
5455 n = copy(p, r.content[r.off:])
5456 r.off += n
5457 if r.off == len(r.content) {
5458 r.count--
5459 r.off = 0
5460 }
5461 return
5462 }
5463
5464 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5465 b.ReportAllocs()
5466
5467 req := reqBytes(`GET / HTTP/1.1
5468 Host: golang.org
5469 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5470 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5471 Accept-Encoding: gzip,deflate,sdch
5472 Accept-Language: en-US,en;q=0.8
5473 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5474 `)
5475 res := []byte("Hello world!\n")
5476
5477 conn := &rwTestConn{
5478 Reader: &repeatReader{content: req, count: b.N},
5479 Writer: io.Discard,
5480 closec: make(chan bool, 1),
5481 }
5482 handled := 0
5483 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5484 handled++
5485 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5486 rw.Write(res)
5487 })
5488 ln := &oneConnListener{conn: conn}
5489 go Serve(ln, handler)
5490 <-conn.closec
5491 if b.N != handled {
5492 b.Errorf("b.N=%d but handled %d", b.N, handled)
5493 }
5494 }
5495
5496
5497
5498 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5499 b.ReportAllocs()
5500
5501 req := reqBytes(`GET / HTTP/1.1
5502 Host: golang.org
5503 `)
5504 res := []byte("Hello world!\n")
5505
5506 conn := &rwTestConn{
5507 Reader: &repeatReader{content: req, count: b.N},
5508 Writer: io.Discard,
5509 closec: make(chan bool, 1),
5510 }
5511 handled := 0
5512 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5513 handled++
5514 rw.Write(res)
5515 })
5516 ln := &oneConnListener{conn: conn}
5517 go Serve(ln, handler)
5518 <-conn.closec
5519 if b.N != handled {
5520 b.Errorf("b.N=%d but handled %d", b.N, handled)
5521 }
5522 }
5523
5524 const someResponse = "<html>some response</html>"
5525
5526
5527 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5528
5529
5530 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5531 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5532 w.Header().Set("Content-Type", "text/html")
5533 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5534 w.Write(response)
5535 }))
5536 }
5537
5538
5539 func BenchmarkServerHandlerNoLen(b *testing.B) {
5540 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5541 w.Header().Set("Content-Type", "text/html")
5542 w.Write(response)
5543 }))
5544 }
5545
5546
5547 func BenchmarkServerHandlerNoType(b *testing.B) {
5548 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5549 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5550 w.Write(response)
5551 }))
5552 }
5553
5554
5555 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5556 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5557 w.Write(response)
5558 }))
5559 }
5560
5561 func benchmarkHandler(b *testing.B, h Handler) {
5562 b.ReportAllocs()
5563 req := reqBytes(`GET / HTTP/1.1
5564 Host: golang.org
5565 `)
5566 conn := &rwTestConn{
5567 Reader: &repeatReader{content: req, count: b.N},
5568 Writer: io.Discard,
5569 closec: make(chan bool, 1),
5570 }
5571 handled := 0
5572 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5573 handled++
5574 h.ServeHTTP(rw, r)
5575 })
5576 ln := &oneConnListener{conn: conn}
5577 go Serve(ln, handler)
5578 <-conn.closec
5579 if b.N != handled {
5580 b.Errorf("b.N=%d but handled %d", b.N, handled)
5581 }
5582 }
5583
5584 func BenchmarkServerHijack(b *testing.B) {
5585 b.ReportAllocs()
5586 req := reqBytes(`GET / HTTP/1.1
5587 Host: golang.org
5588 `)
5589 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5590 conn, _, err := w.(Hijacker).Hijack()
5591 if err != nil {
5592 panic(err)
5593 }
5594 conn.Close()
5595 })
5596 conn := &rwTestConn{
5597 Writer: io.Discard,
5598 closec: make(chan bool, 1),
5599 }
5600 ln := &oneConnListener{conn: conn}
5601 for i := 0; i < b.N; i++ {
5602 conn.Reader = bytes.NewReader(req)
5603 ln.conn = conn
5604 Serve(ln, h)
5605 <-conn.closec
5606 }
5607 }
5608
5609 func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
5610 func benchmarkCloseNotifier(b *testing.B, mode testMode) {
5611 b.ReportAllocs()
5612 b.StopTimer()
5613 sawClose := make(chan bool)
5614 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
5615 <-rw.(CloseNotifier).CloseNotify()
5616 sawClose <- true
5617 })).ts
5618 b.StartTimer()
5619 for i := 0; i < b.N; i++ {
5620 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5621 if err != nil {
5622 b.Fatalf("error dialing: %v", err)
5623 }
5624 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5625 if err != nil {
5626 b.Fatal(err)
5627 }
5628 conn.Close()
5629 <-sawClose
5630 }
5631 b.StopTimer()
5632 }
5633
5634
5635 func TestConcurrentServerServe(t *testing.T) {
5636 setParallel(t)
5637 for i := 0; i < 100; i++ {
5638 ln1 := &oneConnListener{conn: nil}
5639 ln2 := &oneConnListener{conn: nil}
5640 srv := Server{}
5641 go func() { srv.Serve(ln1) }()
5642 go func() { srv.Serve(ln2) }()
5643 }
5644 }
5645
5646 func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
5647 func testServerIdleTimeout(t *testing.T, mode testMode) {
5648 if testing.Short() {
5649 t.Skip("skipping in short mode")
5650 }
5651 runTimeSensitiveTest(t, []time.Duration{
5652 10 * time.Millisecond,
5653 100 * time.Millisecond,
5654 1 * time.Second,
5655 10 * time.Second,
5656 }, func(t *testing.T, readHeaderTimeout time.Duration) error {
5657 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5658 io.Copy(io.Discard, r.Body)
5659 io.WriteString(w, r.RemoteAddr)
5660 }), func(ts *httptest.Server) {
5661 ts.Config.ReadHeaderTimeout = readHeaderTimeout
5662 ts.Config.IdleTimeout = 2 * readHeaderTimeout
5663 })
5664 defer cst.close()
5665 ts := cst.ts
5666 t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
5667 t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
5668 c := ts.Client()
5669
5670 get := func() (string, error) {
5671 res, err := c.Get(ts.URL)
5672 if err != nil {
5673 return "", err
5674 }
5675 defer res.Body.Close()
5676 slurp, err := io.ReadAll(res.Body)
5677 if err != nil {
5678
5679
5680
5681 t.Fatal(err)
5682 }
5683 return string(slurp), nil
5684 }
5685
5686 a1, err := get()
5687 if err != nil {
5688 return err
5689 }
5690 a2, err := get()
5691 if err != nil {
5692 return err
5693 }
5694 if a1 != a2 {
5695 return fmt.Errorf("did requests on different connections")
5696 }
5697 time.Sleep(ts.Config.IdleTimeout * 3 / 2)
5698 a3, err := get()
5699 if err != nil {
5700 return err
5701 }
5702 if a2 == a3 {
5703 return fmt.Errorf("request three unexpectedly on same connection")
5704 }
5705
5706
5707 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5708 if err != nil {
5709 return err
5710 }
5711 defer conn.Close()
5712 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5713 time.Sleep(ts.Config.ReadHeaderTimeout * 2)
5714 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5715 return fmt.Errorf("copy byte succeeded; want err")
5716 }
5717
5718 return nil
5719 })
5720 }
5721
5722 func get(t *testing.T, c *Client, url string) string {
5723 res, err := c.Get(url)
5724 if err != nil {
5725 t.Fatal(err)
5726 }
5727 defer res.Body.Close()
5728 slurp, err := io.ReadAll(res.Body)
5729 if err != nil {
5730 t.Fatal(err)
5731 }
5732 return string(slurp)
5733 }
5734
5735
5736
5737 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5738 run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
5739 }
5740 func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
5741 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5742 io.WriteString(w, r.RemoteAddr)
5743 })).ts
5744
5745 c := ts.Client()
5746 tr := c.Transport.(*Transport)
5747
5748 get := func() string { return get(t, c, ts.URL) }
5749
5750 a1, a2 := get(), get()
5751 if a1 == a2 {
5752 t.Logf("made two requests from a single conn %q (as expected)", a1)
5753 } else {
5754 t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
5755 }
5756
5757
5758
5759
5760
5761 if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
5762 t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
5763 }
5764
5765
5766 ts.Config.SetKeepAlivesEnabled(false)
5767
5768 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5769 if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
5770 if d > 0 {
5771 t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
5772 }
5773 return false
5774 }
5775 return true
5776 })
5777
5778
5779
5780
5781 }
5782
5783 func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) }
5784 func testServerShutdown(t *testing.T, mode testMode) {
5785 var cst *clientServerTest
5786
5787 var once sync.Once
5788 statesRes := make(chan map[ConnState]int, 1)
5789 shutdownRes := make(chan error, 1)
5790 gotOnShutdown := make(chan struct{})
5791 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5792 first := false
5793 once.Do(func() {
5794 statesRes <- cst.ts.Config.ExportAllConnsByState()
5795 go func() {
5796 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5797 }()
5798 first = true
5799 })
5800
5801 if first {
5802
5803
5804
5805 <-gotOnShutdown
5806
5807
5808 for !t.Failed() {
5809 res, err := cst.c.Get(cst.ts.URL)
5810 if err != nil {
5811 break
5812 }
5813 out, _ := io.ReadAll(res.Body)
5814 res.Body.Close()
5815 if mode == http2Mode {
5816 t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5817 t.Logf("Retrying to work around https://go.dev/issue/59038.")
5818 continue
5819 }
5820 t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5821 }
5822 }
5823
5824 io.WriteString(w, r.RemoteAddr)
5825 })
5826
5827 cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
5828 srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
5829 })
5830
5831 out := get(t, cst.c, cst.ts.URL)
5832 t.Logf("%v: %q", cst.ts.URL, out)
5833
5834 if err := <-shutdownRes; err != nil {
5835 t.Fatalf("Shutdown: %v", err)
5836 }
5837 <-gotOnShutdown
5838
5839 if states := <-statesRes; states[StateActive] != 1 {
5840 t.Errorf("connection in wrong state, %v", states)
5841 }
5842 }
5843
5844 func TestServerShutdownStateNew(t *testing.T) { runSynctest(t, testServerShutdownStateNew) }
5845 func testServerShutdownStateNew(t testing.TB, mode testMode) {
5846 if testing.Short() {
5847 t.Skip("test takes 5-6 seconds; skipping in short mode")
5848 }
5849
5850 listener := fakeNetListen()
5851 defer listener.Close()
5852
5853 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5854
5855 }), func(ts *httptest.Server) {
5856 ts.Listener.Close()
5857 ts.Listener = listener
5858
5859 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
5860 }).ts
5861
5862
5863 c := listener.connect()
5864 defer c.Close()
5865 synctest.Wait()
5866
5867 shutdownRes := runAsync(func() (struct{}, error) {
5868 return struct{}{}, ts.Config.Shutdown(context.Background())
5869 })
5870
5871
5872
5873
5874 const expectTimeout = 5 * time.Second
5875
5876
5877 time.Sleep(expectTimeout - 1)
5878 synctest.Wait()
5879 if shutdownRes.done() {
5880 t.Fatal("shutdown too soon")
5881 }
5882 if c.IsClosedByPeer() {
5883 t.Fatal("connection was closed by server too soon")
5884 }
5885
5886
5887
5888
5889
5890 time.Sleep(2 * time.Second)
5891 synctest.Wait()
5892 if _, err := shutdownRes.result(); err != nil {
5893 t.Fatalf("Shutdown() = %v, want complete", err)
5894 }
5895 if !c.IsClosedByPeer() {
5896 t.Fatalf("connection was not closed by server after shutdown")
5897 }
5898 }
5899
5900
5901 func TestServerCloseDeadlock(t *testing.T) {
5902 var s Server
5903 s.Close()
5904 s.Close()
5905 }
5906
5907
5908
5909 func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) }
5910 func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
5911 if mode == http2Mode {
5912 restore := ExportSetH2GoawayTimeout(10 * time.Millisecond)
5913 defer restore()
5914 }
5915
5916 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
5917 defer cst.close()
5918 srv := cst.ts.Config
5919 srv.SetKeepAlivesEnabled(false)
5920 for try := 0; try < 2; try++ {
5921 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5922 if !srv.ExportAllConnsIdle() {
5923 if d > 0 {
5924 t.Logf("test server still has active conns after %v", d)
5925 }
5926 return false
5927 }
5928 return true
5929 })
5930 conns := 0
5931 var info httptrace.GotConnInfo
5932 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
5933 GotConn: func(v httptrace.GotConnInfo) {
5934 conns++
5935 info = v
5936 },
5937 })
5938 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
5939 if err != nil {
5940 t.Fatal(err)
5941 }
5942 res, err := cst.c.Do(req)
5943 if err != nil {
5944 t.Fatal(err)
5945 }
5946 res.Body.Close()
5947 if conns != 1 {
5948 t.Fatalf("request %v: got %v conns, want 1", try, conns)
5949 }
5950 if info.Reused || info.WasIdle {
5951 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
5952 }
5953 }
5954 }
5955
5956
5957
5958
5959 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
5960 func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
5961 runTimeSensitiveTest(t, []time.Duration{
5962 10 * time.Millisecond,
5963 50 * time.Millisecond,
5964 250 * time.Millisecond,
5965 time.Second,
5966 2 * time.Second,
5967 }, func(t *testing.T, timeout time.Duration) error {
5968 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5969 select {
5970 case <-time.After(2 * timeout):
5971 fmt.Fprint(w, "ok")
5972 case <-r.Context().Done():
5973 fmt.Fprint(w, r.Context().Err())
5974 }
5975 }), func(ts *httptest.Server) {
5976 ts.Config.ReadTimeout = timeout
5977 t.Logf("Server.Config.ReadTimeout = %v", timeout)
5978 })
5979 defer cst.close()
5980 ts := cst.ts
5981
5982 var retries atomic.Int32
5983 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
5984 if retries.Add(1) != 1 {
5985 return nil, errors.New("too many retries")
5986 }
5987 return nil, nil
5988 }
5989
5990 c := ts.Client()
5991
5992 res, err := c.Get(ts.URL)
5993 if err != nil {
5994 return fmt.Errorf("Get: %v", err)
5995 }
5996 slurp, err := io.ReadAll(res.Body)
5997 res.Body.Close()
5998 if err != nil {
5999 return fmt.Errorf("Body ReadAll: %v", err)
6000 }
6001 if string(slurp) != "ok" {
6002 return fmt.Errorf("got: %q, want ok", slurp)
6003 }
6004 return nil
6005 })
6006 }
6007
6008
6009
6010
6011 func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
6012 run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
6013 }
6014 func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
6015 runTimeSensitiveTest(t, []time.Duration{
6016 10 * time.Millisecond,
6017 50 * time.Millisecond,
6018 250 * time.Millisecond,
6019 time.Second,
6020 2 * time.Second,
6021 }, func(t *testing.T, timeout time.Duration) error {
6022 cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
6023 ts.Config.ReadHeaderTimeout = timeout
6024 ts.Config.IdleTimeout = 0
6025 })
6026 defer cst.close()
6027 ts := cst.ts
6028
6029
6030
6031 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6032 if err != nil {
6033 t.Fatalf("dial failed: %v", err)
6034 }
6035 br := bufio.NewReader(conn)
6036 defer conn.Close()
6037
6038 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6039 return fmt.Errorf("writing first request failed: %v", err)
6040 }
6041
6042 if _, err := ReadResponse(br, nil); err != nil {
6043 return fmt.Errorf("first response (before timeout) failed: %v", err)
6044 }
6045
6046
6047
6048 time.Sleep(timeout * 3 / 2)
6049
6050 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6051 return fmt.Errorf("writing second request failed: %v", err)
6052 }
6053
6054 if _, err := ReadResponse(br, nil); err != nil {
6055 return fmt.Errorf("second response (after timeout) failed: %v", err)
6056 }
6057
6058 return nil
6059 })
6060 }
6061
6062
6063
6064 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
6065 for i, d := range durations {
6066 err := test(t, d)
6067 if err == nil {
6068 return
6069 }
6070 if i == len(durations)-1 || t.Failed() {
6071 t.Fatalf("failed with duration %v: %v", d, err)
6072 }
6073 t.Logf("retrying after error with duration %v: %v", d, err)
6074 }
6075 }
6076
6077
6078
6079 func TestServerDuplicateBackgroundRead(t *testing.T) {
6080 run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
6081 }
6082 func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
6083 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
6084 testenv.SkipFlaky(t, 24826)
6085 }
6086
6087 goroutines := 5
6088 requests := 2000
6089 if testing.Short() {
6090 goroutines = 3
6091 requests = 100
6092 }
6093
6094 hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
6095
6096 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6097
6098 var wg sync.WaitGroup
6099 for i := 0; i < goroutines; i++ {
6100 wg.Add(1)
6101 go func() {
6102 defer wg.Done()
6103 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
6104 if err != nil {
6105 t.Error(err)
6106 return
6107 }
6108 defer cn.Close()
6109
6110 wg.Add(1)
6111 go func() {
6112 defer wg.Done()
6113 io.Copy(io.Discard, cn)
6114 }()
6115
6116 for j := 0; j < requests; j++ {
6117 if t.Failed() {
6118 return
6119 }
6120 _, err := cn.Write(reqBytes)
6121 if err != nil {
6122 t.Error(err)
6123 return
6124 }
6125 }
6126 }()
6127 }
6128 wg.Wait()
6129 }
6130
6131
6132
6133
6134
6135
6136 func TestServerHijackGetsBackgroundByte(t *testing.T) {
6137 run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
6138 }
6139 func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
6140 if runtime.GOOS == "plan9" {
6141 t.Skip("skipping test; see https://golang.org/issue/18657")
6142 }
6143 done := make(chan struct{})
6144 inHandler := make(chan bool, 1)
6145 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6146 defer close(done)
6147
6148
6149 inHandler <- true
6150
6151 conn, buf, err := w.(Hijacker).Hijack()
6152 if err != nil {
6153 t.Error(err)
6154 return
6155 }
6156 defer conn.Close()
6157
6158 peek, err := buf.Reader.Peek(3)
6159 if string(peek) != "foo" || err != nil {
6160 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
6161 }
6162
6163 select {
6164 case <-r.Context().Done():
6165 t.Error("context unexpectedly canceled")
6166 default:
6167 }
6168 })).ts
6169
6170 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6171 if err != nil {
6172 t.Fatal(err)
6173 }
6174 defer cn.Close()
6175 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6176 t.Fatal(err)
6177 }
6178 <-inHandler
6179 if _, err := cn.Write([]byte("foo")); err != nil {
6180 t.Fatal(err)
6181 }
6182
6183 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6184 t.Fatal(err)
6185 }
6186 <-done
6187 }
6188
6189
6190 func TestServerHijackGetsFullBody(t *testing.T) {
6191 run(t, testServerHijackGetsFullBody, []testMode{http1Mode})
6192 }
6193 func testServerHijackGetsFullBody(t *testing.T, mode testMode) {
6194 if runtime.GOOS == "plan9" {
6195 t.Skip("skipping test; see https://golang.org/issue/18657")
6196 }
6197 done := make(chan struct{})
6198 needle := strings.Repeat("x", 100*1024)
6199 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6200 defer close(done)
6201
6202 conn, buf, err := w.(Hijacker).Hijack()
6203 if err != nil {
6204 t.Error(err)
6205 return
6206 }
6207 defer conn.Close()
6208
6209 got := make([]byte, len(needle))
6210 n, err := io.ReadFull(buf.Reader, got)
6211 if n != len(needle) || string(got) != needle || err != nil {
6212 t.Errorf("Peek = %q, %v; want 'x'*4096, nil", got, err)
6213 }
6214 })).ts
6215
6216 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6217 if err != nil {
6218 t.Fatal(err)
6219 }
6220 defer cn.Close()
6221 buf := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6222 buf = append(buf, []byte(needle)...)
6223 if _, err := cn.Write(buf); err != nil {
6224 t.Fatal(err)
6225 }
6226
6227 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6228 t.Fatal(err)
6229 }
6230 <-done
6231 }
6232
6233
6234
6235
6236 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
6237 run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
6238 }
6239 func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
6240 if runtime.GOOS == "plan9" {
6241 t.Skip("skipping test; see https://golang.org/issue/18657")
6242 }
6243 done := make(chan struct{})
6244 const size = 8 << 10
6245 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6246 defer close(done)
6247
6248 conn, buf, err := w.(Hijacker).Hijack()
6249 if err != nil {
6250 t.Error(err)
6251 return
6252 }
6253 defer conn.Close()
6254 slurp, err := io.ReadAll(buf.Reader)
6255 if err != nil {
6256 t.Errorf("Copy: %v", err)
6257 }
6258 allX := true
6259 for _, v := range slurp {
6260 if v != 'x' {
6261 allX = false
6262 }
6263 }
6264 if len(slurp) != size {
6265 t.Errorf("read %d; want %d", len(slurp), size)
6266 } else if !allX {
6267 t.Errorf("read %q; want %d 'x'", slurp, size)
6268 }
6269 })).ts
6270
6271 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6272 if err != nil {
6273 t.Fatal(err)
6274 }
6275 defer cn.Close()
6276 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
6277 strings.Repeat("x", size)); err != nil {
6278 t.Fatal(err)
6279 }
6280 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6281 t.Fatal(err)
6282 }
6283
6284 <-done
6285 }
6286
6287
6288 func TestServerValidatesMethod(t *testing.T) {
6289 tests := []struct {
6290 method string
6291 want int
6292 }{
6293 {"GET", 200},
6294 {"GE(T", 400},
6295 }
6296 for _, tt := range tests {
6297 conn := newTestConn()
6298 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
6299
6300 ln := &oneConnListener{conn}
6301 go Serve(ln, serve(200))
6302 <-conn.closec
6303 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
6304 if err != nil {
6305 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
6306 continue
6307 }
6308 if res.StatusCode != tt.want {
6309 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
6310 }
6311 }
6312 }
6313
6314
6315 type eofListenerNotComparable []int
6316
6317 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
6318 func (eofListenerNotComparable) Addr() net.Addr { return nil }
6319 func (eofListenerNotComparable) Close() error { return nil }
6320
6321
6322 func TestServerListenNotComparableListener(t *testing.T) {
6323 var s Server
6324 s.Serve(make(eofListenerNotComparable, 1))
6325 }
6326
6327
6328 type countCloseListener struct {
6329 net.Listener
6330 closes int32
6331 }
6332
6333 func (p *countCloseListener) Close() error {
6334 var err error
6335 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
6336 err = p.Listener.Close()
6337 }
6338 return err
6339 }
6340
6341
6342 func TestServerCloseListenerOnce(t *testing.T) {
6343 setParallel(t)
6344 defer afterTest(t)
6345
6346 ln := newLocalListener(t)
6347 defer ln.Close()
6348
6349 cl := &countCloseListener{Listener: ln}
6350 server := &Server{}
6351 sdone := make(chan bool, 1)
6352
6353 go func() {
6354 server.Serve(cl)
6355 sdone <- true
6356 }()
6357 time.Sleep(10 * time.Millisecond)
6358 server.Shutdown(context.Background())
6359 ln.Close()
6360 <-sdone
6361
6362 nclose := atomic.LoadInt32(&cl.closes)
6363 if nclose != 1 {
6364 t.Errorf("Close calls = %v; want 1", nclose)
6365 }
6366 }
6367
6368
6369 func TestServerShutdownThenServe(t *testing.T) {
6370 var srv Server
6371 cl := &countCloseListener{Listener: nil}
6372 srv.Shutdown(context.Background())
6373 got := srv.Serve(cl)
6374 if got != ErrServerClosed {
6375 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6376 }
6377 nclose := atomic.LoadInt32(&cl.closes)
6378 if nclose != 1 {
6379 t.Errorf("Close calls = %v; want 1", nclose)
6380 }
6381 }
6382
6383
6384 func TestStripPortFromHost(t *testing.T) {
6385 mux := NewServeMux()
6386
6387 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6388 fmt.Fprintf(w, "OK")
6389 })
6390 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6391 fmt.Fprintf(w, "uh-oh!")
6392 })
6393
6394 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6395 rw := httptest.NewRecorder()
6396
6397 mux.ServeHTTP(rw, req)
6398
6399 response := rw.Body.String()
6400 if response != "OK" {
6401 t.Errorf("Response gotten was %q", response)
6402 }
6403 }
6404
6405 func TestServerContexts(t *testing.T) { run(t, testServerContexts) }
6406 func testServerContexts(t *testing.T, mode testMode) {
6407 type baseKey struct{}
6408 type connKey struct{}
6409 ch := make(chan context.Context, 1)
6410 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6411 ch <- r.Context()
6412 }), func(ts *httptest.Server) {
6413 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6414 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6415 t.Errorf("unexpected onceClose listener type %T", ln)
6416 }
6417 return context.WithValue(context.Background(), baseKey{}, "base")
6418 }
6419 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6420 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6421 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6422 }
6423 return context.WithValue(ctx, connKey{}, "conn")
6424 }
6425 }).ts
6426 res, err := ts.Client().Get(ts.URL)
6427 if err != nil {
6428 t.Fatal(err)
6429 }
6430 res.Body.Close()
6431 ctx := <-ch
6432 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6433 t.Errorf("base context key = %#v; want %q", got, want)
6434 }
6435 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6436 t.Errorf("conn context key = %#v; want %q", got, want)
6437 }
6438 }
6439
6440
6441 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6442 run(t, testConnContextNotModifyingAllContexts)
6443 }
6444 func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
6445 type connKey struct{}
6446 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6447 rw.Header().Set("Connection", "close")
6448 }), func(ts *httptest.Server) {
6449 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6450 if got := ctx.Value(connKey{}); got != nil {
6451 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6452 }
6453 return context.WithValue(ctx, connKey{}, "conn")
6454 }
6455 }).ts
6456
6457 var res *Response
6458 var err error
6459
6460 res, err = ts.Client().Get(ts.URL)
6461 if err != nil {
6462 t.Fatal(err)
6463 }
6464 res.Body.Close()
6465
6466 res, err = ts.Client().Get(ts.URL)
6467 if err != nil {
6468 t.Fatal(err)
6469 }
6470 res.Body.Close()
6471 }
6472
6473
6474
6475 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6476 run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
6477 }
6478 func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
6479 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6480 w.Write([]byte("Hello, World!"))
6481 })).ts
6482
6483 serverURL, err := url.Parse(cst.URL)
6484 if err != nil {
6485 t.Fatalf("Failed to parse server URL: %v", err)
6486 }
6487
6488 unsupportedTEs := []string{
6489 "fugazi",
6490 "foo-bar",
6491 "unknown",
6492 `" chunked"`,
6493 }
6494
6495 for _, badTE := range unsupportedTEs {
6496 http1ReqBody := fmt.Sprintf(""+
6497 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6498 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6499
6500 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6501 if err != nil {
6502 t.Errorf("%q. unexpected error: %v", badTE, err)
6503 continue
6504 }
6505
6506 wantBody := fmt.Sprintf("" +
6507 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6508 "Connection: close\r\n\r\nUnsupported transfer encoding")
6509
6510 if string(gotBody) != wantBody {
6511 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6512 }
6513 }
6514 }
6515
6516
6517 func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) }
6518 func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
6519 type setting struct {
6520 name string
6521 body []byte
6522
6523
6524
6525
6526 contentEncoding any
6527 wantContentType string
6528 }
6529
6530 settings := []*setting{
6531 {
6532 name: "gzip content-encoding, gzipped",
6533 contentEncoding: "application/gzip",
6534 wantContentType: "",
6535 body: func() []byte {
6536 buf := new(bytes.Buffer)
6537 gzw := gzip.NewWriter(buf)
6538 gzw.Write([]byte("doctype html><p>Hello</p>"))
6539 gzw.Close()
6540 return buf.Bytes()
6541 }(),
6542 },
6543 {
6544 name: "zlib content-encoding, zlibbed",
6545 contentEncoding: "application/zlib",
6546 wantContentType: "",
6547 body: func() []byte {
6548 buf := new(bytes.Buffer)
6549 zw := zlib.NewWriter(buf)
6550 zw.Write([]byte("doctype html><p>Hello</p>"))
6551 zw.Close()
6552 return buf.Bytes()
6553 }(),
6554 },
6555 {
6556 name: "no content-encoding",
6557 wantContentType: "application/x-gzip",
6558 body: func() []byte {
6559 buf := new(bytes.Buffer)
6560 gzw := gzip.NewWriter(buf)
6561 gzw.Write([]byte("doctype html><p>Hello</p>"))
6562 gzw.Close()
6563 return buf.Bytes()
6564 }(),
6565 },
6566 {
6567 name: "phony content-encoding",
6568 contentEncoding: "foo/bar",
6569 body: []byte("doctype html><p>Hello</p>"),
6570 },
6571 {
6572 name: "empty but set content-encoding",
6573 contentEncoding: "",
6574 wantContentType: "audio/mpeg",
6575 body: []byte("ID3"),
6576 },
6577 }
6578
6579 for _, tt := range settings {
6580 t.Run(tt.name, func(t *testing.T) {
6581 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6582 if tt.contentEncoding != nil {
6583 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6584 }
6585 rw.Write(tt.body)
6586 }))
6587
6588 res, err := cst.c.Get(cst.ts.URL)
6589 if err != nil {
6590 t.Fatalf("Failed to fetch URL: %v", err)
6591 }
6592 defer res.Body.Close()
6593
6594 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6595 if w != nil {
6596 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6597 } else if g != "" {
6598 t.Errorf("Unexpected Content-Encoding %q", g)
6599 }
6600 }
6601
6602 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6603 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6604 }
6605 })
6606 }
6607 }
6608
6609
6610
6611 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6612 run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
6613 }
6614 func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
6615 if testing.Short() {
6616 t.Skip("skipping in short mode")
6617 }
6618
6619 pc, curFile, _, _ := runtime.Caller(0)
6620 curFileBaseName := filepath.Base(curFile)
6621 testFuncName := runtime.FuncForPC(pc).Name()
6622
6623 timeoutMsg := "timed out here!"
6624
6625 tests := []struct {
6626 name string
6627 mustTimeout bool
6628 wantResp string
6629 }{
6630 {
6631 name: "return before timeout",
6632 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6633 },
6634 {
6635 name: "return after timeout",
6636 mustTimeout: true,
6637 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6638 len(timeoutMsg), timeoutMsg),
6639 },
6640 }
6641
6642 for _, tt := range tests {
6643 tt := tt
6644 t.Run(tt.name, func(t *testing.T) {
6645 exitHandler := make(chan bool, 1)
6646 defer close(exitHandler)
6647 lastLine := make(chan int, 1)
6648
6649 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6650 w.WriteHeader(404)
6651 w.WriteHeader(404)
6652 w.WriteHeader(404)
6653 w.WriteHeader(404)
6654 _, _, line, _ := runtime.Caller(0)
6655 lastLine <- line
6656 <-exitHandler
6657 })
6658
6659 if !tt.mustTimeout {
6660 exitHandler <- true
6661 }
6662
6663 logBuf := new(strings.Builder)
6664 srvLog := log.New(logBuf, "", 0)
6665
6666 dur := 20 * time.Millisecond
6667 if !tt.mustTimeout {
6668
6669 dur = 10 * time.Second
6670 }
6671 th := TimeoutHandler(sh, dur, timeoutMsg)
6672 cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
6673 defer cst.close()
6674
6675 res, err := cst.c.Get(cst.ts.URL)
6676 if err != nil {
6677 t.Fatalf("Unexpected error: %v", err)
6678 }
6679
6680
6681
6682 res.Header.Del("Date")
6683 res.Header.Del("Content-Type")
6684
6685
6686 blob, _ := httputil.DumpResponse(res, true)
6687 if g, w := string(blob), tt.wantResp; g != w {
6688 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6689 }
6690
6691
6692
6693 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6694 if g, w := len(logEntries), 3; g != w {
6695 blob, _ := json.MarshalIndent(logEntries, "", " ")
6696 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6697 }
6698
6699 lastSpuriousLine := <-lastLine
6700 firstSpuriousLine := lastSpuriousLine - 3
6701
6702
6703 for i, logEntry := range logEntries {
6704 wantLine := firstSpuriousLine + i
6705 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6706 testFuncName, curFileBaseName, wantLine)
6707 re := regexp.MustCompile(pat)
6708 if !re.MatchString(logEntry) {
6709 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6710 }
6711 }
6712 })
6713 }
6714 }
6715
6716
6717
6718
6719 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6720 conn, err := net.Dial("tcp", host)
6721 if err != nil {
6722 return nil, err
6723 }
6724 defer conn.Close()
6725
6726 if _, err := conn.Write(http1ReqBody); err != nil {
6727 return nil, err
6728 }
6729 return io.ReadAll(conn)
6730 }
6731
6732 func BenchmarkResponseStatusLine(b *testing.B) {
6733 b.ReportAllocs()
6734 b.RunParallel(func(pb *testing.PB) {
6735 bw := bufio.NewWriter(io.Discard)
6736 var buf3 [3]byte
6737 for pb.Next() {
6738 Export_writeStatusLine(bw, true, 200, buf3[:])
6739 }
6740 })
6741 }
6742
6743 func TestDisableKeepAliveUpgrade(t *testing.T) {
6744 run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
6745 }
6746 func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
6747 if testing.Short() {
6748 t.Skip("skipping in short mode")
6749 }
6750
6751 s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6752 w.Header().Set("Connection", "Upgrade")
6753 w.Header().Set("Upgrade", "someProto")
6754 w.WriteHeader(StatusSwitchingProtocols)
6755 c, buf, err := w.(Hijacker).Hijack()
6756 if err != nil {
6757 return
6758 }
6759 defer c.Close()
6760
6761
6762
6763 io.Copy(c, buf)
6764 }), func(ts *httptest.Server) {
6765 ts.Config.SetKeepAlivesEnabled(false)
6766 }).ts
6767
6768 cl := s.Client()
6769 cl.Transport.(*Transport).DisableKeepAlives = true
6770
6771 resp, err := cl.Get(s.URL)
6772 if err != nil {
6773 t.Fatalf("failed to perform request: %v", err)
6774 }
6775 defer resp.Body.Close()
6776
6777 if resp.StatusCode != StatusSwitchingProtocols {
6778 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6779 }
6780
6781 rwc, ok := resp.Body.(io.ReadWriteCloser)
6782 if !ok {
6783 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6784 }
6785
6786 _, err = rwc.Write([]byte("hello"))
6787 if err != nil {
6788 t.Fatalf("failed to write to body: %v", err)
6789 }
6790
6791 b := make([]byte, 5)
6792 _, err = io.ReadFull(rwc, b)
6793 if err != nil {
6794 t.Fatalf("failed to read from body: %v", err)
6795 }
6796
6797 if string(b) != "hello" {
6798 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6799 }
6800 }
6801
6802 type tlogWriter struct{ t *testing.T }
6803
6804 func (w tlogWriter) Write(p []byte) (int, error) {
6805 w.t.Log(string(p))
6806 return len(p), nil
6807 }
6808
6809 func TestWriteHeaderSwitchingProtocols(t *testing.T) {
6810 run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
6811 }
6812 func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
6813 const wantBody = "want"
6814 const wantUpgrade = "someProto"
6815 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6816 w.Header().Set("Connection", "Upgrade")
6817 w.Header().Set("Upgrade", wantUpgrade)
6818 w.WriteHeader(StatusSwitchingProtocols)
6819 NewResponseController(w).Flush()
6820
6821
6822 w.WriteHeader(200)
6823 if _, err := w.Write([]byte("x")); err == nil {
6824 t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
6825 }
6826
6827 c, _, err := NewResponseController(w).Hijack()
6828 if err != nil {
6829 t.Errorf("Hijack: %v", err)
6830 return
6831 }
6832 defer c.Close()
6833 if _, err := c.Write([]byte(wantBody)); err != nil {
6834 t.Errorf("Write to hijacked body: %v", err)
6835 }
6836 }), func(ts *httptest.Server) {
6837
6838 ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
6839 }).ts
6840
6841 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6842 if err != nil {
6843 t.Fatalf("net.Dial: %v", err)
6844 }
6845 _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
6846 if err != nil {
6847 t.Fatalf("conn.Write: %v", err)
6848 }
6849 defer conn.Close()
6850
6851 r := bufio.NewReader(conn)
6852 res, err := ReadResponse(r, &Request{Method: "GET"})
6853 if err != nil {
6854 t.Fatal("ReadResponse error:", err)
6855 }
6856 if res.StatusCode != StatusSwitchingProtocols {
6857 t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
6858 }
6859 if got := res.Header.Get("Upgrade"); got != wantUpgrade {
6860 t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
6861 }
6862 body, err := io.ReadAll(r)
6863 if err != nil {
6864 t.Error(err)
6865 }
6866 if string(body) != wantBody {
6867 t.Errorf("Response body = %q, want %q", string(body), wantBody)
6868 }
6869 }
6870
6871 func TestMuxRedirectRelative(t *testing.T) {
6872 setParallel(t)
6873 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
6874 if err != nil {
6875 t.Errorf("%s", err)
6876 }
6877 mux := NewServeMux()
6878 resp := httptest.NewRecorder()
6879 mux.ServeHTTP(resp, req)
6880 if got, want := resp.Header().Get("Location"), "/"; got != want {
6881 t.Errorf("Location header expected %q; got %q", want, got)
6882 }
6883 if got, want := resp.Code, StatusMovedPermanently; got != want {
6884 t.Errorf("Expected response code %d; got %d", want, got)
6885 }
6886 }
6887
6888
6889 func TestQuerySemicolon(t *testing.T) {
6890 t.Cleanup(func() { afterTest(t) })
6891
6892 tests := []struct {
6893 query string
6894 xNoSemicolons string
6895 xWithSemicolons string
6896 expectParseFormErr bool
6897 }{
6898 {"?a=1;x=bad&x=good", "good", "bad", true},
6899 {"?a=1;b=bad&x=good", "good", "good", true},
6900 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
6901 {"?a=1;x=good;x=bad", "", "good", true},
6902 }
6903
6904 run(t, func(t *testing.T, mode testMode) {
6905 for _, tt := range tests {
6906 t.Run(tt.query+"/allow=false", func(t *testing.T) {
6907 allowSemicolons := false
6908 testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
6909 })
6910 t.Run(tt.query+"/allow=true", func(t *testing.T) {
6911 allowSemicolons, expectParseFormErr := true, false
6912 testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
6913 })
6914 }
6915 })
6916 }
6917
6918 func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
6919 writeBackX := func(w ResponseWriter, r *Request) {
6920 x := r.URL.Query().Get("x")
6921 if expectParseFormErr {
6922 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
6923 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
6924 }
6925 } else {
6926 if err := r.ParseForm(); err != nil {
6927 t.Errorf("expected no error from ParseForm, got %v", err)
6928 }
6929 }
6930 if got := r.FormValue("x"); x != got {
6931 t.Errorf("got %q from FormValue, want %q", got, x)
6932 }
6933 fmt.Fprintf(w, "%s", x)
6934 }
6935
6936 h := Handler(HandlerFunc(writeBackX))
6937 if allowSemicolons {
6938 h = AllowQuerySemicolons(h)
6939 }
6940
6941 logBuf := &strings.Builder{}
6942 ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
6943 ts.Config.ErrorLog = log.New(logBuf, "", 0)
6944 }).ts
6945
6946 req, _ := NewRequest("GET", ts.URL+query, nil)
6947 res, err := ts.Client().Do(req)
6948 if err != nil {
6949 t.Fatal(err)
6950 }
6951 slurp, _ := io.ReadAll(res.Body)
6952 res.Body.Close()
6953 if got, want := res.StatusCode, 200; got != want {
6954 t.Errorf("Status = %d; want = %d", got, want)
6955 }
6956 if got, want := string(slurp), wantX; got != want {
6957 t.Errorf("Body = %q; want = %q", got, want)
6958 }
6959 }
6960
6961 func TestMaxBytesHandler(t *testing.T) {
6962
6963 defer afterTest(t)
6964
6965 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
6966 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
6967 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
6968 func(t *testing.T) {
6969 run(t, func(t *testing.T, mode testMode) {
6970 testMaxBytesHandler(t, mode, maxSize, requestSize)
6971 }, testNotParallel)
6972 })
6973 }
6974 }
6975 }
6976
6977 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
6978 runTimeSensitiveTest(t, []time.Duration{
6979 1 * time.Millisecond,
6980 5 * time.Millisecond,
6981 10 * time.Millisecond,
6982 50 * time.Millisecond,
6983 100 * time.Millisecond,
6984 500 * time.Millisecond,
6985 time.Second,
6986 5 * time.Second,
6987 }, func(t *testing.T, timeout time.Duration) error {
6988 SetRSTAvoidanceDelay(t, timeout)
6989 t.Logf("set RST avoidance delay to %v", timeout)
6990
6991 var (
6992 handlerN int64
6993 handlerErr error
6994 )
6995 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
6996 var buf bytes.Buffer
6997 handlerN, handlerErr = io.Copy(&buf, r.Body)
6998 io.Copy(w, &buf)
6999 })
7000
7001 cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
7002
7003
7004 defer cst.close()
7005 ts := cst.ts
7006 c := ts.Client()
7007
7008 body := strings.Repeat("a", int(requestSize))
7009 var wg sync.WaitGroup
7010 defer wg.Wait()
7011 getBody := func() (io.ReadCloser, error) {
7012 wg.Add(1)
7013 body := &wgReadCloser{
7014 Reader: strings.NewReader(body),
7015 wg: &wg,
7016 }
7017 return body, nil
7018 }
7019 reqBody, _ := getBody()
7020 req, err := NewRequest("POST", ts.URL, reqBody)
7021 if err != nil {
7022 reqBody.Close()
7023 t.Fatal(err)
7024 }
7025 req.ContentLength = int64(len(body))
7026 req.GetBody = getBody
7027 req.Header.Set("Content-Type", "text/plain")
7028
7029 var buf strings.Builder
7030 res, err := c.Do(req)
7031 if err != nil {
7032 return fmt.Errorf("unexpected connection error: %v", err)
7033 } else {
7034 _, err = io.Copy(&buf, res.Body)
7035 res.Body.Close()
7036 if err != nil {
7037 return fmt.Errorf("unexpected read error: %v", err)
7038 }
7039 }
7040
7041
7042
7043
7044 if handlerN > maxSize {
7045 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
7046 }
7047 if requestSize > maxSize && handlerErr == nil {
7048 t.Error("expected error on handler side; got nil")
7049 }
7050 if requestSize <= maxSize {
7051 if handlerErr != nil {
7052 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
7053 }
7054 if handlerN != requestSize {
7055 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
7056 }
7057 }
7058 if buf.Len() != int(handlerN) {
7059 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
7060 }
7061
7062 return nil
7063 })
7064 }
7065
7066 func TestEarlyHints(t *testing.T) {
7067 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7068 h := w.Header()
7069 h.Add("Link", "</style.css>; rel=preload; as=style")
7070 h.Add("Link", "</script.js>; rel=preload; as=script")
7071 w.WriteHeader(StatusEarlyHints)
7072
7073 h.Add("Link", "</foo.js>; rel=preload; as=script")
7074 w.WriteHeader(StatusEarlyHints)
7075
7076 w.Write([]byte("stuff"))
7077 }))
7078
7079 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7080 expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: "
7081 if !strings.Contains(got, expected) {
7082 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7083 }
7084 }
7085 func TestProcessing(t *testing.T) {
7086 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7087 w.WriteHeader(StatusProcessing)
7088 w.Write([]byte("stuff"))
7089 }))
7090
7091 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7092 expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: "
7093 if !strings.Contains(got, expected) {
7094 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7095 }
7096 }
7097
7098 func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) }
7099 func testParseFormCleanup(t *testing.T, mode testMode) {
7100 if mode == http2Mode {
7101 t.Skip("https://go.dev/issue/20253")
7102 }
7103
7104 const maxMemory = 1024
7105 const key = "file"
7106
7107 if runtime.GOOS == "windows" {
7108
7109 t.Skip("https://go.dev/issue/25965")
7110 }
7111
7112 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7113 r.ParseMultipartForm(maxMemory)
7114 f, _, err := r.FormFile(key)
7115 if err != nil {
7116 t.Errorf("r.FormFile(%q) = %v", key, err)
7117 return
7118 }
7119 of, ok := f.(*os.File)
7120 if !ok {
7121 t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
7122 return
7123 }
7124 w.Write([]byte(of.Name()))
7125 }))
7126
7127 fBuf := new(bytes.Buffer)
7128 mw := multipart.NewWriter(fBuf)
7129 mf, err := mw.CreateFormFile(key, "myfile.txt")
7130 if err != nil {
7131 t.Fatal(err)
7132 }
7133 if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
7134 t.Fatal(err)
7135 }
7136 if err := mw.Close(); err != nil {
7137 t.Fatal(err)
7138 }
7139 req, err := NewRequest("POST", cst.ts.URL, fBuf)
7140 if err != nil {
7141 t.Fatal(err)
7142 }
7143 req.Header.Set("Content-Type", mw.FormDataContentType())
7144 res, err := cst.c.Do(req)
7145 if err != nil {
7146 t.Fatal(err)
7147 }
7148 defer res.Body.Close()
7149 fname, err := io.ReadAll(res.Body)
7150 if err != nil {
7151 t.Fatal(err)
7152 }
7153 cst.close()
7154 if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
7155 t.Errorf("file %q exists after HTTP handler returned", string(fname))
7156 }
7157 }
7158
7159 func TestHeadBody(t *testing.T) {
7160 const identityMode = false
7161 const chunkedMode = true
7162 run(t, func(t *testing.T, mode testMode) {
7163 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
7164 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
7165 })
7166 }
7167
7168 func TestGetBody(t *testing.T) {
7169 const identityMode = false
7170 const chunkedMode = true
7171 run(t, func(t *testing.T, mode testMode) {
7172 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
7173 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
7174 })
7175 }
7176
7177 func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
7178 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7179 b, err := io.ReadAll(r.Body)
7180 if err != nil {
7181 t.Errorf("server reading body: %v", err)
7182 return
7183 }
7184 w.Header().Set("X-Request-Body", string(b))
7185 w.Header().Set("Content-Length", "0")
7186 }))
7187 defer cst.close()
7188 for _, reqBody := range []string{
7189 "",
7190 "",
7191 "request_body",
7192 "",
7193 } {
7194 var bodyReader io.Reader
7195 if reqBody != "" {
7196 bodyReader = strings.NewReader(reqBody)
7197 if chunked {
7198 bodyReader = bufio.NewReader(bodyReader)
7199 }
7200 }
7201 req, err := NewRequest(method, cst.ts.URL, bodyReader)
7202 if err != nil {
7203 t.Fatal(err)
7204 }
7205 res, err := cst.c.Do(req)
7206 if err != nil {
7207 t.Fatal(err)
7208 }
7209 res.Body.Close()
7210 if got, want := res.StatusCode, 200; got != want {
7211 t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
7212 }
7213 if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
7214 t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
7215 }
7216 }
7217 }
7218
7219
7220
7221 func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
7222 func testDisableContentLength(t *testing.T, mode testMode) {
7223 noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7224 w.Header()["Content-Length"] = nil
7225 fmt.Fprintf(w, "OK")
7226 }))
7227
7228 res, err := noCL.c.Get(noCL.ts.URL)
7229 if err != nil {
7230 t.Fatal(err)
7231 }
7232 if got, haveCL := res.Header["Content-Length"]; haveCL {
7233 t.Errorf("Unexpected Content-Length: %q", got)
7234 }
7235 if err := res.Body.Close(); err != nil {
7236 t.Fatal(err)
7237 }
7238
7239 withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7240 fmt.Fprintf(w, "OK")
7241 }))
7242
7243 res, err = withCL.c.Get(withCL.ts.URL)
7244 if err != nil {
7245 t.Fatal(err)
7246 }
7247 if got := res.Header.Get("Content-Length"); got != "2" {
7248 t.Errorf("Content-Length: %q; want 2", got)
7249 }
7250 if err := res.Body.Close(); err != nil {
7251 t.Fatal(err)
7252 }
7253 }
7254
7255 func TestErrorContentLength(t *testing.T) { run(t, testErrorContentLength) }
7256 func testErrorContentLength(t *testing.T, mode testMode) {
7257 const errorBody = "an error occurred"
7258 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7259 w.Header().Set("Content-Length", "1000")
7260 Error(w, errorBody, 400)
7261 }))
7262 res, err := cst.c.Get(cst.ts.URL)
7263 if err != nil {
7264 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7265 }
7266 defer res.Body.Close()
7267 body, err := io.ReadAll(res.Body)
7268 if err != nil {
7269 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7270 }
7271 if string(body) != errorBody+"\n" {
7272 t.Fatalf("read body: %q, want %q", string(body), errorBody)
7273 }
7274 }
7275
7276 func TestError(t *testing.T) {
7277 w := httptest.NewRecorder()
7278 w.Header().Set("Content-Length", "1")
7279 w.Header().Set("X-Content-Type-Options", "scratch and sniff")
7280 w.Header().Set("Other", "foo")
7281 Error(w, "oops", 432)
7282
7283 h := w.Header()
7284 for _, hdr := range []string{"Content-Length"} {
7285 if v, ok := h[hdr]; ok {
7286 t.Errorf("%s: %q, want not present", hdr, v)
7287 }
7288 }
7289 if v := h.Get("Content-Type"); v != "text/plain; charset=utf-8" {
7290 t.Errorf("Content-Type: %q, want %q", v, "text/plain; charset=utf-8")
7291 }
7292 if v := h.Get("X-Content-Type-Options"); v != "nosniff" {
7293 t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff")
7294 }
7295 }
7296
7297 func TestServerReadAfterWriteHeader100Continue(t *testing.T) {
7298 run(t, testServerReadAfterWriteHeader100Continue)
7299 }
7300 func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) {
7301 t.Skip("https://go.dev/issue/67555")
7302 body := []byte("body")
7303 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7304 w.WriteHeader(200)
7305 NewResponseController(w).Flush()
7306 io.ReadAll(r.Body)
7307 w.Write(body)
7308 }), func(tr *Transport) {
7309 tr.ExpectContinueTimeout = 24 * time.Hour
7310 })
7311
7312 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7313 req.Header.Set("Expect", "100-continue")
7314 res, err := cst.c.Do(req)
7315 if err != nil {
7316 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7317 }
7318 defer res.Body.Close()
7319 got, err := io.ReadAll(res.Body)
7320 if err != nil {
7321 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7322 }
7323 if !bytes.Equal(got, body) {
7324 t.Fatalf("response body = %q, want %q", got, body)
7325 }
7326 }
7327
7328 func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
7329 run(t, testServerReadAfterHandlerDone100Continue)
7330 }
7331 func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
7332 t.Skip("https://go.dev/issue/67555")
7333 readyc := make(chan struct{})
7334 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7335 go func() {
7336 <-readyc
7337 io.ReadAll(r.Body)
7338 <-readyc
7339 }()
7340 }), func(tr *Transport) {
7341 tr.ExpectContinueTimeout = 24 * time.Hour
7342 })
7343
7344 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7345 req.Header.Set("Expect", "100-continue")
7346 res, err := cst.c.Do(req)
7347 if err != nil {
7348 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7349 }
7350 res.Body.Close()
7351 readyc <- struct{}{}
7352 readyc <- struct{}{}
7353 }
7354
7355 func TestServerReadAfterHandlerAbort100Continue(t *testing.T) {
7356 run(t, testServerReadAfterHandlerAbort100Continue)
7357 }
7358 func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) {
7359 t.Skip("https://go.dev/issue/67555")
7360 readyc := make(chan struct{})
7361 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7362 go func() {
7363 <-readyc
7364 io.ReadAll(r.Body)
7365 <-readyc
7366 }()
7367 panic(ErrAbortHandler)
7368 }), func(tr *Transport) {
7369 tr.ExpectContinueTimeout = 24 * time.Hour
7370 })
7371
7372 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7373 req.Header.Set("Expect", "100-continue")
7374 res, err := cst.c.Do(req)
7375 if err == nil {
7376 res.Body.Close()
7377 }
7378 readyc <- struct{}{}
7379 readyc <- struct{}{}
7380 }
7381
7382 func TestInvalidChunkedBodies(t *testing.T) {
7383 for _, test := range []struct {
7384 name string
7385 b string
7386 }{{
7387 name: "bare LF in chunk size",
7388 b: "1\na\r\n0\r\n\r\n",
7389 }, {
7390 name: "bare LF at body end",
7391 b: "1\r\na\r\n0\r\n\n",
7392 }} {
7393 t.Run(test.name, func(t *testing.T) {
7394 reqc := make(chan error)
7395 ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7396 got, err := io.ReadAll(r.Body)
7397 if err == nil {
7398 t.Logf("read body: %q", got)
7399 }
7400 reqc <- err
7401 })).ts
7402
7403 serverURL, err := url.Parse(ts.URL)
7404 if err != nil {
7405 t.Fatal(err)
7406 }
7407
7408 conn, err := net.Dial("tcp", serverURL.Host)
7409 if err != nil {
7410 t.Fatal(err)
7411 }
7412
7413 if _, err := conn.Write([]byte(
7414 "POST / HTTP/1.1\r\n" +
7415 "Host: localhost\r\n" +
7416 "Transfer-Encoding: chunked\r\n" +
7417 "Connection: close\r\n" +
7418 "\r\n" +
7419 test.b)); err != nil {
7420 t.Fatal(err)
7421 }
7422 conn.(*net.TCPConn).CloseWrite()
7423
7424 if err := <-reqc; err == nil {
7425 t.Errorf("server handler: io.ReadAll(r.Body) succeeded, want error")
7426 }
7427 })
7428 }
7429 }
7430
7431
7432 func TestServerTLSNextProtos(t *testing.T) {
7433 run(t, testServerTLSNextProtos, []testMode{https1Mode, http2Mode})
7434 }
7435 func testServerTLSNextProtos(t *testing.T, mode testMode) {
7436 CondSkipHTTP2(t)
7437
7438 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7439 if err != nil {
7440 t.Fatal(err)
7441 }
7442 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7443 if err != nil {
7444 t.Fatal(err)
7445 }
7446 certpool := x509.NewCertPool()
7447 certpool.AddCert(leafCert)
7448
7449 protos := new(Protocols)
7450 switch mode {
7451 case https1Mode:
7452 protos.SetHTTP1(true)
7453 case http2Mode:
7454 protos.SetHTTP2(true)
7455 }
7456
7457 wantNextProtos := []string{"http/1.1", "h2", "other"}
7458 nextProtos := slices.Clone(wantNextProtos)
7459
7460
7461 srv := &Server{
7462 TLSConfig: &tls.Config{
7463 Certificates: []tls.Certificate{cert},
7464 NextProtos: nextProtos,
7465 },
7466 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {}),
7467 Protocols: protos,
7468 }
7469 tr := &Transport{
7470 TLSClientConfig: &tls.Config{
7471 RootCAs: certpool,
7472 NextProtos: nextProtos,
7473 },
7474 Protocols: protos,
7475 }
7476
7477 listener := newLocalListener(t)
7478 srvc := make(chan error, 1)
7479 go func() {
7480 srvc <- srv.ServeTLS(listener, "", "")
7481 }()
7482 t.Cleanup(func() {
7483 srv.Close()
7484 <-srvc
7485 })
7486
7487 client := &Client{Transport: tr}
7488 resp, err := client.Get("https://" + listener.Addr().String())
7489 if err != nil {
7490 t.Fatal(err)
7491 }
7492 resp.Body.Close()
7493
7494 if !slices.Equal(nextProtos, wantNextProtos) {
7495 t.Fatalf("after running test: original NextProtos slice = %v, want %v", nextProtos, wantNextProtos)
7496 }
7497 }
7498
View as plain text