Source file src/net/http/internal/http2/server_push_test.go

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http2_test
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"reflect"
    13  	"strconv"
    14  	"testing"
    15  	"testing/synctest"
    16  	"time"
    17  
    18  	. "net/http/internal/http2"
    19  )
    20  
    21  func TestServer_Push_Success(t *testing.T) { synctestTest(t, testServer_Push_Success) }
    22  func testServer_Push_Success(t testing.TB) {
    23  	const (
    24  		mainBody   = "<html>index page</html>"
    25  		pushedBody = "<html>pushed page</html>"
    26  		userAgent  = "testagent"
    27  		cookie     = "testcookie"
    28  	)
    29  
    30  	var stURL string
    31  	checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error {
    32  		if got, want := r.Method, wantMethod; got != want {
    33  			return fmt.Errorf("promised Req.Method=%q, want %q", got, want)
    34  		}
    35  		if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) {
    36  			return fmt.Errorf("promised Req.Header=%q, want %q", got, want)
    37  		}
    38  		if got, want := "https://"+r.Host, stURL; got != want {
    39  			return fmt.Errorf("promised Req.Host=%q, want %q", got, want)
    40  		}
    41  		if r.Body == nil {
    42  			return fmt.Errorf("nil Body")
    43  		}
    44  		if buf, err := io.ReadAll(r.Body); err != nil || len(buf) != 0 {
    45  			return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err)
    46  		}
    47  		return nil
    48  	}
    49  
    50  	errc := make(chan error, 3)
    51  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
    52  		switch r.URL.RequestURI() {
    53  		case "/":
    54  			// Push "/pushed?get" as a GET request, using an absolute URL.
    55  			opt := &http.PushOptions{
    56  				Header: http.Header{
    57  					"User-Agent": {userAgent},
    58  				},
    59  			}
    60  			if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil {
    61  				errc <- fmt.Errorf("error pushing /pushed?get: %v", err)
    62  				return
    63  			}
    64  			// Push "/pushed?head" as a HEAD request, using a path.
    65  			opt = &http.PushOptions{
    66  				Method: "HEAD",
    67  				Header: http.Header{
    68  					"User-Agent": {userAgent},
    69  					"Cookie":     {cookie},
    70  				},
    71  			}
    72  			if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil {
    73  				errc <- fmt.Errorf("error pushing /pushed?head: %v", err)
    74  				return
    75  			}
    76  			w.Header().Set("Content-Type", "text/html")
    77  			w.Header().Set("Content-Length", strconv.Itoa(len(mainBody)))
    78  			w.WriteHeader(200)
    79  			io.WriteString(w, mainBody)
    80  			errc <- nil
    81  
    82  		case "/pushed?get":
    83  			wantH := http.Header{}
    84  			wantH.Set("User-Agent", userAgent)
    85  			if err := checkPromisedReq(r, "GET", wantH); err != nil {
    86  				errc <- fmt.Errorf("/pushed?get: %v", err)
    87  				return
    88  			}
    89  			w.Header().Set("Content-Type", "text/html")
    90  			w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody)))
    91  			w.WriteHeader(200)
    92  			io.WriteString(w, pushedBody)
    93  			errc <- nil
    94  
    95  		case "/pushed?head":
    96  			wantH := http.Header{}
    97  			wantH.Set("User-Agent", userAgent)
    98  			wantH.Set("Cookie", cookie)
    99  			if err := checkPromisedReq(r, "HEAD", wantH); err != nil {
   100  				errc <- fmt.Errorf("/pushed?head: %v", err)
   101  				return
   102  			}
   103  			w.WriteHeader(204)
   104  			errc <- nil
   105  
   106  		default:
   107  			errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
   108  		}
   109  	})
   110  	stURL = "https://" + st.authority()
   111  
   112  	// Send one request, which should push two responses.
   113  	st.greet()
   114  	getSlash(st)
   115  	for k := 0; k < 3; k++ {
   116  		select {
   117  		case <-time.After(2 * time.Second):
   118  			t.Errorf("timeout waiting for handler %d to finish", k)
   119  		case err := <-errc:
   120  			if err != nil {
   121  				t.Fatal(err)
   122  			}
   123  		}
   124  	}
   125  
   126  	checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error {
   127  		pp, ok := f.(*PushPromiseFrame)
   128  		if !ok {
   129  			return fmt.Errorf("got a %T; want *PushPromiseFrame", f)
   130  		}
   131  		if !pp.HeadersEnded() {
   132  			return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame")
   133  		}
   134  		if got, want := pp.PromiseID, promiseID; got != want {
   135  			return fmt.Errorf("got PromiseID %v; want %v", got, want)
   136  		}
   137  		gotH := st.decodeHeader(pp.HeaderBlockFragment())
   138  		if !reflect.DeepEqual(gotH, wantH) {
   139  			return fmt.Errorf("got promised headers %v; want %v", gotH, wantH)
   140  		}
   141  		return nil
   142  	}
   143  	checkHeaders := func(f Frame, wantH [][2]string) error {
   144  		hf, ok := f.(*HeadersFrame)
   145  		if !ok {
   146  			return fmt.Errorf("got a %T; want *HeadersFrame", f)
   147  		}
   148  		gotH := st.decodeHeader(hf.HeaderBlockFragment())
   149  		if !reflect.DeepEqual(gotH, wantH) {
   150  			return fmt.Errorf("got response headers %v; want %v", gotH, wantH)
   151  		}
   152  		return nil
   153  	}
   154  	checkData := func(f Frame, wantData string) error {
   155  		df, ok := f.(*DataFrame)
   156  		if !ok {
   157  			return fmt.Errorf("got a %T; want *DataFrame", f)
   158  		}
   159  		if gotData := string(df.Data()); gotData != wantData {
   160  			return fmt.Errorf("got response data %q; want %q", gotData, wantData)
   161  		}
   162  		return nil
   163  	}
   164  
   165  	// Stream 1 has 2 PUSH_PROMISE + HEADERS + DATA
   166  	// Stream 2 has HEADERS + DATA
   167  	// Stream 4 has HEADERS
   168  	expected := map[uint32][]func(Frame) error{
   169  		1: {
   170  			func(f Frame) error {
   171  				return checkPushPromise(f, 2, [][2]string{
   172  					{":method", "GET"},
   173  					{":scheme", "https"},
   174  					{":authority", st.authority()},
   175  					{":path", "/pushed?get"},
   176  					{"user-agent", userAgent},
   177  				})
   178  			},
   179  			func(f Frame) error {
   180  				return checkPushPromise(f, 4, [][2]string{
   181  					{":method", "HEAD"},
   182  					{":scheme", "https"},
   183  					{":authority", st.authority()},
   184  					{":path", "/pushed?head"},
   185  					{"cookie", cookie},
   186  					{"user-agent", userAgent},
   187  				})
   188  			},
   189  			func(f Frame) error {
   190  				return checkHeaders(f, [][2]string{
   191  					{":status", "200"},
   192  					{"content-type", "text/html"},
   193  					{"content-length", strconv.Itoa(len(mainBody))},
   194  				})
   195  			},
   196  			func(f Frame) error {
   197  				return checkData(f, mainBody)
   198  			},
   199  		},
   200  		2: {
   201  			func(f Frame) error {
   202  				return checkHeaders(f, [][2]string{
   203  					{":status", "200"},
   204  					{"content-type", "text/html"},
   205  					{"content-length", strconv.Itoa(len(pushedBody))},
   206  				})
   207  			},
   208  			func(f Frame) error {
   209  				return checkData(f, pushedBody)
   210  			},
   211  		},
   212  		4: {
   213  			func(f Frame) error {
   214  				return checkHeaders(f, [][2]string{
   215  					{":status", "204"},
   216  				})
   217  			},
   218  		},
   219  	}
   220  
   221  	consumed := map[uint32]int{}
   222  	for k := 0; len(expected) > 0; k++ {
   223  		f := st.readFrame()
   224  		if f == nil {
   225  			for id, left := range expected {
   226  				t.Errorf("stream %d: missing %d frames", id, len(left))
   227  			}
   228  			break
   229  		}
   230  		id := f.Header().StreamID
   231  		label := fmt.Sprintf("stream %d, frame %d", id, consumed[id])
   232  		if len(expected[id]) == 0 {
   233  			t.Fatalf("%s: unexpected frame %#+v", label, f)
   234  		}
   235  		check := expected[id][0]
   236  		expected[id] = expected[id][1:]
   237  		if len(expected[id]) == 0 {
   238  			delete(expected, id)
   239  		}
   240  		if err := check(f); err != nil {
   241  			t.Fatalf("%s: %v", label, err)
   242  		}
   243  		consumed[id]++
   244  	}
   245  }
   246  
   247  func TestServer_Push_SuccessNoRace(t *testing.T) { synctestTest(t, testServer_Push_SuccessNoRace) }
   248  func testServer_Push_SuccessNoRace(t testing.TB) {
   249  	// Regression test for issue #18326. Ensure the request handler can mutate
   250  	// pushed request headers without racing with the PUSH_PROMISE write.
   251  	errc := make(chan error, 2)
   252  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   253  		switch r.URL.RequestURI() {
   254  		case "/":
   255  			opt := &http.PushOptions{
   256  				Header: http.Header{"User-Agent": {"testagent"}},
   257  			}
   258  			if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
   259  				errc <- fmt.Errorf("error pushing: %v", err)
   260  				return
   261  			}
   262  			w.WriteHeader(200)
   263  			errc <- nil
   264  
   265  		case "/pushed":
   266  			// Update request header, ensure there is no race.
   267  			r.Header.Set("User-Agent", "newagent")
   268  			r.Header.Set("Cookie", "cookie")
   269  			w.WriteHeader(200)
   270  			errc <- nil
   271  
   272  		default:
   273  			errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
   274  		}
   275  	})
   276  
   277  	// Send one request, which should push one response.
   278  	st.greet()
   279  	getSlash(st)
   280  	for k := 0; k < 2; k++ {
   281  		select {
   282  		case <-time.After(2 * time.Second):
   283  			t.Errorf("timeout waiting for handler %d to finish", k)
   284  		case err := <-errc:
   285  			if err != nil {
   286  				t.Fatal(err)
   287  			}
   288  		}
   289  	}
   290  }
   291  
   292  func TestServer_Push_RejectRecursivePush(t *testing.T) {
   293  	synctestTest(t, testServer_Push_RejectRecursivePush)
   294  }
   295  func testServer_Push_RejectRecursivePush(t testing.TB) {
   296  	// Expect two requests, but might get three if there's a bug and the second push succeeds.
   297  	errc := make(chan error, 3)
   298  	handler := func(w http.ResponseWriter, r *http.Request) error {
   299  		baseURL := "https://" + r.Host
   300  		switch r.URL.Path {
   301  		case "/":
   302  			if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil {
   303  				return fmt.Errorf("first Push()=%v, want nil", err)
   304  			}
   305  			return nil
   306  
   307  		case "/push1":
   308  			if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want {
   309  				return fmt.Errorf("Push()=%v, want %v", got, want)
   310  			}
   311  			return nil
   312  
   313  		default:
   314  			return fmt.Errorf("unexpected path: %q", r.URL.Path)
   315  		}
   316  	}
   317  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   318  		errc <- handler(w, r)
   319  	})
   320  	defer st.Close()
   321  	st.greet()
   322  	getSlash(st)
   323  	if err := <-errc; err != nil {
   324  		t.Errorf("First request failed: %v", err)
   325  	}
   326  	if err := <-errc; err != nil {
   327  		t.Errorf("Second request failed: %v", err)
   328  	}
   329  }
   330  
   331  func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) {
   332  	synctestTest(t, func(t testing.TB) {
   333  		testServer_Push_RejectSingleRequest_Bubble(t, doPush, settings...)
   334  	})
   335  }
   336  func testServer_Push_RejectSingleRequest_Bubble(t testing.TB, doPush func(http.Pusher, *http.Request) error, settings ...Setting) {
   337  	// Expect one request, but might get two if there's a bug and the push succeeds.
   338  	errc := make(chan error, 2)
   339  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   340  		errc <- doPush(w.(http.Pusher), r)
   341  	})
   342  	defer st.Close()
   343  	st.greet()
   344  	if err := st.fr.WriteSettings(settings...); err != nil {
   345  		st.t.Fatalf("WriteSettings: %v", err)
   346  	}
   347  	st.wantSettingsAck()
   348  	getSlash(st)
   349  	if err := <-errc; err != nil {
   350  		t.Error(err)
   351  	}
   352  	// Should not get a PUSH_PROMISE frame.
   353  	st.wantHeaders(wantHeader{
   354  		streamID:  1,
   355  		endStream: true,
   356  	})
   357  }
   358  
   359  func TestServer_Push_RejectIfDisabled(t *testing.T) {
   360  	testServer_Push_RejectSingleRequest(t,
   361  		func(p http.Pusher, r *http.Request) error {
   362  			if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
   363  				return fmt.Errorf("Push()=%v, want %v", got, want)
   364  			}
   365  			return nil
   366  		},
   367  		Setting{SettingEnablePush, 0})
   368  }
   369  
   370  func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) {
   371  	testServer_Push_RejectSingleRequest(t,
   372  		func(p http.Pusher, r *http.Request) error {
   373  			if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want {
   374  				return fmt.Errorf("Push()=%v, want %v", got, want)
   375  			}
   376  			return nil
   377  		},
   378  		Setting{SettingMaxConcurrentStreams, 0})
   379  }
   380  
   381  func TestServer_Push_RejectWrongScheme(t *testing.T) {
   382  	testServer_Push_RejectSingleRequest(t,
   383  		func(p http.Pusher, r *http.Request) error {
   384  			if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil {
   385  				return errors.New("Push() should have failed (push target URL is http)")
   386  			}
   387  			return nil
   388  		})
   389  }
   390  
   391  func TestServer_Push_RejectMissingHost(t *testing.T) {
   392  	testServer_Push_RejectSingleRequest(t,
   393  		func(p http.Pusher, r *http.Request) error {
   394  			if err := p.Push("https:pushed", nil); err == nil {
   395  				return errors.New("Push() should have failed (push target URL missing host)")
   396  			}
   397  			return nil
   398  		})
   399  }
   400  
   401  func TestServer_Push_RejectRelativePath(t *testing.T) {
   402  	testServer_Push_RejectSingleRequest(t,
   403  		func(p http.Pusher, r *http.Request) error {
   404  			if err := p.Push("../test", nil); err == nil {
   405  				return errors.New("Push() should have failed (push target is a relative path)")
   406  			}
   407  			return nil
   408  		})
   409  }
   410  
   411  func TestServer_Push_RejectForbiddenMethod(t *testing.T) {
   412  	testServer_Push_RejectSingleRequest(t,
   413  		func(p http.Pusher, r *http.Request) error {
   414  			if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil {
   415  				return errors.New("Push() should have failed (cannot promise a POST)")
   416  			}
   417  			return nil
   418  		})
   419  }
   420  
   421  func TestServer_Push_RejectForbiddenHeader(t *testing.T) {
   422  	testServer_Push_RejectSingleRequest(t,
   423  		func(p http.Pusher, r *http.Request) error {
   424  			header := http.Header{
   425  				"Content-Length":   {"10"},
   426  				"Content-Encoding": {"gzip"},
   427  				"Trailer":          {"Foo"},
   428  				"Te":               {"trailers"},
   429  				"Host":             {"test.com"},
   430  				":authority":       {"test.com"},
   431  			}
   432  			if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil {
   433  				return errors.New("Push() should have failed (forbidden headers)")
   434  			}
   435  			return nil
   436  		})
   437  }
   438  
   439  func TestServer_Push_StateTransitions(t *testing.T) {
   440  	synctestTest(t, testServer_Push_StateTransitions)
   441  }
   442  func testServer_Push_StateTransitions(t testing.TB) {
   443  	const body = "foo"
   444  
   445  	gotPromise := make(chan bool)
   446  	finishedPush := make(chan bool)
   447  
   448  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   449  		switch r.URL.RequestURI() {
   450  		case "/":
   451  			if err := w.(http.Pusher).Push("/pushed", nil); err != nil {
   452  				t.Errorf("Push error: %v", err)
   453  			}
   454  			// Don't finish this request until the push finishes so we don't
   455  			// nondeterministically interleave output frames with the push.
   456  			<-finishedPush
   457  		case "/pushed":
   458  			<-gotPromise
   459  		}
   460  		w.Header().Set("Content-Type", "text/html")
   461  		w.Header().Set("Content-Length", strconv.Itoa(len(body)))
   462  		w.WriteHeader(200)
   463  		io.WriteString(w, body)
   464  	})
   465  	defer st.Close()
   466  
   467  	st.greet()
   468  	if st.streamExists(2) {
   469  		t.Fatal("stream 2 should be empty")
   470  	}
   471  	if got, want := st.streamState(2), StateIdle; got != want {
   472  		t.Fatalf("streamState(2)=%v, want %v", got, want)
   473  	}
   474  	getSlash(st)
   475  	// After the PUSH_PROMISE is sent, the stream should be stateHalfClosedRemote.
   476  	_ = readFrame[*PushPromiseFrame](t, st)
   477  	if got, want := st.streamState(2), StateHalfClosedRemote; got != want {
   478  		t.Fatalf("streamState(2)=%v, want %v", got, want)
   479  	}
   480  	// We stall the HTTP handler for "/pushed" until the above check. If we don't
   481  	// stall the handler, then the handler might write HEADERS and DATA and finish
   482  	// the stream before we check st.streamState(2) -- should that happen, we'll
   483  	// see stateClosed and fail the above check.
   484  	close(gotPromise)
   485  	st.wantHeaders(wantHeader{
   486  		streamID:  2,
   487  		endStream: false,
   488  	})
   489  	if got, want := st.streamState(2), StateClosed; got != want {
   490  		t.Fatalf("streamState(2)=%v, want %v", got, want)
   491  	}
   492  	close(finishedPush)
   493  }
   494  
   495  func TestServer_Push_RejectAfterGoAway(t *testing.T) {
   496  	synctestTest(t, testServer_Push_RejectAfterGoAway)
   497  }
   498  func testServer_Push_RejectAfterGoAway(t testing.TB) {
   499  	ready := make(chan struct{})
   500  	errc := make(chan error, 2)
   501  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   502  		<-ready
   503  		if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
   504  			errc <- fmt.Errorf("Push()=%v, want %v", got, want)
   505  		}
   506  		errc <- nil
   507  	})
   508  	defer st.Close()
   509  	st.greet()
   510  	getSlash(st)
   511  
   512  	// Send GOAWAY and wait for it to be processed.
   513  	st.fr.WriteGoAway(1, ErrCodeNo, nil)
   514  	synctest.Wait()
   515  	close(ready)
   516  	if err := <-errc; err != nil {
   517  		t.Error(err)
   518  	}
   519  }
   520  
   521  func TestServer_Push_Underflow(t *testing.T) { synctestTest(t, testServer_Push_Underflow) }
   522  func testServer_Push_Underflow(t testing.TB) {
   523  	// Test for #63511: Send several requests which generate PUSH_PROMISE responses,
   524  	// verify they all complete successfully.
   525  	st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
   526  		switch r.URL.RequestURI() {
   527  		case "/":
   528  			opt := &http.PushOptions{
   529  				Header: http.Header{"User-Agent": {"testagent"}},
   530  			}
   531  			if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
   532  				t.Errorf("error pushing: %v", err)
   533  			}
   534  			w.WriteHeader(200)
   535  		case "/pushed":
   536  			r.Header.Set("User-Agent", "newagent")
   537  			r.Header.Set("Cookie", "cookie")
   538  			w.WriteHeader(200)
   539  		default:
   540  			t.Errorf("unknown RequestURL %q", r.URL.RequestURI())
   541  		}
   542  	})
   543  	// Send several requests.
   544  	st.greet()
   545  	const numRequests = 4
   546  	for i := 0; i < numRequests; i++ {
   547  		st.writeHeaders(HeadersFrameParam{
   548  			StreamID:      uint32(1 + i*2), // clients send odd numbers
   549  			BlockFragment: st.encodeHeader(),
   550  			EndStream:     true,
   551  			EndHeaders:    true,
   552  		})
   553  	}
   554  	// Each request should result in one PUSH_PROMISE and two responses.
   555  	numPushPromises := 0
   556  	numHeaders := 0
   557  	for numHeaders < numRequests*2 || numPushPromises < numRequests {
   558  		f := st.readFrame()
   559  		if f == nil {
   560  			st.t.Fatal("conn is idle, want frame")
   561  		}
   562  		switch f := f.(type) {
   563  		case *HeadersFrame:
   564  			if !f.Flags.Has(FlagHeadersEndStream) {
   565  				t.Fatalf("got HEADERS frame with no END_STREAM, expected END_STREAM: %v", f)
   566  			}
   567  			numHeaders++
   568  		case *PushPromiseFrame:
   569  			numPushPromises++
   570  		}
   571  	}
   572  }
   573  

View as plain text