Source file
src/net/http/client_test.go
1
2
3
4
5
6
7 package http_test
8
9 import (
10 "bytes"
11 "context"
12 "crypto/tls"
13 "encoding/base64"
14 "errors"
15 "fmt"
16 "internal/testenv"
17 "io"
18 "log"
19 "net"
20 . "net/http"
21 "net/http/cookiejar"
22 "net/http/httptest"
23 "net/url"
24 "reflect"
25 "runtime"
26 "strconv"
27 "strings"
28 "sync"
29 "sync/atomic"
30 "testing"
31 "time"
32 )
33
34 var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
35 w.Header().Set("Last-Modified", "sometime")
36 fmt.Fprintf(w, "User-agent: go\nDisallow: /something/")
37 })
38
39
40
41 func pedanticReadAll(r io.Reader) (b []byte, err error) {
42 var bufa [64]byte
43 buf := bufa[:]
44 for {
45 n, err := r.Read(buf)
46 if n == 0 && err == nil {
47 return nil, fmt.Errorf("Read: n=0 with err=nil")
48 }
49 b = append(b, buf[:n]...)
50 if err == io.EOF {
51 n, err := r.Read(buf)
52 if n != 0 || err != io.EOF {
53 return nil, fmt.Errorf("Read: n=%d err=%#v after EOF", n, err)
54 }
55 return b, nil
56 }
57 if err != nil {
58 return b, err
59 }
60 }
61 }
62
63 func TestClient(t *testing.T) { run(t, testClient) }
64 func testClient(t *testing.T, mode testMode) {
65 ts := newClientServerTest(t, mode, robotsTxtHandler).ts
66
67 c := ts.Client()
68 r, err := c.Get(ts.URL)
69 var b []byte
70 if err == nil {
71 b, err = pedanticReadAll(r.Body)
72 r.Body.Close()
73 }
74 if err != nil {
75 t.Error(err)
76 } else if s := string(b); !strings.HasPrefix(s, "User-agent:") {
77 t.Errorf("Incorrect page body (did not begin with User-agent): %q", s)
78 }
79 }
80
81 func TestClientHead(t *testing.T) { run(t, testClientHead) }
82 func testClientHead(t *testing.T, mode testMode) {
83 cst := newClientServerTest(t, mode, robotsTxtHandler)
84 r, err := cst.c.Head(cst.ts.URL)
85 if err != nil {
86 t.Fatal(err)
87 }
88 if _, ok := r.Header["Last-Modified"]; !ok {
89 t.Error("Last-Modified header not found.")
90 }
91 }
92
93 type recordingTransport struct {
94 req *Request
95 }
96
97 func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) {
98 t.req = req
99 return nil, errors.New("dummy impl")
100 }
101
102 func TestGetRequestFormat(t *testing.T) {
103 setParallel(t)
104 defer afterTest(t)
105 tr := &recordingTransport{}
106 client := &Client{Transport: tr}
107 url := "http://dummy.faketld/"
108 client.Get(url)
109 if tr.req.Method != "GET" {
110 t.Errorf("expected method %q; got %q", "GET", tr.req.Method)
111 }
112 if tr.req.URL.String() != url {
113 t.Errorf("expected URL %q; got %q", url, tr.req.URL.String())
114 }
115 if tr.req.Header == nil {
116 t.Errorf("expected non-nil request Header")
117 }
118 }
119
120 func TestPostRequestFormat(t *testing.T) {
121 defer afterTest(t)
122 tr := &recordingTransport{}
123 client := &Client{Transport: tr}
124
125 url := "http://dummy.faketld/"
126 json := `{"key":"value"}`
127 b := strings.NewReader(json)
128 client.Post(url, "application/json", b)
129
130 if tr.req.Method != "POST" {
131 t.Errorf("got method %q, want %q", tr.req.Method, "POST")
132 }
133 if tr.req.URL.String() != url {
134 t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
135 }
136 if tr.req.Header == nil {
137 t.Fatalf("expected non-nil request Header")
138 }
139 if tr.req.Close {
140 t.Error("got Close true, want false")
141 }
142 if g, e := tr.req.ContentLength, int64(len(json)); g != e {
143 t.Errorf("got ContentLength %d, want %d", g, e)
144 }
145 }
146
147 func TestPostFormRequestFormat(t *testing.T) {
148 defer afterTest(t)
149 tr := &recordingTransport{}
150 client := &Client{Transport: tr}
151
152 urlStr := "http://dummy.faketld/"
153 form := make(url.Values)
154 form.Set("foo", "bar")
155 form.Add("foo", "bar2")
156 form.Set("bar", "baz")
157 client.PostForm(urlStr, form)
158
159 if tr.req.Method != "POST" {
160 t.Errorf("got method %q, want %q", tr.req.Method, "POST")
161 }
162 if tr.req.URL.String() != urlStr {
163 t.Errorf("got URL %q, want %q", tr.req.URL.String(), urlStr)
164 }
165 if tr.req.Header == nil {
166 t.Fatalf("expected non-nil request Header")
167 }
168 if g, e := tr.req.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; g != e {
169 t.Errorf("got Content-Type %q, want %q", g, e)
170 }
171 if tr.req.Close {
172 t.Error("got Close true, want false")
173 }
174
175 expectedBody := "foo=bar&foo=bar2&bar=baz"
176 expectedBody1 := "bar=baz&foo=bar&foo=bar2"
177 if g, e := tr.req.ContentLength, int64(len(expectedBody)); g != e {
178 t.Errorf("got ContentLength %d, want %d", g, e)
179 }
180 bodyb, err := io.ReadAll(tr.req.Body)
181 if err != nil {
182 t.Fatalf("ReadAll on req.Body: %v", err)
183 }
184 if g := string(bodyb); g != expectedBody && g != expectedBody1 {
185 t.Errorf("got body %q, want %q or %q", g, expectedBody, expectedBody1)
186 }
187 }
188
189 func TestClientRedirects(t *testing.T) { run(t, testClientRedirects) }
190 func testClientRedirects(t *testing.T, mode testMode) {
191 var ts *httptest.Server
192 ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
193 n, _ := strconv.Atoi(r.FormValue("n"))
194
195 if n == 7 {
196 if g, e := r.Referer(), ts.URL+"/?n=6"; e != g {
197 t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g)
198 }
199 }
200 if n < 15 {
201 Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusTemporaryRedirect)
202 return
203 }
204 fmt.Fprintf(w, "n=%d", n)
205 })).ts
206
207 c := ts.Client()
208 _, err := c.Get(ts.URL)
209 if e, g := `Get "/?n=10": stopped after 10 redirects`, fmt.Sprintf("%v", err); e != g {
210 t.Errorf("with default client Get, expected error %q, got %q", e, g)
211 }
212
213
214 _, err = c.Head(ts.URL)
215 if e, g := `Head "/?n=10": stopped after 10 redirects`, fmt.Sprintf("%v", err); e != g {
216 t.Errorf("with default client Head, expected error %q, got %q", e, g)
217 }
218
219
220 greq, _ := NewRequest("GET", ts.URL, nil)
221 _, err = c.Do(greq)
222 if e, g := `Get "/?n=10": stopped after 10 redirects`, fmt.Sprintf("%v", err); e != g {
223 t.Errorf("with default client Do, expected error %q, got %q", e, g)
224 }
225
226
227 greq.Method = ""
228 _, err = c.Do(greq)
229 if e, g := `Get "/?n=10": stopped after 10 redirects`, fmt.Sprintf("%v", err); e != g {
230 t.Errorf("with default client Do and empty Method, expected error %q, got %q", e, g)
231 }
232
233 var checkErr error
234 var lastVia []*Request
235 var lastReq *Request
236 c.CheckRedirect = func(req *Request, via []*Request) error {
237 lastReq = req
238 lastVia = via
239 return checkErr
240 }
241 res, err := c.Get(ts.URL)
242 if err != nil {
243 t.Fatalf("Get error: %v", err)
244 }
245 res.Body.Close()
246 finalURL := res.Request.URL.String()
247 if e, g := "<nil>", fmt.Sprintf("%v", err); e != g {
248 t.Errorf("with custom client, expected error %q, got %q", e, g)
249 }
250 if !strings.HasSuffix(finalURL, "/?n=15") {
251 t.Errorf("expected final url to end in /?n=15; got url %q", finalURL)
252 }
253 if e, g := 15, len(lastVia); e != g {
254 t.Errorf("expected lastVia to have contained %d elements; got %d", e, g)
255 }
256
257
258 creq, _ := NewRequest("HEAD", ts.URL, nil)
259 cancel := make(chan struct{})
260 creq.Cancel = cancel
261 if _, err := c.Do(creq); err != nil {
262 t.Fatal(err)
263 }
264 if lastReq == nil {
265 t.Fatal("didn't see redirect")
266 }
267 if lastReq.Cancel != cancel {
268 t.Errorf("expected lastReq to have the cancel channel set on the initial req")
269 }
270
271 checkErr = errors.New("no redirects allowed")
272 res, err = c.Get(ts.URL)
273 if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr {
274 t.Errorf("with redirects forbidden, expected a *url.Error with our 'no redirects allowed' error inside; got %#v (%q)", err, err)
275 }
276 if res == nil {
277 t.Fatalf("Expected a non-nil Response on CheckRedirect failure (https://golang.org/issue/3795)")
278 }
279 res.Body.Close()
280 if res.Header.Get("Location") == "" {
281 t.Errorf("no Location header in Response")
282 }
283 }
284
285
286 func TestClientRedirectsContext(t *testing.T) { run(t, testClientRedirectsContext) }
287 func testClientRedirectsContext(t *testing.T, mode testMode) {
288 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
289 Redirect(w, r, "/", StatusTemporaryRedirect)
290 })).ts
291
292 ctx, cancel := context.WithCancel(context.Background())
293 c := ts.Client()
294 c.CheckRedirect = func(req *Request, via []*Request) error {
295 cancel()
296 select {
297 case <-req.Context().Done():
298 return nil
299 case <-time.After(5 * time.Second):
300 return errors.New("redirected request's context never expired after root request canceled")
301 }
302 }
303 req, _ := NewRequestWithContext(ctx, "GET", ts.URL, nil)
304 _, err := c.Do(req)
305 ue, ok := err.(*url.Error)
306 if !ok {
307 t.Fatalf("got error %T; want *url.Error", err)
308 }
309 if ue.Err != context.Canceled {
310 t.Errorf("url.Error.Err = %v; want %v", ue.Err, context.Canceled)
311 }
312 }
313
314 type redirectTest struct {
315 suffix string
316 want int
317 redirectBody string
318 }
319
320 func TestPostRedirects(t *testing.T) {
321 postRedirectTests := []redirectTest{
322 {"/", 200, "first"},
323 {"/?code=301&next=302", 200, "c301"},
324 {"/?code=302&next=302", 200, "c302"},
325 {"/?code=303&next=301", 200, "c303wc301"},
326 {"/?code=304", 304, "c304"},
327 {"/?code=305", 305, "c305"},
328 {"/?code=307&next=303,308,302", 200, "c307"},
329 {"/?code=308&next=302,301", 200, "c308"},
330 {"/?code=404", 404, "c404"},
331 }
332
333 wantSegments := []string{
334 `POST / "first"`,
335 `POST /?code=301&next=302 "c301"`,
336 `GET /?code=302 ""`,
337 `GET / ""`,
338 `POST /?code=302&next=302 "c302"`,
339 `GET /?code=302 ""`,
340 `GET / ""`,
341 `POST /?code=303&next=301 "c303wc301"`,
342 `GET /?code=301 ""`,
343 `GET / ""`,
344 `POST /?code=304 "c304"`,
345 `POST /?code=305 "c305"`,
346 `POST /?code=307&next=303,308,302 "c307"`,
347 `POST /?code=303&next=308,302 "c307"`,
348 `GET /?code=308&next=302 ""`,
349 `GET /?code=302 ""`,
350 `GET / ""`,
351 `POST /?code=308&next=302,301 "c308"`,
352 `POST /?code=302&next=301 "c308"`,
353 `GET /?code=301 ""`,
354 `GET / ""`,
355 `POST /?code=404 "c404"`,
356 }
357 want := strings.Join(wantSegments, "\n")
358 run(t, func(t *testing.T, mode testMode) {
359 testRedirectsByMethod(t, mode, "POST", postRedirectTests, want)
360 })
361 }
362
363 func TestDeleteRedirects(t *testing.T) {
364 deleteRedirectTests := []redirectTest{
365 {"/", 200, "first"},
366 {"/?code=301&next=302,308", 200, "c301"},
367 {"/?code=302&next=302", 200, "c302"},
368 {"/?code=303", 200, "c303"},
369 {"/?code=307&next=301,308,303,302,304", 304, "c307"},
370 {"/?code=308&next=307", 200, "c308"},
371 {"/?code=404", 404, "c404"},
372 }
373
374 wantSegments := []string{
375 `DELETE / "first"`,
376 `DELETE /?code=301&next=302,308 "c301"`,
377 `GET /?code=302&next=308 ""`,
378 `GET /?code=308 ""`,
379 `GET / ""`,
380 `DELETE /?code=302&next=302 "c302"`,
381 `GET /?code=302 ""`,
382 `GET / ""`,
383 `DELETE /?code=303 "c303"`,
384 `GET / ""`,
385 `DELETE /?code=307&next=301,308,303,302,304 "c307"`,
386 `DELETE /?code=301&next=308,303,302,304 "c307"`,
387 `GET /?code=308&next=303,302,304 ""`,
388 `GET /?code=303&next=302,304 ""`,
389 `GET /?code=302&next=304 ""`,
390 `GET /?code=304 ""`,
391 `DELETE /?code=308&next=307 "c308"`,
392 `DELETE /?code=307 "c308"`,
393 `DELETE / "c308"`,
394 `DELETE /?code=404 "c404"`,
395 }
396 want := strings.Join(wantSegments, "\n")
397 run(t, func(t *testing.T, mode testMode) {
398 testRedirectsByMethod(t, mode, "DELETE", deleteRedirectTests, want)
399 })
400 }
401
402 func testRedirectsByMethod(t *testing.T, mode testMode, method string, table []redirectTest, want string) {
403 var log struct {
404 sync.Mutex
405 bytes.Buffer
406 }
407 var ts *httptest.Server
408 ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
409 log.Lock()
410 slurp, _ := io.ReadAll(r.Body)
411 fmt.Fprintf(&log.Buffer, "%s %s %q", r.Method, r.RequestURI, slurp)
412 if cl := r.Header.Get("Content-Length"); r.Method == "GET" && len(slurp) == 0 && (r.ContentLength != 0 || cl != "") {
413 fmt.Fprintf(&log.Buffer, " (but with body=%T, content-length = %v, %q)", r.Body, r.ContentLength, cl)
414 }
415 log.WriteByte('\n')
416 log.Unlock()
417 urlQuery := r.URL.Query()
418 if v := urlQuery.Get("code"); v != "" {
419 location := ts.URL
420 if final := urlQuery.Get("next"); final != "" {
421 first, rest, _ := strings.Cut(final, ",")
422 location = fmt.Sprintf("%s?code=%s", location, first)
423 if rest != "" {
424 location = fmt.Sprintf("%s&next=%s", location, rest)
425 }
426 }
427 code, _ := strconv.Atoi(v)
428 if code/100 == 3 {
429 w.Header().Set("Location", location)
430 }
431 w.WriteHeader(code)
432 }
433 })).ts
434
435 c := ts.Client()
436 for _, tt := range table {
437 content := tt.redirectBody
438 req, _ := NewRequest(method, ts.URL+tt.suffix, strings.NewReader(content))
439 req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(strings.NewReader(content)), nil }
440 res, err := c.Do(req)
441
442 if err != nil {
443 t.Fatal(err)
444 }
445 if res.StatusCode != tt.want {
446 t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want)
447 }
448 }
449 log.Lock()
450 got := log.String()
451 log.Unlock()
452
453 got = strings.TrimSpace(got)
454 want = strings.TrimSpace(want)
455
456 if got != want {
457 got, want, lines := removeCommonLines(got, want)
458 t.Errorf("Log differs after %d common lines.\n\nGot:\n%s\n\nWant:\n%s\n", lines, got, want)
459 }
460 }
461
462 func removeCommonLines(a, b string) (asuffix, bsuffix string, commonLines int) {
463 for {
464 nl := strings.IndexByte(a, '\n')
465 if nl < 0 {
466 return a, b, commonLines
467 }
468 line := a[:nl+1]
469 if !strings.HasPrefix(b, line) {
470 return a, b, commonLines
471 }
472 commonLines++
473 a = a[len(line):]
474 b = b[len(line):]
475 }
476 }
477
478 func TestClientRedirectUseResponse(t *testing.T) { run(t, testClientRedirectUseResponse) }
479 func testClientRedirectUseResponse(t *testing.T, mode testMode) {
480 const body = "Hello, world."
481 var ts *httptest.Server
482 ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
483 if strings.Contains(r.URL.Path, "/other") {
484 io.WriteString(w, "wrong body")
485 } else {
486 w.Header().Set("Location", ts.URL+"/other")
487 w.WriteHeader(StatusFound)
488 io.WriteString(w, body)
489 }
490 })).ts
491
492 c := ts.Client()
493 c.CheckRedirect = func(req *Request, via []*Request) error {
494 if req.Response == nil {
495 t.Error("expected non-nil Request.Response")
496 }
497 return ErrUseLastResponse
498 }
499 res, err := c.Get(ts.URL)
500 if err != nil {
501 t.Fatal(err)
502 }
503 if res.StatusCode != StatusFound {
504 t.Errorf("status = %d; want %d", res.StatusCode, StatusFound)
505 }
506 defer res.Body.Close()
507 slurp, err := io.ReadAll(res.Body)
508 if err != nil {
509 t.Fatal(err)
510 }
511 if string(slurp) != body {
512 t.Errorf("body = %q; want %q", slurp, body)
513 }
514 }
515
516
517
518 func TestClientRedirectNoLocation(t *testing.T) { run(t, testClientRedirectNoLocation) }
519 func testClientRedirectNoLocation(t *testing.T, mode testMode) {
520 for _, code := range []int{301, 308} {
521 t.Run(fmt.Sprint(code), func(t *testing.T) {
522 setParallel(t)
523 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
524 w.Header().Set("Foo", "Bar")
525 w.WriteHeader(code)
526 }))
527 res, err := cst.c.Get(cst.ts.URL)
528 if err != nil {
529 t.Fatal(err)
530 }
531 res.Body.Close()
532 if res.StatusCode != code {
533 t.Errorf("status = %d; want %d", res.StatusCode, code)
534 }
535 if got := res.Header.Get("Foo"); got != "Bar" {
536 t.Errorf("Foo header = %q; want Bar", got)
537 }
538 })
539 }
540 }
541
542
543 func TestClientRedirect308NoGetBody(t *testing.T) { run(t, testClientRedirect308NoGetBody) }
544 func testClientRedirect308NoGetBody(t *testing.T, mode testMode) {
545 const fakeURL = "https://localhost:1234/"
546 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
547 w.Header().Set("Location", fakeURL)
548 w.WriteHeader(308)
549 })).ts
550 req, err := NewRequest("POST", ts.URL, strings.NewReader("some body"))
551 if err != nil {
552 t.Fatal(err)
553 }
554 c := ts.Client()
555 req.GetBody = nil
556 res, err := c.Do(req)
557 if err != nil {
558 t.Fatal(err)
559 }
560 res.Body.Close()
561 if res.StatusCode != 308 {
562 t.Errorf("status = %d; want %d", res.StatusCode, 308)
563 }
564 if got := res.Header.Get("Location"); got != fakeURL {
565 t.Errorf("Location header = %q; want %q", got, fakeURL)
566 }
567 }
568
569 var expectedCookies = []*Cookie{
570 {Name: "ChocolateChip", Value: "tasty"},
571 {Name: "First", Value: "Hit"},
572 {Name: "Second", Value: "Hit"},
573 }
574
575 var echoCookiesRedirectHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
576 for _, cookie := range r.Cookies() {
577 SetCookie(w, cookie)
578 }
579 if r.URL.Path == "/" {
580 SetCookie(w, expectedCookies[1])
581 Redirect(w, r, "/second", StatusMovedPermanently)
582 } else {
583 SetCookie(w, expectedCookies[2])
584 w.Write([]byte("hello"))
585 }
586 })
587
588 func TestClientSendsCookieFromJar(t *testing.T) {
589 defer afterTest(t)
590 tr := &recordingTransport{}
591 client := &Client{Transport: tr}
592 client.Jar = &TestJar{perURL: make(map[string][]*Cookie)}
593 us := "http://dummy.faketld/"
594 u, _ := url.Parse(us)
595 client.Jar.SetCookies(u, expectedCookies)
596
597 client.Get(us)
598 matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
599
600 client.Head(us)
601 matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
602
603 client.Post(us, "text/plain", strings.NewReader("body"))
604 matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
605
606 client.PostForm(us, url.Values{})
607 matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
608
609 req, _ := NewRequest("GET", us, nil)
610 client.Do(req)
611 matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
612
613 req, _ = NewRequest("POST", us, nil)
614 client.Do(req)
615 matchReturnedCookies(t, expectedCookies, tr.req.Cookies())
616 }
617
618
619
620 type TestJar struct {
621 m sync.Mutex
622 perURL map[string][]*Cookie
623 }
624
625 func (j *TestJar) SetCookies(u *url.URL, cookies []*Cookie) {
626 j.m.Lock()
627 defer j.m.Unlock()
628 if j.perURL == nil {
629 j.perURL = make(map[string][]*Cookie)
630 }
631 j.perURL[u.Host] = cookies
632 }
633
634 func (j *TestJar) Cookies(u *url.URL) []*Cookie {
635 j.m.Lock()
636 defer j.m.Unlock()
637 return j.perURL[u.Host]
638 }
639
640 func TestRedirectCookiesJar(t *testing.T) { run(t, testRedirectCookiesJar) }
641 func testRedirectCookiesJar(t *testing.T, mode testMode) {
642 var ts *httptest.Server
643 ts = newClientServerTest(t, mode, echoCookiesRedirectHandler).ts
644 c := ts.Client()
645 c.Jar = new(TestJar)
646 u, _ := url.Parse(ts.URL)
647 c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]})
648 resp, err := c.Get(ts.URL)
649 if err != nil {
650 t.Fatalf("Get: %v", err)
651 }
652 resp.Body.Close()
653 matchReturnedCookies(t, expectedCookies, resp.Cookies())
654 }
655
656 func matchReturnedCookies(t *testing.T, expected, given []*Cookie) {
657 if len(given) != len(expected) {
658 t.Logf("Received cookies: %v", given)
659 t.Errorf("Expected %d cookies, got %d", len(expected), len(given))
660 }
661 for _, ec := range expected {
662 foundC := false
663 for _, c := range given {
664 if ec.Name == c.Name && ec.Value == c.Value {
665 foundC = true
666 break
667 }
668 }
669 if !foundC {
670 t.Errorf("Missing cookie %v", ec)
671 }
672 }
673 }
674
675 func TestJarCalls(t *testing.T) { run(t, testJarCalls, []testMode{http1Mode}) }
676 func testJarCalls(t *testing.T, mode testMode) {
677 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
678 pathSuffix := r.RequestURI[1:]
679 if r.RequestURI == "/nosetcookie" {
680 return
681 }
682 SetCookie(w, &Cookie{Name: "name" + pathSuffix, Value: "val" + pathSuffix})
683 if r.RequestURI == "/" {
684 Redirect(w, r, "http://secondhost.fake/secondpath", 302)
685 }
686 })).ts
687 jar := new(RecordingJar)
688 c := ts.Client()
689 c.Jar = jar
690 c.Transport.(*Transport).Dial = func(_ string, _ string) (net.Conn, error) {
691 return net.Dial("tcp", ts.Listener.Addr().String())
692 }
693 _, err := c.Get("http://firsthost.fake/")
694 if err != nil {
695 t.Fatal(err)
696 }
697 _, err = c.Get("http://firsthost.fake/nosetcookie")
698 if err != nil {
699 t.Fatal(err)
700 }
701 got := jar.log.String()
702 want := `Cookies("http://firsthost.fake/")
703 SetCookie("http://firsthost.fake/", [name=val])
704 Cookies("http://secondhost.fake/secondpath")
705 SetCookie("http://secondhost.fake/secondpath", [namesecondpath=valsecondpath])
706 Cookies("http://firsthost.fake/nosetcookie")
707 `
708 if got != want {
709 t.Errorf("Got Jar calls:\n%s\nWant:\n%s", got, want)
710 }
711 }
712
713
714
715 type RecordingJar struct {
716 mu sync.Mutex
717 log bytes.Buffer
718 }
719
720 func (j *RecordingJar) SetCookies(u *url.URL, cookies []*Cookie) {
721 j.logf("SetCookie(%q, %v)\n", u, cookies)
722 }
723
724 func (j *RecordingJar) Cookies(u *url.URL) []*Cookie {
725 j.logf("Cookies(%q)\n", u)
726 return nil
727 }
728
729 func (j *RecordingJar) logf(format string, args ...any) {
730 j.mu.Lock()
731 defer j.mu.Unlock()
732 fmt.Fprintf(&j.log, format, args...)
733 }
734
735 func TestStreamingGet(t *testing.T) { run(t, testStreamingGet) }
736 func testStreamingGet(t *testing.T, mode testMode) {
737 say := make(chan string)
738 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
739 w.(Flusher).Flush()
740 for str := range say {
741 w.Write([]byte(str))
742 w.(Flusher).Flush()
743 }
744 }))
745
746 c := cst.c
747 res, err := c.Get(cst.ts.URL)
748 if err != nil {
749 t.Fatal(err)
750 }
751 var buf [10]byte
752 for _, str := range []string{"i", "am", "also", "known", "as", "comet"} {
753 say <- str
754 n, err := io.ReadFull(res.Body, buf[:len(str)])
755 if err != nil {
756 t.Fatalf("ReadFull on %q: %v", str, err)
757 }
758 if n != len(str) {
759 t.Fatalf("Receiving %q, only read %d bytes", str, n)
760 }
761 got := string(buf[0:n])
762 if got != str {
763 t.Fatalf("Expected %q, got %q", str, got)
764 }
765 }
766 close(say)
767 _, err = io.ReadFull(res.Body, buf[0:1])
768 if err != io.EOF {
769 t.Fatalf("at end expected EOF, got %v", err)
770 }
771 }
772
773 type writeCountingConn struct {
774 net.Conn
775 count *int
776 }
777
778 func (c *writeCountingConn) Write(p []byte) (int, error) {
779 *c.count++
780 return c.Conn.Write(p)
781 }
782
783
784
785 func TestClientWrites(t *testing.T) { run(t, testClientWrites, []testMode{http1Mode}) }
786 func testClientWrites(t *testing.T, mode testMode) {
787 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
788 })).ts
789
790 writes := 0
791 dialer := func(netz string, addr string) (net.Conn, error) {
792 c, err := net.Dial(netz, addr)
793 if err == nil {
794 c = &writeCountingConn{c, &writes}
795 }
796 return c, err
797 }
798 c := ts.Client()
799 c.Transport.(*Transport).Dial = dialer
800
801 _, err := c.Get(ts.URL)
802 if err != nil {
803 t.Fatal(err)
804 }
805 if writes != 1 {
806 t.Errorf("Get request did %d Write calls, want 1", writes)
807 }
808
809 writes = 0
810 _, err = c.PostForm(ts.URL, url.Values{"foo": {"bar"}})
811 if err != nil {
812 t.Fatal(err)
813 }
814 if writes != 1 {
815 t.Errorf("Post request did %d Write calls, want 1", writes)
816 }
817 }
818
819 func TestClientInsecureTransport(t *testing.T) {
820 run(t, testClientInsecureTransport, []testMode{https1Mode, http2Mode})
821 }
822 func testClientInsecureTransport(t *testing.T, mode testMode) {
823 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
824 w.Write([]byte("Hello"))
825 }))
826 ts := cst.ts
827 errLog := new(strings.Builder)
828 ts.Config.ErrorLog = log.New(errLog, "", 0)
829
830
831
832
833 c := ts.Client()
834 for _, insecure := range []bool{true, false} {
835 c.Transport.(*Transport).TLSClientConfig = &tls.Config{
836 InsecureSkipVerify: insecure,
837 NextProtos: cst.tr.TLSClientConfig.NextProtos,
838 }
839 req, _ := NewRequest("GET", ts.URL, nil)
840 req.Header.Set("Connection", "close")
841 res, err := c.Do(req)
842 if (err == nil) != insecure {
843 t.Errorf("insecure=%v: got unexpected err=%v", insecure, err)
844 }
845 if res != nil {
846 res.Body.Close()
847 }
848 }
849
850 cst.close()
851 if !strings.Contains(errLog.String(), "TLS handshake error") {
852 t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", errLog)
853 }
854 }
855
856 func TestClientErrorWithRequestURI(t *testing.T) {
857 defer afterTest(t)
858 req, _ := NewRequest("GET", "http://localhost:1234/", nil)
859 req.RequestURI = "/this/field/is/illegal/and/should/error/"
860 _, err := DefaultClient.Do(req)
861 if err == nil {
862 t.Fatalf("expected an error")
863 }
864 if !strings.Contains(err.Error(), "RequestURI") {
865 t.Errorf("wanted error mentioning RequestURI; got error: %v", err)
866 }
867 }
868
869 func TestClientWithCorrectTLSServerName(t *testing.T) {
870 run(t, testClientWithCorrectTLSServerName, []testMode{https1Mode, http2Mode})
871 }
872 func testClientWithCorrectTLSServerName(t *testing.T, mode testMode) {
873 const serverName = "example.com"
874 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
875 if r.TLS.ServerName != serverName {
876 t.Errorf("expected client to set ServerName %q, got: %q", serverName, r.TLS.ServerName)
877 }
878 })).ts
879
880 c := ts.Client()
881 c.Transport.(*Transport).TLSClientConfig.ServerName = serverName
882 if _, err := c.Get(ts.URL); err != nil {
883 t.Fatalf("expected successful TLS connection, got error: %v", err)
884 }
885 }
886
887 func TestClientWithIncorrectTLSServerName(t *testing.T) {
888 run(t, testClientWithIncorrectTLSServerName, []testMode{https1Mode, http2Mode})
889 }
890 func testClientWithIncorrectTLSServerName(t *testing.T, mode testMode) {
891 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}))
892 ts := cst.ts
893 errLog := new(strings.Builder)
894 ts.Config.ErrorLog = log.New(errLog, "", 0)
895
896 c := ts.Client()
897 c.Transport.(*Transport).TLSClientConfig.ServerName = "badserver"
898 _, err := c.Get(ts.URL)
899 if err == nil {
900 t.Fatalf("expected an error")
901 }
902 if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
903 t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
904 }
905
906 cst.close()
907 if !strings.Contains(errLog.String(), "TLS handshake error") {
908 t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", errLog)
909 }
910 }
911
912
913
914
915
916
917
918
919
920
921 func TestTransportUsesTLSConfigServerName(t *testing.T) {
922 run(t, testTransportUsesTLSConfigServerName, []testMode{https1Mode, http2Mode})
923 }
924 func testTransportUsesTLSConfigServerName(t *testing.T, mode testMode) {
925 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
926 w.Write([]byte("Hello"))
927 })).ts
928
929 c := ts.Client()
930 tr := c.Transport.(*Transport)
931 tr.TLSClientConfig.ServerName = "example.com"
932 tr.Dial = func(netw, addr string) (net.Conn, error) {
933 return net.Dial(netw, ts.Listener.Addr().String())
934 }
935 res, err := c.Get("https://some-other-host.tld/")
936 if err != nil {
937 t.Fatal(err)
938 }
939 res.Body.Close()
940 }
941
942 func TestResponseSetsTLSConnectionState(t *testing.T) {
943 run(t, testResponseSetsTLSConnectionState, []testMode{https1Mode})
944 }
945 func testResponseSetsTLSConnectionState(t *testing.T, mode testMode) {
946 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
947 w.Write([]byte("Hello"))
948 })).ts
949
950 c := ts.Client()
951 tr := c.Transport.(*Transport)
952 tr.TLSClientConfig.CipherSuites = []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}
953 tr.TLSClientConfig.MaxVersion = tls.VersionTLS12
954 tr.Dial = func(netw, addr string) (net.Conn, error) {
955 return net.Dial(netw, ts.Listener.Addr().String())
956 }
957 res, err := c.Get("https://example.com/")
958 if err != nil {
959 t.Fatal(err)
960 }
961 defer res.Body.Close()
962 if res.TLS == nil {
963 t.Fatal("Response didn't set TLS Connection State.")
964 }
965 if got, want := res.TLS.CipherSuite, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256; got != want {
966 t.Errorf("TLS Cipher Suite = %d; want %d", got, want)
967 }
968 }
969
970
971
972
973 func TestHTTPSClientDetectsHTTPServer(t *testing.T) {
974 run(t, testHTTPSClientDetectsHTTPServer, []testMode{http1Mode})
975 }
976 func testHTTPSClientDetectsHTTPServer(t *testing.T, mode testMode) {
977 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts
978 ts.Config.ErrorLog = quietLog
979
980 _, err := Get(strings.Replace(ts.URL, "http", "https", 1))
981 if got := err.Error(); !strings.Contains(got, "HTTP response to HTTPS client") {
982 t.Fatalf("error = %q; want error indicating HTTP response to HTTPS request", got)
983 }
984 }
985
986
987 func TestClientHeadContentLength(t *testing.T) { run(t, testClientHeadContentLength) }
988 func testClientHeadContentLength(t *testing.T, mode testMode) {
989 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
990 if v := r.FormValue("cl"); v != "" {
991 w.Header().Set("Content-Length", v)
992 }
993 }))
994 tests := []struct {
995 suffix string
996 want int64
997 }{
998 {"/?cl=1234", 1234},
999 {"/?cl=0", 0},
1000 {"", -1},
1001 }
1002 for _, tt := range tests {
1003 req, _ := NewRequest("HEAD", cst.ts.URL+tt.suffix, nil)
1004 res, err := cst.c.Do(req)
1005 if err != nil {
1006 t.Fatal(err)
1007 }
1008 if res.ContentLength != tt.want {
1009 t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want)
1010 }
1011 bs, err := io.ReadAll(res.Body)
1012 if err != nil {
1013 t.Fatal(err)
1014 }
1015 if len(bs) != 0 {
1016 t.Errorf("Unexpected content: %q", bs)
1017 }
1018 }
1019 }
1020
1021 func TestEmptyPasswordAuth(t *testing.T) { run(t, testEmptyPasswordAuth) }
1022 func testEmptyPasswordAuth(t *testing.T, mode testMode) {
1023 gopher := "gopher"
1024 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1025 auth := r.Header.Get("Authorization")
1026 if strings.HasPrefix(auth, "Basic ") {
1027 encoded := auth[6:]
1028 decoded, err := base64.StdEncoding.DecodeString(encoded)
1029 if err != nil {
1030 t.Fatal(err)
1031 }
1032 expected := gopher + ":"
1033 s := string(decoded)
1034 if expected != s {
1035 t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
1036 }
1037 } else {
1038 t.Errorf("Invalid auth %q", auth)
1039 }
1040 })).ts
1041 defer ts.Close()
1042 req, err := NewRequest("GET", ts.URL, nil)
1043 if err != nil {
1044 t.Fatal(err)
1045 }
1046 req.URL.User = url.User(gopher)
1047 c := ts.Client()
1048 resp, err := c.Do(req)
1049 if err != nil {
1050 t.Fatal(err)
1051 }
1052 defer resp.Body.Close()
1053 }
1054
1055 func TestBasicAuth(t *testing.T) {
1056 defer afterTest(t)
1057 tr := &recordingTransport{}
1058 client := &Client{Transport: tr}
1059
1060 url := "http://My%20User:My%20Pass@dummy.faketld/"
1061 expected := "My User:My Pass"
1062 client.Get(url)
1063
1064 if tr.req.Method != "GET" {
1065 t.Errorf("got method %q, want %q", tr.req.Method, "GET")
1066 }
1067 if tr.req.URL.String() != url {
1068 t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
1069 }
1070 if tr.req.Header == nil {
1071 t.Fatalf("expected non-nil request Header")
1072 }
1073 auth := tr.req.Header.Get("Authorization")
1074 if strings.HasPrefix(auth, "Basic ") {
1075 encoded := auth[6:]
1076 decoded, err := base64.StdEncoding.DecodeString(encoded)
1077 if err != nil {
1078 t.Fatal(err)
1079 }
1080 s := string(decoded)
1081 if expected != s {
1082 t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
1083 }
1084 } else {
1085 t.Errorf("Invalid auth %q", auth)
1086 }
1087 }
1088
1089 func TestBasicAuthHeadersPreserved(t *testing.T) {
1090 defer afterTest(t)
1091 tr := &recordingTransport{}
1092 client := &Client{Transport: tr}
1093
1094
1095 url := "http://My%20User@dummy.faketld/"
1096 req, err := NewRequest("GET", url, nil)
1097 if err != nil {
1098 t.Fatal(err)
1099 }
1100 req.SetBasicAuth("My User", "My Pass")
1101 expected := "My User:My Pass"
1102 client.Do(req)
1103
1104 if tr.req.Method != "GET" {
1105 t.Errorf("got method %q, want %q", tr.req.Method, "GET")
1106 }
1107 if tr.req.URL.String() != url {
1108 t.Errorf("got URL %q, want %q", tr.req.URL.String(), url)
1109 }
1110 if tr.req.Header == nil {
1111 t.Fatalf("expected non-nil request Header")
1112 }
1113 auth := tr.req.Header.Get("Authorization")
1114 if strings.HasPrefix(auth, "Basic ") {
1115 encoded := auth[6:]
1116 decoded, err := base64.StdEncoding.DecodeString(encoded)
1117 if err != nil {
1118 t.Fatal(err)
1119 }
1120 s := string(decoded)
1121 if expected != s {
1122 t.Errorf("Invalid Authorization header. Got %q, wanted %q", s, expected)
1123 }
1124 } else {
1125 t.Errorf("Invalid auth %q", auth)
1126 }
1127
1128 }
1129
1130 func TestStripPasswordFromError(t *testing.T) {
1131 client := &Client{Transport: &recordingTransport{}}
1132 testCases := []struct {
1133 desc string
1134 in string
1135 out string
1136 }{
1137 {
1138 desc: "Strip password from error message",
1139 in: "http://user:password@dummy.faketld/",
1140 out: `Get "http://user:***@dummy.faketld/": dummy impl`,
1141 },
1142 {
1143 desc: "Don't Strip password from domain name",
1144 in: "http://user:password@password.faketld/",
1145 out: `Get "http://user:***@password.faketld/": dummy impl`,
1146 },
1147 {
1148 desc: "Don't Strip password from path",
1149 in: "http://user:password@dummy.faketld/password",
1150 out: `Get "http://user:***@dummy.faketld/password": dummy impl`,
1151 },
1152 {
1153 desc: "Strip escaped password",
1154 in: "http://user:pa%2Fssword@dummy.faketld/",
1155 out: `Get "http://user:***@dummy.faketld/": dummy impl`,
1156 },
1157 }
1158 for _, tC := range testCases {
1159 t.Run(tC.desc, func(t *testing.T) {
1160 _, err := client.Get(tC.in)
1161 if err.Error() != tC.out {
1162 t.Errorf("Unexpected output for %q: expected %q, actual %q",
1163 tC.in, tC.out, err.Error())
1164 }
1165 })
1166 }
1167 }
1168
1169 func TestClientTimeout(t *testing.T) { run(t, testClientTimeout) }
1170 func testClientTimeout(t *testing.T, mode testMode) {
1171 var (
1172 mu sync.Mutex
1173 nonce string
1174 sawSlowNonce bool
1175 )
1176 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1177 _ = r.ParseForm()
1178 if r.URL.Path == "/" {
1179 Redirect(w, r, "/slow?nonce="+r.Form.Get("nonce"), StatusFound)
1180 return
1181 }
1182 if r.URL.Path == "/slow" {
1183 mu.Lock()
1184 if r.Form.Get("nonce") == nonce {
1185 sawSlowNonce = true
1186 } else {
1187 t.Logf("mismatched nonce: received %s, want %s", r.Form.Get("nonce"), nonce)
1188 }
1189 mu.Unlock()
1190
1191 w.Write([]byte("Hello"))
1192 w.(Flusher).Flush()
1193 <-r.Context().Done()
1194 return
1195 }
1196 }))
1197
1198
1199
1200
1201
1202
1203
1204 timeout := 10 * time.Millisecond
1205 nextNonce := 0
1206 for ; ; timeout *= 2 {
1207 if timeout <= 0 {
1208
1209
1210 t.Fatalf("timeout overflow")
1211 }
1212 if deadline, ok := t.Deadline(); ok && !time.Now().Add(timeout).Before(deadline) {
1213 t.Fatalf("failed to produce expected timeout before test deadline")
1214 }
1215 t.Logf("attempting test with timeout %v", timeout)
1216 cst.c.Timeout = timeout
1217
1218 mu.Lock()
1219 nonce = fmt.Sprint(nextNonce)
1220 nextNonce++
1221 sawSlowNonce = false
1222 mu.Unlock()
1223 res, err := cst.c.Get(cst.ts.URL + "/?nonce=" + nonce)
1224 if err != nil {
1225 if strings.Contains(err.Error(), "Client.Timeout") {
1226
1227 t.Logf("timeout before response received")
1228 continue
1229 }
1230 if runtime.GOOS == "windows" && strings.HasPrefix(runtime.GOARCH, "arm") {
1231 testenv.SkipFlaky(t, 43120)
1232 }
1233 t.Fatal(err)
1234 }
1235
1236 mu.Lock()
1237 ok := sawSlowNonce
1238 mu.Unlock()
1239 if !ok {
1240 t.Fatal("handler never got /slow request, but client returned response")
1241 }
1242
1243 _, err = io.ReadAll(res.Body)
1244 res.Body.Close()
1245
1246 if err == nil {
1247 t.Fatal("expected error from ReadAll")
1248 }
1249 ne, ok := err.(net.Error)
1250 if !ok {
1251 t.Errorf("error value from ReadAll was %T; expected some net.Error", err)
1252 } else if !ne.Timeout() {
1253 t.Errorf("net.Error.Timeout = false; want true")
1254 }
1255 if !errors.Is(err, context.DeadlineExceeded) {
1256 t.Errorf("ReadAll error = %q; expected some context.DeadlineExceeded", err)
1257 }
1258 if got := ne.Error(); !strings.Contains(got, "(Client.Timeout") {
1259 if runtime.GOOS == "windows" && strings.HasPrefix(runtime.GOARCH, "arm") {
1260 testenv.SkipFlaky(t, 43120)
1261 }
1262 t.Errorf("error string = %q; missing timeout substring", got)
1263 }
1264
1265 break
1266 }
1267 }
1268
1269
1270 func TestClientTimeout_Headers(t *testing.T) { run(t, testClientTimeout_Headers) }
1271 func testClientTimeout_Headers(t *testing.T, mode testMode) {
1272 donec := make(chan bool, 1)
1273 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1274 <-donec
1275 }), optQuietLog)
1276
1277
1278
1279
1280
1281
1282
1283 defer func() { donec <- true }()
1284
1285 cst.c.Timeout = 5 * time.Millisecond
1286 res, err := cst.c.Get(cst.ts.URL)
1287 if err == nil {
1288 res.Body.Close()
1289 t.Fatal("got response from Get; expected error")
1290 }
1291 if _, ok := err.(*url.Error); !ok {
1292 t.Fatalf("Got error of type %T; want *url.Error", err)
1293 }
1294 ne, ok := err.(net.Error)
1295 if !ok {
1296 t.Fatalf("Got error of type %T; want some net.Error", err)
1297 }
1298 if !ne.Timeout() {
1299 t.Error("net.Error.Timeout = false; want true")
1300 }
1301 if !errors.Is(err, context.DeadlineExceeded) {
1302 t.Errorf("ReadAll error = %q; expected some context.DeadlineExceeded", err)
1303 }
1304 if got := ne.Error(); !strings.Contains(got, "Client.Timeout exceeded") {
1305 if runtime.GOOS == "windows" && strings.HasPrefix(runtime.GOARCH, "arm") {
1306 testenv.SkipFlaky(t, 43120)
1307 }
1308 t.Errorf("error string = %q; missing timeout substring", got)
1309 }
1310 }
1311
1312
1313
1314 func TestClientTimeoutCancel(t *testing.T) { run(t, testClientTimeoutCancel) }
1315 func testClientTimeoutCancel(t *testing.T, mode testMode) {
1316 testDone := make(chan struct{})
1317 ctx, cancel := context.WithCancel(context.Background())
1318
1319 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1320 w.(Flusher).Flush()
1321 <-testDone
1322 }))
1323 defer close(testDone)
1324
1325 cst.c.Timeout = 1 * time.Hour
1326 req, _ := NewRequest("GET", cst.ts.URL, nil)
1327 req.Cancel = ctx.Done()
1328 res, err := cst.c.Do(req)
1329 if err != nil {
1330 t.Fatal(err)
1331 }
1332 cancel()
1333 _, err = io.Copy(io.Discard, res.Body)
1334 if err != ExportErrRequestCanceled {
1335 t.Fatalf("error = %v; want errRequestCanceled", err)
1336 }
1337 }
1338
1339
1340 func TestClientTimeoutDoesNotExpire(t *testing.T) { run(t, testClientTimeoutDoesNotExpire) }
1341 func testClientTimeoutDoesNotExpire(t *testing.T, mode testMode) {
1342 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1343 w.Write([]byte("body"))
1344 }))
1345
1346 cst.c.Timeout = 1 * time.Hour
1347 req, _ := NewRequest("GET", cst.ts.URL, nil)
1348 res, err := cst.c.Do(req)
1349 if err != nil {
1350 t.Fatal(err)
1351 }
1352 if _, err = io.Copy(io.Discard, res.Body); err != nil {
1353 t.Fatalf("io.Copy(io.Discard, res.Body) = %v, want nil", err)
1354 }
1355 if err = res.Body.Close(); err != nil {
1356 t.Fatalf("res.Body.Close() = %v, want nil", err)
1357 }
1358 }
1359
1360 func TestClientRedirectEatsBody_h1(t *testing.T) { run(t, testClientRedirectEatsBody) }
1361 func testClientRedirectEatsBody(t *testing.T, mode testMode) {
1362 saw := make(chan string, 2)
1363 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1364 saw <- r.RemoteAddr
1365 if r.URL.Path == "/" {
1366 Redirect(w, r, "/foo", StatusFound)
1367 }
1368 }))
1369
1370 res, err := cst.c.Get(cst.ts.URL)
1371 if err != nil {
1372 t.Fatal(err)
1373 }
1374 _, err = io.ReadAll(res.Body)
1375 res.Body.Close()
1376 if err != nil {
1377 t.Fatal(err)
1378 }
1379
1380 var first string
1381 select {
1382 case first = <-saw:
1383 default:
1384 t.Fatal("server didn't see a request")
1385 }
1386
1387 var second string
1388 select {
1389 case second = <-saw:
1390 default:
1391 t.Fatal("server didn't see a second request")
1392 }
1393
1394 if first != second {
1395 t.Fatal("server saw different client ports before & after the redirect")
1396 }
1397 }
1398
1399
1400 type eofReaderFunc func()
1401
1402 func (f eofReaderFunc) Read(p []byte) (n int, err error) {
1403 f()
1404 return 0, io.EOF
1405 }
1406
1407 func TestReferer(t *testing.T) {
1408 tests := []struct {
1409 lastReq, newReq, explicitRef string
1410 want string
1411 }{
1412
1413 {lastReq: "http://gopher@test.com", newReq: "http://link.com", want: "http://test.com"},
1414 {lastReq: "https://gopher@test.com", newReq: "https://link.com", want: "https://test.com"},
1415
1416
1417 {lastReq: "http://gopher:go@test.com", newReq: "http://link.com", want: "http://test.com"},
1418 {lastReq: "https://gopher:go@test.com", newReq: "https://link.com", want: "https://test.com"},
1419
1420
1421 {lastReq: "http://test.com", newReq: "http://link.com", want: "http://test.com"},
1422 {lastReq: "https://test.com", newReq: "https://link.com", want: "https://test.com"},
1423
1424
1425 {lastReq: "https://test.com", newReq: "http://link.com", want: ""},
1426 {lastReq: "https://gopher:go@test.com", newReq: "http://link.com", want: ""},
1427
1428
1429 {lastReq: "https://test.com", newReq: "http://link.com", explicitRef: "https://foo.com", want: ""},
1430 {lastReq: "https://gopher:go@test.com", newReq: "http://link.com", explicitRef: "https://foo.com", want: ""},
1431
1432
1433 {lastReq: "https://test.com", newReq: "https://link.com", explicitRef: "https://foo.com", want: "https://foo.com"},
1434 {lastReq: "https://gopher:go@test.com", newReq: "https://link.com", explicitRef: "https://foo.com", want: "https://foo.com"},
1435 }
1436 for _, tt := range tests {
1437 l, err := url.Parse(tt.lastReq)
1438 if err != nil {
1439 t.Fatal(err)
1440 }
1441 n, err := url.Parse(tt.newReq)
1442 if err != nil {
1443 t.Fatal(err)
1444 }
1445 r := ExportRefererForURL(l, n, tt.explicitRef)
1446 if r != tt.want {
1447 t.Errorf("refererForURL(%q, %q) = %q; want %q", tt.lastReq, tt.newReq, r, tt.want)
1448 }
1449 }
1450 }
1451
1452
1453
1454 type issue15577Tripper struct{}
1455
1456 func (issue15577Tripper) RoundTrip(*Request) (*Response, error) {
1457 resp := &Response{
1458 StatusCode: 303,
1459 Header: map[string][]string{"Location": {"http://www.example.com/"}},
1460 Body: io.NopCloser(strings.NewReader("")),
1461 }
1462 return resp, nil
1463 }
1464
1465
1466 func TestClientRedirectResponseWithoutRequest(t *testing.T) {
1467 c := &Client{
1468 CheckRedirect: func(*Request, []*Request) error { return fmt.Errorf("no redirects!") },
1469 Transport: issue15577Tripper{},
1470 }
1471
1472 c.Get("http://dummy.tld")
1473 }
1474
1475
1476
1477
1478
1479 func TestClientCopyHeadersOnRedirect(t *testing.T) { run(t, testClientCopyHeadersOnRedirect) }
1480 func testClientCopyHeadersOnRedirect(t *testing.T, mode testMode) {
1481 const (
1482 ua = "some-agent/1.2"
1483 xfoo = "foo-val"
1484 )
1485 var ts2URL string
1486 ts1 := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1487 want := Header{
1488 "User-Agent": []string{ua},
1489 "X-Foo": []string{xfoo},
1490 "Referer": []string{ts2URL},
1491 "Accept-Encoding": []string{"gzip"},
1492 "Cookie": []string{"foo=bar"},
1493 "Authorization": []string{"secretpassword"},
1494 }
1495 if !reflect.DeepEqual(r.Header, want) {
1496 t.Errorf("Request.Header = %#v; want %#v", r.Header, want)
1497 }
1498 if t.Failed() {
1499 w.Header().Set("Result", "got errors")
1500 } else {
1501 w.Header().Set("Result", "ok")
1502 }
1503 })).ts
1504 ts2 := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1505 Redirect(w, r, ts1.URL, StatusFound)
1506 })).ts
1507 ts2URL = ts2.URL
1508
1509 c := ts1.Client()
1510 c.CheckRedirect = func(r *Request, via []*Request) error {
1511 want := Header{
1512 "User-Agent": []string{ua},
1513 "X-Foo": []string{xfoo},
1514 "Referer": []string{ts2URL},
1515 "Cookie": []string{"foo=bar"},
1516 "Authorization": []string{"secretpassword"},
1517 }
1518 if !reflect.DeepEqual(r.Header, want) {
1519 t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want)
1520 }
1521 return nil
1522 }
1523
1524 req, _ := NewRequest("GET", ts2.URL, nil)
1525 req.Header.Add("User-Agent", ua)
1526 req.Header.Add("X-Foo", xfoo)
1527 req.Header.Add("Cookie", "foo=bar")
1528 req.Header.Add("Authorization", "secretpassword")
1529 res, err := c.Do(req)
1530 if err != nil {
1531 t.Fatal(err)
1532 }
1533 defer res.Body.Close()
1534 if res.StatusCode != 200 {
1535 t.Fatal(res.Status)
1536 }
1537 if got := res.Header.Get("Result"); got != "ok" {
1538 t.Errorf("result = %q; want ok", got)
1539 }
1540 }
1541
1542
1543
1544 func TestClientStripHeadersOnRepeatedRedirect(t *testing.T) {
1545 run(t, testClientStripHeadersOnRepeatedRedirect)
1546 }
1547 func testClientStripHeadersOnRepeatedRedirect(t *testing.T, mode testMode) {
1548 var proto string
1549 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1550 if r.Host+r.URL.Path != "a.example.com/" {
1551 if h := r.Header.Get("Authorization"); h != "" {
1552 t.Errorf("on request to %v%v, Authorization=%q, want no header", r.Host, r.URL.Path, h)
1553 }
1554 }
1555
1556
1557
1558 switch r.Host + r.URL.Path {
1559 case "a.example.com/":
1560 Redirect(w, r, proto+"://b.example.com/", StatusFound)
1561 case "b.example.com/":
1562 Redirect(w, r, proto+"://b.example.com/redirect", StatusFound)
1563 case "b.example.com/redirect":
1564 Redirect(w, r, proto+"://a.example.com/redirect", StatusFound)
1565 case "a.example.com/redirect":
1566 w.Header().Set("X-Done", "true")
1567 default:
1568 t.Errorf("unexpected request to %v", r.URL)
1569 }
1570 })).ts
1571 proto, _, _ = strings.Cut(ts.URL, ":")
1572
1573 c := ts.Client()
1574 c.Transport.(*Transport).Dial = func(_ string, _ string) (net.Conn, error) {
1575 return net.Dial("tcp", ts.Listener.Addr().String())
1576 }
1577
1578 req, _ := NewRequest("GET", proto+"://a.example.com/", nil)
1579 req.Header.Add("Cookie", "foo=bar")
1580 req.Header.Add("Authorization", "secretpassword")
1581 res, err := c.Do(req)
1582 if err != nil {
1583 t.Fatal(err)
1584 }
1585 defer res.Body.Close()
1586 if res.Header.Get("X-Done") != "true" {
1587 t.Fatalf("response missing expected header: X-Done=true")
1588 }
1589 }
1590
1591
1592 func TestClientCopyHostOnRedirect(t *testing.T) { run(t, testClientCopyHostOnRedirect) }
1593 func testClientCopyHostOnRedirect(t *testing.T, mode testMode) {
1594
1595 virtual := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1596 t.Errorf("Virtual host received request %v", r.URL)
1597 w.WriteHeader(403)
1598 io.WriteString(w, "should not see this response")
1599 })).ts
1600 defer virtual.Close()
1601 virtualHost := strings.TrimPrefix(virtual.URL, "http://")
1602 virtualHost = strings.TrimPrefix(virtualHost, "https://")
1603 t.Logf("Virtual host is %v", virtualHost)
1604
1605
1606 const wantBody = "response body"
1607 var tsURL string
1608 var tsHost string
1609 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1610 switch r.URL.Path {
1611 case "/":
1612
1613 if r.Host != virtualHost {
1614 t.Errorf("Serving /: Request.Host = %#v; want %#v", r.Host, virtualHost)
1615 w.WriteHeader(404)
1616 return
1617 }
1618 w.Header().Set("Location", "/hop")
1619 w.WriteHeader(302)
1620 case "/hop":
1621
1622 if r.Host != virtualHost {
1623 t.Errorf("Serving /hop: Request.Host = %#v; want %#v", r.Host, virtualHost)
1624 w.WriteHeader(404)
1625 return
1626 }
1627 w.Header().Set("Location", tsURL+"/final")
1628 w.WriteHeader(302)
1629 case "/final":
1630 if r.Host != tsHost {
1631 t.Errorf("Serving /final: Request.Host = %#v; want %#v", r.Host, tsHost)
1632 w.WriteHeader(404)
1633 return
1634 }
1635 w.WriteHeader(200)
1636 io.WriteString(w, wantBody)
1637 default:
1638 t.Errorf("Serving unexpected path %q", r.URL.Path)
1639 w.WriteHeader(404)
1640 }
1641 })).ts
1642 tsURL = ts.URL
1643 tsHost = strings.TrimPrefix(ts.URL, "http://")
1644 tsHost = strings.TrimPrefix(tsHost, "https://")
1645 t.Logf("Server host is %v", tsHost)
1646
1647 c := ts.Client()
1648 req, _ := NewRequest("GET", ts.URL, nil)
1649 req.Host = virtualHost
1650 resp, err := c.Do(req)
1651 if err != nil {
1652 t.Fatal(err)
1653 }
1654 defer resp.Body.Close()
1655 if resp.StatusCode != 200 {
1656 t.Fatal(resp.Status)
1657 }
1658 if got, err := io.ReadAll(resp.Body); err != nil || string(got) != wantBody {
1659 t.Errorf("body = %q; want %q", got, wantBody)
1660 }
1661 }
1662
1663
1664 func TestClientAltersCookiesOnRedirect(t *testing.T) { run(t, testClientAltersCookiesOnRedirect) }
1665 func testClientAltersCookiesOnRedirect(t *testing.T, mode testMode) {
1666 cookieMap := func(cs []*Cookie) map[string][]string {
1667 m := make(map[string][]string)
1668 for _, c := range cs {
1669 m[c.Name] = append(m[c.Name], c.Value)
1670 }
1671 return m
1672 }
1673
1674 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1675 var want map[string][]string
1676 got := cookieMap(r.Cookies())
1677
1678 c, _ := r.Cookie("Cycle")
1679 switch c.Value {
1680 case "0":
1681 want = map[string][]string{
1682 "Cookie1": {"OldValue1a", "OldValue1b"},
1683 "Cookie2": {"OldValue2"},
1684 "Cookie3": {"OldValue3a", "OldValue3b"},
1685 "Cookie4": {"OldValue4"},
1686 "Cycle": {"0"},
1687 }
1688 SetCookie(w, &Cookie{Name: "Cycle", Value: "1", Path: "/"})
1689 SetCookie(w, &Cookie{Name: "Cookie2", Path: "/", MaxAge: -1})
1690 Redirect(w, r, "/", StatusFound)
1691 case "1":
1692 want = map[string][]string{
1693 "Cookie1": {"OldValue1a", "OldValue1b"},
1694 "Cookie3": {"OldValue3a", "OldValue3b"},
1695 "Cookie4": {"OldValue4"},
1696 "Cycle": {"1"},
1697 }
1698 SetCookie(w, &Cookie{Name: "Cycle", Value: "2", Path: "/"})
1699 SetCookie(w, &Cookie{Name: "Cookie3", Value: "NewValue3", Path: "/"})
1700 SetCookie(w, &Cookie{Name: "Cookie4", Value: "NewValue4", Path: "/"})
1701 Redirect(w, r, "/", StatusFound)
1702 case "2":
1703 want = map[string][]string{
1704 "Cookie1": {"OldValue1a", "OldValue1b"},
1705 "Cookie3": {"NewValue3"},
1706 "Cookie4": {"NewValue4"},
1707 "Cycle": {"2"},
1708 }
1709 SetCookie(w, &Cookie{Name: "Cycle", Value: "3", Path: "/"})
1710 SetCookie(w, &Cookie{Name: "Cookie5", Value: "NewValue5", Path: "/"})
1711 Redirect(w, r, "/", StatusFound)
1712 case "3":
1713 want = map[string][]string{
1714 "Cookie1": {"OldValue1a", "OldValue1b"},
1715 "Cookie3": {"NewValue3"},
1716 "Cookie4": {"NewValue4"},
1717 "Cookie5": {"NewValue5"},
1718 "Cycle": {"3"},
1719 }
1720
1721 default:
1722 t.Errorf("unexpected redirect cycle")
1723 return
1724 }
1725
1726 if !reflect.DeepEqual(got, want) {
1727 t.Errorf("redirect %s, Cookie = %v, want %v", c.Value, got, want)
1728 }
1729 })).ts
1730
1731 jar, _ := cookiejar.New(nil)
1732 c := ts.Client()
1733 c.Jar = jar
1734
1735 u, _ := url.Parse(ts.URL)
1736 req, _ := NewRequest("GET", ts.URL, nil)
1737 req.AddCookie(&Cookie{Name: "Cookie1", Value: "OldValue1a"})
1738 req.AddCookie(&Cookie{Name: "Cookie1", Value: "OldValue1b"})
1739 req.AddCookie(&Cookie{Name: "Cookie2", Value: "OldValue2"})
1740 req.AddCookie(&Cookie{Name: "Cookie3", Value: "OldValue3a"})
1741 req.AddCookie(&Cookie{Name: "Cookie3", Value: "OldValue3b"})
1742 jar.SetCookies(u, []*Cookie{{Name: "Cookie4", Value: "OldValue4", Path: "/"}})
1743 jar.SetCookies(u, []*Cookie{{Name: "Cycle", Value: "0", Path: "/"}})
1744 res, err := c.Do(req)
1745 if err != nil {
1746 t.Fatal(err)
1747 }
1748 defer res.Body.Close()
1749 if res.StatusCode != 200 {
1750 t.Fatal(res.Status)
1751 }
1752 }
1753
1754
1755 func TestShouldCopyHeaderOnRedirect(t *testing.T) {
1756 tests := []struct {
1757 initialURL string
1758 destURL string
1759 want bool
1760 }{
1761
1762 {"http://foo.com/", "http://bar.com/", false},
1763 {"http://foo.com/", "http://bar.com/", false},
1764 {"http://foo.com/", "http://bar.com/", false},
1765 {"http://foo.com/", "https://foo.com/", true},
1766 {"http://foo.com:1234/", "http://foo.com:4321/", true},
1767 {"http://foo.com/", "http://bar.com/", false},
1768 {"http://foo.com/", "http://[::1%25.foo.com]/", false},
1769
1770
1771 {"http://foo.com/", "http://foo.com/", true},
1772 {"http://foo.com/", "http://sub.foo.com/", true},
1773 {"http://foo.com/", "http://notfoo.com/", false},
1774 {"http://foo.com/", "https://foo.com/", true},
1775 {"http://foo.com:80/", "http://foo.com/", true},
1776 {"http://foo.com:80/", "http://sub.foo.com/", true},
1777 {"http://foo.com:443/", "https://foo.com/", true},
1778 {"http://foo.com:443/", "https://sub.foo.com/", true},
1779 {"http://foo.com:1234/", "http://foo.com/", true},
1780
1781 {"http://foo.com/", "http://foo.com/", true},
1782 {"http://foo.com/", "http://sub.foo.com/", true},
1783 {"http://foo.com/", "http://notfoo.com/", false},
1784 {"http://foo.com/", "https://foo.com/", true},
1785 {"http://foo.com:80/", "http://foo.com/", true},
1786 {"http://foo.com:80/", "http://sub.foo.com/", true},
1787 {"http://foo.com:443/", "https://foo.com/", true},
1788 {"http://foo.com:443/", "https://sub.foo.com/", true},
1789 {"http://foo.com:1234/", "http://foo.com/", true},
1790 }
1791 for i, tt := range tests {
1792 u0, err := url.Parse(tt.initialURL)
1793 if err != nil {
1794 t.Errorf("%d. initial URL %q parse error: %v", i, tt.initialURL, err)
1795 continue
1796 }
1797 u1, err := url.Parse(tt.destURL)
1798 if err != nil {
1799 t.Errorf("%d. dest URL %q parse error: %v", i, tt.destURL, err)
1800 continue
1801 }
1802 got := Export_shouldCopyHeaderOnRedirect(u0, u1)
1803 if got != tt.want {
1804 t.Errorf("%d. shouldCopyHeaderOnRedirect(%q => %q) = %v; want %v",
1805 i, tt.initialURL, tt.destURL, got, tt.want)
1806 }
1807 }
1808 }
1809
1810 func TestClientRedirectTypes(t *testing.T) { run(t, testClientRedirectTypes) }
1811 func testClientRedirectTypes(t *testing.T, mode testMode) {
1812 tests := [...]struct {
1813 method string
1814 serverStatus int
1815 wantMethod string
1816 }{
1817 0: {method: "POST", serverStatus: 301, wantMethod: "GET"},
1818 1: {method: "POST", serverStatus: 302, wantMethod: "GET"},
1819 2: {method: "POST", serverStatus: 303, wantMethod: "GET"},
1820 3: {method: "POST", serverStatus: 307, wantMethod: "POST"},
1821 4: {method: "POST", serverStatus: 308, wantMethod: "POST"},
1822
1823 5: {method: "HEAD", serverStatus: 301, wantMethod: "HEAD"},
1824 6: {method: "HEAD", serverStatus: 302, wantMethod: "HEAD"},
1825 7: {method: "HEAD", serverStatus: 303, wantMethod: "HEAD"},
1826 8: {method: "HEAD", serverStatus: 307, wantMethod: "HEAD"},
1827 9: {method: "HEAD", serverStatus: 308, wantMethod: "HEAD"},
1828
1829 10: {method: "GET", serverStatus: 301, wantMethod: "GET"},
1830 11: {method: "GET", serverStatus: 302, wantMethod: "GET"},
1831 12: {method: "GET", serverStatus: 303, wantMethod: "GET"},
1832 13: {method: "GET", serverStatus: 307, wantMethod: "GET"},
1833 14: {method: "GET", serverStatus: 308, wantMethod: "GET"},
1834
1835 15: {method: "DELETE", serverStatus: 301, wantMethod: "GET"},
1836 16: {method: "DELETE", serverStatus: 302, wantMethod: "GET"},
1837 17: {method: "DELETE", serverStatus: 303, wantMethod: "GET"},
1838 18: {method: "DELETE", serverStatus: 307, wantMethod: "DELETE"},
1839 19: {method: "DELETE", serverStatus: 308, wantMethod: "DELETE"},
1840
1841 20: {method: "PUT", serverStatus: 301, wantMethod: "GET"},
1842 21: {method: "PUT", serverStatus: 302, wantMethod: "GET"},
1843 22: {method: "PUT", serverStatus: 303, wantMethod: "GET"},
1844 23: {method: "PUT", serverStatus: 307, wantMethod: "PUT"},
1845 24: {method: "PUT", serverStatus: 308, wantMethod: "PUT"},
1846
1847 25: {method: "MADEUPMETHOD", serverStatus: 301, wantMethod: "GET"},
1848 26: {method: "MADEUPMETHOD", serverStatus: 302, wantMethod: "GET"},
1849 27: {method: "MADEUPMETHOD", serverStatus: 303, wantMethod: "GET"},
1850 28: {method: "MADEUPMETHOD", serverStatus: 307, wantMethod: "MADEUPMETHOD"},
1851 29: {method: "MADEUPMETHOD", serverStatus: 308, wantMethod: "MADEUPMETHOD"},
1852 }
1853
1854 handlerc := make(chan HandlerFunc, 1)
1855
1856 ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
1857 h := <-handlerc
1858 h(rw, req)
1859 })).ts
1860
1861 c := ts.Client()
1862 for i, tt := range tests {
1863 handlerc <- func(w ResponseWriter, r *Request) {
1864 w.Header().Set("Location", ts.URL)
1865 w.WriteHeader(tt.serverStatus)
1866 }
1867
1868 req, err := NewRequest(tt.method, ts.URL, nil)
1869 if err != nil {
1870 t.Errorf("#%d: NewRequest: %v", i, err)
1871 continue
1872 }
1873
1874 c.CheckRedirect = func(req *Request, via []*Request) error {
1875 if got, want := req.Method, tt.wantMethod; got != want {
1876 return fmt.Errorf("#%d: got next method %q; want %q", i, got, want)
1877 }
1878 handlerc <- func(rw ResponseWriter, req *Request) {
1879
1880 }
1881 return nil
1882 }
1883
1884 res, err := c.Do(req)
1885 if err != nil {
1886 t.Errorf("#%d: Response: %v", i, err)
1887 continue
1888 }
1889
1890 res.Body.Close()
1891 }
1892 }
1893
1894
1895
1896
1897 type issue18239Body struct {
1898 readCalls *int32
1899 closeCalls *int32
1900 readErr error
1901 }
1902
1903 func (b issue18239Body) Read([]byte) (int, error) {
1904 atomic.AddInt32(b.readCalls, 1)
1905 return 0, b.readErr
1906 }
1907
1908 func (b issue18239Body) Close() error {
1909 atomic.AddInt32(b.closeCalls, 1)
1910 return nil
1911 }
1912
1913
1914
1915 func TestTransportBodyReadError(t *testing.T) { run(t, testTransportBodyReadError) }
1916 func testTransportBodyReadError(t *testing.T, mode testMode) {
1917 ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
1918 if r.URL.Path == "/ping" {
1919 return
1920 }
1921 buf := make([]byte, 1)
1922 n, err := r.Body.Read(buf)
1923 w.Header().Set("X-Body-Read", fmt.Sprintf("%v, %v", n, err))
1924 })).ts
1925 c := ts.Client()
1926 tr := c.Transport.(*Transport)
1927
1928
1929
1930
1931 res, err := c.Get(ts.URL + "/ping")
1932 if err != nil {
1933 t.Fatal(err)
1934 }
1935 res.Body.Close()
1936
1937 var readCallsAtomic int32
1938 var closeCallsAtomic int32
1939 someErr := errors.New("some body read error")
1940 body := issue18239Body{&readCallsAtomic, &closeCallsAtomic, someErr}
1941
1942 req, err := NewRequest("POST", ts.URL, body)
1943 if err != nil {
1944 t.Fatal(err)
1945 }
1946 req = req.WithT(t)
1947 _, err = tr.RoundTrip(req)
1948 if err != someErr {
1949 t.Errorf("Got error: %v; want Request.Body read error: %v", err, someErr)
1950 }
1951
1952
1953
1954
1955 readCalls := atomic.LoadInt32(&readCallsAtomic)
1956 closeCalls := atomic.LoadInt32(&closeCallsAtomic)
1957 if readCalls != 1 {
1958 t.Errorf("read calls = %d; want 1", readCalls)
1959 }
1960 if closeCalls != 1 {
1961 t.Errorf("close calls = %d; want 1", closeCalls)
1962 }
1963 }
1964
1965 type roundTripperWithoutCloseIdle struct{}
1966
1967 func (roundTripperWithoutCloseIdle) RoundTrip(*Request) (*Response, error) { panic("unused") }
1968
1969 type roundTripperWithCloseIdle func()
1970
1971 func (roundTripperWithCloseIdle) RoundTrip(*Request) (*Response, error) { panic("unused") }
1972 func (f roundTripperWithCloseIdle) CloseIdleConnections() { f() }
1973
1974 func TestClientCloseIdleConnections(t *testing.T) {
1975 c := &Client{Transport: roundTripperWithoutCloseIdle{}}
1976 c.CloseIdleConnections()
1977
1978 closed := false
1979 var tr RoundTripper = roundTripperWithCloseIdle(func() {
1980 closed = true
1981 })
1982 c = &Client{Transport: tr}
1983 c.CloseIdleConnections()
1984 if !closed {
1985 t.Error("not closed")
1986 }
1987 }
1988
1989 type testRoundTripper func(*Request) (*Response, error)
1990
1991 func (t testRoundTripper) RoundTrip(req *Request) (*Response, error) {
1992 return t(req)
1993 }
1994
1995 func TestClientPropagatesTimeoutToContext(t *testing.T) {
1996 c := &Client{
1997 Timeout: 5 * time.Second,
1998 Transport: testRoundTripper(func(req *Request) (*Response, error) {
1999 ctx := req.Context()
2000 deadline, ok := ctx.Deadline()
2001 if !ok {
2002 t.Error("no deadline")
2003 } else {
2004 t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10))
2005 }
2006 return nil, errors.New("not actually making a request")
2007 }),
2008 }
2009 c.Get("https://example.tld/")
2010 }
2011
2012
2013
2014 func TestClientDoCanceledVsTimeout(t *testing.T) { run(t, testClientDoCanceledVsTimeout) }
2015 func testClientDoCanceledVsTimeout(t *testing.T, mode testMode) {
2016 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2017 w.Write([]byte("Hello, World!"))
2018 }))
2019
2020 cases := []string{"timeout", "canceled"}
2021
2022 for _, name := range cases {
2023 t.Run(name, func(t *testing.T) {
2024 var ctx context.Context
2025 var cancel func()
2026 if name == "timeout" {
2027 ctx, cancel = context.WithTimeout(context.Background(), -time.Nanosecond)
2028 } else {
2029 ctx, cancel = context.WithCancel(context.Background())
2030 cancel()
2031 }
2032 defer cancel()
2033
2034 req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
2035 _, err := cst.c.Do(req)
2036 if err == nil {
2037 t.Fatal("Unexpectedly got a nil error")
2038 }
2039
2040 ue := err.(*url.Error)
2041
2042 var wantIsTimeout bool
2043 var wantErr error = context.Canceled
2044 if name == "timeout" {
2045 wantErr = context.DeadlineExceeded
2046 wantIsTimeout = true
2047 }
2048 if g, w := ue.Timeout(), wantIsTimeout; g != w {
2049 t.Fatalf("url.Timeout() = %t, want %t", g, w)
2050 }
2051 if g, w := ue.Err, wantErr; g != w {
2052 t.Errorf("url.Error.Err = %v; want %v", g, w)
2053 }
2054 if got := errors.Is(err, context.DeadlineExceeded); got != wantIsTimeout {
2055 t.Errorf("errors.Is(err, context.DeadlineExceeded) = %v, want %v", got, wantIsTimeout)
2056 }
2057 })
2058 }
2059 }
2060
2061 type nilBodyRoundTripper struct{}
2062
2063 func (nilBodyRoundTripper) RoundTrip(req *Request) (*Response, error) {
2064 return &Response{
2065 StatusCode: StatusOK,
2066 Status: StatusText(StatusOK),
2067 Body: nil,
2068 Request: req,
2069 }, nil
2070 }
2071
2072 func TestClientPopulatesNilResponseBody(t *testing.T) {
2073 c := &Client{Transport: nilBodyRoundTripper{}}
2074
2075 resp, err := c.Get("http://localhost/anything")
2076 if err != nil {
2077 t.Fatalf("Client.Get rejected Response with nil Body: %v", err)
2078 }
2079
2080 if resp.Body == nil {
2081 t.Fatalf("Client failed to provide a non-nil Body as documented")
2082 }
2083 defer func() {
2084 if err := resp.Body.Close(); err != nil {
2085 t.Fatalf("error from Close on substitute Response.Body: %v", err)
2086 }
2087 }()
2088
2089 if b, err := io.ReadAll(resp.Body); err != nil {
2090 t.Errorf("read error from substitute Response.Body: %v", err)
2091 } else if len(b) != 0 {
2092 t.Errorf("substitute Response.Body was unexpectedly non-empty: %q", b)
2093 }
2094 }
2095
2096
2097 func TestClientCallsCloseOnlyOnce(t *testing.T) { run(t, testClientCallsCloseOnlyOnce) }
2098 func testClientCallsCloseOnlyOnce(t *testing.T, mode testMode) {
2099 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2100 w.WriteHeader(StatusNoContent)
2101 }))
2102
2103
2104
2105 for i := 0; i < 50 && !t.Failed(); i++ {
2106 body := &issue40382Body{t: t, n: 300000}
2107 req, err := NewRequest(MethodPost, cst.ts.URL, body)
2108 if err != nil {
2109 t.Fatal(err)
2110 }
2111 resp, err := cst.tr.RoundTrip(req)
2112 if err != nil {
2113 t.Fatal(err)
2114 }
2115 resp.Body.Close()
2116 }
2117 }
2118
2119
2120
2121
2122 type issue40382Body struct {
2123 t *testing.T
2124 n int
2125 closeCallsAtomic int32
2126 }
2127
2128 func (b *issue40382Body) Read(p []byte) (int, error) {
2129 switch {
2130 case b.n == 0:
2131 return 0, io.EOF
2132 case b.n < len(p):
2133 p = p[:b.n]
2134 fallthrough
2135 default:
2136 for i := range p {
2137 p[i] = 'x'
2138 }
2139 b.n -= len(p)
2140 return len(p), nil
2141 }
2142 }
2143
2144 func (b *issue40382Body) Close() error {
2145 if atomic.AddInt32(&b.closeCallsAtomic, 1) == 2 {
2146 b.t.Error("Body closed more than once")
2147 }
2148 return nil
2149 }
2150
2151 func TestProbeZeroLengthBody(t *testing.T) { run(t, testProbeZeroLengthBody) }
2152 func testProbeZeroLengthBody(t *testing.T, mode testMode) {
2153 reqc := make(chan struct{})
2154 cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
2155 close(reqc)
2156 if _, err := io.Copy(w, r.Body); err != nil {
2157 t.Errorf("error copying request body: %v", err)
2158 }
2159 }))
2160
2161 bodyr, bodyw := io.Pipe()
2162 var gotBody string
2163 var wg sync.WaitGroup
2164 wg.Add(1)
2165 go func() {
2166 defer wg.Done()
2167 req, _ := NewRequest("GET", cst.ts.URL, bodyr)
2168 res, err := cst.c.Do(req)
2169 b, err := io.ReadAll(res.Body)
2170 if err != nil {
2171 t.Error(err)
2172 }
2173 gotBody = string(b)
2174 }()
2175
2176 select {
2177 case <-reqc:
2178
2179 case <-time.After(60 * time.Second):
2180 t.Errorf("request not sent after 60s")
2181 }
2182
2183
2184 const content = "body"
2185 bodyw.Write([]byte(content))
2186 bodyw.Close()
2187 wg.Wait()
2188 if gotBody != content {
2189 t.Fatalf("server got body %q, want %q", gotBody, content)
2190 }
2191 }
2192
View as plain text