From f56473fb013ededdc63021ca80467be784e45358 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Wed, 8 Aug 2018 13:36:01 -0700 Subject: [PATCH 1/2] make sure reset works on half-closed streams --- p2p/net/mock/mock_stream.go | 73 ++++++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/p2p/net/mock/mock_stream.go b/p2p/net/mock/mock_stream.go index 79a3834c0..6d056c8ea 100644 --- a/p2p/net/mock/mock_stream.go +++ b/p2p/net/mock/mock_stream.go @@ -22,7 +22,7 @@ type stream struct { close chan struct{} closed chan struct{} - state error + writeErr error protocol protocol.ID } @@ -56,7 +56,7 @@ func (s *stream) Write(p []byte) (n int, err error) { t := time.Now().Add(delay) select { case <-s.closed: // bail out if we're closing. - return 0, s.state + return 0, s.writeErr case s.toDeliver <- &transportObject{msg: p, arrivalTime: t}: } return len(p), nil @@ -76,30 +76,28 @@ func (s *stream) Close() error { default: } <-s.closed - if s.state != ErrClosed { - return s.state + if s.writeErr != ErrClosed { + return s.writeErr } return nil } func (s *stream) Reset() error { - // Cancel any pending writes. - s.write.Close() + // Cancel any pending reads/writes with an error. + s.write.CloseWithError(ErrReset) + s.read.CloseWithError(ErrReset) select { case s.reset <- struct{}{}: default: } <-s.closed - if s.state != ErrReset { - return s.state - } + + // No meaningful error case here. return nil } func (s *stream) teardown() { - s.write.Close() - // at this point, no streams are writing. s.conn.removeStream(s) @@ -151,20 +149,21 @@ func (s *stream) transport() { // writeBuf writes the contents of buf through to the s.Writer. // done only when arrival time makes sense. - drainBuf := func() { + drainBuf := func() error { if buf.Len() > 0 { _, err := s.write.Write(buf.Bytes()) if err != nil { - return + return err } buf.Reset() } + return nil } // deliverOrWait is a helper func that processes // an incoming packet. it waits until the arrival time, // and then writes things out. - deliverOrWait := func(o *transportObject) { + deliverOrWait := func(o *transportObject) error { buffered := len(o.msg) + buf.Len() // Yes, we can end up extending a timer multiple times if we @@ -189,43 +188,65 @@ func (s *stream) transport() { select { case <-timer.C: case <-s.reset: - s.reset <- struct{}{} - return + select { + case s.reset <- struct{}{}: + default: + } + return ErrReset + } + if err := drainBuf(); err != nil { + return err } - drainBuf() // write this message. _, err := s.write.Write(o.msg) if err != nil { - log.Error("mock_stream", err) + return err } } else { buf.Write(o.msg) } + return nil } for { // Reset takes precedent. select { case <-s.reset: - s.state = ErrReset - s.read.CloseWithError(ErrReset) + s.writeErr = ErrReset return default: } select { case <-s.reset: - s.state = ErrReset - s.read.CloseWithError(ErrReset) + s.writeErr = ErrReset return case <-s.close: - s.state = ErrClosed - drainBuf() + if err := drainBuf(); err != nil { + s.resetWith(err) + return + } + s.writeErr = s.write.Close() + if s.writeErr == nil { + s.writeErr = ErrClosed + } return case o := <-s.toDeliver: - deliverOrWait(o) + if err := deliverOrWait(o); err != nil { + s.resetWith(err) + return + } case <-timer.C: // ok, due to write it out. - drainBuf() + if err := drainBuf(); err != nil { + s.resetWith(err) + return + } } } } + +func (s *stream) resetWith(err error) { + s.write.CloseWithError(err) + s.read.CloseWithError(err) + s.writeErr = err +} From 784ff9c5a540c42de696c9cc0472b71662cea2f2 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Wed, 8 Aug 2018 13:37:29 -0700 Subject: [PATCH 2/2] fix closing of mock streams in test --- p2p/net/mock/mock_notif_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/p2p/net/mock/mock_notif_test.go b/p2p/net/mock/mock_notif_test.go index 3521d6f86..13758ff75 100644 --- a/p2p/net/mock/mock_notif_test.go +++ b/p2p/net/mock/mock_notif_test.go @@ -114,7 +114,7 @@ func TestNotifications(t *testing.T) { for _, s := range nets { s.SetStreamHandler(func(s inet.Stream) { - s.Close() + inet.FullClose(s) }) } @@ -131,7 +131,7 @@ func TestNotifications(t *testing.T) { for _, s := range nets { s.SetStreamHandler(func(s inet.Stream) { streams <- s - s.Close() + inet.FullClose(s) }) } @@ -146,7 +146,7 @@ func TestNotifications(t *testing.T) { } else { t.Logf("%s %s <--%p--> %s %s", c.LocalPeer(), c.LocalMultiaddr(), st1, c.RemotePeer(), c.RemoteMultiaddr()) // st1.Write([]byte("hello")) - st1.Close() + go inet.FullClose(st1) st2 := <-streams t.Logf("%s %s <--%p--> %s %s", c2.LocalPeer(), c2.LocalMultiaddr(), st2, c2.RemotePeer(), c2.RemoteMultiaddr()) testOCStream(notifiees[i], st1)