diff --git a/net/mock/mock_notif_test.go b/net/mock/mock_notif_test.go index 92a415f16..611f36f55 100644 --- a/net/mock/mock_notif_test.go +++ b/net/mock/mock_notif_test.go @@ -111,10 +111,8 @@ func TestNotifications(t *testing.T) { } } - streams := make(chan inet.Stream) for _, s := range nets { s.SetStreamHandler(func(s inet.Stream) { - streams <- s s.Close() }) } @@ -123,10 +121,17 @@ func TestNotifications(t *testing.T) { // unsure where these are coming from for i := range nets { n := notifiees[i] - testOCStream(n, nil) - testOCStream(n, nil) - testOCStream(n, nil) - testOCStream(n, nil) + for j := 0; j < len(nets)-1; j++ { + testOCStream(n, nil) + } + } + + streams := make(chan inet.Stream) + for _, s := range nets { + s.SetStreamHandler(func(s inet.Stream) { + streams <- s + s.Close() + }) } // open a streams in each conn diff --git a/protocol/identify/id.go b/protocol/identify/id.go index d9cc0b1e0..7c4fc0b73 100644 --- a/protocol/identify/id.go +++ b/protocol/identify/id.go @@ -80,6 +80,8 @@ func (ids *IDService) IdentifyConn(c inet.Conn) { if err != nil { log.Debugf("error opening initial stream for %s", ID) log.Event(context.TODO(), "IdentifyOpenFailed", c.RemotePeer()) + c.Close() + return } else { bwc := ids.Host.GetBandwidthReporter() s = mstream.WrapStream(s, ID, bwc) @@ -88,6 +90,9 @@ func (ids *IDService) IdentifyConn(c inet.Conn) { if err := protocol.WriteHeader(s, ID); err != nil { log.Debugf("error writing stream header for %s", ID) log.Event(context.TODO(), "IdentifyOpenFailed", c.RemotePeer()) + s.Close() + c.Close() + return } ids.ResponseHandler(s) }