1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26 package http2
27
28 import (
29 "bufio"
30 "bytes"
31 "context"
32 "crypto/rand"
33 "crypto/tls"
34 "errors"
35 "fmt"
36 "io"
37 "log"
38 "math"
39 "net"
40 "net/http/internal"
41 "net/http/internal/httpcommon"
42 "net/textproto"
43 "net/url"
44 "os"
45 "reflect"
46 "runtime"
47 "strconv"
48 "strings"
49 "sync"
50 "time"
51
52 "golang.org/x/net/http/httpguts"
53 "golang.org/x/net/http2/hpack"
54 )
55
56 const (
57 prefaceTimeout = 10 * time.Second
58 firstSettingsTimeout = 2 * time.Second
59 handlerChunkWriteSize = 4 << 10
60 defaultMaxStreams = 250
61
62
63
64
65 maxQueuedControlFrames = 10000
66 )
67
68 var (
69 errClientDisconnected = errors.New("client disconnected")
70 errClosedBody = errors.New("body closed by handler")
71 errHandlerComplete = errors.New("http2: request body closed due to handler exiting")
72 errStreamClosed = errors.New("http2: stream closed")
73 )
74
75 var responseWriterStatePool = sync.Pool{
76 New: func() interface{} {
77 rws := &responseWriterState{}
78 rws.bw = bufio.NewWriterSize(chunkWriter{rws}, handlerChunkWriteSize)
79 return rws
80 },
81 }
82
83
84 var (
85 testHookOnConn func()
86 testHookOnPanicMu *sync.Mutex
87 testHookOnPanic func(sc *serverConn, panicVal interface{}) (rePanic bool)
88 )
89
90
91 type Server struct {
92
93
94
95
96 MaxHandlers int
97
98
99
100
101
102
103
104 MaxConcurrentStreams uint32
105
106
107
108
109
110
111 MaxDecoderHeaderTableSize uint32
112
113
114
115
116
117 MaxEncoderHeaderTableSize uint32
118
119
120
121
122
123 MaxReadFrameSize uint32
124
125
126
127 PermitProhibitedCipherSuites bool
128
129
130
131
132
133 IdleTimeout time.Duration
134
135
136
137
138 ReadIdleTimeout time.Duration
139
140
141
142
143 PingTimeout time.Duration
144
145
146
147
148
149 WriteByteTimeout time.Duration
150
151
152
153
154
155
156 MaxUploadBufferPerConnection int32
157
158
159
160
161
162 MaxUploadBufferPerStream int32
163
164
165
166 NewWriteScheduler func() WriteScheduler
167
168
169
170
171
172 CountError func(errType string)
173
174
175
176
177 state *serverInternalState
178 }
179
180 type serverInternalState struct {
181 mu sync.Mutex
182 activeConns map[*serverConn]struct{}
183
184
185
186 errChanPool sync.Pool
187 }
188
189 func (s *serverInternalState) registerConn(sc *serverConn) {
190 if s == nil {
191 return
192 }
193 s.mu.Lock()
194 s.activeConns[sc] = struct{}{}
195 s.mu.Unlock()
196 }
197
198 func (s *serverInternalState) unregisterConn(sc *serverConn) {
199 if s == nil {
200 return
201 }
202 s.mu.Lock()
203 delete(s.activeConns, sc)
204 s.mu.Unlock()
205 }
206
207 func (s *serverInternalState) startGracefulShutdown() {
208 if s == nil {
209 return
210 }
211 s.mu.Lock()
212 for sc := range s.activeConns {
213 sc.startGracefulShutdown()
214 }
215 s.mu.Unlock()
216 }
217
218
219
220 var errChanPool = sync.Pool{
221 New: func() any { return make(chan error, 1) },
222 }
223
224 func (s *serverInternalState) getErrChan() chan error {
225 if s == nil {
226 return errChanPool.Get().(chan error)
227 }
228 return s.errChanPool.Get().(chan error)
229 }
230
231 func (s *serverInternalState) putErrChan(ch chan error) {
232 if s == nil {
233 errChanPool.Put(ch)
234 return
235 }
236 s.errChanPool.Put(ch)
237 }
238
239 func (s *Server) Configure(conf ServerConfig, tcfg *tls.Config) error {
240 s.state = &serverInternalState{
241 activeConns: make(map[*serverConn]struct{}),
242 errChanPool: sync.Pool{New: func() any { return make(chan error, 1) }},
243 }
244
245 if tcfg.CipherSuites != nil && tcfg.MinVersion < tls.VersionTLS13 {
246
247
248
249 haveRequired := false
250 for _, cs := range tcfg.CipherSuites {
251 switch cs {
252 case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
253
254
255 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
256 haveRequired = true
257 }
258 }
259 if !haveRequired {
260 return fmt.Errorf("http2: TLSConfig.CipherSuites is missing an HTTP/2-required AES_128_GCM_SHA256 cipher (need at least one of TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 or TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)")
261 }
262 }
263
264
265
266
267
268
269
270
271 return nil
272 }
273
274 func (s *Server) GracefulShutdown() {
275 s.state.startGracefulShutdown()
276 }
277
278
279 type ServeConnOpts struct {
280
281
282 Context context.Context
283
284
285
286 BaseConfig ServerConfig
287
288
289
290
291 Handler Handler
292
293
294
295 Settings []byte
296
297
298
299 SawClientPreface bool
300 }
301
302 func (o *ServeConnOpts) context() context.Context {
303 if o != nil && o.Context != nil {
304 return o.Context
305 }
306 return context.Background()
307 }
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323 func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
324 if opts == nil {
325 opts = &ServeConnOpts{}
326 }
327
328 var newf func(*serverConn)
329 if inTests {
330
331 newf, _ = opts.Context.Value(NewConnContextKey).(func(*serverConn))
332 }
333
334 s.serveConn(c, opts, newf)
335 }
336
337 type contextKey string
338
339 var (
340 NewConnContextKey = new("NewConnContextKey")
341 ConnectionStateContextKey = new("ConnectionStateContextKey")
342 )
343
344 func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverConn)) {
345 baseCtx, cancel := serverConnBaseContext(c, opts)
346 defer cancel()
347
348 conf := configFromServer(opts.BaseConfig, s)
349 sc := &serverConn{
350 srv: s,
351 hs: opts.BaseConfig,
352 conn: c,
353 baseCtx: baseCtx,
354 remoteAddrStr: c.RemoteAddr().String(),
355 bw: newBufferedWriter(c, conf.WriteByteTimeout),
356 handler: opts.Handler,
357 streams: make(map[uint32]*stream),
358 readFrameCh: make(chan readFrameResult),
359 wantWriteFrameCh: make(chan FrameWriteRequest, 8),
360 serveMsgCh: make(chan interface{}, 8),
361 wroteFrameCh: make(chan frameWriteResult, 1),
362 bodyReadCh: make(chan bodyReadMsg),
363 doneServing: make(chan struct{}),
364 clientMaxStreams: math.MaxUint32,
365 advMaxStreams: uint32(conf.MaxConcurrentStreams),
366 initialStreamSendWindowSize: initialWindowSize,
367 initialStreamRecvWindowSize: int32(conf.MaxReceiveBufferPerStream),
368 maxFrameSize: initialMaxFrameSize,
369 pingTimeout: conf.PingTimeout,
370 countErrorFunc: conf.CountError,
371 serveG: newGoroutineLock(),
372 pushEnabled: true,
373 sawClientPreface: opts.SawClientPreface,
374 }
375 if newf != nil {
376 newf(sc)
377 }
378
379 s.state.registerConn(sc)
380 defer s.state.unregisterConn(sc)
381
382
383
384
385
386
387 if sc.hs.WriteTimeout() > 0 {
388 sc.conn.SetWriteDeadline(time.Time{})
389 }
390
391 switch {
392 case s.NewWriteScheduler != nil:
393 sc.writeSched = s.NewWriteScheduler()
394 case sc.hs.DisableClientPriority():
395 sc.writeSched = newRoundRobinWriteScheduler()
396 default:
397 sc.writeSched = newPriorityWriteSchedulerRFC9218()
398 }
399
400
401
402
403 sc.flow.add(initialWindowSize)
404 sc.inflow.init(initialWindowSize)
405 sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
406 sc.hpackEncoder.SetMaxDynamicTableSizeLimit(uint32(conf.MaxEncoderHeaderTableSize))
407
408 fr := NewFramer(sc.bw, c)
409 if conf.CountError != nil {
410 fr.countError = conf.CountError
411 }
412 fr.ReadMetaHeaders = hpack.NewDecoder(uint32(conf.MaxDecoderHeaderTableSize), nil)
413 fr.MaxHeaderListSize = sc.maxHeaderListSize()
414 fr.SetMaxReadFrameSize(uint32(conf.MaxReadFrameSize))
415 sc.framer = fr
416
417 if tc, ok := c.(connectionStater); ok {
418 sc.tlsState = new(tls.ConnectionState)
419 *sc.tlsState = tc.ConnectionState()
420
421
422 if inTests {
423 f, ok := opts.Context.Value(ConnectionStateContextKey).(func() tls.ConnectionState)
424 if ok {
425 *sc.tlsState = f()
426 }
427 }
428
429
430
431
432
433
434
435
436
437
438
439 if sc.tlsState.Version < tls.VersionTLS12 {
440 sc.rejectConn(ErrCodeInadequateSecurity, "TLS version too low")
441 return
442 }
443
444 if sc.tlsState.ServerName == "" {
445
446
447
448
449
450
451
452
453
454 }
455
456 if !conf.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
457
458
459
460
461
462
463
464
465
466
467 sc.rejectConn(ErrCodeInadequateSecurity, fmt.Sprintf("Prohibited TLS 1.2 Cipher Suite: %x", sc.tlsState.CipherSuite))
468 return
469 }
470 }
471
472 if opts.Settings != nil {
473 fr := &SettingsFrame{
474 FrameHeader: FrameHeader{valid: true},
475 p: opts.Settings,
476 }
477 if err := fr.ForeachSetting(sc.processSetting); err != nil {
478 sc.rejectConn(ErrCodeProtocol, "invalid settings")
479 return
480 }
481 opts.Settings = nil
482 }
483
484 sc.serve(conf)
485 }
486
487 func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) {
488 return context.WithCancel(opts.context())
489 }
490
491 func (sc *serverConn) rejectConn(err ErrCode, debug string) {
492 sc.vlogf("http2: server rejecting conn: %v, %s", err, debug)
493
494 sc.framer.WriteGoAway(0, err, []byte(debug))
495 sc.bw.Flush()
496 sc.conn.Close()
497 }
498
499 type serverConn struct {
500
501 srv *Server
502 hs ServerConfig
503 conn net.Conn
504 bw *bufferedWriter
505 handler Handler
506 baseCtx context.Context
507 framer *Framer
508 doneServing chan struct{}
509 readFrameCh chan readFrameResult
510 wantWriteFrameCh chan FrameWriteRequest
511 wroteFrameCh chan frameWriteResult
512 bodyReadCh chan bodyReadMsg
513 serveMsgCh chan interface{}
514 flow outflow
515 inflow inflow
516 tlsState *tls.ConnectionState
517 remoteAddrStr string
518 writeSched WriteScheduler
519 countErrorFunc func(errType string)
520
521
522 serveG goroutineLock
523 pushEnabled bool
524 sawClientPreface bool
525 sawFirstSettings bool
526 needToSendSettingsAck bool
527 unackedSettings int
528 queuedControlFrames int
529 clientMaxStreams uint32
530 advMaxStreams uint32
531 curClientStreams uint32
532 curPushedStreams uint32
533 curHandlers uint32
534 maxClientStreamID uint32
535 maxPushPromiseID uint32
536 streams map[uint32]*stream
537 unstartedHandlers []unstartedHandler
538 initialStreamSendWindowSize int32
539 initialStreamRecvWindowSize int32
540 maxFrameSize int32
541 peerMaxHeaderListSize uint32
542 canonHeader map[string]string
543 canonHeaderKeysSize int
544 writingFrame bool
545 writingFrameAsync bool
546 needsFrameFlush bool
547 inGoAway bool
548 inFrameScheduleLoop bool
549 needToSendGoAway bool
550 pingSent bool
551 sentPingData [8]byte
552 goAwayCode ErrCode
553 shutdownTimer *time.Timer
554 idleTimer *time.Timer
555 readIdleTimeout time.Duration
556 pingTimeout time.Duration
557 readIdleTimer *time.Timer
558
559
560 headerWriteBuf bytes.Buffer
561 hpackEncoder *hpack.Encoder
562
563
564 shutdownOnce sync.Once
565
566
567 hasIntermediary bool
568 priorityAware bool
569 }
570
571 func (sc *serverConn) writeSchedIgnoresRFC7540() bool {
572 switch sc.writeSched.(type) {
573 case *priorityWriteSchedulerRFC9218:
574 return true
575 case *randomWriteScheduler:
576 return true
577 case *roundRobinWriteScheduler:
578 return true
579 default:
580 return false
581 }
582 }
583
584 const DefaultMaxHeaderBytes = 1 << 20
585
586 func (sc *serverConn) maxHeaderListSize() uint32 {
587 n := sc.hs.MaxHeaderBytes()
588 if n <= 0 {
589 n = DefaultMaxHeaderBytes
590 }
591 return uint32(adjustHTTP1MaxHeaderSize(int64(n)))
592 }
593
594 func (sc *serverConn) curOpenStreams() uint32 {
595 sc.serveG.check()
596 return sc.curClientStreams + sc.curPushedStreams
597 }
598
599
600
601
602
603
604
605
606 type stream struct {
607
608 sc *serverConn
609 id uint32
610 body *pipe
611 cw closeWaiter
612 ctx context.Context
613 cancelCtx func()
614
615
616 bodyBytes int64
617 declBodyBytes int64
618 flow outflow
619 inflow inflow
620 state streamState
621 resetQueued bool
622 gotTrailerHeader bool
623 wroteHeaders bool
624 readDeadline *time.Timer
625 writeDeadline *time.Timer
626 closeErr error
627
628 trailer Header
629 reqTrailer Header
630 }
631
632 func (sc *serverConn) Framer() *Framer { return sc.framer }
633 func (sc *serverConn) CloseConn() error { return sc.conn.Close() }
634 func (sc *serverConn) Flush() error { return sc.bw.Flush() }
635 func (sc *serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) {
636 return sc.hpackEncoder, &sc.headerWriteBuf
637 }
638
639 func (sc *serverConn) state(streamID uint32) (streamState, *stream) {
640 sc.serveG.check()
641
642 if st, ok := sc.streams[streamID]; ok {
643 return st.state, st
644 }
645
646
647
648
649
650
651 if streamID%2 == 1 {
652 if streamID <= sc.maxClientStreamID {
653 return stateClosed, nil
654 }
655 } else {
656 if streamID <= sc.maxPushPromiseID {
657 return stateClosed, nil
658 }
659 }
660 return stateIdle, nil
661 }
662
663
664
665
666 func (sc *serverConn) setConnState(state ConnState) {
667 sc.hs.ConnState(sc.conn, state)
668 }
669
670 func (sc *serverConn) vlogf(format string, args ...interface{}) {
671 if VerboseLogs {
672 sc.logf(format, args...)
673 }
674 }
675
676 func (sc *serverConn) logf(format string, args ...interface{}) {
677 if lg := sc.hs.ErrorLog(); lg != nil {
678 lg.Printf(format, args...)
679 } else {
680 log.Printf(format, args...)
681 }
682 }
683
684
685
686
687
688 func errno(v error) uintptr {
689 if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr {
690 return uintptr(rv.Uint())
691 }
692 return 0
693 }
694
695
696
697 func isClosedConnError(err error) bool {
698 if err == nil {
699 return false
700 }
701
702 if errors.Is(err, net.ErrClosed) {
703 return true
704 }
705
706
707
708
709
710 if runtime.GOOS == "windows" {
711 if oe, ok := err.(*net.OpError); ok && oe.Op == "read" {
712 if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" {
713 const WSAECONNABORTED = 10053
714 const WSAECONNRESET = 10054
715 if n := errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED {
716 return true
717 }
718 }
719 }
720 }
721 return false
722 }
723
724 func (sc *serverConn) condlogf(err error, format string, args ...interface{}) {
725 if err == nil {
726 return
727 }
728 if err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err) || err == errPrefaceTimeout {
729
730 sc.vlogf(format, args...)
731 } else {
732 sc.logf(format, args...)
733 }
734 }
735
736
737
738
739
740
741 const maxCachedCanonicalHeadersKeysSize = 2048
742
743 func (sc *serverConn) canonicalHeader(v string) string {
744 sc.serveG.check()
745 cv, ok := httpcommon.CachedCanonicalHeader(v)
746 if ok {
747 return cv
748 }
749 cv, ok = sc.canonHeader[v]
750 if ok {
751 return cv
752 }
753 if sc.canonHeader == nil {
754 sc.canonHeader = make(map[string]string)
755 }
756 cv = textproto.CanonicalMIMEHeaderKey(v)
757 size := 100 + len(v)*2
758 if sc.canonHeaderKeysSize+size <= maxCachedCanonicalHeadersKeysSize {
759 sc.canonHeader[v] = cv
760 sc.canonHeaderKeysSize += size
761 }
762 return cv
763 }
764
765 type readFrameResult struct {
766 f Frame
767 err error
768
769
770
771
772 readMore func()
773 }
774
775
776
777
778
779 func (sc *serverConn) readFrames() {
780 gate := make(chan struct{})
781 gateDone := func() { gate <- struct{}{} }
782 for {
783 f, err := sc.framer.ReadFrame()
784 select {
785 case sc.readFrameCh <- readFrameResult{f, err, gateDone}:
786 case <-sc.doneServing:
787 return
788 }
789 select {
790 case <-gate:
791 case <-sc.doneServing:
792 return
793 }
794 if terminalReadFrameError(err) {
795 return
796 }
797 }
798 }
799
800
801 type frameWriteResult struct {
802 _ incomparable
803 wr FrameWriteRequest
804 err error
805 }
806
807
808
809
810
811 func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest, wd *writeData) {
812 var err error
813 if wd == nil {
814 err = wr.write.writeFrame(sc)
815 } else {
816 err = sc.framer.endWrite()
817 }
818 sc.wroteFrameCh <- frameWriteResult{wr: wr, err: err}
819 }
820
821 func (sc *serverConn) closeAllStreamsOnConnClose() {
822 sc.serveG.check()
823 for _, st := range sc.streams {
824 sc.closeStream(st, errClientDisconnected)
825 }
826 }
827
828 func (sc *serverConn) stopShutdownTimer() {
829 sc.serveG.check()
830 if t := sc.shutdownTimer; t != nil {
831 t.Stop()
832 }
833 }
834
835 func (sc *serverConn) notePanic() {
836
837 if testHookOnPanicMu != nil {
838 testHookOnPanicMu.Lock()
839 defer testHookOnPanicMu.Unlock()
840 }
841 if testHookOnPanic != nil {
842 if e := recover(); e != nil {
843 if testHookOnPanic(sc, e) {
844 panic(e)
845 }
846 }
847 }
848 }
849
850 func (sc *serverConn) serve(conf Config) {
851 sc.serveG.check()
852 defer sc.notePanic()
853 defer sc.conn.Close()
854 defer sc.closeAllStreamsOnConnClose()
855 defer sc.stopShutdownTimer()
856 defer close(sc.doneServing)
857
858 if VerboseLogs {
859 sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
860 }
861
862 settings := writeSettings{
863 {SettingMaxFrameSize, uint32(conf.MaxReadFrameSize)},
864 {SettingMaxConcurrentStreams, sc.advMaxStreams},
865 {SettingMaxHeaderListSize, sc.maxHeaderListSize()},
866 {SettingHeaderTableSize, uint32(conf.MaxDecoderHeaderTableSize)},
867 {SettingInitialWindowSize, uint32(sc.initialStreamRecvWindowSize)},
868 }
869 if !disableExtendedConnectProtocol {
870 settings = append(settings, Setting{SettingEnableConnectProtocol, 1})
871 }
872 if sc.writeSchedIgnoresRFC7540() {
873 settings = append(settings, Setting{SettingNoRFC7540Priorities, 1})
874 }
875 sc.writeFrame(FrameWriteRequest{
876 write: settings,
877 })
878 sc.unackedSettings++
879
880
881
882 if diff := conf.MaxReceiveBufferPerConnection - initialWindowSize; diff > 0 {
883 sc.sendWindowUpdate(nil, int(diff))
884 }
885
886 if err := sc.readPreface(); err != nil {
887 sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err)
888 return
889 }
890
891
892
893
894 sc.setConnState(ConnStateActive)
895 sc.setConnState(ConnStateIdle)
896
897 if sc.srv.IdleTimeout > 0 {
898 sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
899 defer sc.idleTimer.Stop()
900 }
901
902 if conf.SendPingTimeout > 0 {
903 sc.readIdleTimeout = conf.SendPingTimeout
904 sc.readIdleTimer = time.AfterFunc(conf.SendPingTimeout, sc.onReadIdleTimer)
905 defer sc.readIdleTimer.Stop()
906 }
907
908 go sc.readFrames()
909
910 settingsTimer := time.AfterFunc(firstSettingsTimeout, sc.onSettingsTimer)
911 defer settingsTimer.Stop()
912
913 lastFrameTime := time.Now()
914 loopNum := 0
915 for {
916 loopNum++
917 select {
918 case wr := <-sc.wantWriteFrameCh:
919 if se, ok := wr.write.(StreamError); ok {
920 sc.resetStream(se)
921 break
922 }
923 sc.writeFrame(wr)
924 case res := <-sc.wroteFrameCh:
925 sc.wroteFrame(res)
926 case res := <-sc.readFrameCh:
927 lastFrameTime = time.Now()
928
929
930 if sc.writingFrameAsync {
931 select {
932 case wroteRes := <-sc.wroteFrameCh:
933 sc.wroteFrame(wroteRes)
934 default:
935 }
936 }
937 if !sc.processFrameFromReader(res) {
938 return
939 }
940 res.readMore()
941 if settingsTimer != nil {
942 settingsTimer.Stop()
943 settingsTimer = nil
944 }
945 case m := <-sc.bodyReadCh:
946 sc.noteBodyRead(m.st, m.n)
947 case msg := <-sc.serveMsgCh:
948 switch v := msg.(type) {
949 case func(int):
950 v(loopNum)
951 case *serverMessage:
952 switch v {
953 case settingsTimerMsg:
954 sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
955 return
956 case idleTimerMsg:
957 sc.vlogf("connection is idle")
958 sc.goAway(ErrCodeNo)
959 case readIdleTimerMsg:
960 sc.handlePingTimer(lastFrameTime)
961 case shutdownTimerMsg:
962 sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
963 return
964 case gracefulShutdownMsg:
965 sc.startGracefulShutdownInternal()
966 case handlerDoneMsg:
967 sc.handlerDone()
968 default:
969 panic("unknown timer")
970 }
971 case *startPushRequest:
972 sc.startPush(v)
973 case func(*serverConn):
974 v(sc)
975 default:
976 panic(fmt.Sprintf("unexpected type %T", v))
977 }
978 }
979
980
981
982
983 if sc.queuedControlFrames > maxQueuedControlFrames {
984 sc.vlogf("http2: too many control frames in send queue, closing connection")
985 return
986 }
987
988
989
990
991 sentGoAway := sc.inGoAway && !sc.needToSendGoAway && !sc.writingFrame
992 gracefulShutdownComplete := sc.goAwayCode == ErrCodeNo && sc.curOpenStreams() == 0
993 if sentGoAway && sc.shutdownTimer == nil && (sc.goAwayCode != ErrCodeNo || gracefulShutdownComplete) {
994 sc.shutDownIn(goAwayTimeout)
995 }
996 }
997 }
998
999 func (sc *serverConn) handlePingTimer(lastFrameReadTime time.Time) {
1000 if sc.pingSent {
1001 sc.logf("timeout waiting for PING response")
1002 if f := sc.countErrorFunc; f != nil {
1003 f("conn_close_lost_ping")
1004 }
1005 sc.conn.Close()
1006 return
1007 }
1008
1009 pingAt := lastFrameReadTime.Add(sc.readIdleTimeout)
1010 now := time.Now()
1011 if pingAt.After(now) {
1012
1013
1014 sc.readIdleTimer.Reset(pingAt.Sub(now))
1015 return
1016 }
1017
1018 sc.pingSent = true
1019
1020
1021 _, _ = rand.Read(sc.sentPingData[:])
1022 sc.writeFrame(FrameWriteRequest{
1023 write: &writePing{data: sc.sentPingData},
1024 })
1025 sc.readIdleTimer.Reset(sc.pingTimeout)
1026 }
1027
1028 type serverMessage int
1029
1030
1031 var (
1032 settingsTimerMsg = new(serverMessage)
1033 idleTimerMsg = new(serverMessage)
1034 readIdleTimerMsg = new(serverMessage)
1035 shutdownTimerMsg = new(serverMessage)
1036 gracefulShutdownMsg = new(serverMessage)
1037 handlerDoneMsg = new(serverMessage)
1038 )
1039
1040 func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
1041 func (sc *serverConn) onIdleTimer() { sc.sendServeMsg(idleTimerMsg) }
1042 func (sc *serverConn) onReadIdleTimer() { sc.sendServeMsg(readIdleTimerMsg) }
1043 func (sc *serverConn) onShutdownTimer() { sc.sendServeMsg(shutdownTimerMsg) }
1044
1045 func (sc *serverConn) sendServeMsg(msg interface{}) {
1046 sc.serveG.checkNotOn()
1047 select {
1048 case sc.serveMsgCh <- msg:
1049 case <-sc.doneServing:
1050 }
1051 }
1052
1053 var errPrefaceTimeout = errors.New("timeout waiting for client preface")
1054
1055
1056
1057
1058 func (sc *serverConn) readPreface() error {
1059 if sc.sawClientPreface {
1060 return nil
1061 }
1062 errc := make(chan error, 1)
1063 go func() {
1064
1065 buf := make([]byte, len(ClientPreface))
1066 if _, err := io.ReadFull(sc.conn, buf); err != nil {
1067 errc <- err
1068 } else if !bytes.Equal(buf, clientPreface) {
1069 errc <- fmt.Errorf("bogus greeting %q", buf)
1070 } else {
1071 errc <- nil
1072 }
1073 }()
1074 timer := time.NewTimer(prefaceTimeout)
1075 defer timer.Stop()
1076 select {
1077 case <-timer.C:
1078 return errPrefaceTimeout
1079 case err := <-errc:
1080 if err == nil {
1081 if VerboseLogs {
1082 sc.vlogf("http2: server: client %v said hello", sc.conn.RemoteAddr())
1083 }
1084 }
1085 return err
1086 }
1087 }
1088
1089 var writeDataPool = sync.Pool{
1090 New: func() interface{} { return new(writeData) },
1091 }
1092
1093
1094
1095 func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStream bool) error {
1096 ch := sc.srv.state.getErrChan()
1097 writeArg := writeDataPool.Get().(*writeData)
1098 *writeArg = writeData{stream.id, data, endStream}
1099 err := sc.writeFrameFromHandler(FrameWriteRequest{
1100 write: writeArg,
1101 stream: stream,
1102 done: ch,
1103 })
1104 if err != nil {
1105 return err
1106 }
1107 var frameWriteDone bool
1108 select {
1109 case err = <-ch:
1110 frameWriteDone = true
1111 case <-sc.doneServing:
1112 return errClientDisconnected
1113 case <-stream.cw:
1114
1115
1116
1117
1118
1119
1120
1121 select {
1122 case err = <-ch:
1123 frameWriteDone = true
1124 default:
1125 return errStreamClosed
1126 }
1127 }
1128 sc.srv.state.putErrChan(ch)
1129 if frameWriteDone {
1130 writeDataPool.Put(writeArg)
1131 }
1132 return err
1133 }
1134
1135
1136
1137
1138
1139
1140
1141
1142 func (sc *serverConn) writeFrameFromHandler(wr FrameWriteRequest) error {
1143 sc.serveG.checkNotOn()
1144 select {
1145 case sc.wantWriteFrameCh <- wr:
1146 return nil
1147 case <-sc.doneServing:
1148
1149
1150 return errClientDisconnected
1151 }
1152 }
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162 func (sc *serverConn) writeFrame(wr FrameWriteRequest) {
1163 sc.serveG.check()
1164
1165
1166 var ignoreWrite bool
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186 if wr.StreamID() != 0 {
1187 _, isReset := wr.write.(StreamError)
1188 if state, _ := sc.state(wr.StreamID()); state == stateClosed && !isReset {
1189 ignoreWrite = true
1190 }
1191 }
1192
1193
1194
1195 switch wr.write.(type) {
1196 case *writeResHeaders:
1197 wr.stream.wroteHeaders = true
1198 case write100ContinueHeadersFrame:
1199 if wr.stream.wroteHeaders {
1200
1201
1202 if wr.done != nil {
1203 panic("wr.done != nil for write100ContinueHeadersFrame")
1204 }
1205 ignoreWrite = true
1206 }
1207 }
1208
1209 if !ignoreWrite {
1210 if wr.isControl() {
1211 sc.queuedControlFrames++
1212
1213
1214 if sc.queuedControlFrames < 0 {
1215 sc.conn.Close()
1216 }
1217 }
1218 sc.writeSched.Push(wr)
1219 }
1220 sc.scheduleFrameWrite()
1221 }
1222
1223
1224
1225
1226 func (sc *serverConn) startFrameWrite(wr FrameWriteRequest) {
1227 sc.serveG.check()
1228 if sc.writingFrame {
1229 panic("internal error: can only be writing one frame at a time")
1230 }
1231
1232 st := wr.stream
1233 if st != nil {
1234 switch st.state {
1235 case stateHalfClosedLocal:
1236 switch wr.write.(type) {
1237 case StreamError, handlerPanicRST, writeWindowUpdate:
1238
1239
1240 default:
1241 panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr))
1242 }
1243 case stateClosed:
1244 panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr))
1245 }
1246 }
1247 if wpp, ok := wr.write.(*writePushPromise); ok {
1248 var err error
1249 wpp.promisedID, err = wpp.allocatePromisedID()
1250 if err != nil {
1251 sc.writingFrameAsync = false
1252 wr.replyToWriter(err)
1253 return
1254 }
1255 }
1256
1257 sc.writingFrame = true
1258 sc.needsFrameFlush = true
1259 if wr.write.staysWithinBuffer(sc.bw.Available()) {
1260 sc.writingFrameAsync = false
1261 err := wr.write.writeFrame(sc)
1262 sc.wroteFrame(frameWriteResult{wr: wr, err: err})
1263 } else if wd, ok := wr.write.(*writeData); ok {
1264
1265
1266
1267 sc.framer.startWriteDataPadded(wd.streamID, wd.endStream, wd.p, nil)
1268 sc.writingFrameAsync = true
1269 go sc.writeFrameAsync(wr, wd)
1270 } else {
1271 sc.writingFrameAsync = true
1272 go sc.writeFrameAsync(wr, nil)
1273 }
1274 }
1275
1276
1277
1278
1279 var errHandlerPanicked = errors.New("http2: handler panicked")
1280
1281
1282
1283 func (sc *serverConn) wroteFrame(res frameWriteResult) {
1284 sc.serveG.check()
1285 if !sc.writingFrame {
1286 panic("internal error: expected to be already writing a frame")
1287 }
1288 sc.writingFrame = false
1289 sc.writingFrameAsync = false
1290
1291 if res.err != nil {
1292 sc.conn.Close()
1293 }
1294
1295 wr := res.wr
1296
1297 if writeEndsStream(wr.write) {
1298 st := wr.stream
1299 if st == nil {
1300 panic("internal error: expecting non-nil stream")
1301 }
1302 switch st.state {
1303 case stateOpen:
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314 st.state = stateHalfClosedLocal
1315
1316
1317
1318
1319 sc.resetStream(streamError(st.id, ErrCodeNo))
1320 case stateHalfClosedRemote:
1321 sc.closeStream(st, errHandlerComplete)
1322 }
1323 } else {
1324 switch v := wr.write.(type) {
1325 case StreamError:
1326
1327 if st, ok := sc.streams[v.StreamID]; ok {
1328 sc.closeStream(st, v)
1329 }
1330 case handlerPanicRST:
1331 sc.closeStream(wr.stream, errHandlerPanicked)
1332 }
1333 }
1334
1335
1336 wr.replyToWriter(res.err)
1337
1338 sc.scheduleFrameWrite()
1339 }
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351 func (sc *serverConn) scheduleFrameWrite() {
1352 sc.serveG.check()
1353 if sc.writingFrame || sc.inFrameScheduleLoop {
1354 return
1355 }
1356 sc.inFrameScheduleLoop = true
1357 for !sc.writingFrameAsync {
1358 if sc.needToSendGoAway {
1359 sc.needToSendGoAway = false
1360 sc.startFrameWrite(FrameWriteRequest{
1361 write: &writeGoAway{
1362 maxStreamID: sc.maxClientStreamID,
1363 code: sc.goAwayCode,
1364 },
1365 })
1366 continue
1367 }
1368 if sc.needToSendSettingsAck {
1369 sc.needToSendSettingsAck = false
1370 sc.startFrameWrite(FrameWriteRequest{write: writeSettingsAck{}})
1371 continue
1372 }
1373 if !sc.inGoAway || sc.goAwayCode == ErrCodeNo {
1374 if wr, ok := sc.writeSched.Pop(); ok {
1375 if wr.isControl() {
1376 sc.queuedControlFrames--
1377 }
1378 sc.startFrameWrite(wr)
1379 continue
1380 }
1381 }
1382 if sc.needsFrameFlush {
1383 sc.startFrameWrite(FrameWriteRequest{write: flushFrameWriter{}})
1384 sc.needsFrameFlush = false
1385 continue
1386 }
1387 break
1388 }
1389 sc.inFrameScheduleLoop = false
1390 }
1391
1392
1393
1394
1395
1396
1397
1398
1399 func (sc *serverConn) startGracefulShutdown() {
1400 sc.serveG.checkNotOn()
1401 sc.shutdownOnce.Do(func() { sc.sendServeMsg(gracefulShutdownMsg) })
1402 }
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420 var goAwayTimeout = 1 * time.Second
1421
1422 func (sc *serverConn) startGracefulShutdownInternal() {
1423 sc.goAway(ErrCodeNo)
1424 }
1425
1426 func (sc *serverConn) goAway(code ErrCode) {
1427 sc.serveG.check()
1428 if sc.inGoAway {
1429 if sc.goAwayCode == ErrCodeNo {
1430 sc.goAwayCode = code
1431 }
1432 return
1433 }
1434 sc.inGoAway = true
1435 sc.needToSendGoAway = true
1436 sc.goAwayCode = code
1437 sc.scheduleFrameWrite()
1438 }
1439
1440 func (sc *serverConn) shutDownIn(d time.Duration) {
1441 sc.serveG.check()
1442 sc.shutdownTimer = time.AfterFunc(d, sc.onShutdownTimer)
1443 }
1444
1445 func (sc *serverConn) resetStream(se StreamError) {
1446 sc.serveG.check()
1447 sc.writeFrame(FrameWriteRequest{write: se})
1448 if st, ok := sc.streams[se.StreamID]; ok {
1449 st.resetQueued = true
1450 }
1451 }
1452
1453
1454
1455
1456 func (sc *serverConn) processFrameFromReader(res readFrameResult) bool {
1457 sc.serveG.check()
1458 err := res.err
1459 if err != nil {
1460 if err == ErrFrameTooLarge {
1461 sc.goAway(ErrCodeFrameSize)
1462 return true
1463 }
1464 clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err)
1465 if clientGone {
1466
1467
1468
1469
1470
1471
1472
1473
1474 return false
1475 }
1476 } else {
1477 f := res.f
1478 if VerboseLogs {
1479 sc.vlogf("http2: server read frame %v", summarizeFrame(f))
1480 }
1481 err = sc.processFrame(f)
1482 if err == nil {
1483 return true
1484 }
1485 }
1486
1487 switch ev := err.(type) {
1488 case StreamError:
1489 sc.resetStream(ev)
1490 return true
1491 case goAwayFlowError:
1492 sc.goAway(ErrCodeFlowControl)
1493 return true
1494 case ConnectionError:
1495 if res.f != nil {
1496 if id := res.f.Header().StreamID; id > sc.maxClientStreamID {
1497 sc.maxClientStreamID = id
1498 }
1499 }
1500 sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev)
1501 sc.goAway(ErrCode(ev))
1502 return true
1503 default:
1504 if res.err != nil {
1505 sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err)
1506 } else {
1507 sc.logf("http2: server closing client connection: %v", err)
1508 }
1509 return false
1510 }
1511 }
1512
1513 func (sc *serverConn) processFrame(f Frame) error {
1514 sc.serveG.check()
1515
1516
1517 if !sc.sawFirstSettings {
1518 if _, ok := f.(*SettingsFrame); !ok {
1519 return sc.countError("first_settings", ConnectionError(ErrCodeProtocol))
1520 }
1521 sc.sawFirstSettings = true
1522 }
1523
1524
1525
1526
1527
1528 if sc.inGoAway && (sc.goAwayCode != ErrCodeNo || f.Header().StreamID > sc.maxClientStreamID) {
1529
1530 if f, ok := f.(*DataFrame); ok {
1531 if !sc.inflow.take(f.Length) {
1532 return sc.countError("data_flow", streamError(f.Header().StreamID, ErrCodeFlowControl))
1533 }
1534 sc.sendWindowUpdate(nil, int(f.Length))
1535 }
1536 return nil
1537 }
1538
1539 switch f := f.(type) {
1540 case *SettingsFrame:
1541 return sc.processSettings(f)
1542 case *MetaHeadersFrame:
1543 return sc.processHeaders(f)
1544 case *WindowUpdateFrame:
1545 return sc.processWindowUpdate(f)
1546 case *PingFrame:
1547 return sc.processPing(f)
1548 case *DataFrame:
1549 return sc.processData(f)
1550 case *RSTStreamFrame:
1551 return sc.processResetStream(f)
1552 case *PriorityFrame:
1553 return sc.processPriority(f)
1554 case *GoAwayFrame:
1555 return sc.processGoAway(f)
1556 case *PushPromiseFrame:
1557
1558
1559 return sc.countError("push_promise", ConnectionError(ErrCodeProtocol))
1560 case *PriorityUpdateFrame:
1561 return sc.processPriorityUpdate(f)
1562 default:
1563 sc.vlogf("http2: server ignoring frame: %v", f.Header())
1564 return nil
1565 }
1566 }
1567
1568 func (sc *serverConn) processPing(f *PingFrame) error {
1569 sc.serveG.check()
1570 if f.IsAck() {
1571 if sc.pingSent && sc.sentPingData == f.Data {
1572
1573 sc.pingSent = false
1574 sc.readIdleTimer.Reset(sc.readIdleTimeout)
1575 }
1576
1577
1578 return nil
1579 }
1580 if f.StreamID != 0 {
1581
1582
1583
1584
1585
1586 return sc.countError("ping_on_stream", ConnectionError(ErrCodeProtocol))
1587 }
1588 sc.writeFrame(FrameWriteRequest{write: writePingAck{f}})
1589 return nil
1590 }
1591
1592 func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error {
1593 sc.serveG.check()
1594 switch {
1595 case f.StreamID != 0:
1596 state, st := sc.state(f.StreamID)
1597 if state == stateIdle {
1598
1599
1600
1601
1602 return sc.countError("stream_idle", ConnectionError(ErrCodeProtocol))
1603 }
1604 if st == nil {
1605
1606
1607
1608
1609
1610 return nil
1611 }
1612 if !st.flow.add(int32(f.Increment)) {
1613 return sc.countError("bad_flow", streamError(f.StreamID, ErrCodeFlowControl))
1614 }
1615 default:
1616 if !sc.flow.add(int32(f.Increment)) {
1617 return goAwayFlowError{}
1618 }
1619 }
1620 sc.scheduleFrameWrite()
1621 return nil
1622 }
1623
1624 func (sc *serverConn) processResetStream(f *RSTStreamFrame) error {
1625 sc.serveG.check()
1626
1627 state, st := sc.state(f.StreamID)
1628 if state == stateIdle {
1629
1630
1631
1632
1633
1634 return sc.countError("reset_idle_stream", ConnectionError(ErrCodeProtocol))
1635 }
1636 if st != nil {
1637 st.cancelCtx()
1638 sc.closeStream(st, streamError(f.StreamID, f.ErrCode))
1639 }
1640 return nil
1641 }
1642
1643 func (sc *serverConn) closeStream(st *stream, err error) {
1644 sc.serveG.check()
1645 if st.state == stateIdle || st.state == stateClosed {
1646 panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state))
1647 }
1648 st.state = stateClosed
1649 if st.readDeadline != nil {
1650 st.readDeadline.Stop()
1651 }
1652 if st.writeDeadline != nil {
1653 st.writeDeadline.Stop()
1654 }
1655 if st.isPushed() {
1656 sc.curPushedStreams--
1657 } else {
1658 sc.curClientStreams--
1659 }
1660 delete(sc.streams, st.id)
1661 if len(sc.streams) == 0 {
1662 sc.setConnState(ConnStateIdle)
1663 if sc.srv.IdleTimeout > 0 && sc.idleTimer != nil {
1664 sc.idleTimer.Reset(sc.srv.IdleTimeout)
1665 }
1666 if h1ServerKeepAlivesDisabled(sc.hs) {
1667 sc.startGracefulShutdownInternal()
1668 }
1669 }
1670 if p := st.body; p != nil {
1671
1672
1673 sc.sendWindowUpdate(nil, p.Len())
1674
1675 p.CloseWithError(err)
1676 }
1677 if e, ok := err.(StreamError); ok {
1678 if e.Cause != nil {
1679 err = e.Cause
1680 } else {
1681 err = errStreamClosed
1682 }
1683 }
1684 st.closeErr = err
1685 st.cancelCtx()
1686 st.cw.Close()
1687 sc.writeSched.CloseStream(st.id)
1688 }
1689
1690 func (sc *serverConn) processSettings(f *SettingsFrame) error {
1691 sc.serveG.check()
1692 if f.IsAck() {
1693 sc.unackedSettings--
1694 if sc.unackedSettings < 0 {
1695
1696
1697
1698 return sc.countError("ack_mystery", ConnectionError(ErrCodeProtocol))
1699 }
1700 return nil
1701 }
1702 if f.NumSettings() > 100 || f.HasDuplicates() {
1703
1704
1705
1706 return sc.countError("settings_big_or_dups", ConnectionError(ErrCodeProtocol))
1707 }
1708 if err := f.ForeachSetting(sc.processSetting); err != nil {
1709 return err
1710 }
1711
1712
1713 sc.needToSendSettingsAck = true
1714 sc.scheduleFrameWrite()
1715 return nil
1716 }
1717
1718 func (sc *serverConn) processSetting(s Setting) error {
1719 sc.serveG.check()
1720 if err := s.Valid(); err != nil {
1721 return err
1722 }
1723 if VerboseLogs {
1724 sc.vlogf("http2: server processing setting %v", s)
1725 }
1726 switch s.ID {
1727 case SettingHeaderTableSize:
1728 sc.hpackEncoder.SetMaxDynamicTableSize(s.Val)
1729 case SettingEnablePush:
1730 sc.pushEnabled = s.Val != 0
1731 case SettingMaxConcurrentStreams:
1732 sc.clientMaxStreams = s.Val
1733 case SettingInitialWindowSize:
1734 return sc.processSettingInitialWindowSize(s.Val)
1735 case SettingMaxFrameSize:
1736 sc.maxFrameSize = int32(s.Val)
1737 case SettingMaxHeaderListSize:
1738 sc.peerMaxHeaderListSize = s.Val
1739 case SettingEnableConnectProtocol:
1740
1741
1742 case SettingNoRFC7540Priorities:
1743 if s.Val > 1 {
1744 return ConnectionError(ErrCodeProtocol)
1745 }
1746 default:
1747
1748
1749
1750 if VerboseLogs {
1751 sc.vlogf("http2: server ignoring unknown setting %v", s)
1752 }
1753 }
1754 return nil
1755 }
1756
1757 func (sc *serverConn) processSettingInitialWindowSize(val uint32) error {
1758 sc.serveG.check()
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768 old := sc.initialStreamSendWindowSize
1769 sc.initialStreamSendWindowSize = int32(val)
1770 growth := int32(val) - old
1771 for _, st := range sc.streams {
1772 if !st.flow.add(growth) {
1773
1774
1775
1776
1777
1778
1779 return sc.countError("setting_win_size", ConnectionError(ErrCodeFlowControl))
1780 }
1781 }
1782 return nil
1783 }
1784
1785 func (sc *serverConn) processData(f *DataFrame) error {
1786 sc.serveG.check()
1787 id := f.Header().StreamID
1788
1789 data := f.Data()
1790 state, st := sc.state(id)
1791 if id == 0 || state == stateIdle {
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802 return sc.countError("data_on_idle", ConnectionError(ErrCodeProtocol))
1803 }
1804
1805
1806
1807
1808 if st == nil || state != stateOpen || st.gotTrailerHeader || st.resetQueued {
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818 if !sc.inflow.take(f.Length) {
1819 return sc.countError("data_flow", streamError(id, ErrCodeFlowControl))
1820 }
1821 sc.sendWindowUpdate(nil, int(f.Length))
1822
1823 if st != nil && st.resetQueued {
1824
1825 return nil
1826 }
1827 return sc.countError("closed", streamError(id, ErrCodeStreamClosed))
1828 }
1829 if st.body == nil {
1830 panic("internal error: should have a body in this state")
1831 }
1832
1833
1834 if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
1835 if !sc.inflow.take(f.Length) {
1836 return sc.countError("data_flow", streamError(id, ErrCodeFlowControl))
1837 }
1838 sc.sendWindowUpdate(nil, int(f.Length))
1839
1840 st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
1841
1842
1843
1844 return sc.countError("send_too_much", streamError(id, ErrCodeProtocol))
1845 }
1846 if f.Length > 0 {
1847
1848 if !takeInflows(&sc.inflow, &st.inflow, f.Length) {
1849 return sc.countError("flow_on_data_length", streamError(id, ErrCodeFlowControl))
1850 }
1851
1852 if len(data) > 0 {
1853 st.bodyBytes += int64(len(data))
1854 wrote, err := st.body.Write(data)
1855 if err != nil {
1856
1857
1858
1859 sc.sendWindowUpdate(nil, int(f.Length)-wrote)
1860 return nil
1861 }
1862 if wrote != len(data) {
1863 panic("internal error: bad Writer")
1864 }
1865 }
1866
1867
1868
1869
1870
1871
1872 pad := int32(f.Length) - int32(len(data))
1873 sc.sendWindowUpdate32(nil, pad)
1874 sc.sendWindowUpdate32(st, pad)
1875 }
1876 if f.StreamEnded() {
1877 st.endStream()
1878 }
1879 return nil
1880 }
1881
1882 func (sc *serverConn) processGoAway(f *GoAwayFrame) error {
1883 sc.serveG.check()
1884 if f.ErrCode != ErrCodeNo {
1885 sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f)
1886 } else {
1887 sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f)
1888 }
1889 sc.startGracefulShutdownInternal()
1890
1891
1892 sc.pushEnabled = false
1893 return nil
1894 }
1895
1896
1897 func (st *stream) isPushed() bool {
1898 return st.id%2 == 0
1899 }
1900
1901
1902
1903 func (st *stream) endStream() {
1904 sc := st.sc
1905 sc.serveG.check()
1906
1907 if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
1908 st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
1909 st.declBodyBytes, st.bodyBytes))
1910 } else {
1911 st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest)
1912 st.body.CloseWithError(io.EOF)
1913 }
1914 st.state = stateHalfClosedRemote
1915 }
1916
1917
1918
1919 func (st *stream) copyTrailersToHandlerRequest() {
1920 for k, vv := range st.trailer {
1921 if _, ok := st.reqTrailer[k]; ok {
1922
1923 st.reqTrailer[k] = vv
1924 }
1925 }
1926 }
1927
1928
1929
1930 func (st *stream) onReadTimeout() {
1931 if st.body != nil {
1932
1933
1934 st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded))
1935 }
1936 }
1937
1938
1939
1940 func (st *stream) onWriteTimeout() {
1941 st.sc.writeFrameFromHandler(FrameWriteRequest{write: StreamError{
1942 StreamID: st.id,
1943 Code: ErrCodeInternal,
1944 Cause: os.ErrDeadlineExceeded,
1945 }})
1946 }
1947
1948 func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
1949 sc.serveG.check()
1950 id := f.StreamID
1951
1952
1953
1954
1955
1956 if id%2 != 1 {
1957 return sc.countError("headers_even", ConnectionError(ErrCodeProtocol))
1958 }
1959
1960
1961
1962
1963 if st := sc.streams[f.StreamID]; st != nil {
1964 if st.resetQueued {
1965
1966
1967 return nil
1968 }
1969
1970
1971
1972
1973 if st.state == stateHalfClosedRemote {
1974 return sc.countError("headers_half_closed", streamError(id, ErrCodeStreamClosed))
1975 }
1976 return st.processTrailerHeaders(f)
1977 }
1978
1979
1980
1981
1982
1983
1984 if id <= sc.maxClientStreamID {
1985 return sc.countError("stream_went_down", ConnectionError(ErrCodeProtocol))
1986 }
1987 sc.maxClientStreamID = id
1988
1989 if sc.idleTimer != nil {
1990 sc.idleTimer.Stop()
1991 }
1992
1993
1994
1995
1996
1997
1998
1999 if sc.curClientStreams+1 > sc.advMaxStreams {
2000 if sc.unackedSettings == 0 {
2001
2002 return sc.countError("over_max_streams", streamError(id, ErrCodeProtocol))
2003 }
2004
2005
2006
2007
2008
2009 return sc.countError("over_max_streams_race", streamError(id, ErrCodeRefusedStream))
2010 }
2011
2012 initialState := stateOpen
2013 if f.StreamEnded() {
2014 initialState = stateHalfClosedRemote
2015 }
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025 initialPriority := defaultRFC9218Priority(sc.priorityAware && !sc.hasIntermediary)
2026 if _, ok := sc.writeSched.(*priorityWriteSchedulerRFC9218); ok && !sc.hasIntermediary {
2027 headerPriority, priorityAware, hasIntermediary := f.rfc9218Priority(sc.priorityAware)
2028 initialPriority = headerPriority
2029 sc.hasIntermediary = hasIntermediary
2030 if priorityAware {
2031 sc.priorityAware = true
2032 }
2033 }
2034 st := sc.newStream(id, 0, initialState, initialPriority)
2035
2036 if f.HasPriority() {
2037 if err := sc.checkPriority(f.StreamID, f.Priority); err != nil {
2038 return err
2039 }
2040 if !sc.writeSchedIgnoresRFC7540() {
2041 sc.writeSched.AdjustStream(st.id, f.Priority)
2042 }
2043 }
2044
2045 rw, req, err := sc.newWriterAndRequest(st, f)
2046 if err != nil {
2047 return err
2048 }
2049 st.reqTrailer = req.Trailer
2050 if st.reqTrailer != nil {
2051 st.trailer = make(Header)
2052 }
2053 st.body = req.Body.(*requestBody).pipe
2054 st.declBodyBytes = req.ContentLength
2055
2056 handler := sc.handler.ServeHTTP
2057 if f.Truncated {
2058
2059 handler = handleHeaderListTooLong
2060 } else if err := checkValidHTTP2RequestHeaders(req.Header); err != nil {
2061 handler = serve400Handler{err}.ServeHTTP
2062 }
2063
2064
2065
2066
2067
2068
2069
2070
2071 if sc.hs.ReadTimeout() > 0 {
2072 sc.conn.SetReadDeadline(time.Time{})
2073 st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout(), st.onReadTimeout)
2074 }
2075
2076 return sc.scheduleHandler(id, rw, req, handler)
2077 }
2078
2079 func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
2080 sc := st.sc
2081 sc.serveG.check()
2082 if st.gotTrailerHeader {
2083 return sc.countError("dup_trailers", ConnectionError(ErrCodeProtocol))
2084 }
2085 st.gotTrailerHeader = true
2086 if !f.StreamEnded() {
2087 return sc.countError("trailers_not_ended", streamError(st.id, ErrCodeProtocol))
2088 }
2089
2090 if len(f.PseudoFields()) > 0 {
2091 return sc.countError("trailers_pseudo", streamError(st.id, ErrCodeProtocol))
2092 }
2093 if st.trailer != nil {
2094 for _, hf := range f.RegularFields() {
2095 key := sc.canonicalHeader(hf.Name)
2096 if !httpguts.ValidTrailerHeader(key) {
2097
2098
2099
2100 return sc.countError("trailers_bogus", streamError(st.id, ErrCodeProtocol))
2101 }
2102 st.trailer[key] = append(st.trailer[key], hf.Value)
2103 }
2104 }
2105 st.endStream()
2106 return nil
2107 }
2108
2109 func (sc *serverConn) checkPriority(streamID uint32, p PriorityParam) error {
2110 if streamID == p.StreamDep {
2111
2112
2113
2114
2115 return sc.countError("priority", streamError(streamID, ErrCodeProtocol))
2116 }
2117 return nil
2118 }
2119
2120 func (sc *serverConn) processPriority(f *PriorityFrame) error {
2121 if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil {
2122 return err
2123 }
2124
2125
2126
2127
2128
2129 if sc.writeSchedIgnoresRFC7540() {
2130 return nil
2131 }
2132 sc.writeSched.AdjustStream(f.StreamID, f.PriorityParam)
2133 return nil
2134 }
2135
2136 func (sc *serverConn) processPriorityUpdate(f *PriorityUpdateFrame) error {
2137 sc.priorityAware = true
2138 if _, ok := sc.writeSched.(*priorityWriteSchedulerRFC9218); !ok {
2139 return nil
2140 }
2141 p, ok := parseRFC9218Priority(f.Priority, sc.priorityAware)
2142 if !ok {
2143 return sc.countError("unparsable_priority_update", streamError(f.PrioritizedStreamID, ErrCodeProtocol))
2144 }
2145 sc.writeSched.AdjustStream(f.PrioritizedStreamID, p)
2146 return nil
2147 }
2148
2149 func (sc *serverConn) newStream(id, pusherID uint32, state streamState, priority PriorityParam) *stream {
2150 sc.serveG.check()
2151 if id == 0 {
2152 panic("internal error: cannot create stream with id 0")
2153 }
2154
2155 ctx, cancelCtx := context.WithCancel(sc.baseCtx)
2156 st := &stream{
2157 sc: sc,
2158 id: id,
2159 state: state,
2160 ctx: ctx,
2161 cancelCtx: cancelCtx,
2162 }
2163 st.cw.Init()
2164 st.flow.conn = &sc.flow
2165 st.flow.add(sc.initialStreamSendWindowSize)
2166 st.inflow.init(sc.initialStreamRecvWindowSize)
2167 if writeTimeout := sc.hs.WriteTimeout(); writeTimeout > 0 {
2168 st.writeDeadline = time.AfterFunc(writeTimeout, st.onWriteTimeout)
2169 }
2170
2171 sc.streams[id] = st
2172 sc.writeSched.OpenStream(st.id, OpenStreamOptions{PusherID: pusherID, priority: priority})
2173 if st.isPushed() {
2174 sc.curPushedStreams++
2175 } else {
2176 sc.curClientStreams++
2177 }
2178 if sc.curOpenStreams() == 1 {
2179 sc.setConnState(ConnStateActive)
2180 }
2181
2182 return st
2183 }
2184
2185 func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *ServerRequest, error) {
2186 sc.serveG.check()
2187
2188 rp := httpcommon.ServerRequestParam{
2189 Method: f.PseudoValue("method"),
2190 Scheme: f.PseudoValue("scheme"),
2191 Authority: f.PseudoValue("authority"),
2192 Path: f.PseudoValue("path"),
2193 Protocol: f.PseudoValue("protocol"),
2194 }
2195
2196
2197 if disableExtendedConnectProtocol && rp.Protocol != "" {
2198 return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
2199 }
2200
2201 isConnect := rp.Method == "CONNECT"
2202 if isConnect {
2203 if rp.Protocol == "" && (rp.Path != "" || rp.Scheme != "" || rp.Authority == "") {
2204 return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
2205 }
2206 } else if rp.Method == "" || rp.Path == "" || (rp.Scheme != "https" && rp.Scheme != "http") {
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217 return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol))
2218 }
2219
2220 header := make(Header)
2221 rp.Header = header
2222 for _, hf := range f.RegularFields() {
2223 header.Add(sc.canonicalHeader(hf.Name), hf.Value)
2224 }
2225 if rp.Authority == "" {
2226 rp.Authority = header.Get("Host")
2227 }
2228 if rp.Protocol != "" {
2229 header.Set(":protocol", rp.Protocol)
2230 }
2231
2232 rw, req, err := sc.newWriterAndRequestNoBody(st, rp)
2233 if err != nil {
2234 return nil, nil, err
2235 }
2236 bodyOpen := !f.StreamEnded()
2237 if bodyOpen {
2238 if vv, ok := rp.Header["Content-Length"]; ok {
2239 if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil {
2240 req.ContentLength = int64(cl)
2241 } else {
2242 req.ContentLength = 0
2243 }
2244 } else {
2245 req.ContentLength = -1
2246 }
2247 req.Body.(*requestBody).pipe = &pipe{
2248 b: &dataBuffer{expected: req.ContentLength},
2249 }
2250 }
2251 return rw, req, nil
2252 }
2253
2254 func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp httpcommon.ServerRequestParam) (*responseWriter, *ServerRequest, error) {
2255 sc.serveG.check()
2256
2257 var tlsState *tls.ConnectionState
2258 if rp.Scheme == "https" {
2259 tlsState = sc.tlsState
2260 }
2261
2262 res := httpcommon.NewServerRequest(rp)
2263 if res.InvalidReason != "" {
2264 return nil, nil, sc.countError(res.InvalidReason, streamError(st.id, ErrCodeProtocol))
2265 }
2266
2267 body := &requestBody{
2268 conn: sc,
2269 stream: st,
2270 needsContinue: res.NeedsContinue,
2271 }
2272 rw := sc.newResponseWriter(st)
2273 rw.rws.req = ServerRequest{
2274 Context: st.ctx,
2275 Method: rp.Method,
2276 URL: res.URL,
2277 RemoteAddr: sc.remoteAddrStr,
2278 Header: rp.Header,
2279 RequestURI: res.RequestURI,
2280 Proto: "HTTP/2.0",
2281 ProtoMajor: 2,
2282 ProtoMinor: 0,
2283 TLS: tlsState,
2284 Host: rp.Authority,
2285 Body: body,
2286 Trailer: res.Trailer,
2287 }
2288 return rw, &rw.rws.req, nil
2289 }
2290
2291 func (sc *serverConn) newResponseWriter(st *stream) *responseWriter {
2292 rws := responseWriterStatePool.Get().(*responseWriterState)
2293 bwSave := rws.bw
2294 *rws = responseWriterState{}
2295 rws.conn = sc
2296 rws.bw = bwSave
2297 rws.bw.Reset(chunkWriter{rws})
2298 rws.stream = st
2299 return &responseWriter{rws: rws}
2300 }
2301
2302 type unstartedHandler struct {
2303 streamID uint32
2304 rw *responseWriter
2305 req *ServerRequest
2306 handler func(*ResponseWriter, *ServerRequest)
2307 }
2308
2309
2310
2311 func (sc *serverConn) scheduleHandler(streamID uint32, rw *responseWriter, req *ServerRequest, handler func(*ResponseWriter, *ServerRequest)) error {
2312 sc.serveG.check()
2313 maxHandlers := sc.advMaxStreams
2314 if sc.curHandlers < maxHandlers {
2315 sc.curHandlers++
2316 go sc.runHandler(rw, req, handler)
2317 return nil
2318 }
2319 if len(sc.unstartedHandlers) > int(4*sc.advMaxStreams) {
2320 return sc.countError("too_many_early_resets", ConnectionError(ErrCodeEnhanceYourCalm))
2321 }
2322 sc.unstartedHandlers = append(sc.unstartedHandlers, unstartedHandler{
2323 streamID: streamID,
2324 rw: rw,
2325 req: req,
2326 handler: handler,
2327 })
2328 return nil
2329 }
2330
2331 func (sc *serverConn) handlerDone() {
2332 sc.serveG.check()
2333 sc.curHandlers--
2334 i := 0
2335 maxHandlers := sc.advMaxStreams
2336 for ; i < len(sc.unstartedHandlers); i++ {
2337 u := sc.unstartedHandlers[i]
2338 if sc.streams[u.streamID] == nil {
2339
2340 continue
2341 }
2342 if sc.curHandlers >= maxHandlers {
2343 break
2344 }
2345 sc.curHandlers++
2346 go sc.runHandler(u.rw, u.req, u.handler)
2347 sc.unstartedHandlers[i] = unstartedHandler{}
2348 }
2349 sc.unstartedHandlers = sc.unstartedHandlers[i:]
2350 if len(sc.unstartedHandlers) == 0 {
2351 sc.unstartedHandlers = nil
2352 }
2353 }
2354
2355
2356 func (sc *serverConn) runHandler(rw *responseWriter, req *ServerRequest, handler func(*ResponseWriter, *ServerRequest)) {
2357 defer sc.sendServeMsg(handlerDoneMsg)
2358 didPanic := true
2359 defer func() {
2360 rw.rws.stream.cancelCtx()
2361 if req.MultipartForm != nil {
2362 req.MultipartForm.RemoveAll()
2363 }
2364 if didPanic {
2365 e := recover()
2366 sc.writeFrameFromHandler(FrameWriteRequest{
2367 write: handlerPanicRST{rw.rws.stream.id},
2368 stream: rw.rws.stream,
2369 })
2370
2371 if e != nil && e != ErrAbortHandler {
2372 const size = 64 << 10
2373 buf := make([]byte, size)
2374 buf = buf[:runtime.Stack(buf, false)]
2375 sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf)
2376 }
2377 return
2378 }
2379 rw.handlerDone()
2380 }()
2381 handler(rw, req)
2382 didPanic = false
2383 }
2384
2385 func handleHeaderListTooLong(w *ResponseWriter, r *ServerRequest) {
2386
2387
2388
2389
2390 const statusRequestHeaderFieldsTooLarge = 431
2391 w.WriteHeader(statusRequestHeaderFieldsTooLarge)
2392 io.WriteString(w, "<h1>HTTP Error 431</h1><p>Request Header Field(s) Too Large</p>")
2393 }
2394
2395
2396
2397 func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) error {
2398 sc.serveG.checkNotOn()
2399 var errc chan error
2400 if headerData.h != nil {
2401
2402
2403
2404
2405 errc = sc.srv.state.getErrChan()
2406 }
2407 if err := sc.writeFrameFromHandler(FrameWriteRequest{
2408 write: headerData,
2409 stream: st,
2410 done: errc,
2411 }); err != nil {
2412 return err
2413 }
2414 if errc != nil {
2415 select {
2416 case err := <-errc:
2417 sc.srv.state.putErrChan(errc)
2418 return err
2419 case <-sc.doneServing:
2420 return errClientDisconnected
2421 case <-st.cw:
2422 return errStreamClosed
2423 }
2424 }
2425 return nil
2426 }
2427
2428
2429 func (sc *serverConn) write100ContinueHeaders(st *stream) {
2430 sc.writeFrameFromHandler(FrameWriteRequest{
2431 write: write100ContinueHeadersFrame{st.id},
2432 stream: st,
2433 })
2434 }
2435
2436
2437
2438 type bodyReadMsg struct {
2439 st *stream
2440 n int
2441 }
2442
2443
2444
2445
2446 func (sc *serverConn) noteBodyReadFromHandler(st *stream, n int, err error) {
2447 sc.serveG.checkNotOn()
2448 if n > 0 {
2449 select {
2450 case sc.bodyReadCh <- bodyReadMsg{st, n}:
2451 case <-sc.doneServing:
2452 }
2453 }
2454 }
2455
2456 func (sc *serverConn) noteBodyRead(st *stream, n int) {
2457 sc.serveG.check()
2458 sc.sendWindowUpdate(nil, n)
2459 if st.state != stateHalfClosedRemote && st.state != stateClosed {
2460
2461
2462 sc.sendWindowUpdate(st, n)
2463 }
2464 }
2465
2466
2467 func (sc *serverConn) sendWindowUpdate32(st *stream, n int32) {
2468 sc.sendWindowUpdate(st, int(n))
2469 }
2470
2471
2472 func (sc *serverConn) sendWindowUpdate(st *stream, n int) {
2473 sc.serveG.check()
2474 var streamID uint32
2475 var send int32
2476 if st == nil {
2477 send = sc.inflow.add(n)
2478 } else {
2479 streamID = st.id
2480 send = st.inflow.add(n)
2481 }
2482 if send == 0 {
2483 return
2484 }
2485 sc.writeFrame(FrameWriteRequest{
2486 write: writeWindowUpdate{streamID: streamID, n: uint32(send)},
2487 stream: st,
2488 })
2489 }
2490
2491
2492
2493 type requestBody struct {
2494 _ incomparable
2495 stream *stream
2496 conn *serverConn
2497 closeOnce sync.Once
2498 sawEOF bool
2499 pipe *pipe
2500 needsContinue bool
2501 }
2502
2503 func (b *requestBody) Close() error {
2504 b.closeOnce.Do(func() {
2505 if b.pipe != nil {
2506 b.pipe.BreakWithError(errClosedBody)
2507 }
2508 })
2509 return nil
2510 }
2511
2512 func (b *requestBody) Read(p []byte) (n int, err error) {
2513 if b.needsContinue {
2514 b.needsContinue = false
2515 b.conn.write100ContinueHeaders(b.stream)
2516 }
2517 if b.pipe == nil || b.sawEOF {
2518 return 0, io.EOF
2519 }
2520 n, err = b.pipe.Read(p)
2521 if err == io.EOF {
2522 b.sawEOF = true
2523 }
2524 if b.conn == nil {
2525 return
2526 }
2527 b.conn.noteBodyReadFromHandler(b.stream, n, err)
2528 return
2529 }
2530
2531
2532
2533
2534
2535
2536
2537 type responseWriter struct {
2538 rws *responseWriterState
2539 }
2540
2541 type responseWriterState struct {
2542
2543 stream *stream
2544 req ServerRequest
2545 conn *serverConn
2546
2547
2548 bw *bufio.Writer
2549
2550
2551 handlerHeader Header
2552 snapHeader Header
2553 trailers []string
2554 status int
2555 wroteHeader bool
2556 sentHeader bool
2557 handlerDone bool
2558
2559 sentContentLen int64
2560 wroteBytes int64
2561
2562 closeNotifierMu sync.Mutex
2563 closeNotifierCh chan bool
2564 }
2565
2566 type chunkWriter struct{ rws *responseWriterState }
2567
2568 func (cw chunkWriter) Write(p []byte) (n int, err error) {
2569 n, err = cw.rws.writeChunk(p)
2570 if err == errStreamClosed {
2571
2572
2573 err = cw.rws.stream.closeErr
2574 }
2575 return n, err
2576 }
2577
2578 func (rws *responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 }
2579
2580 func (rws *responseWriterState) hasNonemptyTrailers() bool {
2581 for _, trailer := range rws.trailers {
2582 if _, ok := rws.handlerHeader[trailer]; ok {
2583 return true
2584 }
2585 }
2586 return false
2587 }
2588
2589
2590
2591
2592 func (rws *responseWriterState) declareTrailer(k string) {
2593 k = textproto.CanonicalMIMEHeaderKey(k)
2594 if !httpguts.ValidTrailerHeader(k) {
2595
2596 rws.conn.logf("ignoring invalid trailer %q", k)
2597 return
2598 }
2599 if !strSliceContains(rws.trailers, k) {
2600 rws.trailers = append(rws.trailers, k)
2601 }
2602 }
2603
2604 const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
2605
2606
2607
2608
2609
2610
2611
2612 func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
2613 if !rws.wroteHeader {
2614 rws.writeHeader(200)
2615 }
2616
2617 if rws.handlerDone {
2618 rws.promoteUndeclaredTrailers()
2619 }
2620
2621 isHeadResp := rws.req.Method == "HEAD"
2622 if !rws.sentHeader {
2623 rws.sentHeader = true
2624 var ctype, clen string
2625 if clen = rws.snapHeader.Get("Content-Length"); clen != "" {
2626 rws.snapHeader.Del("Content-Length")
2627 if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
2628 rws.sentContentLen = int64(cl)
2629 } else {
2630 clen = ""
2631 }
2632 }
2633 _, hasContentLength := rws.snapHeader["Content-Length"]
2634 if !hasContentLength && clen == "" && rws.handlerDone && bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) {
2635 clen = strconv.Itoa(len(p))
2636 }
2637 _, hasContentType := rws.snapHeader["Content-Type"]
2638
2639
2640 ce := rws.snapHeader.Get("Content-Encoding")
2641 hasCE := len(ce) > 0
2642 if !hasCE && !hasContentType && bodyAllowedForStatus(rws.status) && len(p) > 0 {
2643 ctype = internal.DetectContentType(p)
2644 }
2645 var date string
2646 if _, ok := rws.snapHeader["Date"]; !ok {
2647
2648 date = time.Now().UTC().Format(TimeFormat)
2649 }
2650
2651 for _, v := range rws.snapHeader["Trailer"] {
2652 foreachHeaderElement(v, rws.declareTrailer)
2653 }
2654
2655
2656
2657
2658
2659
2660 if _, ok := rws.snapHeader["Connection"]; ok {
2661 v := rws.snapHeader.Get("Connection")
2662 delete(rws.snapHeader, "Connection")
2663 if v == "close" {
2664 rws.conn.startGracefulShutdown()
2665 }
2666 }
2667
2668 endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp
2669 err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
2670 streamID: rws.stream.id,
2671 httpResCode: rws.status,
2672 h: rws.snapHeader,
2673 endStream: endStream,
2674 contentType: ctype,
2675 contentLength: clen,
2676 date: date,
2677 })
2678 if err != nil {
2679 return 0, err
2680 }
2681 if endStream {
2682 return 0, nil
2683 }
2684 }
2685 if isHeadResp {
2686 return len(p), nil
2687 }
2688 if len(p) == 0 && !rws.handlerDone {
2689 return 0, nil
2690 }
2691
2692
2693
2694 hasNonemptyTrailers := rws.hasNonemptyTrailers()
2695 endStream := rws.handlerDone && !hasNonemptyTrailers
2696 if len(p) > 0 || endStream {
2697
2698 if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
2699 return 0, err
2700 }
2701 }
2702
2703 if rws.handlerDone && hasNonemptyTrailers {
2704 err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
2705 streamID: rws.stream.id,
2706 h: rws.handlerHeader,
2707 trailers: rws.trailers,
2708 endStream: true,
2709 })
2710 return len(p), err
2711 }
2712 return len(p), nil
2713 }
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728 const TrailerPrefix = "Trailer:"
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751 func (rws *responseWriterState) promoteUndeclaredTrailers() {
2752 for k, vv := range rws.handlerHeader {
2753 if !strings.HasPrefix(k, TrailerPrefix) {
2754 continue
2755 }
2756 trailerKey := strings.TrimPrefix(k, TrailerPrefix)
2757 rws.declareTrailer(trailerKey)
2758 rws.handlerHeader[textproto.CanonicalMIMEHeaderKey(trailerKey)] = vv
2759 }
2760
2761 if len(rws.trailers) > 1 {
2762 sorter := sorterPool.Get().(*sorter)
2763 sorter.SortStrings(rws.trailers)
2764 sorterPool.Put(sorter)
2765 }
2766 }
2767
2768 func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
2769 st := w.rws.stream
2770 if !deadline.IsZero() && deadline.Before(time.Now()) {
2771
2772
2773 st.onReadTimeout()
2774 return nil
2775 }
2776 w.rws.conn.sendServeMsg(func(sc *serverConn) {
2777 if st.readDeadline != nil {
2778 if !st.readDeadline.Stop() {
2779
2780 return
2781 }
2782 }
2783 if deadline.IsZero() {
2784 st.readDeadline = nil
2785 } else if st.readDeadline == nil {
2786 st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout)
2787 } else {
2788 st.readDeadline.Reset(deadline.Sub(time.Now()))
2789 }
2790 })
2791 return nil
2792 }
2793
2794 func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
2795 st := w.rws.stream
2796 if !deadline.IsZero() && deadline.Before(time.Now()) {
2797
2798
2799 st.onWriteTimeout()
2800 return nil
2801 }
2802 w.rws.conn.sendServeMsg(func(sc *serverConn) {
2803 if st.writeDeadline != nil {
2804 if !st.writeDeadline.Stop() {
2805
2806 return
2807 }
2808 }
2809 if deadline.IsZero() {
2810 st.writeDeadline = nil
2811 } else if st.writeDeadline == nil {
2812 st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout)
2813 } else {
2814 st.writeDeadline.Reset(deadline.Sub(time.Now()))
2815 }
2816 })
2817 return nil
2818 }
2819
2820 func (w *responseWriter) EnableFullDuplex() error {
2821
2822 return nil
2823 }
2824
2825 func (w *responseWriter) Flush() {
2826 w.FlushError()
2827 }
2828
2829 func (w *responseWriter) FlushError() error {
2830 rws := w.rws
2831 if rws == nil {
2832 panic("Header called after Handler finished")
2833 }
2834 var err error
2835 if rws.bw.Buffered() > 0 {
2836 err = rws.bw.Flush()
2837 } else {
2838
2839
2840
2841
2842 _, err = chunkWriter{rws}.Write(nil)
2843 if err == nil {
2844 select {
2845 case <-rws.stream.cw:
2846 err = rws.stream.closeErr
2847 default:
2848 }
2849 }
2850 }
2851 return err
2852 }
2853
2854 func (w *responseWriter) CloseNotify() <-chan bool {
2855 rws := w.rws
2856 if rws == nil {
2857 panic("CloseNotify called after Handler finished")
2858 }
2859 rws.closeNotifierMu.Lock()
2860 ch := rws.closeNotifierCh
2861 if ch == nil {
2862 ch = make(chan bool, 1)
2863 rws.closeNotifierCh = ch
2864 cw := rws.stream.cw
2865 go func() {
2866 cw.Wait()
2867 ch <- true
2868 }()
2869 }
2870 rws.closeNotifierMu.Unlock()
2871 return ch
2872 }
2873
2874 func (w *responseWriter) Header() Header {
2875 rws := w.rws
2876 if rws == nil {
2877 panic("Header called after Handler finished")
2878 }
2879 if rws.handlerHeader == nil {
2880 rws.handlerHeader = make(Header)
2881 }
2882 return rws.handlerHeader
2883 }
2884
2885
2886 func checkWriteHeaderCode(code int) {
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897 if code < 100 || code > 999 {
2898 panic(fmt.Sprintf("invalid WriteHeader code %v", code))
2899 }
2900 }
2901
2902 func (w *responseWriter) WriteHeader(code int) {
2903 rws := w.rws
2904 if rws == nil {
2905 panic("WriteHeader called after Handler finished")
2906 }
2907 rws.writeHeader(code)
2908 }
2909
2910 func (rws *responseWriterState) writeHeader(code int) {
2911 if rws.wroteHeader {
2912 return
2913 }
2914
2915 checkWriteHeaderCode(code)
2916
2917
2918 if code >= 100 && code <= 199 {
2919
2920 h := rws.handlerHeader
2921
2922 _, cl := h["Content-Length"]
2923 _, te := h["Transfer-Encoding"]
2924 if cl || te {
2925 h = cloneHeader(h)
2926 h.Del("Content-Length")
2927 h.Del("Transfer-Encoding")
2928 }
2929
2930 rws.conn.writeHeaders(rws.stream, &writeResHeaders{
2931 streamID: rws.stream.id,
2932 httpResCode: code,
2933 h: h,
2934 endStream: rws.handlerDone && !rws.hasTrailers(),
2935 })
2936
2937 return
2938 }
2939
2940 rws.wroteHeader = true
2941 rws.status = code
2942 if len(rws.handlerHeader) > 0 {
2943 rws.snapHeader = cloneHeader(rws.handlerHeader)
2944 }
2945 }
2946
2947 func cloneHeader(h Header) Header {
2948 h2 := make(Header, len(h))
2949 for k, vv := range h {
2950 vv2 := make([]string, len(vv))
2951 copy(vv2, vv)
2952 h2[k] = vv2
2953 }
2954 return h2
2955 }
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965 func (w *responseWriter) Write(p []byte) (n int, err error) {
2966 return w.write(len(p), p, "")
2967 }
2968
2969 func (w *responseWriter) WriteString(s string) (n int, err error) {
2970 return w.write(len(s), nil, s)
2971 }
2972
2973
2974 func (w *responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) {
2975 rws := w.rws
2976 if rws == nil {
2977 panic("Write called after Handler finished")
2978 }
2979 if !rws.wroteHeader {
2980 w.WriteHeader(200)
2981 }
2982 if !bodyAllowedForStatus(rws.status) {
2983 return 0, ErrBodyNotAllowed
2984 }
2985 rws.wroteBytes += int64(len(dataB)) + int64(len(dataS))
2986 if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen {
2987
2988 return 0, errors.New("http2: handler wrote more than declared Content-Length")
2989 }
2990
2991 if dataB != nil {
2992 return rws.bw.Write(dataB)
2993 } else {
2994 return rws.bw.WriteString(dataS)
2995 }
2996 }
2997
2998 func (w *responseWriter) handlerDone() {
2999 rws := w.rws
3000 rws.handlerDone = true
3001 w.Flush()
3002 w.rws = nil
3003 responseWriterStatePool.Put(rws)
3004 }
3005
3006
3007 var (
3008 ErrRecursivePush = errors.New("http2: recursive push not allowed")
3009 ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS")
3010 )
3011
3012 func (w *responseWriter) Push(target, method string, header Header) error {
3013 st := w.rws.stream
3014 sc := st.sc
3015 sc.serveG.checkNotOn()
3016
3017
3018
3019 if st.isPushed() {
3020 return ErrRecursivePush
3021 }
3022
3023
3024 if method == "" {
3025 method = "GET"
3026 }
3027 if header == nil {
3028 header = Header{}
3029 }
3030 wantScheme := "http"
3031 if w.rws.req.TLS != nil {
3032 wantScheme = "https"
3033 }
3034
3035
3036 u, err := url.Parse(target)
3037 if err != nil {
3038 return err
3039 }
3040 if u.Scheme == "" {
3041 if !strings.HasPrefix(target, "/") {
3042 return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target)
3043 }
3044 u.Scheme = wantScheme
3045 u.Host = w.rws.req.Host
3046 } else {
3047 if u.Scheme != wantScheme {
3048 return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme)
3049 }
3050 if u.Host == "" {
3051 return errors.New("URL must have a host")
3052 }
3053 }
3054 for k := range header {
3055 if strings.HasPrefix(k, ":") {
3056 return fmt.Errorf("promised request headers cannot include pseudo header %q", k)
3057 }
3058
3059
3060
3061
3062 if asciiEqualFold(k, "content-length") ||
3063 asciiEqualFold(k, "content-encoding") ||
3064 asciiEqualFold(k, "trailer") ||
3065 asciiEqualFold(k, "te") ||
3066 asciiEqualFold(k, "expect") ||
3067 asciiEqualFold(k, "host") {
3068 return fmt.Errorf("promised request headers cannot include %q", k)
3069 }
3070 }
3071 if err := checkValidHTTP2RequestHeaders(header); err != nil {
3072 return err
3073 }
3074
3075
3076
3077
3078 if method != "GET" && method != "HEAD" {
3079 return fmt.Errorf("method %q must be GET or HEAD", method)
3080 }
3081
3082 msg := &startPushRequest{
3083 parent: st,
3084 method: method,
3085 url: u,
3086 header: cloneHeader(header),
3087 done: sc.srv.state.getErrChan(),
3088 }
3089
3090 select {
3091 case <-sc.doneServing:
3092 return errClientDisconnected
3093 case <-st.cw:
3094 return errStreamClosed
3095 case sc.serveMsgCh <- msg:
3096 }
3097
3098 select {
3099 case <-sc.doneServing:
3100 return errClientDisconnected
3101 case <-st.cw:
3102 return errStreamClosed
3103 case err := <-msg.done:
3104 sc.srv.state.putErrChan(msg.done)
3105 return err
3106 }
3107 }
3108
3109 type startPushRequest struct {
3110 parent *stream
3111 method string
3112 url *url.URL
3113 header Header
3114 done chan error
3115 }
3116
3117 func (sc *serverConn) startPush(msg *startPushRequest) {
3118 sc.serveG.check()
3119
3120
3121
3122
3123 if msg.parent.state != stateOpen && msg.parent.state != stateHalfClosedRemote {
3124
3125 msg.done <- errStreamClosed
3126 return
3127 }
3128
3129
3130 if !sc.pushEnabled {
3131 msg.done <- ErrNotSupported
3132 return
3133 }
3134
3135
3136
3137
3138 allocatePromisedID := func() (uint32, error) {
3139 sc.serveG.check()
3140
3141
3142
3143 if !sc.pushEnabled {
3144 return 0, ErrNotSupported
3145 }
3146
3147 if sc.curPushedStreams+1 > sc.clientMaxStreams {
3148 return 0, ErrPushLimitReached
3149 }
3150
3151
3152
3153
3154
3155 if sc.maxPushPromiseID+2 >= 1<<31 {
3156 sc.startGracefulShutdownInternal()
3157 return 0, ErrPushLimitReached
3158 }
3159 sc.maxPushPromiseID += 2
3160 promisedID := sc.maxPushPromiseID
3161
3162
3163
3164
3165
3166
3167 promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote, defaultRFC9218Priority(sc.priorityAware && !sc.hasIntermediary))
3168 rw, req, err := sc.newWriterAndRequestNoBody(promised, httpcommon.ServerRequestParam{
3169 Method: msg.method,
3170 Scheme: msg.url.Scheme,
3171 Authority: msg.url.Host,
3172 Path: msg.url.RequestURI(),
3173 Header: cloneHeader(msg.header),
3174 })
3175 if err != nil {
3176
3177 panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err))
3178 }
3179
3180 sc.curHandlers++
3181 go sc.runHandler(rw, req, sc.handler.ServeHTTP)
3182 return promisedID, nil
3183 }
3184
3185 sc.writeFrame(FrameWriteRequest{
3186 write: &writePushPromise{
3187 streamID: msg.parent.id,
3188 method: msg.method,
3189 url: msg.url,
3190 h: msg.header,
3191 allocatePromisedID: allocatePromisedID,
3192 },
3193 stream: msg.parent,
3194 done: msg.done,
3195 })
3196 }
3197
3198
3199
3200 func foreachHeaderElement(v string, fn func(string)) {
3201 v = textproto.TrimString(v)
3202 if v == "" {
3203 return
3204 }
3205 if !strings.Contains(v, ",") {
3206 fn(v)
3207 return
3208 }
3209 for _, f := range strings.Split(v, ",") {
3210 if f = textproto.TrimString(f); f != "" {
3211 fn(f)
3212 }
3213 }
3214 }
3215
3216
3217 var connHeaders = []string{
3218 "Connection",
3219 "Keep-Alive",
3220 "Proxy-Connection",
3221 "Transfer-Encoding",
3222 "Upgrade",
3223 }
3224
3225
3226
3227
3228 func checkValidHTTP2RequestHeaders(h Header) error {
3229 for _, k := range connHeaders {
3230 if _, ok := h[k]; ok {
3231 return fmt.Errorf("request header %q is not valid in HTTP/2", k)
3232 }
3233 }
3234 te := h["Te"]
3235 if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) {
3236 return errors.New(`request header "TE" may only be "trailers" in HTTP/2`)
3237 }
3238 return nil
3239 }
3240
3241 type serve400Handler struct {
3242 err error
3243 }
3244
3245 func (handler serve400Handler) ServeHTTP(w *ResponseWriter, r *ServerRequest) {
3246 const statusBadRequest = 400
3247
3248
3249 h := w.Header()
3250 h.Del("Content-Length")
3251 h.Set("Content-Type", "text/plain; charset=utf-8")
3252 h.Set("X-Content-Type-Options", "nosniff")
3253 w.WriteHeader(statusBadRequest)
3254 fmt.Fprintln(w, handler.err.Error())
3255 }
3256
3257
3258
3259
3260 func h1ServerKeepAlivesDisabled(hs ServerConfig) bool {
3261 return !hs.DoKeepAlives()
3262 }
3263
3264 func (sc *serverConn) countError(name string, err error) error {
3265 if sc == nil || sc.srv == nil {
3266 return err
3267 }
3268 f := sc.countErrorFunc
3269 if f == nil {
3270 return err
3271 }
3272 var typ string
3273 var code ErrCode
3274 switch e := err.(type) {
3275 case ConnectionError:
3276 typ = "conn"
3277 code = ErrCode(e)
3278 case StreamError:
3279 typ = "stream"
3280 code = ErrCode(e.Code)
3281 default:
3282 return err
3283 }
3284 codeStr := errCodeName[code]
3285 if codeStr == "" {
3286 codeStr = strconv.Itoa(int(code))
3287 }
3288 f(fmt.Sprintf("%s_%s_%s", typ, codeStr, name))
3289 return err
3290 }
3291
View as plain text