Browse Source

fix: don't prefer local ports from other addresses when dialing (#1673)

* fix: don't prefer local ports from other addresses when dialing

This address may already be in-use (on that other address) somewhere
else.

Thanks to @schomatis for figuring this out.

fixes #1611

* chore: document reuseport dialer logic
release-v022
Steven Allen 2 years ago
committed by GitHub
parent
commit
bbd28365c5
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 57
      p2p/net/reuseport/dial.go
  2. 114
      p2p/net/reuseport/dialer.go
  3. 90
      p2p/net/reuseport/multidialer.go
  4. 16
      p2p/net/reuseport/singledialer.go
  5. 2
      p2p/net/reuseport/transport.go
  6. 8
      p2p/net/reuseport/transport_test.go

57
p2p/net/reuseport/dial.go

@ -2,18 +2,11 @@ package reuseport
import (
"context"
"net"
"github.com/libp2p/go-reuseport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
)
type dialer interface {
Dial(network, addr string) (net.Conn, error)
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
}
// Dial dials the given multiaddr, reusing ports we're currently listening on if
// possible.
//
@ -31,7 +24,7 @@ func (t *Transport) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet.
if err != nil {
return nil, err
}
var d dialer
var d *dialer
switch network {
case "tcp4":
d = t.v4.getDialer(network)
@ -52,7 +45,7 @@ func (t *Transport) DialContext(ctx context.Context, raddr ma.Multiaddr) (manet.
return maconn, nil
}
func (n *network) getDialer(network string) dialer {
func (n *network) getDialer(network string) *dialer {
n.mu.RLock()
d := n.dialer
n.mu.RUnlock()
@ -61,53 +54,9 @@ func (n *network) getDialer(network string) dialer {
defer n.mu.Unlock()
if n.dialer == nil {
n.dialer = n.makeDialer(network)
n.dialer = newDialer(n.listeners)
}
d = n.dialer
}
return d
}
func (n *network) makeDialer(network string) dialer {
if !reuseport.Available() {
log.Debug("reuseport not available")
return &net.Dialer{}
}
var unspec net.IP
switch network {
case "tcp4":
unspec = net.IPv4zero
case "tcp6":
unspec = net.IPv6unspecified
default:
panic("invalid network: must be either tcp4 or tcp6")
}
// How many ports are we listening on.
var port = 0
for l := range n.listeners {
newPort := l.Addr().(*net.TCPAddr).Port
switch {
case newPort == 0: // Any port, ignore (really, we shouldn't get this case...).
case port == 0: // Haven't selected a port yet, choose this one.
port = newPort
case newPort == port: // Same as the selected port, continue...
default: // Multiple ports, use the multi dialer
return newMultiDialer(unspec, n.listeners)
}
}
// None.
if port == 0 {
return &net.Dialer{}
}
// One. Always dial from the single port we're listening on.
laddr := &net.TCPAddr{
IP: unspec,
Port: port,
}
return (*singleDialer)(laddr)
}

114
p2p/net/reuseport/dialer.go

@ -0,0 +1,114 @@
package reuseport
import (
"context"
"fmt"
"math/rand"
"net"
"github.com/libp2p/go-netroute"
)
type dialer struct {
// All address that are _not_ loopback or unspecified (0.0.0.0 or ::).
specific []*net.TCPAddr
// All loopback addresses (127.*.*.*, ::1).
loopback []*net.TCPAddr
// Unspecified addresses (0.0.0.0, ::)
unspecified []*net.TCPAddr
}
func (d *dialer) Dial(network, addr string) (net.Conn, error) {
return d.DialContext(context.Background(), network, addr)
}
func randAddr(addrs []*net.TCPAddr) *net.TCPAddr {
if len(addrs) > 0 {
return addrs[rand.Intn(len(addrs))]
}
return nil
}
// DialContext dials a target addr.
//
// In-order:
//
// 1. If we're _explicitly_ listening on the prefered source address for the destination address
// (per the system's routes), we'll use that listener's port as the source port.
// 2. If we're listening on one or more _unspecified_ addresses (zero address), we'll pick a source
// port from one of these listener's.
// 3. Otherwise, we'll let the system pick the source port.
func (d *dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
// We only check this case if the user is listening on a specific address (loopback or
// otherwise). Generally, users will listen on the "unspecified" address (0.0.0.0 or ::) and
// we can skip this section.
//
// This lets us avoid resolving the address twice, in most cases.
if len(d.specific) > 0 || len(d.loopback) > 0 {
tcpAddr, err := net.ResolveTCPAddr(network, addr)
if err != nil {
return nil, err
}
ip := tcpAddr.IP
if !ip.IsLoopback() && !ip.IsGlobalUnicast() {
return nil, fmt.Errorf("undialable IP: %s", ip)
}
// If we're listening on some specific address and that specific address happens to
// be the preferred source address for the target destination address, we try to
// dial with that address/port.
//
// We skip this check if we _aren't_ listening on any specific addresses, because
// checking routing tables can be expensive and users rarely listen on specific IP
// addresses.
if len(d.specific) > 0 {
if router, err := netroute.New(); err == nil {
if _, _, preferredSrc, err := router.Route(ip); err == nil {
for _, optAddr := range d.specific {
if optAddr.IP.Equal(preferredSrc) {
return reuseDial(ctx, optAddr, network, addr)
}
}
}
}
}
// Otherwise, if we are listening on a loopback address and the destination is also
// a loopback address, use the port from our loopback listener.
if len(d.loopback) > 0 && ip.IsLoopback() {
return reuseDial(ctx, randAddr(d.loopback), network, addr)
}
}
// If we're listening on any uspecified addresses, use a randomly chosen port from one of
// these listeners.
if len(d.unspecified) > 0 {
return reuseDial(ctx, randAddr(d.unspecified), network, addr)
}
// Finally, just pick a random port.
var dialer net.Dialer
return dialer.DialContext(ctx, network, addr)
}
func newDialer(listeners map[*listener]struct{}) *dialer {
specific := make([]*net.TCPAddr, 0)
loopback := make([]*net.TCPAddr, 0)
unspecified := make([]*net.TCPAddr, 0)
for l := range listeners {
addr := l.Addr().(*net.TCPAddr)
if addr.IP.IsLoopback() {
loopback = append(loopback, addr)
} else if addr.IP.IsUnspecified() {
unspecified = append(unspecified, addr)
} else {
specific = append(specific, addr)
}
}
return &dialer{
specific: specific,
loopback: loopback,
unspecified: unspecified,
}
}

90
p2p/net/reuseport/multidialer.go

@ -1,90 +0,0 @@
package reuseport
import (
"context"
"fmt"
"math/rand"
"net"
"github.com/libp2p/go-netroute"
)
type multiDialer struct {
listeningAddresses []*net.TCPAddr
loopback []*net.TCPAddr
unspecified []*net.TCPAddr
fallback net.TCPAddr
}
func (d *multiDialer) Dial(network, addr string) (net.Conn, error) {
return d.DialContext(context.Background(), network, addr)
}
func randAddr(addrs []*net.TCPAddr) *net.TCPAddr {
if len(addrs) > 0 {
return addrs[rand.Intn(len(addrs))]
}
return nil
}
// DialContext dials a target addr.
// Dialing preference is
// * If there is a listener on the local interface the OS expects to use to route towards addr, use that.
// * If there is a listener on a loopback address, addr is loopback, use that.
// * If there is a listener on an undefined address (0.0.0.0 or ::), use that.
// * Use the fallback IP specified during construction, with a port that's already being listened on, if one exists.
func (d *multiDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
tcpAddr, err := net.ResolveTCPAddr(network, addr)
if err != nil {
return nil, err
}
ip := tcpAddr.IP
if !ip.IsLoopback() && !ip.IsGlobalUnicast() {
return nil, fmt.Errorf("undialable IP: %s", ip)
}
if router, err := netroute.New(); err == nil {
if _, _, preferredSrc, err := router.Route(ip); err == nil {
for _, optAddr := range d.listeningAddresses {
if optAddr.IP.Equal(preferredSrc) {
return reuseDial(ctx, optAddr, network, addr)
}
}
}
}
if ip.IsLoopback() && len(d.loopback) > 0 {
return reuseDial(ctx, randAddr(d.loopback), network, addr)
}
if len(d.unspecified) == 0 {
return reuseDial(ctx, &d.fallback, network, addr)
}
return reuseDial(ctx, randAddr(d.unspecified), network, addr)
}
func newMultiDialer(unspec net.IP, listeners map[*listener]struct{}) (m dialer) {
addrs := make([]*net.TCPAddr, 0)
loopback := make([]*net.TCPAddr, 0)
unspecified := make([]*net.TCPAddr, 0)
existingPort := 0
for l := range listeners {
addr := l.Addr().(*net.TCPAddr)
addrs = append(addrs, addr)
if addr.IP.IsLoopback() {
loopback = append(loopback, addr)
} else if addr.IP.IsGlobalUnicast() && existingPort == 0 {
existingPort = addr.Port
} else if addr.IP.IsUnspecified() {
unspecified = append(unspecified, addr)
}
}
m = &multiDialer{
listeningAddresses: addrs,
loopback: loopback,
unspecified: unspecified,
fallback: net.TCPAddr{IP: unspec, Port: existingPort},
}
return
}

16
p2p/net/reuseport/singledialer.go

@ -1,16 +0,0 @@
package reuseport
import (
"context"
"net"
)
type singleDialer net.TCPAddr
func (d *singleDialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
func (d *singleDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return reuseDial(ctx, (*net.TCPAddr)(d), network, address)
}

2
p2p/net/reuseport/transport.go

@ -31,5 +31,5 @@ type Transport struct {
type network struct {
mu sync.RWMutex
listeners map[*listener]struct{}
dialer dialer
dialer *dialer
}

8
p2p/net/reuseport/transport_test.go

@ -141,7 +141,6 @@ func TestGlobalPreferenceV4(t *testing.T) {
testPrefer(t, loopbackV4, loopbackV4, globalV4)
t.Logf("when listening on %v, should prefer %v over %v", loopbackV4, unspecV4, globalV4)
testPrefer(t, loopbackV4, unspecV4, globalV4)
t.Logf("when listening on %v, should prefer %v over %v", globalV4, unspecV4, loopbackV4)
testPrefer(t, globalV4, unspecV4, loopbackV4)
}
@ -177,8 +176,6 @@ func testPrefer(t *testing.T, listen, prefer, avoid ma.Multiaddr) {
}
defer listenerB1.Close()
dialOne(t, &trB, listenerA, listenerB1.Addr().(*net.TCPAddr).Port)
listenerB2, err := trB.Listen(prefer)
if err != nil {
t.Fatal(err)
@ -186,11 +183,6 @@ func testPrefer(t *testing.T, listen, prefer, avoid ma.Multiaddr) {
defer listenerB2.Close()
dialOne(t, &trB, listenerA, listenerB2.Addr().(*net.TCPAddr).Port)
// Closing the listener should reset the dialer.
listenerB2.Close()
dialOne(t, &trB, listenerA, listenerB1.Addr().(*net.TCPAddr).Port)
}
func TestV6V4(t *testing.T) {

Loading…
Cancel
Save