1
2
3
4
5
6
7 package http2
8
9 import (
10 "context"
11 "errors"
12 "net"
13 "sync"
14 )
15
16
17 type ClientConnPool interface {
18
19
20
21
22
23
24 GetClientConn(req *ClientRequest, addr string) (*ClientConn, error)
25 MarkDead(*ClientConn)
26 }
27
28
29
30 type clientConnPoolIdleCloser interface {
31 ClientConnPool
32 closeIdleConnections()
33 }
34
35 var (
36 _ clientConnPoolIdleCloser = (*clientConnPool)(nil)
37 _ clientConnPoolIdleCloser = noDialClientConnPool{}
38 )
39
40
41 type clientConnPool struct {
42 t *Transport
43
44 mu sync.Mutex
45
46
47 conns map[string][]*ClientConn
48 dialing map[string]*dialCall
49 keys map[*ClientConn][]string
50 addConnCalls map[string]*addConnCall
51 }
52
53 func (p *clientConnPool) GetClientConn(req *ClientRequest, addr string) (*ClientConn, error) {
54 return p.getClientConn(req, addr, dialOnMiss)
55 }
56
57 const (
58 dialOnMiss = true
59 noDialOnMiss = false
60 )
61
62 func (p *clientConnPool) getClientConn(req *ClientRequest, addr string, dialOnMiss bool) (*ClientConn, error) {
63
64 if isConnectionCloseRequest(req) && dialOnMiss {
65
66 traceGetConn(req, addr)
67 const singleUse = true
68 cc, err := p.t.dialClientConn(req.Context, addr, singleUse)
69 if err != nil {
70 return nil, err
71 }
72 return cc, nil
73 }
74 for {
75 p.mu.Lock()
76 for _, cc := range p.conns[addr] {
77 if cc.ReserveNewRequest() {
78
79
80
81 if !cc.getConnCalled {
82 traceGetConn(req, addr)
83 }
84 cc.getConnCalled = false
85 p.mu.Unlock()
86 return cc, nil
87 }
88 }
89 if !dialOnMiss {
90 p.mu.Unlock()
91 return nil, ErrNoCachedConn
92 }
93 traceGetConn(req, addr)
94 call := p.getStartDialLocked(req.Context, addr)
95 p.mu.Unlock()
96 <-call.done
97 if shouldRetryDial(call, req) {
98 continue
99 }
100 cc, err := call.res, call.err
101 if err != nil {
102 return nil, err
103 }
104 if cc.ReserveNewRequest() {
105 return cc, nil
106 }
107 }
108 }
109
110
111 type dialCall struct {
112 _ incomparable
113 p *clientConnPool
114
115
116 ctx context.Context
117 done chan struct{}
118 res *ClientConn
119 err error
120 }
121
122
123 func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall {
124 if call, ok := p.dialing[addr]; ok {
125
126 return call
127 }
128 call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx}
129 if p.dialing == nil {
130 p.dialing = make(map[string]*dialCall)
131 }
132 p.dialing[addr] = call
133 go call.dial(call.ctx, addr)
134 return call
135 }
136
137
138 func (c *dialCall) dial(ctx context.Context, addr string) {
139 const singleUse = false
140 c.res, c.err = c.p.t.dialClientConn(ctx, addr, singleUse)
141
142 c.p.mu.Lock()
143 delete(c.p.dialing, addr)
144 if c.err == nil {
145 c.p.addConnLocked(addr, c.res)
146 }
147 c.p.mu.Unlock()
148
149 close(c.done)
150 }
151
152
153
154
155
156
157
158
159
160 func (p *clientConnPool) addConnIfNeeded(key string, t *Transport, c net.Conn) (used bool, err error) {
161 p.mu.Lock()
162 for _, cc := range p.conns[key] {
163 if cc.CanTakeNewRequest() {
164 p.mu.Unlock()
165 return false, nil
166 }
167 }
168 call, dup := p.addConnCalls[key]
169 if !dup {
170 if p.addConnCalls == nil {
171 p.addConnCalls = make(map[string]*addConnCall)
172 }
173 call = &addConnCall{
174 p: p,
175 done: make(chan struct{}),
176 }
177 p.addConnCalls[key] = call
178 go call.run(t, key, c)
179 }
180 p.mu.Unlock()
181
182 <-call.done
183 if call.err != nil {
184 return false, call.err
185 }
186 return !dup, nil
187 }
188
189 type addConnCall struct {
190 _ incomparable
191 p *clientConnPool
192 done chan struct{}
193 err error
194 }
195
196 func (c *addConnCall) run(t *Transport, key string, nc net.Conn) {
197 cc, err := t.newClientConn(nc, t.disableKeepAlives(), nil)
198
199 p := c.p
200 p.mu.Lock()
201 if err != nil {
202 c.err = err
203 } else {
204 cc.getConnCalled = true
205 p.addConnLocked(key, cc)
206 }
207 delete(p.addConnCalls, key)
208 p.mu.Unlock()
209 close(c.done)
210 }
211
212
213 func (p *clientConnPool) addConnLocked(key string, cc *ClientConn) {
214 for _, v := range p.conns[key] {
215 if v == cc {
216 return
217 }
218 }
219 if p.conns == nil {
220 p.conns = make(map[string][]*ClientConn)
221 }
222 if p.keys == nil {
223 p.keys = make(map[*ClientConn][]string)
224 }
225 p.conns[key] = append(p.conns[key], cc)
226 p.keys[cc] = append(p.keys[cc], key)
227 }
228
229 func (p *clientConnPool) MarkDead(cc *ClientConn) {
230 p.mu.Lock()
231 defer p.mu.Unlock()
232 for _, key := range p.keys[cc] {
233 vv, ok := p.conns[key]
234 if !ok {
235 continue
236 }
237 newList := filterOutClientConn(vv, cc)
238 if len(newList) > 0 {
239 p.conns[key] = newList
240 } else {
241 delete(p.conns, key)
242 }
243 }
244 delete(p.keys, cc)
245 }
246
247 func (p *clientConnPool) closeIdleConnections() {
248 p.mu.Lock()
249 defer p.mu.Unlock()
250
251
252
253
254
255
256 for _, vv := range p.conns {
257 for _, cc := range vv {
258 cc.closeIfIdle()
259 }
260 }
261 }
262
263 func filterOutClientConn(in []*ClientConn, exclude *ClientConn) []*ClientConn {
264 out := in[:0]
265 for _, v := range in {
266 if v != exclude {
267 out = append(out, v)
268 }
269 }
270
271
272 if len(in) != len(out) {
273 in[len(in)-1] = nil
274 }
275 return out
276 }
277
278
279
280
281 type noDialClientConnPool struct{ *clientConnPool }
282
283 func (p noDialClientConnPool) GetClientConn(req *ClientRequest, addr string) (*ClientConn, error) {
284 return p.getClientConn(req, addr, noDialOnMiss)
285 }
286
287
288
289
290
291 func shouldRetryDial(call *dialCall, req *ClientRequest) bool {
292 if call.err == nil {
293
294 return false
295 }
296 if call.ctx == req.Context {
297
298
299
300 return false
301 }
302 if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
303
304
305 return false
306 }
307
308
309 return call.ctx.Err() != nil
310 }
311
View as plain text