|
|
@ -6,7 +6,9 @@ import ( |
|
|
|
"net" |
|
|
|
"net/http" |
|
|
|
"net/url" |
|
|
|
"time" |
|
|
|
|
|
|
|
wsGorilla "github.com/gorilla/websocket" |
|
|
|
tpt "github.com/libp2p/go-libp2p-transport" |
|
|
|
ma "github.com/multiformats/go-multiaddr" |
|
|
|
manet "github.com/multiformats/go-multiaddr-net" |
|
|
@ -29,6 +31,9 @@ var WsCodec = &manet.NetCodec{ |
|
|
|
ParseNetAddr: ParseWebsocketNetAddr, |
|
|
|
} |
|
|
|
|
|
|
|
// Default gorilla upgrader
|
|
|
|
var upgrader = wsGorilla.Upgrader{} |
|
|
|
|
|
|
|
func init() { |
|
|
|
err := ma.AddProtocol(WsProtocol) |
|
|
|
if err != nil { |
|
|
@ -107,12 +112,15 @@ func (d *dialer) DialContext(ctx context.Context, raddr ma.Multiaddr) (tpt.Conn, |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
wscon, err := ws.Dial(wsurl, "", "http://127.0.0.1:0/") |
|
|
|
// TODO: figure out origins, probably don't work for us
|
|
|
|
// header := http.Header{}
|
|
|
|
// header.Set("Origin", "http://127.0.0.1:0/")
|
|
|
|
wscon, _, err := wsGorilla.DefaultDialer.Dial(wsurl, nil) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
mnc, err := manet.WrapNetConn(wscon) |
|
|
|
mnc, err := manet.WrapNetConn(NewGorillaNetConn(wscon)) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
@ -141,17 +149,19 @@ type listener struct { |
|
|
|
incoming chan *conn |
|
|
|
|
|
|
|
tpt tpt.Transport |
|
|
|
|
|
|
|
origin *url.URL |
|
|
|
} |
|
|
|
|
|
|
|
type conn struct { |
|
|
|
*ws.Conn |
|
|
|
*GorillaNetConn |
|
|
|
|
|
|
|
done func() |
|
|
|
} |
|
|
|
|
|
|
|
func (c *conn) Close() error { |
|
|
|
c.done() |
|
|
|
return c.Conn.Close() |
|
|
|
return c.GorillaNetConn.Close() |
|
|
|
} |
|
|
|
|
|
|
|
func (t *WebsocketTransport) Listen(a ma.Multiaddr) (tpt.Listener, error) { |
|
|
@ -160,38 +170,41 @@ func (t *WebsocketTransport) Listen(a ma.Multiaddr) (tpt.Listener, error) { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
tlist := t.wrapListener(list) |
|
|
|
|
|
|
|
u, err := url.Parse("ws://" + list.Addr().String()) |
|
|
|
u, err := url.Parse("http://" + list.Addr().String()) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
|
|
|
|
|
s := &ws.Server{ |
|
|
|
Handler: tlist.handleWsConn, |
|
|
|
Config: ws.Config{Origin: u}, |
|
|
|
} |
|
|
|
tlist := t.wrapListener(list, u) |
|
|
|
|
|
|
|
go http.Serve(list.NetListener(), s) |
|
|
|
http.HandleFunc("/", tlist.handleWsConn) |
|
|
|
go http.Serve(list.NetListener(), nil) |
|
|
|
|
|
|
|
return tlist, nil |
|
|
|
} |
|
|
|
|
|
|
|
func (t *WebsocketTransport) wrapListener(l manet.Listener) *listener { |
|
|
|
func (t *WebsocketTransport) wrapListener(l manet.Listener, origin *url.URL) *listener { |
|
|
|
return &listener{ |
|
|
|
Listener: l, |
|
|
|
incoming: make(chan *conn), |
|
|
|
tpt: t, |
|
|
|
origin: origin, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func (l *listener) handleWsConn(s *ws.Conn) { |
|
|
|
func (l *listener) handleWsConn(w http.ResponseWriter, r *http.Request) { |
|
|
|
c, err := upgrader.Upgrade(w, r, nil) |
|
|
|
if err != nil { |
|
|
|
http.Error(w, "Failed to upgrade websocket", 400) |
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
|
ctx, cancel := context.WithCancel(context.Background()) |
|
|
|
s.PayloadType = ws.BinaryFrame |
|
|
|
|
|
|
|
wrapped := NewGorillaNetConn(c) |
|
|
|
l.incoming <- &conn{ |
|
|
|
Conn: s, |
|
|
|
done: cancel, |
|
|
|
GorillaNetConn: &wrapped, |
|
|
|
done: cancel, |
|
|
|
} |
|
|
|
|
|
|
|
// wait until conn gets closed, otherwise the handler closes it early
|
|
|
@ -225,3 +238,62 @@ func (l *listener) Multiaddr() ma.Multiaddr { |
|
|
|
} |
|
|
|
|
|
|
|
var _ tpt.Transport = (*WebsocketTransport)(nil) |
|
|
|
|
|
|
|
type GorillaNetConn struct { |
|
|
|
Inner *wsGorilla.Conn |
|
|
|
DefaultMessageType int |
|
|
|
} |
|
|
|
|
|
|
|
func (c GorillaNetConn) Read(b []byte) (n int, err error) { |
|
|
|
fmt.Println("reading") |
|
|
|
_, r, err := c.Inner.NextReader() |
|
|
|
if err != nil { |
|
|
|
return 0, err |
|
|
|
} |
|
|
|
|
|
|
|
return r.Read(b) |
|
|
|
} |
|
|
|
|
|
|
|
func (c GorillaNetConn) Write(b []byte) (n int, err error) { |
|
|
|
fmt.Printf("write %s\n", string(b)) |
|
|
|
if err := c.Inner.WriteMessage(c.DefaultMessageType, b); err != nil { |
|
|
|
return 0, err |
|
|
|
} |
|
|
|
|
|
|
|
return len(b), nil |
|
|
|
} |
|
|
|
|
|
|
|
func (c GorillaNetConn) Close() error { |
|
|
|
return c.Inner.Close() |
|
|
|
} |
|
|
|
|
|
|
|
func (c GorillaNetConn) LocalAddr() net.Addr { |
|
|
|
return c.Inner.LocalAddr() |
|
|
|
} |
|
|
|
|
|
|
|
func (c GorillaNetConn) RemoteAddr() net.Addr { |
|
|
|
return c.Inner.RemoteAddr() |
|
|
|
} |
|
|
|
|
|
|
|
func (c GorillaNetConn) SetDeadline(t time.Time) error { |
|
|
|
if err := c.SetReadDeadline(t); err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
|
|
|
|
return c.SetReadDeadline(t) |
|
|
|
} |
|
|
|
|
|
|
|
func (c GorillaNetConn) SetReadDeadline(t time.Time) error { |
|
|
|
return c.Inner.SetReadDeadline(t) |
|
|
|
} |
|
|
|
|
|
|
|
func (c GorillaNetConn) SetWriteDeadline(t time.Time) error { |
|
|
|
return c.Inner.SetWriteDeadline(t) |
|
|
|
} |
|
|
|
|
|
|
|
func NewGorillaNetConn(raw *wsGorilla.Conn) GorillaNetConn { |
|
|
|
return GorillaNetConn{ |
|
|
|
Inner: raw, |
|
|
|
DefaultMessageType: wsGorilla.BinaryMessage, |
|
|
|
} |
|
|
|
} |
|
|
|