From d8bb57c042ba8080d6bae749dac860be7fa58ee2 Mon Sep 17 00:00:00 2001 From: sukun Date: Wed, 26 Jun 2024 14:02:28 +0530 Subject: [PATCH] metricshelper: improve checks for ip and transport (#2849) --- p2p/metricshelper/conn.go | 15 +++++--- p2p/metricshelper/conn_test.go | 68 ++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 5 deletions(-) create mode 100644 p2p/metricshelper/conn_test.go diff --git a/p2p/metricshelper/conn.go b/p2p/metricshelper/conn.go index b07016ce8..5fd8a114a 100644 --- a/p2p/metricshelper/conn.go +++ b/p2p/metricshelper/conn.go @@ -5,6 +5,9 @@ import ma "github.com/multiformats/go-multiaddr" var transports = [...]int{ma.P_CIRCUIT, ma.P_WEBRTC, ma.P_WEBRTC_DIRECT, ma.P_WEBTRANSPORT, ma.P_QUIC, ma.P_QUIC_V1, ma.P_WSS, ma.P_WS, ma.P_TCP} func GetTransport(a ma.Multiaddr) string { + if a == nil { + return "other" + } for _, t := range transports { if _, err := a.ValueForProtocol(t); err == nil { return ma.ProtocolWithCode(t).Name @@ -15,15 +18,17 @@ func GetTransport(a ma.Multiaddr) string { func GetIPVersion(addr ma.Multiaddr) string { version := "unknown" + if addr == nil { + return version + } ma.ForEach(addr, func(c ma.Component) bool { - if c.Protocol().Code == ma.P_IP4 { + switch c.Protocol().Code { + case ma.P_IP4, ma.P_DNS4: version = "ip4" - return false - } else if c.Protocol().Code == ma.P_IP6 { + case ma.P_IP6, ma.P_DNS6: version = "ip6" - return false } - return true + return false }) return version } diff --git a/p2p/metricshelper/conn_test.go b/p2p/metricshelper/conn_test.go new file mode 100644 index 000000000..0650bfec2 --- /dev/null +++ b/p2p/metricshelper/conn_test.go @@ -0,0 +1,68 @@ +package metricshelper + +import ( + "fmt" + "testing" + + ma "github.com/multiformats/go-multiaddr" +) + +func TestGetTransport(t *testing.T) { + cases := []struct { + addr ma.Multiaddr + result string + }{ + { + addr: ma.StringCast("/ip4/1.1.1.1/tcp/1"), + result: "tcp", + }, + { + addr: ma.StringCast("/ip4/1.1.1.1/udp/10"), + result: "other", + }, + { + addr: nil, + result: "other", + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + got := GetTransport(tc.addr) + if got != tc.result { + t.Fatalf("invalid transport for %s\ngot:%v\nwant:%v", tc.addr, got, tc.result) + } + }) + } +} + +func TestIPVersion(t *testing.T) { + cases := []struct { + addr ma.Multiaddr + result string + }{ + { + addr: ma.StringCast("/ip4/1.1.1.1/tcp/1"), + result: "ip4", + }, + { + addr: ma.StringCast("/ip4/1.1.1.1/udp/10"), + result: "ip4", + }, + { + addr: nil, + result: "unknown", + }, + { + addr: ma.StringCast("/dns/hello.world/tcp/10"), + result: "unknown", + }, + } + for i, tc := range cases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + got := GetIPVersion(tc.addr) + if got != tc.result { + t.Fatalf("invalid ip version for %s\ngot:%v\nwant:%v", tc.addr, got, tc.result) + } + }) + } +}