Source file src/net/http/internal/http2/connframes_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http2_test
     6  
     7  import (
     8  	"bytes"
     9  	"io"
    10  	"net/http"
    11  	"os"
    12  	"reflect"
    13  	"slices"
    14  	"testing"
    15  	"testing/synctest"
    16  
    17  	. "net/http/internal/http2"
    18  
    19  	"golang.org/x/net/http2/hpack"
    20  )
    21  
    22  type testConnFramer struct {
    23  	t   testing.TB
    24  	fr  *Framer
    25  	dec *hpack.Decoder
    26  }
    27  
    28  // readFrame reads the next frame.
    29  // It returns nil if the conn is closed or no frames are available.
    30  func (tf *testConnFramer) readFrame() Frame {
    31  	tf.t.Helper()
    32  	synctest.Wait()
    33  	fr, err := tf.fr.ReadFrame()
    34  	if err == io.EOF || err == os.ErrDeadlineExceeded {
    35  		return nil
    36  	}
    37  	if err != nil {
    38  		tf.t.Fatalf("ReadFrame: %v", err)
    39  	}
    40  	return fr
    41  }
    42  
    43  type readFramer interface {
    44  	readFrame() Frame
    45  }
    46  
    47  // readFrame reads a frame of a specific type.
    48  func readFrame[T any](t testing.TB, framer readFramer) T {
    49  	t.Helper()
    50  	var v T
    51  	fr := framer.readFrame()
    52  	if fr == nil {
    53  		t.Fatalf("got no frame, want frame %T", v)
    54  	}
    55  	v, ok := fr.(T)
    56  	if !ok {
    57  		t.Fatalf("got frame %T, want %T", fr, v)
    58  	}
    59  	return v
    60  }
    61  
    62  // wantFrameType reads the next frame.
    63  // It produces an error if the frame type is not the expected value.
    64  func (tf *testConnFramer) wantFrameType(want FrameType) {
    65  	tf.t.Helper()
    66  	fr := tf.readFrame()
    67  	if fr == nil {
    68  		tf.t.Fatalf("got no frame, want frame %v", want)
    69  	}
    70  	if got := fr.Header().Type; got != want {
    71  		tf.t.Fatalf("got frame %v, want %v", got, want)
    72  	}
    73  }
    74  
    75  // wantUnorderedFrames reads frames until every condition in want has been satisfied.
    76  //
    77  // want is a list of func(*SomeFrame) bool.
    78  // wantUnorderedFrames will call each func with frames of the appropriate type
    79  // until the func returns true.
    80  // It calls t.Fatal if an unexpected frame is received (no func has that frame type,
    81  // or all funcs with that type have returned true), or if the framer runs out of frames
    82  // with unsatisfied funcs.
    83  //
    84  // Example:
    85  //
    86  //	// Read a SETTINGS frame, and any number of DATA frames for a stream.
    87  //	// The SETTINGS frame may appear anywhere in the sequence.
    88  //	// The last DATA frame must indicate the end of the stream.
    89  //	tf.wantUnorderedFrames(
    90  //		func(f *SettingsFrame) bool {
    91  //			return true
    92  //		},
    93  //		func(f *DataFrame) bool {
    94  //			return f.StreamEnded()
    95  //		},
    96  //	)
    97  func (tf *testConnFramer) wantUnorderedFrames(want ...any) {
    98  	tf.t.Helper()
    99  	want = slices.Clone(want)
   100  	seen := 0
   101  frame:
   102  	for seen < len(want) && !tf.t.Failed() {
   103  		fr := tf.readFrame()
   104  		if fr == nil {
   105  			break
   106  		}
   107  		for i, f := range want {
   108  			if f == nil {
   109  				continue
   110  			}
   111  			typ := reflect.TypeOf(f)
   112  			if typ.Kind() != reflect.Func ||
   113  				typ.NumIn() != 1 ||
   114  				typ.NumOut() != 1 ||
   115  				typ.Out(0) != reflect.TypeFor[bool]() {
   116  				tf.t.Fatalf("expected func(*SomeFrame) bool, got %T", f)
   117  			}
   118  			if typ.In(0) == reflect.TypeOf(fr) {
   119  				out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)})
   120  				if out[0].Bool() {
   121  					want[i] = nil
   122  					seen++
   123  				}
   124  				continue frame
   125  			}
   126  		}
   127  		tf.t.Errorf("got unexpected frame type %T", fr)
   128  	}
   129  	if seen < len(want) {
   130  		for _, f := range want {
   131  			if f == nil {
   132  				continue
   133  			}
   134  			tf.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0))
   135  		}
   136  		tf.t.Fatalf("did not see %v expected frame types", len(want)-seen)
   137  	}
   138  }
   139  
   140  type wantHeader struct {
   141  	streamID  uint32
   142  	endStream bool
   143  	header    http.Header
   144  }
   145  
   146  // wantHeaders reads a HEADERS frame and potential CONTINUATION frames,
   147  // and asserts that they contain the expected headers.
   148  func (tf *testConnFramer) wantHeaders(want wantHeader) {
   149  	tf.t.Helper()
   150  
   151  	hf := readFrame[*HeadersFrame](tf.t, tf)
   152  	if got, want := hf.StreamID, want.streamID; got != want {
   153  		tf.t.Fatalf("got stream ID %v, want %v", got, want)
   154  	}
   155  	if got, want := hf.StreamEnded(), want.endStream; got != want {
   156  		tf.t.Fatalf("got stream ended %v, want %v", got, want)
   157  	}
   158  
   159  	gotHeader := make(http.Header)
   160  	tf.dec.SetEmitFunc(func(hf hpack.HeaderField) {
   161  		gotHeader[hf.Name] = append(gotHeader[hf.Name], hf.Value)
   162  	})
   163  	defer tf.dec.SetEmitFunc(nil)
   164  	if _, err := tf.dec.Write(hf.HeaderBlockFragment()); err != nil {
   165  		tf.t.Fatalf("decoding HEADERS frame: %v", err)
   166  	}
   167  	headersEnded := hf.HeadersEnded()
   168  	for !headersEnded {
   169  		cf := readFrame[*ContinuationFrame](tf.t, tf)
   170  		if cf == nil {
   171  			tf.t.Fatalf("got end of frames, want CONTINUATION")
   172  		}
   173  		if _, err := tf.dec.Write(cf.HeaderBlockFragment()); err != nil {
   174  			tf.t.Fatalf("decoding CONTINUATION frame: %v", err)
   175  		}
   176  		headersEnded = cf.HeadersEnded()
   177  	}
   178  	if err := tf.dec.Close(); err != nil {
   179  		tf.t.Fatalf("hpack decoding error: %v", err)
   180  	}
   181  
   182  	for k, v := range want.header {
   183  		if !reflect.DeepEqual(v, gotHeader[k]) {
   184  			tf.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k])
   185  		}
   186  	}
   187  }
   188  
   189  // decodeHeader supports some older server tests.
   190  // TODO: rewrite those tests to use newer, more convenient test APIs.
   191  func (tf *testConnFramer) decodeHeader(headerBlock []byte) (pairs [][2]string) {
   192  	tf.dec.SetEmitFunc(func(hf hpack.HeaderField) {
   193  		if hf.Name == "date" {
   194  			return
   195  		}
   196  		pairs = append(pairs, [2]string{hf.Name, hf.Value})
   197  	})
   198  	defer tf.dec.SetEmitFunc(nil)
   199  	if _, err := tf.dec.Write(headerBlock); err != nil {
   200  		tf.t.Fatalf("hpack decoding error: %v", err)
   201  	}
   202  	if err := tf.dec.Close(); err != nil {
   203  		tf.t.Fatalf("hpack decoding error: %v", err)
   204  	}
   205  	return pairs
   206  }
   207  
   208  type wantData struct {
   209  	streamID  uint32
   210  	endStream bool
   211  	size      int
   212  	data      []byte
   213  	multiple  bool // data may be spread across multiple DATA frames
   214  }
   215  
   216  // wantData reads zero or more DATA frames, and asserts that they match the expectation.
   217  func (tf *testConnFramer) wantData(want wantData) {
   218  	tf.t.Helper()
   219  	gotSize := 0
   220  	gotEndStream := false
   221  	if want.data != nil {
   222  		want.size = len(want.data)
   223  	}
   224  	var gotData []byte
   225  	for {
   226  		fr := tf.readFrame()
   227  		if fr == nil {
   228  			break
   229  		}
   230  		data, ok := fr.(*DataFrame)
   231  		if !ok {
   232  			tf.t.Fatalf("got frame %T, want DataFrame", fr)
   233  		}
   234  		if want.data != nil {
   235  			gotData = append(gotData, data.Data()...)
   236  		}
   237  		gotSize += len(data.Data())
   238  		if data.StreamEnded() {
   239  			gotEndStream = true
   240  			break
   241  		}
   242  		if !want.endStream && gotSize >= want.size {
   243  			break
   244  		}
   245  		if !want.multiple {
   246  			break
   247  		}
   248  	}
   249  	if gotSize != want.size {
   250  		tf.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size)
   251  	}
   252  	if gotEndStream != want.endStream {
   253  		tf.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream)
   254  	}
   255  	if want.data != nil && !bytes.Equal(gotData, want.data) {
   256  		tf.t.Fatalf("got data %q, want %q", gotData, want.data)
   257  	}
   258  }
   259  
   260  func (tf *testConnFramer) wantRSTStream(streamID uint32, code ErrCode) {
   261  	tf.t.Helper()
   262  	fr := readFrame[*RSTStreamFrame](tf.t, tf)
   263  	if fr.StreamID != streamID || fr.ErrCode != code {
   264  		tf.t.Fatalf("got %v, want RST_STREAM StreamID=%v, code=%v", SummarizeFrame(fr), streamID, code)
   265  	}
   266  }
   267  
   268  func (tf *testConnFramer) wantSettings(want map[SettingID]uint32) {
   269  	fr := readFrame[*SettingsFrame](tf.t, tf)
   270  	if fr.Header().Flags.Has(FlagSettingsAck) {
   271  		tf.t.Errorf("got SETTINGS frame with ACK set, want no ACK")
   272  	}
   273  	for wantID, wantVal := range want {
   274  		gotVal, ok := fr.Value(wantID)
   275  		if !ok {
   276  			tf.t.Errorf("SETTINGS: %v is not set, want %v", wantID, wantVal)
   277  		} else if gotVal != wantVal {
   278  			tf.t.Errorf("SETTINGS: %v is %v, want %v", wantID, gotVal, wantVal)
   279  		}
   280  	}
   281  	if tf.t.Failed() {
   282  		tf.t.Fatalf("%v", fr)
   283  	}
   284  }
   285  
   286  func (tf *testConnFramer) wantSettingsAck() {
   287  	tf.t.Helper()
   288  	fr := readFrame[*SettingsFrame](tf.t, tf)
   289  	if !fr.Header().Flags.Has(FlagSettingsAck) {
   290  		tf.t.Fatal("Settings Frame didn't have ACK set")
   291  	}
   292  }
   293  
   294  func (tf *testConnFramer) wantGoAway(maxStreamID uint32, code ErrCode) {
   295  	tf.t.Helper()
   296  	fr := readFrame[*GoAwayFrame](tf.t, tf)
   297  	if fr.LastStreamID != maxStreamID || fr.ErrCode != code {
   298  		tf.t.Fatalf("got %v, want GOAWAY LastStreamID=%v, code=%v", SummarizeFrame(fr), maxStreamID, code)
   299  	}
   300  }
   301  
   302  func (tf *testConnFramer) wantWindowUpdate(streamID, incr uint32) {
   303  	tf.t.Helper()
   304  	wu := readFrame[*WindowUpdateFrame](tf.t, tf)
   305  	if wu.FrameHeader.StreamID != streamID {
   306  		tf.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID)
   307  	}
   308  	if wu.Increment != incr {
   309  		tf.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr)
   310  	}
   311  }
   312  
   313  func (tf *testConnFramer) wantClosed() {
   314  	tf.t.Helper()
   315  	synctest.Wait()
   316  	fr, err := tf.fr.ReadFrame()
   317  	if err == nil {
   318  		tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr)
   319  	}
   320  	if err == os.ErrDeadlineExceeded {
   321  		tf.t.Fatalf("connection is not closed; want it to be")
   322  	}
   323  }
   324  
   325  func (tf *testConnFramer) wantIdle() {
   326  	tf.t.Helper()
   327  	synctest.Wait()
   328  	fr, err := tf.fr.ReadFrame()
   329  	if err == nil {
   330  		tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr)
   331  	}
   332  	if err != os.ErrDeadlineExceeded {
   333  		tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err)
   334  	}
   335  }
   336  
   337  func (tf *testConnFramer) writeSettings(settings ...Setting) {
   338  	tf.t.Helper()
   339  	if err := tf.fr.WriteSettings(settings...); err != nil {
   340  		tf.t.Fatal(err)
   341  	}
   342  }
   343  
   344  func (tf *testConnFramer) writeSettingsAck() {
   345  	tf.t.Helper()
   346  	if err := tf.fr.WriteSettingsAck(); err != nil {
   347  		tf.t.Fatal(err)
   348  	}
   349  }
   350  
   351  func (tf *testConnFramer) writeData(streamID uint32, endStream bool, data []byte) {
   352  	tf.t.Helper()
   353  	if err := tf.fr.WriteData(streamID, endStream, data); err != nil {
   354  		tf.t.Fatal(err)
   355  	}
   356  }
   357  
   358  func (tf *testConnFramer) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
   359  	tf.t.Helper()
   360  	if err := tf.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
   361  		tf.t.Fatal(err)
   362  	}
   363  }
   364  
   365  func (tf *testConnFramer) writeHeaders(p HeadersFrameParam) {
   366  	tf.t.Helper()
   367  	if err := tf.fr.WriteHeaders(p); err != nil {
   368  		tf.t.Fatal(err)
   369  	}
   370  }
   371  
   372  // writeHeadersMode writes header frames, as modified by mode:
   373  //
   374  //   - noHeader: Don't write the header.
   375  //   - oneHeader: Write a single HEADERS frame.
   376  //   - splitHeader: Write a HEADERS frame and CONTINUATION frame.
   377  func (tf *testConnFramer) writeHeadersMode(mode headerType, p HeadersFrameParam) {
   378  	tf.t.Helper()
   379  	switch mode {
   380  	case noHeader:
   381  	case oneHeader:
   382  		tf.writeHeaders(p)
   383  	case splitHeader:
   384  		if len(p.BlockFragment) < 2 {
   385  			panic("too small")
   386  		}
   387  		contData := p.BlockFragment[1:]
   388  		contEnd := p.EndHeaders
   389  		p.BlockFragment = p.BlockFragment[:1]
   390  		p.EndHeaders = false
   391  		tf.writeHeaders(p)
   392  		tf.writeContinuation(p.StreamID, contEnd, contData)
   393  	default:
   394  		panic("bogus mode")
   395  	}
   396  }
   397  
   398  func (tf *testConnFramer) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) {
   399  	tf.t.Helper()
   400  	if err := tf.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil {
   401  		tf.t.Fatal(err)
   402  	}
   403  }
   404  
   405  func (tf *testConnFramer) writePriority(id uint32, p PriorityParam) {
   406  	if err := tf.fr.WritePriority(id, p); err != nil {
   407  		tf.t.Fatal(err)
   408  	}
   409  }
   410  
   411  func (tf *testConnFramer) writePriorityUpdate(id uint32, p string) {
   412  	if err := tf.fr.WritePriorityUpdate(id, p); err != nil {
   413  		tf.t.Fatal(err)
   414  	}
   415  }
   416  
   417  func (tf *testConnFramer) writeRSTStream(streamID uint32, code ErrCode) {
   418  	tf.t.Helper()
   419  	if err := tf.fr.WriteRSTStream(streamID, code); err != nil {
   420  		tf.t.Fatal(err)
   421  	}
   422  }
   423  
   424  func (tf *testConnFramer) writePing(ack bool, data [8]byte) {
   425  	tf.t.Helper()
   426  	if err := tf.fr.WritePing(ack, data); err != nil {
   427  		tf.t.Fatal(err)
   428  	}
   429  }
   430  
   431  func (tf *testConnFramer) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) {
   432  	tf.t.Helper()
   433  	if err := tf.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
   434  		tf.t.Fatal(err)
   435  	}
   436  }
   437  
   438  func (tf *testConnFramer) writeWindowUpdate(streamID, incr uint32) {
   439  	tf.t.Helper()
   440  	if err := tf.fr.WriteWindowUpdate(streamID, incr); err != nil {
   441  		tf.t.Fatal(err)
   442  	}
   443  }
   444  

View as plain text