diff options
Diffstat (limited to 'libgo/go/net')
-rw-r--r-- | libgo/go/net/dial_unix_test.go | 108 | ||||
-rw-r--r-- | libgo/go/net/dnsclient_unix.go | 24 | ||||
-rw-r--r-- | libgo/go/net/dnsclient_unix_test.go | 105 | ||||
-rw-r--r-- | libgo/go/net/fd_unix.go | 55 | ||||
-rw-r--r-- | libgo/go/net/hook_unix.go | 3 | ||||
-rw-r--r-- | libgo/go/net/http/h2_bundle.go | 433 | ||||
-rw-r--r-- | libgo/go/net/http/serve_test.go | 11 | ||||
-rw-r--r-- | libgo/go/net/http/server.go | 38 | ||||
-rw-r--r-- | libgo/go/net/http/transport.go | 70 | ||||
-rw-r--r-- | libgo/go/net/http/transport_internal_test.go | 9 | ||||
-rw-r--r-- | libgo/go/net/http/transport_test.go | 94 |
11 files changed, 751 insertions, 199 deletions
diff --git a/libgo/go/net/dial_unix_test.go b/libgo/go/net/dial_unix_test.go new file mode 100644 index 00000000000..47052547284 --- /dev/null +++ b/libgo/go/net/dial_unix_test.go @@ -0,0 +1,108 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux netbsd openbsd solaris + +package net + +import ( + "context" + "syscall" + "testing" + "time" +) + +// Issue 16523 +func TestDialContextCancelRace(t *testing.T) { + oldConnectFunc := connectFunc + oldGetsockoptIntFunc := getsockoptIntFunc + oldTestHookCanceledDial := testHookCanceledDial + defer func() { + connectFunc = oldConnectFunc + getsockoptIntFunc = oldGetsockoptIntFunc + testHookCanceledDial = oldTestHookCanceledDial + }() + + ln, err := newLocalListener("tcp") + if err != nil { + t.Fatal(err) + } + listenerDone := make(chan struct{}) + go func() { + defer close(listenerDone) + c, err := ln.Accept() + if err == nil { + c.Close() + } + }() + defer func() { <-listenerDone }() + defer ln.Close() + + sawCancel := make(chan bool, 1) + testHookCanceledDial = func() { + sawCancel <- true + } + + ctx, cancelCtx := context.WithCancel(context.Background()) + + connectFunc = func(fd int, addr syscall.Sockaddr) error { + err := oldConnectFunc(fd, addr) + t.Logf("connect(%d, addr) = %v", fd, err) + if err == nil { + // On some operating systems, localhost + // connects _sometimes_ succeed immediately. + // Prevent that, so we exercise the code path + // we're interested in testing. This seems + // harmless. It makes FreeBSD 10.10 work when + // run with many iterations. It failed about + // half the time previously. + return syscall.EINPROGRESS + } + return err + } + + getsockoptIntFunc = func(fd, level, opt int) (val int, err error) { + val, err = oldGetsockoptIntFunc(fd, level, opt) + t.Logf("getsockoptIntFunc(%d, %d, %d) = (%v, %v)", fd, level, opt, val, err) + if level == syscall.SOL_SOCKET && opt == syscall.SO_ERROR && err == nil && val == 0 { + t.Logf("canceling context") + + // Cancel the context at just the moment which + // caused the race in issue 16523. + cancelCtx() + + // And wait for the "interrupter" goroutine to + // cancel the dial by messing with its write + // timeout before returning. + select { + case <-sawCancel: + t.Logf("saw cancel") + case <-time.After(5 * time.Second): + t.Errorf("didn't see cancel after 5 seconds") + } + } + return + } + + var d Dialer + c, err := d.DialContext(ctx, "tcp", ln.Addr().String()) + if err == nil { + c.Close() + t.Fatal("unexpected successful dial; want context canceled error") + } + + select { + case <-ctx.Done(): + case <-time.After(5 * time.Second): + t.Fatal("expected context to be canceled") + } + + oe, ok := err.(*OpError) + if !ok || oe.Op != "dial" { + t.Fatalf("Dial error = %#v; want dial *OpError", err) + } + if oe.Err != ctx.Err() { + t.Errorf("DialContext = (%v, %v); want OpError with error %v", c, err, ctx.Err()) + } +} diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go index 8f2dff46751..b5b6ffb1c50 100644 --- a/libgo/go/net/dnsclient_unix.go +++ b/libgo/go/net/dnsclient_unix.go @@ -141,7 +141,7 @@ func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, } // exchange sends a query on the connection and hopes for a response. -func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg, error) { +func exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) { d := testHookDNSDialer() out := dnsMsg{ dnsMsgHdr: dnsMsgHdr{ @@ -152,6 +152,12 @@ func exchange(ctx context.Context, server, name string, qtype uint16) (*dnsMsg, }, } for _, network := range []string{"udp", "tcp"} { + // TODO(mdempsky): Refactor so defers from UDP-based + // exchanges happen before TCP-based exchange. + + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancel() + c, err := d.dialDNS(ctx, network, server) if err != nil { return nil, err @@ -180,17 +186,10 @@ func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) return "", nil, &DNSError{Err: "no DNS servers", Name: name} } - deadline := time.Now().Add(cfg.timeout) - if old, ok := ctx.Deadline(); !ok || deadline.Before(old) { - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, deadline) - defer cancel() - } - var lastErr error for i := 0; i < cfg.attempts; i++ { for _, server := range cfg.servers { - msg, err := exchange(ctx, server, name, qtype) + msg, err := exchange(ctx, server, name, qtype, cfg.timeout) if err != nil { lastErr = &DNSError{ Err: err.Error(), @@ -338,8 +337,9 @@ func lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs [ } // avoidDNS reports whether this is a hostname for which we should not -// use DNS. Currently this includes only .onion and .local names, -// per RFC 7686 and RFC 6762, respectively. See golang.org/issue/13705. +// use DNS. Currently this includes only .onion, per RFC 7686. See +// golang.org/issue/13705. Does not cover .local names (RFC 6762), +// see golang.org/issue/16739. func avoidDNS(name string) bool { if name == "" { return true @@ -347,7 +347,7 @@ func avoidDNS(name string) bool { if name[len(name)-1] == '.' { name = name[:len(name)-1] } - return stringsHasSuffixFold(name, ".onion") || stringsHasSuffixFold(name, ".local") + return stringsHasSuffixFold(name, ".onion") } // nameList returns a list of names for sequential DNS queries. diff --git a/libgo/go/net/dnsclient_unix_test.go b/libgo/go/net/dnsclient_unix_test.go index 09bbd488660..6ebeeaeb8f4 100644 --- a/libgo/go/net/dnsclient_unix_test.go +++ b/libgo/go/net/dnsclient_unix_test.go @@ -40,9 +40,9 @@ func TestDNSTransportFallback(t *testing.T) { testenv.MustHaveExternalNetwork(t) for _, tt := range dnsTransportFallbackTests { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(tt.timeout)*time.Second) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - msg, err := exchange(ctx, tt.server, tt.name, tt.qtype) + msg, err := exchange(ctx, tt.server, tt.name, tt.qtype, time.Second) if err != nil { t.Error(err) continue @@ -82,9 +82,9 @@ func TestSpecialDomainName(t *testing.T) { server := "8.8.8.8:53" for _, tt := range specialDomainNameTests { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - msg, err := exchange(ctx, server, tt.name, tt.qtype) + msg, err := exchange(ctx, server, tt.name, tt.qtype, 3*time.Second) if err != nil { t.Error(err) continue @@ -112,10 +112,11 @@ func TestAvoidDNSName(t *testing.T) { {"foo.ONION", true}, {"foo.ONION.", true}, - {"foo.local.", true}, - {"foo.local", true}, - {"foo.LOCAL", true}, - {"foo.LOCAL.", true}, + // But do resolve *.local address; Issue 16739 + {"foo.local.", false}, + {"foo.local", false}, + {"foo.LOCAL", false}, + {"foo.LOCAL.", false}, {"", true}, // will be rejected earlier too @@ -500,7 +501,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { d := &fakeDNSDialer{} testHookDNSDialer = func() dnsDialer { return d } - d.rh = func(s string, q *dnsMsg) (*dnsMsg, error) { + d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { r := &dnsMsg{ dnsMsgHdr: dnsMsgHdr{ id: q.id, @@ -539,14 +540,15 @@ func TestIgnoreLameReferrals(t *testing.T) { } defer conf.teardown() - if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1", "nameserver 192.0.2.2"}); err != nil { + if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1", // the one that will give a lame referral + "nameserver 192.0.2.2"}); err != nil { t.Fatal(err) } d := &fakeDNSDialer{} testHookDNSDialer = func() dnsDialer { return d } - d.rh = func(s string, q *dnsMsg) (*dnsMsg, error) { + d.rh = func(s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { t.Log(s, q) r := &dnsMsg{ dnsMsgHdr: dnsMsgHdr{ @@ -633,28 +635,30 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { type fakeDNSDialer struct { // reply handler - rh func(s string, q *dnsMsg) (*dnsMsg, error) + rh func(s string, q *dnsMsg, t time.Time) (*dnsMsg, error) } func (f *fakeDNSDialer) dialDNS(_ context.Context, n, s string) (dnsConn, error) { - return &fakeDNSConn{f.rh, s}, nil + return &fakeDNSConn{f.rh, s, time.Time{}}, nil } type fakeDNSConn struct { - rh func(s string, q *dnsMsg) (*dnsMsg, error) + rh func(s string, q *dnsMsg, t time.Time) (*dnsMsg, error) s string + t time.Time } func (f *fakeDNSConn) Close() error { return nil } -func (f *fakeDNSConn) SetDeadline(time.Time) error { +func (f *fakeDNSConn) SetDeadline(t time.Time) error { + f.t = t return nil } func (f *fakeDNSConn) dnsRoundTrip(q *dnsMsg) (*dnsMsg, error) { - return f.rh(f.s, q) + return f.rh(f.s, q, f.t) } // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281). @@ -724,3 +728,72 @@ func TestIgnoreDNSForgeries(t *testing.T) { t.Errorf("got address %v, want %v", got, TestAddr) } } + +// Issue 16865. If a name server times out, continue to the next. +func TestRetryTimeout(t *testing.T) { + origTestHookDNSDialer := testHookDNSDialer + defer func() { testHookDNSDialer = origTestHookDNSDialer }() + + conf, err := newResolvConfTest() + if err != nil { + t.Fatal(err) + } + defer conf.teardown() + + if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1", // the one that will timeout + "nameserver 192.0.2.2"}); err != nil { + t.Fatal(err) + } + + d := &fakeDNSDialer{} + testHookDNSDialer = func() dnsDialer { return d } + + var deadline0 time.Time + + d.rh = func(s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { + t.Log(s, q, deadline) + + if deadline.IsZero() { + t.Error("zero deadline") + } + + if s == "192.0.2.1:53" { + deadline0 = deadline + time.Sleep(10 * time.Millisecond) + return nil, errTimeout + } + + if deadline == deadline0 { + t.Error("deadline didn't change") + } + + r := &dnsMsg{ + dnsMsgHdr: dnsMsgHdr{ + id: q.id, + response: true, + recursion_available: true, + }, + question: q.question, + answer: []dnsRR{ + &dnsRR_CNAME{ + Hdr: dnsRR_Header{ + Name: q.question[0].Name, + Rrtype: dnsTypeCNAME, + Class: dnsClassINET, + }, + Cname: "golang.org", + }, + }, + } + return r, nil + } + + _, err = goLookupCNAME(context.Background(), "www.golang.org") + if err != nil { + t.Fatal(err) + } + + if deadline0.IsZero() { + t.Error("deadline0 still zero", deadline0) + } +} diff --git a/libgo/go/net/fd_unix.go b/libgo/go/net/fd_unix.go index 6fbb9cbf114..0309db08ebc 100644 --- a/libgo/go/net/fd_unix.go +++ b/libgo/go/net/fd_unix.go @@ -64,7 +64,7 @@ func (fd *netFD) name() string { return fd.net + ":" + ls + "->" + rs } -func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error { +func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (ret error) { // Do not need to call fd.writeLock here, // because fd is not yet accessible to user, // so no concurrent operations are possible. @@ -101,21 +101,44 @@ func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) error { defer fd.setWriteDeadline(noDeadline) } - // Wait for the goroutine converting context.Done into a write timeout - // to exist, otherwise our caller might cancel the context and - // cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial. - done := make(chan bool) // must be unbuffered - defer func() { done <- true }() - go func() { - select { - case <-ctx.Done(): - // Force the runtime's poller to immediately give - // up waiting for writability. - fd.setWriteDeadline(aLongTimeAgo) - <-done - case <-done: - } - }() + // Start the "interrupter" goroutine, if this context might be canceled. + // (The background context cannot) + // + // The interrupter goroutine waits for the context to be done and + // interrupts the dial (by altering the fd's write deadline, which + // wakes up waitWrite). + if ctx != context.Background() { + // Wait for the interrupter goroutine to exit before returning + // from connect. + done := make(chan struct{}) + interruptRes := make(chan error) + defer func() { + close(done) + if ctxErr := <-interruptRes; ctxErr != nil && ret == nil { + // The interrupter goroutine called setWriteDeadline, + // but the connect code below had returned from + // waitWrite already and did a successful connect (ret + // == nil). Because we've now poisoned the connection + // by making it unwritable, don't return a successful + // dial. This was issue 16523. + ret = ctxErr + fd.Close() // prevent a leak + } + }() + go func() { + select { + case <-ctx.Done(): + // Force the runtime's poller to immediately give up + // waiting for writability, unblocking waitWrite + // below. + fd.setWriteDeadline(aLongTimeAgo) + testHookCanceledDial() + interruptRes <- ctx.Err() + case <-done: + interruptRes <- nil + } + }() + } for { // Performing multiple connect system calls on a diff --git a/libgo/go/net/hook_unix.go b/libgo/go/net/hook_unix.go index 361ca5980c3..cf52567fcfd 100644 --- a/libgo/go/net/hook_unix.go +++ b/libgo/go/net/hook_unix.go @@ -9,7 +9,8 @@ package net import "syscall" var ( - testHookDialChannel = func() {} // see golang.org/issue/5349 + testHookDialChannel = func() {} // for golang.org/issue/5349 + testHookCanceledDial = func() {} // for golang.org/issue/16523 // Placeholders for socket system calls. socketFunc func(int, int, int) (int, error) = syscall.Socket diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go index db774554b2c..5826bb7d858 100644 --- a/libgo/go/net/http/h2_bundle.go +++ b/libgo/go/net/http/h2_bundle.go @@ -28,6 +28,7 @@ import ( "io" "io/ioutil" "log" + "math" "net" "net/http/httptrace" "net/textproto" @@ -85,7 +86,16 @@ const ( http2noDialOnMiss = false ) -func (p *http2clientConnPool) getClientConn(_ *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) { +func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) { + if http2isConnectionCloseRequest(req) && dialOnMiss { + // It gets its own connection. + const singleUse = true + cc, err := p.t.dialClientConn(addr, singleUse) + if err != nil { + return nil, err + } + return cc, nil + } p.mu.Lock() for _, cc := range p.conns[addr] { if cc.CanTakeNewRequest() { @@ -128,7 +138,8 @@ func (p *http2clientConnPool) getStartDialLocked(addr string) *http2dialCall { // run in its own goroutine. func (c *http2dialCall) dial(addr string) { - c.res, c.err = c.p.t.dialClientConn(addr) + const singleUse = false // shared conn + c.res, c.err = c.p.t.dialClientConn(addr, singleUse) close(c.done) c.p.mu.Lock() @@ -393,9 +404,17 @@ func (e http2ConnectionError) Error() string { type http2StreamError struct { StreamID uint32 Code http2ErrCode + Cause error // optional additional detail +} + +func http2streamError(id uint32, code http2ErrCode) http2StreamError { + return http2StreamError{StreamID: id, Code: code} } func (e http2StreamError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("stream error: stream ID %d; %v; %v", e.StreamID, e.Code, e.Cause) + } return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code) } @@ -1105,6 +1124,7 @@ func http2parseDataFrame(fh http2FrameHeader, payload []byte) (http2Frame, error var ( http2errStreamID = errors.New("invalid stream ID") http2errDepStreamID = errors.New("invalid dependent stream ID") + http2errPadLength = errors.New("pad length too large") ) func http2validStreamIDOrZero(streamID uint32) bool { @@ -1118,18 +1138,40 @@ func http2validStreamID(streamID uint32) bool { // WriteData writes a DATA frame. // // It will perform exactly one Write to the underlying Writer. -// It is the caller's responsibility to not call other Write methods concurrently. +// It is the caller's responsibility not to violate the maximum frame size +// and to not call other Write methods concurrently. func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) error { + return f.WriteDataPadded(streamID, endStream, data, nil) +} +// WriteData writes a DATA frame with optional padding. +// +// If pad is nil, the padding bit is not sent. +// The length of pad must not exceed 255 bytes. +// +// It will perform exactly one Write to the underlying Writer. +// It is the caller's responsibility not to violate the maximum frame size +// and to not call other Write methods concurrently. +func (f *http2Framer) WriteDataPadded(streamID uint32, endStream bool, data, pad []byte) error { if !http2validStreamID(streamID) && !f.AllowIllegalWrites { return http2errStreamID } + if len(pad) > 255 { + return http2errPadLength + } var flags http2Flags if endStream { flags |= http2FlagDataEndStream } + if pad != nil { + flags |= http2FlagDataPadded + } f.startWrite(http2FrameData, flags, streamID) + if pad != nil { + f.wbuf = append(f.wbuf, byte(len(pad))) + } f.wbuf = append(f.wbuf, data...) + f.wbuf = append(f.wbuf, pad...) return f.endWrite() } @@ -1333,7 +1375,7 @@ func http2parseWindowUpdateFrame(fh http2FrameHeader, p []byte) (http2Frame, err if fh.StreamID == 0 { return nil, http2ConnectionError(http2ErrCodeProtocol) } - return nil, http2StreamError{fh.StreamID, http2ErrCodeProtocol} + return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) } return &http2WindowUpdateFrame{ http2FrameHeader: fh, @@ -1411,7 +1453,7 @@ func http2parseHeadersFrame(fh http2FrameHeader, p []byte) (_ http2Frame, err er } } if len(p)-int(padLength) <= 0 { - return nil, http2StreamError{fh.StreamID, http2ErrCodeProtocol} + return nil, http2streamError(fh.StreamID, http2ErrCodeProtocol) } hf.headerFragBuf = p[:len(p)-int(padLength)] return hf, nil @@ -1878,6 +1920,9 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr hdec.SetEmitEnabled(true) hdec.SetMaxStringLength(fr.maxHeaderStringLen()) hdec.SetEmitFunc(func(hf hpack.HeaderField) { + if http2VerboseLogs && http2logFrameReads { + log.Printf("http2: decoded hpack field %+v", hf) + } if !httplex.ValidHeaderFieldValue(hf.Value) { invalid = http2headerFieldValueError(hf.Value) } @@ -1936,11 +1981,17 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr } if invalid != nil { fr.errDetail = invalid - return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol} + if http2VerboseLogs { + log.Printf("http2: invalid header: %v", invalid) + } + return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, invalid} } if err := mh.checkPseudos(); err != nil { fr.errDetail = err - return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol} + if http2VerboseLogs { + log.Printf("http2: invalid pseudo headers: %v", err) + } + return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol, err} } return mh, nil } @@ -3571,7 +3622,7 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { case http2stateOpen: st.state = http2stateHalfClosedLocal - errCancel := http2StreamError{st.id, http2ErrCodeCancel} + errCancel := http2streamError(st.id, http2ErrCodeCancel) sc.resetStream(errCancel) case http2stateHalfClosedRemote: sc.closeStream(st, http2errHandlerComplete) @@ -3764,7 +3815,7 @@ func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error return nil } if !st.flow.add(int32(f.Increment)) { - return http2StreamError{f.StreamID, http2ErrCodeFlowControl} + return http2streamError(f.StreamID, http2ErrCodeFlowControl) } default: if !sc.flow.add(int32(f.Increment)) { @@ -3786,7 +3837,7 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { if st != nil { st.gotReset = true st.cancelCtx() - sc.closeStream(st, http2StreamError{f.StreamID, f.ErrCode}) + sc.closeStream(st, http2streamError(f.StreamID, f.ErrCode)) } return nil } @@ -3803,6 +3854,9 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { } delete(sc.streams, st.id) if p := st.body; p != nil { + + sc.sendWindowUpdate(nil, p.Len()) + p.CloseWithError(err) } st.cw.Close() @@ -3879,36 +3933,51 @@ func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { func (sc *http2serverConn) processData(f *http2DataFrame) error { sc.serveG.check() + data := f.Data() id := f.Header().StreamID st, ok := sc.streams[id] if !ok || st.state != http2stateOpen || st.gotTrailerHeader { - return http2StreamError{id, http2ErrCodeStreamClosed} + if sc.inflow.available() < int32(f.Length) { + return http2streamError(id, http2ErrCodeFlowControl) + } + + sc.inflow.take(int32(f.Length)) + sc.sendWindowUpdate(nil, int(f.Length)) + + return http2streamError(id, http2ErrCodeStreamClosed) } if st.body == nil { panic("internal error: should have a body in this state") } - data := f.Data() if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes { st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) - return http2StreamError{id, http2ErrCodeStreamClosed} + return http2streamError(id, http2ErrCodeStreamClosed) } - if len(data) > 0 { + if f.Length > 0 { - if int(st.inflow.available()) < len(data) { - return http2StreamError{id, http2ErrCodeFlowControl} + if st.inflow.available() < int32(f.Length) { + return http2streamError(id, http2ErrCodeFlowControl) } - st.inflow.take(int32(len(data))) - wrote, err := st.body.Write(data) - if err != nil { - return http2StreamError{id, http2ErrCodeStreamClosed} + st.inflow.take(int32(f.Length)) + + if len(data) > 0 { + wrote, err := st.body.Write(data) + if err != nil { + return http2streamError(id, http2ErrCodeStreamClosed) + } + if wrote != len(data) { + panic("internal error: bad Writer") + } + st.bodyBytes += int64(len(data)) } - if wrote != len(data) { - panic("internal error: bad Writer") + + if pad := int32(f.Length) - int32(len(data)); pad > 0 { + sc.sendWindowUpdate32(nil, pad) + sc.sendWindowUpdate32(st, pad) } - st.bodyBytes += int64(len(data)) } if f.StreamEnded() { st.endStream() @@ -3995,10 +4064,10 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { if sc.unackedSettings == 0 { - return http2StreamError{st.id, http2ErrCodeProtocol} + return http2streamError(st.id, http2ErrCodeProtocol) } - return http2StreamError{st.id, http2ErrCodeRefusedStream} + return http2streamError(st.id, http2ErrCodeRefusedStream) } rw, req, err := sc.newWriterAndRequest(st, f) @@ -4032,18 +4101,18 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { } st.gotTrailerHeader = true if !f.StreamEnded() { - return http2StreamError{st.id, http2ErrCodeProtocol} + return http2streamError(st.id, http2ErrCodeProtocol) } if len(f.PseudoFields()) > 0 { - return http2StreamError{st.id, http2ErrCodeProtocol} + return http2streamError(st.id, http2ErrCodeProtocol) } if st.trailer != nil { for _, hf := range f.RegularFields() { key := sc.canonicalHeader(hf.Name) if !http2ValidTrailerHeader(key) { - return http2StreamError{st.id, http2ErrCodeProtocol} + return http2streamError(st.id, http2ErrCodeProtocol) } st.trailer[key] = append(st.trailer[key], hf.Value) } @@ -4097,18 +4166,18 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead isConnect := method == "CONNECT" if isConnect { if path != "" || scheme != "" || authority == "" { - return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol} + return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } } else if method == "" || path == "" || (scheme != "https" && scheme != "http") { - return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol} + return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } bodyOpen := !f.StreamEnded() if method == "HEAD" && bodyOpen { - return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol} + return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } var tlsState *tls.ConnectionState // nil if not scheme https @@ -4165,7 +4234,7 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead var err error url_, err = url.ParseRequestURI(path) if err != nil { - return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol} + return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } requestURI = path } @@ -4919,35 +4988,37 @@ func (t *http2Transport) initConnPool() { // ClientConn is the state of a single HTTP/2 client connection to an // HTTP/2 server. type http2ClientConn struct { - t *http2Transport - tconn net.Conn // usually *tls.Conn, except specialized impls - tlsState *tls.ConnectionState // nil only for specialized impls + t *http2Transport + tconn net.Conn // usually *tls.Conn, except specialized impls + tlsState *tls.ConnectionState // nil only for specialized impls + singleUse bool // whether being used for a single http.Request // readLoop goroutine fields: readerDone chan struct{} // closed on error readerErr error // set before readerDone is closed - mu sync.Mutex // guards following - cond *sync.Cond // hold mu; broadcast on flow/closed changes - flow http2flow // our conn-level flow control quota (cs.flow is per stream) - inflow http2flow // peer's conn-level flow control - closed bool - goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received - goAwayDebug string // goAway frame's debug data, retained as a string - streams map[uint32]*http2clientStream // client-initiated - nextStreamID uint32 - bw *bufio.Writer - br *bufio.Reader - fr *http2Framer - lastActive time.Time - - // Settings from peer: + mu sync.Mutex // guards following + cond *sync.Cond // hold mu; broadcast on flow/closed changes + flow http2flow // our conn-level flow control quota (cs.flow is per stream) + inflow http2flow // peer's conn-level flow control + closed bool + wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back + goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received + goAwayDebug string // goAway frame's debug data, retained as a string + streams map[uint32]*http2clientStream // client-initiated + nextStreamID uint32 + bw *bufio.Writer + br *bufio.Reader + fr *http2Framer + lastActive time.Time + // Settings from peer: (also guarded by mu) maxFrameSize uint32 maxConcurrentStreams uint32 initialWindowSize uint32 - hbuf bytes.Buffer // HPACK encoder writes into this - henc *hpack.Encoder - freeBuf [][]byte + + hbuf bytes.Buffer // HPACK encoder writes into this + henc *hpack.Encoder + freeBuf [][]byte wmu sync.Mutex // held while writing; acquire AFTER mu if holding both werr error // first write error that has occurred @@ -5117,7 +5188,7 @@ func http2shouldRetryRequest(req *Request, err error) bool { return err == http2errClientConnUnusable } -func (t *http2Transport) dialClientConn(addr string) (*http2ClientConn, error) { +func (t *http2Transport) dialClientConn(addr string, singleUse bool) (*http2ClientConn, error) { host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -5126,7 +5197,7 @@ func (t *http2Transport) dialClientConn(addr string) (*http2ClientConn, error) { if err != nil { return nil, err } - return t.NewClientConn(tconn) + return t.newClientConn(tconn, singleUse) } func (t *http2Transport) newTLSConfig(host string) *tls.Config { @@ -5187,14 +5258,10 @@ func (t *http2Transport) expectContinueTimeout() time.Duration { } func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { - if http2VerboseLogs { - t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr()) - } - if _, err := c.Write(http2clientPreface); err != nil { - t.vlogf("client preface write error: %v", err) - return nil, err - } + return t.newClientConn(c, false) +} +func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2ClientConn, error) { cc := &http2ClientConn{ t: t, tconn: c, @@ -5204,7 +5271,13 @@ func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { initialWindowSize: 65535, maxConcurrentStreams: 1000, streams: make(map[uint32]*http2clientStream), + singleUse: singleUse, + wantSettingsAck: true, } + if http2VerboseLogs { + t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) + } + cc.cond = sync.NewCond(&cc.mu) cc.flow.add(int32(http2initialWindowSize)) @@ -5228,6 +5301,8 @@ func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { if max := t.maxHeaderListSize(); max != 0 { initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max}) } + + cc.bw.Write(http2clientPreface) cc.fr.WriteSettings(initialSettings...) cc.fr.WriteWindowUpdate(0, http2transportDefaultConnFlow) cc.inflow.add(http2transportDefaultConnFlow + http2initialWindowSize) @@ -5236,32 +5311,6 @@ func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { return nil, cc.werr } - f, err := cc.fr.ReadFrame() - if err != nil { - return nil, err - } - sf, ok := f.(*http2SettingsFrame) - if !ok { - return nil, fmt.Errorf("expected settings frame, got: %T", f) - } - cc.fr.WriteSettingsAck() - cc.bw.Flush() - - sf.ForeachSetting(func(s http2Setting) error { - switch s.ID { - case http2SettingMaxFrameSize: - cc.maxFrameSize = s.Val - case http2SettingMaxConcurrentStreams: - cc.maxConcurrentStreams = s.Val - case http2SettingInitialWindowSize: - cc.initialWindowSize = s.Val - default: - - t.vlogf("Unhandled Setting: %v", s) - } - return nil - }) - go cc.readLoop() return cc, nil } @@ -5288,9 +5337,12 @@ func (cc *http2ClientConn) CanTakeNewRequest() bool { } func (cc *http2ClientConn) canTakeNewRequestLocked() bool { + if cc.singleUse && cc.nextStreamID > 1 { + return false + } return cc.goAway == nil && !cc.closed && int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) && - cc.nextStreamID < 2147483647 + cc.nextStreamID < math.MaxInt32 } func (cc *http2ClientConn) closeIfIdle() { @@ -5300,9 +5352,13 @@ func (cc *http2ClientConn) closeIfIdle() { return } cc.closed = true + nextID := cc.nextStreamID cc.mu.Unlock() + if http2VerboseLogs { + cc.vlogf("http2: Transport closing idle conn %p (forSingleUse=%v, maxStream=%v)", cc, cc.singleUse, nextID-2) + } cc.tconn.Close() } @@ -5404,12 +5460,15 @@ func http2bodyAndLength(req *Request) (body io.Reader, contentLen int64) { // We have a body but a zero content length. Test to see if // it's actually zero or just unset. var buf [1]byte - n, rerr := io.ReadFull(body, buf[:]) + n, rerr := body.Read(buf[:]) if rerr != nil && rerr != io.EOF { return http2errorReader{rerr}, -1 } if n == 1 { + if rerr == io.EOF { + return bytes.NewReader(buf[:]), 1 + } return io.MultiReader(bytes.NewReader(buf[:]), body), -1 } @@ -5494,9 +5553,10 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { bodyWritten := false ctx := http2reqContext(req) - reFunc := func(re http2resAndError) (*Response, error) { + handleReadLoopResponse := func(re http2resAndError) (*Response, error) { res := re.res if re.err != nil || res.StatusCode > 299 { + bodyWriter.cancel() cs.abortRequestBodyWrite(http2errStopReqBodyWrite) } @@ -5512,7 +5572,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { for { select { case re := <-readLoopResCh: - return reFunc(re) + return handleReadLoopResponse(re) case <-respHeaderTimer: cc.forgetStreamID(cs.ID) if !hasBody || bodyWritten { @@ -5525,7 +5585,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { case <-ctx.Done(): select { case re := <-readLoopResCh: - return reFunc(re) + return handleReadLoopResponse(re) default: } cc.forgetStreamID(cs.ID) @@ -5539,7 +5599,7 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { case <-req.Cancel: select { case re := <-readLoopResCh: - return reFunc(re) + return handleReadLoopResponse(re) default: } cc.forgetStreamID(cs.ID) @@ -5553,14 +5613,15 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { case <-cs.peerReset: select { case re := <-readLoopResCh: - return reFunc(re) + return handleReadLoopResponse(re) default: } return nil, cs.resetErr case err := <-bodyWriter.resc: + select { case re := <-readLoopResCh: - return reFunc(re) + return handleReadLoopResponse(re) default: } if err != nil { @@ -5670,26 +5731,29 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos } } + if sentEnd { + + return nil + } + + var trls []byte + if hasTrailers { + cc.mu.Lock() + defer cc.mu.Unlock() + trls = cc.encodeTrailers(req) + } + cc.wmu.Lock() - if !sentEnd { - var trls []byte - if hasTrailers { - cc.mu.Lock() - trls = cc.encodeTrailers(req) - cc.mu.Unlock() - } + defer cc.wmu.Unlock() - if len(trls) > 0 { - err = cc.writeHeaders(cs.ID, true, trls) - } else { - err = cc.fr.WriteData(cs.ID, true, nil) - } + if len(trls) > 0 { + err = cc.writeHeaders(cs.ID, true, trls) + } else { + err = cc.fr.WriteData(cs.ID, true, nil) } if ferr := cc.bw.Flush(); ferr != nil && err == nil { err = ferr } - cc.wmu.Unlock() - return err } @@ -5918,6 +5982,14 @@ func (e http2GoAwayError) Error() string { e.LastStreamID, e.ErrCode, e.DebugData) } +func http2isEOFOrNetReadError(err error) bool { + if err == io.EOF { + return true + } + ne, ok := err.(*net.OpError) + return ok && ne.Op == "read" +} + func (rl *http2clientConnReadLoop) cleanup() { cc := rl.cc defer cc.tconn.Close() @@ -5926,16 +5998,14 @@ func (rl *http2clientConnReadLoop) cleanup() { err := cc.readerErr cc.mu.Lock() - if err == io.EOF { - if cc.goAway != nil { - err = http2GoAwayError{ - LastStreamID: cc.goAway.LastStreamID, - ErrCode: cc.goAway.ErrCode, - DebugData: cc.goAwayDebug, - } - } else { - err = io.ErrUnexpectedEOF + if cc.goAway != nil && http2isEOFOrNetReadError(err) { + err = http2GoAwayError{ + LastStreamID: cc.goAway.LastStreamID, + ErrCode: cc.goAway.ErrCode, + DebugData: cc.goAwayDebug, } + } else if err == io.EOF { + err = io.ErrUnexpectedEOF } for _, cs := range rl.activeRes { cs.bufPipe.CloseWithError(err) @@ -5954,16 +6024,21 @@ func (rl *http2clientConnReadLoop) cleanup() { func (rl *http2clientConnReadLoop) run() error { cc := rl.cc - rl.closeWhenIdle = cc.t.disableKeepAlives() + rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse gotReply := false + gotSettings := false for { f, err := cc.fr.ReadFrame() if err != nil { - cc.vlogf("Transport readFrame error: (%T) %v", err, err) + cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err) } if se, ok := err.(http2StreamError); ok { if cs := cc.streamByID(se.StreamID, true); cs != nil { - rl.endStreamError(cs, cc.fr.errDetail) + cs.cc.writeStreamReset(cs.ID, se.Code, err) + if se.Cause == nil { + se.Cause = cc.fr.errDetail + } + rl.endStreamError(cs, se) } continue } else if err != nil { @@ -5972,6 +6047,13 @@ func (rl *http2clientConnReadLoop) run() error { if http2VerboseLogs { cc.vlogf("http2: Transport received %s", http2summarizeFrame(f)) } + if !gotSettings { + if _, ok := f.(*http2SettingsFrame); !ok { + cc.logf("protocol error: received %T before a SETTINGS frame", f) + return http2ConnectionError(http2ErrCodeProtocol) + } + gotSettings = true + } maybeIdle := false switch f := f.(type) { @@ -6000,6 +6082,9 @@ func (rl *http2clientConnReadLoop) run() error { cc.logf("Transport: unhandled response frame type %T", f) } if err != nil { + if http2VerboseLogs { + cc.vlogf("http2: Transport conn %p received error from processing frame %v: %v", cc, http2summarizeFrame(f), err) + } return err } if rl.closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 { @@ -6238,10 +6323,27 @@ var http2errClosedResponseBody = errors.New("http2: response body closed") func (b http2transportResponseBody) Close() error { cs := b.cs - if cs.bufPipe.Err() != io.EOF { + cc := cs.cc - cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + serverSentStreamEnd := cs.bufPipe.Err() == io.EOF + unread := cs.bufPipe.Len() + + if unread > 0 || !serverSentStreamEnd { + cc.mu.Lock() + cc.wmu.Lock() + if !serverSentStreamEnd { + cc.fr.WriteRSTStream(cs.ID, http2ErrCodeCancel) + } + + if unread > 0 { + cc.inflow.add(int32(unread)) + cc.fr.WriteWindowUpdate(0, uint32(unread)) + } + cc.bw.Flush() + cc.wmu.Unlock() + cc.mu.Unlock() } + cs.bufPipe.BreakWithError(http2errClosedResponseBody) return nil } @@ -6249,6 +6351,7 @@ func (b http2transportResponseBody) Close() error { func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { cc := rl.cc cs := cc.streamByID(f.StreamID, f.StreamEnded()) + data := f.Data() if cs == nil { cc.mu.Lock() neverSent := cc.nextStreamID @@ -6259,27 +6362,49 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { return http2ConnectionError(http2ErrCodeProtocol) } + if f.Length > 0 { + cc.mu.Lock() + cc.inflow.add(int32(f.Length)) + cc.mu.Unlock() + + cc.wmu.Lock() + cc.fr.WriteWindowUpdate(0, uint32(f.Length)) + cc.bw.Flush() + cc.wmu.Unlock() + } return nil } - if data := f.Data(); len(data) > 0 { - if cs.bufPipe.b == nil { + if f.Length > 0 { + if len(data) > 0 && cs.bufPipe.b == nil { cc.logf("http2: Transport received DATA frame for closed stream; closing connection") return http2ConnectionError(http2ErrCodeProtocol) } cc.mu.Lock() - if cs.inflow.available() >= int32(len(data)) { - cs.inflow.take(int32(len(data))) + if cs.inflow.available() >= int32(f.Length) { + cs.inflow.take(int32(f.Length)) } else { cc.mu.Unlock() return http2ConnectionError(http2ErrCodeFlowControl) } + + if pad := int32(f.Length) - int32(len(data)); pad > 0 { + cs.inflow.add(pad) + cc.inflow.add(pad) + cc.wmu.Lock() + cc.fr.WriteWindowUpdate(0, uint32(pad)) + cc.fr.WriteWindowUpdate(cs.ID, uint32(pad)) + cc.bw.Flush() + cc.wmu.Unlock() + } cc.mu.Unlock() - if _, err := cs.bufPipe.Write(data); err != nil { - rl.endStreamError(cs, err) - return err + if len(data) > 0 { + if _, err := cs.bufPipe.Write(data); err != nil { + rl.endStreamError(cs, err) + return err + } } } @@ -6304,9 +6429,14 @@ func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err err } cs.bufPipe.closeWithErrorAndCode(err, code) delete(rl.activeRes, cs.ID) - if cs.req.Close || cs.req.Header.Get("Connection") == "close" { + if http2isConnectionCloseRequest(cs.req) { rl.closeWhenIdle = true } + + select { + case cs.resc <- http2resAndError{err: err}: + default: + } } func (cs *http2clientStream) copyTrailers() { @@ -6334,7 +6464,16 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error cc := rl.cc cc.mu.Lock() defer cc.mu.Unlock() - return f.ForeachSetting(func(s http2Setting) error { + + if f.IsAck() { + if cc.wantSettingsAck { + cc.wantSettingsAck = false + return nil + } + return http2ConnectionError(http2ErrCodeProtocol) + } + + err := f.ForeachSetting(func(s http2Setting) error { switch s.ID { case http2SettingMaxFrameSize: cc.maxFrameSize = s.Val @@ -6342,6 +6481,16 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error cc.maxConcurrentStreams = s.Val case http2SettingInitialWindowSize: + if s.Val > math.MaxInt32 { + return http2ConnectionError(http2ErrCodeFlowControl) + } + + delta := int32(s.Val) - int32(cc.initialWindowSize) + for _, cs := range cc.streams { + cs.flow.add(delta) + } + cc.cond.Broadcast() + cc.initialWindowSize = s.Val default: @@ -6349,6 +6498,16 @@ func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error } return nil }) + if err != nil { + return err + } + + cc.wmu.Lock() + defer cc.wmu.Unlock() + + cc.fr.WriteSettingsAck() + cc.bw.Flush() + return cc.werr } func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error { @@ -6382,7 +6541,7 @@ func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) er case <-cs.peerReset: default: - err := http2StreamError{cs.ID, f.ErrCode} + err := http2streamError(cs.ID, f.ErrCode) cs.resetErr = err close(cs.peerReset) cs.bufPipe.CloseWithError(err) @@ -6560,6 +6719,12 @@ func (s http2bodyWriterState) scheduleBodyWrite() { } } +// isConnectionCloseRequest reports whether req should use its own +// connection for a single request and then close the connection. +func http2isConnectionCloseRequest(req *Request) bool { + return req.Close || httplex.HeaderValuesContainsToken(req.Header["Connection"], "close") +} + // writeFramer is implemented by any type that is used to write frames. type http2writeFramer interface { writeFrame(http2writeContext) error diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index 139ce3eafc7..13e5f283e4c 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -4716,3 +4716,14 @@ func BenchmarkCloseNotifier(b *testing.B) { } b.StopTimer() } + +// Verify this doesn't race (Issue 16505) +func TestConcurrentServerServe(t *testing.T) { + for i := 0; i < 100; i++ { + ln1 := &oneConnListener{conn: nil} + ln2 := &oneConnListener{conn: nil} + srv := Server{} + go func() { srv.Serve(ln1) }() + go func() { srv.Serve(ln2) }() + } +} diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index 7b2b4b2f423..89574a8b36e 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -2129,8 +2129,8 @@ type Server struct { ErrorLog *log.Logger disableKeepAlives int32 // accessed atomically. - nextProtoOnce sync.Once // guards initialization of TLSNextProto in Serve - nextProtoErr error + nextProtoOnce sync.Once // guards setupHTTP2_* init + nextProtoErr error // result of http2.ConfigureServer if used } // A ConnState represents the state of a client connection to a server. @@ -2260,10 +2260,8 @@ func (srv *Server) Serve(l net.Listener) error { } var tempDelay time.Duration // how long to sleep on accept failure - if srv.shouldConfigureHTTP2ForServe() { - if err := srv.setupHTTP2(); err != nil { - return err - } + if err := srv.setupHTTP2_Serve(); err != nil { + return err } // TODO: allow changing base context? can't imagine concrete @@ -2408,7 +2406,7 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { // Setup HTTP/2 before srv.Serve, to initialize srv.TLSConfig // before we clone it and create the TLS Listener. - if err := srv.setupHTTP2(); err != nil { + if err := srv.setupHTTP2_ListenAndServeTLS(); err != nil { return err } @@ -2436,14 +2434,36 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error { return srv.Serve(tlsListener) } -func (srv *Server) setupHTTP2() error { +// setupHTTP2_ListenAndServeTLS conditionally configures HTTP/2 on +// srv and returns whether there was an error setting it up. If it is +// not configured for policy reasons, nil is returned. +func (srv *Server) setupHTTP2_ListenAndServeTLS() error { srv.nextProtoOnce.Do(srv.onceSetNextProtoDefaults) return srv.nextProtoErr } +// setupHTTP2_Serve is called from (*Server).Serve and conditionally +// configures HTTP/2 on srv using a more conservative policy than +// setupHTTP2_ListenAndServeTLS because Serve may be called +// concurrently. +// +// The tests named TestTransportAutomaticHTTP2* and +// TestConcurrentServerServe in server_test.go demonstrate some +// of the supported use cases and motivations. +func (srv *Server) setupHTTP2_Serve() error { + srv.nextProtoOnce.Do(srv.onceSetNextProtoDefaults_Serve) + return srv.nextProtoErr +} + +func (srv *Server) onceSetNextProtoDefaults_Serve() { + if srv.shouldConfigureHTTP2ForServe() { + srv.onceSetNextProtoDefaults() + } +} + // onceSetNextProtoDefaults configures HTTP/2, if the user hasn't // configured otherwise. (by setting srv.TLSNextProto non-nil) -// It must only be called via srv.nextProtoOnce (use srv.setupHTTP2). +// It must only be called via srv.nextProtoOnce (use srv.setupHTTP2_*). func (srv *Server) onceSetNextProtoDefaults() { if strings.Contains(os.Getenv("GODEBUG"), "http2server=0") { return diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index 9164d0d827c..1f0763471b8 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -383,6 +383,11 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { return resp, nil } if !pconn.shouldRetryRequest(req, err) { + // Issue 16465: return underlying net.Conn.Read error from peek, + // as we've historically done. + if e, ok := err.(transportReadFromServerError); ok { + err = e.err + } return nil, err } testHookRoundTripRetried() @@ -393,6 +398,15 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { // HTTP request on a new connection. The non-nil input error is the // error from roundTrip. func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool { + if err == http2ErrNoCachedConn { + // Issue 16582: if the user started a bunch of + // requests at once, they can all pick the same conn + // and violate the server's max concurrent streams. + // Instead, match the HTTP/1 behavior for now and dial + // again to get a new TCP connection, rather than failing + // this request. + return true + } if err == errMissingHost { // User error. return false @@ -415,11 +429,19 @@ func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool { // first, per golang.org/issue/15723 return false } - if _, ok := err.(nothingWrittenError); ok { + switch err.(type) { + case nothingWrittenError: // We never wrote anything, so it's safe to retry. return true + case transportReadFromServerError: + // We got some non-EOF net.Conn.Read failure reading + // the 1st response byte from the server. + return true } - if err == errServerClosedIdle || err == errServerClosedConn { + if err == errServerClosedIdle { + // The server replied with io.EOF while we were trying to + // read the response. Probably an unfortunately keep-alive + // timeout, just as the client was writing a request. return true } return false // conservatively @@ -476,8 +498,9 @@ func (t *Transport) CloseIdleConnections() { // CancelRequest cancels an in-flight request by closing its connection. // CancelRequest should only be called after RoundTrip has returned. // -// Deprecated: Use Request.Cancel instead. CancelRequest cannot cancel -// HTTP/2 requests. +// Deprecated: Use Request.WithContext to create a request with a +// cancelable context instead. CancelRequest cannot cancel HTTP/2 +// requests. func (t *Transport) CancelRequest(req *Request) { t.reqMu.Lock() cancel := t.reqCanceler[req] @@ -566,10 +589,26 @@ var ( errCloseIdleConns = errors.New("http: CloseIdleConnections called") errReadLoopExiting = errors.New("http: persistConn.readLoop exiting") errServerClosedIdle = errors.New("http: server closed idle connection") - errServerClosedConn = errors.New("http: server closed connection") errIdleConnTimeout = errors.New("http: idle connection timeout") + errNotCachingH2Conn = errors.New("http: not caching alternate protocol's connections") ) +// transportReadFromServerError is used by Transport.readLoop when the +// 1 byte peek read fails and we're actually anticipating a response. +// Usually this is just due to the inherent keep-alive shut down race, +// where the server closed the connection at the same time the client +// wrote. The underlying err field is usually io.EOF or some +// ECONNRESET sort of thing which varies by platform. But it might be +// the user's custom net.Conn.Read error too, so we carry it along for +// them to return from Transport.RoundTrip. +type transportReadFromServerError struct { + err error +} + +func (e transportReadFromServerError) Error() string { + return fmt.Sprintf("net/http: Transport failed to read from server: %v", e.err) +} + func (t *Transport) putOrCloseIdleConn(pconn *persistConn) { if err := t.tryPutIdleConn(pconn); err != nil { pconn.close(err) @@ -595,6 +634,9 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { if pconn.isBroken() { return errConnBroken } + if pconn.alt != nil { + return errNotCachingH2Conn + } pconn.markReused() key := pconn.cacheKey @@ -1293,7 +1335,10 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er if pc.isCanceled() { return errRequestCanceled } - if err == errServerClosedIdle || err == errServerClosedConn { + if err == errServerClosedIdle { + return err + } + if _, ok := err.(transportReadFromServerError); ok { return err } if pc.isBroken() { @@ -1314,7 +1359,11 @@ func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) err return errRequestCanceled } err := pc.closed - if err == errServerClosedIdle || err == errServerClosedConn { + if err == errServerClosedIdle { + // Don't decorate + return err + } + if _, ok := err.(transportReadFromServerError); ok { // Don't decorate return err } @@ -1383,7 +1432,7 @@ func (pc *persistConn) readLoop() { if err == nil { resp, err = pc.readResponse(rc, trace) } else { - err = errServerClosedConn + err = transportReadFromServerError{err} closeErr = err } @@ -1784,6 +1833,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err var re responseAndError var respHeaderTimer <-chan time.Time cancelChan := req.Request.Cancel + ctxDoneChan := req.Context().Done() WaitResponse: for { testHookWaitResLoop() @@ -1815,9 +1865,11 @@ WaitResponse: case <-cancelChan: pc.t.CancelRequest(req.Request) cancelChan = nil - case <-req.Context().Done(): + ctxDoneChan = nil + case <-ctxDoneChan: pc.t.CancelRequest(req.Request) cancelChan = nil + ctxDoneChan = nil } } diff --git a/libgo/go/net/http/transport_internal_test.go b/libgo/go/net/http/transport_internal_test.go index a157d906300..a05ca6ed0d8 100644 --- a/libgo/go/net/http/transport_internal_test.go +++ b/libgo/go/net/http/transport_internal_test.go @@ -46,17 +46,22 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) { conn.Close() // simulate the server hanging up on the client _, err = pc.roundTrip(treq) - if err != errServerClosedConn && err != errServerClosedIdle { + if !isTransportReadFromServerError(err) && err != errServerClosedIdle { t.Fatalf("roundTrip = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err) } <-pc.closech err = pc.closed - if err != errServerClosedConn && err != errServerClosedIdle { + if !isTransportReadFromServerError(err) && err != errServerClosedIdle { t.Fatalf("pc.closed = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err) } } +func isTransportReadFromServerError(err error) bool { + _, ok := err.(transportReadFromServerError) + return ok +} + func newLocalListener(t *testing.T) net.Listener { ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index 72b98f16d7e..298682d04de 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -3511,6 +3511,100 @@ func TestTransportIdleConnTimeout(t *testing.T) { } } +// Issue 16208: Go 1.7 crashed after Transport.IdleConnTimeout if an +// HTTP/2 connection was established but but its caller no longer +// wanted it. (Assuming the connection cache was enabled, which it is +// by default) +// +// This test reproduced the crash by setting the IdleConnTimeout low +// (to make the test reasonable) and then making a request which is +// canceled by the DialTLS hook, which then also waits to return the +// real connection until after the RoundTrip saw the error. Then we +// know the successful tls.Dial from DialTLS will need to go into the +// idle pool. Then we give it a of time to explode. +func TestIdleConnH2Crash(t *testing.T) { + cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + // nothing + })) + defer cst.close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + gotErr := make(chan bool, 1) + + cst.tr.IdleConnTimeout = 5 * time.Millisecond + cst.tr.DialTLS = func(network, addr string) (net.Conn, error) { + cancel() + <-gotErr + c, err := tls.Dial(network, addr, &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + }) + if err != nil { + t.Error(err) + return nil, err + } + if cs := c.ConnectionState(); cs.NegotiatedProtocol != "h2" { + t.Errorf("protocol = %q; want %q", cs.NegotiatedProtocol, "h2") + c.Close() + return nil, errors.New("bogus") + } + return c, nil + } + + req, _ := NewRequest("GET", cst.ts.URL, nil) + req = req.WithContext(ctx) + res, err := cst.c.Do(req) + if err == nil { + res.Body.Close() + t.Fatal("unexpected success") + } + gotErr <- true + + // Wait for the explosion. + time.Sleep(cst.tr.IdleConnTimeout * 10) +} + +type funcConn struct { + net.Conn + read func([]byte) (int, error) + write func([]byte) (int, error) +} + +func (c funcConn) Read(p []byte) (int, error) { return c.read(p) } +func (c funcConn) Write(p []byte) (int, error) { return c.write(p) } +func (c funcConn) Close() error { return nil } + +// Issue 16465: Transport.RoundTrip should return the raw net.Conn.Read error from Peek +// back to the caller. +func TestTransportReturnsPeekError(t *testing.T) { + errValue := errors.New("specific error value") + + wrote := make(chan struct{}) + var wroteOnce sync.Once + + tr := &Transport{ + Dial: func(network, addr string) (net.Conn, error) { + c := funcConn{ + read: func([]byte) (int, error) { + <-wrote + return 0, errValue + }, + write: func(p []byte) (int, error) { + wroteOnce.Do(func() { close(wrote) }) + return len(p), nil + }, + } + return c, nil + }, + } + _, err := tr.RoundTrip(httptest.NewRequest("GET", "http://fake.tld/", nil)) + if err != errValue { + t.Errorf("error = %#v; want %v", err, errValue) + } +} + var errFakeRoundTrip = errors.New("fake roundtrip") type funcRoundTripper func() |