Source file
src/net/http/serve_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bufio"
11 "bytes"
12 "compress/gzip"
13 "compress/zlib"
14 "context"
15 crand "crypto/rand"
16 "crypto/tls"
17 "crypto/x509"
18 "encoding/json"
19 "errors"
20 "fmt"
21 "internal/testenv"
22 "io"
23 "log"
24 "math/rand"
25 "mime/multipart"
26 "net"
27 . "net/http"
28 "net/http/httptest"
29 "net/http/httptrace"
30 "net/http/httputil"
31 "net/http/internal"
32 "net/http/internal/testcert"
33 "net/url"
34 "os"
35 "path/filepath"
36 "reflect"
37 "regexp"
38 "runtime"
39 "slices"
40 "strconv"
41 "strings"
42 "sync"
43 "sync/atomic"
44 "syscall"
45 "testing"
46 "testing/synctest"
47 "time"
48 )
49
50 type dummyAddr string
51 type oneConnListener struct {
52 conn net.Conn
53 }
54
55 func (l *oneConnListener) Accept() (c net.Conn, err error) {
56 c = l.conn
57 if c == nil {
58 err = io.EOF
59 return
60 }
61 err = nil
62 l.conn = nil
63 return
64 }
65
66 func (l *oneConnListener) Close() error {
67 return nil
68 }
69
70 func (l *oneConnListener) Addr() net.Addr {
71 return dummyAddr("test-address")
72 }
73
74 func (a dummyAddr) Network() string {
75 return string(a)
76 }
77
78 func (a dummyAddr) String() string {
79 return string(a)
80 }
81
82 type noopConn struct{}
83
84 func (noopConn) LocalAddr() net.Addr { return dummyAddr("local-addr") }
85 func (noopConn) RemoteAddr() net.Addr { return dummyAddr("remote-addr") }
86 func (noopConn) SetDeadline(t time.Time) error { return nil }
87 func (noopConn) SetReadDeadline(t time.Time) error { return nil }
88 func (noopConn) SetWriteDeadline(t time.Time) error { return nil }
89
90 type rwTestConn struct {
91 io.Reader
92 io.Writer
93 noopConn
94
95 closeFunc func() error
96 closec chan bool
97 }
98
99 func (c *rwTestConn) Close() error {
100 if c.closeFunc != nil {
101 return c.closeFunc()
102 }
103 select {
104 case c.closec <- true:
105 default:
106 }
107 return nil
108 }
109
110 type testConn struct {
111 readMu sync.Mutex
112 readBuf bytes.Buffer
113 writeBuf bytes.Buffer
114 closec chan bool
115 noopConn
116 }
117
118 func newTestConn() *testConn {
119 return &testConn{closec: make(chan bool, 1)}
120 }
121
122 func (c *testConn) Read(b []byte) (int, error) {
123 c.readMu.Lock()
124 defer c.readMu.Unlock()
125 return c.readBuf.Read(b)
126 }
127
128 func (c *testConn) Write(b []byte) (int, error) {
129 return c.writeBuf.Write(b)
130 }
131
132 func (c *testConn) Close() error {
133 select {
134 case c.closec <- true:
135 default:
136 }
137 return nil
138 }
139
140
141
142 func reqBytes(req string) []byte {
143 return []byte(strings.ReplaceAll(strings.TrimSpace(req), "\n", "\r\n") + "\r\n\r\n")
144 }
145
146 type handlerTest struct {
147 logbuf bytes.Buffer
148 handler Handler
149 }
150
151 func newHandlerTest(h Handler) handlerTest {
152 return handlerTest{handler: h}
153 }
154
155 func (ht *handlerTest) rawResponse(req string) string {
156 reqb := reqBytes(req)
157 var output strings.Builder
158 conn := &rwTestConn{
159 Reader: bytes.NewReader(reqb),
160 Writer: &output,
161 closec: make(chan bool, 1),
162 }
163 ln := &oneConnListener{conn: conn}
164 srv := &Server{
165 ErrorLog: log.New(&ht.logbuf, "", 0),
166 Handler: ht.handler,
167 }
168 go srv.Serve(ln)
169 <-conn.closec
170 return output.String()
171 }
172
173 func TestConsumingBodyOnNextConn(t *testing.T) {
174 t.Parallel()
175 defer afterTest(t)
176 conn := new(testConn)
177 for i := 0; i < 2; i++ {
178 conn.readBuf.Write([]byte(
179 "POST / HTTP/1.1\r\n" +
180 "Host: test\r\n" +
181 "Content-Length: 11\r\n" +
182 "\r\n" +
183 "foo=1&bar=1"))
184 }
185
186 reqNum := 0
187 ch := make(chan *Request)
188 servech := make(chan error)
189 listener := &oneConnListener{conn}
190 handler := func(res ResponseWriter, req *Request) {
191 reqNum++
192 ch <- req
193 }
194
195 go func() {
196 servech <- Serve(listener, HandlerFunc(handler))
197 }()
198
199 var req *Request
200 req = <-ch
201 if req == nil {
202 t.Fatal("Got nil first request.")
203 }
204 if req.Method != "POST" {
205 t.Errorf("For request #1's method, got %q; expected %q",
206 req.Method, "POST")
207 }
208
209 req = <-ch
210 if req == nil {
211 t.Fatal("Got nil first request.")
212 }
213 if req.Method != "POST" {
214 t.Errorf("For request #2's method, got %q; expected %q",
215 req.Method, "POST")
216 }
217
218 if serveerr := <-servech; serveerr != io.EOF {
219 t.Errorf("Serve returned %q; expected EOF", serveerr)
220 }
221 }
222
223 type stringHandler string
224
225 func (s stringHandler) ServeHTTP(w ResponseWriter, r *Request) {
226 w.Header().Set("Result", string(s))
227 }
228
229 var handlers = []struct {
230 pattern string
231 msg string
232 }{
233 {"/", "Default"},
234 {"/someDir/", "someDir"},
235 {"/#/", "hash"},
236 {"someHost.com/someDir/", "someHost.com/someDir"},
237 }
238
239 var vtests = []struct {
240 url string
241 expected string
242 }{
243 {"http://localhost/someDir/apage", "someDir"},
244 {"http://localhost/%23/apage", "hash"},
245 {"http://localhost/otherDir/apage", "Default"},
246 {"http://someHost.com/someDir/apage", "someHost.com/someDir"},
247 {"http://otherHost.com/someDir/apage", "someDir"},
248 {"http://otherHost.com/aDir/apage", "Default"},
249
250 {"http://localhost/someDir", "/someDir/"},
251 {"http://localhost/%23", "/%23/"},
252 {"http://someHost.com/someDir", "/someDir/"},
253 }
254
255 func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) }
256 func testHostHandlers(t *testing.T, mode testMode) {
257 mux := NewServeMux()
258 for _, h := range handlers {
259 mux.Handle(h.pattern, stringHandler(h.msg))
260 }
261 ts := newClientServerTest(t, mode, mux).ts
262
263 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
264 if err != nil {
265 t.Fatal(err)
266 }
267 defer conn.Close()
268 cc := httputil.NewClientConn(conn, nil)
269 for _, vt := range vtests {
270 var r *Response
271 var req Request
272 if req.URL, err = url.Parse(vt.url); err != nil {
273 t.Errorf("cannot parse url: %v", err)
274 continue
275 }
276 if err := cc.Write(&req); err != nil {
277 t.Errorf("writing request: %v", err)
278 continue
279 }
280 r, err := cc.Read(&req)
281 if err != nil {
282 t.Errorf("reading response: %v", err)
283 continue
284 }
285 switch r.StatusCode {
286 case StatusOK:
287 s := r.Header.Get("Result")
288 if s != vt.expected {
289 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
290 }
291 case StatusTemporaryRedirect:
292 s := r.Header.Get("Location")
293 if s != vt.expected {
294 t.Errorf("Get(%q) = %q, want %q", vt.url, s, vt.expected)
295 }
296 default:
297 t.Errorf("Get(%q) unhandled status code %d", vt.url, r.StatusCode)
298 }
299 }
300 }
301
302 var serveMuxRegister = []struct {
303 pattern string
304 h Handler
305 }{
306 {"/dir/", serve(200)},
307 {"/search", serve(201)},
308 {"codesearch.google.com/search", serve(202)},
309 {"codesearch.google.com/", serve(203)},
310 {"example.com/", HandlerFunc(checkQueryStringHandler)},
311 {"/pkg/bar/extra%2fpath", serve(200)},
312 }
313
314
315 func serve(code int) HandlerFunc {
316 return func(w ResponseWriter, r *Request) {
317 w.WriteHeader(code)
318 }
319 }
320
321
322
323
324 func checkQueryStringHandler(w ResponseWriter, r *Request) {
325 u := *r.URL
326 u.Scheme = "http"
327 u.Host = r.Host
328 u.RawQuery = ""
329 if "http://"+r.URL.RawQuery == u.String() {
330 w.WriteHeader(200)
331 } else {
332 w.WriteHeader(500)
333 }
334 }
335
336 var serveMuxTests = []struct {
337 method string
338 host string
339 path string
340 code int
341 pattern string
342 }{
343 {"GET", "google.com", "/", 404, ""},
344 {"GET", "google.com", "/dir", 307, "/dir/"},
345 {"GET", "google.com", "/dir/", 200, "/dir/"},
346 {"GET", "google.com", "/dir/file", 200, "/dir/"},
347 {"GET", "google.com", "/search", 201, "/search"},
348 {"GET", "google.com", "/search/", 404, ""},
349 {"GET", "google.com", "/search/foo", 404, ""},
350 {"GET", "codesearch.google.com", "/search", 202, "codesearch.google.com/search"},
351 {"GET", "codesearch.google.com", "/search/", 203, "codesearch.google.com/"},
352 {"GET", "codesearch.google.com", "/search/foo", 203, "codesearch.google.com/"},
353 {"GET", "codesearch.google.com", "/", 203, "codesearch.google.com/"},
354 {"GET", "codesearch.google.com:443", "/", 203, "codesearch.google.com/"},
355 {"GET", "images.google.com", "/search", 201, "/search"},
356 {"GET", "images.google.com", "/search/", 404, ""},
357 {"GET", "images.google.com", "/search/foo", 404, ""},
358 {"GET", "google.com", "/../search", 307, "/search"},
359 {"GET", "google.com", "/dir/..", 307, ""},
360 {"GET", "google.com", "/dir/..", 307, ""},
361 {"GET", "google.com", "/dir/./file", 307, "/dir/"},
362
363
364
365 {"CONNECT", "google.com", "/dir", 307, "/dir/"},
366 {"CONNECT", "google.com", "/../search", 404, ""},
367 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
368 {"CONNECT", "google.com", "/dir/..", 200, "/dir/"},
369 {"CONNECT", "google.com", "/dir/./file", 200, "/dir/"},
370 }
371
372 func TestServeMuxHandler(t *testing.T) {
373 setParallel(t)
374 mux := NewServeMux()
375 for _, e := range serveMuxRegister {
376 mux.Handle(e.pattern, e.h)
377 }
378
379 for _, tt := range serveMuxTests {
380 r := &Request{
381 Method: tt.method,
382 Host: tt.host,
383 URL: &url.URL{
384 Path: tt.path,
385 },
386 }
387 h, pattern := mux.Handler(r)
388 rr := httptest.NewRecorder()
389 h.ServeHTTP(rr, r)
390 if pattern != tt.pattern || rr.Code != tt.code {
391 t.Errorf("%s %s %s = %d, %q, want %d, %q", tt.method, tt.host, tt.path, rr.Code, pattern, tt.code, tt.pattern)
392 }
393 }
394 }
395
396
397 func TestServeMuxHandlerTrailingSlash(t *testing.T) {
398 setParallel(t)
399 mux := NewServeMux()
400 const original = "/{x}/"
401 mux.Handle(original, NotFoundHandler())
402 r, _ := NewRequest("POST", "/foo", nil)
403 _, p := mux.Handler(r)
404 if p != original {
405 t.Errorf("got %q, want %q", p, original)
406 }
407 }
408
409
410 func TestServeMuxHandleFuncWithNilHandler(t *testing.T) {
411 setParallel(t)
412 defer func() {
413 if err := recover(); err == nil {
414 t.Error("expected call to mux.HandleFunc to panic")
415 }
416 }()
417 mux := NewServeMux()
418 mux.HandleFunc("/", nil)
419 }
420
421 var serveMuxTests2 = []struct {
422 method string
423 host string
424 url string
425 code int
426 redirOk bool
427 }{
428 {"GET", "google.com", "/", 404, false},
429 {"GET", "example.com", "/test/?example.com/test/", 200, false},
430 {"GET", "example.com", "test/?example.com/test/", 200, true},
431 {"GET", "google.com", "/pkg/bar//extra%2fpath", 200, true},
432 {"GET", "google.com", "/dir/b%2fc/..", 200, true},
433 {"GET", "google.com", "/doesnotexist/b%2fc/..", 404, true},
434 }
435
436
437
438 func TestServeMuxHandlerRedirects(t *testing.T) {
439 setParallel(t)
440 mux := NewServeMux()
441 for _, e := range serveMuxRegister {
442 mux.Handle(e.pattern, e.h)
443 }
444
445 for _, tt := range serveMuxTests2 {
446 tries := 1
447 turl := tt.url
448 for {
449 u, e := url.Parse(turl)
450 if e != nil {
451 t.Fatal(e)
452 }
453 r := &Request{
454 Method: tt.method,
455 Host: tt.host,
456 URL: u,
457 }
458 h, _ := mux.Handler(r)
459 rr := httptest.NewRecorder()
460 h.ServeHTTP(rr, r)
461 if rr.Code != 307 {
462 if rr.Code != tt.code {
463 t.Errorf("%s %s %s = %d, want %d", tt.method, tt.host, tt.url, rr.Code, tt.code)
464 }
465 break
466 }
467 if !tt.redirOk {
468 t.Errorf("%s %s %s, unexpected redirect", tt.method, tt.host, tt.url)
469 break
470 }
471 turl = rr.HeaderMap.Get("Location")
472 tries--
473 }
474 if tries < 0 {
475 t.Errorf("%s %s %s, too many redirects", tt.method, tt.host, tt.url)
476 }
477 }
478 }
479
480 func TestServeMuxHandlerRedirectPost(t *testing.T) {
481 setParallel(t)
482 mux := NewServeMux()
483 mux.HandleFunc("POST /test/", func(w ResponseWriter, r *Request) {
484 w.WriteHeader(200)
485 })
486
487 var code, retries int
488 startURL := "http://example.com/test"
489 reqURL := startURL
490 for retries = 0; retries <= 1; retries++ {
491 r := httptest.NewRequest("POST", reqURL, strings.NewReader("hello world"))
492 h, _ := mux.Handler(r)
493 rr := httptest.NewRecorder()
494 h.ServeHTTP(rr, r)
495 code = rr.Code
496 switch rr.Code {
497 case 307:
498 reqURL = rr.Result().Header.Get("Location")
499 continue
500 case 200:
501
502 default:
503 t.Errorf("unhandled response code: %v", rr.Code)
504 }
505 }
506 if code != 200 {
507 t.Errorf("POST %s = %d after %d retries, want = 200", startURL, code, retries)
508 }
509 }
510
511
512 func TestMuxRedirectLeadingSlashes(t *testing.T) {
513 setParallel(t)
514 paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"}
515 for _, path := range paths {
516 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n")))
517 if err != nil {
518 t.Errorf("%s", err)
519 }
520 mux := NewServeMux()
521 resp := httptest.NewRecorder()
522
523 mux.ServeHTTP(resp, req)
524
525 if loc, expected := resp.Header().Get("Location"), "/foo.txt"; loc != expected {
526 t.Errorf("Expected Location header set to %q; got %q", expected, loc)
527 return
528 }
529
530 if code, expected := resp.Code, StatusTemporaryRedirect; code != expected {
531 t.Errorf("Expected response code of StatusPermanentRedirect; got %d", code)
532 return
533 }
534 }
535 }
536
537
538
539
540
541 func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) {
542 run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode})
543 }
544 func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) {
545 writeBackQuery := func(w ResponseWriter, r *Request) {
546 fmt.Fprintf(w, "%s", r.URL.RawQuery)
547 }
548
549 mux := NewServeMux()
550 mux.HandleFunc("/testOne", writeBackQuery)
551 mux.HandleFunc("/testTwo/", writeBackQuery)
552 mux.HandleFunc("/testThree", writeBackQuery)
553 mux.HandleFunc("/testThree/", func(w ResponseWriter, r *Request) {
554 fmt.Fprintf(w, "%s:bar", r.URL.RawQuery)
555 })
556
557 ts := newClientServerTest(t, mode, mux).ts
558
559 tests := [...]struct {
560 path string
561 method string
562 want string
563 statusOk bool
564 }{
565 0: {"/testOne?this=that", "GET", "this=that", true},
566 1: {"/testTwo?foo=bar", "GET", "foo=bar", true},
567 2: {"/testTwo?a=1&b=2&a=3", "GET", "a=1&b=2&a=3", true},
568 3: {"/testTwo?", "GET", "", true},
569 4: {"/testThree?foo", "GET", "foo", true},
570 5: {"/testThree/?foo", "GET", "foo:bar", true},
571 6: {"/testThree?foo", "CONNECT", "foo", true},
572 7: {"/testThree/?foo", "CONNECT", "foo:bar", true},
573
574
575 8: {"/testOne/foo/..?foo", "GET", "foo", true},
576 9: {"/testOne/foo/..?foo", "CONNECT", "404 page not found\n", false},
577 }
578
579 for i, tt := range tests {
580 req, _ := NewRequest(tt.method, ts.URL+tt.path, nil)
581 res, err := ts.Client().Do(req)
582 if err != nil {
583 continue
584 }
585 slurp, _ := io.ReadAll(res.Body)
586 res.Body.Close()
587 if !tt.statusOk {
588 if got, want := res.StatusCode, 404; got != want {
589 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
590 }
591 }
592 if got, want := string(slurp), tt.want; got != want {
593 t.Errorf("#%d: Body = %q; want = %q", i, got, want)
594 }
595 }
596 }
597
598 func TestServeWithSlashRedirectForHostPatterns(t *testing.T) {
599 setParallel(t)
600
601 mux := NewServeMux()
602 mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/"))
603 mux.Handle("example.com/pkg/bar", stringHandler("example.com/pkg/bar"))
604 mux.Handle("example.com/pkg/bar/", stringHandler("example.com/pkg/bar/"))
605 mux.Handle("example.com:3000/pkg/connect/", stringHandler("example.com:3000/pkg/connect/"))
606 mux.Handle("example.com:9000/", stringHandler("example.com:9000/"))
607 mux.Handle("/pkg/baz/", stringHandler("/pkg/baz/"))
608 mux.Handle("example.com/a%2fb/", stringHandler("example.com/a%2fb/"))
609
610 tests := []struct {
611 method string
612 url string
613 code int
614 loc string
615 want string
616 }{
617 {"GET", "http://example.com/", 404, "", ""},
618 {"GET", "http://example.com/pkg/foo", 307, "/pkg/foo/", ""},
619 {"GET", "http://example.com/pkg/bar", 200, "", "example.com/pkg/bar"},
620 {"GET", "http://example.com/pkg/bar/", 200, "", "example.com/pkg/bar/"},
621 {"GET", "http://example.com/pkg/baz", 307, "/pkg/baz/", ""},
622 {"GET", "http://example.com:3000/pkg/foo", 307, "/pkg/foo/", ""},
623 {"CONNECT", "http://example.com/", 404, "", ""},
624 {"CONNECT", "http://example.com:3000/", 404, "", ""},
625 {"CONNECT", "http://example.com:9000/", 200, "", "example.com:9000/"},
626 {"CONNECT", "http://example.com/pkg/foo", 307, "/pkg/foo/", ""},
627 {"CONNECT", "http://example.com:3000/pkg/foo", 404, "", ""},
628 {"CONNECT", "http://example.com:3000/pkg/baz", 307, "/pkg/baz/", ""},
629 {"CONNECT", "http://example.com:3000/pkg/connect", 307, "/pkg/connect/", ""},
630 {"GET", "http://example.com/a%2fb", 307, "/a%2fb/", ""},
631 }
632
633 for i, tt := range tests {
634 req, _ := NewRequest(tt.method, tt.url, nil)
635 w := httptest.NewRecorder()
636 mux.ServeHTTP(w, req)
637
638 if got, want := w.Code, tt.code; got != want {
639 t.Errorf("#%d: Status = %d; want = %d", i, got, want)
640 }
641
642 if tt.code == 307 {
643 if got, want := w.HeaderMap.Get("Location"), tt.loc; got != want {
644 t.Errorf("#%d: Location = %q; want = %q", i, got, want)
645 }
646 } else {
647 if got, want := w.HeaderMap.Get("Result"), tt.want; got != want {
648 t.Errorf("#%d: Result = %q; want = %q", i, got, want)
649 }
650 }
651 }
652 }
653
654
655
656
657 func TestMuxNoSlashRedirectWithTrailingSlash(t *testing.T) {
658 mux := NewServeMux()
659 mux.HandleFunc("/{x}/", func(w ResponseWriter, r *Request) {
660 fmt.Fprintln(w, "ok")
661 })
662 w := httptest.NewRecorder()
663 req, _ := NewRequest("GET", "/", nil)
664 mux.ServeHTTP(w, req)
665 if g, w := w.Code, 404; g != w {
666 t.Errorf("got %d, want %d", g, w)
667 }
668 }
669
670
671
672
673 func TestMuxNoSlash405WithTrailingSlash(t *testing.T) {
674 mux := NewServeMux()
675 mux.HandleFunc("GET /{x}/", func(w ResponseWriter, r *Request) {
676 fmt.Fprintln(w, "ok")
677 })
678 w := httptest.NewRecorder()
679 req, _ := NewRequest("GET", "/", nil)
680 mux.ServeHTTP(w, req)
681 if g, w := w.Code, 404; g != w {
682 t.Errorf("got %d, want %d", g, w)
683 }
684 }
685
686 func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) }
687 func testShouldRedirectConcurrency(t *testing.T, mode testMode) {
688 mux := NewServeMux()
689 newClientServerTest(t, mode, mux)
690 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {})
691 }
692
693 func BenchmarkServeMux(b *testing.B) { benchmarkServeMux(b, true) }
694 func BenchmarkServeMux_SkipServe(b *testing.B) { benchmarkServeMux(b, false) }
695 func benchmarkServeMux(b *testing.B, runHandler bool) {
696 type test struct {
697 path string
698 code int
699 req *Request
700 }
701
702
703 var tests []test
704 endpoints := []string{"search", "dir", "file", "change", "count", "s"}
705 for _, e := range endpoints {
706 for i := 200; i < 230; i++ {
707 p := fmt.Sprintf("/%s/%d/", e, i)
708 tests = append(tests, test{
709 path: p,
710 code: i,
711 req: &Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: p}},
712 })
713 }
714 }
715 mux := NewServeMux()
716 for _, tt := range tests {
717 mux.Handle(tt.path, serve(tt.code))
718 }
719
720 rw := httptest.NewRecorder()
721 b.ReportAllocs()
722 b.ResetTimer()
723 for i := 0; i < b.N; i++ {
724 for _, tt := range tests {
725 *rw = httptest.ResponseRecorder{}
726 h, pattern := mux.Handler(tt.req)
727 if runHandler {
728 h.ServeHTTP(rw, tt.req)
729 if pattern != tt.path || rw.Code != tt.code {
730 b.Fatalf("got %d, %q, want %d, %q", rw.Code, pattern, tt.code, tt.path)
731 }
732 }
733 }
734 }
735 }
736
737 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
738 func testServerTimeouts(t *testing.T, mode testMode) {
739 runTimeSensitiveTest(t, []time.Duration{
740 10 * time.Millisecond,
741 50 * time.Millisecond,
742 100 * time.Millisecond,
743 500 * time.Millisecond,
744 1 * time.Second,
745 }, func(t *testing.T, timeout time.Duration) error {
746 return testServerTimeoutsWithTimeout(t, timeout, mode)
747 })
748 }
749
750 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
751 var reqNum atomic.Int32
752 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
753 fmt.Fprintf(res, "req=%d", reqNum.Add(1))
754 }), func(ts *httptest.Server) {
755 ts.Config.ReadTimeout = timeout
756 ts.Config.WriteTimeout = timeout
757 })
758 defer cst.close()
759 ts := cst.ts
760
761
762 c := ts.Client()
763 r, err := c.Get(ts.URL)
764 if err != nil {
765 return fmt.Errorf("http Get #1: %v", err)
766 }
767 got, err := io.ReadAll(r.Body)
768 expected := "req=1"
769 if string(got) != expected || err != nil {
770 return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil",
771 string(got), err, expected)
772 }
773
774
775 t1 := time.Now()
776 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
777 if err != nil {
778 return fmt.Errorf("Dial: %v", err)
779 }
780 buf := make([]byte, 1)
781 n, err := conn.Read(buf)
782 conn.Close()
783 latency := time.Since(t1)
784 if n != 0 || err != io.EOF {
785 return fmt.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF)
786 }
787 minLatency := timeout / 5 * 4
788 if latency < minLatency {
789 return fmt.Errorf("got EOF after %s, want >= %s", latency, minLatency)
790 }
791
792
793
794
795 r, err = c.Get(ts.URL)
796 if err != nil {
797 return fmt.Errorf("http Get #2: %v", err)
798 }
799 got, err = io.ReadAll(r.Body)
800 r.Body.Close()
801 expected = "req=2"
802 if string(got) != expected || err != nil {
803 return fmt.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected)
804 }
805
806 if !testing.Short() {
807 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
808 if err != nil {
809 return fmt.Errorf("long Dial: %v", err)
810 }
811 defer conn.Close()
812 go io.Copy(io.Discard, conn)
813 for i := 0; i < 5; i++ {
814 _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
815 if err != nil {
816 return fmt.Errorf("on write %d: %v", i, err)
817 }
818 time.Sleep(timeout / 2)
819 }
820 }
821 return nil
822 }
823
824 func TestServerReadTimeout(t *testing.T) { run(t, testServerReadTimeout, http3SkippedMode) }
825 func testServerReadTimeout(t *testing.T, mode testMode) {
826 respBody := "response body"
827 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
828 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
829 _, err := io.Copy(io.Discard, req.Body)
830 if !errors.Is(err, os.ErrDeadlineExceeded) {
831 t.Errorf("server timed out reading request body: got err %v; want os.ErrDeadlineExceeded", err)
832 }
833 res.Write([]byte(respBody))
834 }), func(ts *httptest.Server) {
835 ts.Config.ReadHeaderTimeout = -1
836 ts.Config.ReadTimeout = timeout
837 t.Logf("Server.Config.ReadTimeout = %v", timeout)
838 })
839
840 var retries atomic.Int32
841 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
842 if retries.Add(1) != 1 {
843 return nil, errors.New("too many retries")
844 }
845 return nil, nil
846 }
847
848 pr, pw := io.Pipe()
849 res, err := cst.c.Post(cst.ts.URL, "text/apocryphal", pr)
850 if err != nil {
851 t.Logf("Get error, retrying: %v", err)
852 cst.close()
853 continue
854 }
855 defer res.Body.Close()
856 got, err := io.ReadAll(res.Body)
857 if string(got) != respBody || err != nil {
858 t.Errorf("client read response body: %q, %v; want %q, nil", string(got), err, respBody)
859 }
860 pw.Close()
861 break
862 }
863 }
864
865 func TestServerNoReadTimeout(t *testing.T) {
866
867 run(t, testServerNoReadTimeout, http3SkippedMode)
868 }
869 func testServerNoReadTimeout(t *testing.T, mode testMode) {
870 reqBody := "Hello, Gophers!"
871 resBody := "Hi, Gophers!"
872 for _, timeout := range []time.Duration{0, -1} {
873 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
874 ctl := NewResponseController(res)
875 ctl.EnableFullDuplex()
876 res.WriteHeader(StatusOK)
877
878
879 if err := ctl.Flush(); err != nil {
880 t.Errorf("server flush response: %v", err)
881 return
882 }
883 got, err := io.ReadAll(req.Body)
884 if string(got) != reqBody || err != nil {
885 t.Errorf("server read request body: %v; got %q, want %q", err, got, reqBody)
886 }
887 res.Write([]byte(resBody))
888 }), func(ts *httptest.Server) {
889 ts.Config.ReadTimeout = timeout
890 t.Logf("Server.Config.ReadTimeout = %d", timeout)
891 })
892
893 pr, pw := io.Pipe()
894 res, err := cst.c.Post(cst.ts.URL, "text/plain", pr)
895 if err != nil {
896 t.Fatal(err)
897 }
898 defer res.Body.Close()
899
900
901 time.Sleep(10 * time.Millisecond)
902 pw.Write([]byte(reqBody))
903 pw.Close()
904
905 got, err := io.ReadAll(res.Body)
906 if string(got) != resBody || err != nil {
907 t.Errorf("client read response body: %v; got %v, want %q", err, got, resBody)
908 }
909 }
910 }
911
912 func TestServerWriteTimeout(t *testing.T) { run(t, testServerWriteTimeout, http3SkippedMode) }
913 func testServerWriteTimeout(t *testing.T, mode testMode) {
914 for timeout := 5 * time.Millisecond; ; timeout *= 2 {
915 errc := make(chan error, 2)
916 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
917 errc <- nil
918 _, err := io.Copy(res, neverEnding('a'))
919 errc <- err
920 }), func(ts *httptest.Server) {
921 ts.Config.WriteTimeout = timeout
922 t.Logf("Server.Config.WriteTimeout = %v", timeout)
923 })
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942 var retries atomic.Int32
943 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
944 if retries.Add(1) != 1 {
945 return nil, errors.New("too many retries")
946 }
947 return nil, nil
948 }
949
950 res, err := cst.c.Get(cst.ts.URL)
951 if err != nil {
952
953 t.Logf("Get error, retrying: %v", err)
954 cst.close()
955 continue
956 }
957 defer res.Body.Close()
958 _, err = io.Copy(io.Discard, res.Body)
959 if err == nil {
960 t.Errorf("client reading from truncated request body: got nil error, want non-nil")
961 }
962 select {
963 case <-errc:
964 err = <-errc
965 if !errors.Is(err, os.ErrDeadlineExceeded) {
966 t.Errorf("server timed out writing request body: got err %v; want os.ErrDeadlineExceeded", err)
967 }
968 return
969 default:
970
971 t.Logf("handler didn't run, retrying")
972 cst.close()
973 }
974 }
975 }
976
977 func TestServerNoWriteTimeout(t *testing.T) { run(t, testServerNoWriteTimeout) }
978 func testServerNoWriteTimeout(t *testing.T, mode testMode) {
979 for _, timeout := range []time.Duration{0, -1} {
980 handlerDone := make(chan struct{})
981 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
982 defer close(handlerDone)
983 _, err := io.Copy(res, neverEnding('a'))
984 t.Logf("server write response: %v", err)
985 }), func(ts *httptest.Server) {
986 ts.Config.WriteTimeout = timeout
987 t.Logf("Server.Config.WriteTimeout = %d", timeout)
988 })
989
990 res, err := cst.c.Get(cst.ts.URL)
991 if err != nil {
992 t.Fatal(err)
993 }
994 n, err := io.CopyN(io.Discard, res.Body, 1<<20)
995 if n != 1<<20 || err != nil {
996 t.Errorf("client read response body: %d, %v", n, err)
997 }
998 res.Body.Close()
999
1000 cst.ts.Config.Shutdown(context.Background())
1001 <-handlerDone
1002 }
1003 }
1004
1005
1006 func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) {
1007 run(t, testWriteDeadlineExtendedOnNewRequest)
1008 }
1009 func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) {
1010 if testing.Short() {
1011 t.Skip("skipping in short mode")
1012 }
1013 ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}),
1014 func(ts *httptest.Server) {
1015 ts.Config.WriteTimeout = 250 * time.Millisecond
1016 },
1017 ).ts
1018
1019 c := ts.Client()
1020
1021 for i := 1; i <= 3; i++ {
1022 req, err := NewRequest("GET", ts.URL, nil)
1023 if err != nil {
1024 t.Fatal(err)
1025 }
1026
1027 r, err := c.Do(req)
1028 if err != nil {
1029 t.Fatalf("http2 Get #%d: %v", i, err)
1030 }
1031 r.Body.Close()
1032 time.Sleep(ts.Config.WriteTimeout / 2)
1033 }
1034 }
1035
1036
1037
1038 func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) {
1039 tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
1040 for i, timeout := range tries {
1041 err := testFunc(timeout)
1042 if err == nil {
1043 return
1044 }
1045 t.Logf("failed at %v: %v", timeout, err)
1046 if i != len(tries)-1 {
1047 t.Logf("retrying at %v ...", tries[i+1])
1048 }
1049 }
1050 t.Fatal("all attempts failed")
1051 }
1052
1053
1054 func TestWriteDeadlineEnforcedPerStream(t *testing.T) {
1055 if testing.Short() {
1056 t.Skip("skipping in short mode")
1057 }
1058 setParallel(t)
1059 run(t, func(t *testing.T, mode testMode) {
1060 tryTimeouts(t, func(timeout time.Duration) error {
1061 return testWriteDeadlineEnforcedPerStream(t, mode, timeout)
1062 })
1063 }, http3SkippedMode)
1064 }
1065
1066 func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error {
1067 firstRequest := make(chan bool, 1)
1068 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1069 select {
1070 case firstRequest <- true:
1071
1072 default:
1073
1074 time.Sleep(timeout)
1075 }
1076 }), func(ts *httptest.Server) {
1077 ts.Config.WriteTimeout = timeout / 2
1078 })
1079 defer cst.close()
1080 ts := cst.ts
1081
1082 c := ts.Client()
1083
1084 req, err := NewRequest("GET", ts.URL, nil)
1085 if err != nil {
1086 return fmt.Errorf("NewRequest: %v", err)
1087 }
1088 r, err := c.Do(req)
1089 if err != nil {
1090 return fmt.Errorf("Get #1: %v", err)
1091 }
1092 r.Body.Close()
1093
1094 req, err = NewRequest("GET", ts.URL, nil)
1095 if err != nil {
1096 return fmt.Errorf("NewRequest: %v", err)
1097 }
1098 r, err = c.Do(req)
1099 if err == nil {
1100 r.Body.Close()
1101 return fmt.Errorf("Get #2 expected error, got nil")
1102 }
1103 if mode == http2Mode {
1104 expected := "stream ID 3; INTERNAL_ERROR"
1105 if !strings.Contains(err.Error(), expected) {
1106 return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err)
1107 }
1108 }
1109 return nil
1110 }
1111
1112
1113 func TestNoWriteDeadline(t *testing.T) {
1114 if testing.Short() {
1115 t.Skip("skipping in short mode")
1116 }
1117 setParallel(t)
1118 defer afterTest(t)
1119 run(t, func(t *testing.T, mode testMode) {
1120 tryTimeouts(t, func(timeout time.Duration) error {
1121 return testNoWriteDeadline(t, mode, timeout)
1122 })
1123 })
1124 }
1125
1126 func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error {
1127 firstRequest := make(chan bool, 1)
1128 cst := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {
1129 select {
1130 case firstRequest <- true:
1131
1132 default:
1133
1134 time.Sleep(timeout)
1135 }
1136 }))
1137 defer cst.close()
1138 ts := cst.ts
1139
1140 c := ts.Client()
1141
1142 for i := 0; i < 2; i++ {
1143 req, err := NewRequest("GET", ts.URL, nil)
1144 if err != nil {
1145 return fmt.Errorf("NewRequest: %v", err)
1146 }
1147 r, err := c.Do(req)
1148 if err != nil {
1149 return fmt.Errorf("Get #%d: %v", i, err)
1150 }
1151 r.Body.Close()
1152 }
1153 return nil
1154 }
1155
1156
1157
1158
1159 func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) }
1160 func testOnlyWriteTimeout(t *testing.T, mode testMode) {
1161 var (
1162 mu sync.RWMutex
1163 conn net.Conn
1164 )
1165 var afterTimeoutErrc = make(chan error, 1)
1166 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) {
1167 buf := make([]byte, 512<<10)
1168 _, err := w.Write(buf)
1169 if err != nil {
1170 t.Errorf("handler Write error: %v", err)
1171 return
1172 }
1173 mu.RLock()
1174 defer mu.RUnlock()
1175 if conn == nil {
1176 t.Error("no established connection found")
1177 return
1178 }
1179 conn.SetWriteDeadline(time.Now().Add(-30 * time.Second))
1180 _, err = w.Write(buf)
1181 afterTimeoutErrc <- err
1182 }), func(ts *httptest.Server) {
1183 ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn}
1184 }).ts
1185
1186 c := ts.Client()
1187
1188 err := func() error {
1189 res, err := c.Get(ts.URL)
1190 if err != nil {
1191 return err
1192 }
1193 _, err = io.Copy(io.Discard, res.Body)
1194 res.Body.Close()
1195 return err
1196 }()
1197 if err == nil {
1198 t.Errorf("expected an error copying body from Get request")
1199 }
1200
1201 if err := <-afterTimeoutErrc; err == nil {
1202 t.Error("expected write error after timeout")
1203 }
1204 }
1205
1206
1207 type trackLastConnListener struct {
1208 net.Listener
1209
1210 mu *sync.RWMutex
1211 last *net.Conn
1212 }
1213
1214 func (l trackLastConnListener) Accept() (c net.Conn, err error) {
1215 c, err = l.Listener.Accept()
1216 if err == nil {
1217 l.mu.Lock()
1218 *l.last = c
1219 l.mu.Unlock()
1220 }
1221 return
1222 }
1223
1224
1225 func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) }
1226 func testIdentityResponse(t *testing.T, mode testMode) {
1227 if mode == http2Mode {
1228 t.Skip("https://go.dev/issue/56019")
1229 }
1230
1231 handler := HandlerFunc(func(rw ResponseWriter, req *Request) {
1232 rw.Header().Set("Content-Length", "3")
1233 rw.Header().Set("Transfer-Encoding", req.FormValue("te"))
1234 switch {
1235 case req.FormValue("overwrite") == "1":
1236 _, err := rw.Write([]byte("foo TOO LONG"))
1237 if err != ErrContentLength {
1238 t.Errorf("expected ErrContentLength; got %v", err)
1239 }
1240 case req.FormValue("underwrite") == "1":
1241 rw.Header().Set("Content-Length", "500")
1242 rw.Write([]byte("too short"))
1243 default:
1244 rw.Write([]byte("foo"))
1245 }
1246 })
1247
1248 ts := newClientServerTest(t, mode, handler).ts
1249 c := ts.Client()
1250
1251
1252
1253
1254
1255 for _, te := range []string{"", "identity"} {
1256 url := ts.URL + "/?te=" + te
1257 res, err := c.Get(url)
1258 if err != nil {
1259 t.Fatalf("error with Get of %s: %v", url, err)
1260 }
1261 if cl, expected := res.ContentLength, int64(3); cl != expected {
1262 t.Errorf("for %s expected res.ContentLength of %d; got %d", url, expected, cl)
1263 }
1264 if cl, expected := res.Header.Get("Content-Length"), "3"; cl != expected {
1265 t.Errorf("for %s expected Content-Length header of %q; got %q", url, expected, cl)
1266 }
1267 if tl, expected := len(res.TransferEncoding), 0; tl != expected {
1268 t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
1269 url, expected, tl, res.TransferEncoding)
1270 }
1271 res.Body.Close()
1272 }
1273
1274
1275 url := ts.URL + "/?overwrite=1"
1276 res, err := c.Get(url)
1277 if err != nil {
1278 t.Fatalf("error with Get of %s: %v", url, err)
1279 }
1280 res.Body.Close()
1281
1282 if mode != http1Mode {
1283 return
1284 }
1285
1286
1287
1288 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1289 if err != nil {
1290 t.Fatalf("error dialing: %v", err)
1291 }
1292 _, err = conn.Write([]byte("GET /?underwrite=1 HTTP/1.1\r\nHost: foo\r\n\r\n"))
1293 if err != nil {
1294 t.Fatalf("error writing: %v", err)
1295 }
1296
1297
1298 got, _ := io.ReadAll(conn)
1299 expectedSuffix := "\r\n\r\ntoo short"
1300 if !strings.HasSuffix(string(got), expectedSuffix) {
1301 t.Errorf("Expected output to end with %q; got response body %q",
1302 expectedSuffix, string(got))
1303 }
1304 }
1305
1306 func testTCPConnectionCloses(t *testing.T, req string, h Handler) {
1307 setParallel(t)
1308 s := newClientServerTest(t, http1Mode, h).ts
1309
1310 conn, err := net.Dial("tcp", s.Listener.Addr().String())
1311 if err != nil {
1312 t.Fatal("dial error:", err)
1313 }
1314 defer conn.Close()
1315
1316 _, err = fmt.Fprint(conn, req)
1317 if err != nil {
1318 t.Fatal("print error:", err)
1319 }
1320
1321 r := bufio.NewReader(conn)
1322 res, err := ReadResponse(r, &Request{Method: "GET"})
1323 if err != nil {
1324 t.Fatal("ReadResponse error:", err)
1325 }
1326
1327 _, err = io.ReadAll(r)
1328 if err != nil {
1329 t.Fatal("read error:", err)
1330 }
1331
1332 if !res.Close {
1333 t.Errorf("Response.Close = false; want true")
1334 }
1335 }
1336
1337 func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) {
1338 setParallel(t)
1339 ts := newClientServerTest(t, http1Mode, handler).ts
1340 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1341 if err != nil {
1342 t.Fatal(err)
1343 }
1344 defer conn.Close()
1345 br := bufio.NewReader(conn)
1346 for i := 0; i < 2; i++ {
1347 if _, err := io.WriteString(conn, req); err != nil {
1348 t.Fatal(err)
1349 }
1350 res, err := ReadResponse(br, nil)
1351 if err != nil {
1352 t.Fatalf("res %d: %v", i+1, err)
1353 }
1354 if _, err := io.Copy(io.Discard, res.Body); err != nil {
1355 t.Fatalf("res %d body copy: %v", i+1, err)
1356 }
1357 res.Body.Close()
1358 }
1359 }
1360
1361
1362 func TestServeHTTP10Close(t *testing.T) {
1363 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1364 ServeFile(w, r, "testdata/file")
1365 }))
1366 }
1367
1368
1369 func TestClientCanClose(t *testing.T) {
1370 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1371
1372 }))
1373 }
1374
1375
1376
1377 func TestHandlersCanSetConnectionClose11(t *testing.T) {
1378 testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1379 w.Header().Set("Connection", "close")
1380 }))
1381 }
1382
1383 func TestHandlersCanSetConnectionClose10(t *testing.T) {
1384 testTCPConnectionCloses(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1385 w.Header().Set("Connection", "close")
1386 }))
1387 }
1388
1389 func TestHTTP2UpgradeClosesConnection(t *testing.T) {
1390 testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) {
1391
1392
1393 }))
1394 }
1395
1396 func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) }
1397 func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) }
1398
1399
1400 func TestHTTP10KeepAlive204Response(t *testing.T) {
1401 testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204))
1402 }
1403
1404 func TestHTTP11KeepAlive204Response(t *testing.T) {
1405 testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204))
1406 }
1407
1408 func TestHTTP10KeepAlive304Response(t *testing.T) {
1409 testTCPConnectionStaysOpen(t,
1410 "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n",
1411 HandlerFunc(send304))
1412 }
1413
1414
1415 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) }
1416 func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) {
1417 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1418 w.(Flusher).Flush()
1419 w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}"))
1420 }))
1421 type data struct {
1422 Addr string
1423 }
1424 var addrs [2]data
1425 for i := range addrs {
1426 res, err := cst.c.Get(cst.ts.URL)
1427 if err != nil {
1428 t.Fatal(err)
1429 }
1430 if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil {
1431 t.Fatal(err)
1432 }
1433 if addrs[i].Addr == "" {
1434 t.Fatal("no address")
1435 }
1436 res.Body.Close()
1437 }
1438 if addrs[0] != addrs[1] {
1439 t.Fatalf("connection not reused")
1440 }
1441 }
1442
1443 func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) }
1444 func testSetsRemoteAddr(t *testing.T, mode testMode) {
1445 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1446 fmt.Fprintf(w, "%s", r.RemoteAddr)
1447 }))
1448
1449 res, err := cst.c.Get(cst.ts.URL)
1450 if err != nil {
1451 t.Fatalf("Get error: %v", err)
1452 }
1453 body, err := io.ReadAll(res.Body)
1454 if err != nil {
1455 t.Fatalf("ReadAll error: %v", err)
1456 }
1457 ip := string(body)
1458 if !strings.HasPrefix(ip, "127.0.0.1:") && !strings.HasPrefix(ip, "[::1]:") {
1459 t.Fatalf("Expected local addr; got %q", ip)
1460 }
1461 }
1462
1463 type blockingRemoteAddrListener struct {
1464 net.Listener
1465 conns chan<- net.Conn
1466 }
1467
1468 func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
1469 c, err := l.Listener.Accept()
1470 if err != nil {
1471 return nil, err
1472 }
1473 brac := &blockingRemoteAddrConn{
1474 Conn: c,
1475 addrs: make(chan net.Addr, 1),
1476 }
1477 l.conns <- brac
1478 return brac, nil
1479 }
1480
1481 type blockingRemoteAddrConn struct {
1482 net.Conn
1483 addrs chan net.Addr
1484 }
1485
1486 func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
1487 return <-c.addrs
1488 }
1489
1490
1491 func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
1492 run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode})
1493 }
1494 func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) {
1495 conns := make(chan net.Conn)
1496 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1497 fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
1498 }), func(ts *httptest.Server) {
1499 ts.Listener = &blockingRemoteAddrListener{
1500 Listener: ts.Listener,
1501 conns: conns,
1502 }
1503 }).ts
1504
1505 c := ts.Client()
1506
1507 c.Transport.(*Transport).DisableKeepAlives = true
1508
1509 fetch := func(num int, response chan<- string) {
1510 resp, err := c.Get(ts.URL)
1511 if err != nil {
1512 t.Errorf("Request %d: %v", num, err)
1513 response <- ""
1514 return
1515 }
1516 defer resp.Body.Close()
1517 body, err := io.ReadAll(resp.Body)
1518 if err != nil {
1519 t.Errorf("Request %d: %v", num, err)
1520 response <- ""
1521 return
1522 }
1523 response <- string(body)
1524 }
1525
1526
1527 response1c := make(chan string, 1)
1528 go fetch(1, response1c)
1529
1530
1531 conn1 := <-conns
1532
1533
1534 response2c := make(chan string, 1)
1535 go fetch(2, response2c)
1536 conn2 := <-conns
1537
1538
1539 conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1540 IP: net.ParseIP("12.12.12.12"), Port: 12}
1541
1542
1543 response2 := <-response2c
1544 if g, e := response2, "RA:12.12.12.12:12"; g != e {
1545 t.Fatalf("response 2 addr = %q; want %q", g, e)
1546 }
1547
1548
1549 conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
1550 IP: net.ParseIP("21.21.21.21"), Port: 21}
1551
1552
1553 response1 := <-response1c
1554 if g, e := response1, "RA:21.21.21.21:21"; g != e {
1555 t.Fatalf("response 1 addr = %q; want %q", g, e)
1556 }
1557 }
1558
1559
1560
1561 func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) }
1562 func testHeadResponses(t *testing.T, mode testMode) {
1563 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1564 _, err := w.Write([]byte("<html>"))
1565 if err != nil {
1566 t.Errorf("ResponseWriter.Write: %v", err)
1567 }
1568
1569
1570 _, err = io.Copy(w, struct{ io.Reader }{strings.NewReader("789a")})
1571 if err != nil {
1572 t.Errorf("Copy(ResponseWriter, ...): %v", err)
1573 }
1574 }))
1575 res, err := cst.c.Head(cst.ts.URL)
1576 if err != nil {
1577 t.Error(err)
1578 }
1579 if len(res.TransferEncoding) > 0 {
1580 t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
1581 }
1582 if ct := res.Header.Get("Content-Type"); ct != "text/html; charset=utf-8" {
1583 t.Errorf("Content-Type: %q; want text/html; charset=utf-8", ct)
1584 }
1585
1586 if v := res.ContentLength; v != 10 && mode != http3Mode {
1587 t.Errorf("Content-Length: %d; want 10", v)
1588 }
1589 body, err := io.ReadAll(res.Body)
1590 if err != nil {
1591 t.Error(err)
1592 }
1593 if len(body) > 0 {
1594 t.Errorf("got unexpected body %q", string(body))
1595 }
1596 }
1597
1598
1599
1600 func TestHeadReaderFrom(t *testing.T) { run(t, testHeadReaderFrom, []testMode{http1Mode}) }
1601 func testHeadReaderFrom(t *testing.T, mode testMode) {
1602
1603 wantBody := strings.Repeat("a", 4096)
1604 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1605 w.(io.ReaderFrom).ReadFrom(strings.NewReader(wantBody))
1606 }))
1607 res, err := cst.c.Head(cst.ts.URL)
1608 if err != nil {
1609 t.Fatal(err)
1610 }
1611 res.Body.Close()
1612 res, err = cst.c.Get(cst.ts.URL)
1613 if err != nil {
1614 t.Fatal(err)
1615 }
1616 gotBody, err := io.ReadAll(res.Body)
1617 res.Body.Close()
1618 if err != nil {
1619 t.Fatal(err)
1620 }
1621 if string(gotBody) != wantBody {
1622 t.Errorf("got unexpected body len=%v, want %v", len(gotBody), len(wantBody))
1623 }
1624 }
1625
1626
1627
1628 func TestReaderFromTooLong(t *testing.T) { run(t, testReaderFromTooLong, []testMode{http1Mode}) }
1629 func testReaderFromTooLong(t *testing.T, mode testMode) {
1630 contentLen := 600
1631 tests := []struct {
1632 name string
1633 reader io.Reader
1634 wantHandlerErr error
1635 }{
1636 {
1637 name: "reader of correct length",
1638 reader: strings.NewReader(strings.Repeat("a", contentLen)),
1639 },
1640 {
1641 name: "wrapped reader of correct outer length",
1642 reader: io.LimitReader(strings.NewReader(strings.Repeat("a", 2*contentLen)), int64(contentLen)),
1643 },
1644 {
1645 name: "wrapped reader of correct inner length",
1646 reader: io.LimitReader(io.LimitReader(strings.NewReader(strings.Repeat("a", 2*contentLen)), int64(contentLen)), int64(2*contentLen)),
1647 },
1648 {
1649 name: "reader that is too long",
1650 reader: strings.NewReader(strings.Repeat("a", 2*contentLen)),
1651 wantHandlerErr: ErrContentLength,
1652 },
1653 {
1654 name: "wrapped reader that is too long",
1655 reader: io.LimitReader(strings.NewReader(strings.Repeat("a", 2*contentLen)), int64(2*contentLen)),
1656 wantHandlerErr: ErrContentLength,
1657 },
1658 }
1659
1660 for _, tc := range tests {
1661 t.Run(tc.name, func(t *testing.T) {
1662 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1663 w.Header().Set("Content-Length", strconv.Itoa(contentLen))
1664 n, err := w.(io.ReaderFrom).ReadFrom(tc.reader)
1665 if int(n) != contentLen || !errors.Is(err, tc.wantHandlerErr) {
1666 t.Errorf("got %v, %v from w.ReadFrom; want %v, %v", n, err, contentLen, tc.wantHandlerErr)
1667 }
1668 }))
1669 res, err := cst.c.Get(cst.ts.URL)
1670 if err != nil {
1671 t.Fatal(err)
1672 }
1673 defer res.Body.Close()
1674 gotBody, err := io.ReadAll(res.Body)
1675 if err != nil {
1676 t.Fatal(err)
1677 }
1678 if len(gotBody) != contentLen {
1679 t.Errorf("got unexpected body len=%v, want %v", len(gotBody), contentLen)
1680 }
1681 })
1682 }
1683 }
1684
1685 func TestTLSHandshakeTimeout(t *testing.T) {
1686 run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode})
1687 }
1688 func testTLSHandshakeTimeout(t *testing.T, mode testMode) {
1689 errLog := new(strings.Builder)
1690 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}),
1691 func(ts *httptest.Server) {
1692 ts.Config.ReadTimeout = 250 * time.Millisecond
1693 ts.Config.ErrorLog = log.New(errLog, "", 0)
1694 },
1695 )
1696 ts := cst.ts
1697
1698 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1699 if err != nil {
1700 t.Fatalf("Dial: %v", err)
1701 }
1702 var buf [1]byte
1703 n, err := conn.Read(buf[:])
1704 if err == nil || n != 0 {
1705 t.Errorf("Read = %d, %v; want an error and no bytes", n, err)
1706 }
1707 conn.Close()
1708
1709 cst.close()
1710 if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") {
1711 t.Errorf("expected a TLS handshake timeout error; got %q", v)
1712 }
1713 }
1714
1715 func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) }
1716 func testTLSServer(t *testing.T, mode testMode) {
1717 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1718 if r.TLS != nil {
1719 w.Header().Set("X-TLS-Set", "true")
1720 if r.TLS.HandshakeComplete {
1721 w.Header().Set("X-TLS-HandshakeComplete", "true")
1722 }
1723 }
1724 }), func(ts *httptest.Server) {
1725 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
1726 }).ts
1727
1728
1729
1730
1731
1732
1733 idleConn, err := net.Dial("tcp", ts.Listener.Addr().String())
1734 if err != nil {
1735 t.Fatalf("Dial: %v", err)
1736 }
1737 defer idleConn.Close()
1738
1739 if !strings.HasPrefix(ts.URL, "https://") {
1740 t.Errorf("expected test TLS server to start with https://, got %q", ts.URL)
1741 return
1742 }
1743 client := ts.Client()
1744 res, err := client.Get(ts.URL)
1745 if err != nil {
1746 t.Error(err)
1747 return
1748 }
1749 if res == nil {
1750 t.Errorf("got nil Response")
1751 return
1752 }
1753 defer res.Body.Close()
1754 if res.Header.Get("X-TLS-Set") != "true" {
1755 t.Errorf("expected X-TLS-Set response header")
1756 return
1757 }
1758 if res.Header.Get("X-TLS-HandshakeComplete") != "true" {
1759 t.Errorf("expected X-TLS-HandshakeComplete header")
1760 }
1761 }
1762
1763 type fakeConnectionStateConn struct {
1764 net.Conn
1765 }
1766
1767 func (fcsc *fakeConnectionStateConn) ConnectionState() tls.ConnectionState {
1768 return tls.ConnectionState{
1769 ServerName: "example.com",
1770 }
1771 }
1772
1773 func TestTLSServerWithoutTLSConn(t *testing.T) {
1774
1775 pr, pw := net.Pipe()
1776 c := make(chan int)
1777 listener := &oneConnListener{&fakeConnectionStateConn{pr}}
1778 server := &Server{
1779 Handler: HandlerFunc(func(writer ResponseWriter, request *Request) {
1780 if request.TLS == nil {
1781 t.Fatal("request.TLS is nil, expected not nil")
1782 }
1783 if request.TLS.ServerName != "example.com" {
1784 t.Fatalf("request.TLS.ServerName is %s, expected %s", request.TLS.ServerName, "example.com")
1785 }
1786 writer.Header().Set("X-TLS-ServerName", "example.com")
1787 }),
1788 }
1789
1790
1791 go func() {
1792 req, _ := NewRequest(MethodGet, "https://example.com", nil)
1793 req.Write(pw)
1794
1795 resp, _ := ReadResponse(bufio.NewReader(pw), req)
1796 if hdr := resp.Header.Get("X-TLS-ServerName"); hdr != "example.com" {
1797 t.Errorf("response header X-TLS-ServerName is %s, expected %s", hdr, "example.com")
1798 }
1799 close(c)
1800 pw.Close()
1801 }()
1802
1803 server.Serve(listener)
1804
1805
1806 <-c
1807 pr.Close()
1808 }
1809
1810 func TestServeTLS(t *testing.T) {
1811 CondSkipHTTP2(t)
1812
1813 defer afterTest(t)
1814 defer SetTestHookServerServe(nil)
1815
1816 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1817 if err != nil {
1818 t.Fatal(err)
1819 }
1820 tlsConf := &tls.Config{
1821 Certificates: []tls.Certificate{cert},
1822 }
1823
1824 ln := newLocalListener(t)
1825 defer ln.Close()
1826 addr := ln.Addr().String()
1827
1828 serving := make(chan bool, 1)
1829 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1830 serving <- true
1831 })
1832 handler := HandlerFunc(func(w ResponseWriter, r *Request) {})
1833 s := &Server{
1834 Addr: addr,
1835 TLSConfig: tlsConf,
1836 Handler: handler,
1837 }
1838 errc := make(chan error, 1)
1839 go func() { errc <- s.ServeTLS(ln, "", "") }()
1840 select {
1841 case err := <-errc:
1842 t.Fatalf("ServeTLS: %v", err)
1843 case <-serving:
1844 }
1845
1846 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
1847 InsecureSkipVerify: true,
1848 NextProtos: []string{"h2", "http/1.1"},
1849 })
1850 if err != nil {
1851 t.Fatal(err)
1852 }
1853 defer c.Close()
1854 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
1855 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
1856 }
1857 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
1858 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
1859 }
1860 }
1861
1862
1863 func TestTLSServerRejectHTTPRequests(t *testing.T) {
1864 run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode})
1865 }
1866 func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) {
1867 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1868 t.Error("unexpected HTTPS request")
1869 }), func(ts *httptest.Server) {
1870 var errBuf bytes.Buffer
1871 ts.Config.ErrorLog = log.New(&errBuf, "", 0)
1872 }).ts
1873 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
1874 if err != nil {
1875 t.Fatal(err)
1876 }
1877 defer conn.Close()
1878 io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
1879 slurp, err := io.ReadAll(conn)
1880 if err != nil {
1881 t.Fatal(err)
1882 }
1883 const wantPrefix = "HTTP/1.0 400 Bad Request\r\n"
1884 if !strings.HasPrefix(string(slurp), wantPrefix) {
1885 t.Errorf("response = %q; wanted prefix %q", slurp, wantPrefix)
1886 }
1887 }
1888
1889
1890 func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) {
1891 testAutomaticHTTP2_Serve(t, nil, true)
1892 }
1893
1894 func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) {
1895 testAutomaticHTTP2_Serve(t, &tls.Config{}, false)
1896 }
1897
1898 func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) {
1899 testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true)
1900 }
1901
1902 func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) {
1903 setParallel(t)
1904 defer afterTest(t)
1905 ln := newLocalListener(t)
1906 ln.Close()
1907 var s Server
1908 s.TLSConfig = tlsConf
1909 if err := s.Serve(ln); err == nil {
1910 t.Fatal("expected an error")
1911 }
1912 gotH2 := s.TLSNextProto["h2"] != nil
1913 if gotH2 != wantH2 {
1914 t.Errorf("http2 configured = %v; want %v", gotH2, wantH2)
1915 }
1916 }
1917
1918 func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) {
1919 setParallel(t)
1920 defer afterTest(t)
1921 ln := newLocalListener(t)
1922 ln.Close()
1923 var s Server
1924
1925
1926 s.TLSConfig = &tls.Config{
1927 NextProtos: []string{"h2"},
1928 }
1929 if err := s.Serve(ln); err == nil {
1930 t.Fatal("expected an error")
1931 }
1932 on := s.TLSNextProto["h2"] != nil
1933 if !on {
1934 t.Errorf("http2 wasn't automatically enabled")
1935 }
1936 }
1937
1938 func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
1939 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1940 if err != nil {
1941 t.Fatal(err)
1942 }
1943 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1944 Certificates: []tls.Certificate{cert},
1945 })
1946 }
1947
1948 func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) {
1949 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1950 if err != nil {
1951 t.Fatal(err)
1952 }
1953 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1954 GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
1955 return &cert, nil
1956 },
1957 })
1958 }
1959
1960 func TestAutomaticHTTP2_ListenAndServe_GetConfigForClient(t *testing.T) {
1961 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
1962 if err != nil {
1963 t.Fatal(err)
1964 }
1965 conf := &tls.Config{
1966
1967
1968 NextProtos: []string{"h2"},
1969 Certificates: []tls.Certificate{cert},
1970 }
1971 testAutomaticHTTP2_ListenAndServe(t, &tls.Config{
1972 GetConfigForClient: func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
1973 return conf, nil
1974 },
1975 })
1976 }
1977
1978 func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) {
1979 CondSkipHTTP2(t)
1980
1981 defer afterTest(t)
1982 defer SetTestHookServerServe(nil)
1983 var ok bool
1984 var s *Server
1985 const maxTries = 5
1986 var ln net.Listener
1987 Try:
1988 for try := 0; try < maxTries; try++ {
1989 ln = newLocalListener(t)
1990 addr := ln.Addr().String()
1991 ln.Close()
1992 t.Logf("Got %v", addr)
1993 lnc := make(chan net.Listener, 1)
1994 SetTestHookServerServe(func(s *Server, ln net.Listener) {
1995 lnc <- ln
1996 })
1997 s = &Server{
1998 Addr: addr,
1999 TLSConfig: tlsConf,
2000 }
2001 errc := make(chan error, 1)
2002 go func() { errc <- s.ListenAndServeTLS("", "") }()
2003 select {
2004 case err := <-errc:
2005 t.Logf("On try #%v: %v", try+1, err)
2006 continue
2007 case ln = <-lnc:
2008 ok = true
2009 t.Logf("Listening on %v", ln.Addr().String())
2010 break Try
2011 }
2012 }
2013 if !ok {
2014 t.Fatalf("Failed to start up after %d tries", maxTries)
2015 }
2016 defer ln.Close()
2017 c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
2018 InsecureSkipVerify: true,
2019 NextProtos: []string{"h2", "http/1.1"},
2020 })
2021 if err != nil {
2022 t.Fatal(err)
2023 }
2024 defer c.Close()
2025 if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
2026 t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
2027 }
2028 if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
2029 t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
2030 }
2031 }
2032
2033 type serverExpectTest struct {
2034 contentLength int
2035 chunked bool
2036 expectation string
2037 readBody bool
2038 expectedResponse string
2039 }
2040
2041 func expectTest(contentLength int, expectation string, readBody bool, expectedResponse string) serverExpectTest {
2042 return serverExpectTest{
2043 contentLength: contentLength,
2044 expectation: expectation,
2045 readBody: readBody,
2046 expectedResponse: expectedResponse,
2047 }
2048 }
2049
2050 var serverExpectTests = []serverExpectTest{
2051
2052 expectTest(100, "100-continue", true, "100 Continue"),
2053 expectTest(100, "100-cOntInUE", true, "100 Continue"),
2054
2055
2056 expectTest(100, "", true, "200 OK"),
2057
2058
2059
2060 expectTest(100, "100-continue", false, "401 Unauthorized"),
2061
2062 expectTest(100, "", false, "401 Unauthorized"),
2063
2064
2065 expectTest(0, "a-pony", false, "417 Expectation Failed"),
2066
2067
2068 expectTest(0, "100-continue", true, "200 OK"),
2069
2070 expectTest(0, "100-continue", false, "401 Unauthorized"),
2071
2072 {
2073 expectation: "100-continue",
2074 readBody: true,
2075 chunked: true,
2076 expectedResponse: "100 Continue",
2077 },
2078 }
2079
2080
2081
2082 func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) }
2083 func testServerExpect(t *testing.T, mode testMode) {
2084 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2085
2086
2087
2088 if strings.Contains(r.URL.RawQuery, "readbody=true") {
2089 io.ReadAll(r.Body)
2090 w.Write([]byte("Hi"))
2091 } else {
2092 w.WriteHeader(StatusUnauthorized)
2093 }
2094 })).ts
2095
2096 runTest := func(test serverExpectTest) {
2097 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
2098 if err != nil {
2099 t.Fatalf("Dial: %v", err)
2100 }
2101 defer conn.Close()
2102
2103
2104
2105 writeBody := test.contentLength != 0 && strings.ToLower(test.expectation) != "100-continue"
2106
2107 wg := sync.WaitGroup{}
2108 wg.Add(1)
2109 defer wg.Wait()
2110
2111 go func() {
2112 defer wg.Done()
2113
2114 contentLen := fmt.Sprintf("Content-Length: %d", test.contentLength)
2115 if test.chunked {
2116 contentLen = "Transfer-Encoding: chunked"
2117 }
2118 _, err := fmt.Fprintf(conn, "POST /?readbody=%v HTTP/1.1\r\n"+
2119 "Connection: close\r\n"+
2120 "%s\r\n"+
2121 "Expect: %s\r\nHost: foo\r\n\r\n",
2122 test.readBody, contentLen, test.expectation)
2123 if err != nil {
2124 t.Errorf("On test %#v, error writing request headers: %v", test, err)
2125 return
2126 }
2127 if writeBody {
2128 var targ io.WriteCloser = struct {
2129 io.Writer
2130 io.Closer
2131 }{
2132 conn,
2133 io.NopCloser(nil),
2134 }
2135 if test.chunked {
2136 targ = httputil.NewChunkedWriter(conn)
2137 }
2138 body := strings.Repeat("A", test.contentLength)
2139 _, err = fmt.Fprint(targ, body)
2140 if err == nil {
2141 err = targ.Close()
2142 }
2143 if err != nil {
2144 if !test.readBody {
2145
2146
2147 t.Logf("On test %#v, acceptable error writing request body: %v", test, err)
2148 return
2149 }
2150 t.Errorf("On test %#v, error writing request body: %v", test, err)
2151 }
2152 }
2153 }()
2154 bufr := bufio.NewReader(conn)
2155 line, err := bufr.ReadString('\n')
2156 if err != nil {
2157 if writeBody && !test.readBody {
2158
2159
2160
2161
2162
2163 t.Logf("On test %#v, acceptable error from ReadString: %v", test, err)
2164 return
2165 }
2166 t.Fatalf("On test %#v, ReadString: %v", test, err)
2167 }
2168 if !strings.Contains(line, test.expectedResponse) {
2169 t.Errorf("On test %#v, got first line = %q; want %q", test, line, test.expectedResponse)
2170 }
2171 }
2172
2173 for _, test := range serverExpectTests {
2174 runTest(test)
2175 }
2176 }
2177
2178
2179
2180 func TestServerUnreadRequestBodyLittle(t *testing.T) {
2181 setParallel(t)
2182 defer afterTest(t)
2183 conn := new(testConn)
2184 body := strings.Repeat("x", 100<<10)
2185 conn.readBuf.Write([]byte(fmt.Sprintf(
2186 "POST / HTTP/1.1\r\n"+
2187 "Host: test\r\n"+
2188 "Content-Length: %d\r\n"+
2189 "\r\n", len(body))))
2190 conn.readBuf.Write([]byte(body))
2191
2192 done := make(chan bool)
2193
2194 readBufLen := func() int {
2195 conn.readMu.Lock()
2196 defer conn.readMu.Unlock()
2197 return conn.readBuf.Len()
2198 }
2199
2200 ls := &oneConnListener{conn}
2201 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2202 defer close(done)
2203 if bufLen := readBufLen(); bufLen < len(body)/2 {
2204 t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
2205 }
2206 rw.WriteHeader(200)
2207 rw.(Flusher).Flush()
2208 if g, e := readBufLen(), 0; g != e {
2209 t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
2210 }
2211 if c := rw.Header().Get("Connection"); c != "" {
2212 t.Errorf(`Connection header = %q; want ""`, c)
2213 }
2214 }))
2215 <-done
2216 }
2217
2218
2219
2220
2221 func TestServerUnreadRequestBodyLarge(t *testing.T) {
2222 setParallel(t)
2223 if testing.Short() && testenv.Builder() == "" {
2224 t.Log("skipping in short mode")
2225 }
2226 conn := new(testConn)
2227 body := strings.Repeat("x", 1<<20)
2228 conn.readBuf.Write([]byte(fmt.Sprintf(
2229 "POST / HTTP/1.1\r\n"+
2230 "Host: test\r\n"+
2231 "Content-Length: %d\r\n"+
2232 "\r\n", len(body))))
2233 conn.readBuf.Write([]byte(body))
2234 conn.closec = make(chan bool, 1)
2235
2236 ls := &oneConnListener{conn}
2237 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2238 if conn.readBuf.Len() < len(body)/2 {
2239 t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2240 }
2241 rw.WriteHeader(200)
2242 rw.(Flusher).Flush()
2243 if conn.readBuf.Len() < len(body)/2 {
2244 t.Errorf("post-WriteHeader, read buffer length is %d; expected about 1MB", conn.readBuf.Len())
2245 }
2246 }))
2247 <-conn.closec
2248
2249 if res := conn.writeBuf.String(); !strings.Contains(res, "Connection: close") {
2250 t.Errorf("Expected a Connection: close header; got response: %s", res)
2251 }
2252 }
2253
2254 type handlerBodyCloseTest struct {
2255 bodySize int
2256 bodyChunked bool
2257 reqConnClose bool
2258
2259 wantEOFSearch bool
2260 wantNextReq bool
2261 }
2262
2263 func (t handlerBodyCloseTest) connectionHeader() string {
2264 if t.reqConnClose {
2265 return "Connection: close\r\n"
2266 }
2267 return ""
2268 }
2269
2270 var handlerBodyCloseTests = [...]handlerBodyCloseTest{
2271
2272
2273 0: {
2274 bodySize: 20 << 10,
2275 bodyChunked: false,
2276 reqConnClose: false,
2277 wantEOFSearch: true,
2278 wantNextReq: true,
2279 },
2280
2281
2282
2283 1: {
2284 bodySize: 20 << 10,
2285 bodyChunked: true,
2286 reqConnClose: false,
2287 wantEOFSearch: true,
2288 wantNextReq: true,
2289 },
2290
2291
2292
2293
2294 2: {
2295 bodySize: 20 << 10,
2296 bodyChunked: false,
2297 reqConnClose: true,
2298 wantEOFSearch: false,
2299 wantNextReq: false,
2300 },
2301
2302
2303
2304
2305
2306
2307 3: {
2308 bodySize: 20 << 10,
2309 bodyChunked: true,
2310 reqConnClose: true,
2311 wantEOFSearch: true,
2312 wantNextReq: false,
2313 },
2314
2315
2316 4: {
2317 bodySize: 1 << 20,
2318 bodyChunked: false,
2319 reqConnClose: false,
2320 wantEOFSearch: false,
2321 wantNextReq: false,
2322 },
2323
2324
2325 5: {
2326 bodySize: 1 << 20,
2327 bodyChunked: true,
2328 reqConnClose: false,
2329 wantEOFSearch: true,
2330 wantNextReq: false,
2331 },
2332
2333
2334
2335
2336 6: {
2337 bodySize: 1 << 20,
2338 bodyChunked: true,
2339 reqConnClose: true,
2340 wantEOFSearch: true,
2341 wantNextReq: false,
2342 },
2343
2344
2345
2346 7: {
2347 bodySize: 1 << 20,
2348 bodyChunked: false,
2349 reqConnClose: true,
2350 wantEOFSearch: false,
2351 wantNextReq: false,
2352 },
2353 }
2354
2355 func TestHandlerBodyClose(t *testing.T) {
2356 setParallel(t)
2357 if testing.Short() && testenv.Builder() == "" {
2358 t.Skip("skipping in -short mode")
2359 }
2360 for i, tt := range handlerBodyCloseTests {
2361 testHandlerBodyClose(t, i, tt)
2362 }
2363 }
2364
2365 func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
2366 conn := new(testConn)
2367 body := strings.Repeat("x", tt.bodySize)
2368 if tt.bodyChunked {
2369 conn.readBuf.WriteString("POST / HTTP/1.1\r\n" +
2370 "Host: test\r\n" +
2371 tt.connectionHeader() +
2372 "Transfer-Encoding: chunked\r\n" +
2373 "\r\n")
2374 cw := internal.NewChunkedWriter(&conn.readBuf)
2375 io.WriteString(cw, body)
2376 cw.Close()
2377 conn.readBuf.WriteString("\r\n")
2378 } else {
2379 conn.readBuf.Write([]byte(fmt.Sprintf(
2380 "POST / HTTP/1.1\r\n"+
2381 "Host: test\r\n"+
2382 tt.connectionHeader()+
2383 "Content-Length: %d\r\n"+
2384 "\r\n", len(body))))
2385 conn.readBuf.Write([]byte(body))
2386 }
2387 if !tt.reqConnClose {
2388 conn.readBuf.WriteString("GET / HTTP/1.1\r\nHost: test\r\n\r\n")
2389 }
2390 conn.closec = make(chan bool, 1)
2391
2392 readBufLen := func() int {
2393 conn.readMu.Lock()
2394 defer conn.readMu.Unlock()
2395 return conn.readBuf.Len()
2396 }
2397
2398 ls := &oneConnListener{conn}
2399 var numReqs int
2400 var size0, size1 int
2401 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
2402 numReqs++
2403 if numReqs == 1 {
2404 size0 = readBufLen()
2405 req.Body.Close()
2406 size1 = readBufLen()
2407 }
2408 }))
2409 <-conn.closec
2410 if numReqs < 1 || numReqs > 2 {
2411 t.Fatalf("%d. bug in test. unexpected number of requests = %d", i, numReqs)
2412 }
2413 didSearch := size0 != size1
2414 if didSearch != tt.wantEOFSearch {
2415 t.Errorf("%d. did EOF search = %v; want %v (size went from %d to %d)", i, didSearch, !didSearch, size0, size1)
2416 }
2417 if tt.wantNextReq && numReqs != 2 {
2418 t.Errorf("%d. numReq = %d; want 2", i, numReqs)
2419 }
2420 }
2421
2422
2423
2424 type testHandlerBodyConsumer struct {
2425 name string
2426 f func(io.ReadCloser)
2427 }
2428
2429 var testHandlerBodyConsumers = []testHandlerBodyConsumer{
2430 {"nil", func(io.ReadCloser) {}},
2431 {"close", func(r io.ReadCloser) { r.Close() }},
2432 {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }},
2433 }
2434
2435 func TestRequestBodyReadErrorClosesConnection(t *testing.T) {
2436 setParallel(t)
2437 defer afterTest(t)
2438 for _, handler := range testHandlerBodyConsumers {
2439 conn := new(testConn)
2440 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2441 "Host: test\r\n" +
2442 "Transfer-Encoding: chunked\r\n" +
2443 "\r\n" +
2444 "hax\r\n" +
2445 "GET /secret HTTP/1.1\r\n" +
2446 "Host: test\r\n" +
2447 "\r\n")
2448
2449 conn.closec = make(chan bool, 1)
2450 ls := &oneConnListener{conn}
2451 var numReqs int
2452 go Serve(ls, HandlerFunc(func(_ ResponseWriter, req *Request) {
2453 numReqs++
2454 if strings.Contains(req.URL.Path, "secret") {
2455 t.Error("Request for /secret encountered, should not have happened.")
2456 }
2457 handler.f(req.Body)
2458 }))
2459 <-conn.closec
2460 if numReqs != 1 {
2461 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2462 }
2463 }
2464 }
2465
2466 func TestInvalidTrailerClosesConnection(t *testing.T) {
2467 setParallel(t)
2468 defer afterTest(t)
2469 for _, handler := range testHandlerBodyConsumers {
2470 conn := new(testConn)
2471 conn.readBuf.WriteString("POST /public HTTP/1.1\r\n" +
2472 "Host: test\r\n" +
2473 "Trailer: hack\r\n" +
2474 "Transfer-Encoding: chunked\r\n" +
2475 "\r\n" +
2476 "3\r\n" +
2477 "hax\r\n" +
2478 "0\r\n" +
2479 "I'm not a valid trailer\r\n" +
2480 "GET /secret HTTP/1.1\r\n" +
2481 "Host: test\r\n" +
2482 "\r\n")
2483
2484 conn.closec = make(chan bool, 1)
2485 ln := &oneConnListener{conn}
2486 var numReqs int
2487 go Serve(ln, HandlerFunc(func(_ ResponseWriter, req *Request) {
2488 numReqs++
2489 if strings.Contains(req.URL.Path, "secret") {
2490 t.Errorf("Handler %s, Request for /secret encountered, should not have happened.", handler.name)
2491 }
2492 handler.f(req.Body)
2493 }))
2494 <-conn.closec
2495 if numReqs != 1 {
2496 t.Errorf("Handler %s: got %d reqs; want 1", handler.name, numReqs)
2497 }
2498 }
2499 }
2500
2501
2502
2503
2504 type slowTestConn struct {
2505
2506 script []any
2507 closec chan bool
2508
2509 mu sync.Mutex
2510 rd, wd time.Time
2511 noopConn
2512 }
2513
2514 func (c *slowTestConn) SetDeadline(t time.Time) error {
2515 c.SetReadDeadline(t)
2516 c.SetWriteDeadline(t)
2517 return nil
2518 }
2519
2520 func (c *slowTestConn) SetReadDeadline(t time.Time) error {
2521 c.mu.Lock()
2522 defer c.mu.Unlock()
2523 c.rd = t
2524 return nil
2525 }
2526
2527 func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
2528 c.mu.Lock()
2529 defer c.mu.Unlock()
2530 c.wd = t
2531 return nil
2532 }
2533
2534 func (c *slowTestConn) Read(b []byte) (n int, err error) {
2535 c.mu.Lock()
2536 defer c.mu.Unlock()
2537 restart:
2538 if !c.rd.IsZero() && time.Now().After(c.rd) {
2539 return 0, syscall.ETIMEDOUT
2540 }
2541 if len(c.script) == 0 {
2542 return 0, io.EOF
2543 }
2544
2545 switch cue := c.script[0].(type) {
2546 case time.Duration:
2547 if !c.rd.IsZero() {
2548
2549
2550 if remaining := time.Until(c.rd); remaining < cue {
2551 c.script[0] = cue - remaining
2552 time.Sleep(remaining)
2553 return 0, syscall.ETIMEDOUT
2554 }
2555 }
2556 c.script = c.script[1:]
2557 time.Sleep(cue)
2558 goto restart
2559
2560 case string:
2561 n = copy(b, cue)
2562
2563 if len(cue) > n {
2564 c.script[0] = cue[n:]
2565 } else {
2566 c.script = c.script[1:]
2567 }
2568
2569 default:
2570 panic("unknown cue in slowTestConn script")
2571 }
2572
2573 return
2574 }
2575
2576 func (c *slowTestConn) Close() error {
2577 select {
2578 case c.closec <- true:
2579 default:
2580 }
2581 return nil
2582 }
2583
2584 func (c *slowTestConn) Write(b []byte) (int, error) {
2585 if !c.wd.IsZero() && time.Now().After(c.wd) {
2586 return 0, syscall.ETIMEDOUT
2587 }
2588 return len(b), nil
2589 }
2590
2591 func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
2592 if testing.Short() {
2593 t.Skip("skipping in -short mode")
2594 }
2595 defer afterTest(t)
2596 for _, handler := range testHandlerBodyConsumers {
2597 conn := &slowTestConn{
2598 script: []any{
2599 "POST /public HTTP/1.1\r\n" +
2600 "Host: test\r\n" +
2601 "Content-Length: 10000\r\n" +
2602 "\r\n",
2603 "foo bar baz",
2604 600 * time.Millisecond,
2605 "GET /secret HTTP/1.1\r\n" +
2606 "Host: test\r\n" +
2607 "\r\n",
2608 },
2609 closec: make(chan bool, 1),
2610 }
2611 ls := &oneConnListener{conn}
2612
2613 var numReqs int
2614 s := Server{
2615 Handler: HandlerFunc(func(_ ResponseWriter, req *Request) {
2616 numReqs++
2617 if strings.Contains(req.URL.Path, "secret") {
2618 t.Error("Request for /secret encountered, should not have happened.")
2619 }
2620 handler.f(req.Body)
2621 }),
2622 ReadTimeout: 400 * time.Millisecond,
2623 }
2624 go s.Serve(ls)
2625 <-conn.closec
2626
2627 if numReqs != 1 {
2628 t.Errorf("Handler %v: got %d reqs; want 1", handler.name, numReqs)
2629 }
2630 }
2631 }
2632
2633
2634 type cancelableTimeoutContext struct {
2635 context.Context
2636 }
2637
2638 func (c cancelableTimeoutContext) Err() error {
2639 if c.Context.Err() != nil {
2640 return context.DeadlineExceeded
2641 }
2642 return nil
2643 }
2644
2645 func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) }
2646 func testTimeoutHandler(t *testing.T, mode testMode) {
2647 sendHi := make(chan bool, 1)
2648 writeErrors := make(chan error, 1)
2649 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2650 <-sendHi
2651 _, werr := w.Write([]byte("hi"))
2652 writeErrors <- werr
2653 })
2654 ctx, cancel := context.WithCancel(context.Background())
2655 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2656 cst := newClientServerTest(t, mode, h)
2657
2658
2659 sendHi <- true
2660 res, err := cst.c.Get(cst.ts.URL)
2661 if err != nil {
2662 t.Error(err)
2663 }
2664 if g, e := res.StatusCode, StatusOK; g != e {
2665 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2666 }
2667 body, _ := io.ReadAll(res.Body)
2668 if g, e := string(body), "hi"; g != e {
2669 t.Errorf("got body %q; expected %q", g, e)
2670 }
2671 if g := <-writeErrors; g != nil {
2672 t.Errorf("got unexpected Write error on first request: %v", g)
2673 }
2674
2675
2676 cancel()
2677
2678 res, err = cst.c.Get(cst.ts.URL)
2679 if err != nil {
2680 t.Error(err)
2681 }
2682 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2683 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2684 }
2685 body, _ = io.ReadAll(res.Body)
2686 if !strings.Contains(string(body), "<title>Timeout</title>") {
2687 t.Errorf("expected timeout body; got %q", string(body))
2688 }
2689 if g, w := res.Header.Get("Content-Type"), "text/html; charset=utf-8"; g != w {
2690 t.Errorf("response content-type = %q; want %q", g, w)
2691 }
2692
2693
2694
2695 sendHi <- true
2696 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2697 t.Errorf("expected Write error of %v; got %v", e, g)
2698 }
2699 }
2700
2701
2702 func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) }
2703 func testTimeoutHandlerRace(t *testing.T, mode testMode) {
2704 delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2705 ms, _ := strconv.Atoi(r.URL.Path[1:])
2706 if ms == 0 {
2707 ms = 1
2708 }
2709 for i := 0; i < ms; i++ {
2710 w.Write([]byte("hi"))
2711 time.Sleep(time.Millisecond)
2712 }
2713 })
2714
2715 ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts
2716
2717 c := ts.Client()
2718
2719 var wg sync.WaitGroup
2720 gate := make(chan bool, 10)
2721 n := 50
2722 if testing.Short() {
2723 n = 10
2724 gate = make(chan bool, 3)
2725 }
2726 for i := 0; i < n; i++ {
2727 gate <- true
2728 wg.Add(1)
2729 go func() {
2730 defer wg.Done()
2731 defer func() { <-gate }()
2732 res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
2733 if err == nil {
2734 io.Copy(io.Discard, res.Body)
2735 res.Body.Close()
2736 }
2737 }()
2738 }
2739 wg.Wait()
2740 }
2741
2742
2743
2744 func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) }
2745 func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) {
2746 delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
2747 w.WriteHeader(204)
2748 })
2749
2750 ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts
2751
2752 var wg sync.WaitGroup
2753 gate := make(chan bool, 50)
2754 n := 500
2755 if testing.Short() {
2756 n = 10
2757 }
2758
2759 c := ts.Client()
2760 for i := 0; i < n; i++ {
2761 gate <- true
2762 wg.Add(1)
2763 go func() {
2764 defer wg.Done()
2765 defer func() { <-gate }()
2766 res, err := c.Get(ts.URL)
2767 if err != nil {
2768
2769
2770 t.Log(err)
2771 return
2772 }
2773 defer res.Body.Close()
2774 io.Copy(io.Discard, res.Body)
2775 }()
2776 }
2777 wg.Wait()
2778 }
2779
2780
2781 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) }
2782 func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) {
2783 sendHi := make(chan bool, 1)
2784 writeErrors := make(chan error, 1)
2785 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2786 w.Header().Set("Content-Type", "text/plain")
2787 <-sendHi
2788 _, werr := w.Write([]byte("hi"))
2789 writeErrors <- werr
2790 })
2791 ctx, cancel := context.WithCancel(context.Background())
2792 h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
2793 cst := newClientServerTest(t, mode, h)
2794
2795
2796 sendHi <- true
2797 res, err := cst.c.Get(cst.ts.URL)
2798 if err != nil {
2799 t.Error(err)
2800 }
2801 if g, e := res.StatusCode, StatusOK; g != e {
2802 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2803 }
2804 body, _ := io.ReadAll(res.Body)
2805 if g, e := string(body), "hi"; g != e {
2806 t.Errorf("got body %q; expected %q", g, e)
2807 }
2808 if g := <-writeErrors; g != nil {
2809 t.Errorf("got unexpected Write error on first request: %v", g)
2810 }
2811
2812
2813 cancel()
2814
2815 res, err = cst.c.Get(cst.ts.URL)
2816 if err != nil {
2817 t.Error(err)
2818 }
2819 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2820 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2821 }
2822 body, _ = io.ReadAll(res.Body)
2823 if !strings.Contains(string(body), "<title>Timeout</title>") {
2824 t.Errorf("expected timeout body; got %q", string(body))
2825 }
2826
2827
2828
2829 sendHi <- true
2830 if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
2831 t.Errorf("expected Write error of %v; got %v", e, g)
2832 }
2833 }
2834
2835
2836 func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
2837 run(t, testTimeoutHandlerStartTimerWhenServing)
2838 }
2839 func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) {
2840 if testing.Short() {
2841 t.Skip("skipping sleeping test in -short mode")
2842 }
2843 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2844 w.WriteHeader(StatusNoContent)
2845 }
2846 timeout := 300 * time.Millisecond
2847 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2848 defer ts.Close()
2849
2850 c := ts.Client()
2851
2852
2853
2854
2855 time.Sleep(2 * timeout)
2856 res, err := c.Get(ts.URL)
2857 if err != nil {
2858 t.Fatal(err)
2859 }
2860 defer res.Body.Close()
2861 if res.StatusCode != StatusNoContent {
2862 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent)
2863 }
2864 }
2865
2866 func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) }
2867 func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) {
2868 writeErrors := make(chan error, 1)
2869 sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
2870 w.Header().Set("Content-Type", "text/plain")
2871 var err error
2872
2873
2874
2875 for i := 0; i < 100; i++ {
2876 _, err = w.Write([]byte("a"))
2877 if err != nil {
2878 break
2879 }
2880 time.Sleep(1 * time.Millisecond)
2881 }
2882 writeErrors <- err
2883 })
2884 ctx, cancel := context.WithCancel(context.Background())
2885 cancel()
2886 h := NewTestTimeoutHandler(sayHi, ctx)
2887 cst := newClientServerTest(t, mode, h)
2888 defer cst.close()
2889
2890 res, err := cst.c.Get(cst.ts.URL)
2891 if err != nil {
2892 t.Error(err)
2893 }
2894 if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
2895 t.Errorf("got res.StatusCode %d; expected %d", g, e)
2896 }
2897 body, _ := io.ReadAll(res.Body)
2898 if g, e := string(body), ""; g != e {
2899 t.Errorf("got body %q; expected %q", g, e)
2900 }
2901 if g, e := <-writeErrors, context.Canceled; g != e {
2902 t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
2903 }
2904 }
2905
2906
2907 func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) }
2908 func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) {
2909 var handler HandlerFunc = func(w ResponseWriter, _ *Request) {
2910
2911 }
2912 timeout := 300 * time.Millisecond
2913 ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts
2914
2915 c := ts.Client()
2916
2917 res, err := c.Get(ts.URL)
2918 if err != nil {
2919 t.Fatal(err)
2920 }
2921 defer res.Body.Close()
2922 if res.StatusCode != StatusOK {
2923 t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK)
2924 }
2925 }
2926
2927
2928 func TestTimeoutHandlerPanicRecovery(t *testing.T) {
2929 wrapper := func(h Handler) Handler {
2930 return TimeoutHandler(h, time.Second, "")
2931 }
2932 run(t, func(t *testing.T, mode testMode) {
2933 testHandlerPanic(t, false, mode, wrapper, "intentional death for testing")
2934 }, testNotParallel, http3SkippedMode)
2935 }
2936
2937 func TestRedirectBadPath(t *testing.T) {
2938
2939
2940 rr := httptest.NewRecorder()
2941 req := &Request{
2942 Method: "GET",
2943 URL: &url.URL{
2944 Scheme: "http",
2945 Path: "not-empty-but-no-leading-slash",
2946 },
2947 }
2948 Redirect(rr, req, "", 304)
2949 if rr.Code != 304 {
2950 t.Errorf("Code = %d; want 304", rr.Code)
2951 }
2952 }
2953
2954 func TestRedirectEscapedPath(t *testing.T) {
2955 baseURL, redirectURL := "http://example.com/foo%2Fbar/", "qux%2Fbaz"
2956 req := httptest.NewRequest("GET", baseURL, NoBody)
2957
2958 rr := httptest.NewRecorder()
2959 Redirect(rr, req, redirectURL, StatusMovedPermanently)
2960
2961 wantURL := "/foo%2Fbar/qux%2Fbaz"
2962 if got := rr.Result().Header.Get("Location"); got != wantURL {
2963 t.Errorf("Redirect(%s, %s) = %s, want = %s", baseURL, redirectURL, got, wantURL)
2964 }
2965 }
2966
2967
2968 func TestRedirect(t *testing.T) {
2969 req, _ := NewRequest("GET", "http://example.com/qux/", nil)
2970
2971 var tests = []struct {
2972 in string
2973 want string
2974 }{
2975
2976 {"http://foobar.com/baz", "http://foobar.com/baz"},
2977
2978 {"https://foobar.com/baz", "https://foobar.com/baz"},
2979
2980 {"test://foobar.com/baz", "test://foobar.com/baz"},
2981
2982 {"//foobar.com/baz", "//foobar.com/baz"},
2983
2984 {"/foobar.com/baz", "/foobar.com/baz"},
2985
2986 {"foobar.com/baz", "/qux/foobar.com/baz"},
2987
2988 {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
2989
2990 {"///foobar.com/baz", "/foobar.com/baz"},
2991
2992
2993 {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"},
2994 {"http://localhost:8080/_ah/login?continue=http://localhost:8080/",
2995 "http://localhost:8080/_ah/login?continue=http://localhost:8080/"},
2996
2997 {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2998 {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"},
2999 }
3000
3001 for _, tt := range tests {
3002 rec := httptest.NewRecorder()
3003 Redirect(rec, req, tt.in, 302)
3004 if got, want := rec.Code, 302; got != want {
3005 t.Errorf("Redirect(%q) generated status code %v; want %v", tt.in, got, want)
3006 }
3007 if got := rec.Header().Get("Location"); got != tt.want {
3008 t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
3009 }
3010 }
3011 }
3012
3013
3014
3015 func TestRedirectContentTypeAndBody(t *testing.T) {
3016 type ctHeader struct {
3017 Values []string
3018 }
3019
3020 var tests = []struct {
3021 method string
3022 ct *ctHeader
3023 wantCT string
3024 wantBody string
3025 }{
3026 {MethodGet, nil, "text/html; charset=utf-8", "<a href=\"/foo\">Found</a>.\n\n"},
3027 {MethodHead, nil, "text/html; charset=utf-8", ""},
3028 {MethodPost, nil, "", ""},
3029 {MethodDelete, nil, "", ""},
3030 {"foo", nil, "", ""},
3031 {MethodGet, &ctHeader{[]string{"application/test"}}, "application/test", ""},
3032 {MethodGet, &ctHeader{[]string{}}, "", ""},
3033 {MethodGet, &ctHeader{nil}, "", ""},
3034 }
3035 for _, tt := range tests {
3036 req := httptest.NewRequest(tt.method, "http://example.com/qux/", nil)
3037 rec := httptest.NewRecorder()
3038 if tt.ct != nil {
3039 rec.Header()["Content-Type"] = tt.ct.Values
3040 }
3041 Redirect(rec, req, "/foo", 302)
3042 if got, want := rec.Code, 302; got != want {
3043 t.Errorf("Redirect(%q, %#v) generated status code %v; want %v", tt.method, tt.ct, got, want)
3044 }
3045 if got, want := rec.Header().Get("Content-Type"), tt.wantCT; got != want {
3046 t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want)
3047 }
3048 resp := rec.Result()
3049 body, err := io.ReadAll(resp.Body)
3050 if err != nil {
3051 t.Fatal(err)
3052 }
3053 if got, want := string(body), tt.wantBody; got != want {
3054 t.Errorf("Redirect(%q, %#v) generated Body %q; want %q", tt.method, tt.ct, got, want)
3055 }
3056 }
3057 }
3058
3059
3060
3061
3062
3063
3064
3065 func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) }
3066
3067 func testZeroLengthPostAndResponse(t *testing.T, mode testMode) {
3068 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
3069 all, err := io.ReadAll(r.Body)
3070 if err != nil {
3071 t.Fatalf("handler ReadAll: %v", err)
3072 }
3073 if len(all) != 0 {
3074 t.Errorf("handler got %d bytes; expected 0", len(all))
3075 }
3076 rw.Header().Set("Content-Length", "0")
3077 }))
3078
3079 req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
3080 if err != nil {
3081 t.Fatal(err)
3082 }
3083 req.ContentLength = 0
3084
3085 var resp [5]*Response
3086 for i := range resp {
3087 resp[i], err = cst.c.Do(req)
3088 if err != nil {
3089 t.Fatalf("client post #%d: %v", i, err)
3090 }
3091 }
3092
3093 for i := range resp {
3094 all, err := io.ReadAll(resp[i].Body)
3095 if err != nil {
3096 t.Fatalf("req #%d: client ReadAll: %v", i, err)
3097 }
3098 if len(all) != 0 {
3099 t.Errorf("req #%d: client got %d bytes; expected 0", i, len(all))
3100 }
3101 }
3102 }
3103
3104 func TestHandlerPanicNil(t *testing.T) {
3105 run(t, func(t *testing.T, mode testMode) {
3106 testHandlerPanic(t, false, mode, nil, nil)
3107 }, testNotParallel, http3SkippedMode)
3108 }
3109
3110 func TestHandlerPanic(t *testing.T) {
3111 run(t, func(t *testing.T, mode testMode) {
3112 testHandlerPanic(t, false, mode, nil, "intentional death for testing")
3113 }, testNotParallel, http3SkippedMode)
3114 }
3115
3116 func TestHandlerPanicWithHijack(t *testing.T) {
3117
3118 run(t, func(t *testing.T, mode testMode) {
3119 testHandlerPanic(t, true, mode, nil, "intentional death for testing")
3120 }, []testMode{http1Mode})
3121 }
3122
3123 func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) {
3124
3125
3126
3127
3128
3129
3130
3131
3132 pr, pw := io.Pipe()
3133 defer pw.Close()
3134
3135 var handler Handler = HandlerFunc(func(w ResponseWriter, r *Request) {
3136 if withHijack {
3137 rwc, _, err := w.(Hijacker).Hijack()
3138 if err != nil {
3139 t.Logf("unexpected error: %v", err)
3140 }
3141 defer rwc.Close()
3142 }
3143 panic(panicValue)
3144 })
3145 if wrapper != nil {
3146 handler = wrapper(handler)
3147 }
3148 cst := newClientServerTest(t, mode, handler, func(ts *httptest.Server) {
3149 ts.Config.ErrorLog = log.New(pw, "", 0)
3150 })
3151
3152
3153 done := make(chan bool, 1)
3154 go func() {
3155 buf := make([]byte, 4<<10)
3156 _, err := pr.Read(buf)
3157 pr.Close()
3158 if err != nil && err != io.EOF {
3159 t.Error(err)
3160 }
3161 done <- true
3162 }()
3163
3164 _, err := cst.c.Get(cst.ts.URL)
3165 if err == nil {
3166 t.Logf("expected an error")
3167 }
3168
3169 if panicValue == nil {
3170 return
3171 }
3172
3173 <-done
3174 }
3175
3176 type terrorWriter struct{ t *testing.T }
3177
3178 func (w terrorWriter) Write(p []byte) (int, error) {
3179 w.t.Errorf("%s", p)
3180 return len(p), nil
3181 }
3182
3183
3184
3185 func TestServerWriteHijackZeroBytes(t *testing.T) {
3186 run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode})
3187 }
3188 func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) {
3189 done := make(chan struct{})
3190 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3191 defer close(done)
3192 w.(Flusher).Flush()
3193 conn, _, err := w.(Hijacker).Hijack()
3194 if err != nil {
3195 t.Errorf("Hijack: %v", err)
3196 return
3197 }
3198 defer conn.Close()
3199 _, err = w.Write(nil)
3200 if err != ErrHijacked {
3201 t.Errorf("Write error = %v; want ErrHijacked", err)
3202 }
3203 }), func(ts *httptest.Server) {
3204 ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0)
3205 }).ts
3206
3207 c := ts.Client()
3208 res, err := c.Get(ts.URL)
3209 if err != nil {
3210 t.Fatal(err)
3211 }
3212 res.Body.Close()
3213 <-done
3214 }
3215
3216 func TestServerNoDate(t *testing.T) {
3217 run(t, func(t *testing.T, mode testMode) {
3218 testServerNoHeader(t, mode, "Date")
3219 })
3220 }
3221
3222 func TestServerContentType(t *testing.T) {
3223 run(t, func(t *testing.T, mode testMode) {
3224 testServerNoHeader(t, mode, "Content-Type")
3225 })
3226 }
3227
3228 func testServerNoHeader(t *testing.T, mode testMode, header string) {
3229 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3230 w.Header()[header] = nil
3231 io.WriteString(w, "<html>foo</html>")
3232 }))
3233 res, err := cst.c.Get(cst.ts.URL)
3234 if err != nil {
3235 t.Fatal(err)
3236 }
3237 res.Body.Close()
3238 if got, ok := res.Header[header]; ok {
3239 t.Fatalf("Expected no %s header; got %q", header, got)
3240 }
3241 }
3242
3243 func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) }
3244 func testStripPrefix(t *testing.T, mode testMode) {
3245 h := HandlerFunc(func(w ResponseWriter, r *Request) {
3246 w.Header().Set("X-Path", r.URL.Path)
3247 w.Header().Set("X-RawPath", r.URL.RawPath)
3248 })
3249 ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts
3250
3251 c := ts.Client()
3252
3253 cases := []struct {
3254 reqPath string
3255 path string
3256 rawPath string
3257 }{
3258 {"/foo/bar/qux", "/qux", ""},
3259 {"/foo/bar%2Fqux", "/qux", "%2Fqux"},
3260 {"/foo%2Fbar/qux", "", ""},
3261 {"/bar", "", ""},
3262 }
3263 for _, tc := range cases {
3264 t.Run(tc.reqPath, func(t *testing.T) {
3265 res, err := c.Get(ts.URL + tc.reqPath)
3266 if err != nil {
3267 t.Fatal(err)
3268 }
3269 res.Body.Close()
3270 if tc.path == "" {
3271 if res.StatusCode != StatusNotFound {
3272 t.Errorf("got %q, want 404 Not Found", res.Status)
3273 }
3274 return
3275 }
3276 if res.StatusCode != StatusOK {
3277 t.Fatalf("got %q, want 200 OK", res.Status)
3278 }
3279 if g, w := res.Header.Get("X-Path"), tc.path; g != w {
3280 t.Errorf("got Path %q, want %q", g, w)
3281 }
3282 if g, w := res.Header.Get("X-RawPath"), tc.rawPath; g != w {
3283 t.Errorf("got RawPath %q, want %q", g, w)
3284 }
3285 })
3286 }
3287 }
3288
3289
3290 func TestStripPrefixNotModifyRequest(t *testing.T) {
3291 h := StripPrefix("/foo", NotFoundHandler())
3292 req := httptest.NewRequest("GET", "/foo/bar", nil)
3293 h.ServeHTTP(httptest.NewRecorder(), req)
3294 if req.URL.Path != "/foo/bar" {
3295 t.Errorf("StripPrefix should not modify the provided Request, but it did")
3296 }
3297 }
3298
3299 func TestRequestLimit(t *testing.T) { run(t, testRequestLimit, http3SkippedMode) }
3300 func testRequestLimit(t *testing.T, mode testMode) {
3301 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3302 t.Fatalf("didn't expect to get request in Handler")
3303 }), optQuietLog)
3304 req, _ := NewRequest("GET", cst.ts.URL, nil)
3305 var bytesPerHeader = len("header12345: val12345\r\n")
3306 for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
3307 req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
3308 }
3309 res, err := cst.c.Do(req)
3310 if res != nil {
3311 defer res.Body.Close()
3312 }
3313 if mode == http2Mode {
3314
3315
3316
3317
3318 if err == nil && res.StatusCode != 431 {
3319 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3320 }
3321 } else {
3322
3323
3324
3325
3326 if err != nil {
3327 t.Fatalf("Do: %v", err)
3328 }
3329 if res.StatusCode != 431 {
3330 t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
3331 }
3332 }
3333 }
3334
3335 type neverEnding byte
3336
3337 func (b neverEnding) Read(p []byte) (n int, err error) {
3338 for i := range p {
3339 p[i] = byte(b)
3340 }
3341 return len(p), nil
3342 }
3343
3344 type bodyLimitReader struct {
3345 mu sync.Mutex
3346 count int
3347 limit int
3348 closed chan struct{}
3349 }
3350
3351 func (r *bodyLimitReader) Read(p []byte) (int, error) {
3352 r.mu.Lock()
3353 defer r.mu.Unlock()
3354 select {
3355 case <-r.closed:
3356 return 0, errors.New("closed")
3357 default:
3358 }
3359 if r.count > r.limit {
3360 return 0, errors.New("at limit")
3361 }
3362 r.count += len(p)
3363 for i := range p {
3364 p[i] = 'a'
3365 }
3366 return len(p), nil
3367 }
3368
3369 func (r *bodyLimitReader) Close() error {
3370 r.mu.Lock()
3371 defer r.mu.Unlock()
3372 close(r.closed)
3373 return nil
3374 }
3375
3376 func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) }
3377 func testRequestBodyLimit(t *testing.T, mode testMode) {
3378 const limit = 1 << 20
3379 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3380 r.Body = MaxBytesReader(w, r.Body, limit)
3381 n, err := io.Copy(io.Discard, r.Body)
3382 if err == nil {
3383 t.Errorf("expected error from io.Copy")
3384 }
3385 if n != limit {
3386 t.Errorf("io.Copy = %d, want %d", n, limit)
3387 }
3388 mbErr, ok := err.(*MaxBytesError)
3389 if !ok {
3390 t.Errorf("expected MaxBytesError, got %T", err)
3391 }
3392 if mbErr.Limit != limit {
3393 t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit)
3394 }
3395 }))
3396
3397 body := &bodyLimitReader{
3398 closed: make(chan struct{}),
3399 limit: limit * 200,
3400 }
3401 req, _ := NewRequest("POST", cst.ts.URL, body)
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412 resp, err := cst.c.Do(req)
3413 if err == nil {
3414 resp.Body.Close()
3415 }
3416
3417
3418 <-body.closed
3419
3420 if body.count > limit*100 {
3421 t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
3422 limit, body.count)
3423 }
3424 }
3425
3426
3427
3428 func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown, http3SkippedMode) }
3429 func testClientWriteShutdown(t *testing.T, mode testMode) {
3430 if runtime.GOOS == "plan9" {
3431 t.Skip("skipping test; see https://golang.org/issue/17906")
3432 }
3433 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
3434 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3435 if err != nil {
3436 t.Fatalf("Dial: %v", err)
3437 }
3438 err = conn.(*net.TCPConn).CloseWrite()
3439 if err != nil {
3440 t.Fatalf("CloseWrite: %v", err)
3441 }
3442
3443 bs, err := io.ReadAll(conn)
3444 if err != nil {
3445 t.Errorf("ReadAll: %v", err)
3446 }
3447 got := string(bs)
3448 if got != "" {
3449 t.Errorf("read %q from server; want nothing", got)
3450 }
3451 }
3452
3453
3454
3455 func TestServerBufferedChunking(t *testing.T) {
3456 conn := new(testConn)
3457 conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
3458 conn.closec = make(chan bool, 1)
3459 ls := &oneConnListener{conn}
3460 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
3461 rw.(Flusher).Flush()
3462 rw.Write([]byte{'x'})
3463 rw.Write([]byte{'y'})
3464 rw.Write([]byte{'z'})
3465 }))
3466 <-conn.closec
3467 if !bytes.HasSuffix(conn.writeBuf.Bytes(), []byte("\r\n\r\n3\r\nxyz\r\n0\r\n\r\n")) {
3468 t.Errorf("response didn't end with a single 3 byte 'xyz' chunk; got:\n%q",
3469 conn.writeBuf.Bytes())
3470 }
3471 }
3472
3473
3474
3475
3476
3477 func TestServerGracefulClose(t *testing.T) {
3478
3479 run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
3480 }
3481 func testServerGracefulClose(t *testing.T, mode testMode) {
3482 runTimeSensitiveTest(t, []time.Duration{
3483 1 * time.Millisecond,
3484 5 * time.Millisecond,
3485 10 * time.Millisecond,
3486 50 * time.Millisecond,
3487 100 * time.Millisecond,
3488 500 * time.Millisecond,
3489 time.Second,
3490 5 * time.Second,
3491 }, func(t *testing.T, timeout time.Duration) error {
3492 SetRSTAvoidanceDelay(t, timeout)
3493 t.Logf("set RST avoidance delay to %v", timeout)
3494
3495 const bodySize = 5 << 20
3496 req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
3497 for i := 0; i < bodySize; i++ {
3498 req = append(req, 'x')
3499 }
3500
3501 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3502 Error(w, "bye", StatusUnauthorized)
3503 }))
3504
3505
3506 defer cst.close()
3507 ts := cst.ts
3508
3509 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3510 if err != nil {
3511 return err
3512 }
3513 writeErr := make(chan error)
3514 go func() {
3515 _, err := conn.Write(req)
3516 writeErr <- err
3517 }()
3518 defer func() {
3519 conn.Close()
3520
3521
3522
3523 <-writeErr
3524 }()
3525
3526 br := bufio.NewReader(conn)
3527 lineNum := 0
3528 for {
3529 line, err := br.ReadString('\n')
3530 if err == io.EOF {
3531 break
3532 }
3533 if err != nil {
3534 return fmt.Errorf("ReadLine: %v", err)
3535 }
3536 lineNum++
3537 if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
3538 t.Errorf("Response line = %q; want a 401", line)
3539 }
3540 }
3541 return nil
3542 })
3543 }
3544
3545 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
3546 func testCaseSensitiveMethod(t *testing.T, mode testMode) {
3547 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3548 if r.Method != "get" {
3549 t.Errorf(`Got method %q; want "get"`, r.Method)
3550 }
3551 }))
3552 defer cst.close()
3553 req, _ := NewRequest("get", cst.ts.URL, nil)
3554 res, err := cst.c.Do(req)
3555 if err != nil {
3556 t.Error(err)
3557 return
3558 }
3559
3560 res.Body.Close()
3561 }
3562
3563
3564
3565
3566
3567 func TestContentLengthZero(t *testing.T) {
3568 run(t, testContentLengthZero, []testMode{http1Mode})
3569 }
3570 func testContentLengthZero(t *testing.T, mode testMode) {
3571 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts
3572
3573 for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} {
3574 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3575 if err != nil {
3576 t.Fatalf("error dialing: %v", err)
3577 }
3578 _, err = fmt.Fprintf(conn, "GET / %v\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n", version)
3579 if err != nil {
3580 t.Fatalf("error writing: %v", err)
3581 }
3582 req, _ := NewRequest("GET", "/", nil)
3583 res, err := ReadResponse(bufio.NewReader(conn), req)
3584 if err != nil {
3585 t.Fatalf("error reading response: %v", err)
3586 }
3587 if te := res.TransferEncoding; len(te) > 0 {
3588 t.Errorf("For version %q, Transfer-Encoding = %q; want none", version, te)
3589 }
3590 if cl := res.ContentLength; cl != 0 {
3591 t.Errorf("For version %q, Content-Length = %v; want 0", version, cl)
3592 }
3593 conn.Close()
3594 }
3595 }
3596
3597 func TestCloseNotifier(t *testing.T) {
3598 run(t, testCloseNotifier, []testMode{http1Mode})
3599 }
3600 func testCloseNotifier(t *testing.T, mode testMode) {
3601 gotReq := make(chan bool, 1)
3602 sawClose := make(chan bool, 1)
3603 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3604 gotReq <- true
3605 cc := rw.(CloseNotifier).CloseNotify()
3606 <-cc
3607 sawClose <- true
3608 })).ts
3609 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3610 if err != nil {
3611 t.Fatalf("error dialing: %v", err)
3612 }
3613 diec := make(chan bool)
3614 go func() {
3615 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
3616 if err != nil {
3617 t.Error(err)
3618 return
3619 }
3620 <-diec
3621 conn.Close()
3622 }()
3623 For:
3624 for {
3625 select {
3626 case <-gotReq:
3627 diec <- true
3628 case <-sawClose:
3629 break For
3630 }
3631 }
3632 ts.Close()
3633 }
3634
3635
3636
3637
3638
3639 func TestCloseNotifierPipelined(t *testing.T) {
3640 run(t, testCloseNotifierPipelined, []testMode{http1Mode})
3641 }
3642 func testCloseNotifierPipelined(t *testing.T, mode testMode) {
3643 gotReq := make(chan bool, 2)
3644 sawClose := make(chan bool, 2)
3645 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
3646 gotReq <- true
3647 cc := rw.(CloseNotifier).CloseNotify()
3648 select {
3649 case <-cc:
3650 t.Error("unexpected CloseNotify")
3651 case <-time.After(100 * time.Millisecond):
3652 }
3653 sawClose <- true
3654 })).ts
3655 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3656 if err != nil {
3657 t.Fatalf("error dialing: %v", err)
3658 }
3659 diec := make(chan bool, 1)
3660 defer close(diec)
3661 go func() {
3662 const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
3663 _, err = io.WriteString(conn, req+req)
3664 if err != nil {
3665 t.Error(err)
3666 return
3667 }
3668 <-diec
3669 conn.Close()
3670 }()
3671 reqs := 0
3672 closes := 0
3673 for {
3674 select {
3675 case <-gotReq:
3676 reqs++
3677 if reqs > 2 {
3678 t.Fatal("too many requests")
3679 }
3680 case <-sawClose:
3681 closes++
3682 if closes > 1 {
3683 return
3684 }
3685 }
3686 }
3687 }
3688
3689 func TestCloseNotifierChanLeak(t *testing.T) {
3690 defer afterTest(t)
3691 req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
3692 for i := 0; i < 20; i++ {
3693 var output bytes.Buffer
3694 conn := &rwTestConn{
3695 Reader: bytes.NewReader(req),
3696 Writer: &output,
3697 closec: make(chan bool, 1),
3698 }
3699 ln := &oneConnListener{conn: conn}
3700 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
3701
3702
3703
3704 _ = rw.(CloseNotifier).CloseNotify()
3705 })
3706 go Serve(ln, handler)
3707 <-conn.closec
3708 }
3709 }
3710
3711
3712
3713
3714
3715
3716
3717
3718
3719
3720 func TestHijackAfterCloseNotifier(t *testing.T) {
3721 run(t, testHijackAfterCloseNotifier, []testMode{http1Mode})
3722 }
3723 func testHijackAfterCloseNotifier(t *testing.T, mode testMode) {
3724 script := make(chan string, 2)
3725 script <- "closenotify"
3726 script <- "hijack"
3727 close(script)
3728 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3729 plan := <-script
3730 switch plan {
3731 default:
3732 panic("bogus plan; too many requests")
3733 case "closenotify":
3734 w.(CloseNotifier).CloseNotify()
3735 w.Header().Set("X-Addr", r.RemoteAddr)
3736 case "hijack":
3737 c, _, err := w.(Hijacker).Hijack()
3738 if err != nil {
3739 t.Errorf("Hijack in Handler: %v", err)
3740 return
3741 }
3742 if _, ok := c.(*net.TCPConn); !ok {
3743
3744
3745 t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
3746 }
3747 fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
3748 c.Close()
3749 return
3750 }
3751 })).ts
3752 res1, err := ts.Client().Get(ts.URL)
3753 if err != nil {
3754 log.Fatal(err)
3755 }
3756 res2, err := ts.Client().Get(ts.URL)
3757 if err != nil {
3758 log.Fatal(err)
3759 }
3760 addr1 := res1.Header.Get("X-Addr")
3761 addr2 := res2.Header.Get("X-Addr")
3762 if addr1 == "" || addr1 != addr2 {
3763 t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
3764 }
3765 }
3766
3767 func TestHijackBeforeRequestBodyRead(t *testing.T) {
3768 run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode})
3769 }
3770 func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) {
3771 var requestBody = bytes.Repeat([]byte("a"), 1<<20)
3772 bodyOkay := make(chan bool, 1)
3773 gotCloseNotify := make(chan bool, 1)
3774 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3775 defer close(bodyOkay)
3776
3777 reqBody := r.Body
3778 r.Body = nil
3779
3780 gone := w.(CloseNotifier).CloseNotify()
3781 slurp, err := io.ReadAll(reqBody)
3782 if err != nil {
3783 t.Errorf("Body read: %v", err)
3784 return
3785 }
3786 if len(slurp) != len(requestBody) {
3787 t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
3788 return
3789 }
3790 if !bytes.Equal(slurp, requestBody) {
3791 t.Error("Backend read wrong request body.")
3792 return
3793 }
3794 bodyOkay <- true
3795 <-gone
3796 gotCloseNotify <- true
3797 })).ts
3798
3799 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3800 if err != nil {
3801 t.Fatal(err)
3802 }
3803 defer conn.Close()
3804
3805 fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
3806 len(requestBody), requestBody)
3807 if !<-bodyOkay {
3808
3809 return
3810 }
3811 conn.Close()
3812 <-gotCloseNotify
3813 }
3814
3815 func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) }
3816 func testOptions(t *testing.T, mode testMode) {
3817 uric := make(chan string, 2)
3818 mux := NewServeMux()
3819 mux.HandleFunc("/", func(w ResponseWriter, r *Request) {
3820 uric <- r.RequestURI
3821 })
3822 ts := newClientServerTest(t, mode, mux).ts
3823
3824 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3825 if err != nil {
3826 t.Fatal(err)
3827 }
3828 defer conn.Close()
3829
3830
3831 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3832 if err != nil {
3833 t.Fatal(err)
3834 }
3835 br := bufio.NewReader(conn)
3836 res, err := ReadResponse(br, &Request{Method: "OPTIONS"})
3837 if err != nil {
3838 t.Fatal(err)
3839 }
3840 if res.StatusCode != 200 {
3841 t.Errorf("Got non-200 response to OPTIONS *: %#v", res)
3842 }
3843
3844
3845 _, err = conn.Write([]byte("GET * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3846 if err != nil {
3847 t.Fatal(err)
3848 }
3849 res, err = ReadResponse(br, &Request{Method: "GET"})
3850 if err != nil {
3851 t.Fatal(err)
3852 }
3853 if res.StatusCode != 400 {
3854 t.Errorf("Got non-400 response to GET *: %#v", res)
3855 }
3856
3857 res, err = Get(ts.URL + "/second")
3858 if err != nil {
3859 t.Fatal(err)
3860 }
3861 res.Body.Close()
3862 if got := <-uric; got != "/second" {
3863 t.Errorf("Handler saw request for %q; want /second", got)
3864 }
3865 }
3866
3867 func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) }
3868 func testOptionsHandler(t *testing.T, mode testMode) {
3869 rc := make(chan *Request, 1)
3870
3871 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
3872 rc <- r
3873 }), func(ts *httptest.Server) {
3874 ts.Config.DisableGeneralOptionsHandler = true
3875 }).ts
3876
3877 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
3878 if err != nil {
3879 t.Fatal(err)
3880 }
3881 defer conn.Close()
3882
3883 _, err = conn.Write([]byte("OPTIONS * HTTP/1.1\r\nHost: foo.com\r\n\r\n"))
3884 if err != nil {
3885 t.Fatal(err)
3886 }
3887
3888 if got := <-rc; got.Method != "OPTIONS" || got.RequestURI != "*" {
3889 t.Errorf("Expected OPTIONS * request, got %v", got)
3890 }
3891 }
3892
3893
3894
3895
3896
3897
3898
3899
3900
3901
3902 func TestHeaderToWire(t *testing.T) {
3903 tests := []struct {
3904 name string
3905 handler func(ResponseWriter, *Request)
3906 check func(got, logs string) error
3907 }{
3908 {
3909 name: "write without Header",
3910 handler: func(rw ResponseWriter, r *Request) {
3911 rw.Write([]byte("hello world"))
3912 },
3913 check: func(got, logs string) error {
3914 if !strings.Contains(got, "Content-Length:") {
3915 return errors.New("no content-length")
3916 }
3917 if !strings.Contains(got, "Content-Type: text/plain") {
3918 return errors.New("no content-type")
3919 }
3920 return nil
3921 },
3922 },
3923 {
3924 name: "Header mutation before write",
3925 handler: func(rw ResponseWriter, r *Request) {
3926 h := rw.Header()
3927 h.Set("Content-Type", "some/type")
3928 rw.Write([]byte("hello world"))
3929 h.Set("Too-Late", "bogus")
3930 },
3931 check: func(got, logs string) error {
3932 if !strings.Contains(got, "Content-Length:") {
3933 return errors.New("no content-length")
3934 }
3935 if !strings.Contains(got, "Content-Type: some/type") {
3936 return errors.New("wrong content-type")
3937 }
3938 if strings.Contains(got, "Too-Late") {
3939 return errors.New("don't want too-late header")
3940 }
3941 return nil
3942 },
3943 },
3944 {
3945 name: "write then useless Header mutation",
3946 handler: func(rw ResponseWriter, r *Request) {
3947 rw.Write([]byte("hello world"))
3948 rw.Header().Set("Too-Late", "Write already wrote headers")
3949 },
3950 check: func(got, logs string) error {
3951 if strings.Contains(got, "Too-Late") {
3952 return errors.New("header appeared from after WriteHeader")
3953 }
3954 return nil
3955 },
3956 },
3957 {
3958 name: "flush then write",
3959 handler: func(rw ResponseWriter, r *Request) {
3960 rw.(Flusher).Flush()
3961 rw.Write([]byte("post-flush"))
3962 rw.Header().Set("Too-Late", "Write already wrote headers")
3963 },
3964 check: func(got, logs string) error {
3965 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3966 return errors.New("not chunked")
3967 }
3968 if strings.Contains(got, "Too-Late") {
3969 return errors.New("header appeared from after WriteHeader")
3970 }
3971 return nil
3972 },
3973 },
3974 {
3975 name: "header then flush",
3976 handler: func(rw ResponseWriter, r *Request) {
3977 rw.Header().Set("Content-Type", "some/type")
3978 rw.(Flusher).Flush()
3979 rw.Write([]byte("post-flush"))
3980 rw.Header().Set("Too-Late", "Write already wrote headers")
3981 },
3982 check: func(got, logs string) error {
3983 if !strings.Contains(got, "Transfer-Encoding: chunked") {
3984 return errors.New("not chunked")
3985 }
3986 if strings.Contains(got, "Too-Late") {
3987 return errors.New("header appeared from after WriteHeader")
3988 }
3989 if !strings.Contains(got, "Content-Type: some/type") {
3990 return errors.New("wrong content-type")
3991 }
3992 return nil
3993 },
3994 },
3995 {
3996 name: "sniff-on-first-write content-type",
3997 handler: func(rw ResponseWriter, r *Request) {
3998 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
3999 rw.Header().Set("Content-Type", "x/wrong")
4000 },
4001 check: func(got, logs string) error {
4002 if !strings.Contains(got, "Content-Type: text/html") {
4003 return errors.New("wrong content-type; want html")
4004 }
4005 return nil
4006 },
4007 },
4008 {
4009 name: "explicit content-type wins",
4010 handler: func(rw ResponseWriter, r *Request) {
4011 rw.Header().Set("Content-Type", "some/type")
4012 rw.Write([]byte("<html><head></head><body>some html</body></html>"))
4013 },
4014 check: func(got, logs string) error {
4015 if !strings.Contains(got, "Content-Type: some/type") {
4016 return errors.New("wrong content-type; want html")
4017 }
4018 return nil
4019 },
4020 },
4021 {
4022 name: "empty handler",
4023 handler: func(rw ResponseWriter, r *Request) {
4024 },
4025 check: func(got, logs string) error {
4026 if !strings.Contains(got, "Content-Length: 0") {
4027 return errors.New("want 0 content-length")
4028 }
4029 return nil
4030 },
4031 },
4032 {
4033 name: "only Header, no write",
4034 handler: func(rw ResponseWriter, r *Request) {
4035 rw.Header().Set("Some-Header", "some-value")
4036 },
4037 check: func(got, logs string) error {
4038 if !strings.Contains(got, "Some-Header") {
4039 return errors.New("didn't get header")
4040 }
4041 return nil
4042 },
4043 },
4044 {
4045 name: "WriteHeader call",
4046 handler: func(rw ResponseWriter, r *Request) {
4047 rw.WriteHeader(404)
4048 rw.Header().Set("Too-Late", "some-value")
4049 },
4050 check: func(got, logs string) error {
4051 if !strings.Contains(got, "404") {
4052 return errors.New("wrong status")
4053 }
4054 if strings.Contains(got, "Too-Late") {
4055 return errors.New("shouldn't have seen Too-Late")
4056 }
4057 return nil
4058 },
4059 },
4060 }
4061 for _, tc := range tests {
4062 ht := newHandlerTest(HandlerFunc(tc.handler))
4063 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
4064 logs := ht.logbuf.String()
4065 if err := tc.check(got, logs); err != nil {
4066 t.Errorf("%s: %v\nGot response:\n%s\n\n%s", tc.name, err, got, logs)
4067 }
4068 }
4069 }
4070
4071 type errorListener struct {
4072 errs []error
4073 }
4074
4075 func (l *errorListener) Accept() (c net.Conn, err error) {
4076 if len(l.errs) == 0 {
4077 return nil, io.EOF
4078 }
4079 err = l.errs[0]
4080 l.errs = l.errs[1:]
4081 return
4082 }
4083
4084 func (l *errorListener) Close() error {
4085 return nil
4086 }
4087
4088 func (l *errorListener) Addr() net.Addr {
4089 return dummyAddr("test-address")
4090 }
4091
4092 func TestAcceptMaxFds(t *testing.T) {
4093 setParallel(t)
4094
4095 ln := &errorListener{[]error{
4096 &net.OpError{
4097 Op: "accept",
4098 Err: syscall.EMFILE,
4099 }}}
4100 server := &Server{
4101 Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})),
4102 ErrorLog: log.New(io.Discard, "", 0),
4103 }
4104 err := server.Serve(ln)
4105 if err != io.EOF {
4106 t.Errorf("got error %v, want EOF", err)
4107 }
4108 }
4109
4110 func TestWriteAfterHijack(t *testing.T) {
4111 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
4112 var buf strings.Builder
4113 wrotec := make(chan bool, 1)
4114 conn := &rwTestConn{
4115 Reader: bytes.NewReader(req),
4116 Writer: &buf,
4117 closec: make(chan bool, 1),
4118 }
4119 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
4120 conn, bufrw, err := rw.(Hijacker).Hijack()
4121 if err != nil {
4122 t.Error(err)
4123 return
4124 }
4125 go func() {
4126 bufrw.Write([]byte("[hijack-to-bufw]"))
4127 bufrw.Flush()
4128 conn.Write([]byte("[hijack-to-conn]"))
4129 conn.Close()
4130 wrotec <- true
4131 }()
4132 })
4133 ln := &oneConnListener{conn: conn}
4134 go Serve(ln, handler)
4135 <-conn.closec
4136 <-wrotec
4137 if g, w := buf.String(), "[hijack-to-bufw][hijack-to-conn]"; g != w {
4138 t.Errorf("wrote %q; want %q", g, w)
4139 }
4140 }
4141
4142 func TestDoubleHijack(t *testing.T) {
4143 req := reqBytes("GET / HTTP/1.1\nHost: golang.org")
4144 var buf bytes.Buffer
4145 conn := &rwTestConn{
4146 Reader: bytes.NewReader(req),
4147 Writer: &buf,
4148 closec: make(chan bool, 1),
4149 }
4150 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
4151 conn, _, err := rw.(Hijacker).Hijack()
4152 if err != nil {
4153 t.Error(err)
4154 return
4155 }
4156 _, _, err = rw.(Hijacker).Hijack()
4157 if err == nil {
4158 t.Errorf("got err = nil; want err != nil")
4159 }
4160 conn.Close()
4161 })
4162 ln := &oneConnListener{conn: conn}
4163 go Serve(ln, handler)
4164 <-conn.closec
4165 }
4166
4167
4168
4169
4170
4171
4172
4173 func TestHTTP10ConnectionHeader(t *testing.T) {
4174 run(t, testHTTP10ConnectionHeader, []testMode{http1Mode})
4175 }
4176 func testHTTP10ConnectionHeader(t *testing.T, mode testMode) {
4177 mux := NewServeMux()
4178 mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
4179 ts := newClientServerTest(t, mode, mux).ts
4180
4181
4182 tests := []struct {
4183 req string
4184 expect []string
4185 }{
4186 {
4187 req: "GET / HTTP/1.0\r\n\r\n",
4188 expect: nil,
4189 },
4190 {
4191 req: "OPTIONS * HTTP/1.0\r\n\r\n",
4192 expect: nil,
4193 },
4194 {
4195 req: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n",
4196 expect: []string{"keep-alive"},
4197 },
4198 }
4199
4200 for _, tt := range tests {
4201 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
4202 if err != nil {
4203 t.Fatal("dial err:", err)
4204 }
4205
4206 _, err = fmt.Fprint(conn, tt.req)
4207 if err != nil {
4208 t.Fatal("conn write err:", err)
4209 }
4210
4211 resp, err := ReadResponse(bufio.NewReader(conn), &Request{Method: "GET"})
4212 if err != nil {
4213 t.Fatal("ReadResponse err:", err)
4214 }
4215 conn.Close()
4216 resp.Body.Close()
4217
4218 got := resp.Header["Connection"]
4219 if !slices.Equal(got, tt.expect) {
4220 t.Errorf("wrong Connection headers for request %q. Got %q expect %q", tt.req, got, tt.expect)
4221 }
4222 }
4223 }
4224
4225
4226 func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) }
4227 func testServerReaderFromOrder(t *testing.T, mode testMode) {
4228 pr, pw := io.Pipe()
4229 const size = 3 << 20
4230 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4231 rw.Header().Set("Content-Type", "text/plain")
4232 done := make(chan bool)
4233 go func() {
4234 io.Copy(rw, pr)
4235 close(done)
4236 }()
4237 time.Sleep(25 * time.Millisecond)
4238 n, err := io.Copy(io.Discard, req.Body)
4239 if err != nil {
4240 t.Errorf("handler Copy: %v", err)
4241 return
4242 }
4243 if n != size {
4244 t.Errorf("handler Copy = %d; want %d", n, size)
4245 }
4246 pw.Write([]byte("hi"))
4247 pw.Close()
4248 <-done
4249 }))
4250
4251 req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
4252 if err != nil {
4253 t.Fatal(err)
4254 }
4255 res, err := cst.c.Do(req)
4256 if err != nil {
4257 t.Fatal(err)
4258 }
4259 all, err := io.ReadAll(res.Body)
4260 if err != nil {
4261 t.Fatal(err)
4262 }
4263 res.Body.Close()
4264 if string(all) != "hi" {
4265 t.Errorf("Body = %q; want hi", all)
4266 }
4267 }
4268
4269
4270 func TestCodesPreventingContentTypeAndBody(t *testing.T) {
4271 for _, code := range []int{StatusNotModified, StatusNoContent} {
4272 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4273 if r.URL.Path == "/header" {
4274 w.Header().Set("Content-Length", "123")
4275 }
4276 w.WriteHeader(code)
4277 if r.URL.Path == "/more" {
4278 w.Write([]byte("stuff"))
4279 }
4280 }))
4281 for _, req := range []string{
4282 "GET / HTTP/1.0",
4283 "GET /header HTTP/1.0",
4284 "GET /more HTTP/1.0",
4285 "GET / HTTP/1.1\nHost: foo",
4286 "GET /header HTTP/1.1\nHost: foo",
4287 "GET /more HTTP/1.1\nHost: foo",
4288 } {
4289 got := ht.rawResponse(req)
4290 wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
4291 if !strings.Contains(got, wantStatus) {
4292 t.Errorf("Code %d: Wanted %q Modified for %q: %s", code, wantStatus, req, got)
4293 } else if strings.Contains(got, "Content-Length") {
4294 t.Errorf("Code %d: Got a Content-Length from %q: %s", code, req, got)
4295 } else if strings.Contains(got, "stuff") {
4296 t.Errorf("Code %d: Response contains a body from %q: %s", code, req, got)
4297 }
4298 }
4299 }
4300 }
4301
4302 func TestContentTypeOkayOn204(t *testing.T) {
4303 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4304 w.Header().Set("Content-Length", "123")
4305 w.Header().Set("Content-Type", "foo/bar")
4306 w.WriteHeader(204)
4307 }))
4308 got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
4309 if !strings.Contains(got, "Content-Type: foo/bar") {
4310 t.Errorf("Response = %q; want Content-Type: foo/bar", got)
4311 }
4312 if strings.Contains(got, "Content-Length: 123") {
4313 t.Errorf("Response = %q; don't want a Content-Length", got)
4314 }
4315 }
4316
4317
4318
4319
4320
4321
4322
4323 func TestTransportAndServerSharedBodyRace(t *testing.T) {
4324 run(t, testTransportAndServerSharedBodyRace, testNotParallel, http3SkippedMode)
4325 }
4326 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
4327
4328
4329
4330
4331 runTimeSensitiveTest(t, []time.Duration{
4332 1 * time.Millisecond,
4333 5 * time.Millisecond,
4334 10 * time.Millisecond,
4335 50 * time.Millisecond,
4336 100 * time.Millisecond,
4337 500 * time.Millisecond,
4338 time.Second,
4339 5 * time.Second,
4340 }, func(t *testing.T, timeout time.Duration) error {
4341 SetRSTAvoidanceDelay(t, timeout)
4342 t.Logf("set RST avoidance delay to %v", timeout)
4343
4344 const bodySize = 1 << 20
4345
4346 var wg sync.WaitGroup
4347 backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4348
4349
4350
4351
4352
4353
4354
4355
4356 wg.Add(1)
4357 defer wg.Done()
4358
4359 n, err := io.CopyN(rw, req.Body, bodySize)
4360 t.Logf("backend CopyN: %v, %v", n, err)
4361 <-req.Context().Done()
4362 }))
4363
4364
4365 defer func() {
4366 wg.Wait()
4367 backend.close()
4368 }()
4369
4370 var proxy *clientServerTest
4371 proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4372 req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
4373 req2.ContentLength = bodySize
4374 cancel := make(chan struct{})
4375 req2.Cancel = cancel
4376
4377 bresp, err := proxy.c.Do(req2)
4378 if err != nil {
4379 t.Errorf("Proxy outbound request: %v", err)
4380 return
4381 }
4382 _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
4383 if err != nil {
4384 t.Errorf("Proxy copy error: %v", err)
4385 return
4386 }
4387 t.Cleanup(func() { bresp.Body.Close() })
4388
4389
4390
4391
4392
4393
4394 if mode == http2Mode {
4395 close(cancel)
4396 } else {
4397 proxy.c.Transport.(*Transport).CancelRequest(req2)
4398 }
4399 rw.Write([]byte("OK"))
4400 }))
4401 defer proxy.close()
4402
4403 req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
4404 res, err := proxy.c.Do(req)
4405 if err != nil {
4406 return fmt.Errorf("original request: %v", err)
4407 }
4408 res.Body.Close()
4409 return nil
4410 })
4411 }
4412
4413
4414
4415
4416 func TestRequestBodyCloseDoesntBlock(t *testing.T) {
4417 run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode})
4418 }
4419 func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) {
4420 if testing.Short() {
4421 t.Skip("skipping in -short mode")
4422 }
4423
4424 readErrCh := make(chan error, 1)
4425 errCh := make(chan error, 2)
4426
4427 server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4428 go func(body io.Reader) {
4429 _, err := body.Read(make([]byte, 100))
4430 readErrCh <- err
4431 }(req.Body)
4432 time.Sleep(500 * time.Millisecond)
4433 })).ts
4434
4435 closeConn := make(chan bool)
4436 defer close(closeConn)
4437 go func() {
4438 conn, err := net.Dial("tcp", server.Listener.Addr().String())
4439 if err != nil {
4440 errCh <- err
4441 return
4442 }
4443 defer conn.Close()
4444 _, err = conn.Write([]byte("POST / HTTP/1.1\r\nConnection: close\r\nHost: foo\r\nContent-Length: 100000\r\n\r\n"))
4445 if err != nil {
4446 errCh <- err
4447 return
4448 }
4449
4450
4451 <-closeConn
4452 }()
4453 select {
4454 case err := <-readErrCh:
4455 if err == nil {
4456 t.Error("Read was nil. Expected error.")
4457 }
4458 case err := <-errCh:
4459 t.Error(err)
4460 }
4461 }
4462
4463
4464 func TestResponseWriterWriteString(t *testing.T) {
4465 okc := make(chan bool, 1)
4466 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
4467 _, ok := w.(io.StringWriter)
4468 okc <- ok
4469 }))
4470 ht.rawResponse("GET / HTTP/1.0")
4471 select {
4472 case ok := <-okc:
4473 if !ok {
4474 t.Error("ResponseWriter did not implement io.StringWriter")
4475 }
4476 default:
4477 t.Error("handler was never called")
4478 }
4479 }
4480
4481 func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) }
4482 func testServerConnState(t *testing.T, mode testMode) {
4483 handler := map[string]func(w ResponseWriter, r *Request){
4484 "/": func(w ResponseWriter, r *Request) {
4485 fmt.Fprintf(w, "Hello.")
4486 },
4487 "/close": func(w ResponseWriter, r *Request) {
4488 w.Header().Set("Connection", "close")
4489 fmt.Fprintf(w, "Hello.")
4490 },
4491 "/hijack": func(w ResponseWriter, r *Request) {
4492 c, _, _ := w.(Hijacker).Hijack()
4493 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4494 c.Close()
4495 },
4496 "/hijack-panic": func(w ResponseWriter, r *Request) {
4497 c, _, _ := w.(Hijacker).Hijack()
4498 c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello."))
4499 c.Close()
4500 panic("intentional panic")
4501 },
4502 }
4503
4504
4505 type stateLog struct {
4506 active net.Conn
4507 got []ConnState
4508 want []ConnState
4509 complete chan<- struct{}
4510 }
4511 activeLog := make(chan *stateLog, 1)
4512
4513
4514
4515
4516 wantLog := func(doRequests func(), want ...ConnState) {
4517 t.Helper()
4518 complete := make(chan struct{})
4519 activeLog <- &stateLog{want: want, complete: complete}
4520
4521 doRequests()
4522
4523 <-complete
4524 sl := <-activeLog
4525 if !slices.Equal(sl.got, sl.want) {
4526 t.Errorf("Request(s) produced unexpected state sequence.\nGot: %v\nWant: %v", sl.got, sl.want)
4527 }
4528
4529
4530
4531 }
4532
4533 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4534 handler[r.URL.Path](w, r)
4535 }), func(ts *httptest.Server) {
4536 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
4537 ts.Config.ConnState = func(c net.Conn, state ConnState) {
4538 if c == nil {
4539 t.Errorf("nil conn seen in state %s", state)
4540 return
4541 }
4542 sl := <-activeLog
4543 if sl.active == nil && state == StateNew {
4544 sl.active = c
4545 } else if sl.active != c {
4546 t.Errorf("unexpected conn in state %s", state)
4547 activeLog <- sl
4548 return
4549 }
4550 sl.got = append(sl.got, state)
4551 if sl.complete != nil && (len(sl.got) >= len(sl.want) || !slices.Equal(sl.got, sl.want[:len(sl.got)])) {
4552 close(sl.complete)
4553 sl.complete = nil
4554 }
4555 activeLog <- sl
4556 }
4557 }).ts
4558 defer func() {
4559 activeLog <- &stateLog{}
4560 ts.Close()
4561 }()
4562
4563 c := ts.Client()
4564
4565 mustGet := func(url string, headers ...string) {
4566 t.Helper()
4567 req, err := NewRequest("GET", url, nil)
4568 if err != nil {
4569 t.Fatal(err)
4570 }
4571 for len(headers) > 0 {
4572 req.Header.Add(headers[0], headers[1])
4573 headers = headers[2:]
4574 }
4575 res, err := c.Do(req)
4576 if err != nil {
4577 t.Errorf("Error fetching %s: %v", url, err)
4578 return
4579 }
4580 _, err = io.ReadAll(res.Body)
4581 defer res.Body.Close()
4582 if err != nil {
4583 t.Errorf("Error reading %s: %v", url, err)
4584 }
4585 }
4586
4587 wantLog(func() {
4588 mustGet(ts.URL + "/")
4589 mustGet(ts.URL + "/close")
4590 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4591
4592 wantLog(func() {
4593 mustGet(ts.URL + "/")
4594 mustGet(ts.URL+"/", "Connection", "close")
4595 }, StateNew, StateActive, StateIdle, StateActive, StateClosed)
4596
4597 wantLog(func() {
4598 mustGet(ts.URL + "/hijack")
4599 }, StateNew, StateActive, StateHijacked)
4600
4601 wantLog(func() {
4602 mustGet(ts.URL + "/hijack-panic")
4603 }, StateNew, StateActive, StateHijacked)
4604
4605 wantLog(func() {
4606 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4607 if err != nil {
4608 t.Fatal(err)
4609 }
4610 c.Close()
4611 }, StateNew, StateClosed)
4612
4613 wantLog(func() {
4614 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4615 if err != nil {
4616 t.Fatal(err)
4617 }
4618 if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
4619 t.Fatal(err)
4620 }
4621 c.Read(make([]byte, 1))
4622 c.Close()
4623 }, StateNew, StateActive, StateClosed)
4624
4625 wantLog(func() {
4626 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4627 if err != nil {
4628 t.Fatal(err)
4629 }
4630 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4631 t.Fatal(err)
4632 }
4633 res, err := ReadResponse(bufio.NewReader(c), nil)
4634 if err != nil {
4635 t.Fatal(err)
4636 }
4637 if _, err := io.Copy(io.Discard, res.Body); err != nil {
4638 t.Fatal(err)
4639 }
4640 c.Close()
4641 }, StateNew, StateActive, StateIdle, StateClosed)
4642 }
4643
4644 func TestServerKeepAlivesEnabledResultClose(t *testing.T) {
4645 run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode})
4646 }
4647 func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) {
4648 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4649 }), func(ts *httptest.Server) {
4650 ts.Config.SetKeepAlivesEnabled(false)
4651 }).ts
4652 res, err := ts.Client().Get(ts.URL)
4653 if err != nil {
4654 t.Fatal(err)
4655 }
4656 defer res.Body.Close()
4657 if !res.Close {
4658 t.Errorf("Body.Close == false; want true")
4659 }
4660 }
4661
4662
4663 func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) }
4664 func testServerEmptyBodyRace(t *testing.T, mode testMode) {
4665 var n int32
4666 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
4667 atomic.AddInt32(&n, 1)
4668 }), optQuietLog)
4669 var wg sync.WaitGroup
4670 const reqs = 20
4671 for i := 0; i < reqs; i++ {
4672 wg.Add(1)
4673 go func() {
4674 defer wg.Done()
4675 res, err := cst.c.Get(cst.ts.URL)
4676 if err != nil {
4677
4678
4679 time.Sleep(10 * time.Millisecond)
4680 res, err = cst.c.Get(cst.ts.URL)
4681 if err != nil {
4682 t.Error(err)
4683 return
4684 }
4685 }
4686 defer res.Body.Close()
4687 _, err = io.Copy(io.Discard, res.Body)
4688 if err != nil {
4689 t.Error(err)
4690 return
4691 }
4692 }()
4693 }
4694 wg.Wait()
4695 if got := atomic.LoadInt32(&n); got != reqs {
4696 t.Errorf("handler ran %d times; want %d", got, reqs)
4697 }
4698 }
4699
4700 func TestServerConnStateNew(t *testing.T) {
4701 sawNew := false
4702 srv := &Server{
4703 ConnState: func(c net.Conn, state ConnState) {
4704 if state == StateNew {
4705 sawNew = true
4706 }
4707 },
4708 Handler: HandlerFunc(func(w ResponseWriter, r *Request) {}),
4709 }
4710 srv.Serve(&oneConnListener{
4711 conn: &rwTestConn{
4712 Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"),
4713 Writer: io.Discard,
4714 },
4715 })
4716 if !sawNew {
4717 t.Error("StateNew not seen")
4718 }
4719 }
4720
4721 type closeWriteTestConn struct {
4722 rwTestConn
4723 didCloseWrite bool
4724 }
4725
4726 func (c *closeWriteTestConn) CloseWrite() error {
4727 c.didCloseWrite = true
4728 return nil
4729 }
4730
4731 func TestCloseWrite(t *testing.T) {
4732 SetRSTAvoidanceDelay(t, 1*time.Millisecond)
4733
4734 var srv Server
4735 var testConn closeWriteTestConn
4736 c := ExportServerNewConn(&srv, &testConn)
4737 ExportCloseWriteAndWait(c)
4738 if !testConn.didCloseWrite {
4739 t.Error("didn't see CloseWrite call")
4740 }
4741 }
4742
4743
4744
4745
4746
4747
4748
4749
4750 func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) }
4751 func testServerFlushAndHijack(t *testing.T, mode testMode) {
4752 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4753 io.WriteString(w, "Hello, ")
4754 w.(Flusher).Flush()
4755 conn, buf, _ := w.(Hijacker).Hijack()
4756 buf.WriteString("6\r\nworld!\r\n0\r\n\r\n")
4757 if err := buf.Flush(); err != nil {
4758 t.Error(err)
4759 }
4760 if err := conn.Close(); err != nil {
4761 t.Error(err)
4762 }
4763 })).ts
4764 res, err := Get(ts.URL)
4765 if err != nil {
4766 t.Fatal(err)
4767 }
4768 defer res.Body.Close()
4769 all, err := io.ReadAll(res.Body)
4770 if err != nil {
4771 t.Fatal(err)
4772 }
4773 if want := "Hello, world!"; string(all) != want {
4774 t.Errorf("Got %q; want %q", all, want)
4775 }
4776 }
4777
4778
4779
4780
4781
4782
4783
4784 func TestServerKeepAliveAfterWriteError(t *testing.T) {
4785 run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode})
4786 }
4787 func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) {
4788 if testing.Short() {
4789 t.Skip("skipping in -short mode")
4790 }
4791 const numReq = 3
4792 addrc := make(chan string, numReq)
4793 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4794 addrc <- r.RemoteAddr
4795 time.Sleep(500 * time.Millisecond)
4796 w.(Flusher).Flush()
4797 }), func(ts *httptest.Server) {
4798 ts.Config.WriteTimeout = 250 * time.Millisecond
4799 }).ts
4800
4801 errc := make(chan error, numReq)
4802 go func() {
4803 defer close(errc)
4804 for i := 0; i < numReq; i++ {
4805 res, err := Get(ts.URL)
4806 if res != nil {
4807 res.Body.Close()
4808 }
4809 errc <- err
4810 }
4811 }()
4812
4813 addrSeen := map[string]bool{}
4814 numOkay := 0
4815 for {
4816 select {
4817 case v := <-addrc:
4818 addrSeen[v] = true
4819 case err, ok := <-errc:
4820 if !ok {
4821 if len(addrSeen) != numReq {
4822 t.Errorf("saw %d unique client addresses; want %d", len(addrSeen), numReq)
4823 }
4824 if numOkay != 0 {
4825 t.Errorf("got %d successful client requests; want 0", numOkay)
4826 }
4827 return
4828 }
4829 if err == nil {
4830 numOkay++
4831 }
4832 }
4833 }
4834 }
4835
4836
4837
4838 func TestNoContentLengthIfTransferEncoding(t *testing.T) {
4839 run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode})
4840 }
4841 func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) {
4842 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4843 w.Header().Set("Transfer-Encoding", "foo")
4844 io.WriteString(w, "<html>")
4845 })).ts
4846 c, err := net.Dial("tcp", ts.Listener.Addr().String())
4847 if err != nil {
4848 t.Fatalf("Dial: %v", err)
4849 }
4850 defer c.Close()
4851 if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil {
4852 t.Fatal(err)
4853 }
4854 bs := bufio.NewScanner(c)
4855 var got strings.Builder
4856 for bs.Scan() {
4857 if strings.TrimSpace(bs.Text()) == "" {
4858 break
4859 }
4860 got.WriteString(bs.Text())
4861 got.WriteByte('\n')
4862 }
4863 if err := bs.Err(); err != nil {
4864 t.Fatal(err)
4865 }
4866 if strings.Contains(got.String(), "Content-Length") {
4867 t.Errorf("Unexpected Content-Length in response headers: %s", got.String())
4868 }
4869 if strings.Contains(got.String(), "Content-Type") {
4870 t.Errorf("Unexpected Content-Type in response headers: %s", got.String())
4871 }
4872 }
4873
4874
4875
4876 func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
4877 req := []byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
4878 "\r\n\r\n" +
4879 "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")
4880 var buf bytes.Buffer
4881 conn := &rwTestConn{
4882 Reader: bytes.NewReader(req),
4883 Writer: &buf,
4884 closec: make(chan bool, 1),
4885 }
4886 ln := &oneConnListener{conn: conn}
4887 numReq := 0
4888 go Serve(ln, HandlerFunc(func(rw ResponseWriter, r *Request) {
4889 numReq++
4890 }))
4891 <-conn.closec
4892 if numReq != 2 {
4893 t.Errorf("num requests = %d; want 2", numReq)
4894 t.Logf("Res: %s", buf.Bytes())
4895 }
4896 }
4897
4898 func TestIssue13893_Expect100(t *testing.T) {
4899
4900 req := reqBytes(`PUT /readbody HTTP/1.1
4901 User-Agent: PycURL/7.22.0
4902 Host: 127.0.0.1:9000
4903 Accept: */*
4904 Expect: 100-continue
4905 Content-Length: 10
4906
4907 HelloWorld
4908
4909 `)
4910 var buf bytes.Buffer
4911 conn := &rwTestConn{
4912 Reader: bytes.NewReader(req),
4913 Writer: &buf,
4914 closec: make(chan bool, 1),
4915 }
4916 ln := &oneConnListener{conn: conn}
4917 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4918 if _, ok := r.Header["Expect"]; !ok {
4919 t.Error("Expect header should not be filtered out")
4920 }
4921 }))
4922 <-conn.closec
4923 }
4924
4925 func TestIssue11549_Expect100(t *testing.T) {
4926 req := reqBytes(`PUT /readbody HTTP/1.1
4927 User-Agent: PycURL/7.22.0
4928 Host: 127.0.0.1:9000
4929 Accept: */*
4930 Expect: 100-continue
4931 Content-Length: 10
4932
4933 HelloWorldPUT /noreadbody HTTP/1.1
4934 User-Agent: PycURL/7.22.0
4935 Host: 127.0.0.1:9000
4936 Accept: */*
4937 Expect: 100-continue
4938 Content-Length: 10
4939
4940 GET /should-be-ignored HTTP/1.1
4941 Host: foo
4942
4943 `)
4944 var buf strings.Builder
4945 conn := &rwTestConn{
4946 Reader: bytes.NewReader(req),
4947 Writer: &buf,
4948 closec: make(chan bool, 1),
4949 }
4950 ln := &oneConnListener{conn: conn}
4951 numReq := 0
4952 go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
4953 numReq++
4954 if r.URL.Path == "/readbody" {
4955 io.ReadAll(r.Body)
4956 }
4957 io.WriteString(w, "Hello world!")
4958 }))
4959 <-conn.closec
4960 if numReq != 2 {
4961 t.Errorf("num requests = %d; want 2", numReq)
4962 }
4963 if !strings.Contains(buf.String(), "Connection: close\r\n") {
4964 t.Errorf("expected 'Connection: close' in response; got: %s", buf.String())
4965 }
4966 }
4967
4968
4969
4970 func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
4971 setParallel(t)
4972 conn := newTestConn()
4973 conn.readBuf.WriteString(
4974 "POST / HTTP/1.1\r\n" +
4975 "Host: test\r\n" +
4976 "Content-Length: 9999999999\r\n" +
4977 "\r\n" + strings.Repeat("a", 1<<20))
4978
4979 ls := &oneConnListener{conn}
4980 var inHandlerLen int
4981 go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
4982 inHandlerLen = conn.readBuf.Len()
4983 rw.WriteHeader(404)
4984 }))
4985 <-conn.closec
4986 afterHandlerLen := conn.readBuf.Len()
4987
4988 if afterHandlerLen != inHandlerLen {
4989 t.Errorf("unexpected implicit read. Read buffer went from %d -> %d", inHandlerLen, afterHandlerLen)
4990 }
4991 }
4992
4993 func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) }
4994 func testHandlerSetsBodyNil(t *testing.T, mode testMode) {
4995 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
4996 r.Body = nil
4997 fmt.Fprintf(w, "%v", r.RemoteAddr)
4998 }))
4999 get := func() string {
5000 res, err := cst.c.Get(cst.ts.URL)
5001 if err != nil {
5002 t.Fatal(err)
5003 }
5004 defer res.Body.Close()
5005 slurp, err := io.ReadAll(res.Body)
5006 if err != nil {
5007 t.Fatal(err)
5008 }
5009 return string(slurp)
5010 }
5011 a, b := get(), get()
5012 if a != b {
5013 t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
5014 }
5015 }
5016
5017
5018
5019 func TestServerValidatesHostHeader(t *testing.T) {
5020 tests := []struct {
5021 proto string
5022 host string
5023 want int
5024 }{
5025 {"HTTP/0.9", "", 505},
5026
5027 {"HTTP/1.1", "", 400},
5028 {"HTTP/1.1", "Host: \r\n", 200},
5029 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
5030 {"HTTP/1.1", "Host: foo.com\r\n", 200},
5031 {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
5032 {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
5033 {"HTTP/1.1", "Host: ::1\r\n", 200},
5034 {"HTTP/1.1", "Host: [::1]\r\n", 200},
5035 {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
5036 {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
5037 {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
5038 {"HTTP/1.1", "Host: \x06\r\n", 400},
5039 {"HTTP/1.1", "Host: \xff\r\n", 400},
5040 {"HTTP/1.1", "Host: {\r\n", 400},
5041 {"HTTP/1.1", "Host: }\r\n", 400},
5042 {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
5043
5044
5045
5046 {"HTTP/1.0", "", 200},
5047 {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
5048 {"HTTP/1.0", "Host: \xff\r\n", 400},
5049
5050
5051 {"PRI * HTTP/2.0", "", 200},
5052
5053
5054 {"CONNECT golang.org:443 HTTP/1.1", "", 200},
5055
5056
5057 {"PRI / HTTP/2.0", "", 505},
5058 {"GET / HTTP/2.0", "", 505},
5059 {"GET / HTTP/3.0", "", 505},
5060 }
5061 for _, tt := range tests {
5062 conn := newTestConn()
5063 methodTarget := "GET / "
5064 if !strings.HasPrefix(tt.proto, "HTTP/") {
5065 methodTarget = ""
5066 }
5067 io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n")
5068
5069 ln := &oneConnListener{conn}
5070 srv := Server{
5071 ErrorLog: quietLog,
5072 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5073 }
5074 go srv.Serve(ln)
5075 <-conn.closec
5076 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5077 if err != nil {
5078 t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
5079 continue
5080 }
5081 if res.StatusCode != tt.want {
5082 t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
5083 }
5084 }
5085 }
5086
5087 func TestServerHandlersCanHandleH2PRI(t *testing.T) {
5088 run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode})
5089 }
5090 func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) {
5091 const upgradeResponse = "upgrade here"
5092 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5093 conn, br, err := w.(Hijacker).Hijack()
5094 if err != nil {
5095 t.Error(err)
5096 return
5097 }
5098 defer conn.Close()
5099 if r.Method != "PRI" || r.RequestURI != "*" {
5100 t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI)
5101 return
5102 }
5103 if !r.Close {
5104 t.Errorf("Request.Close = true; want false")
5105 }
5106 const want = "SM\r\n\r\n"
5107 buf := make([]byte, len(want))
5108 n, err := io.ReadFull(br, buf)
5109 if err != nil || string(buf[:n]) != want {
5110 t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want)
5111 return
5112 }
5113 io.WriteString(conn, upgradeResponse)
5114 })).ts
5115
5116 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5117 if err != nil {
5118 t.Fatalf("Dial: %v", err)
5119 }
5120 defer c.Close()
5121 io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
5122 slurp, err := io.ReadAll(c)
5123 if err != nil {
5124 t.Fatal(err)
5125 }
5126 if string(slurp) != upgradeResponse {
5127 t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse)
5128 }
5129 }
5130
5131
5132
5133 func TestServerValidatesHeaders(t *testing.T) {
5134 setParallel(t)
5135 tests := []struct {
5136 header string
5137 want int
5138 }{
5139 {"", 200},
5140 {"Foo: bar\r\n", 200},
5141 {"X-Foo: bar\r\n", 200},
5142 {"Foo: a space\r\n", 200},
5143
5144 {"A space: foo\r\n", 400},
5145 {"foo\xffbar: foo\r\n", 400},
5146 {"foo\x00bar: foo\r\n", 400},
5147 {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431},
5148
5149
5150 {"Foo : bar\r\n", 400},
5151 {"Foo\t: bar\r\n", 400},
5152
5153
5154
5155 {": empty key\r\n", 400},
5156
5157
5158
5159
5160 {"Content-Length: notdigits\r\n", 400},
5161 {"Content-Length: notdigits\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n", 400},
5162
5163 {"foo: foo foo\r\n", 200},
5164 {"foo: foo\tfoo\r\n", 200},
5165 {"foo: foo\x00foo\r\n", 400},
5166 {"foo: foo\x7ffoo\r\n", 400},
5167 {"foo: foo\xfffoo\r\n", 200},
5168 }
5169 for _, tt := range tests {
5170 conn := newTestConn()
5171 io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
5172
5173 ln := &oneConnListener{conn}
5174 srv := Server{
5175 ErrorLog: quietLog,
5176 Handler: HandlerFunc(func(ResponseWriter, *Request) {}),
5177 }
5178 go srv.Serve(ln)
5179 <-conn.closec
5180 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
5181 if err != nil {
5182 t.Errorf("For %q, ReadResponse: %v", tt.header, res)
5183 continue
5184 }
5185 if res.StatusCode != tt.want {
5186 t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
5187 }
5188 }
5189 }
5190
5191 func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
5192 run(t, testServerRequestContextCancel_ServeHTTPDone, http3SkippedMode)
5193 }
5194 func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) {
5195 ctxc := make(chan context.Context, 1)
5196 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5197 ctx := r.Context()
5198 select {
5199 case <-ctx.Done():
5200 t.Error("should not be Done in ServeHTTP")
5201 default:
5202 }
5203 ctxc <- ctx
5204 }))
5205 res, err := cst.c.Get(cst.ts.URL)
5206 if err != nil {
5207 t.Fatal(err)
5208 }
5209 res.Body.Close()
5210 ctx := <-ctxc
5211 select {
5212 case <-ctx.Done():
5213 default:
5214 t.Error("context should be done after ServeHTTP completes")
5215 }
5216 }
5217
5218
5219
5220
5221
5222 func TestServerRequestContextCancel_ConnClose(t *testing.T) {
5223 run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode})
5224 }
5225 func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) {
5226 inHandler := make(chan struct{})
5227 handlerDone := make(chan struct{})
5228 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5229 close(inHandler)
5230 <-r.Context().Done()
5231 close(handlerDone)
5232 })).ts
5233 c, err := net.Dial("tcp", ts.Listener.Addr().String())
5234 if err != nil {
5235 t.Fatal(err)
5236 }
5237 defer c.Close()
5238 io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n")
5239 <-inHandler
5240 c.Close()
5241 <-handlerDone
5242 }
5243
5244 func TestServerContext_ServerContextKey(t *testing.T) {
5245 run(t, testServerContext_ServerContextKey, http3SkippedMode)
5246 }
5247 func testServerContext_ServerContextKey(t *testing.T, mode testMode) {
5248 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5249 ctx := r.Context()
5250 got := ctx.Value(ServerContextKey)
5251 if _, ok := got.(*Server); !ok {
5252 t.Errorf("context value = %T; want *http.Server", got)
5253 }
5254 }))
5255 res, err := cst.c.Get(cst.ts.URL)
5256 if err != nil {
5257 t.Fatal(err)
5258 }
5259 res.Body.Close()
5260 }
5261
5262 func TestServerContext_LocalAddrContextKey(t *testing.T) {
5263 run(t, testServerContext_LocalAddrContextKey, http3SkippedMode)
5264 }
5265 func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) {
5266 ch := make(chan any, 1)
5267 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5268 ch <- r.Context().Value(LocalAddrContextKey)
5269 }))
5270 if _, err := cst.c.Head(cst.ts.URL); err != nil {
5271 t.Fatal(err)
5272 }
5273
5274 host := cst.ts.Listener.Addr().String()
5275 got := <-ch
5276 if addr, ok := got.(net.Addr); !ok {
5277 t.Errorf("local addr value = %T; want net.Addr", got)
5278 } else if fmt.Sprint(addr) != host {
5279 t.Errorf("local addr = %v; want %v", addr, host)
5280 }
5281 }
5282
5283
5284 func TestHandlerSetTransferEncodingChunked(t *testing.T) {
5285 setParallel(t)
5286 defer afterTest(t)
5287 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5288 w.Header().Set("Transfer-Encoding", "chunked")
5289 w.Write([]byte("hello"))
5290 }))
5291 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5292 const hdr = "Transfer-Encoding: chunked"
5293 if n := strings.Count(resp, hdr); n != 1 {
5294 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5295 }
5296 }
5297
5298
5299 func TestHandlerSetTransferEncodingGzip(t *testing.T) {
5300 setParallel(t)
5301 defer afterTest(t)
5302 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
5303 w.Header().Set("Transfer-Encoding", "gzip")
5304 gz := gzip.NewWriter(w)
5305 gz.Write([]byte("hello"))
5306 gz.Close()
5307 }))
5308 resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
5309 for _, v := range []string{"gzip", "chunked"} {
5310 hdr := "Transfer-Encoding: " + v
5311 if n := strings.Count(resp, hdr); n != 1 {
5312 t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp)
5313 }
5314 }
5315 }
5316
5317 func BenchmarkClientServer(b *testing.B) {
5318 run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode})
5319 }
5320 func benchmarkClientServer(b *testing.B, mode testMode) {
5321 b.ReportAllocs()
5322 b.StopTimer()
5323 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5324 fmt.Fprintf(rw, "Hello world.\n")
5325 })).ts
5326 b.StartTimer()
5327
5328 c := ts.Client()
5329 for i := 0; i < b.N; i++ {
5330 res, err := c.Get(ts.URL)
5331 if err != nil {
5332 b.Fatal("Get:", err)
5333 }
5334 all, err := io.ReadAll(res.Body)
5335 res.Body.Close()
5336 if err != nil {
5337 b.Fatal("ReadAll:", err)
5338 }
5339 body := string(all)
5340 if body != "Hello world.\n" {
5341 b.Fatal("Got body:", body)
5342 }
5343 }
5344
5345 b.StopTimer()
5346 }
5347
5348 func BenchmarkClientServerParallel(b *testing.B) {
5349 for _, parallelism := range []int{4, 64} {
5350 b.Run(fmt.Sprint(parallelism), func(b *testing.B) {
5351 run(b, func(b *testing.B, mode testMode) {
5352 benchmarkClientServerParallel(b, parallelism, mode)
5353 }, []testMode{http1Mode, https1Mode, http2Mode})
5354 })
5355 }
5356 }
5357
5358 func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) {
5359 b.ReportAllocs()
5360 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
5361 fmt.Fprintf(rw, "Hello world.\n")
5362 })).ts
5363 b.ResetTimer()
5364 b.SetParallelism(parallelism)
5365 b.RunParallel(func(pb *testing.PB) {
5366 c := ts.Client()
5367 for pb.Next() {
5368 res, err := c.Get(ts.URL)
5369 if err != nil {
5370 b.Logf("Get: %v", err)
5371 continue
5372 }
5373 all, err := io.ReadAll(res.Body)
5374 res.Body.Close()
5375 if err != nil {
5376 b.Logf("ReadAll: %v", err)
5377 continue
5378 }
5379 body := string(all)
5380 if body != "Hello world.\n" {
5381 panic("Got body: " + body)
5382 }
5383 }
5384 })
5385 }
5386
5387
5388
5389
5390
5391
5392
5393
5394
5395
5396 func BenchmarkServer(b *testing.B) {
5397 b.ReportAllocs()
5398
5399 if url := os.Getenv("GO_TEST_BENCH_SERVER_URL"); url != "" {
5400 n, err := strconv.Atoi(os.Getenv("GO_TEST_BENCH_CLIENT_N"))
5401 if err != nil {
5402 panic(err)
5403 }
5404 for i := 0; i < n; i++ {
5405 res, err := Get(url)
5406 if err != nil {
5407 log.Panicf("Get: %v", err)
5408 }
5409 all, err := io.ReadAll(res.Body)
5410 res.Body.Close()
5411 if err != nil {
5412 log.Panicf("ReadAll: %v", err)
5413 }
5414 body := string(all)
5415 if body != "Hello world.\n" {
5416 log.Panicf("Got body: %q", body)
5417 }
5418 }
5419 os.Exit(0)
5420 return
5421 }
5422
5423 var res = []byte("Hello world.\n")
5424 b.StopTimer()
5425 ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
5426 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5427 rw.Write(res)
5428 }))
5429 defer ts.Close()
5430 b.StartTimer()
5431
5432 cmd := testenv.Command(b, os.Args[0], "-test.run=^$", "-test.bench=^BenchmarkServer$")
5433 cmd.Env = append([]string{
5434 fmt.Sprintf("GO_TEST_BENCH_CLIENT_N=%d", b.N),
5435 fmt.Sprintf("GO_TEST_BENCH_SERVER_URL=%s", ts.URL),
5436 }, os.Environ()...)
5437 out, err := cmd.CombinedOutput()
5438 if err != nil {
5439 b.Errorf("Test failure: %v, with output: %s", err, out)
5440 }
5441 }
5442
5443
5444 func getNoBody(urlStr string) (*Response, error) {
5445 res, err := Get(urlStr)
5446 if err != nil {
5447 return nil, err
5448 }
5449 res.Body.Close()
5450 return res, nil
5451 }
5452
5453
5454
5455 func BenchmarkClient(b *testing.B) {
5456 var data = []byte("Hello world.\n")
5457
5458 url := startClientBenchmarkServer(b, HandlerFunc(func(w ResponseWriter, _ *Request) {
5459 w.Header().Set("Content-Type", "text/html; charset=utf-8")
5460 w.Write(data)
5461 }))
5462
5463
5464 b.StartTimer()
5465 for i := 0; i < b.N; i++ {
5466 res, err := Get(url)
5467 if err != nil {
5468 b.Fatalf("Get: %v", err)
5469 }
5470 body, err := io.ReadAll(res.Body)
5471 res.Body.Close()
5472 if err != nil {
5473 b.Fatalf("ReadAll: %v", err)
5474 }
5475 if !bytes.Equal(body, data) {
5476 b.Fatalf("Got body: %q", body)
5477 }
5478 }
5479 b.StopTimer()
5480 }
5481
5482 func startClientBenchmarkServer(b *testing.B, handler Handler) string {
5483 b.ReportAllocs()
5484 b.StopTimer()
5485
5486 if server := os.Getenv("GO_TEST_BENCH_SERVER"); server != "" {
5487
5488 port := os.Getenv("GO_TEST_BENCH_SERVER_PORT")
5489 if port == "" {
5490 port = "0"
5491 }
5492 ln, err := net.Listen("tcp", "localhost:"+port)
5493 if err != nil {
5494 log.Fatal(err)
5495 }
5496 fmt.Println(ln.Addr().String())
5497
5498 HandleFunc("/", func(w ResponseWriter, r *Request) {
5499 r.ParseForm()
5500 if r.Form.Get("stop") != "" {
5501 os.Exit(0)
5502 }
5503 handler.ServeHTTP(w, r)
5504 })
5505 var srv Server
5506 log.Fatal(srv.Serve(ln))
5507 }
5508
5509
5510 ctx, cancel := context.WithCancel(context.Background())
5511 cmd := testenv.CommandContext(b, ctx, os.Args[0], "-test.run=^$", "-test.bench=^"+b.Name()+"$")
5512 cmd.Env = append(cmd.Environ(), "GO_TEST_BENCH_SERVER=yes")
5513 cmd.Stderr = os.Stderr
5514 stdout, err := cmd.StdoutPipe()
5515 if err != nil {
5516 b.Fatal(err)
5517 }
5518 if err := cmd.Start(); err != nil {
5519 b.Fatalf("subprocess failed to start: %v", err)
5520 }
5521
5522 done := make(chan error, 1)
5523 go func() {
5524 done <- cmd.Wait()
5525 close(done)
5526 }()
5527
5528
5529
5530 bs := bufio.NewScanner(stdout)
5531 if !bs.Scan() {
5532 b.Fatalf("failed to read listening URL from child: %v", bs.Err())
5533 }
5534 url := "http://" + strings.TrimSpace(bs.Text()) + "/"
5535 if _, err := getNoBody(url); err != nil {
5536 b.Fatalf("initial probe of child process failed: %v", err)
5537 }
5538
5539
5540 b.Cleanup(func() {
5541 getNoBody(url + "?stop=yes")
5542 if err := <-done; err != nil {
5543 b.Fatalf("subprocess failed: %v", err)
5544 }
5545
5546 cancel()
5547 <-done
5548
5549 afterTest(b)
5550 })
5551
5552 return url
5553 }
5554
5555 func BenchmarkClientGzip(b *testing.B) {
5556 const responseSize = 1024 * 1024
5557
5558 var buf bytes.Buffer
5559 gz := gzip.NewWriter(&buf)
5560 if _, err := io.CopyN(gz, crand.Reader, responseSize); err != nil {
5561 b.Fatal(err)
5562 }
5563 gz.Close()
5564
5565 data := buf.Bytes()
5566
5567 url := startClientBenchmarkServer(b, HandlerFunc(func(w ResponseWriter, _ *Request) {
5568 w.Header().Set("Content-Encoding", "gzip")
5569 w.Write(data)
5570 }))
5571
5572
5573 b.StartTimer()
5574 for i := 0; i < b.N; i++ {
5575 res, err := Get(url)
5576 if err != nil {
5577 b.Fatalf("Get: %v", err)
5578 }
5579 n, err := io.Copy(io.Discard, res.Body)
5580 res.Body.Close()
5581 if err != nil {
5582 b.Fatalf("ReadAll: %v", err)
5583 }
5584 if n != responseSize {
5585 b.Fatalf("ReadAll: expected %d bytes, got %d", responseSize, n)
5586 }
5587 }
5588 b.StopTimer()
5589 }
5590
5591 func BenchmarkServerFakeConnNoKeepAlive(b *testing.B) {
5592 b.ReportAllocs()
5593 req := reqBytes(`GET / HTTP/1.0
5594 Host: golang.org
5595 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5596 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5597 Accept-Encoding: gzip,deflate,sdch
5598 Accept-Language: en-US,en;q=0.8
5599 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5600 `)
5601 res := []byte("Hello world!\n")
5602
5603 conn := newTestConn()
5604 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5605 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5606 rw.Write(res)
5607 })
5608 ln := new(oneConnListener)
5609 for i := 0; i < b.N; i++ {
5610 conn.readBuf.Reset()
5611 conn.writeBuf.Reset()
5612 conn.readBuf.Write(req)
5613 ln.conn = conn
5614 Serve(ln, handler)
5615 <-conn.closec
5616 }
5617 }
5618
5619
5620 type repeatReader struct {
5621 content []byte
5622 count int
5623 off int
5624 }
5625
5626 func (r *repeatReader) Read(p []byte) (n int, err error) {
5627 if r.count <= 0 {
5628 return 0, io.EOF
5629 }
5630 n = copy(p, r.content[r.off:])
5631 r.off += n
5632 if r.off == len(r.content) {
5633 r.count--
5634 r.off = 0
5635 }
5636 return
5637 }
5638
5639 func BenchmarkServerFakeConnWithKeepAlive(b *testing.B) {
5640 b.ReportAllocs()
5641
5642 req := reqBytes(`GET / HTTP/1.1
5643 Host: golang.org
5644 Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8
5645 User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_2) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.52 Safari/537.17
5646 Accept-Encoding: gzip,deflate,sdch
5647 Accept-Language: en-US,en;q=0.8
5648 Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3
5649 `)
5650 res := []byte("Hello world!\n")
5651
5652 conn := &rwTestConn{
5653 Reader: &repeatReader{content: req, count: b.N},
5654 Writer: io.Discard,
5655 closec: make(chan bool, 1),
5656 }
5657 handled := 0
5658 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5659 handled++
5660 rw.Header().Set("Content-Type", "text/html; charset=utf-8")
5661 rw.Write(res)
5662 })
5663 ln := &oneConnListener{conn: conn}
5664 go Serve(ln, handler)
5665 <-conn.closec
5666 if b.N != handled {
5667 b.Errorf("b.N=%d but handled %d", b.N, handled)
5668 }
5669 }
5670
5671
5672
5673 func BenchmarkServerFakeConnWithKeepAliveLite(b *testing.B) {
5674 b.ReportAllocs()
5675
5676 req := reqBytes(`GET / HTTP/1.1
5677 Host: golang.org
5678 `)
5679 res := []byte("Hello world!\n")
5680
5681 conn := &rwTestConn{
5682 Reader: &repeatReader{content: req, count: b.N},
5683 Writer: io.Discard,
5684 closec: make(chan bool, 1),
5685 }
5686 handled := 0
5687 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5688 handled++
5689 rw.Write(res)
5690 })
5691 ln := &oneConnListener{conn: conn}
5692 go Serve(ln, handler)
5693 <-conn.closec
5694 if b.N != handled {
5695 b.Errorf("b.N=%d but handled %d", b.N, handled)
5696 }
5697 }
5698
5699 const someResponse = "<html>some response</html>"
5700
5701
5702 var response = bytes.Repeat([]byte(someResponse), 2<<10/len(someResponse))
5703
5704
5705 func BenchmarkServerHandlerTypeLen(b *testing.B) {
5706 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5707 w.Header().Set("Content-Type", "text/html")
5708 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5709 w.Write(response)
5710 }))
5711 }
5712
5713
5714 func BenchmarkServerHandlerNoLen(b *testing.B) {
5715 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5716 w.Header().Set("Content-Type", "text/html")
5717 w.Write(response)
5718 }))
5719 }
5720
5721
5722 func BenchmarkServerHandlerNoType(b *testing.B) {
5723 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5724 w.Header().Set("Content-Length", strconv.Itoa(len(response)))
5725 w.Write(response)
5726 }))
5727 }
5728
5729
5730 func BenchmarkServerHandlerNoHeader(b *testing.B) {
5731 benchmarkHandler(b, HandlerFunc(func(w ResponseWriter, r *Request) {
5732 w.Write(response)
5733 }))
5734 }
5735
5736 func benchmarkHandler(b *testing.B, h Handler) {
5737 b.ReportAllocs()
5738 req := reqBytes(`GET / HTTP/1.1
5739 Host: golang.org
5740 `)
5741 conn := &rwTestConn{
5742 Reader: &repeatReader{content: req, count: b.N},
5743 Writer: io.Discard,
5744 closec: make(chan bool, 1),
5745 }
5746 handled := 0
5747 handler := HandlerFunc(func(rw ResponseWriter, r *Request) {
5748 handled++
5749 h.ServeHTTP(rw, r)
5750 })
5751 ln := &oneConnListener{conn: conn}
5752 go Serve(ln, handler)
5753 <-conn.closec
5754 if b.N != handled {
5755 b.Errorf("b.N=%d but handled %d", b.N, handled)
5756 }
5757 }
5758
5759 func BenchmarkServerHijack(b *testing.B) {
5760 b.ReportAllocs()
5761 req := reqBytes(`GET / HTTP/1.1
5762 Host: golang.org
5763 `)
5764 h := HandlerFunc(func(w ResponseWriter, r *Request) {
5765 conn, _, err := w.(Hijacker).Hijack()
5766 if err != nil {
5767 panic(err)
5768 }
5769 conn.Close()
5770 })
5771 conn := &rwTestConn{
5772 Writer: io.Discard,
5773 closec: make(chan bool, 1),
5774 }
5775 ln := &oneConnListener{conn: conn}
5776 for i := 0; i < b.N; i++ {
5777 conn.Reader = bytes.NewReader(req)
5778 ln.conn = conn
5779 Serve(ln, h)
5780 <-conn.closec
5781 }
5782 }
5783
5784 func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) }
5785 func benchmarkCloseNotifier(b *testing.B, mode testMode) {
5786 b.ReportAllocs()
5787 b.StopTimer()
5788 sawClose := make(chan bool)
5789 ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
5790 <-rw.(CloseNotifier).CloseNotify()
5791 sawClose <- true
5792 })).ts
5793 b.StartTimer()
5794 for i := 0; i < b.N; i++ {
5795 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5796 if err != nil {
5797 b.Fatalf("error dialing: %v", err)
5798 }
5799 _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
5800 if err != nil {
5801 b.Fatal(err)
5802 }
5803 conn.Close()
5804 <-sawClose
5805 }
5806 b.StopTimer()
5807 }
5808
5809
5810 func TestConcurrentServerServe(t *testing.T) {
5811 setParallel(t)
5812 for i := 0; i < 100; i++ {
5813 ln1 := &oneConnListener{conn: nil}
5814 ln2 := &oneConnListener{conn: nil}
5815 srv := Server{}
5816 go func() { srv.Serve(ln1) }()
5817 go func() { srv.Serve(ln2) }()
5818 }
5819 }
5820
5821 func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) }
5822 func testServerIdleTimeout(t *testing.T, mode testMode) {
5823 if testing.Short() {
5824 t.Skip("skipping in short mode")
5825 }
5826 runTimeSensitiveTest(t, []time.Duration{
5827 10 * time.Millisecond,
5828 100 * time.Millisecond,
5829 1 * time.Second,
5830 10 * time.Second,
5831 }, func(t *testing.T, readHeaderTimeout time.Duration) error {
5832 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5833 io.Copy(io.Discard, r.Body)
5834 io.WriteString(w, r.RemoteAddr)
5835 }), func(ts *httptest.Server) {
5836 ts.Config.ReadHeaderTimeout = readHeaderTimeout
5837 ts.Config.IdleTimeout = 2 * readHeaderTimeout
5838 })
5839 defer cst.close()
5840 ts := cst.ts
5841 t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
5842 t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
5843 c := ts.Client()
5844
5845 get := func() (string, error) {
5846 res, err := c.Get(ts.URL)
5847 if err != nil {
5848 return "", err
5849 }
5850 defer res.Body.Close()
5851 slurp, err := io.ReadAll(res.Body)
5852 if err != nil {
5853
5854
5855
5856 t.Fatal(err)
5857 }
5858 return string(slurp), nil
5859 }
5860
5861 a1, err := get()
5862 if err != nil {
5863 return err
5864 }
5865 a2, err := get()
5866 if err != nil {
5867 return err
5868 }
5869 if a1 != a2 {
5870 return fmt.Errorf("did requests on different connections")
5871 }
5872 time.Sleep(ts.Config.IdleTimeout * 3 / 2)
5873 a3, err := get()
5874 if err != nil {
5875 return err
5876 }
5877 if a2 == a3 {
5878 return fmt.Errorf("request three unexpectedly on same connection")
5879 }
5880
5881
5882 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
5883 if err != nil {
5884 return err
5885 }
5886 defer conn.Close()
5887 conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
5888 time.Sleep(ts.Config.ReadHeaderTimeout * 2)
5889 if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
5890 return fmt.Errorf("copy byte succeeded; want err")
5891 }
5892
5893 return nil
5894 })
5895 }
5896
5897 func get(t *testing.T, c *Client, url string) string {
5898 res, err := c.Get(url)
5899 if err != nil {
5900 t.Fatal(err)
5901 }
5902 defer res.Body.Close()
5903 slurp, err := io.ReadAll(res.Body)
5904 if err != nil {
5905 t.Fatal(err)
5906 }
5907 return string(slurp)
5908 }
5909
5910
5911
5912 func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) {
5913 run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode})
5914 }
5915 func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) {
5916 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
5917 io.WriteString(w, r.RemoteAddr)
5918 })).ts
5919
5920 c := ts.Client()
5921 tr := c.Transport.(*Transport)
5922
5923 get := func() string { return get(t, c, ts.URL) }
5924
5925 a1, a2 := get(), get()
5926 if a1 == a2 {
5927 t.Logf("made two requests from a single conn %q (as expected)", a1)
5928 } else {
5929 t.Errorf("server reported requests from %q and %q; expected same connection", a1, a2)
5930 }
5931
5932
5933
5934
5935
5936 if conns := tr.IdleConnStrsForTesting(); len(conns) != 1 {
5937 t.Errorf("found %d idle conns (%q); want 1", len(conns), conns)
5938 }
5939
5940
5941 ts.Config.SetKeepAlivesEnabled(false)
5942
5943 waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
5944 if conns := tr.IdleConnStrsForTesting(); len(conns) > 0 {
5945 if d > 0 {
5946 t.Logf("idle conns %v after SetKeepAlivesEnabled called = %q; waiting for empty", d, conns)
5947 }
5948 return false
5949 }
5950 return true
5951 })
5952
5953
5954
5955
5956 }
5957
5958 func TestServerShutdown(t *testing.T) { run(t, testServerShutdown, http3SkippedMode) }
5959 func testServerShutdown(t *testing.T, mode testMode) {
5960 var cst *clientServerTest
5961
5962 var once sync.Once
5963 statesRes := make(chan map[ConnState]int, 1)
5964 shutdownRes := make(chan error, 1)
5965 gotOnShutdown := make(chan struct{})
5966 handler := HandlerFunc(func(w ResponseWriter, r *Request) {
5967 first := false
5968 once.Do(func() {
5969 statesRes <- cst.ts.Config.ExportAllConnsByState()
5970 go func() {
5971 shutdownRes <- cst.ts.Config.Shutdown(context.Background())
5972 }()
5973 first = true
5974 })
5975
5976 if first {
5977
5978
5979
5980 <-gotOnShutdown
5981
5982
5983 for !t.Failed() {
5984 res, err := cst.c.Get(cst.ts.URL)
5985 if err != nil {
5986 break
5987 }
5988 out, _ := io.ReadAll(res.Body)
5989 res.Body.Close()
5990 if mode == http2Mode {
5991 t.Logf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5992 t.Logf("Retrying to work around https://go.dev/issue/59038.")
5993 continue
5994 }
5995 t.Errorf("%v: unexpected success (%q). Listener should be closed before OnShutdown is called.", cst.ts.URL, out)
5996 }
5997 }
5998
5999 io.WriteString(w, r.RemoteAddr)
6000 })
6001
6002 cst = newClientServerTest(t, mode, handler, func(srv *httptest.Server) {
6003 srv.Config.RegisterOnShutdown(func() { close(gotOnShutdown) })
6004 })
6005
6006 out := get(t, cst.c, cst.ts.URL)
6007 t.Logf("%v: %q", cst.ts.URL, out)
6008
6009 if err := <-shutdownRes; err != nil {
6010 t.Fatalf("Shutdown: %v", err)
6011 }
6012 <-gotOnShutdown
6013
6014 if states := <-statesRes; states[StateActive] != 1 {
6015 t.Errorf("connection in wrong state, %v", states)
6016 }
6017 }
6018
6019 func TestServerShutdownStateNew(t *testing.T) {
6020 runSynctest(t, testServerShutdownStateNew, http3SkippedMode)
6021 }
6022 func testServerShutdownStateNew(t *testing.T, mode testMode) {
6023 if testing.Short() {
6024 t.Skip("test takes 5-6 seconds; skipping in short mode")
6025 }
6026
6027 listener := fakeNetListen()
6028 defer listener.Close()
6029
6030 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6031
6032 }), func(ts *httptest.Server) {
6033 ts.Listener.Close()
6034 ts.Listener = listener
6035
6036 ts.Config.ErrorLog = log.New(io.Discard, "", 0)
6037 }).ts
6038
6039
6040 c := listener.connect()
6041 defer c.Close()
6042 synctest.Wait()
6043
6044 shutdownRes := runAsync(func() (struct{}, error) {
6045 return struct{}{}, ts.Config.Shutdown(context.Background())
6046 })
6047
6048
6049
6050
6051 const expectTimeout = 5 * time.Second
6052
6053
6054 time.Sleep(expectTimeout - 1)
6055 synctest.Wait()
6056 if shutdownRes.done() {
6057 t.Fatal("shutdown too soon")
6058 }
6059 if c.IsClosedByPeer() {
6060 t.Fatal("connection was closed by server too soon")
6061 }
6062
6063
6064
6065
6066
6067 time.Sleep(2 * time.Second)
6068 synctest.Wait()
6069 if _, err := shutdownRes.result(); err != nil {
6070 t.Fatalf("Shutdown() = %v, want complete", err)
6071 }
6072 if !c.IsClosedByPeer() {
6073 t.Fatalf("connection was not closed by server after shutdown")
6074 }
6075 }
6076
6077
6078 func TestServerCloseDeadlock(t *testing.T) {
6079 var s Server
6080 s.Close()
6081 s.Close()
6082 }
6083
6084
6085
6086 func TestServerKeepAlivesEnabled(t *testing.T) {
6087 runSynctest(t, testServerKeepAlivesEnabled, http3SkippedMode)
6088 }
6089 func testServerKeepAlivesEnabled(t *testing.T, mode testMode) {
6090 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optFakeNet)
6091 defer cst.close()
6092 srv := cst.ts.Config
6093 srv.SetKeepAlivesEnabled(false)
6094 for try := range 2 {
6095 synctest.Wait()
6096 if !srv.ExportAllConnsIdle() {
6097 t.Fatalf("test server still has active conns before request %v", try)
6098 }
6099 conns := 0
6100 var info httptrace.GotConnInfo
6101 ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
6102 GotConn: func(v httptrace.GotConnInfo) {
6103 conns++
6104 info = v
6105 },
6106 })
6107 req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
6108 if err != nil {
6109 t.Fatal(err)
6110 }
6111 res, err := cst.c.Do(req)
6112 if err != nil {
6113 t.Fatal(err)
6114 }
6115 res.Body.Close()
6116 if conns != 1 {
6117 t.Fatalf("request %v: got %v conns, want 1", try, conns)
6118 }
6119 if info.Reused || info.WasIdle {
6120 t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
6121 }
6122 }
6123 }
6124
6125
6126
6127
6128 func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) }
6129 func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) {
6130 runTimeSensitiveTest(t, []time.Duration{
6131 10 * time.Millisecond,
6132 50 * time.Millisecond,
6133 250 * time.Millisecond,
6134 time.Second,
6135 2 * time.Second,
6136 }, func(t *testing.T, timeout time.Duration) error {
6137 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6138 select {
6139 case <-time.After(2 * timeout):
6140 fmt.Fprint(w, "ok")
6141 case <-r.Context().Done():
6142 fmt.Fprint(w, r.Context().Err())
6143 }
6144 }), func(ts *httptest.Server) {
6145 ts.Config.ReadTimeout = timeout
6146 t.Logf("Server.Config.ReadTimeout = %v", timeout)
6147 })
6148 defer cst.close()
6149 ts := cst.ts
6150
6151 var retries atomic.Int32
6152 cst.c.Transport.(*Transport).Proxy = func(*Request) (*url.URL, error) {
6153 if retries.Add(1) != 1 {
6154 return nil, errors.New("too many retries")
6155 }
6156 return nil, nil
6157 }
6158
6159 c := ts.Client()
6160
6161 res, err := c.Get(ts.URL)
6162 if err != nil {
6163 return fmt.Errorf("Get: %v", err)
6164 }
6165 slurp, err := io.ReadAll(res.Body)
6166 res.Body.Close()
6167 if err != nil {
6168 return fmt.Errorf("Body ReadAll: %v", err)
6169 }
6170 if string(slurp) != "ok" {
6171 return fmt.Errorf("got: %q, want ok", slurp)
6172 }
6173 return nil
6174 })
6175 }
6176
6177
6178
6179
6180 func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) {
6181 run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode})
6182 }
6183 func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) {
6184 runTimeSensitiveTest(t, []time.Duration{
6185 10 * time.Millisecond,
6186 50 * time.Millisecond,
6187 250 * time.Millisecond,
6188 time.Second,
6189 2 * time.Second,
6190 }, func(t *testing.T, timeout time.Duration) error {
6191 cst := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) {
6192 ts.Config.ReadHeaderTimeout = timeout
6193 ts.Config.IdleTimeout = 0
6194 })
6195 defer cst.close()
6196 ts := cst.ts
6197
6198
6199
6200 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
6201 if err != nil {
6202 t.Fatalf("dial failed: %v", err)
6203 }
6204 br := bufio.NewReader(conn)
6205 defer conn.Close()
6206
6207 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6208 return fmt.Errorf("writing first request failed: %v", err)
6209 }
6210
6211 if _, err := ReadResponse(br, nil); err != nil {
6212 return fmt.Errorf("first response (before timeout) failed: %v", err)
6213 }
6214
6215
6216
6217 time.Sleep(timeout * 3 / 2)
6218
6219 if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6220 return fmt.Errorf("writing second request failed: %v", err)
6221 }
6222
6223 if _, err := ReadResponse(br, nil); err != nil {
6224 return fmt.Errorf("second response (after timeout) failed: %v", err)
6225 }
6226
6227 return nil
6228 })
6229 }
6230
6231
6232
6233 func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *testing.T, d time.Duration) error) {
6234 for i, d := range durations {
6235 err := test(t, d)
6236 if err == nil {
6237 return
6238 }
6239 if i == len(durations)-1 || t.Failed() {
6240 t.Fatalf("failed with duration %v: %v", d, err)
6241 }
6242 t.Logf("retrying after error with duration %v: %v", d, err)
6243 }
6244 }
6245
6246
6247
6248 func TestServerDuplicateBackgroundRead(t *testing.T) {
6249 run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode})
6250 }
6251 func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) {
6252 if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" {
6253 testenv.SkipFlaky(t, 24826)
6254 }
6255
6256 goroutines := 5
6257 requests := 2000
6258 if testing.Short() {
6259 goroutines = 3
6260 requests = 100
6261 }
6262
6263 hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts
6264
6265 reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6266
6267 var wg sync.WaitGroup
6268 for i := 0; i < goroutines; i++ {
6269 wg.Add(1)
6270 go func() {
6271 defer wg.Done()
6272 cn, err := net.Dial("tcp", hts.Listener.Addr().String())
6273 if err != nil {
6274 t.Error(err)
6275 return
6276 }
6277 defer cn.Close()
6278
6279 wg.Add(1)
6280 go func() {
6281 defer wg.Done()
6282 io.Copy(io.Discard, cn)
6283 }()
6284
6285 for j := 0; j < requests; j++ {
6286 if t.Failed() {
6287 return
6288 }
6289 _, err := cn.Write(reqBytes)
6290 if err != nil {
6291 t.Error(err)
6292 return
6293 }
6294 }
6295 }()
6296 }
6297 wg.Wait()
6298 }
6299
6300
6301
6302
6303
6304
6305 func TestServerHijackGetsBackgroundByte(t *testing.T) {
6306 run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode})
6307 }
6308 func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) {
6309 if runtime.GOOS == "plan9" {
6310 t.Skip("skipping test; see https://golang.org/issue/18657")
6311 }
6312 done := make(chan struct{})
6313 inHandler := make(chan bool, 1)
6314 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6315 defer close(done)
6316
6317
6318 inHandler <- true
6319
6320 conn, buf, err := w.(Hijacker).Hijack()
6321 if err != nil {
6322 t.Error(err)
6323 return
6324 }
6325 defer conn.Close()
6326
6327 peek, err := buf.Reader.Peek(3)
6328 if string(peek) != "foo" || err != nil {
6329 t.Errorf("Peek = %q, %v; want foo, nil", peek, err)
6330 }
6331
6332 select {
6333 case <-r.Context().Done():
6334 t.Error("context unexpectedly canceled")
6335 default:
6336 }
6337 })).ts
6338
6339 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6340 if err != nil {
6341 t.Fatal(err)
6342 }
6343 defer cn.Close()
6344 if _, err := cn.Write([]byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")); err != nil {
6345 t.Fatal(err)
6346 }
6347 <-inHandler
6348 if _, err := cn.Write([]byte("foo")); err != nil {
6349 t.Fatal(err)
6350 }
6351
6352 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6353 t.Fatal(err)
6354 }
6355 <-done
6356 }
6357
6358
6359 func TestServerHijackGetsFullBody(t *testing.T) {
6360 run(t, testServerHijackGetsFullBody, []testMode{http1Mode})
6361 }
6362 func testServerHijackGetsFullBody(t *testing.T, mode testMode) {
6363 if runtime.GOOS == "plan9" {
6364 t.Skip("skipping test; see https://golang.org/issue/18657")
6365 }
6366 done := make(chan struct{})
6367 needle := strings.Repeat("x", 100*1024)
6368 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6369 defer close(done)
6370
6371 conn, buf, err := w.(Hijacker).Hijack()
6372 if err != nil {
6373 t.Error(err)
6374 return
6375 }
6376 defer conn.Close()
6377
6378 got := make([]byte, len(needle))
6379 n, err := io.ReadFull(buf.Reader, got)
6380 if n != len(needle) || string(got) != needle || err != nil {
6381 t.Errorf("Peek = %q, %v; want 'x'*4096, nil", got, err)
6382 }
6383 })).ts
6384
6385 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6386 if err != nil {
6387 t.Fatal(err)
6388 }
6389 defer cn.Close()
6390 buf := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n")
6391 buf = append(buf, []byte(needle)...)
6392 if _, err := cn.Write(buf); err != nil {
6393 t.Fatal(err)
6394 }
6395
6396 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6397 t.Fatal(err)
6398 }
6399 <-done
6400 }
6401
6402
6403
6404
6405 func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
6406 run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode})
6407 }
6408 func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) {
6409 if runtime.GOOS == "plan9" {
6410 t.Skip("skipping test; see https://golang.org/issue/18657")
6411 }
6412 done := make(chan struct{})
6413 const size = 8 << 10
6414 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6415 defer close(done)
6416
6417 conn, buf, err := w.(Hijacker).Hijack()
6418 if err != nil {
6419 t.Error(err)
6420 return
6421 }
6422 defer conn.Close()
6423 slurp, err := io.ReadAll(buf.Reader)
6424 if err != nil {
6425 t.Errorf("Copy: %v", err)
6426 }
6427 allX := true
6428 for _, v := range slurp {
6429 if v != 'x' {
6430 allX = false
6431 }
6432 }
6433 if len(slurp) != size {
6434 t.Errorf("read %d; want %d", len(slurp), size)
6435 } else if !allX {
6436 t.Errorf("read %q; want %d 'x'", slurp, size)
6437 }
6438 })).ts
6439
6440 cn, err := net.Dial("tcp", ts.Listener.Addr().String())
6441 if err != nil {
6442 t.Fatal(err)
6443 }
6444 defer cn.Close()
6445 if _, err := fmt.Fprintf(cn, "GET / HTTP/1.1\r\nHost: e.com\r\n\r\n%s",
6446 strings.Repeat("x", size)); err != nil {
6447 t.Fatal(err)
6448 }
6449 if err := cn.(*net.TCPConn).CloseWrite(); err != nil {
6450 t.Fatal(err)
6451 }
6452
6453 <-done
6454 }
6455
6456
6457 func TestServerValidatesMethod(t *testing.T) {
6458 tests := []struct {
6459 method string
6460 want int
6461 }{
6462 {"GET", 200},
6463 {"GE(T", 400},
6464 }
6465 for _, tt := range tests {
6466 conn := newTestConn()
6467 io.WriteString(&conn.readBuf, tt.method+" / HTTP/1.1\r\nHost: foo.example\r\n\r\n")
6468
6469 ln := &oneConnListener{conn}
6470 go Serve(ln, serve(200))
6471 <-conn.closec
6472 res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
6473 if err != nil {
6474 t.Errorf("For %s, ReadResponse: %v", tt.method, res)
6475 continue
6476 }
6477 if res.StatusCode != tt.want {
6478 t.Errorf("For %s, Status = %d; want %d", tt.method, res.StatusCode, tt.want)
6479 }
6480 }
6481 }
6482
6483
6484 type eofListenerNotComparable []int
6485
6486 func (eofListenerNotComparable) Accept() (net.Conn, error) { return nil, io.EOF }
6487 func (eofListenerNotComparable) Addr() net.Addr { return nil }
6488 func (eofListenerNotComparable) Close() error { return nil }
6489
6490
6491 func TestServerListenNotComparableListener(t *testing.T) {
6492 var s Server
6493 s.Serve(make(eofListenerNotComparable, 1))
6494 }
6495
6496
6497 type countCloseListener struct {
6498 net.Listener
6499 closes int32
6500 }
6501
6502 func (p *countCloseListener) Close() error {
6503 var err error
6504 if n := atomic.AddInt32(&p.closes, 1); n == 1 && p.Listener != nil {
6505 err = p.Listener.Close()
6506 }
6507 return err
6508 }
6509
6510
6511 func TestServerCloseListenerOnce(t *testing.T) {
6512 setParallel(t)
6513 defer afterTest(t)
6514
6515 ln := newLocalListener(t)
6516 defer ln.Close()
6517
6518 cl := &countCloseListener{Listener: ln}
6519 server := &Server{}
6520 sdone := make(chan bool, 1)
6521
6522 go func() {
6523 server.Serve(cl)
6524 sdone <- true
6525 }()
6526 time.Sleep(10 * time.Millisecond)
6527 server.Shutdown(context.Background())
6528 ln.Close()
6529 <-sdone
6530
6531 nclose := atomic.LoadInt32(&cl.closes)
6532 if nclose != 1 {
6533 t.Errorf("Close calls = %v; want 1", nclose)
6534 }
6535 }
6536
6537
6538 func TestServerShutdownThenServe(t *testing.T) {
6539 var srv Server
6540 cl := &countCloseListener{Listener: nil}
6541 srv.Shutdown(context.Background())
6542 got := srv.Serve(cl)
6543 if got != ErrServerClosed {
6544 t.Errorf("Serve err = %v; want ErrServerClosed", got)
6545 }
6546 nclose := atomic.LoadInt32(&cl.closes)
6547 if nclose != 1 {
6548 t.Errorf("Close calls = %v; want 1", nclose)
6549 }
6550 }
6551
6552
6553 func TestStripPortFromHost(t *testing.T) {
6554 mux := NewServeMux()
6555
6556 mux.HandleFunc("example.com/", func(w ResponseWriter, r *Request) {
6557 fmt.Fprintf(w, "OK")
6558 })
6559 mux.HandleFunc("example.com:9000/", func(w ResponseWriter, r *Request) {
6560 fmt.Fprintf(w, "uh-oh!")
6561 })
6562
6563 req := httptest.NewRequest("GET", "http://example.com:9000/", nil)
6564 rw := httptest.NewRecorder()
6565
6566 mux.ServeHTTP(rw, req)
6567
6568 response := rw.Body.String()
6569 if response != "OK" {
6570 t.Errorf("Response gotten was %q", response)
6571 }
6572 }
6573
6574 func TestServerContexts(t *testing.T) { run(t, testServerContexts, http3SkippedMode) }
6575 func testServerContexts(t *testing.T, mode testMode) {
6576 type baseKey struct{}
6577 type connKey struct{}
6578 ch := make(chan context.Context, 1)
6579 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6580 ch <- r.Context()
6581 }), func(ts *httptest.Server) {
6582 ts.Config.BaseContext = func(ln net.Listener) context.Context {
6583 if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
6584 t.Errorf("unexpected onceClose listener type %T", ln)
6585 }
6586 return context.WithValue(context.Background(), baseKey{}, "base")
6587 }
6588 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6589 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6590 t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
6591 }
6592 return context.WithValue(ctx, connKey{}, "conn")
6593 }
6594 }).ts
6595 res, err := ts.Client().Get(ts.URL)
6596 if err != nil {
6597 t.Fatal(err)
6598 }
6599 res.Body.Close()
6600 ctx := <-ch
6601 if got, want := ctx.Value(baseKey{}), "base"; got != want {
6602 t.Errorf("base context key = %#v; want %q", got, want)
6603 }
6604 if got, want := ctx.Value(connKey{}), "conn"; got != want {
6605 t.Errorf("conn context key = %#v; want %q", got, want)
6606 }
6607 }
6608
6609
6610 func TestConnContextNotModifyingAllContexts(t *testing.T) {
6611 run(t, testConnContextNotModifyingAllContexts)
6612 }
6613 func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) {
6614 type connKey struct{}
6615 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6616 rw.Header().Set("Connection", "close")
6617 }), func(ts *httptest.Server) {
6618 ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
6619 if got := ctx.Value(connKey{}); got != nil {
6620 t.Errorf("in ConnContext, unexpected context key = %#v", got)
6621 }
6622 return context.WithValue(ctx, connKey{}, "conn")
6623 }
6624 }).ts
6625
6626 var res *Response
6627 var err error
6628
6629 res, err = ts.Client().Get(ts.URL)
6630 if err != nil {
6631 t.Fatal(err)
6632 }
6633 res.Body.Close()
6634
6635 res, err = ts.Client().Get(ts.URL)
6636 if err != nil {
6637 t.Fatal(err)
6638 }
6639 res.Body.Close()
6640 }
6641
6642
6643
6644 func TestUnsupportedTransferEncodingsReturn501(t *testing.T) {
6645 run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode})
6646 }
6647 func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) {
6648 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6649 w.Write([]byte("Hello, World!"))
6650 })).ts
6651
6652 serverURL, err := url.Parse(cst.URL)
6653 if err != nil {
6654 t.Fatalf("Failed to parse server URL: %v", err)
6655 }
6656
6657 unsupportedTEs := []string{
6658 "fugazi",
6659 "foo-bar",
6660 "unknown",
6661 `" chunked"`,
6662 }
6663
6664 for _, badTE := range unsupportedTEs {
6665 http1ReqBody := fmt.Sprintf(""+
6666 "POST / HTTP/1.1\r\nConnection: close\r\n"+
6667 "Host: localhost\r\nTransfer-Encoding: %s\r\n\r\n", badTE)
6668
6669 gotBody, err := fetchWireResponse(serverURL.Host, []byte(http1ReqBody))
6670 if err != nil {
6671 t.Errorf("%q. unexpected error: %v", badTE, err)
6672 continue
6673 }
6674
6675 wantBody := fmt.Sprintf("" +
6676 "HTTP/1.1 501 Not Implemented\r\nContent-Type: text/plain; charset=utf-8\r\n" +
6677 "Connection: close\r\n\r\nUnsupported transfer encoding")
6678
6679 if string(gotBody) != wantBody {
6680 t.Errorf("%q. body\ngot\n%q\nwant\n%q", badTE, gotBody, wantBody)
6681 }
6682 }
6683 }
6684
6685
6686 func TestContentEncodingNoSniffing(t *testing.T) {
6687 run(t, testContentEncodingNoSniffing, http3SkippedMode)
6688 }
6689 func testContentEncodingNoSniffing(t *testing.T, mode testMode) {
6690 type setting struct {
6691 name string
6692 body []byte
6693
6694
6695
6696
6697 contentEncoding any
6698 wantContentType string
6699 }
6700
6701 settings := []*setting{
6702 {
6703 name: "gzip content-encoding, gzipped",
6704 contentEncoding: "application/gzip",
6705 wantContentType: "",
6706 body: func() []byte {
6707 buf := new(bytes.Buffer)
6708 gzw := gzip.NewWriter(buf)
6709 gzw.Write([]byte("doctype html><p>Hello</p>"))
6710 gzw.Close()
6711 return buf.Bytes()
6712 }(),
6713 },
6714 {
6715 name: "zlib content-encoding, zlibbed",
6716 contentEncoding: "application/zlib",
6717 wantContentType: "",
6718 body: func() []byte {
6719 buf := new(bytes.Buffer)
6720 zw := zlib.NewWriter(buf)
6721 zw.Write([]byte("doctype html><p>Hello</p>"))
6722 zw.Close()
6723 return buf.Bytes()
6724 }(),
6725 },
6726 {
6727 name: "no content-encoding",
6728 wantContentType: "application/x-gzip",
6729 body: func() []byte {
6730 buf := new(bytes.Buffer)
6731 gzw := gzip.NewWriter(buf)
6732 gzw.Write([]byte("doctype html><p>Hello</p>"))
6733 gzw.Close()
6734 return buf.Bytes()
6735 }(),
6736 },
6737 {
6738 name: "phony content-encoding",
6739 contentEncoding: "foo/bar",
6740 body: []byte("doctype html><p>Hello</p>"),
6741 },
6742 {
6743 name: "empty but set content-encoding",
6744 contentEncoding: "",
6745 wantContentType: "audio/mpeg",
6746 body: []byte("ID3"),
6747 },
6748 }
6749
6750 for _, tt := range settings {
6751 t.Run(tt.name, func(t *testing.T) {
6752 cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) {
6753 if tt.contentEncoding != nil {
6754 rw.Header().Set("Content-Encoding", tt.contentEncoding.(string))
6755 }
6756 rw.Write(tt.body)
6757 }))
6758
6759 res, err := cst.c.Get(cst.ts.URL)
6760 if err != nil {
6761 t.Fatalf("Failed to fetch URL: %v", err)
6762 }
6763 defer res.Body.Close()
6764
6765 if g, w := res.Header.Get("Content-Encoding"), tt.contentEncoding; g != w {
6766 if w != nil {
6767 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
6768 } else if g != "" {
6769 t.Errorf("Unexpected Content-Encoding %q", g)
6770 }
6771 }
6772
6773 if g, w := res.Header.Get("Content-Type"), tt.wantContentType; g != w {
6774 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
6775 }
6776 })
6777 }
6778 }
6779
6780
6781
6782 func TestTimeoutHandlerSuperfluousLogs(t *testing.T) {
6783 run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode})
6784 }
6785 func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) {
6786 if testing.Short() {
6787 t.Skip("skipping in short mode")
6788 }
6789
6790 pc, curFile, _, _ := runtime.Caller(0)
6791 curFileBaseName := filepath.Base(curFile)
6792 testFuncName := runtime.FuncForPC(pc).Name()
6793
6794 timeoutMsg := "timed out here!"
6795
6796 tests := []struct {
6797 name string
6798 mustTimeout bool
6799 wantResp string
6800 }{
6801 {
6802 name: "return before timeout",
6803 wantResp: "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n",
6804 },
6805 {
6806 name: "return after timeout",
6807 mustTimeout: true,
6808 wantResp: fmt.Sprintf("HTTP/1.1 503 Service Unavailable\r\nContent-Length: %d\r\n\r\n%s",
6809 len(timeoutMsg), timeoutMsg),
6810 },
6811 }
6812
6813 for _, tt := range tests {
6814 t.Run(tt.name, func(t *testing.T) {
6815 exitHandler := make(chan bool, 1)
6816 defer close(exitHandler)
6817 lastLine := make(chan int, 1)
6818
6819 sh := HandlerFunc(func(w ResponseWriter, r *Request) {
6820 w.WriteHeader(404)
6821 w.WriteHeader(404)
6822 w.WriteHeader(404)
6823 w.WriteHeader(404)
6824 _, _, line, _ := runtime.Caller(0)
6825 lastLine <- line
6826 <-exitHandler
6827 })
6828
6829 if !tt.mustTimeout {
6830 exitHandler <- true
6831 }
6832
6833 logBuf := new(strings.Builder)
6834 srvLog := log.New(logBuf, "", 0)
6835
6836 dur := 20 * time.Millisecond
6837 if !tt.mustTimeout {
6838
6839 dur = 10 * time.Second
6840 }
6841 th := TimeoutHandler(sh, dur, timeoutMsg)
6842 cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog))
6843 defer cst.close()
6844
6845 res, err := cst.c.Get(cst.ts.URL)
6846 if err != nil {
6847 t.Fatalf("Unexpected error: %v", err)
6848 }
6849
6850
6851
6852 res.Header.Del("Date")
6853 res.Header.Del("Content-Type")
6854
6855
6856 blob, _ := httputil.DumpResponse(res, true)
6857 if g, w := string(blob), tt.wantResp; g != w {
6858 t.Errorf("Response mismatch\nGot\n%q\n\nWant\n%q", g, w)
6859 }
6860
6861
6862
6863 logEntries := strings.Split(strings.TrimSpace(logBuf.String()), "\n")
6864 if g, w := len(logEntries), 3; g != w {
6865 blob, _ := json.MarshalIndent(logEntries, "", " ")
6866 t.Fatalf("Server logs count mismatch\ngot %d, want %d\n\nGot\n%s\n", g, w, blob)
6867 }
6868
6869 lastSpuriousLine := <-lastLine
6870 firstSpuriousLine := lastSpuriousLine - 3
6871
6872
6873 for i, logEntry := range logEntries {
6874 wantLine := firstSpuriousLine + i
6875 pat := fmt.Sprintf("^http: superfluous response.WriteHeader call from %s.func\\d+.\\d+ \\(%s:%d\\)$",
6876 testFuncName, curFileBaseName, wantLine)
6877 re := regexp.MustCompile(pat)
6878 if !re.MatchString(logEntry) {
6879 t.Errorf("Log entry mismatch\n\t%s\ndoes not match\n\t%s", logEntry, pat)
6880 }
6881 }
6882 })
6883 }
6884 }
6885
6886
6887
6888
6889 func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) {
6890 conn, err := net.Dial("tcp", host)
6891 if err != nil {
6892 return nil, err
6893 }
6894 defer conn.Close()
6895
6896 if _, err := conn.Write(http1ReqBody); err != nil {
6897 return nil, err
6898 }
6899 return io.ReadAll(conn)
6900 }
6901
6902 func BenchmarkResponseStatusLine(b *testing.B) {
6903 b.ReportAllocs()
6904 b.RunParallel(func(pb *testing.PB) {
6905 bw := bufio.NewWriter(io.Discard)
6906 var buf3 [3]byte
6907 for pb.Next() {
6908 Export_writeStatusLine(bw, true, 200, buf3[:])
6909 }
6910 })
6911 }
6912
6913 func TestDisableKeepAliveUpgrade(t *testing.T) {
6914 run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode})
6915 }
6916 func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) {
6917 if testing.Short() {
6918 t.Skip("skipping in short mode")
6919 }
6920
6921 s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6922 w.Header().Set("Connection", "Upgrade")
6923 w.Header().Set("Upgrade", "someProto")
6924 w.WriteHeader(StatusSwitchingProtocols)
6925 c, buf, err := w.(Hijacker).Hijack()
6926 if err != nil {
6927 return
6928 }
6929 defer c.Close()
6930
6931
6932
6933 io.Copy(c, buf)
6934 }), func(ts *httptest.Server) {
6935 ts.Config.SetKeepAlivesEnabled(false)
6936 }).ts
6937
6938 cl := s.Client()
6939 cl.Transport.(*Transport).DisableKeepAlives = true
6940
6941 resp, err := cl.Get(s.URL)
6942 if err != nil {
6943 t.Fatalf("failed to perform request: %v", err)
6944 }
6945 defer resp.Body.Close()
6946
6947 if resp.StatusCode != StatusSwitchingProtocols {
6948 t.Fatalf("unexpected status code: %v", resp.StatusCode)
6949 }
6950
6951 rwc, ok := resp.Body.(io.ReadWriteCloser)
6952 if !ok {
6953 t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
6954 }
6955
6956 _, err = rwc.Write([]byte("hello"))
6957 if err != nil {
6958 t.Fatalf("failed to write to body: %v", err)
6959 }
6960
6961 b := make([]byte, 5)
6962 _, err = io.ReadFull(rwc, b)
6963 if err != nil {
6964 t.Fatalf("failed to read from body: %v", err)
6965 }
6966
6967 if string(b) != "hello" {
6968 t.Fatalf("unexpected value read from body:\ngot: %q\nwant: %q", b, "hello")
6969 }
6970 }
6971
6972 type tlogWriter struct{ t *testing.T }
6973
6974 func (w tlogWriter) Write(p []byte) (int, error) {
6975 w.t.Log(string(p))
6976 return len(p), nil
6977 }
6978
6979 func TestWriteHeaderSwitchingProtocols(t *testing.T) {
6980 run(t, testWriteHeaderSwitchingProtocols, []testMode{http1Mode})
6981 }
6982 func testWriteHeaderSwitchingProtocols(t *testing.T, mode testMode) {
6983 const wantBody = "want"
6984 const wantUpgrade = "someProto"
6985 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
6986 w.Header().Set("Connection", "Upgrade")
6987 w.Header().Set("Upgrade", wantUpgrade)
6988 w.WriteHeader(StatusSwitchingProtocols)
6989 NewResponseController(w).Flush()
6990
6991
6992 w.WriteHeader(200)
6993 if _, err := w.Write([]byte("x")); err == nil {
6994 t.Errorf("Write to body after 101 Switching Protocols unexpectedly succeeded")
6995 }
6996
6997 c, _, err := NewResponseController(w).Hijack()
6998 if err != nil {
6999 t.Errorf("Hijack: %v", err)
7000 return
7001 }
7002 defer c.Close()
7003 if _, err := c.Write([]byte(wantBody)); err != nil {
7004 t.Errorf("Write to hijacked body: %v", err)
7005 }
7006 }), func(ts *httptest.Server) {
7007
7008 ts.Config.ErrorLog = log.New(tlogWriter{t}, "log: ", 0)
7009 }).ts
7010
7011 conn, err := net.Dial("tcp", ts.Listener.Addr().String())
7012 if err != nil {
7013 t.Fatalf("net.Dial: %v", err)
7014 }
7015 _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
7016 if err != nil {
7017 t.Fatalf("conn.Write: %v", err)
7018 }
7019 defer conn.Close()
7020
7021 r := bufio.NewReader(conn)
7022 res, err := ReadResponse(r, &Request{Method: "GET"})
7023 if err != nil {
7024 t.Fatal("ReadResponse error:", err)
7025 }
7026 if res.StatusCode != StatusSwitchingProtocols {
7027 t.Errorf("Response StatusCode=%v, want 101", res.StatusCode)
7028 }
7029 if got := res.Header.Get("Upgrade"); got != wantUpgrade {
7030 t.Errorf("Response Upgrade header = %q, want %q", got, wantUpgrade)
7031 }
7032 body, err := io.ReadAll(r)
7033 if err != nil {
7034 t.Error(err)
7035 }
7036 if string(body) != wantBody {
7037 t.Errorf("Response body = %q, want %q", string(body), wantBody)
7038 }
7039 }
7040
7041 func TestMuxRedirectRelative(t *testing.T) {
7042 setParallel(t)
7043 req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET http://example.com HTTP/1.1\r\nHost: test\r\n\r\n")))
7044 if err != nil {
7045 t.Errorf("%s", err)
7046 }
7047 mux := NewServeMux()
7048 resp := httptest.NewRecorder()
7049 mux.ServeHTTP(resp, req)
7050 if got, want := resp.Header().Get("Location"), "/"; got != want {
7051 t.Errorf("Location header expected %q; got %q", want, got)
7052 }
7053 if got, want := resp.Code, StatusTemporaryRedirect; got != want {
7054 t.Errorf("Expected response code %d; got %d", want, got)
7055 }
7056 }
7057
7058
7059 func TestQuerySemicolon(t *testing.T) {
7060 t.Cleanup(func() { afterTest(t) })
7061
7062 tests := []struct {
7063 query string
7064 xNoSemicolons string
7065 xWithSemicolons string
7066 expectParseFormErr bool
7067 }{
7068 {"?a=1;x=bad&x=good", "good", "bad", true},
7069 {"?a=1;b=bad&x=good", "good", "good", true},
7070 {"?a=1%3Bx=bad&x=good%3B", "good;", "good;", false},
7071 {"?a=1;x=good;x=bad", "", "good", true},
7072 }
7073
7074 run(t, func(t *testing.T, mode testMode) {
7075 for _, tt := range tests {
7076 t.Run(tt.query+"/allow=false", func(t *testing.T) {
7077 allowSemicolons := false
7078 testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.expectParseFormErr)
7079 })
7080 t.Run(tt.query+"/allow=true", func(t *testing.T) {
7081 allowSemicolons, expectParseFormErr := true, false
7082 testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectParseFormErr)
7083 })
7084 }
7085 })
7086 }
7087
7088 func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectParseFormErr bool) {
7089 writeBackX := func(w ResponseWriter, r *Request) {
7090 x := r.URL.Query().Get("x")
7091 if expectParseFormErr {
7092 if err := r.ParseForm(); err == nil || !strings.Contains(err.Error(), "semicolon") {
7093 t.Errorf("expected error mentioning semicolons from ParseForm, got %v", err)
7094 }
7095 } else {
7096 if err := r.ParseForm(); err != nil {
7097 t.Errorf("expected no error from ParseForm, got %v", err)
7098 }
7099 }
7100 if got := r.FormValue("x"); x != got {
7101 t.Errorf("got %q from FormValue, want %q", got, x)
7102 }
7103 fmt.Fprintf(w, "%s", x)
7104 }
7105
7106 h := Handler(HandlerFunc(writeBackX))
7107 if allowSemicolons {
7108 h = AllowQuerySemicolons(h)
7109 }
7110
7111 logBuf := &strings.Builder{}
7112 ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) {
7113 ts.Config.ErrorLog = log.New(logBuf, "", 0)
7114 }).ts
7115
7116 req, _ := NewRequest("GET", ts.URL+query, nil)
7117 res, err := ts.Client().Do(req)
7118 if err != nil {
7119 t.Fatal(err)
7120 }
7121 slurp, _ := io.ReadAll(res.Body)
7122 res.Body.Close()
7123 if got, want := res.StatusCode, 200; got != want {
7124 t.Errorf("Status = %d; want = %d", got, want)
7125 }
7126 if got, want := string(slurp), wantX; got != want {
7127 t.Errorf("Body = %q; want = %q", got, want)
7128 }
7129 }
7130
7131 func TestMaxBytesHandler(t *testing.T) {
7132
7133 defer afterTest(t)
7134
7135 for _, maxSize := range []int64{100, 1_000, 1_000_000} {
7136 for _, requestSize := range []int64{100, 1_000, 1_000_000} {
7137 t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
7138 func(t *testing.T) {
7139 run(t, func(t *testing.T, mode testMode) {
7140 testMaxBytesHandler(t, mode, maxSize, requestSize)
7141 }, testNotParallel)
7142 })
7143 }
7144 }
7145 }
7146
7147 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
7148 runTimeSensitiveTest(t, []time.Duration{
7149 1 * time.Millisecond,
7150 5 * time.Millisecond,
7151 10 * time.Millisecond,
7152 50 * time.Millisecond,
7153 100 * time.Millisecond,
7154 500 * time.Millisecond,
7155 time.Second,
7156 5 * time.Second,
7157 }, func(t *testing.T, timeout time.Duration) error {
7158 SetRSTAvoidanceDelay(t, timeout)
7159 t.Logf("set RST avoidance delay to %v", timeout)
7160
7161 var (
7162 mu sync.Mutex
7163 handlerN int64
7164 handlerErr error
7165 )
7166 echo := HandlerFunc(func(w ResponseWriter, r *Request) {
7167 mu.Lock()
7168 defer mu.Unlock()
7169 var buf bytes.Buffer
7170 handlerN, handlerErr = io.Copy(&buf, r.Body)
7171 io.Copy(w, &buf)
7172 })
7173
7174 cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
7175
7176
7177 defer cst.close()
7178 ts := cst.ts
7179 c := ts.Client()
7180
7181 body := strings.Repeat("a", int(requestSize))
7182 var wg sync.WaitGroup
7183 defer wg.Wait()
7184 getBody := func() (io.ReadCloser, error) {
7185 wg.Add(1)
7186 body := &wgReadCloser{
7187 Reader: strings.NewReader(body),
7188 wg: &wg,
7189 }
7190 return body, nil
7191 }
7192 reqBody, _ := getBody()
7193 req, err := NewRequest("POST", ts.URL, reqBody)
7194 if err != nil {
7195 reqBody.Close()
7196 t.Fatal(err)
7197 }
7198 req.ContentLength = int64(len(body))
7199 req.GetBody = getBody
7200 req.Header.Set("Content-Type", "text/plain")
7201
7202 var buf strings.Builder
7203 res, err := c.Do(req)
7204 if err != nil {
7205 return fmt.Errorf("unexpected connection error: %v", err)
7206 } else {
7207 _, err = io.Copy(&buf, res.Body)
7208 res.Body.Close()
7209 if err != nil {
7210 return fmt.Errorf("unexpected read error: %v", err)
7211 }
7212 }
7213
7214
7215
7216
7217 mu.Lock()
7218 defer mu.Unlock()
7219 if handlerN > maxSize {
7220 t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
7221 }
7222 if requestSize > maxSize && handlerErr == nil {
7223 t.Error("expected error on handler side; got nil")
7224 }
7225 if requestSize <= maxSize {
7226 if handlerErr != nil {
7227 t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
7228 }
7229 if handlerN != requestSize {
7230 t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
7231 }
7232 }
7233 if buf.Len() != int(handlerN) {
7234 t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
7235 }
7236
7237 return nil
7238 })
7239 }
7240
7241 func TestEarlyHints(t *testing.T) {
7242 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7243 h := w.Header()
7244 h.Add("Link", "</style.css>; rel=preload; as=style")
7245 h.Add("Link", "</script.js>; rel=preload; as=script")
7246 w.WriteHeader(StatusEarlyHints)
7247
7248 h.Add("Link", "</foo.js>; rel=preload; as=script")
7249 w.WriteHeader(StatusEarlyHints)
7250
7251 w.Write([]byte("stuff"))
7252 }))
7253
7254 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7255 expected := "HTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 103 Early Hints\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\n\r\nHTTP/1.1 200 OK\r\nLink: </style.css>; rel=preload; as=style\r\nLink: </script.js>; rel=preload; as=script\r\nLink: </foo.js>; rel=preload; as=script\r\nDate: "
7256 if !strings.Contains(got, expected) {
7257 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7258 }
7259 }
7260 func TestProcessing(t *testing.T) {
7261 ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
7262 w.WriteHeader(StatusProcessing)
7263 w.Write([]byte("stuff"))
7264 }))
7265
7266 got := ht.rawResponse("GET / HTTP/1.1\nHost: golang.org")
7267 expected := "HTTP/1.1 102 Processing\r\n\r\nHTTP/1.1 200 OK\r\nDate: "
7268 if !strings.Contains(got, expected) {
7269 t.Errorf("unexpected response; got %q; should start by %q", got, expected)
7270 }
7271 }
7272
7273 func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup, http3SkippedMode) }
7274 func testParseFormCleanup(t *testing.T, mode testMode) {
7275 if mode == http2Mode {
7276 t.Skip("https://go.dev/issue/20253")
7277 }
7278
7279 const maxMemory = 1024
7280 const key = "file"
7281
7282 if runtime.GOOS == "windows" {
7283
7284 t.Skip("https://go.dev/issue/25965")
7285 }
7286
7287 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7288 r.ParseMultipartForm(maxMemory)
7289 f, _, err := r.FormFile(key)
7290 if err != nil {
7291 t.Errorf("r.FormFile(%q) = %v", key, err)
7292 return
7293 }
7294 of, ok := f.(*os.File)
7295 if !ok {
7296 t.Errorf("r.FormFile(%q) returned type %T, want *os.File", key, f)
7297 return
7298 }
7299 w.Write([]byte(of.Name()))
7300 }))
7301
7302 fBuf := new(bytes.Buffer)
7303 mw := multipart.NewWriter(fBuf)
7304 mf, err := mw.CreateFormFile(key, "myfile.txt")
7305 if err != nil {
7306 t.Fatal(err)
7307 }
7308 if _, err := mf.Write(bytes.Repeat([]byte("A"), maxMemory*2)); err != nil {
7309 t.Fatal(err)
7310 }
7311 if err := mw.Close(); err != nil {
7312 t.Fatal(err)
7313 }
7314 req, err := NewRequest("POST", cst.ts.URL, fBuf)
7315 if err != nil {
7316 t.Fatal(err)
7317 }
7318 req.Header.Set("Content-Type", mw.FormDataContentType())
7319 res, err := cst.c.Do(req)
7320 if err != nil {
7321 t.Fatal(err)
7322 }
7323 defer res.Body.Close()
7324 fname, err := io.ReadAll(res.Body)
7325 if err != nil {
7326 t.Fatal(err)
7327 }
7328 cst.close()
7329 if _, err := os.Stat(string(fname)); !errors.Is(err, os.ErrNotExist) {
7330 t.Errorf("file %q exists after HTTP handler returned", string(fname))
7331 }
7332 }
7333
7334 func TestHeadBody(t *testing.T) {
7335 const identityMode = false
7336 const chunkedMode = true
7337 run(t, func(t *testing.T, mode testMode) {
7338 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") })
7339 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") })
7340 })
7341 }
7342
7343 func TestGetBody(t *testing.T) {
7344 const identityMode = false
7345 const chunkedMode = true
7346 run(t, func(t *testing.T, mode testMode) {
7347 t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") })
7348 t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") })
7349 })
7350 }
7351
7352 func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) {
7353 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7354 b, err := io.ReadAll(r.Body)
7355 if err != nil {
7356 t.Errorf("server reading body: %v", err)
7357 return
7358 }
7359 w.Header().Set("X-Request-Body", string(b))
7360 w.Header().Set("Content-Length", "0")
7361 }))
7362 defer cst.close()
7363 for _, reqBody := range []string{
7364 "",
7365 "",
7366 "request_body",
7367 "",
7368 } {
7369 var bodyReader io.Reader
7370 if reqBody != "" {
7371 bodyReader = strings.NewReader(reqBody)
7372 if chunked {
7373 bodyReader = bufio.NewReader(bodyReader)
7374 }
7375 }
7376 req, err := NewRequest(method, cst.ts.URL, bodyReader)
7377 if err != nil {
7378 t.Fatal(err)
7379 }
7380 res, err := cst.c.Do(req)
7381 if err != nil {
7382 t.Fatal(err)
7383 }
7384 res.Body.Close()
7385 if got, want := res.StatusCode, 200; got != want {
7386 t.Errorf("%v request with %d-byte body: StatusCode = %v, want %v", method, len(reqBody), got, want)
7387 }
7388 if got, want := res.Header.Get("X-Request-Body"), reqBody; got != want {
7389 t.Errorf("%v request with %d-byte body: handler read body %q, want %q", method, len(reqBody), got, want)
7390 }
7391 }
7392 }
7393
7394
7395
7396 func TestDisableContentLength(t *testing.T) { run(t, testDisableContentLength) }
7397 func testDisableContentLength(t *testing.T, mode testMode) {
7398 noCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7399 w.Header()["Content-Length"] = nil
7400 fmt.Fprintf(w, "OK")
7401 }))
7402
7403 res, err := noCL.c.Get(noCL.ts.URL)
7404 if err != nil {
7405 t.Fatal(err)
7406 }
7407 if got, haveCL := res.Header["Content-Length"]; haveCL {
7408 t.Errorf("Unexpected Content-Length: %q", got)
7409 }
7410 if err := res.Body.Close(); err != nil {
7411 t.Fatal(err)
7412 }
7413
7414 withCL := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7415 fmt.Fprintf(w, "OK")
7416 }))
7417
7418 res, err = withCL.c.Get(withCL.ts.URL)
7419 if err != nil {
7420 t.Fatal(err)
7421 }
7422
7423 if got := res.Header.Get("Content-Length"); got != "2" && mode != http3Mode {
7424 t.Errorf("Content-Length: %q; want 2", got)
7425 }
7426 if err := res.Body.Close(); err != nil {
7427 t.Fatal(err)
7428 }
7429 }
7430
7431 func TestErrorContentLength(t *testing.T) { run(t, testErrorContentLength) }
7432 func testErrorContentLength(t *testing.T, mode testMode) {
7433 const errorBody = "an error occurred"
7434 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7435 w.Header().Set("Content-Length", "1000")
7436 Error(w, errorBody, 400)
7437 }))
7438 res, err := cst.c.Get(cst.ts.URL)
7439 if err != nil {
7440 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7441 }
7442 defer res.Body.Close()
7443 body, err := io.ReadAll(res.Body)
7444 if err != nil {
7445 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7446 }
7447 if string(body) != errorBody+"\n" {
7448 t.Fatalf("read body: %q, want %q", string(body), errorBody)
7449 }
7450 }
7451
7452 func TestError(t *testing.T) {
7453 w := httptest.NewRecorder()
7454 w.Header().Set("Content-Length", "1")
7455 w.Header().Set("X-Content-Type-Options", "scratch and sniff")
7456 w.Header().Set("Other", "foo")
7457 Error(w, "oops", 432)
7458
7459 h := w.Header()
7460 for _, hdr := range []string{"Content-Length"} {
7461 if v, ok := h[hdr]; ok {
7462 t.Errorf("%s: %q, want not present", hdr, v)
7463 }
7464 }
7465 if v := h.Get("Content-Type"); v != "text/plain; charset=utf-8" {
7466 t.Errorf("Content-Type: %q, want %q", v, "text/plain; charset=utf-8")
7467 }
7468 if v := h.Get("X-Content-Type-Options"); v != "nosniff" {
7469 t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff")
7470 }
7471 }
7472
7473 func TestServerReadAfterWriteHeader100Continue(t *testing.T) {
7474 run(t, testServerReadAfterWriteHeader100Continue)
7475 }
7476 func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) {
7477 t.Skip("https://go.dev/issue/67555")
7478 body := []byte("body")
7479 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7480 w.WriteHeader(200)
7481 NewResponseController(w).Flush()
7482 io.ReadAll(r.Body)
7483 w.Write(body)
7484 }), func(tr *Transport) {
7485 tr.ExpectContinueTimeout = 24 * time.Hour
7486 })
7487
7488 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7489 req.Header.Set("Expect", "100-continue")
7490 res, err := cst.c.Do(req)
7491 if err != nil {
7492 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7493 }
7494 defer res.Body.Close()
7495 got, err := io.ReadAll(res.Body)
7496 if err != nil {
7497 t.Fatalf("io.ReadAll(res.Body) = %v", err)
7498 }
7499 if !bytes.Equal(got, body) {
7500 t.Fatalf("response body = %q, want %q", got, body)
7501 }
7502 }
7503
7504 func TestServerReadAfterHandlerDone100Continue(t *testing.T) {
7505 run(t, testServerReadAfterHandlerDone100Continue)
7506 }
7507 func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) {
7508 t.Skip("https://go.dev/issue/67555")
7509 readyc := make(chan struct{})
7510 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7511 go func() {
7512 <-readyc
7513 io.ReadAll(r.Body)
7514 <-readyc
7515 }()
7516 }), func(tr *Transport) {
7517 tr.ExpectContinueTimeout = 24 * time.Hour
7518 })
7519
7520 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7521 req.Header.Set("Expect", "100-continue")
7522 res, err := cst.c.Do(req)
7523 if err != nil {
7524 t.Fatalf("Get(%q) = %v", cst.ts.URL, err)
7525 }
7526 res.Body.Close()
7527 readyc <- struct{}{}
7528 readyc <- struct{}{}
7529 }
7530
7531 func TestServerReadAfterHandlerAbort100Continue(t *testing.T) {
7532 run(t, testServerReadAfterHandlerAbort100Continue)
7533 }
7534 func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) {
7535 t.Skip("https://go.dev/issue/67555")
7536 readyc := make(chan struct{})
7537 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7538 go func() {
7539 <-readyc
7540 io.ReadAll(r.Body)
7541 <-readyc
7542 }()
7543 panic(ErrAbortHandler)
7544 }), func(tr *Transport) {
7545 tr.ExpectContinueTimeout = 24 * time.Hour
7546 })
7547
7548 req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body"))
7549 req.Header.Set("Expect", "100-continue")
7550 res, err := cst.c.Do(req)
7551 if err == nil {
7552 res.Body.Close()
7553 }
7554 readyc <- struct{}{}
7555 readyc <- struct{}{}
7556 }
7557
7558 func TestInvalidChunkedBodies(t *testing.T) {
7559 for _, test := range []struct {
7560 name string
7561 b string
7562 }{{
7563 name: "bare LF in chunk size",
7564 b: "1\na\r\n0\r\n\r\n",
7565 }, {
7566 name: "bare LF at body end",
7567 b: "1\r\na\r\n0\r\n\n",
7568 }} {
7569 t.Run(test.name, func(t *testing.T) {
7570 reqc := make(chan error)
7571 ts := newClientServerTest(t, http1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
7572 got, err := io.ReadAll(r.Body)
7573 if err == nil {
7574 t.Logf("read body: %q", got)
7575 }
7576 reqc <- err
7577 })).ts
7578
7579 serverURL, err := url.Parse(ts.URL)
7580 if err != nil {
7581 t.Fatal(err)
7582 }
7583
7584 conn, err := net.Dial("tcp", serverURL.Host)
7585 if err != nil {
7586 t.Fatal(err)
7587 }
7588
7589 if _, err := conn.Write([]byte(
7590 "POST / HTTP/1.1\r\n" +
7591 "Host: localhost\r\n" +
7592 "Transfer-Encoding: chunked\r\n" +
7593 "Connection: close\r\n" +
7594 "\r\n" +
7595 test.b)); err != nil {
7596 t.Fatal(err)
7597 }
7598 conn.(*net.TCPConn).CloseWrite()
7599
7600 if err := <-reqc; err == nil {
7601 t.Errorf("server handler: io.ReadAll(r.Body) succeeded, want error")
7602 }
7603 })
7604 }
7605 }
7606
7607
7608 func TestServerTLSNextProtos(t *testing.T) {
7609 run(t, testServerTLSNextProtos, []testMode{https1Mode, http2Mode})
7610 }
7611 func testServerTLSNextProtos(t *testing.T, mode testMode) {
7612 CondSkipHTTP2(t)
7613
7614 cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
7615 if err != nil {
7616 t.Fatal(err)
7617 }
7618 leafCert, err := x509.ParseCertificate(cert.Certificate[0])
7619 if err != nil {
7620 t.Fatal(err)
7621 }
7622 certpool := x509.NewCertPool()
7623 certpool.AddCert(leafCert)
7624
7625 protos := new(Protocols)
7626 switch mode {
7627 case https1Mode:
7628 protos.SetHTTP1(true)
7629 case http2Mode:
7630 protos.SetHTTP2(true)
7631 }
7632
7633 wantNextProtos := []string{"http/1.1", "h2", "other"}
7634 nextProtos := slices.Clone(wantNextProtos)
7635
7636
7637 srv := &Server{
7638 TLSConfig: &tls.Config{
7639 Certificates: []tls.Certificate{cert},
7640 NextProtos: nextProtos,
7641 },
7642 Handler: HandlerFunc(func(w ResponseWriter, req *Request) {}),
7643 Protocols: protos,
7644 }
7645 tr := &Transport{
7646 TLSClientConfig: &tls.Config{
7647 RootCAs: certpool,
7648 NextProtos: nextProtos,
7649 },
7650 Protocols: protos,
7651 }
7652
7653 listener := newLocalListener(t)
7654 srvc := make(chan error, 1)
7655 go func() {
7656 srvc <- srv.ServeTLS(listener, "", "")
7657 }()
7658 t.Cleanup(func() {
7659 srv.Close()
7660 <-srvc
7661 })
7662
7663 client := &Client{Transport: tr}
7664 resp, err := client.Get("https://" + listener.Addr().String())
7665 if err != nil {
7666 t.Fatal(err)
7667 }
7668 resp.Body.Close()
7669
7670 if !slices.Equal(nextProtos, wantNextProtos) {
7671 t.Fatalf("after running test: original NextProtos slice = %v, want %v", nextProtos, wantNextProtos)
7672 }
7673 }
7674
7675
7676
7677 func TestServerHTTP2Disabled(t *testing.T) {
7678 synctest.Test(t, func(t *testing.T) {
7679 li := fakeNetListen()
7680 srv := &Server{}
7681 srv.Protocols = new(Protocols)
7682 srv.Protocols.SetHTTP1(true)
7683 go srv.ServeTLS(li, "", "")
7684 synctest.Wait()
7685 srv.Shutdown(t.Context())
7686 })
7687 }
7688
View as plain text