Browse Source

fixup for transport refactor

Also, use the now standardized transport tests (and fix bugs around closing the
listener).
pull/1422/head
Steven Allen 7 years ago
parent
commit
42780280b4
  1. 42
      p2p/transport/websocket/dialer.go
  2. 31
      p2p/transport/websocket/dialer_test.go
  3. 88
      p2p/transport/websocket/listener.go
  4. 98
      p2p/transport/websocket/websocket.go
  5. 58
      p2p/transport/websocket/websocket_test.go

42
p2p/transport/websocket/dialer.go

@ -1,42 +0,0 @@
package websocket
import (
"context"
ws "github.com/gorilla/websocket"
tpt "github.com/libp2p/go-libp2p-transport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr-net"
)
type dialer struct{}
func (d *dialer) Dial(raddr ma.Multiaddr) (tpt.Conn, error) {
return d.DialContext(context.Background(), raddr)
}
func (d *dialer) DialContext(ctx context.Context, raddr ma.Multiaddr) (tpt.Conn, error) {
wsurl, err := parseMultiaddr(raddr)
if err != nil {
return nil, err
}
wscon, _, err := ws.DefaultDialer.Dial(wsurl, nil)
if err != nil {
return nil, err
}
mnc, err := manet.WrapNetConn(NewConn(wscon, nil))
if err != nil {
wscon.Close()
return nil, err
}
return &wsConn{
Conn: mnc,
}, nil
}
func (d *dialer) Matches(a ma.Multiaddr) bool {
return WsFmt.Matches(a)
}

31
p2p/transport/websocket/dialer_test.go

@ -1,31 +0,0 @@
package websocket
import (
"testing"
ma "github.com/multiformats/go-multiaddr"
)
func TestDialerMatches(t *testing.T) {
addrWs, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/5555/ws")
if err != nil {
t.Fatal(err)
}
addrTcp, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/5555")
if err != nil {
t.Fatal(err)
}
d := &dialer{}
matchTrue := d.Matches(addrWs)
matchFalse := d.Matches(addrTcp)
if !matchTrue {
t.Fatal("expected to match websocket maddr, but did not")
}
if matchFalse {
t.Fatal("expected to not match tcp maddr, but did")
}
}

88
p2p/transport/websocket/listener.go

