diff --git a/p2p/transport/webtransport/multiaddr.go b/p2p/transport/webtransport/multiaddr.go index 5560d90f6..9d7999193 100644 --- a/p2p/transport/webtransport/multiaddr.go +++ b/p2p/transport/webtransport/multiaddr.go @@ -1,6 +1,7 @@ package libp2pwebtransport import ( + "errors" "net" ma "github.com/multiformats/go-multiaddr" @@ -13,9 +14,12 @@ var webtransportMA = ma.StringCast("/quic/webtransport") var webtransportMatcher = mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC), mafmt.Base(ma.P_WEBTRANSPORT)) func toWebtransportMultiaddr(na net.Addr) (ma.Multiaddr, error) { - udpMA, err := manet.FromNetAddr(na) + addr, err := manet.FromNetAddr(na) if err != nil { return nil, err } - return udpMA.Encapsulate(webtransportMA), nil + if _, err := addr.ValueForProtocol(ma.P_UDP); err != nil { + return nil, errors.New("not a UDP address") + } + return addr.Encapsulate(webtransportMA), nil } diff --git a/p2p/transport/webtransport/multiaddr_test.go b/p2p/transport/webtransport/multiaddr_test.go new file mode 100644 index 000000000..2098babbd --- /dev/null +++ b/p2p/transport/webtransport/multiaddr_test.go @@ -0,0 +1,21 @@ +package libp2pwebtransport + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestWebtransportMultiaddr(t *testing.T) { + t.Run("valid", func(t *testing.T) { + addr, err := toWebtransportMultiaddr(&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}) + require.NoError(t, err) + require.Equal(t, "/ip4/127.0.0.1/udp/1337/quic/webtransport", addr.String()) + }) + + t.Run("invalid", func(t *testing.T) { + _, err := toWebtransportMultiaddr(&net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}) + require.EqualError(t, err, "not a UDP address") + }) +}