Source file src/net/http/internal/http2/clientconn_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Infrastructure for testing ClientConn.RoundTrip.
     6  // Put actual tests in transport_test.go.
     7  
     8  package http2_test
     9  
    10  import (
    11  	"bytes"
    12  	"context"
    13  	"crypto/tls"
    14  	"fmt"
    15  	"internal/gate"
    16  	"io"
    17  	"net"
    18  	"net/http"
    19  	. "net/http/internal/http2"
    20  	"reflect"
    21  	"sync/atomic"
    22  	"testing"
    23  	"testing/synctest"
    24  	"time"
    25  	_ "unsafe" // for go:linkname
    26  
    27  	"golang.org/x/net/http2/hpack"
    28  )
    29  
    30  // TestTestClientConn demonstrates usage of testClientConn.
    31  func TestTestClientConn(t *testing.T) { synctestTest(t, testTestClientConn) }
    32  func testTestClientConn(t testing.TB) {
    33  	// newTestClientConn creates a *ClientConn and surrounding test infrastructure.
    34  	tc := newTestClientConn(t)
    35  
    36  	// tc.greet reads the client's initial SETTINGS and WINDOW_UPDATE frames,
    37  	// and sends a SETTINGS frame to the client.
    38  	//
    39  	// Additional settings may be provided as optional parameters to greet.
    40  	tc.greet()
    41  
    42  	// Request bodies must either be constant (bytes.Buffer, strings.Reader)
    43  	// or created with newRequestBody.
    44  	body := tc.newRequestBody()
    45  	body.writeBytes(10)         // 10 arbitrary bytes...
    46  	body.closeWithError(io.EOF) // ...followed by EOF.
    47  
    48  	// tc.roundTrip calls RoundTrip, but does not wait for it to return.
    49  	// It returns a testRoundTrip.
    50  	req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
    51  	rt := tc.roundTrip(req)
    52  
    53  	// tc has a number of methods to check for expected frames sent.
    54  	// Here, we look for headers and the request body.
    55  	tc.wantHeaders(wantHeader{
    56  		streamID:  rt.streamID(),
    57  		endStream: false,
    58  		header: http.Header{
    59  			":authority": []string{"dummy.tld"},
    60  			":method":    []string{"PUT"},
    61  			":path":      []string{"/"},
    62  		},
    63  	})
    64  	// Expect 10 bytes of request body in DATA frames.
    65  	tc.wantData(wantData{
    66  		streamID:  rt.streamID(),
    67  		endStream: true,
    68  		size:      10,
    69  		multiple:  true,
    70  	})
    71  
    72  	// tc.writeHeaders sends a HEADERS frame back to the client.
    73  	tc.writeHeaders(HeadersFrameParam{
    74  		StreamID:   rt.streamID(),
    75  		EndHeaders: true,
    76  		EndStream:  true,
    77  		BlockFragment: tc.makeHeaderBlockFragment(
    78  			":status", "200",
    79  		),
    80  	})
    81  
    82  	// Now that we've received headers, RoundTrip has finished.
    83  	// testRoundTrip has various methods to examine the response,
    84  	// or to fetch the response and/or error returned by RoundTrip
    85  	rt.wantStatus(200)
    86  	rt.wantBody(nil)
    87  }
    88  
    89  // A testClientConn allows testing ClientConn.RoundTrip against a fake server.
    90  //
    91  // A test using testClientConn consists of:
    92  //   - actions on the client (calling RoundTrip, making data available to Request.Body);
    93  //   - validation of frames sent by the client to the server; and
    94  //   - providing frames from the server to the client.
    95  //
    96  // testClientConn manages synchronization, so tests can generally be written as
    97  // a linear sequence of actions and validations without additional synchronization.
    98  type testClientConn struct {
    99  	t testing.TB
   100  
   101  	tr *Transport
   102  	fr *Framer
   103  	cc *ClientConn
   104  	testConnFramer
   105  
   106  	encbuf bytes.Buffer
   107  	enc    *hpack.Encoder
   108  
   109  	roundtrips []*testRoundTrip
   110  
   111  	netconn *synctestNetConn
   112  }
   113  
   114  func newTestClientConnFromClientConn(t testing.TB, tr *Transport, cc *ClientConn) *testClientConn {
   115  	tc := &testClientConn{
   116  		t:  t,
   117  		tr: tr,
   118  		cc: cc,
   119  	}
   120  
   121  	// cli is the conn used by the client under test, srv is the side controlled by the test.
   122  	// We replace the conn being used by the client (possibly a *tls.Conn) with a new one,
   123  	// to avoid dealing with encryption in tests.
   124  	cli, srv := synctestNetPipe()
   125  	cc.TestSetNetConn(cli)
   126  
   127  	srv.SetReadDeadline(time.Now())
   128  	tc.netconn = srv
   129  	tc.enc = hpack.NewEncoder(&tc.encbuf)
   130  	tc.fr = NewFramer(srv, srv)
   131  	tc.testConnFramer = testConnFramer{
   132  		t:   t,
   133  		fr:  tc.fr,
   134  		dec: hpack.NewDecoder(InitialHeaderTableSize, nil),
   135  	}
   136  	tc.fr.SetMaxReadFrameSize(10 << 20)
   137  	t.Cleanup(func() {
   138  		tc.closeWrite()
   139  	})
   140  
   141  	return tc
   142  }
   143  
   144  func (tc *testClientConn) readClientPreface() {
   145  	tc.t.Helper()
   146  	// Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames.
   147  	buf := make([]byte, len(ClientPreface))
   148  	if _, err := io.ReadFull(tc.netconn, buf); err != nil {
   149  		tc.t.Fatalf("reading preface: %v", err)
   150  	}
   151  	if !bytes.Equal(buf, []byte(ClientPreface)) {
   152  		tc.t.Fatalf("client preface: %q, want %q", buf, ClientPreface)
   153  	}
   154  }
   155  
   156  func newTestClientConn(t testing.TB, opts ...any) *testClientConn {
   157  	t.Helper()
   158  
   159  	tt := newTestTransport(t, opts...)
   160  	const singleUse = false
   161  	tr := transportFromH1Transport(tt.tr1).(*Transport)
   162  	_, err := tr.TestNewClientConn(nil, singleUse, nil)
   163  	if err != nil {
   164  		t.Fatalf("newClientConn: %v", err)
   165  	}
   166  
   167  	return tt.getConn()
   168  }
   169  
   170  // hasFrame reports whether a frame is available to be read.
   171  func (tc *testClientConn) hasFrame() bool {
   172  	synctest.Wait()
   173  	return len(tc.netconn.Peek()) > 0
   174  }
   175  
   176  // isClosed reports whether the peer has closed the connection.
   177  func (tc *testClientConn) isClosed() bool {
   178  	synctest.Wait()
   179  	return tc.netconn.IsClosedByPeer()
   180  }
   181  
   182  // closeWrite causes the net.Conn used by the ClientConn to return a error
   183  // from Read calls.
   184  func (tc *testClientConn) closeWrite() {
   185  	tc.netconn.Close()
   186  }
   187  
   188  // closeWrite causes the net.Conn used by the ClientConn to return a error
   189  // from Write calls.
   190  func (tc *testClientConn) closeWriteWithError(err error) {
   191  	tc.netconn.loc.setReadError(io.EOF)
   192  	tc.netconn.loc.setWriteError(err)
   193  }
   194  
   195  // testRequestBody is a Request.Body for use in tests.
   196  type testRequestBody struct {
   197  	tc   *testClientConn
   198  	gate gate.Gate
   199  
   200  	// At most one of buf or bytes can be set at any given time:
   201  	buf   bytes.Buffer // specific bytes to read from the body
   202  	bytes int          // body contains this many arbitrary bytes
   203  
   204  	err error // read error (comes after any available bytes)
   205  }
   206  
   207  func (tc *testClientConn) newRequestBody() *testRequestBody {
   208  	b := &testRequestBody{
   209  		tc:   tc,
   210  		gate: gate.New(false),
   211  	}
   212  	return b
   213  }
   214  
   215  func (b *testRequestBody) unlock() {
   216  	b.gate.Unlock(b.buf.Len() > 0 || b.bytes > 0 || b.err != nil)
   217  }
   218  
   219  // Read is called by the ClientConn to read from a request body.
   220  func (b *testRequestBody) Read(p []byte) (n int, _ error) {
   221  	if err := b.gate.WaitAndLock(context.Background()); err != nil {
   222  		return 0, err
   223  	}
   224  	defer b.unlock()
   225  	switch {
   226  	case b.buf.Len() > 0:
   227  		return b.buf.Read(p)
   228  	case b.bytes > 0:
   229  		if len(p) > b.bytes {
   230  			p = p[:b.bytes]
   231  		}
   232  		b.bytes -= len(p)
   233  		for i := range p {
   234  			p[i] = 'A'
   235  		}
   236  		return len(p), nil
   237  	default:
   238  		return 0, b.err
   239  	}
   240  }
   241  
   242  // Close is called by the ClientConn when it is done reading from a request body.
   243  func (b *testRequestBody) Close() error {
   244  	return nil
   245  }
   246  
   247  // writeBytes adds n arbitrary bytes to the body.
   248  func (b *testRequestBody) writeBytes(n int) {
   249  	defer synctest.Wait()
   250  	b.gate.Lock()
   251  	defer b.unlock()
   252  	b.bytes += n
   253  	b.checkWrite()
   254  	synctest.Wait()
   255  }
   256  
   257  // Write adds bytes to the body.
   258  func (b *testRequestBody) Write(p []byte) (int, error) {
   259  	defer synctest.Wait()
   260  	b.gate.Lock()
   261  	defer b.unlock()
   262  	n, err := b.buf.Write(p)
   263  	b.checkWrite()
   264  	return n, err
   265  }
   266  
   267  func (b *testRequestBody) checkWrite() {
   268  	if b.bytes > 0 && b.buf.Len() > 0 {
   269  		b.tc.t.Fatalf("can't interleave Write and writeBytes on request body")
   270  	}
   271  	if b.err != nil {
   272  		b.tc.t.Fatalf("can't write to request body after closeWithError")
   273  	}
   274  }
   275  
   276  // closeWithError sets an error which will be returned by Read.
   277  func (b *testRequestBody) closeWithError(err error) {
   278  	defer synctest.Wait()
   279  	b.gate.Lock()
   280  	defer b.unlock()
   281  	b.err = err
   282  }
   283  
   284  // roundTrip starts a RoundTrip call.
   285  //
   286  // (Note that the RoundTrip won't complete until response headers are received,
   287  // the request times out, or some other terminal condition is reached.)
   288  func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip {
   289  	ctx, cancel := context.WithCancel(req.Context())
   290  	req = req.WithContext(ctx)
   291  	rt := &testRoundTrip{
   292  		t:      tc.t,
   293  		donec:  make(chan struct{}),
   294  		cancel: cancel,
   295  	}
   296  	tc.roundtrips = append(tc.roundtrips, rt)
   297  	go func() {
   298  		// TODO: This duplicates too much of the net/http RoundTrip flow.
   299  		// We need to do that here because many of the http2 Transport tests
   300  		// rely on having a ClientConn to operate on.
   301  		//
   302  		// We should switch to using net/http.Transport.NewClientConn to create
   303  		// single-target client connections, and move any http2 tests which
   304  		// exercise pooling behavior into net/http.
   305  		defer close(rt.donec)
   306  		cresp := &http.Response{}
   307  		creq := &ClientRequest{
   308  			Context:       req.Context(),
   309  			Method:        req.Method,
   310  			URL:           req.URL,
   311  			Header:        Header(req.Header),
   312  			Trailer:       Header(req.Trailer),
   313  			Body:          req.Body,
   314  			Host:          req.Host,
   315  			GetBody:       req.GetBody,
   316  			ContentLength: req.ContentLength,
   317  			Cancel:        req.Cancel,
   318  			Close:         req.Close,
   319  			ResTrailer:    (*Header)(&cresp.Trailer),
   320  		}
   321  		resp, err := tc.cc.TestRoundTrip(creq, func(id uint32) {
   322  			rt.id.Store(id)
   323  		})
   324  		rt.respErr = err
   325  		if resp != nil {
   326  			cresp.Status = resp.Status + " " + http.StatusText(resp.StatusCode)
   327  			cresp.StatusCode = resp.StatusCode
   328  			cresp.Proto = "HTTP/2.0"
   329  			cresp.ProtoMajor = 2
   330  			cresp.ProtoMinor = 0
   331  			cresp.ContentLength = resp.ContentLength
   332  			cresp.Uncompressed = resp.Uncompressed
   333  			cresp.Header = http.Header(resp.Header)
   334  			cresp.Trailer = http.Header(resp.Trailer)
   335  			cresp.Body = resp.Body
   336  			cresp.TLS = resp.TLS
   337  			cresp.Request = req
   338  			rt.resp = cresp
   339  		}
   340  	}()
   341  	synctest.Wait()
   342  
   343  	tc.t.Cleanup(func() {
   344  		if !rt.done() {
   345  			return
   346  		}
   347  		res, _ := rt.result()
   348  		if res != nil {
   349  			res.Body.Close()
   350  		}
   351  	})
   352  
   353  	return rt
   354  }
   355  
   356  func (tc *testClientConn) greet(settings ...Setting) {
   357  	tc.wantFrameType(FrameSettings)
   358  	tc.wantFrameType(FrameWindowUpdate)
   359  	tc.writeSettings(settings...)
   360  	tc.writeSettingsAck()
   361  	tc.wantFrameType(FrameSettings) // acknowledgement
   362  }
   363  
   364  // makeHeaderBlockFragment encodes headers in a form suitable for inclusion
   365  // in a HEADERS or CONTINUATION frame.
   366  //
   367  // It takes a list of alternating names and values.
   368  func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte {
   369  	if len(s)%2 != 0 {
   370  		tc.t.Fatalf("uneven list of header name/value pairs")
   371  	}
   372  	tc.encbuf.Reset()
   373  	for i := 0; i < len(s); i += 2 {
   374  		tc.enc.WriteField(hpack.HeaderField{Name: s[i], Value: s[i+1]})
   375  	}
   376  	return tc.encbuf.Bytes()
   377  }
   378  
   379  // inflowWindow returns the amount of inbound flow control available for a stream,
   380  // or for the connection if streamID is 0.
   381  func (tc *testClientConn) inflowWindow(streamID uint32) int32 {
   382  	w, err := tc.cc.TestInflowWindow(streamID)
   383  	if err != nil {
   384  		tc.t.Error(err)
   385  	}
   386  	return w
   387  }
   388  
   389  // testRoundTrip manages a RoundTrip in progress.
   390  type testRoundTrip struct {
   391  	t       testing.TB
   392  	resp    *http.Response
   393  	respErr error
   394  	donec   chan struct{}
   395  	id      atomic.Uint32
   396  	cancel  context.CancelFunc
   397  }
   398  
   399  // streamID returns the HTTP/2 stream ID of the request.
   400  func (rt *testRoundTrip) streamID() uint32 {
   401  	id := rt.id.Load()
   402  	if id == 0 {
   403  		panic("stream ID unknown")
   404  	}
   405  	return id
   406  }
   407  
   408  // done reports whether RoundTrip has returned.
   409  func (rt *testRoundTrip) done() bool {
   410  	synctest.Wait()
   411  	select {
   412  	case <-rt.donec:
   413  		return true
   414  	default:
   415  		return false
   416  	}
   417  }
   418  
   419  // result returns the result of the RoundTrip.
   420  func (rt *testRoundTrip) result() (*http.Response, error) {
   421  	t := rt.t
   422  	t.Helper()
   423  	synctest.Wait()
   424  	select {
   425  	case <-rt.donec:
   426  	default:
   427  		t.Fatalf("RoundTrip is not done; want it to be")
   428  	}
   429  	return rt.resp, rt.respErr
   430  }
   431  
   432  // response returns the response of a successful RoundTrip.
   433  // If the RoundTrip unexpectedly failed, it calls t.Fatal.
   434  func (rt *testRoundTrip) response() *http.Response {
   435  	t := rt.t
   436  	t.Helper()
   437  	resp, err := rt.result()
   438  	if err != nil {
   439  		t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr)
   440  	}
   441  	if resp == nil {
   442  		t.Fatalf("RoundTrip returned nil *Response and nil error")
   443  	}
   444  	return resp
   445  }
   446  
   447  // err returns the (possibly nil) error result of RoundTrip.
   448  func (rt *testRoundTrip) err() error {
   449  	t := rt.t
   450  	t.Helper()
   451  	_, err := rt.result()
   452  	return err
   453  }
   454  
   455  // wantStatus indicates the expected response StatusCode.
   456  func (rt *testRoundTrip) wantStatus(want int) {
   457  	t := rt.t
   458  	t.Helper()
   459  	if got := rt.response().StatusCode; got != want {
   460  		t.Fatalf("got response status %v, want %v", got, want)
   461  	}
   462  }
   463  
   464  // readBody reads the contents of the response body.
   465  func (rt *testRoundTrip) readBody() ([]byte, error) {
   466  	t := rt.t
   467  	t.Helper()
   468  	return io.ReadAll(rt.response().Body)
   469  }
   470  
   471  // wantBody indicates the expected response body.
   472  // (Note that this consumes the body.)
   473  func (rt *testRoundTrip) wantBody(want []byte) {
   474  	t := rt.t
   475  	t.Helper()
   476  	got, err := rt.readBody()
   477  	if err != nil {
   478  		t.Fatalf("unexpected error reading response body: %v", err)
   479  	}
   480  	if !bytes.Equal(got, want) {
   481  		t.Fatalf("unexpected response body:\ngot:  %q\nwant: %q", got, want)
   482  	}
   483  }
   484  
   485  // wantHeaders indicates the expected response headers.
   486  func (rt *testRoundTrip) wantHeaders(want http.Header) {
   487  	t := rt.t
   488  	t.Helper()
   489  	res := rt.response()
   490  	if diff := diffHeaders(res.Header, want); diff != "" {
   491  		t.Fatalf("unexpected response headers:\n%v", diff)
   492  	}
   493  }
   494  
   495  // wantTrailers indicates the expected response trailers.
   496  func (rt *testRoundTrip) wantTrailers(want http.Header) {
   497  	t := rt.t
   498  	t.Helper()
   499  	res := rt.response()
   500  	if diff := diffHeaders(res.Trailer, want); diff != "" {
   501  		t.Fatalf("unexpected response trailers:\n%v", diff)
   502  	}
   503  }
   504  
   505  func diffHeaders(got, want http.Header) string {
   506  	// nil and 0-length non-nil are equal.
   507  	if len(got) == 0 && len(want) == 0 {
   508  		return ""
   509  	}
   510  	// We could do a more sophisticated diff here.
   511  	// DeepEqual is good enough for now.
   512  	if reflect.DeepEqual(got, want) {
   513  		return ""
   514  	}
   515  	return fmt.Sprintf("got:  %v\nwant: %v", got, want)
   516  }
   517  
   518  // A testTransport allows testing Transport.RoundTrip against fake servers.
   519  // Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling
   520  // should use testClientConn instead.
   521  type testTransport struct {
   522  	t   testing.TB
   523  	tr1 *http.Transport
   524  
   525  	ccs []*testClientConn
   526  }
   527  
   528  func newTestTransport(t testing.TB, opts ...any) *testTransport {
   529  	t.Helper()
   530  	tt := &testTransport{
   531  		t: t,
   532  	}
   533  
   534  	tr1 := &http.Transport{
   535  		DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
   536  			// This connection will be replaced by newTestClientConnFromClientConn.
   537  			// net/http will perform a TLS handshake on it, though.
   538  			//
   539  			// TODO: We can simplify connection handling if we support
   540  			// returning a non-*tls.Conn from Transport.DialTLSContext,
   541  			// in which case we could have a DialTLSContext function that
   542  			// returns an unencrypted conn.
   543  			cli, srv := synctestNetPipe()
   544  			go func() {
   545  				tlsSrv := tls.Server(srv, testServerTLSConfig)
   546  				if err := tlsSrv.Handshake(); err != nil {
   547  					t.Errorf("unexpected TLS server handshake error: %v", err)
   548  				}
   549  			}()
   550  			return cli, nil
   551  		},
   552  		Protocols:       protocols("h2"),
   553  		TLSClientConfig: testClientTLSConfig,
   554  	}
   555  	for _, o := range opts {
   556  		switch o := o.(type) {
   557  		case nil:
   558  		case func(*http.Transport):
   559  			o(tr1)
   560  		case func(*http.HTTP2Config):
   561  			if tr1.HTTP2 == nil {
   562  				tr1.HTTP2 = &http.HTTP2Config{}
   563  			}
   564  			o(tr1.HTTP2)
   565  		default:
   566  			t.Fatalf("unknown newTestTransport option type %T", o)
   567  		}
   568  	}
   569  	tt.tr1 = tr1
   570  
   571  	tr2 := transportFromH1Transport(tr1).(*Transport)
   572  	tr2.TestSetNewClientConnHook(func(cc *ClientConn) {
   573  		tc := newTestClientConnFromClientConn(t, tr2, cc)
   574  		tt.ccs = append(tt.ccs, tc)
   575  	})
   576  
   577  	t.Cleanup(func() {
   578  		synctest.Wait()
   579  		if len(tt.ccs) > 0 {
   580  			t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs))
   581  		}
   582  	})
   583  
   584  	return tt
   585  }
   586  
   587  func (tt *testTransport) hasConn() bool {
   588  	return len(tt.ccs) > 0
   589  }
   590  
   591  func (tt *testTransport) getConn() *testClientConn {
   592  	tt.t.Helper()
   593  	synctest.Wait()
   594  	if len(tt.ccs) == 0 {
   595  		tt.t.Fatalf("no new ClientConns created; wanted one")
   596  	}
   597  	tc := tt.ccs[0]
   598  	tt.ccs = tt.ccs[1:]
   599  	tc.readClientPreface()
   600  	synctest.Wait()
   601  	return tc
   602  }
   603  
   604  func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip {
   605  	ctx, cancel := context.WithCancel(req.Context())
   606  	req = req.WithContext(ctx)
   607  	rt := &testRoundTrip{
   608  		t:      tt.t,
   609  		donec:  make(chan struct{}),
   610  		cancel: cancel,
   611  	}
   612  	go func() {
   613  		defer close(rt.donec)
   614  		rt.resp, rt.respErr = tt.tr1.RoundTrip(req)
   615  	}()
   616  	synctest.Wait()
   617  
   618  	tt.t.Cleanup(func() {
   619  		if !rt.done() {
   620  			return
   621  		}
   622  		res, _ := rt.result()
   623  		if res != nil {
   624  			res.Body.Close()
   625  		}
   626  	})
   627  
   628  	return rt
   629  }
   630  

View as plain text