Browse Source

Merge pull request #148 from pfalcon/list-cmp

Implement type virtual equality method support and implement comparisons for lists
pull/157/head
Damien George 11 years ago
parent
commit
97eb73cf84
  1. 7
      py/obj.c
  2. 64
      py/objlist.c
  3. 50
      tests/basics/tests/list_compare.py

7
py/obj.c

@ -117,6 +117,13 @@ bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) {
} else if (MP_OBJ_IS_TYPE(o1, &str_type) && MP_OBJ_IS_TYPE(o2, &str_type)) { } else if (MP_OBJ_IS_TYPE(o1, &str_type) && MP_OBJ_IS_TYPE(o2, &str_type)) {
return mp_obj_str_get(o1) == mp_obj_str_get(o2); return mp_obj_str_get(o1) == mp_obj_str_get(o2);
} else { } else {
mp_obj_base_t *o = o1;
if (o->type->binary_op != NULL) {
mp_obj_t r = o->type->binary_op(RT_COMPARE_OP_EQUAL, o1, o2);
if (r != MP_OBJ_NULL) {
return r == mp_const_true ? true : false;
}
}
// TODO: Debugging helper // TODO: Debugging helper
printf("Equality for '%s' and '%s' types not yet implemented\n", mp_obj_get_type_str(o1), mp_obj_get_type_str(o2)); printf("Equality for '%s' and '%s' types not yet implemented\n", mp_obj_get_type_str(o1), mp_obj_get_type_str(o2));
assert(0); assert(0);

64
py/objlist.c

@ -62,6 +62,61 @@ static mp_obj_t list_make_new(mp_obj_t type_in, int n_args, const mp_obj_t *args
return NULL; return NULL;
} }
// Don't pass RT_COMPARE_OP_NOT_EQUAL here
static bool list_cmp_helper(int op, mp_obj_t self_in, mp_obj_t another_in) {
assert(MP_OBJ_IS_TYPE(self_in, &list_type));
if (!MP_OBJ_IS_TYPE(another_in, &list_type)) {
return false;
}
mp_obj_list_t *self = self_in;
mp_obj_list_t *another = another_in;
if (op == RT_COMPARE_OP_EQUAL && self->len != another->len) {
return false;
}
// Let's deal only with > & >=
if (op == RT_COMPARE_OP_LESS || op == RT_COMPARE_OP_LESS_EQUAL) {
mp_obj_t t = self;
self = another;
another = t;
if (op == RT_COMPARE_OP_LESS) {
op = RT_COMPARE_OP_MORE;
} else {
op = RT_COMPARE_OP_MORE_EQUAL;
}
}
int len = self->len < another->len ? self->len : another->len;
bool eq_status = true; // empty lists are equal
bool rel_status;
for (int i = 0; i < len; i++) {
eq_status = mp_obj_equal(self->items[i], another->items[i]);
if (op == RT_COMPARE_OP_EQUAL && !eq_status) {
return false;
}
rel_status = (rt_binary_op(op, self->items[i], another->items[i]) == mp_const_true);
if (!eq_status && !rel_status) {
return false;
}
}
// If we had tie in the last element...
if (eq_status) {
// ... and we have lists of different lengths...
if (self->len != another->len) {
if (self->len < another->len) {
// ... then longer list length wins (we deal only with >)
return false;
}
} else if (op == RT_COMPARE_OP_MORE) {
// Otherwise, if we have strict relation, equality means failure
return false;
}
}
return true;
}
static mp_obj_t list_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { static mp_obj_t list_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
mp_obj_list_t *o = lhs; mp_obj_list_t *o = lhs;
switch (op) { switch (op) {
@ -105,6 +160,15 @@ static mp_obj_t list_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
} }
return s; return s;
} }
case RT_COMPARE_OP_EQUAL:
case RT_COMPARE_OP_LESS:
case RT_COMPARE_OP_LESS_EQUAL:
case RT_COMPARE_OP_MORE:
case RT_COMPARE_OP_MORE_EQUAL:
return MP_BOOL(list_cmp_helper(op, lhs, rhs));
case RT_COMPARE_OP_NOT_EQUAL:
return MP_BOOL(!list_cmp_helper(RT_COMPARE_OP_EQUAL, lhs, rhs));
default: default:
// op not supported // op not supported
return NULL; return NULL;

50
tests/basics/tests/list_compare.py

@ -0,0 +1,50 @@
print([] == [])
print([] > [])
print([] < [])
print([] == [1])
print([1] == [])
print([] > [1])
print([1] > [])
print([] < [1])
print([1] < [])
print([] >= [1])
print([1] >= [])
print([] <= [1])
print([1] <= [])
print([1] == [1])
print([1] != [1])
print([1] == [2])
print([1] == [1, 0])
print([1] > [1])
print([1] > [2])
print([2] > [1])
print([1, 0] > [1])
print([1, -1] > [1])
print([1] > [1, 0])
print([1] > [1, -1])
print([1] < [1])
print([2] < [1])
print([1] < [2])
print([1] < [1, 0])
print([1] < [1, -1])
print([1, 0] < [1])
print([1, -1] < [1])
print([1] >= [1])
print([1] >= [2])
print([2] >= [1])
print([1, 0] >= [1])
print([1, -1] >= [1])
print([1] >= [1, 0])
print([1] >= [1, -1])
print([1] <= [1])
print([2] <= [1])
print([1] <= [2])
print([1] <= [1, 0])
print([1] <= [1, -1])
print([1, 0] <= [1])
print([1, -1] <= [1])
Loading…
Cancel
Save