diff --git a/py/compile.c b/py/compile.c index adf76fb974..2f6a9a326d 100644 --- a/py/compile.c +++ b/py/compile.c @@ -2996,7 +2996,12 @@ STATIC void compile_scope(compiler_t *comp, scope_t *scope, pass_kind_t pass) { // nodes[2] can be null or a test-expr if (MP_PARSE_NODE_IS_ID(pn_annotation)) { qstr ret_type = MP_PARSE_NODE_LEAF_ARG(pn_annotation); - EMIT_ARG(set_native_type, MP_EMIT_NATIVE_TYPE_RETURN, 0, ret_type); + int native_type = mp_native_type_from_qstr(ret_type); + if (native_type < 0) { + comp->compile_error = mp_obj_new_exception_msg_varg(&mp_type_ViperTypeError, "unknown type '%q'", ret_type); + } else { + scope->scope_flags |= native_type << MP_SCOPE_FLAG_VIPERRET_POS; + } } else { compile_syntax_error(comp, pn_annotation, "return annotation must be an identifier"); } diff --git a/py/emit.h b/py/emit.h index f63bb1d7a5..d511eb259d 100644 --- a/py/emit.h +++ b/py/emit.h @@ -51,7 +51,6 @@ typedef enum { #define MP_EMIT_BREAK_FROM_FOR (0x8000) -#define MP_EMIT_NATIVE_TYPE_RETURN (1) #define MP_EMIT_NATIVE_TYPE_ARG (2) // Kind for emit_id_ops->local() @@ -161,6 +160,8 @@ typedef struct _emit_method_table_t { void (*end_except_handler)(emit_t *emit); } emit_method_table_t; +int mp_native_type_from_qstr(qstr qst); + void mp_emit_common_get_id_for_load(scope_t *scope, qstr qst); void mp_emit_common_get_id_for_modification(scope_t *scope, qstr qst); void mp_emit_common_id_op(emit_t *emit, const mp_emit_method_table_id_ops_t *emit_method_table, scope_t *scope, qstr qst); diff --git a/py/emitnative.c b/py/emitnative.c index 0794b9d502..0301d85b2e 100644 --- a/py/emitnative.c +++ b/py/emitnative.c @@ -128,6 +128,20 @@ typedef enum { VTYPE_BUILTIN_CAST = 0x70 | MP_NATIVE_TYPE_OBJ, } vtype_kind_t; +int mp_native_type_from_qstr(qstr qst) { + switch (qst) { + case MP_QSTR_object: return MP_NATIVE_TYPE_OBJ; + case MP_QSTR_bool: return MP_NATIVE_TYPE_BOOL; + case MP_QSTR_int: return MP_NATIVE_TYPE_INT; + case MP_QSTR_uint: return MP_NATIVE_TYPE_UINT; + case MP_QSTR_ptr: return MP_NATIVE_TYPE_PTR; + case MP_QSTR_ptr8: return MP_NATIVE_TYPE_PTR8; + case MP_QSTR_ptr16: return MP_NATIVE_TYPE_PTR16; + case MP_QSTR_ptr32: return MP_NATIVE_TYPE_PTR32; + default: return -1; + } +} + STATIC qstr vtype_to_qstr(vtype_kind_t vtype) { switch (vtype) { case VTYPE_PYOBJ: return MP_QSTR_object; @@ -169,8 +183,6 @@ struct _emit_t { bool do_viper_types; - vtype_kind_t return_vtype; - mp_uint_t local_vtype_alloc; vtype_kind_t *local_vtype; @@ -224,22 +236,14 @@ void EXPORT_FUN(free)(emit_t *emit) { } STATIC void emit_native_set_native_type(emit_t *emit, mp_uint_t op, mp_uint_t arg1, qstr arg2) { + (void)op; { - vtype_kind_t type; - switch (arg2) { - case MP_QSTR_object: type = VTYPE_PYOBJ; break; - case MP_QSTR_bool: type = VTYPE_BOOL; break; - case MP_QSTR_int: type = VTYPE_INT; break; - case MP_QSTR_uint: type = VTYPE_UINT; break; - case MP_QSTR_ptr: type = VTYPE_PTR; break; - case MP_QSTR_ptr8: type = VTYPE_PTR8; break; - case MP_QSTR_ptr16: type = VTYPE_PTR16; break; - case MP_QSTR_ptr32: type = VTYPE_PTR32; break; - default: EMIT_NATIVE_VIPER_TYPE_ERROR(emit, "unknown type '%q'", arg2); return; + int type = mp_native_type_from_qstr(arg2); + if (type < 0) { + EMIT_NATIVE_VIPER_TYPE_ERROR(emit, "unknown type '%q'", arg2); + return; } - if (op == MP_EMIT_NATIVE_TYPE_RETURN) { - emit->return_vtype = type; - } else { + { assert(arg1 < emit->local_vtype_alloc); emit->local_vtype[arg1] = type; } @@ -267,9 +271,6 @@ STATIC void emit_native_start_pass(emit_t *emit, pass_kind_t pass, scope_t *scop emit->local_vtype_alloc = scope->num_locals; } - // set default type for return - emit->return_vtype = VTYPE_PYOBJ; - // set default type for arguments mp_uint_t num_args = emit->scope->num_pos_args + emit->scope->num_kwonly_args; if (scope->scope_flags & MP_SCOPE_FLAG_VARARGS) { @@ -482,7 +483,7 @@ STATIC void emit_native_end_pass(emit_t *emit) { // compute type signature // note that the lower 4 bits of a vtype are tho correct MP_NATIVE_TYPE_xxx - mp_uint_t type_sig = emit->return_vtype & 0xf; + mp_uint_t type_sig = emit->scope->scope_flags >> MP_SCOPE_FLAG_VIPERRET_POS; for (mp_uint_t i = 0; i < emit->scope->num_pos_args; i++) { type_sig |= (emit->local_vtype[i] & 0xf) << (i * 4 + 4); } @@ -2420,9 +2421,10 @@ STATIC void emit_native_call_method(emit_t *emit, mp_uint_t n_positional, mp_uin STATIC void emit_native_return_value(emit_t *emit) { DEBUG_printf("return_value\n"); if (emit->do_viper_types) { + vtype_kind_t return_vtype = emit->scope->scope_flags >> MP_SCOPE_FLAG_VIPERRET_POS; if (peek_vtype(emit, 0) == VTYPE_PTR_NONE) { emit_pre_pop_discard(emit); - if (emit->return_vtype == VTYPE_PYOBJ) { + if (return_vtype == VTYPE_PYOBJ) { ASM_MOV_REG_IMM(emit->as, REG_RET, (mp_uint_t)mp_const_none); } else { ASM_MOV_REG_IMM(emit->as, REG_RET, 0); @@ -2430,10 +2432,10 @@ STATIC void emit_native_return_value(emit_t *emit) { } else { vtype_kind_t vtype; emit_pre_pop_reg(emit, &vtype, REG_RET); - if (vtype != emit->return_vtype) { + if (vtype != return_vtype) { EMIT_NATIVE_VIPER_TYPE_ERROR(emit, "return expected '%q' but got '%q'", - vtype_to_qstr(emit->return_vtype), vtype_to_qstr(vtype)); + vtype_to_qstr(return_vtype), vtype_to_qstr(vtype)); } } } else { diff --git a/py/runtime0.h b/py/runtime0.h index b47a10ea22..f26b701bf1 100644 --- a/py/runtime0.h +++ b/py/runtime0.h @@ -32,6 +32,7 @@ #define MP_SCOPE_FLAG_GENERATOR (0x04) #define MP_SCOPE_FLAG_DEFKWARGS (0x08) #define MP_SCOPE_FLAG_REFGLOBALS (0x10) // used only if native emitter enabled +#define MP_SCOPE_FLAG_VIPERRET_POS (5) // top 3 bits used for viper return type // types for native (viper) function signature #define MP_NATIVE_TYPE_OBJ (0x00)