diff --git a/ports/unix/modusocket.c b/ports/unix/modusocket.c index 1a073ca035..61402e001d 100644 --- a/ports/unix/modusocket.c +++ b/ports/unix/modusocket.c @@ -37,6 +37,7 @@ #include #include #include +#include #include "py/objtuple.h" #include "py/objstr.h" @@ -65,6 +66,7 @@ typedef struct _mp_obj_socket_t { mp_obj_base_t base; int fd; + bool blocking; } mp_obj_socket_t; const mp_obj_type_t mp_type_socket; @@ -78,6 +80,7 @@ STATIC mp_obj_socket_t *socket_new(int fd) { mp_obj_socket_t *o = m_new_obj(mp_obj_socket_t); o->base.type = &mp_type_socket; o->fd = fd; + o->blocking = true; return o; } @@ -92,12 +95,14 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc mp_obj_socket_t *o = MP_OBJ_TO_PTR(o_in); mp_int_t r = read(o->fd, buf, size); if (r == -1) { - *errcode = errno; - - if (*errcode == EAGAIN) { - *errcode = MP_ETIMEDOUT; + int err = errno; + // On blocking socket, we get EAGAIN in case SO_RCVTIMEO/SO_SNDTIMEO + // timed out, and need to convert that to ETIMEDOUT. + if (err == EAGAIN && o->blocking) { + err = MP_ETIMEDOUT; } + *errcode = err; return MP_STREAM_ERROR; } return r; @@ -107,12 +112,14 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in mp_obj_socket_t *o = MP_OBJ_TO_PTR(o_in); mp_int_t r = write(o->fd, buf, size); if (r == -1) { - *errcode = errno; - - if (*errcode == EAGAIN) { - *errcode = MP_ETIMEDOUT; + int err = errno; + // On blocking socket, we get EAGAIN in case SO_RCVTIMEO/SO_SNDTIMEO + // timed out, and need to convert that to ETIMEDOUT. + if (err == EAGAIN && o->blocking) { + err = MP_ETIMEDOUT; } + *errcode = err; return MP_STREAM_ERROR; } return r; @@ -320,6 +327,7 @@ STATIC mp_obj_t socket_setblocking(mp_obj_t self_in, mp_obj_t flag_in) { } flags = fcntl(self->fd, F_SETFL, flags); RAISE_ERRNO(flags, errno); + self->blocking = val; return mp_const_none; } STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_setblocking_obj, socket_setblocking); @@ -327,21 +335,37 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_setblocking_obj, socket_setblocking); STATIC mp_obj_t socket_settimeout(mp_obj_t self_in, mp_obj_t timeout_in) { mp_obj_socket_t *self = MP_OBJ_TO_PTR(self_in); struct timeval tv = {0,}; + bool new_blocking = true; + if (timeout_in == mp_const_none) { setsockopt(self->fd, SOL_SOCKET, SO_RCVTIMEO, NULL, 0); setsockopt(self->fd, SOL_SOCKET, SO_SNDTIMEO, NULL, 0); } else { - tv.tv_sec = mp_obj_get_int(timeout_in); - #if MICROPY_PY_BUILTINS_FLOAT - tv.tv_usec = (mp_obj_get_float(timeout_in) - tv.tv_sec) * 1000000; + mp_float_t val = mp_obj_get_float(timeout_in); + double ipart; + tv.tv_usec = round(modf(val, &ipart) * 1000000); + tv.tv_sec = ipart; + #else + tv.tv_sec = mp_obj_get_int(timeout_in); #endif - setsockopt(self->fd, SOL_SOCKET, SO_RCVTIMEO, - &tv, sizeof(struct timeval)); - setsockopt(self->fd, SOL_SOCKET, SO_SNDTIMEO, - &tv, sizeof(struct timeval)); + // For SO_RCVTIMEO/SO_SNDTIMEO, zero timeout means infinity, but + // for Python API it means non-blocking. + if (tv.tv_sec == 0 && tv.tv_usec == 0) { + new_blocking = false; + } else { + setsockopt(self->fd, SOL_SOCKET, SO_RCVTIMEO, + &tv, sizeof(struct timeval)); + setsockopt(self->fd, SOL_SOCKET, SO_SNDTIMEO, + &tv, sizeof(struct timeval)); + } } + + if (self->blocking != new_blocking) { + socket_setblocking(self_in, mp_obj_new_bool(new_blocking)); + } + return mp_const_none; } STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_settimeout_obj, socket_settimeout);