diff --git a/extmod/modlwip.c b/extmod/modlwip.c index 33546a6324..5cc7bbf817 100644 --- a/extmod/modlwip.c +++ b/extmod/modlwip.c @@ -272,7 +272,15 @@ typedef struct _lwip_socket_obj_t { } pcb; volatile union { struct pbuf *pbuf; - struct tcp_pcb *connection; + struct { + uint8_t alloc; + uint8_t iget; + uint8_t iput; + union { + struct tcp_pcb *item; // if alloc == 0 + struct tcp_pcb **array; // if alloc != 0 + } tcp; + } connection; } incoming; mp_obj_t callback; byte peer[4]; @@ -371,13 +379,19 @@ STATIC err_t _lwip_tcp_accept(void *arg, struct tcp_pcb *newpcb, err_t err) { lwip_socket_obj_t *socket = (lwip_socket_obj_t*)arg; tcp_recv(newpcb, _lwip_tcp_recv_unaccepted); - if (socket->incoming.connection != NULL) { - DEBUG_printf("_lwip_tcp_accept: Tried to queue >1 pcb waiting for accept\n"); - // We need to handle this better. This single-level structure makes the - // backlog setting kind of pointless. FIXME - return ERR_BUF; + // Search for an empty slot to store the new connection + struct tcp_pcb *volatile *tcp_array; + if (socket->incoming.connection.alloc == 0) { + tcp_array = &socket->incoming.connection.tcp.item; } else { - socket->incoming.connection = newpcb; + tcp_array = socket->incoming.connection.tcp.array; + } + if (tcp_array[socket->incoming.connection.iput] == NULL) { + // Have an empty slot to store waiting connection + tcp_array[socket->incoming.connection.iput] = newpcb; + if (++socket->incoming.connection.iput >= socket->incoming.connection.alloc) { + socket->incoming.connection.iput = 0; + } if (socket->callback != MP_OBJ_NULL) { // Schedule accept callback to be called when lwIP is done // with processing this incoming connection on its side and @@ -386,6 +400,9 @@ STATIC err_t _lwip_tcp_accept(void *arg, struct tcp_pcb *newpcb, err_t err) { } return ERR_OK; } + + DEBUG_printf("_lwip_tcp_accept: No room to queue pcb waiting for accept\n"); + return ERR_BUF; } // Callback for inbound tcp packets. @@ -643,8 +660,15 @@ STATIC mp_obj_t lwip_socket_make_new(const mp_obj_type_t *type, size_t n_args, s } switch (socket->type) { - case MOD_NETWORK_SOCK_STREAM: socket->pcb.tcp = tcp_new(); break; - case MOD_NETWORK_SOCK_DGRAM: socket->pcb.udp = udp_new(); break; + case MOD_NETWORK_SOCK_STREAM: + socket->pcb.tcp = tcp_new(); + socket->incoming.connection.alloc = 0; + socket->incoming.connection.tcp.item = NULL; + break; + case MOD_NETWORK_SOCK_DGRAM: + socket->pcb.udp = udp_new(); + socket->incoming.pbuf = NULL; + break; //case MOD_NETWORK_SOCK_RAW: socket->pcb.raw = raw_new(); break; default: mp_raise_OSError(MP_EINVAL); } @@ -669,7 +693,6 @@ STATIC mp_obj_t lwip_socket_make_new(const mp_obj_type_t *type, size_t n_args, s } } - socket->incoming.pbuf = NULL; socket->timeout = -1; socket->state = STATE_NEW; socket->recv_offset = 0; @@ -721,6 +744,18 @@ STATIC mp_obj_t lwip_socket_listen(mp_obj_t self_in, mp_obj_t backlog_in) { mp_raise_OSError(MP_ENOMEM); } socket->pcb.tcp = new_pcb; + + // Allocate memory for the backlog of connections + if (backlog <= 1) { + socket->incoming.connection.alloc = 0; + socket->incoming.connection.tcp.item = NULL; + } else { + socket->incoming.connection.alloc = backlog; + socket->incoming.connection.tcp.array = m_new0(struct tcp_pcb*, backlog); + } + socket->incoming.connection.iget = 0; + socket->incoming.connection.iput = 0; + tcp_accept(new_pcb, _lwip_tcp_accept); // Socket is no longer considered "new" for purposes of polling @@ -746,19 +781,25 @@ STATIC mp_obj_t lwip_socket_accept(mp_obj_t self_in) { } // accept incoming connection - if (socket->incoming.connection == NULL) { + struct tcp_pcb *volatile *incoming_connection; + if (socket->incoming.connection.alloc == 0) { + incoming_connection = &socket->incoming.connection.tcp.item; + } else { + incoming_connection = &socket->incoming.connection.tcp.array[socket->incoming.connection.iget]; + } + if (*incoming_connection == NULL) { if (socket->timeout == 0) { mp_raise_OSError(MP_EAGAIN); } else if (socket->timeout != -1) { for (mp_uint_t retries = socket->timeout / 100; retries--;) { mp_hal_delay_ms(100); - if (socket->incoming.connection != NULL) break; + if (*incoming_connection != NULL) break; } - if (socket->incoming.connection == NULL) { + if (*incoming_connection == NULL) { mp_raise_OSError(MP_ETIMEDOUT); } } else { - while (socket->incoming.connection == NULL) { + while (*incoming_connection == NULL) { poll_sockets(); } } @@ -769,8 +810,11 @@ STATIC mp_obj_t lwip_socket_accept(mp_obj_t self_in) { socket2->base.type = &lwip_socket_type; // We get a new pcb handle... - socket2->pcb.tcp = socket->incoming.connection; - socket->incoming.connection = NULL; + socket2->pcb.tcp = *incoming_connection; + if (++socket->incoming.connection.iget >= socket->incoming.connection.alloc) { + socket->incoming.connection.iget = 0; + } + *incoming_connection = NULL; // ...and set up the new socket for it. socket2->domain = MOD_NETWORK_AF_INET; @@ -1222,15 +1266,27 @@ STATIC mp_uint_t lwip_socket_ioctl(mp_obj_t self_in, mp_uint_t request, uintptr_ } socket->pcb.tcp = NULL; socket->state = _ERR_BADF; - if (socket->incoming.pbuf != NULL) { - if (!socket_is_listener) { + if (!socket_is_listener) { + if (socket->incoming.pbuf != NULL) { pbuf_free(socket->incoming.pbuf); + socket->incoming.pbuf = NULL; + } + } else { + uint8_t alloc = socket->incoming.connection.alloc; + struct tcp_pcb *volatile *tcp_array; + if (alloc == 0) { + tcp_array = &socket->incoming.connection.tcp.item; } else { + tcp_array = socket->incoming.connection.tcp.array; + } + for (uint8_t i = 0; i < alloc; ++i) { // Deregister callback and abort - tcp_poll(socket->incoming.connection, NULL, 0); - tcp_abort(socket->incoming.connection); + if (tcp_array[i] != NULL) { + tcp_poll(tcp_array[i], NULL, 0); + tcp_abort(tcp_array[i]); + tcp_array[i] = NULL; + } } - socket->incoming.pbuf = NULL; } ret = 0;