net/http: set Request.TLS when net.Conn implements ConnectionState

Fixes #56104

Change-Id: I8fbbb00379e51323e2782144070cbcad650eb6f1
GitHub-Last-Rev: 62d7a8064e4f2173f0d8e02ed91a7e8de7f13fca
GitHub-Pull-Request: golang/go#56110
Reviewed-on: https://go-review.googlesource.com/c/go/+/440795
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Sean Liao <sean@liao.dev>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Dmitri Shuralyov <dmitshur@golang.org>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
This commit is contained in:
Weidi Deng 2022-11-02 01:19:16 +00:00 committed by Gopher Robot
parent 786e62bcd3
commit 5413abc440
2 changed files with 59 additions and 0 deletions

View File

@ -1645,6 +1645,53 @@ func testTLSServer(t *testing.T, mode testMode) {
}
}
type fakeConnectionStateConn struct {
net.Conn
}
func (fcsc *fakeConnectionStateConn) ConnectionState() tls.ConnectionState {
return tls.ConnectionState{
ServerName: "example.com",
}
}
func TestTLSServerWithoutTLSConn(t *testing.T) {
//set up
pr, pw := net.Pipe()
c := make(chan int)
listener := &oneConnListener{&fakeConnectionStateConn{pr}}
server := &Server{
Handler: HandlerFunc(func(writer ResponseWriter, request *Request) {
if request.TLS == nil {
t.Fatal("request.TLS is nil, expected not nil")
}
if request.TLS.ServerName != "example.com" {
t.Fatalf("request.TLS.ServerName is %s, expected %s", request.TLS.ServerName, "example.com")
}
writer.Header().Set("X-TLS-ServerName", "example.com")
}),
}
// write request and read response
go func() {
req, _ := NewRequest(MethodGet, "https://example.com", nil)
req.Write(pw)
resp, _ := ReadResponse(bufio.NewReader(pw), req)
if hdr := resp.Header.Get("X-TLS-ServerName"); hdr != "example.com" {
t.Errorf("response header X-TLS-ServerName is %s, expected %s", hdr, "example.com")
}
close(c)
pw.Close()
}()
server.Serve(listener)
// oneConnListener returns error after one accept, wait util response is read
<-c
pr.Close()
}
func TestServeTLS(t *testing.T) {
CondSkipHTTP2(t)
// Not parallel: uses global test hooks.

View File

@ -1924,6 +1924,10 @@ func isCommonNetReadError(err error) bool {
return false
}
type connectionStater interface {
ConnectionState() tls.ConnectionState
}
// Serve a new connection.
func (c *conn) serve(ctx context.Context) {
if ra := c.rwc.RemoteAddr(); ra != nil {
@ -1996,6 +2000,14 @@ func (c *conn) serve(ctx context.Context) {
// HTTP/1.x from here on.
// Set Request.TLS if the conn is not a *tls.Conn, but implements ConnectionState.
if c.tlsState == nil {
if tc, ok := c.rwc.(connectionStater); ok {
c.tlsState = new(tls.ConnectionState)
*c.tlsState = tc.ConnectionState()
}
}
ctx, cancelCtx := context.WithCancel(ctx)
c.cancelCtx = cancelCtx
defer cancelCtx()