From 536dde254be99e19700a0934af38b913256475e3 Mon Sep 17 00:00:00 2001 From: Damien George Date: Thu, 13 Mar 2014 22:07:55 +0000 Subject: [PATCH] py: In string.count, handle case of zero-length needle. --- py/objstr.c | 14 +++++++------- tests/basics/string_count.py | 26 ++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/py/objstr.c b/py/objstr.c index 6a2625b621..64ba6c5fad 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -495,8 +495,8 @@ STATIC mp_obj_t str_count(uint n_args, const mp_obj_t *args) { GET_STR_DATA_LEN(args[0], haystack, haystack_len); GET_STR_DATA_LEN(args[1], needle, needle_len); - size_t start = 0; - size_t end = haystack_len; + machine_uint_t start = 0; + machine_uint_t end = haystack_len; /* TODO use a non-exception-throwing mp_get_index */ if (n_args >= 3 && args[2] != mp_const_none) { start = mp_get_index(&str_type, haystack_len, args[2], true); @@ -505,13 +505,13 @@ STATIC mp_obj_t str_count(uint n_args, const mp_obj_t *args) { end = mp_get_index(&str_type, haystack_len, args[3], true); } - machine_int_t num_occurrences = 0; - - // needle won't exist in haystack if it's longer, so nothing to count - if (needle_len > haystack_len) { - MP_OBJ_NEW_SMALL_INT(0); + // if needle_len is zero then we count each gap between characters as an occurrence + if (needle_len == 0) { + return MP_OBJ_NEW_SMALL_INT(end - start + 1); } + // count the occurrences + machine_int_t num_occurrences = 0; for (machine_uint_t haystack_index = start; haystack_index + needle_len <= end; haystack_index++) { if (memcmp(&haystack[haystack_index], needle, needle_len) == 0) { num_occurrences++; diff --git a/tests/basics/string_count.py b/tests/basics/string_count.py index bac99e78d8..0da1b1fcae 100644 --- a/tests/basics/string_count.py +++ b/tests/basics/string_count.py @@ -1,3 +1,29 @@ +print("".count("")) +print("".count("a")) +print("a".count("")) +print("a".count("a")) +print("a".count("b")) +print("b".count("a")) + +print("aaa".count("")) +print("aaa".count("a")) +print("aaa".count("aa")) +print("aaa".count("aaa")) +print("aaa".count("aaaa")) + +print("aaaa".count("")) +print("aaaa".count("a")) +print("aaaa".count("aa")) +print("aaaa".count("aaa")) +print("aaaa".count("aaaa")) +print("aaaa".count("aaaaa")) + +print("aaa".count("", 1)) +print("aaa".count("", 2)) +print("aaa".count("", 3)) + +print("aaa".count("", 1, 2)) + print("asdfasdfaaa".count("asdf", -100)) print("asdfasdfaaa".count("asdf", -8)) print("asdf".count('s', True))