diff --git a/src/backend.rs b/src/backend.rs index 18cf1f1468..e6c59c7697 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1965,6 +1965,55 @@ impl Context<'_, M> { } } + pub fn select(&mut self) { + let cond = self.pop(); + let else_ = self.pop(); + let then = self.pop(); + + match cond { + Value::Immediate(i) => { + if i == 0 { + self.push(else_); + } else { + self.push(then); + } + + return; + } + other => { + let (reg, free) = self.into_reg(other); + + dynasm!(self.asm + ; test Rd(reg), Rd(reg) + ); + + if free { + self.block_state.regs.release_scratch_gpr(reg); + } + } + } + + let out = self.block_state.regs.take_scratch_gpr(); + + // TODO: Can do this better for variables on stack + let (reg, free) = self.into_reg(else_); + dynasm!(self.asm + ; cmovz Rq(out), Rq(reg) + ); + if free { + self.block_state.regs.release_scratch_gpr(reg); + } + let (reg, free) = self.into_reg(then); + dynasm!(self.asm + ; cmovnz Rq(out), Rq(reg) + ); + if free { + self.block_state.regs.release_scratch_gpr(reg); + } + + self.push(Value::Temp(out)); + } + // TODO: This is wildly unsound, we don't actually check if the // local was written first. Would be fixed by Microwasm. pub fn get_local(&mut self, local_idx: u32) { diff --git a/src/function_body.rs b/src/function_body.rs index 4c994ca0e3..448f3efcd8 100644 --- a/src/function_body.rs +++ b/src/function_body.rs @@ -460,6 +460,9 @@ pub fn translate( Operator::I64Load { memarg } => ctx.i64_load(memarg.offset)?, Operator::I32Store { memarg } => ctx.i32_store(memarg.offset)?, Operator::I64Store { memarg } => ctx.i64_store(memarg.offset)?, + Operator::Select => { + ctx.select(); + } Operator::Call { function_index } => { let callee_ty = session.module_context.func_type(function_index); diff --git a/src/tests.rs b/src/tests.rs index b63cd250d8..35ec12bf87 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -101,7 +101,7 @@ mod op32 { } fn lit(a: u32) -> bool { - let translated = translate_wat(&format!(concat!(" + let translated = translate_wat(&format!(concat!(" (module (func (result i32) (i32.",stringify!($name)," (i32.const {val})))) "), val = a)); @@ -153,6 +153,7 @@ mod op64 { ($op:ident, $func:expr, $retty:ident) => { mod $op { use super::{translate_wat, ExecutableModule}; + use std::sync::Once; const RETTY: &str = stringify!($retty); const OP: &str = stringify!($op); @@ -226,7 +227,7 @@ mod op64 { } fn lit(a: u64) -> bool { - let translated = translate_wat(&format!(concat!(" + let translated = translate_wat(&format!(concat!(" (module (func (result ",stringify!($out_ty),") (i64.",stringify!($name)," (i64.const {val})))) "), val = a)); @@ -1060,6 +1061,53 @@ fn call_indirect() { ); } +macro_rules! test_select { + ($name:ident, $ty:ident) => { + mod $name { + use super::{translate_wat, ExecutableModule}; + use std::sync::Once; + + lazy_static! { + static ref AS_PARAMS: ExecutableModule = translate_wat(&format!(" + (module + (func (param {ty}) (param {ty}) (param i32) (result {ty}) + (select (get_local 0) (get_local 1) (get_local 2)) + ) + )", + ty = stringify!($ty) + )); + } + + quickcheck! { + fn as_param(cond: bool, then: $ty, else_: $ty) -> bool { + let icond: i32 = if cond { 1 } else { 0 }; + AS_PARAMS.execute_func::<($ty, $ty, i32), $ty>(0, (then, else_, icond)) == + Ok(if cond { then } else { else_ }) + } + + fn lit(cond: bool, then: $ty, else_: $ty) -> bool { + let icond: i32 = if cond { 1 } else { 0 }; + let translated = translate_wat(&format!(" + (module (func (param {ty}) (param {ty}) (result {ty}) + (select (get_local 0) (get_local 1) (i32.const {val})))) + ", + val = icond, + ty = stringify!($ty) + )); + static ONCE: Once = Once::new(); + ONCE.call_once(|| translated.disassemble()); + + translated.execute_func::<($ty, $ty), $ty>(0, (then, else_)) == + Ok(if cond { then } else { else_ }) + } + } + } + } +} + +test_select!(select32, i32); +test_select!(select64, i64); + #[bench] fn bench_fibonacci_compile(b: &mut test::Bencher) { let wasm = wabt::wat2wasm(FIBONACCI).unwrap(); @@ -1087,120 +1135,3 @@ fn bench_fibonacci_baseline(b: &mut test::Bencher) { b.iter(|| test::black_box(fib(test::black_box(20)))); } - -#[test] -fn test_recursive_factorial() { - let code = r#" -(module - (func (export "fac-rec") (param i64) (result i64) - (if (result i64) (i64.eq (get_local 0) (i64.const 0)) - (then (i64.const 1)) - (else - (i64.mul (get_local 0) (call 0 (i64.sub (get_local 0) (i64.const 1)))) - ) - ) - ) -) -"#; - - assert_eq!(translate_wat(code).execute_func::<_, u64>(0, (25u64,)).unwrap(), 7034535277573963776u64); -} -#[test] -fn test_recursive_factorial_named() { - let code = r#" -(module - (func $fac-rec-named (export "fac-rec-named") (param $n i64) (result i64) - (if (result i64) (i64.eq (get_local $n) (i64.const 0)) - (then (i64.const 1)) - (else - (i64.mul - (get_local $n) - (call $fac-rec-named (i64.sub (get_local $n) (i64.const 1))) - ) - ) - ) - ) -) -"#; - - assert_eq!(translate_wat(code).execute_func::<_, u64>(0, (25u64,)).unwrap(), 7034535277573963776u64); -} -#[test] -fn test_iterative_factorial() { - let code = r#" -(module - (func (export "fac-iter") (param i64) (result i64) - (local i64 i64) - (set_local 1 (get_local 0)) - (set_local 2 (i64.const 1)) - (block - (loop - (if - (i64.eq (get_local 1) (i64.const 0)) - (then (br 2)) - (else - (set_local 2 (i64.mul (get_local 1) (get_local 2))) - (set_local 1 (i64.sub (get_local 1) (i64.const 1))) - ) - ) - (br 0) - ) - ) - (get_local 2) - ) -) -"#; - - assert_eq!(translate_wat(code).execute_func::<_, u64>(0, (25u64,)).unwrap(), 7034535277573963776u64); -} -#[test] -fn test_iterative_factorial_named() { - let code = r#" -(module - (func (export "fac-iter-named") (param $n i64) (result i64) - (local $i i64) - (local $res i64) - (set_local $i (get_local $n)) - (set_local $res (i64.const 1)) - (block $done - (loop $loop - (if - (i64.eq (get_local $i) (i64.const 0)) - (then (br $done)) - (else - (set_local $res (i64.mul (get_local $i) (get_local $res))) - (set_local $i (i64.sub (get_local $i) (i64.const 1))) - ) - ) - (br $loop) - ) - ) - (get_local $res) - ) -) -"#; - - assert_eq!(translate_wat(code).execute_func::<_, u64>(0, (25u64,)).unwrap(), 7034535277573963776u64); -} -#[test] -fn test_optimized_factorial() { - let code = r#" -(module - (func (export "fac-opt") (param i64) (result i64) - (local i64) - (set_local 1 (i64.const 1)) - (block - (br_if 0 (i64.lt_s (get_local 0) (i64.const 2))) - (loop - (set_local 1 (i64.mul (get_local 1) (get_local 0))) - (set_local 0 (i64.add (get_local 0) (i64.const -1))) - (br_if 0 (i64.gt_s (get_local 0) (i64.const 1))) - ) - ) - (get_local 1) - ) -) -"#; - - assert_eq!(translate_wat(code).execute_func::<_, u64>(0, (25u64,)).unwrap(), 7034535277573963776u64); -}