@ -3,33 +3,25 @@ package websocket
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
tpt "github.com/libp2p/go-libp2p-transport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr-net"
)
type wsConn struct {
manet.Conn
t tpt.Transport
}
var _ tpt.Conn = (*wsConn)(nil)
func (c *wsConn) Transport() tpt.Transport {
return c.t
}
type listener struct {
manet.Listener
net.Listener
incoming chan *Conn
laddr ma.Multiaddr
tpt tpt.Transport
closed chan struct{}
incoming chan *Conn
}
origin *url.URL
func (l *listener) serve() {
defer close(l.closed)
http.Serve(l.Listener, l)
}
func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@ -40,35 +32,55 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
ctx, cancel := context.WithCancel(context.Background())
l.incoming <- NewConn(c, cancel)
// wait until conn gets closed, otherwise the handler closes it early
<-ctx.Done()
}
func (l *listener) Accept() (tpt.Conn, error) {
c, ok := <-l.incoming
if !ok {
return nil, fmt.Errorf("listener is closed")
var cnCh <-chan bool
if cn, ok := w.(http.CloseNotifier); ok {
cnCh = cn.CloseNotify()
}
mnc, err := manet.WrapNetConn(c)
if err != nil {
wscon := NewConn(c, cancel)
// Just to make sure.
defer wscon.Close()
select {
case l.incoming <- wscon:
case <-l.closed:
c.Close()
return nil, err
return
case <-cnCh:
return
}
return &wsConn{
Conn: mnc,
t: l.tpt,
}, nil
// wait until conn gets closed, otherwise the handler closes it early
select {
case <-ctx.Done():
case <-l.closed:
c.Close()
return
case <-cnCh:
return
}
}
func (l *listener) Multiaddr() ma.Multiaddr {
wsma, err := ma.NewMultiaddr("/ws")
if err != nil {
panic(err)
func (l *listener) Accept() (manet.Conn, error) {
select {
case c, ok := <-l.incoming:
if !ok {
return nil, fmt.Errorf("listener is closed")
}
mnc, err := manet.WrapNetConn(c)
if err != nil {
c.Close()
return nil, err
}
return mnc, nil
case <-l.closed:
return nil, fmt.Errorf("listener is closed")
}
}
return l.Listener.Multiaddr().Encapsulate(wsma)
func (l *listener) Multiaddr() ma.Multiaddr {
return l.laddr
}

98
p2p/transport/websocket/websocket.go

@ -2,15 +2,18 @@
package websocket
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
ws "github.com/gorilla/websocket"
peer "github.com/libp2p/go-libp2p-peer"
tpt "github.com/libp2p/go-libp2p-transport"
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr-net"
mafmt "github.com/whyrusleeping/mafmt"
)
@ -50,41 +53,106 @@ func init() {
}
// WebsocketTransport is the actual go-libp2p transport
type WebsocketTransport struct{}
type WebsocketTransport struct {
Upgrader *tptu.Upgrader
}
func New(u *tptu.Upgrader) *WebsocketTransport {
return &WebsocketTransport{u}
}
var _ tpt.Transport = (*WebsocketTransport)(nil)
func (t *WebsocketTransport) Matches(a ma.Multiaddr) bool {
func (t *WebsocketTransport) CanDial(a ma.Multiaddr) bool {
return WsFmt.Matches(a)
}
func (t *WebsocketTransport) Dialer(_ ma.Multiaddr, opts ...tpt.DialOpt) (tpt.Dialer, error) {
return &dialer{}, nil
func (t *WebsocketTransport) Protocols() []int {
return []int{WsProtocol.Code}
}
func (t *WebsocketTransport) Listen(a ma.Multiaddr) (tpt.Listener, error) {
list, err := manet.Listen(a)
func (t *WebsocketTransport) Proxy() bool {
return false
}
func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (manet.Conn, error) {
wsurl, err := parseMultiaddr(raddr)
if err != nil {
return nil, err
}
wscon, _, err := ws.DefaultDialer.Dial(wsurl, nil)
if err != nil {
return nil, err
}
mnc, err := manet.WrapNetConn(NewConn(wscon, nil))
if err != nil {
wscon.Close()
return nil, err
}
return mnc, nil
}
func (t *WebsocketTransport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.Conn, error) {
macon, err := t.maDial(ctx, raddr)
if err != nil {
return nil, err
}
return t.Upgrader.UpgradeOutbound(ctx, t, macon, p)
}
func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) {
lnet, lnaddr, err := manet.DialArgs(a)
if err != nil {
return nil, err
}
nl, err := net.Listen(lnet, lnaddr)
if err != nil {
return nil, err
}
u, err := url.Parse("http://" + list.Addr().String())
u, err := url.Parse("http://" + nl.Addr().String())
if err != nil {
nl.Close()
return nil, err
}
tlist := t.wrapListener(list, u)
malist, err := t.wrapListener(nl, u)
if err != nil {
nl.Close()
return nil, err
}
go http.Serve(list.NetListener(), tlist)
go malist.serve()
return tlist, nil
return malist, nil
}
func (t *WebsocketTransport) wrapListener(l manet.Listener, origin *url.URL) *listener {
func (t *WebsocketTransport) Listen(a ma.Multiaddr) (tpt.Listener, error) {
malist, err := t.maListen(a)
if err != nil {
return nil, err
}
return t.Upgrader.UpgradeListener(t, malist), nil
}
func (t *WebsocketTransport) wrapListener(l net.Listener, origin *url.URL) (*listener, error) {
laddr, err := manet.FromNetAddr(l.Addr())
if err != nil {
return nil, err
}
wsma, err := ma.NewMultiaddr("/ws")
if err != nil {
return nil, err
}
laddr = laddr.Encapsulate(wsma)
return &listener{
laddr: laddr,
Listener: l,
incoming: make(chan *Conn),
tpt: t,
origin: origin,
}
closed: make(chan struct{}),
}, nil
}

58
p2p/transport/websocket/websocket_test.go

@ -2,14 +2,57 @@ package websocket
import (
"bytes"
"context"
"io"
"io/ioutil"
"testing"
"testing/iotest"
insecure "github.com/libp2p/go-conn-security/insecure"
tptu "github.com/libp2p/go-libp2p-transport-upgrader"
utils "github.com/libp2p/go-libp2p-transport/test"
ma "github.com/multiformats/go-multiaddr"
mplex "github.com/whyrusleeping/go-smux-multiplex"
)
func TestCanDial(t *testing.T) {
addrWs, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/5555/ws")
if err != nil {
t.Fatal(err)
}
addrTCP, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/5555")
if err != nil {
t.Fatal(err)
}
d := &WebsocketTransport{}
matchTrue := d.CanDial(addrWs)
matchFalse := d.CanDial(addrTCP)
if !matchTrue {
t.Fatal("expected to match websocket maddr, but did not")
}
if matchFalse {
t.Fatal("expected to not match tcp maddr, but did")
}
}
func TestWebsocketTransport(t *testing.T) {
ta := New(&tptu.Upgrader{
Secure: insecure.New("peerA"),
Muxer: new(mplex.Transport),
})
tb := New(&tptu.Upgrader{
Secure: insecure.New("peerB"),
Muxer: new(mplex.Transport),
})
zero := "/ip4/127.0.0.1/tcp/0/ws"
utils.SubtestTransport(t, ta, tb, zero, "peerA")
}
func TestWebsocketListen(t *testing.T) {
zero, err := ma.NewMultiaddr("/ip4/127.0.0.1/tcp/0/ws")
if err != nil {
@ -17,7 +60,7 @@ func TestWebsocketListen(t *testing.T) {
}
tpt := &WebsocketTransport{}
l, err := tpt.Listen(zero)
l, err := tpt.maListen(zero)
if err != nil {
t.Fatal(err)
}
@ -26,8 +69,7 @@ func TestWebsocketListen(t *testing.T) {
msg := []byte("HELLO WORLD")
go func() {
d, _ := tpt.Dialer(nil)
c, err := d.Dial(l.Multiaddr())
c, err := tpt.maDial(context.Background(), l.Multiaddr())
if err != nil {
t.Error(err)
return
@ -62,7 +104,7 @@ func TestConcurrentClose(t *testing.T) {
}
tpt := &WebsocketTransport{}
l, err := tpt.Listen(zero)
l, err := tpt.maListen(zero)
if err != nil {
t.Fatal(err)
}
@ -71,9 +113,8 @@ func TestConcurrentClose(t *testing.T) {
msg := []byte("HELLO WORLD")
go func() {
d, _ := tpt.Dialer(nil)
for i := 0; i < 100; i++ {
c, err := d.Dial(l.Multiaddr())
c, err := tpt.maDial(context.Background(), l.Multiaddr())
if err != nil {
t.Error(err)
return
@ -100,7 +141,7 @@ func TestWriteZero(t *testing.T) {
}
tpt := &WebsocketTransport{}
l, err := tpt.Listen(zero)
l, err := tpt.maListen(zero)
if err != nil {
t.Fatal(err)
}
@ -109,8 +150,7 @@ func TestWriteZero(t *testing.T) {
msg := []byte(nil)
go func() {
d, _ := tpt.Dialer(nil)
c, err := d.Dial(l.Multiaddr())
c, err := tpt.maDial(context.Background(), l.Multiaddr())
defer c.Close()
if err != nil {
t.Error(err)

Loading…
Cancel
Save