1
2
3
4
5 package http2_test
6
7 import (
8 "errors"
9 "fmt"
10 "io"
11 "net/http"
12 "reflect"
13 "strconv"
14 "testing"
15 "testing/synctest"
16 "time"
17
18 . "net/http/internal/http2"
19 )
20
21 func TestServer_Push_Success(t *testing.T) { synctestTest(t, testServer_Push_Success) }
22 func testServer_Push_Success(t testing.TB) {
23 const (
24 mainBody = "<html>index page</html>"
25 pushedBody = "<html>pushed page</html>"
26 userAgent = "testagent"
27 cookie = "testcookie"
28 )
29
30 var stURL string
31 checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error {
32 if got, want := r.Method, wantMethod; got != want {
33 return fmt.Errorf("promised Req.Method=%q, want %q", got, want)
34 }
35 if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) {
36 return fmt.Errorf("promised Req.Header=%q, want %q", got, want)
37 }
38 if got, want := "https://"+r.Host, stURL; got != want {
39 return fmt.Errorf("promised Req.Host=%q, want %q", got, want)
40 }
41 if r.Body == nil {
42 return fmt.Errorf("nil Body")
43 }
44 if buf, err := io.ReadAll(r.Body); err != nil || len(buf) != 0 {
45 return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err)
46 }
47 return nil
48 }
49
50 errc := make(chan error, 3)
51 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
52 switch r.URL.RequestURI() {
53 case "/":
54
55 opt := &http.PushOptions{
56 Header: http.Header{
57 "User-Agent": {userAgent},
58 },
59 }
60 if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil {
61 errc <- fmt.Errorf("error pushing /pushed?get: %v", err)
62 return
63 }
64
65 opt = &http.PushOptions{
66 Method: "HEAD",
67 Header: http.Header{
68 "User-Agent": {userAgent},
69 "Cookie": {cookie},
70 },
71 }
72 if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil {
73 errc <- fmt.Errorf("error pushing /pushed?head: %v", err)
74 return
75 }
76 w.Header().Set("Content-Type", "text/html")
77 w.Header().Set("Content-Length", strconv.Itoa(len(mainBody)))
78 w.WriteHeader(200)
79 io.WriteString(w, mainBody)
80 errc <- nil
81
82 case "/pushed?get":
83 wantH := http.Header{}
84 wantH.Set("User-Agent", userAgent)
85 if err := checkPromisedReq(r, "GET", wantH); err != nil {
86 errc <- fmt.Errorf("/pushed?get: %v", err)
87 return
88 }
89 w.Header().Set("Content-Type", "text/html")
90 w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody)))
91 w.WriteHeader(200)
92 io.WriteString(w, pushedBody)
93 errc <- nil
94
95 case "/pushed?head":
96 wantH := http.Header{}
97 wantH.Set("User-Agent", userAgent)
98 wantH.Set("Cookie", cookie)
99 if err := checkPromisedReq(r, "HEAD", wantH); err != nil {
100 errc <- fmt.Errorf("/pushed?head: %v", err)
101 return
102 }
103 w.WriteHeader(204)
104 errc <- nil
105
106 default:
107 errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
108 }
109 })
110 stURL = "https://" + st.authority()
111
112
113 st.greet()
114 getSlash(st)
115 for k := 0; k < 3; k++ {
116 select {
117 case <-time.After(2 * time.Second):
118 t.Errorf("timeout waiting for handler %d to finish", k)
119 case err := <-errc:
120 if err != nil {
121 t.Fatal(err)
122 }
123 }
124 }
125
126 checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error {
127 pp, ok := f.(*PushPromiseFrame)
128 if !ok {
129 return fmt.Errorf("got a %T; want *PushPromiseFrame", f)
130 }
131 if !pp.HeadersEnded() {
132 return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame")
133 }
134 if got, want := pp.PromiseID, promiseID; got != want {
135 return fmt.Errorf("got PromiseID %v; want %v", got, want)
136 }
137 gotH := st.decodeHeader(pp.HeaderBlockFragment())
138 if !reflect.DeepEqual(gotH, wantH) {
139 return fmt.Errorf("got promised headers %v; want %v", gotH, wantH)
140 }
141 return nil
142 }
143 checkHeaders := func(f Frame, wantH [][2]string) error {
144 hf, ok := f.(*HeadersFrame)
145 if !ok {
146 return fmt.Errorf("got a %T; want *HeadersFrame", f)
147 }
148 gotH := st.decodeHeader(hf.HeaderBlockFragment())
149 if !reflect.DeepEqual(gotH, wantH) {
150 return fmt.Errorf("got response headers %v; want %v", gotH, wantH)
151 }
152 return nil
153 }
154 checkData := func(f Frame, wantData string) error {
155 df, ok := f.(*DataFrame)
156 if !ok {
157 return fmt.Errorf("got a %T; want *DataFrame", f)
158 }
159 if gotData := string(df.Data()); gotData != wantData {
160 return fmt.Errorf("got response data %q; want %q", gotData, wantData)
161 }
162 return nil
163 }
164
165
166
167
168 expected := map[uint32][]func(Frame) error{
169 1: {
170 func(f Frame) error {
171 return checkPushPromise(f, 2, [][2]string{
172 {":method", "GET"},
173 {":scheme", "https"},
174 {":authority", st.authority()},
175 {":path", "/pushed?get"},
176 {"user-agent", userAgent},
177 })
178 },
179 func(f Frame) error {
180 return checkPushPromise(f, 4, [][2]string{
181 {":method", "HEAD"},
182 {":scheme", "https"},
183 {":authority", st.authority()},
184 {":path", "/pushed?head"},
185 {"cookie", cookie},
186 {"user-agent", userAgent},
187 })
188 },
189 func(f Frame) error {
190 return checkHeaders(f, [][2]string{
191 {":status", "200"},
192 {"content-type", "text/html"},
193 {"content-length", strconv.Itoa(len(mainBody))},
194 })
195 },
196 func(f Frame) error {
197 return checkData(f, mainBody)
198 },
199 },
200 2: {
201 func(f Frame) error {
202 return checkHeaders(f, [][2]string{
203 {":status", "200"},
204 {"content-type", "text/html"},
205 {"content-length", strconv.Itoa(len(pushedBody))},
206 })
207 },
208 func(f Frame) error {
209 return checkData(f, pushedBody)
210 },
211 },
212 4: {
213 func(f Frame) error {
214 return checkHeaders(f, [][2]string{
215 {":status", "204"},
216 })
217 },
218 },
219 }
220
221 consumed := map[uint32]int{}
222 for k := 0; len(expected) > 0; k++ {
223 f := st.readFrame()
224 if f == nil {
225 for id, left := range expected {
226 t.Errorf("stream %d: missing %d frames", id, len(left))
227 }
228 break
229 }
230 id := f.Header().StreamID
231 label := fmt.Sprintf("stream %d, frame %d", id, consumed[id])
232 if len(expected[id]) == 0 {
233 t.Fatalf("%s: unexpected frame %#+v", label, f)
234 }
235 check := expected[id][0]
236 expected[id] = expected[id][1:]
237 if len(expected[id]) == 0 {
238 delete(expected, id)
239 }
240 if err := check(f); err != nil {
241 t.Fatalf("%s: %v", label, err)
242 }
243 consumed[id]++
244 }
245 }
246
247 func TestServer_Push_SuccessNoRace(t *testing.T) { synctestTest(t, testServer_Push_SuccessNoRace) }
248 func testServer_Push_SuccessNoRace(t testing.TB) {
249
250
251 errc := make(chan error, 2)
252 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
253 switch r.URL.RequestURI() {
254 case "/":
255 opt := &http.PushOptions{
256 Header: http.Header{"User-Agent": {"testagent"}},
257 }
258 if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
259 errc <- fmt.Errorf("error pushing: %v", err)
260 return
261 }
262 w.WriteHeader(200)
263 errc <- nil
264
265 case "/pushed":
266
267 r.Header.Set("User-Agent", "newagent")
268 r.Header.Set("Cookie", "cookie")
269 w.WriteHeader(200)
270 errc <- nil
271
272 default:
273 errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
274 }
275 })
276
277
278 st.greet()
279 getSlash(st)
280 for k := 0; k < 2; k++ {
281 select {
282 case <-time.After(2 * time.Second):
283 t.Errorf("timeout waiting for handler %d to finish", k)
284 case err := <-errc:
285 if err != nil {
286 t.Fatal(err)
287 }
288 }
289 }
290 }
291
292 func TestServer_Push_RejectRecursivePush(t *testing.T) {
293 synctestTest(t, testServer_Push_RejectRecursivePush)
294 }
295 func testServer_Push_RejectRecursivePush(t testing.TB) {
296
297 errc := make(chan error, 3)
298 handler := func(w http.ResponseWriter, r *http.Request) error {
299 baseURL := "https://" + r.Host
300 switch r.URL.Path {
301 case "/":
302 if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil {
303 return fmt.Errorf("first Push()=%v, want nil", err)
304 }
305 return nil
306
307 case "/push1":
308 if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want {
309 return fmt.Errorf("Push()=%v, want %v", got, want)
310 }
311 return nil
312
313 default:
314 return fmt.Errorf("unexpected path: %q", r.URL.Path)
315 }
316 }
317 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
318 errc <- handler(w, r)
319 })
320 defer st.Close()
321 st.greet()
322 getSlash(st)
323 if err := <-errc; err != nil {
324 t.Errorf("First request failed: %v", err)
325 }
326 if err := <-errc; err != nil {
327 t.Errorf("Second request failed: %v", err)
328 }
329 }
330
331 func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) {
332 synctestTest(t, func(t testing.TB) {
333 testServer_Push_RejectSingleRequest_Bubble(t, doPush, settings...)
334 })
335 }
336 func testServer_Push_RejectSingleRequest_Bubble(t testing.TB, doPush func(http.Pusher, *http.Request) error, settings ...Setting) {
337
338 errc := make(chan error, 2)
339 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
340 errc <- doPush(w.(http.Pusher), r)
341 })
342 defer st.Close()
343 st.greet()
344 if err := st.fr.WriteSettings(settings...); err != nil {
345 st.t.Fatalf("WriteSettings: %v", err)
346 }
347 st.wantSettingsAck()
348 getSlash(st)
349 if err := <-errc; err != nil {
350 t.Error(err)
351 }
352
353 st.wantHeaders(wantHeader{
354 streamID: 1,
355 endStream: true,
356 })
357 }
358
359 func TestServer_Push_RejectIfDisabled(t *testing.T) {
360 testServer_Push_RejectSingleRequest(t,
361 func(p http.Pusher, r *http.Request) error {
362 if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
363 return fmt.Errorf("Push()=%v, want %v", got, want)
364 }
365 return nil
366 },
367 Setting{SettingEnablePush, 0})
368 }
369
370 func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) {
371 testServer_Push_RejectSingleRequest(t,
372 func(p http.Pusher, r *http.Request) error {
373 if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want {
374 return fmt.Errorf("Push()=%v, want %v", got, want)
375 }
376 return nil
377 },
378 Setting{SettingMaxConcurrentStreams, 0})
379 }
380
381 func TestServer_Push_RejectWrongScheme(t *testing.T) {
382 testServer_Push_RejectSingleRequest(t,
383 func(p http.Pusher, r *http.Request) error {
384 if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil {
385 return errors.New("Push() should have failed (push target URL is http)")
386 }
387 return nil
388 })
389 }
390
391 func TestServer_Push_RejectMissingHost(t *testing.T) {
392 testServer_Push_RejectSingleRequest(t,
393 func(p http.Pusher, r *http.Request) error {
394 if err := p.Push("https:pushed", nil); err == nil {
395 return errors.New("Push() should have failed (push target URL missing host)")
396 }
397 return nil
398 })
399 }
400
401 func TestServer_Push_RejectRelativePath(t *testing.T) {
402 testServer_Push_RejectSingleRequest(t,
403 func(p http.Pusher, r *http.Request) error {
404 if err := p.Push("../test", nil); err == nil {
405 return errors.New("Push() should have failed (push target is a relative path)")
406 }
407 return nil
408 })
409 }
410
411 func TestServer_Push_RejectForbiddenMethod(t *testing.T) {
412 testServer_Push_RejectSingleRequest(t,
413 func(p http.Pusher, r *http.Request) error {
414 if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil {
415 return errors.New("Push() should have failed (cannot promise a POST)")
416 }
417 return nil
418 })
419 }
420
421 func TestServer_Push_RejectForbiddenHeader(t *testing.T) {
422 testServer_Push_RejectSingleRequest(t,
423 func(p http.Pusher, r *http.Request) error {
424 header := http.Header{
425 "Content-Length": {"10"},
426 "Content-Encoding": {"gzip"},
427 "Trailer": {"Foo"},
428 "Te": {"trailers"},
429 "Host": {"test.com"},
430 ":authority": {"test.com"},
431 }
432 if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil {
433 return errors.New("Push() should have failed (forbidden headers)")
434 }
435 return nil
436 })
437 }
438
439 func TestServer_Push_StateTransitions(t *testing.T) {
440 synctestTest(t, testServer_Push_StateTransitions)
441 }
442 func testServer_Push_StateTransitions(t testing.TB) {
443 const body = "foo"
444
445 gotPromise := make(chan bool)
446 finishedPush := make(chan bool)
447
448 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
449 switch r.URL.RequestURI() {
450 case "/":
451 if err := w.(http.Pusher).Push("/pushed", nil); err != nil {
452 t.Errorf("Push error: %v", err)
453 }
454
455
456 <-finishedPush
457 case "/pushed":
458 <-gotPromise
459 }
460 w.Header().Set("Content-Type", "text/html")
461 w.Header().Set("Content-Length", strconv.Itoa(len(body)))
462 w.WriteHeader(200)
463 io.WriteString(w, body)
464 })
465 defer st.Close()
466
467 st.greet()
468 if st.streamExists(2) {
469 t.Fatal("stream 2 should be empty")
470 }
471 if got, want := st.streamState(2), StateIdle; got != want {
472 t.Fatalf("streamState(2)=%v, want %v", got, want)
473 }
474 getSlash(st)
475
476 _ = readFrame[*PushPromiseFrame](t, st)
477 if got, want := st.streamState(2), StateHalfClosedRemote; got != want {
478 t.Fatalf("streamState(2)=%v, want %v", got, want)
479 }
480
481
482
483
484 close(gotPromise)
485 st.wantHeaders(wantHeader{
486 streamID: 2,
487 endStream: false,
488 })
489 if got, want := st.streamState(2), StateClosed; got != want {
490 t.Fatalf("streamState(2)=%v, want %v", got, want)
491 }
492 close(finishedPush)
493 }
494
495 func TestServer_Push_RejectAfterGoAway(t *testing.T) {
496 synctestTest(t, testServer_Push_RejectAfterGoAway)
497 }
498 func testServer_Push_RejectAfterGoAway(t testing.TB) {
499 ready := make(chan struct{})
500 errc := make(chan error, 2)
501 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
502 <-ready
503 if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
504 errc <- fmt.Errorf("Push()=%v, want %v", got, want)
505 }
506 errc <- nil
507 })
508 defer st.Close()
509 st.greet()
510 getSlash(st)
511
512
513 st.fr.WriteGoAway(1, ErrCodeNo, nil)
514 synctest.Wait()
515 close(ready)
516 if err := <-errc; err != nil {
517 t.Error(err)
518 }
519 }
520
521 func TestServer_Push_Underflow(t *testing.T) { synctestTest(t, testServer_Push_Underflow) }
522 func testServer_Push_Underflow(t testing.TB) {
523
524
525 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
526 switch r.URL.RequestURI() {
527 case "/":
528 opt := &http.PushOptions{
529 Header: http.Header{"User-Agent": {"testagent"}},
530 }
531 if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
532 t.Errorf("error pushing: %v", err)
533 }
534 w.WriteHeader(200)
535 case "/pushed":
536 r.Header.Set("User-Agent", "newagent")
537 r.Header.Set("Cookie", "cookie")
538 w.WriteHeader(200)
539 default:
540 t.Errorf("unknown RequestURL %q", r.URL.RequestURI())
541 }
542 })
543
544 st.greet()
545 const numRequests = 4
546 for i := 0; i < numRequests; i++ {
547 st.writeHeaders(HeadersFrameParam{
548 StreamID: uint32(1 + i*2),
549 BlockFragment: st.encodeHeader(),
550 EndStream: true,
551 EndHeaders: true,
552 })
553 }
554
555 numPushPromises := 0
556 numHeaders := 0
557 for numHeaders < numRequests*2 || numPushPromises < numRequests {
558 f := st.readFrame()
559 if f == nil {
560 st.t.Fatal("conn is idle, want frame")
561 }
562 switch f := f.(type) {
563 case *HeadersFrame:
564 if !f.Flags.Has(FlagHeadersEndStream) {
565 t.Fatalf("got HEADERS frame with no END_STREAM, expected END_STREAM: %v", f)
566 }
567 numHeaders++
568 case *PushPromiseFrame:
569 numPushPromises++
570 }
571 }
572 }
573
View as plain text