From cc57bd2663fb742770a56f205488dbb2c757fb8f Mon Sep 17 00:00:00 2001 From: Paul Sokolovsky Date: Sun, 12 Jan 2014 01:55:50 +0200 Subject: [PATCH 1/2] mp_obj_equal(): For non-trivial types, call out to type's special method. --- py/obj.c | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/py/obj.c b/py/obj.c index 81b5c69f7a..2759437fd7 100644 --- a/py/obj.c +++ b/py/obj.c @@ -117,6 +117,13 @@ bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) { } else if (MP_OBJ_IS_TYPE(o1, &str_type) && MP_OBJ_IS_TYPE(o2, &str_type)) { return mp_obj_str_get(o1) == mp_obj_str_get(o2); } else { + mp_obj_base_t *o = o1; + if (o->type->binary_op != NULL) { + mp_obj_t r = o->type->binary_op(RT_COMPARE_OP_EQUAL, o1, o2); + if (r != MP_OBJ_NULL) { + return r == mp_const_true ? true : false; + } + } // TODO: Debugging helper printf("Equality for '%s' and '%s' types not yet implemented\n", mp_obj_get_type_str(o1), mp_obj_get_type_str(o2)); assert(0); From 1945e60aeb1099e7cd0ab0ccd9c427c2c68ead99 Mon Sep 17 00:00:00 2001 From: Paul Sokolovsky Date: Sun, 12 Jan 2014 02:01:00 +0200 Subject: [PATCH 2/2] list: Implement comparison operators. --- py/objlist.c | 64 ++++++++++++++++++++++++++++++ tests/basics/tests/list_compare.py | 50 +++++++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 tests/basics/tests/list_compare.py diff --git a/py/objlist.c b/py/objlist.c index c153d2222b..fa8ec67d09 100644 --- a/py/objlist.c +++ b/py/objlist.c @@ -62,6 +62,61 @@ static mp_obj_t list_make_new(mp_obj_t type_in, int n_args, const mp_obj_t *args return NULL; } +// Don't pass RT_COMPARE_OP_NOT_EQUAL here +static bool list_cmp_helper(int op, mp_obj_t self_in, mp_obj_t another_in) { + assert(MP_OBJ_IS_TYPE(self_in, &list_type)); + if (!MP_OBJ_IS_TYPE(another_in, &list_type)) { + return false; + } + mp_obj_list_t *self = self_in; + mp_obj_list_t *another = another_in; + if (op == RT_COMPARE_OP_EQUAL && self->len != another->len) { + return false; + } + + // Let's deal only with > & >= + if (op == RT_COMPARE_OP_LESS || op == RT_COMPARE_OP_LESS_EQUAL) { + mp_obj_t t = self; + self = another; + another = t; + if (op == RT_COMPARE_OP_LESS) { + op = RT_COMPARE_OP_MORE; + } else { + op = RT_COMPARE_OP_MORE_EQUAL; + } + } + + int len = self->len < another->len ? self->len : another->len; + bool eq_status = true; // empty lists are equal + bool rel_status; + for (int i = 0; i < len; i++) { + eq_status = mp_obj_equal(self->items[i], another->items[i]); + if (op == RT_COMPARE_OP_EQUAL && !eq_status) { + return false; + } + rel_status = (rt_binary_op(op, self->items[i], another->items[i]) == mp_const_true); + if (!eq_status && !rel_status) { + return false; + } + } + + // If we had tie in the last element... + if (eq_status) { + // ... and we have lists of different lengths... + if (self->len != another->len) { + if (self->len < another->len) { + // ... then longer list length wins (we deal only with >) + return false; + } + } else if (op == RT_COMPARE_OP_MORE) { + // Otherwise, if we have strict relation, equality means failure + return false; + } + } + + return true; +} + static mp_obj_t list_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { mp_obj_list_t *o = lhs; switch (op) { @@ -105,6 +160,15 @@ static mp_obj_t list_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { } return s; } + case RT_COMPARE_OP_EQUAL: + case RT_COMPARE_OP_LESS: + case RT_COMPARE_OP_LESS_EQUAL: + case RT_COMPARE_OP_MORE: + case RT_COMPARE_OP_MORE_EQUAL: + return MP_BOOL(list_cmp_helper(op, lhs, rhs)); + case RT_COMPARE_OP_NOT_EQUAL: + return MP_BOOL(!list_cmp_helper(RT_COMPARE_OP_EQUAL, lhs, rhs)); + default: // op not supported return NULL; diff --git a/tests/basics/tests/list_compare.py b/tests/basics/tests/list_compare.py new file mode 100644 index 0000000000..eea8814247 --- /dev/null +++ b/tests/basics/tests/list_compare.py @@ -0,0 +1,50 @@ +print([] == []) +print([] > []) +print([] < []) +print([] == [1]) +print([1] == []) +print([] > [1]) +print([1] > []) +print([] < [1]) +print([1] < []) +print([] >= [1]) +print([1] >= []) +print([] <= [1]) +print([1] <= []) + +print([1] == [1]) +print([1] != [1]) +print([1] == [2]) +print([1] == [1, 0]) + +print([1] > [1]) +print([1] > [2]) +print([2] > [1]) +print([1, 0] > [1]) +print([1, -1] > [1]) +print([1] > [1, 0]) +print([1] > [1, -1]) + +print([1] < [1]) +print([2] < [1]) +print([1] < [2]) +print([1] < [1, 0]) +print([1] < [1, -1]) +print([1, 0] < [1]) +print([1, -1] < [1]) + +print([1] >= [1]) +print([1] >= [2]) +print([2] >= [1]) +print([1, 0] >= [1]) +print([1, -1] >= [1]) +print([1] >= [1, 0]) +print([1] >= [1, -1]) + +print([1] <= [1]) +print([2] <= [1]) +print([1] <= [2]) +print([1] <= [1, 0]) +print([1] <= [1, -1]) +print([1, 0] <= [1]) +print([1, -1] <= [1])