Source file src/crypto/tls/handshake_server_test.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"crypto"
    11  	"crypto/ecdh"
    12  	"crypto/elliptic"
    13  	"crypto/rand"
    14  	"crypto/tls/internal/fips140tls"
    15  	"crypto/x509"
    16  	"crypto/x509/pkix"
    17  	"encoding/pem"
    18  	"errors"
    19  	"fmt"
    20  	"io"
    21  	"net"
    22  	"os"
    23  	"os/exec"
    24  	"path/filepath"
    25  	"runtime"
    26  	"slices"
    27  	"strings"
    28  	"sync/atomic"
    29  	"testing"
    30  	"time"
    31  )
    32  
    33  func testClientHello(t *testing.T, serverConfig *Config, m handshakeMessage) {
    34  	t.Helper()
    35  	testClientHelloFailure(t, serverConfig, m, "")
    36  }
    37  
    38  // testFatal is a hack to prevent the compiler from complaining that there is a
    39  // call to t.Fatal from a non-test goroutine
    40  func testFatal(t *testing.T, err error) {
    41  	t.Helper()
    42  	t.Fatal(err)
    43  }
    44  
    45  func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessage, expectedSubStr string) {
    46  	c, s := localPipe(t)
    47  	go func() {
    48  		cli := Client(c, testConfig)
    49  		if ch, ok := m.(*clientHelloMsg); ok {
    50  			cli.vers = ch.vers
    51  		}
    52  		if _, err := cli.writeHandshakeRecord(m, nil); err != nil {
    53  			testFatal(t, err)
    54  		}
    55  		c.Close()
    56  	}()
    57  	ctx := context.Background()
    58  	conn := Server(s, serverConfig)
    59  	ch, ech, err := conn.readClientHello(ctx)
    60  	if conn.vers == VersionTLS13 {
    61  		hs := serverHandshakeStateTLS13{
    62  			c:           conn,
    63  			ctx:         ctx,
    64  			clientHello: ch,
    65  			echContext:  ech,
    66  		}
    67  		if err == nil {
    68  			err = hs.processClientHello()
    69  		}
    70  		if err == nil {
    71  			err = hs.checkForResumption()
    72  		}
    73  		if err == nil {
    74  			err = hs.pickCertificate()
    75  		}
    76  	} else {
    77  		hs := serverHandshakeState{
    78  			c:           conn,
    79  			ctx:         ctx,
    80  			clientHello: ch,
    81  		}
    82  		if err == nil {
    83  			err = hs.processClientHello()
    84  		}
    85  		if err == nil {
    86  			err = hs.pickCipherSuite()
    87  		}
    88  	}
    89  	s.Close()
    90  	t.Helper()
    91  	if len(expectedSubStr) == 0 {
    92  		if err != nil && err != io.EOF {
    93  			t.Errorf("Got error: %s; expected to succeed", err)
    94  		}
    95  	} else if err == nil || !strings.Contains(err.Error(), expectedSubStr) {
    96  		t.Errorf("Got error: %v; expected to match substring '%s'", err, expectedSubStr)
    97  	}
    98  }
    99  
   100  func TestSimpleError(t *testing.T) {
   101  	testClientHelloFailure(t, testConfig, &serverHelloDoneMsg{}, "unexpected handshake message")
   102  }
   103  
   104  var badProtocolVersions = []uint16{0x0000, 0x0005, 0x0100, 0x0105, 0x0200, 0x0205, VersionSSL30}
   105  
   106  func TestRejectBadProtocolVersion(t *testing.T) {
   107  	config := testConfig.Clone()
   108  	config.MinVersion = VersionSSL30
   109  	for _, v := range badProtocolVersions {
   110  		testClientHelloFailure(t, config, &clientHelloMsg{
   111  			vers:   v,
   112  			random: make([]byte, 32),
   113  		}, "unsupported versions")
   114  	}
   115  	testClientHelloFailure(t, config, &clientHelloMsg{
   116  		vers:              VersionTLS12,
   117  		supportedVersions: badProtocolVersions,
   118  		random:            make([]byte, 32),
   119  	}, "unsupported versions")
   120  }
   121  
   122  func TestNoSuiteOverlap(t *testing.T) {
   123  	clientHello := &clientHelloMsg{
   124  		vers:               VersionTLS12,
   125  		random:             make([]byte, 32),
   126  		cipherSuites:       []uint16{0xff00},
   127  		compressionMethods: []uint8{compressionNone},
   128  	}
   129  	testClientHelloFailure(t, testConfig, clientHello, "no cipher suite supported by both client and server")
   130  }
   131  
   132  func TestNoCompressionOverlap(t *testing.T) {
   133  	clientHello := &clientHelloMsg{
   134  		vers:               VersionTLS12,
   135  		random:             make([]byte, 32),
   136  		cipherSuites:       []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
   137  		compressionMethods: []uint8{0xff},
   138  	}
   139  	testClientHelloFailure(t, testConfig, clientHello, "client does not support uncompressed connections")
   140  }
   141  
   142  func TestNoRC4ByDefault(t *testing.T) {
   143  	clientHello := &clientHelloMsg{
   144  		vers:               VersionTLS12,
   145  		random:             make([]byte, 32),
   146  		cipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
   147  		compressionMethods: []uint8{compressionNone},
   148  	}
   149  	serverConfig := testConfig.Clone()
   150  	// Reset the enabled cipher suites to nil in order to test the
   151  	// defaults.
   152  	serverConfig.CipherSuites = nil
   153  	testClientHelloFailure(t, serverConfig, clientHello, "no cipher suite supported by both client and server")
   154  }
   155  
   156  func TestRejectSNIWithTrailingDot(t *testing.T) {
   157  	testClientHelloFailure(t, testConfig, &clientHelloMsg{
   158  		vers:       VersionTLS12,
   159  		random:     make([]byte, 32),
   160  		serverName: "foo.com.",
   161  	}, "decoding message")
   162  }
   163  
   164  func TestDontSelectECDSAWithRSAKey(t *testing.T) {
   165  	// Test that, even when both sides support an ECDSA cipher suite, it
   166  	// won't be selected if the server's private key doesn't support it.
   167  	clientHello := &clientHelloMsg{
   168  		vers:               VersionTLS12,
   169  		random:             make([]byte, 32),
   170  		cipherSuites:       []uint16{TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384},
   171  		compressionMethods: []uint8{compressionNone},
   172  		supportedCurves:    []CurveID{CurveP256},
   173  		supportedPoints:    []uint8{pointFormatUncompressed},
   174  	}
   175  	serverConfig := testConfig.Clone()
   176  	serverConfig.CipherSuites = clientHello.cipherSuites
   177  	serverConfig.Certificates = make([]Certificate, 1)
   178  	serverConfig.Certificates[0].Certificate = [][]byte{testECDSACertificate}
   179  	serverConfig.Certificates[0].PrivateKey = testECDSAPrivateKey
   180  	serverConfig.BuildNameToCertificate()
   181  	// First test that it *does* work when the server's key is ECDSA.
   182  	testClientHello(t, serverConfig, clientHello)
   183  
   184  	// Now test that switching to an RSA key causes the expected error (and
   185  	// not an internal error about a signing failure).
   186  	serverConfig.Certificates = testConfig.Certificates
   187  	testClientHelloFailure(t, serverConfig, clientHello, "no cipher suite supported by both client and server")
   188  }
   189  
   190  func TestDontSelectRSAWithECDSAKey(t *testing.T) {
   191  	// Test that, even when both sides support an RSA cipher suite, it
   192  	// won't be selected if the server's private key doesn't support it.
   193  	clientHello := &clientHelloMsg{
   194  		vers:               VersionTLS12,
   195  		random:             make([]byte, 32),
   196  		cipherSuites:       []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
   197  		compressionMethods: []uint8{compressionNone},
   198  		supportedCurves:    []CurveID{CurveP256},
   199  		supportedPoints:    []uint8{pointFormatUncompressed},
   200  	}
   201  	serverConfig := testConfig.Clone()
   202  	serverConfig.CipherSuites = clientHello.cipherSuites
   203  	// First test that it *does* work when the server's key is RSA.
   204  	testClientHello(t, serverConfig, clientHello)
   205  
   206  	// Now test that switching to an ECDSA key causes the expected error
   207  	// (and not an internal error about a signing failure).
   208  	serverConfig.Certificates = make([]Certificate, 1)
   209  	serverConfig.Certificates[0].Certificate = [][]byte{testECDSACertificate}
   210  	serverConfig.Certificates[0].PrivateKey = testECDSAPrivateKey
   211  	serverConfig.BuildNameToCertificate()
   212  	testClientHelloFailure(t, serverConfig, clientHello, "no cipher suite supported by both client and server")
   213  }
   214  
   215  func TestRenegotiationExtension(t *testing.T) {
   216  	skipFIPS(t) // #70505
   217  
   218  	clientHello := &clientHelloMsg{
   219  		vers:                         VersionTLS12,
   220  		compressionMethods:           []uint8{compressionNone},
   221  		random:                       make([]byte, 32),
   222  		secureRenegotiationSupported: true,
   223  		cipherSuites:                 []uint16{TLS_RSA_WITH_RC4_128_SHA},
   224  	}
   225  
   226  	bufChan := make(chan []byte, 1)
   227  	c, s := localPipe(t)
   228  
   229  	go func() {
   230  		cli := Client(c, testConfig)
   231  		cli.vers = clientHello.vers
   232  		if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil {
   233  			testFatal(t, err)
   234  		}
   235  
   236  		buf := make([]byte, 1024)
   237  		n, err := c.Read(buf)
   238  		if err != nil {
   239  			t.Errorf("Server read returned error: %s", err)
   240  		}
   241  		c.Close()
   242  		bufChan <- buf[:n]
   243  	}()
   244  
   245  	Server(s, testConfig).Handshake()
   246  	buf := <-bufChan
   247  
   248  	if len(buf) < 5+4 {
   249  		t.Fatalf("Server returned short message of length %d", len(buf))
   250  	}
   251  	// buf contains a TLS record, with a 5 byte record header and a 4 byte
   252  	// handshake header. The length of the ServerHello is taken from the
   253  	// handshake header.
   254  	serverHelloLen := int(buf[6])<<16 | int(buf[7])<<8 | int(buf[8])
   255  
   256  	var serverHello serverHelloMsg
   257  	// unmarshal expects to be given the handshake header, but
   258  	// serverHelloLen doesn't include it.
   259  	if !serverHello.unmarshal(buf[5 : 9+serverHelloLen]) {
   260  		t.Fatalf("Failed to parse ServerHello")
   261  	}
   262  
   263  	if !serverHello.secureRenegotiationSupported {
   264  		t.Errorf("Secure renegotiation extension was not echoed.")
   265  	}
   266  }
   267  
   268  func TestTLS12OnlyCipherSuites(t *testing.T) {
   269  	skipFIPS(t) // No TLS 1.1 in FIPS mode.
   270  
   271  	// Test that a Server doesn't select a TLS 1.2-only cipher suite when
   272  	// the client negotiates TLS 1.1.
   273  	clientHello := &clientHelloMsg{
   274  		vers:   VersionTLS11,
   275  		random: make([]byte, 32),
   276  		cipherSuites: []uint16{
   277  			// The Server, by default, will use the client's
   278  			// preference order. So the GCM cipher suite
   279  			// will be selected unless it's excluded because
   280  			// of the version in this ClientHello.
   281  			TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
   282  			TLS_RSA_WITH_RC4_128_SHA,
   283  		},
   284  		compressionMethods: []uint8{compressionNone},
   285  		supportedCurves:    []CurveID{CurveP256, CurveP384, CurveP521},
   286  		supportedPoints:    []uint8{pointFormatUncompressed},
   287  	}
   288  
   289  	c, s := localPipe(t)
   290  	replyChan := make(chan any)
   291  	go func() {
   292  		cli := Client(c, testConfig)
   293  		cli.vers = clientHello.vers
   294  		if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil {
   295  			testFatal(t, err)
   296  		}
   297  		reply, err := cli.readHandshake(nil)
   298  		c.Close()
   299  		if err != nil {
   300  			replyChan <- err
   301  		} else {
   302  			replyChan <- reply
   303  		}
   304  	}()
   305  	config := testConfig.Clone()
   306  	config.CipherSuites = clientHello.cipherSuites
   307  	Server(s, config).Handshake()
   308  	s.Close()
   309  	reply := <-replyChan
   310  	if err, ok := reply.(error); ok {
   311  		t.Fatal(err)
   312  	}
   313  	serverHello, ok := reply.(*serverHelloMsg)
   314  	if !ok {
   315  		t.Fatalf("didn't get ServerHello message in reply. Got %v\n", reply)
   316  	}
   317  	if s := serverHello.cipherSuite; s != TLS_RSA_WITH_RC4_128_SHA {
   318  		t.Fatalf("bad cipher suite from server: %x", s)
   319  	}
   320  }
   321  
   322  func TestTLSPointFormats(t *testing.T) {
   323  	// Test that a Server returns the ec_point_format extension when ECC is
   324  	// negotiated, and not on a RSA handshake or if ec_point_format is missing.
   325  	tests := []struct {
   326  		name                string
   327  		cipherSuites        []uint16
   328  		supportedCurves     []CurveID
   329  		supportedPoints     []uint8
   330  		wantSupportedPoints bool
   331  	}{
   332  		{"ECC", []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, []CurveID{CurveP256}, []uint8{pointFormatUncompressed}, true},
   333  		{"ECC without ec_point_format", []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, []CurveID{CurveP256}, nil, false},
   334  		{"ECC with extra values", []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, []CurveID{CurveP256}, []uint8{13, 37, pointFormatUncompressed, 42}, true},
   335  		{"RSA", []uint16{TLS_RSA_WITH_AES_256_GCM_SHA384}, nil, nil, false},
   336  		{"RSA with ec_point_format", []uint16{TLS_RSA_WITH_AES_256_GCM_SHA384}, nil, []uint8{pointFormatUncompressed}, false},
   337  	}
   338  	for _, tt := range tests {
   339  		// The RSA subtests should be enabled for FIPS 140 required mode: #70505
   340  		if strings.HasPrefix(tt.name, "RSA") && fips140tls.Required() {
   341  			t.Logf("skipping in FIPS mode.")
   342  			continue
   343  		}
   344  		t.Run(tt.name, func(t *testing.T) {
   345  			clientHello := &clientHelloMsg{
   346  				vers:               VersionTLS12,
   347  				random:             make([]byte, 32),
   348  				cipherSuites:       tt.cipherSuites,
   349  				compressionMethods: []uint8{compressionNone},
   350  				supportedCurves:    tt.supportedCurves,
   351  				supportedPoints:    tt.supportedPoints,
   352  			}
   353  
   354  			c, s := localPipe(t)
   355  			replyChan := make(chan any)
   356  			go func() {
   357  				clientConfig := testConfig.Clone()
   358  				clientConfig.Certificates = []Certificate{{Certificate: [][]byte{testRSA2048Certificate}, PrivateKey: testRSA2048PrivateKey}}
   359  				cli := Client(c, clientConfig)
   360  				cli.vers = clientHello.vers
   361  				if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil {
   362  					testFatal(t, err)
   363  				}
   364  				reply, err := cli.readHandshake(nil)
   365  				c.Close()
   366  				if err != nil {
   367  					replyChan <- err
   368  				} else {
   369  					replyChan <- reply
   370  				}
   371  			}()
   372  			serverConfig := testConfig.Clone()
   373  			serverConfig.Certificates = []Certificate{{Certificate: [][]byte{testRSA2048Certificate}, PrivateKey: testRSA2048PrivateKey}}
   374  			serverConfig.CipherSuites = clientHello.cipherSuites
   375  			Server(s, serverConfig).Handshake()
   376  			s.Close()
   377  			reply := <-replyChan
   378  			if err, ok := reply.(error); ok {
   379  				t.Fatal(err)
   380  			}
   381  			serverHello, ok := reply.(*serverHelloMsg)
   382  			if !ok {
   383  				t.Fatalf("didn't get ServerHello message in reply. Got %v\n", reply)
   384  			}
   385  			if tt.wantSupportedPoints {
   386  				if !bytes.Equal(serverHello.supportedPoints, []uint8{pointFormatUncompressed}) {
   387  					t.Fatal("incorrect ec_point_format extension from server")
   388  				}
   389  			} else {
   390  				if len(serverHello.supportedPoints) != 0 {
   391  					t.Fatalf("unexpected ec_point_format extension from server: %v", serverHello.supportedPoints)
   392  				}
   393  			}
   394  		})
   395  	}
   396  }
   397  
   398  func TestAlertForwarding(t *testing.T) {
   399  	c, s := localPipe(t)
   400  	go func() {
   401  		Client(c, testConfig).sendAlert(alertUnknownCA)
   402  		c.Close()
   403  	}()
   404  
   405  	err := Server(s, testConfig).Handshake()
   406  	s.Close()
   407  	if opErr, ok := errors.AsType[*net.OpError](err); !ok || opErr.Err != error(alertUnknownCA) {
   408  		t.Errorf("Got error: %s; expected: %s", err, error(alertUnknownCA))
   409  	}
   410  }
   411  
   412  func TestClose(t *testing.T) {
   413  	c, s := localPipe(t)
   414  	go c.Close()
   415  
   416  	err := Server(s, testConfig).Handshake()
   417  	s.Close()
   418  	if err != io.EOF {
   419  		t.Errorf("Got error: %s; expected: %s", err, io.EOF)
   420  	}
   421  }
   422  
   423  func TestVersion(t *testing.T) {
   424  	serverConfig := &Config{
   425  		Certificates: testConfig.Certificates,
   426  		MaxVersion:   VersionTLS13,
   427  	}
   428  	clientConfig := &Config{
   429  		InsecureSkipVerify: true,
   430  		MinVersion:         VersionTLS12,
   431  	}
   432  	state, _, err := testHandshake(t, clientConfig, serverConfig)
   433  	if err != nil {
   434  		t.Fatalf("handshake failed: %s", err)
   435  	}
   436  	if state.Version != VersionTLS13 {
   437  		t.Fatalf("incorrect version %x, should be %x", state.Version, VersionTLS11)
   438  	}
   439  
   440  	clientConfig.MinVersion = 0
   441  	serverConfig.MaxVersion = VersionTLS11
   442  	_, _, err = testHandshake(t, clientConfig, serverConfig)
   443  	if err == nil {
   444  		t.Fatalf("expected failure to connect with TLS 1.0/1.1")
   445  	}
   446  }
   447  
   448  func TestCipherSuitePreference(t *testing.T) {
   449  	skipFIPS(t) // No RC4 or CHACHA20_POLY1305 in FIPS mode.
   450  
   451  	serverConfig := &Config{
   452  		CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA, TLS_AES_128_GCM_SHA256,
   453  			TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256},
   454  		Certificates: testConfig.Certificates,
   455  		MaxVersion:   VersionTLS12,
   456  		GetConfigForClient: func(chi *ClientHelloInfo) (*Config, error) {
   457  			if chi.CipherSuites[0] != TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 {
   458  				t.Error("the advertised order should not depend on Config.CipherSuites")
   459  			}
   460  			if len(chi.CipherSuites) != 2+len(defaultCipherSuitesTLS13) {
   461  				t.Error("the advertised TLS 1.2 suites should be filtered by Config.CipherSuites")
   462  			}
   463  			return nil, nil
   464  		},
   465  	}
   466  	clientConfig := &Config{
   467  		CipherSuites:       []uint16{TLS_RSA_WITH_AES_128_CBC_SHA, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256},
   468  		InsecureSkipVerify: true,
   469  	}
   470  	state, _, err := testHandshake(t, clientConfig, serverConfig)
   471  	if err != nil {
   472  		t.Fatalf("handshake failed: %s", err)
   473  	}
   474  	if state.CipherSuite != TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 {
   475  		t.Error("the preference order should not depend on Config.CipherSuites")
   476  	}
   477  }
   478  
   479  func TestSCTHandshake(t *testing.T) {
   480  	t.Run("TLSv12", func(t *testing.T) { testSCTHandshake(t, VersionTLS12) })
   481  	t.Run("TLSv13", func(t *testing.T) { testSCTHandshake(t, VersionTLS13) })
   482  }
   483  
   484  func testSCTHandshake(t *testing.T, version uint16) {
   485  	expected := [][]byte{[]byte("certificate"), []byte("transparency")}
   486  	serverConfig := &Config{
   487  		Certificates: []Certificate{{
   488  			Certificate:                 [][]byte{testRSACertificate},
   489  			PrivateKey:                  testRSAPrivateKey,
   490  			SignedCertificateTimestamps: expected,
   491  		}},
   492  		MaxVersion: version,
   493  	}
   494  	clientConfig := &Config{
   495  		InsecureSkipVerify: true,
   496  	}
   497  	_, state, err := testHandshake(t, clientConfig, serverConfig)
   498  	if err != nil {
   499  		t.Fatalf("handshake failed: %s", err)
   500  	}
   501  	actual := state.SignedCertificateTimestamps
   502  	if len(actual) != len(expected) {
   503  		t.Fatalf("got %d scts, want %d", len(actual), len(expected))
   504  	}
   505  	for i, sct := range expected {
   506  		if !bytes.Equal(sct, actual[i]) {
   507  			t.Fatalf("SCT #%d was %x, but expected %x", i, actual[i], sct)
   508  		}
   509  	}
   510  }
   511  
   512  func TestCrossVersionResume(t *testing.T) {
   513  	t.Run("TLSv12", func(t *testing.T) { testCrossVersionResume(t, VersionTLS12) })
   514  	t.Run("TLSv13", func(t *testing.T) { testCrossVersionResume(t, VersionTLS13) })
   515  }
   516  
   517  func testCrossVersionResume(t *testing.T, version uint16) {
   518  	serverConfig := &Config{
   519  		CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
   520  		Certificates: testConfig.Certificates,
   521  		Time:         testTime,
   522  	}
   523  	clientConfig := &Config{
   524  		CipherSuites:       []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
   525  		InsecureSkipVerify: true,
   526  		ClientSessionCache: NewLRUClientSessionCache(1),
   527  		ServerName:         "servername",
   528  		MinVersion:         VersionTLS12,
   529  		Time:               testTime,
   530  	}
   531  
   532  	// Establish a session at TLS 1.3.
   533  	clientConfig.MaxVersion = VersionTLS13
   534  	_, _, err := testHandshake(t, clientConfig, serverConfig)
   535  	if err != nil {
   536  		t.Fatalf("handshake failed: %s", err)
   537  	}
   538  
   539  	// The client session cache now contains a TLS 1.3 session.
   540  	state, _, err := testHandshake(t, clientConfig, serverConfig)
   541  	if err != nil {
   542  		t.Fatalf("handshake failed: %s", err)
   543  	}
   544  	if !state.DidResume {
   545  		t.Fatalf("handshake did not resume at the same version")
   546  	}
   547  
   548  	// Test that the server will decline to resume at a lower version.
   549  	clientConfig.MaxVersion = VersionTLS12
   550  	state, _, err = testHandshake(t, clientConfig, serverConfig)
   551  	if err != nil {
   552  		t.Fatalf("handshake failed: %s", err)
   553  	}
   554  	if state.DidResume {
   555  		t.Fatalf("handshake resumed at a lower version")
   556  	}
   557  
   558  	// The client session cache now contains a TLS 1.2 session.
   559  	state, _, err = testHandshake(t, clientConfig, serverConfig)
   560  	if err != nil {
   561  		t.Fatalf("handshake failed: %s", err)
   562  	}
   563  	if !state.DidResume {
   564  		t.Fatalf("handshake did not resume at the same version")
   565  	}
   566  
   567  	// Test that the server will decline to resume at a higher version.
   568  	clientConfig.MaxVersion = VersionTLS13
   569  	state, _, err = testHandshake(t, clientConfig, serverConfig)
   570  	if err != nil {
   571  		t.Fatalf("handshake failed: %s", err)
   572  	}
   573  	if state.DidResume {
   574  		t.Fatalf("handshake resumed at a higher version")
   575  	}
   576  }
   577  
   578  // Note: see comment in handshake_test.go for details of how the reference
   579  // tests work.
   580  
   581  // serverTest represents a test of the TLS server handshake against a reference
   582  // implementation.
   583  type serverTest struct {
   584  	// name is a freeform string identifying the test and the file in which
   585  	// the expected results will be stored.
   586  	name string
   587  	// command, if not empty, contains a series of arguments for the
   588  	// command to run for the reference server.
   589  	command []string
   590  	// expectedPeerCerts contains a list of PEM blocks of expected
   591  	// certificates from the client.
   592  	expectedPeerCerts []string
   593  	// config, if not nil, contains a custom Config to use for this test.
   594  	config *Config
   595  	// expectHandshakeErrorIncluding, when not empty, contains a string
   596  	// that must be a substring of the error resulting from the handshake.
   597  	expectHandshakeErrorIncluding string
   598  	// validate, if not nil, is a function that will be called with the
   599  	// ConnectionState of the resulting connection. It returns false if the
   600  	// ConnectionState is unacceptable.
   601  	validate func(ConnectionState) error
   602  }
   603  
   604  var defaultClientCommand []string
   605  
   606  // connFromCommand starts opens a listening socket and starts the reference
   607  // client to connect to it. It returns a recordingConn that wraps the resulting
   608  // connection.
   609  func (test *serverTest) connFromCommand() (conn *recordingConn, child *exec.Cmd, exit <-chan error, err error) {
   610  	l, err := net.ListenTCP("tcp", &net.TCPAddr{
   611  		IP:   net.IPv4(127, 0, 0, 1),
   612  		Port: 0,
   613  	})
   614  	if err != nil {
   615  		return nil, nil, nil, err
   616  	}
   617  	defer l.Close()
   618  
   619  	port := l.Addr().(*net.TCPAddr).Port
   620  
   621  	var command []string
   622  	command = append(command, test.command...)
   623  	if len(command) == 0 {
   624  		command = defaultClientCommand
   625  	}
   626  	command = append(command, "-connect")
   627  	command = append(command, fmt.Sprintf("127.0.0.1:%d", port))
   628  	cmd := exec.Command(command[0], command[1:]...)
   629  	cmd.Stdin = nil
   630  	var output bytes.Buffer
   631  	cmd.Stdout = &output
   632  	cmd.Stderr = &output
   633  	if err := cmd.Start(); err != nil {
   634  		return nil, nil, nil, err
   635  	}
   636  
   637  	exitChan := make(chan error, 1)
   638  	go func() {
   639  		exitChan <- cmd.Wait()
   640  	}()
   641  
   642  	connChan := make(chan any, 1)
   643  	go func() {
   644  		tcpConn, err := l.Accept()
   645  		if err != nil {
   646  			connChan <- err
   647  			return
   648  		}
   649  		connChan <- tcpConn
   650  	}()
   651  
   652  	var tcpConn net.Conn
   653  	select {
   654  	case connOrError := <-connChan:
   655  		if err, ok := connOrError.(error); ok {
   656  			return nil, nil, nil, err
   657  		}
   658  		tcpConn = connOrError.(net.Conn)
   659  	case err := <-exitChan:
   660  		return nil, nil, nil, fmt.Errorf("child process exited before connecting: %v\n%s", err, output.String())
   661  	case <-time.After(2 * time.Second):
   662  		cmd.Process.Kill()
   663  		return nil, nil, nil, fmt.Errorf("timed out waiting for connection from child process\n%s", output.String())
   664  	}
   665  
   666  	record := &recordingConn{
   667  		Conn: tcpConn,
   668  	}
   669  
   670  	return record, cmd, exitChan, nil
   671  }
   672  
   673  func (test *serverTest) dataPath() string {
   674  	return filepath.Join("testdata", "Server-"+test.name)
   675  }
   676  
   677  func (test *serverTest) loadData() (flows [][]byte, err error) {
   678  	in, err := os.Open(test.dataPath())
   679  	if err != nil {
   680  		return nil, err
   681  	}
   682  	defer in.Close()
   683  	return parseTestData(in)
   684  }
   685  
   686  func (test *serverTest) run(t *testing.T, write bool) {
   687  	var serverConn net.Conn
   688  	var recordingConn *recordingConn
   689  	var childProcess *exec.Cmd
   690  	var childExit <-chan error
   691  
   692  	if write {
   693  		var err error
   694  		recordingConn, childProcess, childExit, err = test.connFromCommand()
   695  		if err != nil {
   696  			t.Fatalf("Failed to start subcommand: %s", err)
   697  		}
   698  		serverConn = recordingConn
   699  	} else {
   700  		flows, err := test.loadData()
   701  		if err != nil {
   702  			t.Fatalf("Failed to load data from %s", test.dataPath())
   703  		}
   704  		serverConn = &replayingConn{t: t, flows: flows, reading: true}
   705  	}
   706  	config := test.config
   707  	if config == nil {
   708  		config = testConfigServer
   709  	}
   710  	config = config.Clone()
   711  	server := Server(serverConn, config)
   712  
   713  	_, err := server.Write([]byte("hello, world\n"))
   714  	if len(test.expectHandshakeErrorIncluding) > 0 {
   715  		if err == nil {
   716  			t.Errorf("Error expected, but no error returned")
   717  		} else if s := err.Error(); !strings.Contains(s, test.expectHandshakeErrorIncluding) {
   718  			t.Errorf("Error expected containing '%s' but got '%s'", test.expectHandshakeErrorIncluding, s)
   719  		}
   720  	} else {
   721  		if err != nil {
   722  			t.Errorf("Error from Server.Write: '%s'", err)
   723  		}
   724  	}
   725  	server.Close()
   726  
   727  	connState := server.ConnectionState()
   728  	peerCerts := connState.PeerCertificates
   729  	if len(peerCerts) == len(test.expectedPeerCerts) {
   730  		for i, peerCert := range peerCerts {
   731  			block, _ := pem.Decode([]byte(test.expectedPeerCerts[i]))
   732  			if !bytes.Equal(block.Bytes, peerCert.Raw) {
   733  				t.Fatalf("%s: mismatch on peer cert %d", test.name, i+1)
   734  			}
   735  		}
   736  	} else {
   737  		t.Fatalf("%s: mismatch on peer list length: %d (wanted) != %d (got)", test.name, len(test.expectedPeerCerts), len(peerCerts))
   738  	}
   739  
   740  	if test.validate != nil && !t.Failed() {
   741  		if err := test.validate(connState); err != nil {
   742  			t.Fatalf("validate callback returned error: %s", err)
   743  		}
   744  	}
   745  
   746  	if write {
   747  		serverConn.Close()
   748  		recordingConn.Close()
   749  		if err := <-childExit; err != nil && len(test.expectHandshakeErrorIncluding) == 0 {
   750  			t.Errorf("OpenSSL exited with error: %s", err)
   751  		}
   752  		if t.Failed() {
   753  			t.Logf("OpenSSL output:\n\n%s", childProcess.Stdout)
   754  			return
   755  		}
   756  		if len(recordingConn.flows) < 3 {
   757  			if len(test.expectHandshakeErrorIncluding) == 0 {
   758  				t.Fatalf("Handshake failed")
   759  			}
   760  		}
   761  		path := test.dataPath()
   762  		out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
   763  		if err != nil {
   764  			t.Fatalf("Failed to create output file: %s", err)
   765  		}
   766  		defer out.Close()
   767  		recordingConn.WriteTo(out)
   768  		t.Logf("Wrote %s\n", path)
   769  	}
   770  }
   771  
   772  func runServerTestForVersion(t *testing.T, template *serverTest, version, option string) {
   773  	test := *template
   774  	if template.config != nil {
   775  		test.config = template.config.Clone()
   776  	}
   777  	test.name = version + "-" + test.name
   778  	if len(test.command) == 0 {
   779  		test.command = defaultClientCommand
   780  	}
   781  	test.command = append([]string(nil), test.command...)
   782  	test.command = append(test.command, option)
   783  
   784  	runTestAndUpdateIfNeeded(t, version, test.run)
   785  }
   786  
   787  func runServerTestTLS10(t *testing.T, template *serverTest) {
   788  	if template.config == nil {
   789  		template.config = testConfigServer.Clone()
   790  	}
   791  	if template.config.MinVersion == 0 {
   792  		template.config.MinVersion = VersionTLS10
   793  	}
   794  	runServerTestForVersion(t, template, "TLSv10", "-tls1")
   795  }
   796  
   797  func runServerTestTLS11(t *testing.T, template *serverTest) {
   798  	if template.config == nil {
   799  		template.config = testConfigServer.Clone()
   800  	}
   801  	if template.config.MinVersion == 0 {
   802  		template.config.MinVersion = VersionTLS11
   803  	}
   804  	runServerTestForVersion(t, template, "TLSv11", "-tls1_1")
   805  }
   806  
   807  func runServerTestTLS12(t *testing.T, template *serverTest) {
   808  	runServerTestForVersion(t, template, "TLSv12", "-tls1_2")
   809  }
   810  
   811  func runServerTestTLS13(t *testing.T, template *serverTest) {
   812  	runServerTestForVersion(t, template, "TLSv13", "-tls1_3")
   813  }
   814  
   815  func checkCipherSuite(want uint16) func(ConnectionState) error {
   816  	return func(state ConnectionState) error {
   817  		if state.CipherSuite != want {
   818  			return fmt.Errorf("got cipher suite %x, want %x", state.CipherSuite, want)
   819  		}
   820  		return nil
   821  	}
   822  }
   823  
   824  func TestHandshakeServerRSARC4(t *testing.T) {
   825  	config := testConfigServer.Clone()
   826  	config.CipherSuites = []uint16{TLS_RSA_WITH_RC4_128_SHA}
   827  	test := &serverTest{
   828  		name:     "RSA-RC4",
   829  		command:  append(defaultClientCommand, "-cipher", "RC4-SHA"),
   830  		config:   config,
   831  		validate: checkCipherSuite(TLS_RSA_WITH_RC4_128_SHA),
   832  	}
   833  	runServerTestTLS10(t, test)
   834  	runServerTestTLS11(t, test)
   835  	runServerTestTLS12(t, test)
   836  }
   837  
   838  func TestHandshakeServerRSA3DES(t *testing.T) {
   839  	config := testConfigServer.Clone()
   840  	config.CipherSuites = []uint16{TLS_RSA_WITH_3DES_EDE_CBC_SHA}
   841  	test := &serverTest{
   842  		name:     "RSA-3DES",
   843  		command:  append(defaultClientCommand, "-cipher", "DES-CBC3-SHA"),
   844  		config:   config,
   845  		validate: checkCipherSuite(TLS_RSA_WITH_3DES_EDE_CBC_SHA),
   846  	}
   847  	runServerTestTLS10(t, test)
   848  	runServerTestTLS12(t, test)
   849  }
   850  
   851  func TestHandshakeServerRSAAES(t *testing.T) {
   852  	config := testConfigServer.Clone()
   853  	config.CipherSuites = []uint16{TLS_RSA_WITH_AES_128_CBC_SHA}
   854  	test := &serverTest{
   855  		name:     "RSA-AES",
   856  		command:  append(defaultClientCommand, "-cipher", "AES128-SHA"),
   857  		config:   config,
   858  		validate: checkCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA),
   859  	}
   860  	runServerTestTLS10(t, test)
   861  	runServerTestTLS12(t, test)
   862  }
   863  
   864  func TestHandshakeServerAESGCM(t *testing.T) {
   865  	test := &serverTest{
   866  		name:     "RSA-AES-GCM",
   867  		command:  append(defaultClientCommand, "-cipher", "ECDHE-RSA-AES128-GCM-SHA256"),
   868  		validate: checkCipherSuite(TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256),
   869  	}
   870  	runServerTestTLS12(t, test)
   871  }
   872  
   873  func TestHandshakeServerAES256GCMSHA384(t *testing.T) {
   874  	test := &serverTest{
   875  		name:     "RSA-AES256-GCM-SHA384",
   876  		command:  append(defaultClientCommand, "-cipher", "ECDHE-RSA-AES256-GCM-SHA384"),
   877  		validate: checkCipherSuite(TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384),
   878  	}
   879  	runServerTestTLS12(t, test)
   880  }
   881  
   882  func TestHandshakeServerAES128SHA256(t *testing.T) {
   883  	test := &serverTest{
   884  		name:     "AES128-SHA256",
   885  		command:  append(defaultClientCommand, "-ciphersuites", "TLS_AES_128_GCM_SHA256"),
   886  		validate: checkCipherSuite(TLS_AES_128_GCM_SHA256),
   887  	}
   888  	runServerTestTLS13(t, test)
   889  }
   890  
   891  func TestHandshakeServerAES256SHA384(t *testing.T) {
   892  	test := &serverTest{
   893  		name:     "AES256-SHA384",
   894  		command:  append(defaultClientCommand, "-ciphersuites", "TLS_AES_256_GCM_SHA384"),
   895  		validate: checkCipherSuite(TLS_AES_256_GCM_SHA384),
   896  	}
   897  	runServerTestTLS13(t, test)
   898  }
   899  
   900  func TestHandshakeServerCHACHA20SHA256(t *testing.T) {
   901  	test := &serverTest{
   902  		name:     "CHACHA20-SHA256",
   903  		command:  append(defaultClientCommand, "-ciphersuites", "TLS_CHACHA20_POLY1305_SHA256"),
   904  		validate: checkCipherSuite(TLS_CHACHA20_POLY1305_SHA256),
   905  	}
   906  	runServerTestTLS13(t, test)
   907  }
   908  
   909  func TestHandshakeServerECDHEECDSAAES(t *testing.T) {
   910  	test := &serverTest{
   911  		name:    "ECDHE-ECDSA-AES",
   912  		command: append(defaultClientCommand, "-sigalgs", "ecdsa_secp256r1_sha256"),
   913  	}
   914  	runServerTestTLS10(t, test)
   915  	runServerTestTLS12(t, test)
   916  	runServerTestTLS13(t, test)
   917  }
   918  
   919  func checkCurveID(want CurveID) func(ConnectionState) error {
   920  	return func(state ConnectionState) error {
   921  		if state.CurveID != want {
   922  			return fmt.Errorf("got curve %d, want %d", state.CurveID, want)
   923  		}
   924  		return nil
   925  	}
   926  }
   927  
   928  func TestHandshakeServerX25519(t *testing.T) {
   929  	test := &serverTest{
   930  		name:     "X25519",
   931  		command:  append(defaultClientCommand, "-curves", "X25519"),
   932  		validate: checkCurveID(X25519),
   933  	}
   934  	runServerTestTLS12(t, test)
   935  	runServerTestTLS13(t, test)
   936  }
   937  
   938  func TestHandshakeServerP256(t *testing.T) {
   939  	test := &serverTest{
   940  		name:     "P256",
   941  		command:  append(defaultClientCommand, "-curves", "P-256"),
   942  		validate: checkCurveID(CurveP256),
   943  	}
   944  	runServerTestTLS12(t, test)
   945  	runServerTestTLS13(t, test)
   946  }
   947  
   948  func TestHandshakeServerHelloRetryRequest(t *testing.T) {
   949  	config := testConfigServer.Clone()
   950  	config.CurvePreferences = []CurveID{CurveP256}
   951  
   952  	var clientHelloInfoHRR bool
   953  	var getCertificateCalled bool
   954  	config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
   955  		getCertificateCalled = true
   956  		clientHelloInfoHRR = clientHello.HelloRetryRequest
   957  		return nil, nil
   958  	}
   959  
   960  	test := &serverTest{
   961  		name:    "HelloRetryRequest",
   962  		command: append(defaultClientCommand, "-curves", "X25519:P-256"),
   963  		config:  config,
   964  		validate: func(cs ConnectionState) error {
   965  			if !cs.HelloRetryRequest {
   966  				return errors.New("expected HelloRetryRequest")
   967  			}
   968  			if !getCertificateCalled {
   969  				return errors.New("expected GetCertificate to be called")
   970  			}
   971  			if !clientHelloInfoHRR {
   972  				return errors.New("expected ClientHelloInfo.HelloRetryRequest to be true")
   973  			}
   974  			return nil
   975  		},
   976  	}
   977  	runServerTestTLS13(t, test)
   978  }
   979  
   980  // TestHandshakeServerKeySharePreference checks that we prefer a key share even
   981  // if it's later in the CurvePreferences order, and that the client hello HRR
   982  // field is correctly represented.
   983  func TestHandshakeServerKeySharePreference(t *testing.T) {
   984  	config := testConfigServer.Clone()
   985  	config.CurvePreferences = []CurveID{X25519, CurveP256}
   986  
   987  	// We also use this test as a convenient place to assert the ClientHelloInfo
   988  	// HelloRetryRequest field is _not_ set for a non-HRR hello.
   989  	var clientHelloInfoHRR bool
   990  	var getCertificateCalled bool
   991  	config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
   992  		getCertificateCalled = true
   993  		clientHelloInfoHRR = clientHello.HelloRetryRequest
   994  		return &config.Certificates[0], nil
   995  	}
   996  
   997  	test := &serverTest{
   998  		name:    "KeySharePreference",
   999  		command: append(defaultClientCommand, "-curves", "P-256:X25519"),
  1000  		config:  config,
  1001  		validate: func(cs ConnectionState) error {
  1002  			if cs.HelloRetryRequest {
  1003  				return errors.New("unexpected HelloRetryRequest")
  1004  			}
  1005  			if !getCertificateCalled {
  1006  				return errors.New("expected GetCertificate to be called")
  1007  			}
  1008  			if clientHelloInfoHRR {
  1009  				return errors.New("expected ClientHelloInfo.HelloRetryRequest to be false")
  1010  			}
  1011  			return nil
  1012  		},
  1013  	}
  1014  	runServerTestTLS13(t, test)
  1015  }
  1016  
  1017  func checkNegotiatedProtocol(want string) func(ConnectionState) error {
  1018  	return func(state ConnectionState) error {
  1019  		if state.NegotiatedProtocol != want {
  1020  			return fmt.Errorf("got protocol %q, want %q", state.NegotiatedProtocol, want)
  1021  		}
  1022  		return nil
  1023  	}
  1024  }
  1025  
  1026  func TestHandshakeServerALPN(t *testing.T) {
  1027  	config := testConfigServer.Clone()
  1028  	config.NextProtos = []string{"proto1", "proto2"}
  1029  
  1030  	test := &serverTest{
  1031  		name:    "ALPN",
  1032  		command: append(defaultClientCommand, "-alpn", "proto2,proto1"),
  1033  		config:  config,
  1034  		// The server's preferences should override the client.
  1035  		validate: checkNegotiatedProtocol("proto1"),
  1036  	}
  1037  	runServerTestTLS12(t, test)
  1038  	runServerTestTLS13(t, test)
  1039  }
  1040  
  1041  func TestHandshakeServerALPNNoMatch(t *testing.T) {
  1042  	config := testConfigServer.Clone()
  1043  	config.NextProtos = []string{"proto3"}
  1044  
  1045  	test := &serverTest{
  1046  		name:                          "ALPN-NoMatch",
  1047  		command:                       append(defaultClientCommand, "-alpn", "proto2,proto1"),
  1048  		config:                        config,
  1049  		expectHandshakeErrorIncluding: "client requested unsupported application protocol",
  1050  	}
  1051  	runServerTestTLS12(t, test)
  1052  	runServerTestTLS13(t, test)
  1053  }
  1054  
  1055  func TestHandshakeServerALPNNotConfigured(t *testing.T) {
  1056  	config := testConfigServer.Clone()
  1057  	config.NextProtos = nil
  1058  
  1059  	test := &serverTest{
  1060  		name:     "ALPN-NotConfigured",
  1061  		command:  append(defaultClientCommand, "-alpn", "proto2,proto1"),
  1062  		config:   config,
  1063  		validate: checkNegotiatedProtocol(""),
  1064  	}
  1065  	runServerTestTLS12(t, test)
  1066  	runServerTestTLS13(t, test)
  1067  }
  1068  
  1069  func TestHandshakeServerALPNFallback(t *testing.T) {
  1070  	config := testConfigServer.Clone()
  1071  	config.NextProtos = []string{"proto1", "h2", "proto2"}
  1072  
  1073  	test := &serverTest{
  1074  		name:     "ALPN-Fallback",
  1075  		command:  append(defaultClientCommand, "-alpn", "proto3,http/1.1,proto4"),
  1076  		config:   config,
  1077  		validate: checkNegotiatedProtocol(""),
  1078  	}
  1079  	runServerTestTLS12(t, test)
  1080  	runServerTestTLS13(t, test)
  1081  }
  1082  
  1083  func checkServerName(want string) func(ConnectionState) error {
  1084  	return func(state ConnectionState) error {
  1085  		if state.ServerName != want {
  1086  			return fmt.Errorf("got ServerName %q, want %q", state.ServerName, want)
  1087  		}
  1088  		return nil
  1089  	}
  1090  }
  1091  
  1092  // TestHandshakeServerSNI involves a client sending an SNI extension that
  1093  // matches a later certificate in Config.Certificates. The test verifies that
  1094  // the server correctly selects that certificate.
  1095  func TestHandshakeServerSNI(t *testing.T) {
  1096  	command := slices.Clone(defaultClientCommand)
  1097  	command[slices.Index(command, "-servername")+1] = "different.example.com"
  1098  	test := &serverTest{
  1099  		name:     "SNI",
  1100  		command:  command,
  1101  		validate: checkServerName("different.example.com"),
  1102  	}
  1103  	runServerTestTLS12(t, test)
  1104  	runServerTestTLS13(t, test)
  1105  }
  1106  
  1107  // TestHandshakeServerSNIGetCertificate is similar to TestHandshakeServerSNI, but
  1108  // tests the dynamic GetCertificate method
  1109  func TestHandshakeServerSNIGetCertificate(t *testing.T) {
  1110  	config := testConfigServer.Clone()
  1111  	config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
  1112  		return &testSNICert, nil
  1113  	}
  1114  	command := slices.Clone(defaultClientCommand)
  1115  	command[slices.Index(command, "-servername")+1] = "different.example.com"
  1116  	test := &serverTest{
  1117  		name:     "SNI-GetCertificate",
  1118  		command:  command,
  1119  		config:   config,
  1120  		validate: checkServerName("different.example.com"),
  1121  	}
  1122  	runServerTestTLS12(t, test)
  1123  	runServerTestTLS13(t, test)
  1124  }
  1125  
  1126  // TestHandshakeServerSNIGetCertificateNotFound is similar to
  1127  // TestHandshakeServerSNICertForName, but tests to make sure that when the
  1128  // GetCertificate method doesn't return a cert, we fall back to what's in
  1129  // the NameToCertificate map.
  1130  func TestHandshakeServerSNIGetCertificateNotFound(t *testing.T) {
  1131  	config := testConfigServer.Clone()
  1132  	config.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
  1133  		return nil, nil
  1134  	}
  1135  	command := slices.Clone(defaultClientCommand)
  1136  	command[slices.Index(command, "-servername")+1] = "different.example.com"
  1137  	test := &serverTest{
  1138  		name:     "SNI-GetCertificateNotFound",
  1139  		command:  command,
  1140  		config:   config,
  1141  		validate: checkServerName("different.example.com"),
  1142  	}
  1143  	runServerTestTLS12(t, test)
  1144  	runServerTestTLS13(t, test)
  1145  }
  1146  
  1147  // TestHandshakeServerGetCertificateExtensions tests to make sure that the
  1148  // Extensions passed to GetCertificate match what we expect based on the
  1149  // clientHelloMsg
  1150  func TestHandshakeServerGetCertificateExtensions(t *testing.T) {
  1151  	const errMsg = "TestHandshakeServerGetCertificateExtensions error"
  1152  	// ensure the test condition inside our GetCertificate callback
  1153  	// is actually invoked
  1154  	var called atomic.Int32
  1155  
  1156  	testVersions := []uint16{VersionTLS12, VersionTLS13}
  1157  	for _, vers := range testVersions {
  1158  		t.Run(fmt.Sprintf("TLS version %04x", vers), func(t *testing.T) {
  1159  			pk, _ := ecdh.P256().GenerateKey(rand.Reader)
  1160  			clientHello := &clientHelloMsg{
  1161  				vers:                         vers,
  1162  				random:                       make([]byte, 32),
  1163  				cipherSuites:                 []uint16{TLS_AES_128_GCM_SHA256},
  1164  				compressionMethods:           []uint8{compressionNone},
  1165  				serverName:                   "test",
  1166  				keyShares:                    []keyShare{{group: CurveP256, data: pk.PublicKey().Bytes()}},
  1167  				supportedCurves:              []CurveID{CurveP256},
  1168  				supportedSignatureAlgorithms: []SignatureScheme{ECDSAWithP256AndSHA256},
  1169  			}
  1170  
  1171  			// the clientHelloMsg initialized just above is serialized with
  1172  			// two extensions: server_name(0) and application_layer_protocol_negotiation(16)
  1173  			expectedExtensions := []uint16{
  1174  				extensionServerName,
  1175  				extensionSupportedCurves,
  1176  				extensionSignatureAlgorithms,
  1177  				extensionKeyShare,
  1178  			}
  1179  
  1180  			if vers == VersionTLS13 {
  1181  				clientHello.supportedVersions = []uint16{VersionTLS13}
  1182  				expectedExtensions = append(expectedExtensions, extensionSupportedVersions)
  1183  			}
  1184  
  1185  			// Go's TLS client presents extensions in the ClientHello sorted by extension ID
  1186  			slices.Sort(expectedExtensions)
  1187  
  1188  			serverConfig := testConfig.Clone()
  1189  			serverConfig.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
  1190  				if !slices.Equal(expectedExtensions, clientHello.Extensions) {
  1191  					t.Errorf("expected extensions on ClientHelloInfo (%v) to match clientHelloMsg (%v)", expectedExtensions, clientHello.Extensions)
  1192  				}
  1193  				called.Add(1)
  1194  
  1195  				return nil, errors.New(errMsg)
  1196  			}
  1197  			testClientHelloFailure(t, serverConfig, clientHello, errMsg)
  1198  		})
  1199  	}
  1200  
  1201  	if int(called.Load()) != len(testVersions) {
  1202  		t.Error("expected our GetCertificate test to be called twice")
  1203  	}
  1204  }
  1205  
  1206  // TestHandshakeServerSNIGetCertificateError tests to make sure that errors in
  1207  // GetCertificate result in a tls alert.
  1208  func TestHandshakeServerSNIGetCertificateError(t *testing.T) {
  1209  	const errMsg = "TestHandshakeServerSNIGetCertificateError error"
  1210  
  1211  	serverConfig := testConfig.Clone()
  1212  	serverConfig.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
  1213  		return nil, errors.New(errMsg)
  1214  	}
  1215  
  1216  	clientHello := &clientHelloMsg{
  1217  		vers:               VersionTLS12,
  1218  		random:             make([]byte, 32),
  1219  		cipherSuites:       []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
  1220  		compressionMethods: []uint8{compressionNone},
  1221  		serverName:         "test",
  1222  	}
  1223  	testClientHelloFailure(t, serverConfig, clientHello, errMsg)
  1224  }
  1225  
  1226  // TestHandshakeServerEmptyCertificates tests that GetCertificates is called in
  1227  // the case that Certificates is empty, even without SNI.
  1228  func TestHandshakeServerEmptyCertificates(t *testing.T) {
  1229  	const errMsg = "TestHandshakeServerEmptyCertificates error"
  1230  
  1231  	serverConfig := testConfig.Clone()
  1232  	serverConfig.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
  1233  		return nil, errors.New(errMsg)
  1234  	}
  1235  	serverConfig.Certificates = nil
  1236  
  1237  	clientHello := &clientHelloMsg{
  1238  		vers:               VersionTLS12,
  1239  		random:             make([]byte, 32),
  1240  		cipherSuites:       []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
  1241  		compressionMethods: []uint8{compressionNone},
  1242  	}
  1243  	testClientHelloFailure(t, serverConfig, clientHello, errMsg)
  1244  
  1245  	// With an empty Certificates and a nil GetCertificate, the server
  1246  	// should always return a “no certificates” error.
  1247  	serverConfig.GetCertificate = nil
  1248  
  1249  	clientHello = &clientHelloMsg{
  1250  		vers:               VersionTLS12,
  1251  		random:             make([]byte, 32),
  1252  		cipherSuites:       []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
  1253  		compressionMethods: []uint8{compressionNone},
  1254  	}
  1255  	testClientHelloFailure(t, serverConfig, clientHello, "no certificates")
  1256  }
  1257  
  1258  func checkDidResume(want bool) func(ConnectionState) error {
  1259  	return func(state ConnectionState) error {
  1260  		if state.DidResume != want {
  1261  			return fmt.Errorf("got DidResume %t, want %t", state.DidResume, want)
  1262  		}
  1263  		return nil
  1264  	}
  1265  }
  1266  
  1267  func TestServerResumption(t *testing.T) {
  1268  	sessionFilePath := tempFile("")
  1269  	defer os.Remove(sessionFilePath)
  1270  
  1271  	command := slices.Clone(defaultClientCommand)
  1272  	command = slices.DeleteFunc(command, func(s string) bool { return s == "-no_ticket" })
  1273  
  1274  	testIssue := &serverTest{
  1275  		name:    "IssueTicket",
  1276  		command: append(command, "-sess_out", sessionFilePath),
  1277  	}
  1278  	testResume := &serverTest{
  1279  		name:     "Resume",
  1280  		command:  append(command, "-sess_in", sessionFilePath),
  1281  		validate: checkDidResume(true),
  1282  	}
  1283  
  1284  	runServerTestTLS12(t, testIssue)
  1285  	runServerTestTLS12(t, testResume)
  1286  
  1287  	runServerTestTLS13(t, testIssue)
  1288  	runServerTestTLS13(t, testResume)
  1289  
  1290  	config := testConfigServer.Clone()
  1291  	config.CurvePreferences = []CurveID{CurveP256}
  1292  
  1293  	testResumeHRR := &serverTest{
  1294  		name:    "Resume-HelloRetryRequest",
  1295  		command: append(command, "-curves", "X25519:P-256", "-sess_in", sessionFilePath),
  1296  		config:  config,
  1297  		validate: func(state ConnectionState) error {
  1298  			if !state.DidResume {
  1299  				return errors.New("did not resume")
  1300  			}
  1301  			if !state.HelloRetryRequest {
  1302  				return errors.New("expected HelloRetryRequest")
  1303  			}
  1304  			return nil
  1305  		},
  1306  	}
  1307  
  1308  	runServerTestTLS13(t, testResumeHRR)
  1309  }
  1310  
  1311  func TestServerResumptionDisabled(t *testing.T) {
  1312  	sessionFilePath := tempFile("")
  1313  	defer os.Remove(sessionFilePath)
  1314  
  1315  	config := testConfigServer.Clone()
  1316  	command := slices.Clone(defaultClientCommand)
  1317  	command = slices.DeleteFunc(command, func(s string) bool { return s == "-no_ticket" })
  1318  
  1319  	testIssue := &serverTest{
  1320  		name:    "IssueTicketPreDisable",
  1321  		command: append(command, "-sess_out", sessionFilePath),
  1322  		config:  config,
  1323  	}
  1324  	testResume := &serverTest{
  1325  		name:     "ResumeDisabled",
  1326  		command:  append(command, "-sess_in", sessionFilePath),
  1327  		config:   config,
  1328  		validate: checkDidResume(false),
  1329  	}
  1330  
  1331  	config.SessionTicketsDisabled = false
  1332  	runServerTestTLS12(t, testIssue)
  1333  	config.SessionTicketsDisabled = true
  1334  	runServerTestTLS12(t, testResume)
  1335  
  1336  	config.SessionTicketsDisabled = false
  1337  	runServerTestTLS13(t, testIssue)
  1338  	config.SessionTicketsDisabled = true
  1339  	runServerTestTLS13(t, testResume)
  1340  }
  1341  
  1342  func TestFallbackSCSV(t *testing.T) {
  1343  	test := &serverTest{
  1344  		name:                          "FallbackSCSV",
  1345  		command:                       append(defaultClientCommand, "--fallback_scsv"),
  1346  		expectHandshakeErrorIncluding: "inappropriate protocol fallback",
  1347  	}
  1348  	runServerTestTLS11(t, test)
  1349  }
  1350  
  1351  func TestHandshakeServerExportKeyingMaterial(t *testing.T) {
  1352  	test := &serverTest{
  1353  		name: "ExportKeyingMaterial",
  1354  		validate: func(state ConnectionState) error {
  1355  			if km, err := state.ExportKeyingMaterial("test", nil, 42); err != nil {
  1356  				return fmt.Errorf("ExportKeyingMaterial failed: %v", err)
  1357  			} else if len(km) != 42 {
  1358  				return fmt.Errorf("Got %d bytes from ExportKeyingMaterial, wanted %d", len(km), 42)
  1359  			}
  1360  			return nil
  1361  		},
  1362  	}
  1363  	runServerTestTLS10(t, test)
  1364  	runServerTestTLS12(t, test)
  1365  	runServerTestTLS13(t, test)
  1366  }
  1367  
  1368  func TestHandshakeServerRSAPKCS1v15(t *testing.T) {
  1369  	test := &serverTest{
  1370  		name:    "RSA-RSAPKCS1v15",
  1371  		command: append(defaultClientCommand, "-sigalgs", "rsa_pkcs1_sha256"),
  1372  	}
  1373  	runServerTestTLS12(t, test)
  1374  }
  1375  
  1376  func TestHandshakeServerRSAPSS(t *testing.T) {
  1377  	config := testConfigServer.Clone()
  1378  	config.Certificates = []Certificate{testRSA1024Cert}
  1379  
  1380  	// We send rsa_pss_rsae_sha512 first, as the test key won't fit, and we
  1381  	// verify the server implementation will disregard the client preference in
  1382  	// that case. See Issue 29793.
  1383  	test := &serverTest{
  1384  		name:    "RSA-RSAPSS",
  1385  		config:  config,
  1386  		command: append(defaultClientCommand, "-sigalgs", "rsa_pss_rsae_sha512:rsa_pss_rsae_sha256", "-auth_level", "0"),
  1387  	}
  1388  	runServerTestTLS12(t, test)
  1389  	runServerTestTLS13(t, test)
  1390  
  1391  	test = &serverTest{
  1392  		name:                          "RSA-RSAPSS-TooSmall",
  1393  		config:                        config,
  1394  		command:                       append(defaultClientCommand, "-sigalgs", "rsa_pss_rsae_sha512", "-auth_level", "0"),
  1395  		expectHandshakeErrorIncluding: "peer doesn't support any of the certificate's signature algorithms",
  1396  	}
  1397  	runServerTestTLS13(t, test)
  1398  }
  1399  
  1400  func TestHandshakeServerEd25519(t *testing.T) {
  1401  	test := &serverTest{
  1402  		name:    "Ed25519",
  1403  		command: append(defaultClientCommand, "-sigalgs", "ed25519"),
  1404  	}
  1405  	runServerTestTLS12(t, test)
  1406  	runServerTestTLS13(t, test)
  1407  }
  1408  
  1409  func benchmarkHandshakeServer(b *testing.B, version uint16, cipherSuite uint16, curve CurveID, cert []byte, key crypto.PrivateKey) {
  1410  	config := testConfig.Clone()
  1411  	config.CipherSuites = []uint16{cipherSuite}
  1412  	config.CurvePreferences = []CurveID{curve}
  1413  	config.Certificates = make([]Certificate, 1)
  1414  	config.Certificates[0].Certificate = [][]byte{cert}
  1415  	config.Certificates[0].PrivateKey = key
  1416  	config.BuildNameToCertificate()
  1417  
  1418  	clientConn, serverConn := localPipe(b)
  1419  	serverConn = &recordingConn{Conn: serverConn}
  1420  	go func() {
  1421  		config := testConfig.Clone()
  1422  		config.MaxVersion = version
  1423  		config.CurvePreferences = []CurveID{curve}
  1424  		client := Client(clientConn, config)
  1425  		client.Handshake()
  1426  	}()
  1427  	server := Server(serverConn, config)
  1428  	if err := server.Handshake(); err != nil {
  1429  		b.Fatalf("handshake failed: %v", err)
  1430  	}
  1431  	serverConn.Close()
  1432  	flows := serverConn.(*recordingConn).flows
  1433  
  1434  	b.ResetTimer()
  1435  	for i := 0; i < b.N; i++ {
  1436  		replay := &replayingConn{t: b, flows: slices.Clone(flows), reading: true}
  1437  		server := Server(replay, config)
  1438  		if err := server.Handshake(); err != nil {
  1439  			b.Fatalf("handshake failed: %v", err)
  1440  		}
  1441  	}
  1442  }
  1443  
  1444  func BenchmarkHandshakeServer(b *testing.B) {
  1445  	b.Run("RSA", func(b *testing.B) {
  1446  		benchmarkHandshakeServer(b, VersionTLS12, TLS_RSA_WITH_AES_128_GCM_SHA256,
  1447  			0, testRSACertificate, testRSAPrivateKey)
  1448  	})
  1449  	b.Run("ECDHE-P256-RSA", func(b *testing.B) {
  1450  		b.Run("TLSv13", func(b *testing.B) {
  1451  			benchmarkHandshakeServer(b, VersionTLS13, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1452  				CurveP256, testRSACertificate, testRSAPrivateKey)
  1453  		})
  1454  		b.Run("TLSv12", func(b *testing.B) {
  1455  			benchmarkHandshakeServer(b, VersionTLS12, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1456  				CurveP256, testRSACertificate, testRSAPrivateKey)
  1457  		})
  1458  	})
  1459  	b.Run("ECDHE-P256-ECDSA-P256", func(b *testing.B) {
  1460  		b.Run("TLSv13", func(b *testing.B) {
  1461  			benchmarkHandshakeServer(b, VersionTLS13, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
  1462  				CurveP256, testP256Certificate, testP256PrivateKey)
  1463  		})
  1464  		b.Run("TLSv12", func(b *testing.B) {
  1465  			benchmarkHandshakeServer(b, VersionTLS12, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
  1466  				CurveP256, testP256Certificate, testP256PrivateKey)
  1467  		})
  1468  	})
  1469  	b.Run("ECDHE-X25519-ECDSA-P256", func(b *testing.B) {
  1470  		b.Run("TLSv13", func(b *testing.B) {
  1471  			benchmarkHandshakeServer(b, VersionTLS13, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
  1472  				X25519, testP256Certificate, testP256PrivateKey)
  1473  		})
  1474  		b.Run("TLSv12", func(b *testing.B) {
  1475  			benchmarkHandshakeServer(b, VersionTLS12, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
  1476  				X25519, testP256Certificate, testP256PrivateKey)
  1477  		})
  1478  	})
  1479  	b.Run("ECDHE-P521-ECDSA-P521", func(b *testing.B) {
  1480  		if testECDSAPrivateKey.PublicKey.Curve != elliptic.P521() {
  1481  			b.Fatal("test ECDSA key doesn't use curve P-521")
  1482  		}
  1483  		b.Run("TLSv13", func(b *testing.B) {
  1484  			benchmarkHandshakeServer(b, VersionTLS13, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
  1485  				CurveP521, testECDSACertificate, testECDSAPrivateKey)
  1486  		})
  1487  		b.Run("TLSv12", func(b *testing.B) {
  1488  			benchmarkHandshakeServer(b, VersionTLS12, TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
  1489  				CurveP521, testECDSACertificate, testECDSAPrivateKey)
  1490  		})
  1491  	})
  1492  }
  1493  
  1494  func TestClientAuth(t *testing.T) {
  1495  	var certPath, keyPath, ecdsaCertPath, ecdsaKeyPath, ed25519CertPath, ed25519KeyPath string
  1496  
  1497  	if *update {
  1498  		certPath = tempFile(clientCertificatePEM)
  1499  		defer os.Remove(certPath)
  1500  		keyPath = tempFile(clientKeyPEM)
  1501  		defer os.Remove(keyPath)
  1502  		ecdsaCertPath = tempFile(clientECDSACertificatePEM)
  1503  		defer os.Remove(ecdsaCertPath)
  1504  		ecdsaKeyPath = tempFile(clientECDSAKeyPEM)
  1505  		defer os.Remove(ecdsaKeyPath)
  1506  		ed25519CertPath = tempFile(clientEd25519CertificatePEM)
  1507  		defer os.Remove(ed25519CertPath)
  1508  		ed25519KeyPath = tempFile(clientEd25519KeyPEM)
  1509  		defer os.Remove(ed25519KeyPath)
  1510  	}
  1511  
  1512  	config := testConfigServer.Clone()
  1513  	config.ClientAuth = RequestClientCert
  1514  
  1515  	test := &serverTest{
  1516  		name:   "ClientAuthRequestedNotGiven",
  1517  		config: config,
  1518  	}
  1519  	runServerTestTLS12(t, test)
  1520  	runServerTestTLS13(t, test)
  1521  
  1522  	test = &serverTest{
  1523  		name:              "ClientAuthRequestedAndGiven",
  1524  		command:           append(defaultClientCommand, "-cert", certPath, "-key", keyPath, "-client_sigalgs", "rsa_pss_rsae_sha256"),
  1525  		config:            config,
  1526  		expectedPeerCerts: []string{clientCertificatePEM},
  1527  	}
  1528  	runServerTestTLS12(t, test)
  1529  	runServerTestTLS13(t, test)
  1530  
  1531  	test = &serverTest{
  1532  		name:              "ClientAuthRequestedAndECDSAGiven",
  1533  		command:           append(defaultClientCommand, "-cert", ecdsaCertPath, "-key", ecdsaKeyPath),
  1534  		config:            config,
  1535  		expectedPeerCerts: []string{clientECDSACertificatePEM},
  1536  	}
  1537  	runServerTestTLS12(t, test)
  1538  	runServerTestTLS13(t, test)
  1539  
  1540  	test = &serverTest{
  1541  		name:              "ClientAuthRequestedAndEd25519Given",
  1542  		command:           append(defaultClientCommand, "-cert", ed25519CertPath, "-key", ed25519KeyPath),
  1543  		config:            config,
  1544  		expectedPeerCerts: []string{clientEd25519CertificatePEM},
  1545  	}
  1546  	runServerTestTLS12(t, test)
  1547  	runServerTestTLS13(t, test)
  1548  
  1549  	test = &serverTest{
  1550  		name:              "ClientAuthRequestedAndPKCS1v15Given",
  1551  		command:           append(defaultClientCommand, "-cert", certPath, "-key", keyPath, "-client_sigalgs", "rsa_pkcs1_sha256"),
  1552  		config:            config,
  1553  		expectedPeerCerts: []string{clientCertificatePEM},
  1554  	}
  1555  	runServerTestTLS12(t, test)
  1556  }
  1557  
  1558  func TestSNIGivenOnFailure(t *testing.T) {
  1559  	const expectedServerName = "test.testing"
  1560  
  1561  	clientHello := &clientHelloMsg{
  1562  		vers:               VersionTLS12,
  1563  		random:             make([]byte, 32),
  1564  		cipherSuites:       []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
  1565  		compressionMethods: []uint8{compressionNone},
  1566  		serverName:         expectedServerName,
  1567  	}
  1568  
  1569  	serverConfig := testConfig.Clone()
  1570  	// Erase the server's cipher suites to ensure the handshake fails.
  1571  	serverConfig.CipherSuites = nil
  1572  
  1573  	c, s := localPipe(t)
  1574  	go func() {
  1575  		cli := Client(c, testConfig)
  1576  		cli.vers = clientHello.vers
  1577  		if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil {
  1578  			testFatal(t, err)
  1579  		}
  1580  		c.Close()
  1581  	}()
  1582  	conn := Server(s, serverConfig)
  1583  	ctx := context.Background()
  1584  	ch, _, err := conn.readClientHello(ctx)
  1585  	hs := serverHandshakeState{
  1586  		c:           conn,
  1587  		ctx:         ctx,
  1588  		clientHello: ch,
  1589  	}
  1590  	if err == nil {
  1591  		err = hs.processClientHello()
  1592  	}
  1593  	if err == nil {
  1594  		err = hs.pickCipherSuite()
  1595  	}
  1596  	defer s.Close()
  1597  
  1598  	if err == nil {
  1599  		t.Error("No error reported from server")
  1600  	}
  1601  
  1602  	cs := hs.c.ConnectionState()
  1603  	if cs.HandshakeComplete {
  1604  		t.Error("Handshake registered as complete")
  1605  	}
  1606  
  1607  	if cs.ServerName != expectedServerName {
  1608  		t.Errorf("Expected ServerName of %q, but got %q", expectedServerName, cs.ServerName)
  1609  	}
  1610  }
  1611  
  1612  var getConfigForClientTests = []struct {
  1613  	setup          func(config *Config)
  1614  	callback       func(clientHello *ClientHelloInfo) (*Config, error)
  1615  	errorSubstring string
  1616  	verify         func(config *Config) error
  1617  }{
  1618  	{
  1619  		nil,
  1620  		func(clientHello *ClientHelloInfo) (*Config, error) {
  1621  			return nil, nil
  1622  		},
  1623  		"",
  1624  		nil,
  1625  	},
  1626  	{
  1627  		nil,
  1628  		func(clientHello *ClientHelloInfo) (*Config, error) {
  1629  			return nil, errors.New("should bubble up")
  1630  		},
  1631  		"should bubble up",
  1632  		nil,
  1633  	},
  1634  	{
  1635  		nil,
  1636  		func(clientHello *ClientHelloInfo) (*Config, error) {
  1637  			config := testConfig.Clone()
  1638  			// Setting a maximum version of TLS 1.1 should cause
  1639  			// the handshake to fail, as the client MinVersion is TLS 1.2.
  1640  			config.MaxVersion = VersionTLS11
  1641  			return config, nil
  1642  		},
  1643  		"client offered only unsupported versions",
  1644  		nil,
  1645  	},
  1646  	{
  1647  		func(config *Config) {
  1648  			for i := range config.SessionTicketKey {
  1649  				config.SessionTicketKey[i] = byte(i)
  1650  			}
  1651  			config.sessionTicketKeys = nil
  1652  		},
  1653  		func(clientHello *ClientHelloInfo) (*Config, error) {
  1654  			config := testConfig.Clone()
  1655  			clear(config.SessionTicketKey[:])
  1656  			config.sessionTicketKeys = nil
  1657  			return config, nil
  1658  		},
  1659  		"",
  1660  		func(config *Config) error {
  1661  			if config.SessionTicketKey == [32]byte{} {
  1662  				return fmt.Errorf("expected SessionTicketKey to be set")
  1663  			}
  1664  			return nil
  1665  		},
  1666  	},
  1667  	{
  1668  		func(config *Config) {
  1669  			var dummyKey [32]byte
  1670  			for i := range dummyKey {
  1671  				dummyKey[i] = byte(i)
  1672  			}
  1673  
  1674  			config.SetSessionTicketKeys([][32]byte{dummyKey})
  1675  		},
  1676  		func(clientHello *ClientHelloInfo) (*Config, error) {
  1677  			config := testConfig.Clone()
  1678  			config.sessionTicketKeys = nil
  1679  			return config, nil
  1680  		},
  1681  		"",
  1682  		func(config *Config) error {
  1683  			if config.SessionTicketKey == [32]byte{} {
  1684  				return fmt.Errorf("expected SessionTicketKey to be set")
  1685  			}
  1686  			return nil
  1687  		},
  1688  	},
  1689  }
  1690  
  1691  func TestGetConfigForClient(t *testing.T) {
  1692  	serverConfig := testConfig.Clone()
  1693  	clientConfig := testConfig.Clone()
  1694  	clientConfig.MinVersion = VersionTLS12
  1695  
  1696  	for i, test := range getConfigForClientTests {
  1697  		if test.setup != nil {
  1698  			test.setup(serverConfig)
  1699  		}
  1700  
  1701  		var configReturned *Config
  1702  		serverConfig.GetConfigForClient = func(clientHello *ClientHelloInfo) (*Config, error) {
  1703  			config, err := test.callback(clientHello)
  1704  			configReturned = config
  1705  			return config, err
  1706  		}
  1707  		c, s := localPipe(t)
  1708  		done := make(chan error)
  1709  
  1710  		go func() {
  1711  			defer s.Close()
  1712  			done <- Server(s, serverConfig).Handshake()
  1713  		}()
  1714  
  1715  		clientErr := Client(c, clientConfig).Handshake()
  1716  		c.Close()
  1717  
  1718  		serverErr := <-done
  1719  
  1720  		if len(test.errorSubstring) == 0 {
  1721  			if serverErr != nil || clientErr != nil {
  1722  				t.Errorf("test[%d]: expected no error but got serverErr: %q, clientErr: %q", i, serverErr, clientErr)
  1723  			}
  1724  			if test.verify != nil {
  1725  				if err := test.verify(configReturned); err != nil {
  1726  					t.Errorf("test[%d]: verify returned error: %v", i, err)
  1727  				}
  1728  			}
  1729  		} else {
  1730  			if serverErr == nil {
  1731  				t.Errorf("test[%d]: expected error containing %q but got no error", i, test.errorSubstring)
  1732  			} else if !strings.Contains(serverErr.Error(), test.errorSubstring) {
  1733  				t.Errorf("test[%d]: expected error to contain %q but it was %q", i, test.errorSubstring, serverErr)
  1734  			}
  1735  		}
  1736  	}
  1737  }
  1738  
  1739  func TestCloseServerConnectionOnIdleClient(t *testing.T) {
  1740  	clientConn, serverConn := localPipe(t)
  1741  	server := Server(serverConn, testConfig.Clone())
  1742  	go func() {
  1743  		clientConn.Write([]byte{'0'})
  1744  		server.Close()
  1745  	}()
  1746  	server.SetReadDeadline(time.Now().Add(time.Minute))
  1747  	err := server.Handshake()
  1748  	if err != nil {
  1749  		if err, ok := err.(net.Error); ok && err.Timeout() {
  1750  			t.Errorf("Expected a closed network connection error but got '%s'", err.Error())
  1751  		}
  1752  	} else {
  1753  		t.Errorf("Error expected, but no error returned")
  1754  	}
  1755  }
  1756  
  1757  func TestCloneHash(t *testing.T) {
  1758  	h1 := crypto.SHA256.New()
  1759  	h1.Write([]byte("test"))
  1760  	s1 := h1.Sum(nil)
  1761  	h2 := cloneHash(h1, crypto.SHA256)
  1762  	s2 := h2.Sum(nil)
  1763  	if !bytes.Equal(s1, s2) {
  1764  		t.Error("cloned hash generated a different sum")
  1765  	}
  1766  }
  1767  
  1768  func expectError(t *testing.T, err error, sub string) {
  1769  	if err == nil {
  1770  		t.Errorf(`expected error %q, got nil`, sub)
  1771  	} else if !strings.Contains(err.Error(), sub) {
  1772  		t.Errorf(`expected error %q, got %q`, sub, err)
  1773  	}
  1774  }
  1775  
  1776  func TestKeyTooSmallForRSAPSS(t *testing.T) {
  1777  	t.Setenv("GODEBUG", os.Getenv("GODEBUG")+",rsa1024min=0")
  1778  	clientConn, serverConn := localPipe(t)
  1779  	client := Client(clientConn, testConfigClient)
  1780  	done := make(chan struct{})
  1781  	go func() {
  1782  		config := testConfigServer.Clone()
  1783  		config.Certificates = []Certificate{testRSA512Cert}
  1784  		config.MinVersion = VersionTLS13
  1785  		server := Server(serverConn, config)
  1786  		err := server.Handshake()
  1787  		expectError(t, err, "key size too small")
  1788  		close(done)
  1789  	}()
  1790  	err := client.Handshake()
  1791  	expectError(t, err, "handshake failure")
  1792  	<-done
  1793  }
  1794  
  1795  func TestMultipleCertificates(t *testing.T) {
  1796  	clientConfig := testConfig.Clone()
  1797  	clientConfig.CipherSuites = []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}
  1798  	clientConfig.MaxVersion = VersionTLS12
  1799  
  1800  	serverConfig := testConfig.Clone()
  1801  	serverConfig.Certificates = []Certificate{{
  1802  		Certificate: [][]byte{testECDSACertificate},
  1803  		PrivateKey:  testECDSAPrivateKey,
  1804  	}, {
  1805  		Certificate: [][]byte{testRSACertificate},
  1806  		PrivateKey:  testRSAPrivateKey,
  1807  	}}
  1808  
  1809  	_, clientState, err := testHandshake(t, clientConfig, serverConfig)
  1810  	if err != nil {
  1811  		t.Fatal(err)
  1812  	}
  1813  	if got := clientState.PeerCertificates[0].PublicKeyAlgorithm; got != x509.RSA {
  1814  		t.Errorf("expected RSA certificate, got %v", got)
  1815  	}
  1816  }
  1817  
  1818  func TestAESCipherReordering(t *testing.T) {
  1819  	skipFIPS(t) // No CHACHA20_POLY1305 for FIPS.
  1820  
  1821  	currentAESSupport := hasAESGCMHardwareSupport
  1822  	defer func() { hasAESGCMHardwareSupport = currentAESSupport }()
  1823  
  1824  	tests := []struct {
  1825  		name            string
  1826  		clientCiphers   []uint16
  1827  		serverHasAESGCM bool
  1828  		serverCiphers   []uint16
  1829  		expectedCipher  uint16
  1830  	}{
  1831  		{
  1832  			name: "server has hardware AES, client doesn't (pick ChaCha)",
  1833  			clientCiphers: []uint16{
  1834  				TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1835  				TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  1836  				TLS_RSA_WITH_AES_128_CBC_SHA,
  1837  			},
  1838  			serverHasAESGCM: true,
  1839  			expectedCipher:  TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1840  		},
  1841  		{
  1842  			name: "client prefers AES-GCM, server doesn't have hardware AES (pick ChaCha)",
  1843  			clientCiphers: []uint16{
  1844  				TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  1845  				TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1846  				TLS_RSA_WITH_AES_128_CBC_SHA,
  1847  			},
  1848  			serverHasAESGCM: false,
  1849  			expectedCipher:  TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1850  		},
  1851  		{
  1852  			name: "client prefers AES-GCM, server has hardware AES (pick AES-GCM)",
  1853  			clientCiphers: []uint16{
  1854  				TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  1855  				TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1856  				TLS_RSA_WITH_AES_128_CBC_SHA,
  1857  			},
  1858  			serverHasAESGCM: true,
  1859  			expectedCipher:  TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  1860  		},
  1861  		{
  1862  			name: "client prefers AES-GCM and sends GREASE, server has hardware AES (pick AES-GCM)",
  1863  			clientCiphers: []uint16{
  1864  				0x0A0A, // GREASE value
  1865  				TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  1866  				TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1867  				TLS_RSA_WITH_AES_128_CBC_SHA,
  1868  			},
  1869  			serverHasAESGCM: true,
  1870  			expectedCipher:  TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  1871  		},
  1872  		{
  1873  			name: "client prefers AES-GCM and doesn't support ChaCha, server doesn't have hardware AES (pick AES-GCM)",
  1874  			clientCiphers: []uint16{
  1875  				TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  1876  				TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
  1877  				TLS_RSA_WITH_AES_128_CBC_SHA,
  1878  			},
  1879  			serverHasAESGCM: false,
  1880  			expectedCipher:  TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  1881  		},
  1882  		{
  1883  			name: "client prefers AES-GCM and AES-CBC over ChaCha, server doesn't have hardware AES (pick ChaCha)",
  1884  			clientCiphers: []uint16{
  1885  				TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  1886  				TLS_RSA_WITH_AES_128_CBC_SHA,
  1887  				TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1888  			},
  1889  			serverHasAESGCM: false,
  1890  			expectedCipher:  TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1891  		},
  1892  		{
  1893  			name: "client prefers AES-GCM over ChaCha and sends GREASE, server doesn't have hardware AES (pick ChaCha)",
  1894  			clientCiphers: []uint16{
  1895  				0x0A0A, // GREASE value
  1896  				TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  1897  				TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1898  				TLS_RSA_WITH_AES_128_CBC_SHA,
  1899  			},
  1900  			serverHasAESGCM: false,
  1901  			expectedCipher:  TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1902  		},
  1903  		{
  1904  			name: "client supports multiple AES-GCM, server doesn't have hardware AES and doesn't support ChaCha (AES-GCM)",
  1905  			clientCiphers: []uint16{
  1906  				TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
  1907  				TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1908  				TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  1909  			},
  1910  			serverHasAESGCM: false,
  1911  			serverCiphers: []uint16{
  1912  				TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
  1913  				TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  1914  			},
  1915  			expectedCipher: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  1916  		},
  1917  		{
  1918  			name: "client prefers AES-GCM, server has hardware but doesn't support AES (pick ChaCha)",
  1919  			clientCiphers: []uint16{
  1920  				TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
  1921  				TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1922  				TLS_RSA_WITH_AES_128_CBC_SHA,
  1923  			},
  1924  			serverHasAESGCM: true,
  1925  			serverCiphers: []uint16{
  1926  				TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1927  			},
  1928  			expectedCipher: TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
  1929  		},
  1930  	}
  1931  
  1932  	for _, tc := range tests {
  1933  		t.Run(tc.name, func(t *testing.T) {
  1934  			hasAESGCMHardwareSupport = tc.serverHasAESGCM
  1935  			hs := &serverHandshakeState{
  1936  				c: &Conn{
  1937  					config: &Config{
  1938  						CipherSuites: tc.serverCiphers,
  1939  					},
  1940  					vers: VersionTLS12,
  1941  				},
  1942  				clientHello: &clientHelloMsg{
  1943  					cipherSuites: tc.clientCiphers,
  1944  					vers:         VersionTLS12,
  1945  				},
  1946  				ecdheOk:      true,
  1947  				rsaSignOk:    true,
  1948  				rsaDecryptOk: true,
  1949  			}
  1950  
  1951  			err := hs.pickCipherSuite()
  1952  			if err != nil {
  1953  				t.Errorf("pickCipherSuite failed: %s", err)
  1954  			}
  1955  
  1956  			if tc.expectedCipher != hs.suite.id {
  1957  				t.Errorf("unexpected cipher chosen: want %d, got %d", tc.expectedCipher, hs.suite.id)
  1958  			}
  1959  		})
  1960  	}
  1961  }
  1962  
  1963  func TestAESCipherReorderingTLS13(t *testing.T) {
  1964  	skipFIPS(t) // No CHACHA20_POLY1305 for FIPS.
  1965  
  1966  	currentAESSupport := hasAESGCMHardwareSupport
  1967  	defer func() { hasAESGCMHardwareSupport = currentAESSupport }()
  1968  
  1969  	tests := []struct {
  1970  		name            string
  1971  		clientCiphers   []uint16
  1972  		serverHasAESGCM bool
  1973  		expectedCipher  uint16
  1974  	}{
  1975  		{
  1976  			name: "server has hardware AES, client doesn't (pick ChaCha)",
  1977  			clientCiphers: []uint16{
  1978  				TLS_CHACHA20_POLY1305_SHA256,
  1979  				TLS_AES_128_GCM_SHA256,
  1980  			},
  1981  			serverHasAESGCM: true,
  1982  			expectedCipher:  TLS_CHACHA20_POLY1305_SHA256,
  1983  		},
  1984  		{
  1985  			name: "neither server nor client have hardware AES (pick ChaCha)",
  1986  			clientCiphers: []uint16{
  1987  				TLS_CHACHA20_POLY1305_SHA256,
  1988  				TLS_AES_128_GCM_SHA256,
  1989  			},
  1990  			serverHasAESGCM: false,
  1991  			expectedCipher:  TLS_CHACHA20_POLY1305_SHA256,
  1992  		},
  1993  		{
  1994  			name: "client prefers AES, server doesn't have hardware (pick ChaCha)",
  1995  			clientCiphers: []uint16{
  1996  				TLS_AES_128_GCM_SHA256,
  1997  				TLS_CHACHA20_POLY1305_SHA256,
  1998  			},
  1999  			serverHasAESGCM: false,
  2000  			expectedCipher:  TLS_CHACHA20_POLY1305_SHA256,
  2001  		},
  2002  		{
  2003  			name: "client prefers AES and sends GREASE, server doesn't have hardware (pick ChaCha)",
  2004  			clientCiphers: []uint16{
  2005  				0x0A0A, // GREASE value
  2006  				TLS_AES_128_GCM_SHA256,
  2007  				TLS_CHACHA20_POLY1305_SHA256,
  2008  			},
  2009  			serverHasAESGCM: false,
  2010  			expectedCipher:  TLS_CHACHA20_POLY1305_SHA256,
  2011  		},
  2012  		{
  2013  			name: "client prefers AES, server has hardware AES (pick AES)",
  2014  			clientCiphers: []uint16{
  2015  				TLS_AES_128_GCM_SHA256,
  2016  				TLS_CHACHA20_POLY1305_SHA256,
  2017  			},
  2018  			serverHasAESGCM: true,
  2019  			expectedCipher:  TLS_AES_128_GCM_SHA256,
  2020  		},
  2021  		{
  2022  			name: "client prefers AES and sends GREASE, server has hardware AES (pick AES)",
  2023  			clientCiphers: []uint16{
  2024  				0x0A0A, // GREASE value
  2025  				TLS_AES_128_GCM_SHA256,
  2026  				TLS_CHACHA20_POLY1305_SHA256,
  2027  			},
  2028  			serverHasAESGCM: true,
  2029  			expectedCipher:  TLS_AES_128_GCM_SHA256,
  2030  		},
  2031  	}
  2032  
  2033  	for _, tc := range tests {
  2034  		t.Run(tc.name, func(t *testing.T) {
  2035  			hasAESGCMHardwareSupport = tc.serverHasAESGCM
  2036  			pk, _ := ecdh.X25519().GenerateKey(rand.Reader)
  2037  			hs := &serverHandshakeStateTLS13{
  2038  				c: &Conn{
  2039  					config: &Config{},
  2040  					vers:   VersionTLS13,
  2041  				},
  2042  				clientHello: &clientHelloMsg{
  2043  					cipherSuites:       tc.clientCiphers,
  2044  					supportedVersions:  []uint16{VersionTLS13},
  2045  					compressionMethods: []uint8{compressionNone},
  2046  					keyShares:          []keyShare{{group: X25519, data: pk.PublicKey().Bytes()}},
  2047  					supportedCurves:    []CurveID{X25519},
  2048  				},
  2049  			}
  2050  
  2051  			err := hs.processClientHello()
  2052  			if err != nil {
  2053  				t.Errorf("pickCipherSuite failed: %s", err)
  2054  			}
  2055  
  2056  			if tc.expectedCipher != hs.suite.id {
  2057  				t.Errorf("unexpected cipher chosen: want %d, got %d", tc.expectedCipher, hs.suite.id)
  2058  			}
  2059  		})
  2060  	}
  2061  }
  2062  
  2063  // TestServerHandshakeContextCancellation tests that canceling
  2064  // the context given to the server side conn.HandshakeContext
  2065  // interrupts the in-progress handshake.
  2066  func TestServerHandshakeContextCancellation(t *testing.T) {
  2067  	c, s := localPipe(t)
  2068  	ctx, cancel := context.WithCancel(context.Background())
  2069  	unblockClient := make(chan struct{})
  2070  	defer close(unblockClient)
  2071  	go func() {
  2072  		cancel()
  2073  		<-unblockClient
  2074  		_ = c.Close()
  2075  	}()
  2076  	conn := Server(s, testConfig)
  2077  	// Initiates server side handshake, which will block until a client hello is read
  2078  	// unless the cancellation works.
  2079  	err := conn.HandshakeContext(ctx)
  2080  	if err == nil {
  2081  		t.Fatal("Server handshake did not error when the context was canceled")
  2082  	}
  2083  	if err != context.Canceled {
  2084  		t.Errorf("Unexpected server handshake error: %v", err)
  2085  	}
  2086  	if runtime.GOOS == "js" || runtime.GOOS == "wasip1" {
  2087  		t.Skip("conn.Close does not error as expected when called multiple times on GOOS=js or GOOS=wasip1")
  2088  	}
  2089  	err = conn.Close()
  2090  	if err == nil {
  2091  		t.Error("Server connection was not closed when the context was canceled")
  2092  	}
  2093  }
  2094  
  2095  // TestHandshakeContextHierarchy tests whether the contexts
  2096  // available to GetClientCertificate and GetCertificate are
  2097  // derived from the context provided to HandshakeContext, and
  2098  // that those contexts are canceled after HandshakeContext has
  2099  // returned.
  2100  func TestHandshakeContextHierarchy(t *testing.T) {
  2101  	c, s := localPipe(t)
  2102  	clientErr := make(chan error, 1)
  2103  	clientConfig := testConfig.Clone()
  2104  	serverConfig := testConfig.Clone()
  2105  	ctx, cancel := context.WithCancel(context.Background())
  2106  	defer cancel()
  2107  	key := struct{}{}
  2108  	ctx = context.WithValue(ctx, key, true)
  2109  	go func() {
  2110  		defer close(clientErr)
  2111  		defer c.Close()
  2112  		var innerCtx context.Context
  2113  		clientConfig.Certificates = nil
  2114  		clientConfig.GetClientCertificate = func(certificateRequest *CertificateRequestInfo) (*Certificate, error) {
  2115  			if val, ok := certificateRequest.Context().Value(key).(bool); !ok || !val {
  2116  				t.Errorf("GetClientCertificate context was not child of HandshakeContext")
  2117  			}
  2118  			innerCtx = certificateRequest.Context()
  2119  			return &Certificate{
  2120  				Certificate: [][]byte{testRSACertificate},
  2121  				PrivateKey:  testRSAPrivateKey,
  2122  			}, nil
  2123  		}
  2124  		cli := Client(c, clientConfig)
  2125  		err := cli.HandshakeContext(ctx)
  2126  		if err != nil {
  2127  			clientErr <- err
  2128  			return
  2129  		}
  2130  		select {
  2131  		case <-innerCtx.Done():
  2132  		default:
  2133  			t.Errorf("GetClientCertificate context was not canceled after HandshakeContext returned.")
  2134  		}
  2135  	}()
  2136  	var innerCtx context.Context
  2137  	serverConfig.Certificates = nil
  2138  	serverConfig.ClientAuth = RequestClientCert
  2139  	serverConfig.GetCertificate = func(clientHello *ClientHelloInfo) (*Certificate, error) {
  2140  		if val, ok := clientHello.Context().Value(key).(bool); !ok || !val {
  2141  			t.Errorf("GetClientCertificate context was not child of HandshakeContext")
  2142  		}
  2143  		innerCtx = clientHello.Context()
  2144  		return &Certificate{
  2145  			Certificate: [][]byte{testRSACertificate},
  2146  			PrivateKey:  testRSAPrivateKey,
  2147  		}, nil
  2148  	}
  2149  	conn := Server(s, serverConfig)
  2150  	err := conn.HandshakeContext(ctx)
  2151  	if err != nil {
  2152  		t.Errorf("Unexpected server handshake error: %v", err)
  2153  	}
  2154  	select {
  2155  	case <-innerCtx.Done():
  2156  	default:
  2157  		t.Errorf("GetCertificate context was not canceled after HandshakeContext returned.")
  2158  	}
  2159  	if err := <-clientErr; err != nil {
  2160  		t.Errorf("Unexpected client error: %v", err)
  2161  	}
  2162  }
  2163  
  2164  func TestHandshakeChainExpiryResumption(t *testing.T) {
  2165  	t.Run("TLS1.2", func(t *testing.T) {
  2166  		testHandshakeChainExpiryResumption(t, VersionTLS12)
  2167  	})
  2168  	t.Run("TLS1.3", func(t *testing.T) {
  2169  		testHandshakeChainExpiryResumption(t, VersionTLS13)
  2170  	})
  2171  }
  2172  
  2173  func testHandshakeChainExpiryResumption(t *testing.T, version uint16) {
  2174  	now := time.Now()
  2175  
  2176  	createChain := func(leafNotAfter, rootNotAfter time.Time) (leafDER, expiredLeafDER []byte, root *x509.Certificate) {
  2177  		tmpl := &x509.Certificate{
  2178  			Subject:               pkix.Name{CommonName: "root"},
  2179  			NotBefore:             rootNotAfter.Add(-time.Hour * 24),
  2180  			NotAfter:              rootNotAfter,
  2181  			IsCA:                  true,
  2182  			BasicConstraintsValid: true,
  2183  		}
  2184  		rootDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey)
  2185  		if err != nil {
  2186  			t.Fatalf("CreateCertificate: %v", err)
  2187  		}
  2188  		root, err = x509.ParseCertificate(rootDER)
  2189  		if err != nil {
  2190  			t.Fatalf("ParseCertificate: %v", err)
  2191  		}
  2192  
  2193  		tmpl = &x509.Certificate{
  2194  			Subject:   pkix.Name{},
  2195  			DNSNames:  []string{"expired-resume.example.com"},
  2196  			NotBefore: leafNotAfter.Add(-time.Hour * 24),
  2197  			NotAfter:  leafNotAfter,
  2198  			KeyUsage:  x509.KeyUsageDigitalSignature,
  2199  		}
  2200  		leafCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, root, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey)
  2201  		if err != nil {
  2202  			t.Fatalf("CreateCertificate: %v", err)
  2203  		}
  2204  		tmpl.NotBefore, tmpl.NotAfter = leafNotAfter.Add(-time.Hour*24*365), leafNotAfter.Add(-time.Hour*24*364)
  2205  		expiredLeafDERCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, root, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey)
  2206  		if err != nil {
  2207  			t.Fatalf("CreateCertificate: %v", err)
  2208  		}
  2209  
  2210  		return leafCertDER, expiredLeafDERCertDER, root
  2211  	}
  2212  	testExpiration := func(name string, leafNotAfter, rootNotAfter time.Time) {
  2213  		t.Run(name, func(t *testing.T) {
  2214  			initialLeafDER, expiredLeafDER, initialRoot := createChain(leafNotAfter, rootNotAfter)
  2215  
  2216  			serverConfig := testConfig.Clone()
  2217  			serverConfig.MaxVersion = version
  2218  			serverConfig.Certificates = []Certificate{{
  2219  				Certificate: [][]byte{initialLeafDER, expiredLeafDER},
  2220  				PrivateKey:  testECDSAPrivateKey,
  2221  			}}
  2222  			serverConfig.ClientCAs = x509.NewCertPool()
  2223  			serverConfig.ClientCAs.AddCert(initialRoot)
  2224  			serverConfig.ClientAuth = RequireAndVerifyClientCert
  2225  			serverConfig.Time = func() time.Time {
  2226  				return now
  2227  			}
  2228  			serverConfig.InsecureSkipVerify = false
  2229  			serverConfig.ServerName = "expired-resume.example.com"
  2230  
  2231  			clientConfig := testConfig.Clone()
  2232  			clientConfig.MaxVersion = version
  2233  			clientConfig.Certificates = []Certificate{{
  2234  				Certificate: [][]byte{initialLeafDER, expiredLeafDER},
  2235  				PrivateKey:  testECDSAPrivateKey,
  2236  			}}
  2237  			clientConfig.RootCAs = x509.NewCertPool()
  2238  			clientConfig.RootCAs.AddCert(initialRoot)
  2239  			clientConfig.ServerName = "expired-resume.example.com"
  2240  			clientConfig.ClientSessionCache = NewLRUClientSessionCache(32)
  2241  			clientConfig.InsecureSkipVerify = false
  2242  			clientConfig.ServerName = "expired-resume.example.com"
  2243  			clientConfig.Time = func() time.Time {
  2244  				return now
  2245  			}
  2246  
  2247  			testResume := func(t *testing.T, sc, cc *Config, expectResume bool) {
  2248  				t.Helper()
  2249  				ss, cs, err := testHandshake(t, cc, sc)
  2250  				if err != nil {
  2251  					t.Fatalf("handshake: %v", err)
  2252  				}
  2253  				if cs.DidResume != expectResume {
  2254  					t.Fatalf("DidResume = %v; want %v", cs.DidResume, expectResume)
  2255  				}
  2256  				if ss.DidResume != expectResume {
  2257  					t.Fatalf("DidResume = %v; want %v", cs.DidResume, expectResume)
  2258  				}
  2259  			}
  2260  
  2261  			testResume(t, serverConfig, clientConfig, false)
  2262  			testResume(t, serverConfig, clientConfig, true)
  2263  
  2264  			expiredNow := time.Unix(0, min(leafNotAfter.UnixNano(), rootNotAfter.UnixNano())).Add(time.Minute)
  2265  
  2266  			freshLeafDER, expiredLeafDER, freshRoot := createChain(expiredNow.Add(time.Hour), expiredNow.Add(time.Hour))
  2267  			clientConfig.Certificates = []Certificate{{
  2268  				Certificate: [][]byte{freshLeafDER, expiredLeafDER},
  2269  				PrivateKey:  testECDSAPrivateKey,
  2270  			}}
  2271  			serverConfig.Time = func() time.Time {
  2272  				return expiredNow
  2273  			}
  2274  			serverConfig.ClientCAs = x509.NewCertPool()
  2275  			serverConfig.ClientCAs.AddCert(freshRoot)
  2276  
  2277  			testResume(t, serverConfig, clientConfig, false)
  2278  		})
  2279  	}
  2280  
  2281  	testExpiration("LeafExpiresBeforeRoot", now.Add(2*time.Hour), now.Add(3*time.Hour))
  2282  	testExpiration("LeafExpiresAfterRoot", now.Add(2*time.Hour), now.Add(time.Hour))
  2283  }
  2284  
  2285  func TestHandshakeGetConfigForClientDifferentClientCAs(t *testing.T) {
  2286  	t.Run("TLS1.2", func(t *testing.T) {
  2287  		testHandshakeGetConfigForClientDifferentClientCAs(t, VersionTLS12)
  2288  	})
  2289  	t.Run("TLS1.3", func(t *testing.T) {
  2290  		testHandshakeGetConfigForClientDifferentClientCAs(t, VersionTLS13)
  2291  	})
  2292  }
  2293  
  2294  func testHandshakeGetConfigForClientDifferentClientCAs(t *testing.T, version uint16) {
  2295  	now := time.Now()
  2296  	tmpl := &x509.Certificate{
  2297  		Subject:               pkix.Name{CommonName: "root"},
  2298  		NotBefore:             now.Add(-time.Hour * 24),
  2299  		NotAfter:              now.Add(time.Hour * 24),
  2300  		IsCA:                  true,
  2301  		BasicConstraintsValid: true,
  2302  	}
  2303  	rootDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey)
  2304  	if err != nil {
  2305  		t.Fatalf("CreateCertificate: %v", err)
  2306  	}
  2307  	rootA, err := x509.ParseCertificate(rootDER)
  2308  	if err != nil {
  2309  		t.Fatalf("ParseCertificate: %v", err)
  2310  	}
  2311  	rootDER, err = x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testRSA2048PrivateKey.PublicKey, testRSA2048PrivateKey)
  2312  	if err != nil {
  2313  		t.Fatalf("CreateCertificate: %v", err)
  2314  	}
  2315  	rootB, err := x509.ParseCertificate(rootDER)
  2316  	if err != nil {
  2317  		t.Fatalf("ParseCertificate: %v", err)
  2318  	}
  2319  
  2320  	tmpl = &x509.Certificate{
  2321  		Subject:   pkix.Name{},
  2322  		DNSNames:  []string{"example.com"},
  2323  		NotBefore: now.Add(-time.Hour * 24),
  2324  		NotAfter:  now.Add(time.Hour * 24),
  2325  		KeyUsage:  x509.KeyUsageDigitalSignature,
  2326  	}
  2327  	certA, err := x509.CreateCertificate(rand.Reader, tmpl, rootA, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey)
  2328  	if err != nil {
  2329  		t.Fatalf("CreateCertificate: %v", err)
  2330  	}
  2331  	certB, err := x509.CreateCertificate(rand.Reader, tmpl, rootB, &testECDSAPrivateKey.PublicKey, testRSA2048PrivateKey)
  2332  	if err != nil {
  2333  		t.Fatalf("CreateCertificate: %v", err)
  2334  	}
  2335  
  2336  	serverConfig := testConfig.Clone()
  2337  	serverConfig.MaxVersion = version
  2338  	serverConfig.Certificates = []Certificate{{
  2339  		Certificate: [][]byte{certA},
  2340  		PrivateKey:  testECDSAPrivateKey,
  2341  	}}
  2342  	serverConfig.Time = func() time.Time {
  2343  		return now
  2344  	}
  2345  	serverConfig.ClientCAs = x509.NewCertPool()
  2346  	serverConfig.ClientCAs.AddCert(rootA)
  2347  	serverConfig.ClientAuth = RequireAndVerifyClientCert
  2348  	switchConfig := false
  2349  	serverConfig.GetConfigForClient = func(clientHello *ClientHelloInfo) (*Config, error) {
  2350  		if !switchConfig {
  2351  			return nil, nil
  2352  		}
  2353  		cfg := serverConfig.Clone()
  2354  		cfg.ClientCAs = x509.NewCertPool()
  2355  		cfg.ClientCAs.AddCert(rootB)
  2356  		return cfg, nil
  2357  	}
  2358  	serverConfig.InsecureSkipVerify = false
  2359  	serverConfig.ServerName = "example.com"
  2360  
  2361  	clientConfig := testConfig.Clone()
  2362  	clientConfig.MaxVersion = version
  2363  	clientConfig.Certificates = []Certificate{{
  2364  		Certificate: [][]byte{certA},
  2365  		PrivateKey:  testECDSAPrivateKey,
  2366  	}}
  2367  	clientConfig.ClientSessionCache = NewLRUClientSessionCache(32)
  2368  	clientConfig.RootCAs = x509.NewCertPool()
  2369  	clientConfig.RootCAs.AddCert(rootA)
  2370  	clientConfig.Time = func() time.Time {
  2371  		return now
  2372  	}
  2373  	clientConfig.InsecureSkipVerify = false
  2374  	clientConfig.ServerName = "example.com"
  2375  
  2376  	testResume := func(t *testing.T, sc, cc *Config, expectResume bool) {
  2377  		t.Helper()
  2378  		ss, cs, err := testHandshake(t, cc, sc)
  2379  		if err != nil {
  2380  			t.Fatalf("handshake: %v", err)
  2381  		}
  2382  		if cs.DidResume != expectResume {
  2383  			t.Fatalf("DidResume = %v; want %v", cs.DidResume, expectResume)
  2384  		}
  2385  		if ss.DidResume != expectResume {
  2386  			t.Fatalf("DidResume = %v; want %v", cs.DidResume, expectResume)
  2387  		}
  2388  	}
  2389  
  2390  	testResume(t, serverConfig, clientConfig, false)
  2391  	testResume(t, serverConfig, clientConfig, true)
  2392  
  2393  	clientConfig.Certificates[0].Certificate = [][]byte{certB}
  2394  
  2395  	// Cause GetConfigForClient to return a config cloned from the base config,
  2396  	// but with a different ClientCAs pool. This should cause resumption to fail.
  2397  	switchConfig = true
  2398  
  2399  	testResume(t, serverConfig, clientConfig, false)
  2400  	testResume(t, serverConfig, clientConfig, true)
  2401  }
  2402  
  2403  func TestHandshakeChangeRootCAsResumption(t *testing.T) {
  2404  	t.Run("TLS1.2", func(t *testing.T) {
  2405  		testHandshakeChangeRootCAsResumption(t, VersionTLS12)
  2406  	})
  2407  	t.Run("TLS1.3", func(t *testing.T) {
  2408  		testHandshakeChangeRootCAsResumption(t, VersionTLS13)
  2409  	})
  2410  }
  2411  
  2412  func testHandshakeChangeRootCAsResumption(t *testing.T, version uint16) {
  2413  	now := time.Now()
  2414  	tmpl := &x509.Certificate{
  2415  		Subject:               pkix.Name{CommonName: "root"},
  2416  		NotBefore:             now.Add(-time.Hour * 24),
  2417  		NotAfter:              now.Add(time.Hour * 24),
  2418  		IsCA:                  true,
  2419  		BasicConstraintsValid: true,
  2420  	}
  2421  	rootDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey)
  2422  	if err != nil {
  2423  		t.Fatalf("CreateCertificate: %v", err)
  2424  	}
  2425  	rootA, err := x509.ParseCertificate(rootDER)
  2426  	if err != nil {
  2427  		t.Fatalf("ParseCertificate: %v", err)
  2428  	}
  2429  	rootDER, err = x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testRSA2048PrivateKey.PublicKey, testRSA2048PrivateKey)
  2430  	if err != nil {
  2431  		t.Fatalf("CreateCertificate: %v", err)
  2432  	}
  2433  	rootB, err := x509.ParseCertificate(rootDER)
  2434  	if err != nil {
  2435  		t.Fatalf("ParseCertificate: %v", err)
  2436  	}
  2437  
  2438  	tmpl = &x509.Certificate{
  2439  		Subject:   pkix.Name{},
  2440  		DNSNames:  []string{"example.com"},
  2441  		NotBefore: now.Add(-time.Hour * 24),
  2442  		NotAfter:  now.Add(time.Hour * 24),
  2443  		KeyUsage:  x509.KeyUsageDigitalSignature,
  2444  	}
  2445  	certA, err := x509.CreateCertificate(rand.Reader, tmpl, rootA, &testECDSAPrivateKey.PublicKey, testECDSAPrivateKey)
  2446  	if err != nil {
  2447  		t.Fatalf("CreateCertificate: %v", err)
  2448  	}
  2449  	certB, err := x509.CreateCertificate(rand.Reader, tmpl, rootB, &testECDSAPrivateKey.PublicKey, testRSA2048PrivateKey)
  2450  	if err != nil {
  2451  		t.Fatalf("CreateCertificate: %v", err)
  2452  	}
  2453  
  2454  	serverConfig := testConfig.Clone()
  2455  	serverConfig.MaxVersion = version
  2456  	serverConfig.Certificates = []Certificate{{
  2457  		Certificate: [][]byte{certA},
  2458  		PrivateKey:  testECDSAPrivateKey,
  2459  	}}
  2460  	serverConfig.Time = func() time.Time {
  2461  		return now
  2462  	}
  2463  	serverConfig.ClientCAs = x509.NewCertPool()
  2464  	serverConfig.ClientCAs.AddCert(rootA)
  2465  	serverConfig.ClientAuth = RequireAndVerifyClientCert
  2466  	serverConfig.InsecureSkipVerify = false
  2467  	serverConfig.ServerName = "example.com"
  2468  
  2469  	clientConfig := testConfig.Clone()
  2470  	clientConfig.MaxVersion = version
  2471  	clientConfig.Certificates = []Certificate{{
  2472  		Certificate: [][]byte{certA},
  2473  		PrivateKey:  testECDSAPrivateKey,
  2474  	}}
  2475  	clientConfig.ClientSessionCache = NewLRUClientSessionCache(32)
  2476  	clientConfig.RootCAs = x509.NewCertPool()
  2477  	clientConfig.RootCAs.AddCert(rootA)
  2478  	clientConfig.Time = func() time.Time {
  2479  		return now
  2480  	}
  2481  	clientConfig.InsecureSkipVerify = false
  2482  	clientConfig.ServerName = "example.com"
  2483  
  2484  	testResume := func(t *testing.T, sc, cc *Config, expectResume bool) {
  2485  		t.Helper()
  2486  		ss, cs, err := testHandshake(t, cc, sc)
  2487  		if err != nil {
  2488  			t.Fatalf("handshake: %v", err)
  2489  		}
  2490  		if cs.DidResume != expectResume {
  2491  			t.Fatalf("DidResume = %v; want %v", cs.DidResume, expectResume)
  2492  		}
  2493  		if ss.DidResume != expectResume {
  2494  			t.Fatalf("DidResume = %v; want %v", cs.DidResume, expectResume)
  2495  		}
  2496  	}
  2497  
  2498  	testResume(t, serverConfig, clientConfig, false)
  2499  	testResume(t, serverConfig, clientConfig, true)
  2500  
  2501  	clientConfig = clientConfig.Clone()
  2502  	clientConfig.RootCAs = x509.NewCertPool()
  2503  	clientConfig.RootCAs.AddCert(rootB)
  2504  
  2505  	serverConfig.Certificates[0].Certificate = [][]byte{certB}
  2506  
  2507  	testResume(t, serverConfig, clientConfig, false)
  2508  	testResume(t, serverConfig, clientConfig, true)
  2509  }
  2510  

View as plain text