From c9f9bcdd04e35e407e843124b81b3de84685f874 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sa=C3=BAl=20Cabrera?= Date: Wed, 10 Apr 2024 13:13:47 -0400 Subject: [PATCH] winch: Refactor the MacroAssembler trait (#8320) No new functionality is introduced as part of this change. This commit introduces a small refactoring of Winch's `MacroAssembler` trait by removing multiple methods that can be treated in an ISA-indepedent way by relying on the base building blocks already provided by the `MacroAssembler`. This change simplifies the process of adding support for other backends to Winch. --- winch/codegen/src/codegen/env.rs | 9 ++ winch/codegen/src/codegen/mod.rs | 138 +++++++++++++++++++++++-- winch/codegen/src/isa/aarch64/masm.rs | 20 +--- winch/codegen/src/isa/x64/masm.rs | 140 +------------------------- winch/codegen/src/masm.rs | 18 +--- winch/codegen/src/visitor.rs | 13 +-- 6 files changed, 149 insertions(+), 189 deletions(-) diff --git a/winch/codegen/src/codegen/env.rs b/winch/codegen/src/codegen/env.rs index 63f1d31a2c..377bbcb726 100644 --- a/winch/codegen/src/codegen/env.rs +++ b/winch/codegen/src/codegen/env.rs @@ -127,6 +127,8 @@ pub struct FuncEnv<'a, 'translation: 'a, 'data: 'translation, P: PtrSize> { ptr_type: WasmValType, /// Whether or not to enable Spectre mitigation on heap bounds checks. heap_access_spectre_mitigation: bool, + /// Whether or not to enable Spectre mitigation on table element accesses. + table_access_spectre_mitigation: bool, name_map: PrimaryMap, name_intern: HashMap, } @@ -158,6 +160,7 @@ impl<'a, 'translation, 'data, P: PtrSize> FuncEnv<'a, 'translation, 'data, P> { resolved_globals: HashMap::new(), ptr_type, heap_access_spectre_mitigation: isa.flags().enable_heap_access_spectre_mitigation(), + table_access_spectre_mitigation: isa.flags().enable_table_access_spectre_mitigation(), builtins, name_map: Default::default(), name_intern: Default::default(), @@ -314,6 +317,12 @@ impl<'a, 'translation, 'data, P: PtrSize> FuncEnv<'a, 'translation, 'data, P> { self.heap_access_spectre_mitigation } + /// Returns true if Spectre mitigations are enabled for table element + /// accesses. + pub fn table_access_spectre_mitigation(&self) -> bool { + self.table_access_spectre_mitigation + } + pub(crate) fn callee_sig<'b, A>(&'b mut self, callee: &'b Callee) -> &'b ABISig where A: ABI, diff --git a/winch/codegen/src/codegen/mod.rs b/winch/codegen/src/codegen/mod.rs index 1ba4e98676..3017dc8120 100644 --- a/winch/codegen/src/codegen/mod.rs +++ b/winch/codegen/src/codegen/mod.rs @@ -2,8 +2,10 @@ use crate::{ abi::{vmctx, ABIOperand, ABISig, RetArea, ABI}, codegen::BlockSig, isa::reg::Reg, - masm::{ExtendKind, IntCmpKind, MacroAssembler, OperandSize, RegImm, SPOffset, TrapCode}, - stack::TypedReg, + masm::{ + ExtendKind, IntCmpKind, MacroAssembler, OperandSize, RegImm, SPOffset, ShiftKind, TrapCode, + }, + stack::{TypedReg, Val}, }; use anyhow::Result; use smallvec::SmallVec; @@ -12,7 +14,7 @@ use wasmparser::{ }; use wasmtime_environ::{ GlobalIndex, MemoryIndex, PtrSize, TableIndex, TypeIndex, WasmHeapType, WasmValType, - FUNCREF_MASK, + FUNCREF_MASK, WASM_PAGE_SIZE, }; use cranelift_codegen::{ @@ -475,9 +477,7 @@ where let index = self.context.pop_to_reg(self.masm, None); let base = self.context.any_gpr(self.masm); - let elem_addr = - self.masm - .table_elem_address(index.into(), base, &table_data, &mut self.context); + let elem_addr = self.emit_compute_table_elem_addr(index.into(), base, &table_data); self.masm.load_ptr(elem_addr, elem_value); // Free the register used as base, once we have loaded the element // address into the element value register. @@ -745,6 +745,132 @@ where } self.context.free_reg(src); } + + /// Loads the address of the table element at a given index. Returns the + /// address of the table element using the provided register as base. + pub fn emit_compute_table_elem_addr( + &mut self, + index: Reg, + base: Reg, + table_data: &TableData, + ) -> M::Address { + let scratch = ::scratch_reg(); + let bound = self.context.any_gpr(self.masm); + let tmp = self.context.any_gpr(self.masm); + let ptr_size: OperandSize = self.env.ptr_type().into(); + + if let Some(offset) = table_data.import_from { + // If the table data declares a particular offset base, + // load the address into a register to further use it as + // the table address. + self.masm.load_ptr(self.masm.address_at_vmctx(offset), base); + } else { + // Else, simply move the vmctx register into the addr register as + // the base to calculate the table address. + self.masm.mov(vmctx!(M).into(), base, ptr_size); + }; + + // OOB check. + let bound_addr = self + .masm + .address_at_reg(base, table_data.current_elems_offset); + let bound_size = table_data.current_elements_size; + self.masm.load(bound_addr, bound, bound_size.into()); + self.masm.cmp(bound.into(), index, bound_size); + self.masm + .trapif(IntCmpKind::GeU, TrapCode::TableOutOfBounds); + + // Move the index into the scratch register to calcualte the table + // element address. + // Moving the value of the index register to the scratch register + // also avoids overwriting the context of the index register. + self.masm.mov(index.into(), scratch, bound_size); + self.masm.mul( + scratch, + scratch, + RegImm::i32(table_data.element_size.bytes() as i32), + table_data.element_size, + ); + self.masm + .load_ptr(self.masm.address_at_reg(base, table_data.offset), base); + // Copy the value of the table base into a temporary register + // so that we can use it later in case of a misspeculation. + self.masm.mov(base.into(), tmp, ptr_size); + // Calculate the address of the table element. + self.masm.add(base, base, scratch.into(), ptr_size); + if self.env.table_access_spectre_mitigation() { + // Perform a bounds check and override the value of the + // table element address in case the index is out of bounds. + self.masm.cmp(bound.into(), index, OperandSize::S32); + self.masm.cmov(tmp, base, IntCmpKind::GeU, ptr_size); + } + self.context.free_reg(bound); + self.context.free_reg(tmp); + self.masm.address_at_reg(base, 0) + } + + /// Retrieves the size of the table, pushing the result to the value stack. + pub fn emit_compute_table_size(&mut self, table_data: &TableData) { + let scratch = ::scratch_reg(); + let size = self.context.any_gpr(self.masm); + let ptr_size: OperandSize = self.env.ptr_type().into(); + + if let Some(offset) = table_data.import_from { + self.masm + .load_ptr(self.masm.address_at_vmctx(offset), scratch); + } else { + self.masm.mov(vmctx!(M).into(), scratch, ptr_size); + }; + + let size_addr = self + .masm + .address_at_reg(scratch, table_data.current_elems_offset); + self.masm + .load(size_addr, size, table_data.current_elements_size.into()); + + self.context.stack.push(TypedReg::i32(size).into()); + } + + /// Retrieves the size of the memory, pushing the result to the value stack. + pub fn emit_compute_memory_size(&mut self, heap_data: &HeapData) { + let size_reg = self.context.any_gpr(self.masm); + let scratch = ::scratch_reg(); + + let base = if let Some(offset) = heap_data.import_from { + self.masm + .load_ptr(self.masm.address_at_vmctx(offset), scratch); + scratch + } else { + vmctx!(M) + }; + + let size_addr = self + .masm + .address_at_reg(base, heap_data.current_length_offset); + self.masm.load_ptr(size_addr, size_reg); + // Prepare the stack to emit a shift to get the size in pages rather + // than in bytes. + self.context + .stack + .push(TypedReg::new(heap_data.ty, size_reg).into()); + + // Since the page size is a power-of-two, verify that 2^16, equals the + // defined constant. This is mostly a safeguard in case the constant + // value ever changes. + let pow = 16; + debug_assert_eq!(2u32.pow(pow), WASM_PAGE_SIZE); + + // Ensure that the constant is correctly typed according to the heap + // type to reduce register pressure when emitting the shift operation. + match heap_data.ty { + WasmValType::I32 => self.context.stack.push(Val::i32(pow as i32)), + WasmValType::I64 => self.context.stack.push(Val::i64(pow as i64)), + _ => unreachable!(), + } + + self.masm + .shift(&mut self.context, ShiftKind::ShrU, heap_data.ty.into()); + } } /// Returns the index of the [`ControlStackFrame`] for the given diff --git a/winch/codegen/src/isa/aarch64/masm.rs b/winch/codegen/src/isa/aarch64/masm.rs index 8f2d205135..0a106d5829 100644 --- a/winch/codegen/src/isa/aarch64/masm.rs +++ b/winch/codegen/src/isa/aarch64/masm.rs @@ -1,7 +1,7 @@ use super::{abi::Aarch64ABI, address::Address, asm::Assembler, regs}; use crate::{ abi::{self, local::LocalSlot}, - codegen::{ptr_type_from_ptr_size, CodeGenContext, FuncEnv, HeapData, TableData}, + codegen::{ptr_type_from_ptr_size, CodeGenContext, FuncEnv}, isa::reg::Reg, masm::{ CalleeKind, DivKind, ExtendKind, FloatCmpKind, Imm as I, IntCmpKind, @@ -111,24 +111,6 @@ impl Masm for MacroAssembler { Address::offset(reg, offset as i64) } - fn table_elem_address( - &mut self, - _index: Reg, - _base: Reg, - _table_data: &TableData, - _context: &mut CodeGenContext, - ) -> Self::Address { - todo!() - } - - fn table_size(&mut self, _table_data: &TableData, _context: &mut CodeGenContext) { - todo!() - } - - fn memory_size(&mut self, _heap_data: &HeapData, _context: &mut CodeGenContext) { - todo!() - } - fn address_from_sp(&self, _offset: SPOffset) -> Self::Address { todo!() } diff --git a/winch/codegen/src/isa/x64/masm.rs b/winch/codegen/src/isa/x64/masm.rs index 4b68dde4a6..2300253b7d 100644 --- a/winch/codegen/src/isa/x64/masm.rs +++ b/winch/codegen/src/isa/x64/masm.rs @@ -11,13 +11,12 @@ use crate::masm::{ }; use crate::{ abi::{self, align_to, calculate_frame_adjustment, LocalSlot}, - codegen::{ptr_type_from_ptr_size, CodeGenContext, FuncEnv, HeapData, TableData}, + codegen::{ptr_type_from_ptr_size, CodeGenContext, FuncEnv}, stack::Val, }; use crate::{ abi::{vmctx, ABI}, masm::{SPOffset, StackSlot}, - stack::TypedReg, }; use crate::{ isa::reg::{Reg, RegClass}, @@ -34,7 +33,7 @@ use cranelift_codegen::{ settings, Final, MachBufferFinalized, MachLabel, }; -use wasmtime_environ::{PtrSize, WasmValType, WASM_PAGE_SIZE}; +use wasmtime_environ::{PtrSize, WasmValType}; /// x64 MacroAssembler. pub(crate) struct MacroAssembler { @@ -180,141 +179,6 @@ impl Masm for MacroAssembler { Address::offset(reg, offset) } - fn table_elem_address( - &mut self, - index: Reg, - ptr_base: Reg, - table_data: &TableData, - context: &mut CodeGenContext, - ) -> Self::Address { - let scratch = regs::scratch(); - let bound = context.any_gpr(self); - let tmp = context.any_gpr(self); - - if let Some(offset) = table_data.import_from { - // If the table data declares a particular offset base, - // load the address into a register to further use it as - // the table address. - self.asm.movzx_mr( - &self.address_at_vmctx(offset), - ptr_base, - self.ptr_size.into(), - TRUSTED_FLAGS, - ); - } else { - // Else, simply move the vmctx register into the addr register as - // the base to calculate the table address. - self.asm.mov_rr(vmctx!(Self), ptr_base, self.ptr_size); - }; - - // OOB check. - let bound_addr = self.address_at_reg(ptr_base, table_data.current_elems_offset); - let bound_size = table_data.current_elements_size; - self.asm - .movzx_mr(&bound_addr, bound, bound_size.into(), TRUSTED_FLAGS); - self.asm.cmp_rr(bound, index, bound_size); - self.asm.trapif(IntCmpKind::GeU, TrapCode::TableOutOfBounds); - - // Move the index into the scratch register to calcualte the table - // element address. - // Moving the value of the index register to the scratch register - // also avoids overwriting the context of the index register. - self.asm.mov_rr(index, scratch, bound_size); - self.asm.mul_ir( - table_data.element_size.bytes() as i32, - scratch, - table_data.element_size, - ); - self.asm.movzx_mr( - &self.address_at_reg(ptr_base, table_data.offset), - ptr_base, - self.ptr_size.into(), - TRUSTED_FLAGS, - ); - // Copy the value of the table base into a temporary register - // so that we can use it later in case of a misspeculation. - self.asm.mov_rr(ptr_base, tmp, self.ptr_size); - // Calculate the address of the table element. - self.asm.add_rr(scratch, ptr_base, self.ptr_size); - if self.shared_flags.enable_table_access_spectre_mitigation() { - // Perform a bounds check and override the value of the - // table element address in case the index is out of bounds. - self.asm.cmp_rr(bound, index, OperandSize::S32); - self.asm.cmov(tmp, ptr_base, IntCmpKind::GeU, self.ptr_size); - } - context.free_reg(bound); - context.free_reg(tmp); - self.address_at_reg(ptr_base, 0) - } - - fn table_size(&mut self, table_data: &TableData, context: &mut CodeGenContext) { - let scratch = regs::scratch(); - let size = context.any_gpr(self); - - if let Some(offset) = table_data.import_from { - self.asm.movzx_mr( - &self.address_at_vmctx(offset), - scratch, - self.ptr_size.into(), - TRUSTED_FLAGS, - ); - } else { - self.asm.mov_rr(vmctx!(Self), scratch, self.ptr_size); - }; - - let size_addr = Address::offset(scratch, table_data.current_elems_offset); - self.asm.movzx_mr( - &size_addr, - size, - table_data.current_elements_size.into(), - TRUSTED_FLAGS, - ); - - context.stack.push(TypedReg::i32(size).into()); - } - - fn memory_size(&mut self, heap_data: &HeapData, context: &mut CodeGenContext) { - let size_reg = context.any_gpr(self); - let scratch = regs::scratch(); - - let base = if let Some(offset) = heap_data.import_from { - self.asm.movzx_mr( - &self.address_at_vmctx(offset), - scratch, - self.ptr_size.into(), - TRUSTED_FLAGS, - ); - scratch - } else { - vmctx!(Self) - }; - - let size_addr = Address::offset(base, heap_data.current_length_offset); - self.asm - .movzx_mr(&size_addr, size_reg, self.ptr_size.into(), TRUSTED_FLAGS); - // Prepare the stack to emit a shift to get the size in pages rather - // than in bytes. - context - .stack - .push(TypedReg::new(heap_data.ty, size_reg).into()); - - // Since the page size is a power-of-two, verify that 2^16, equals the - // defined constant. This is mostly a safeguard in case the constant - // value ever changes. - let pow = 16; - debug_assert_eq!(2u32.pow(pow), WASM_PAGE_SIZE); - - // Ensure that the constant is correctly typed according to the heap - // type to reduce register pressure when emitting the shift operation. - match heap_data.ty { - WasmValType::I32 => context.stack.push(Val::i32(pow as i32)), - WasmValType::I64 => context.stack.push(Val::i64(pow as i64)), - _ => unreachable!(), - } - - self.shift(context, ShiftKind::ShrU, heap_data.ty.into()); - } - fn address_from_sp(&self, offset: SPOffset) -> Self::Address { Address::offset(regs::rsp(), self.sp_offset - offset.as_u32()) } diff --git a/winch/codegen/src/masm.rs b/winch/codegen/src/masm.rs index e04cf5cda1..5b86dc8fda 100644 --- a/winch/codegen/src/masm.rs +++ b/winch/codegen/src/masm.rs @@ -1,5 +1,5 @@ use crate::abi::{self, align_to, LocalSlot}; -use crate::codegen::{CodeGenContext, FuncEnv, HeapData, TableData}; +use crate::codegen::{CodeGenContext, FuncEnv}; use crate::isa::reg::Reg; use cranelift_codegen::{ binemit::CodeOffset, @@ -503,22 +503,6 @@ pub(crate) trait MacroAssembler { /// Get the address of a local slot. fn local_address(&mut self, local: &LocalSlot) -> Self::Address; - /// Loads the address of the table element at a given index. Returns the - /// address of the table element using the provided register as base. - fn table_elem_address( - &mut self, - index: Reg, - base: Reg, - table_data: &TableData, - context: &mut CodeGenContext, - ) -> Self::Address; - - /// Retrieves the size of the table, pushing the result to the value stack. - fn table_size(&mut self, table_data: &TableData, context: &mut CodeGenContext); - - /// Retrieves the size of the memory, pushing the result to the value stack. - fn memory_size(&mut self, heap_data: &HeapData, context: &mut CodeGenContext); - /// Constructs an address with an offset that is relative to the /// current position of the stack pointer (e.g. [sp + (sp_offset - /// offset)]. diff --git a/winch/codegen/src/visitor.rs b/winch/codegen/src/visitor.rs index 688e5b4bf5..c02abb763a 100644 --- a/winch/codegen/src/visitor.rs +++ b/winch/codegen/src/visitor.rs @@ -1469,7 +1469,7 @@ where fn visit_table_size(&mut self, table: u32) { let table_index = TableIndex::from_u32(table); let table_data = self.env.resolve_table_data(table_index); - self.masm.table_size(&table_data, &mut self.context); + self.emit_compute_table_size(&table_data); } fn visit_table_fill(&mut self, table: u32) { @@ -1505,13 +1505,8 @@ where let value = self.context.pop_to_reg(self.masm, None); let index = self.context.pop_to_reg(self.masm, None); let base = self.context.any_gpr(self.masm); - let elem_addr = self.masm.table_elem_address( - index.into(), - base, - &table_data, - &mut self.context, - ); - + let elem_addr = + self.emit_compute_table_elem_addr(index.into(), base, &table_data); // Set the initialized bit. self.masm.or( value.into(), @@ -1605,7 +1600,7 @@ where fn visit_memory_size(&mut self, mem: u32, _: u8) { let heap = self.env.resolve_heap(MemoryIndex::from_u32(mem)); - self.masm.memory_size(&heap, &mut self.context); + self.emit_compute_memory_size(&heap); } fn visit_memory_grow(&mut self, mem: u32, _: u8) {