From 867f5c1244191c3e68a1e433f0cec08146062a46 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 9 Aug 2022 09:26:33 -0500 Subject: [PATCH] Update behavior of zero-length lists/strings (#4648) The spec was expected to change to not bounds-check 0-byte lists/strings but has since been updated to match `memory.copy` which does indeed check the pointer for 0-byte copies. --- crates/environ/src/fact/trampoline.rs | 220 ++++++------------ crates/wasmtime/src/component/func/options.rs | 13 +- crates/wasmtime/src/component/func/typed.rs | 24 +- tests/all/component_model/func.rs | 14 +- .../component-model/adapter.wast | 2 +- 5 files changed, 89 insertions(+), 184 deletions(-) diff --git a/crates/environ/src/fact/trampoline.rs b/crates/environ/src/fact/trampoline.rs index fafbcc54e7..c00337f450 100644 --- a/crates/environ/src/fact/trampoline.rs +++ b/crates/environ/src/fact/trampoline.rs @@ -1533,8 +1533,18 @@ impl Compiler<'_, '_> { } fn validate_string_inbounds(&mut self, s: &WasmString<'_>, byte_len: u32) { + self.validate_memory_inbounds(s.opts, s.ptr.idx, byte_len, Trap::StringLengthOverflow) + } + + fn validate_memory_inbounds( + &mut self, + opts: &Options, + ptr_local: u32, + byte_len_local: u32, + trap: Trap, + ) { let extend_to_64 = |me: &mut Self| { - if !s.opts.memory64 { + if !opts.memory64 { me.instruction(I64ExtendI32U); } }; @@ -1546,7 +1556,7 @@ impl Compiler<'_, '_> { // arithmetic here is done always in 64-bits to accomodate 4G memories. // Additionally it's assumed that 64-bit memories never fill up // entirely. - self.instruction(MemorySize(s.opts.memory.unwrap().as_u32())); + self.instruction(MemorySize(opts.memory.unwrap().as_u32())); extend_to_64(self); self.instruction(I64Const(16)); self.instruction(I64Shl); @@ -1555,15 +1565,15 @@ impl Compiler<'_, '_> { // base pointer to the byte length. For 32-bit memories there's no need // to check for overflow since everything is extended to 64-bit, but for // 64-bit memories overflow is checked. - self.instruction(LocalGet(s.ptr.idx)); + self.instruction(LocalGet(ptr_local)); extend_to_64(self); - self.instruction(LocalGet(byte_len)); + self.instruction(LocalGet(byte_len_local)); extend_to_64(self); self.instruction(I64Add); - if s.opts.memory64 { + if opts.memory64 { let tmp = self.local_tee_new_tmp(ValType::I64); - self.instruction(LocalGet(s.ptr.idx)); - self.ptr_lt_u(s.opts); + self.instruction(LocalGet(ptr_local)); + self.ptr_lt_u(opts); self.instruction(BrIf(0)); self.instruction(LocalGet(tmp.idx)); self.free_temp_local(tmp); @@ -1576,7 +1586,7 @@ impl Compiler<'_, '_> { self.instruction(BrIf(1)); self.instruction(End); - self.trap(Trap::StringLengthOverflow); + self.trap(trap); self.instruction(End); } @@ -1619,15 +1629,20 @@ impl Compiler<'_, '_> { // `src_ptr` value is properly aligned. let src_mem = self.memory_operand(src_opts, src_ptr, src_align); - // Next the byte size of the allocation in the destination is - // determined. Note that this is pretty tricky because pointer widths - // could be changing and otherwise everything must stay within the - // 32-bit size-space. This internally will ensure that `src_len * - // dst_size` doesn't overflow 32-bits and will place the final result in - // `dst_byte_len` where `dst_byte_len` has the appropriate type for the - // destination. - let dst_byte_len = - self.calculate_dst_byte_len(src_len.idx, src_opts.ptr(), dst_opts.ptr(), dst_size); + // Calculate the source/destination byte lengths into unique locals. + let src_byte_len = self.calculate_list_byte_len(src_opts, src_len.idx, src_size); + let dst_byte_len = if src_size == dst_size { + self.convert_src_len_to_dst(src_byte_len.idx, src_opts.ptr(), dst_opts.ptr()); + self.local_set_new_tmp(dst_opts.ptr()) + } else if src_opts.ptr() == dst_opts.ptr() { + self.calculate_list_byte_len(dst_opts, src_len.idx, dst_size) + } else { + self.convert_src_len_to_dst(src_byte_len.idx, src_opts.ptr(), dst_opts.ptr()); + let tmp = self.local_set_new_tmp(dst_opts.ptr()); + let ret = self.calculate_list_byte_len(dst_opts, tmp.idx, dst_size); + self.free_temp_local(tmp); + ret + }; // Here `realloc` is invoked (in a `malloc`-like fashion) to allocate // space for the list in the destination memory. This will also @@ -1635,122 +1650,22 @@ impl Compiler<'_, '_> { // correctly for the destination. let dst_mem = self.malloc(dst_opts, MallocSize::Local(dst_byte_len.idx), dst_align); - // At this point we have aligned pointers, a length, and a byte length - // for the destination. The spec also requires this translation to - // ensure that the range of memory within the source and destination - // memories are valid. Currently though this attempts to optimize that - // somewhat at least. The thinking is that if we hit an out-of-bounds - // memory access during translation that's the same as a trap up-front. - // This means we can generally minimize up-front checks in favor of - // simply trying to load out-of-bounds memory. - // - // This doesn't mean we can avoid a check entirely though. One major - // worry here is integer overflow of the pointers in linear memory as - // they're incremented to move to the next element as part of - // translation. For example if the entire 32-bit address space were - // valid and the base pointer was `0xffff_fff0` where the size was 17 - // that should not be a valid list but "simply defer to the loop below" - // would cause a wraparound to occur and no trap would be detected. - // - // To solve this a check is inserted here that the `base + byte_len` - // calculation doesn't overflow the 32-bit address space. Note though - // that this is only done for 32-bit memories, not 64-bit memories. - // Given the iteration of the loop below the only worry is when the - // address space is 100% mapped and wraparound is possible. Otherwise if - // anything in the address space is unmapped then we're guaranteed to - // hit a trap as we march from the base pointer to the end of the array. - // It's assumed that it's impossible for a 64-bit memory to have the - // entire address space mapped, so this isn't a concern for 64-bit - // memories. - // - // Technically this is only a concern for 32-bit memories if the entire - // address space is mapped, so `memory.size` could be used to skip most - // of the check here but it's assume that the `memory.size` check is - // probably more expensive than just checking for 32-bit overflow by - // using 64-bit arithmetic. This should hypothetically be tested though! - // - // TODO: the most-optimal thing here is to probably, once per adapter, - // call `memory.size` and put that in a local. If that is not the - // maximum for a 32-bit memory then this entire bounds-check here can be - // skipped. - if !src_opts.memory64 && src_size > 0 { - self.instruction(LocalGet(src_mem.addr.idx)); - self.instruction(I64ExtendI32U); - if src_size < dst_size { - // If the source byte size is less than the destination size - // then we can leverage the fact that `dst_byte_len` was already - // calculated and didn't overflow so this is also guaranteed to - // not overflow. - self.instruction(LocalGet(src_len.idx)); - self.instruction(I64ExtendI32U); - if src_size != 1 { - self.instruction(I64Const(i64::try_from(src_size).unwrap())); - self.instruction(I64Mul); - } - } else if src_size == dst_size { - // If the source byte size is the same as the destination byte - // size then that can be reused. Note that the destination byte - // size is already guaranteed to fit in 32 bits, even if it's - // store in a 64-bit local. - self.instruction(LocalGet(dst_byte_len.idx)); - if dst_opts.ptr() == ValType::I32 { - self.instruction(I64ExtendI32U); - } - } else { - // Otherwise if the source byte size is larger than the - // destination byte size then the source byte size needs to be - // calculated fresh here. Note, though, that the result of this - // multiplication is not checked for overflow. The reason for - // that is that the result here flows into the check below about - // overflow and if this computation overflows it should be - // guaranteed to overflow the next computation. - // - // In general what's being checked here is: - // - // src_mem.addr_local + src_len * src_size - // - // These three values are all 32-bits originally and if they're - // all assumed to be `u32::MAX` then: - // - // let max = u64::from(u32::MAX); - // let result = max + max * max; - // assert_eq!(result, 0xffffffff00000000); - // - // This means that once an upper bit is set it's guaranteed to - // stay set as part of this computation, so the multiplication - // here is left unchecked to fall through into the addition - // below. - self.instruction(LocalGet(src_len.idx)); - self.instruction(I64ExtendI32U); - self.instruction(I64Const(i64::try_from(src_size).unwrap())); - self.instruction(I64Mul); - } - self.instruction(I64Add); - self.instruction(I64Const(32)); - self.instruction(I64ShrU); - self.instruction(I32WrapI64); - self.instruction(If(BlockType::Empty)); - self.trap(Trap::ListByteLengthOverflow); - self.instruction(End); - } - - // If the destination is a 32-bit memory then its overflow check is - // relatively simple since we've already calculated the byte length of - // the destination above and can reuse that in this check. - if !dst_opts.memory64 && dst_size > 0 { - self.instruction(LocalGet(dst_mem.addr.idx)); - self.instruction(I64ExtendI32U); - self.instruction(LocalGet(dst_byte_len.idx)); - self.instruction(I64ExtendI32U); - self.instruction(I64Add); - self.instruction(I64Const(32)); - self.instruction(I64ShrU); - self.instruction(I32WrapI64); - self.instruction(If(BlockType::Empty)); - self.trap(Trap::ListByteLengthOverflow); - self.instruction(End); - } + // With all the pointers and byte lengths verity that both the source + // and the destination buffers are in-bounds. + self.validate_memory_inbounds( + src_opts, + src_mem.addr.idx, + src_byte_len.idx, + Trap::ListByteLengthOverflow, + ); + self.validate_memory_inbounds( + dst_opts, + dst_mem.addr.idx, + dst_byte_len.idx, + Trap::ListByteLengthOverflow, + ); + self.free_temp_local(src_byte_len); self.free_temp_local(dst_byte_len); // This is the main body of the loop to actually translate list types. @@ -1840,22 +1755,17 @@ impl Compiler<'_, '_> { self.free_temp_local(dst_mem.addr); } - fn calculate_dst_byte_len( + fn calculate_list_byte_len( &mut self, - src_len_local: u32, - src_ptr_ty: ValType, - dst_ptr_ty: ValType, - dst_elt_size: usize, + opts: &Options, + len_local: u32, + elt_size: usize, ) -> TempLocal { // Zero-size types are easy to handle here because the byte size of the // destination is always zero. - if dst_elt_size == 0 { - if dst_ptr_ty == ValType::I64 { - self.instruction(I64Const(0)); - } else { - self.instruction(I32Const(0)); - } - return self.local_set_new_tmp(dst_ptr_ty); + if elt_size == 0 { + self.ptr_uconst(opts, 0); + return self.local_set_new_tmp(opts.ptr()); } // For one-byte elements in the destination the check here can be a bit @@ -1865,9 +1775,9 @@ impl Compiler<'_, '_> { // // If the source is 64-bit then all that needs to be checked is to // ensure that it does not have the upper 32-bits set. - if dst_elt_size == 1 { - if let ValType::I64 = src_ptr_ty { - self.instruction(LocalGet(src_len_local)); + if elt_size == 1 { + if let ValType::I64 = opts.ptr() { + self.instruction(LocalGet(len_local)); self.instruction(I64Const(32)); self.instruction(I64ShrU); self.instruction(I32WrapI64); @@ -1875,8 +1785,8 @@ impl Compiler<'_, '_> { self.trap(Trap::ListByteLengthOverflow); self.instruction(End); } - self.convert_src_len_to_dst(src_len_local, src_ptr_ty, dst_ptr_ty); - return self.local_set_new_tmp(dst_ptr_ty); + self.instruction(LocalGet(len_local)); + return self.local_set_new_tmp(opts.ptr()); } // The main check implemented by this function is to verify that @@ -1885,22 +1795,22 @@ impl Compiler<'_, '_> { // memories. self.instruction(Block(BlockType::Empty)); self.instruction(Block(BlockType::Empty)); - self.instruction(LocalGet(src_len_local)); - match src_ptr_ty { + self.instruction(LocalGet(len_local)); + match opts.ptr() { // The source's list length is guaranteed to be less than 32-bits // so simply extend it up to a 64-bit type for the multiplication // below. ValType::I32 => self.instruction(I64ExtendI32U), // If the source is a 64-bit memory then if the item length doesn't - // fit in 32-bits the byte length definitly won't, so generate a + // fit in 32-bits the byte length definitely won't, so generate a // branch to our overflow trap here if any of the upper 32-bits are set. ValType::I64 => { self.instruction(I64Const(32)); self.instruction(I64ShrU); self.instruction(I32WrapI64); self.instruction(BrIf(0)); - self.instruction(LocalGet(src_len_local)); + self.instruction(LocalGet(len_local)); } _ => unreachable!(), @@ -1914,7 +1824,7 @@ impl Compiler<'_, '_> { // // The result of the multiplication is saved into a local as well to // get the result afterwards. - self.instruction(I64Const(u32::try_from(dst_elt_size).unwrap().into())); + self.instruction(I64Const(u32::try_from(elt_size).unwrap().into())); self.instruction(I64Mul); let tmp = self.local_tee_new_tmp(ValType::I64); // Branch to success if the upper 32-bits are zero, otherwise @@ -1930,7 +1840,7 @@ impl Compiler<'_, '_> { // If a fresh local was used to store the result of the multiplication // then convert it down to 32-bits which should be guaranteed to not // lose information at this point. - if dst_ptr_ty == ValType::I64 { + if opts.ptr() == ValType::I64 { tmp } else { self.instruction(LocalGet(tmp.idx)); diff --git a/crates/wasmtime/src/component/func/options.rs b/crates/wasmtime/src/component/func/options.rs index af939a1c9b..7c07e59004 100644 --- a/crates/wasmtime/src/component/func/options.rs +++ b/crates/wasmtime/src/component/func/options.rs @@ -1,5 +1,6 @@ use crate::store::{StoreId, StoreOpaque}; use crate::StoreContextMut; +use crate::Trap; use anyhow::{bail, Result}; use std::ptr::NonNull; use wasmtime_environ::component::StringEncoding; @@ -96,19 +97,15 @@ impl Options { }; if result % old_align != 0 { - bail!("realloc return: result not aligned"); + bail!(Trap::new("realloc return: result not aligned")); } let result = usize::try_from(result)?; let memory = self.memory_mut(store.0); - let result_slice = if new_size == 0 { - &mut [] - } else { - match memory.get_mut(result..).and_then(|s| s.get_mut(..new_size)) { - Some(end) => end, - None => bail!("realloc return: beyond end of memory"), - } + let result_slice = match memory.get_mut(result..).and_then(|s| s.get_mut(..new_size)) { + Some(end) => end, + None => bail!(Trap::new("realloc return: beyond end of memory")), }; Ok((result_slice, result)) diff --git a/crates/wasmtime/src/component/func/typed.rs b/crates/wasmtime/src/component/func/typed.rs index ec65231f5b..ac4f63a1d0 100644 --- a/crates/wasmtime/src/component/func/typed.rs +++ b/crates/wasmtime/src/component/func/typed.rs @@ -875,9 +875,7 @@ fn lower_string(mem: &mut MemoryMut<'_, T>, string: &str) -> Result<(usize, u ); } let ptr = mem.realloc(0, 0, 1, string.len())?; - if string.len() > 0 { - mem.as_slice_mut()[ptr..][..string.len()].copy_from_slice(string.as_bytes()); - } + mem.as_slice_mut()[ptr..][..string.len()].copy_from_slice(string.as_bytes()); Ok((ptr, string.len())) } @@ -894,17 +892,15 @@ fn lower_string(mem: &mut MemoryMut<'_, T>, string: &str) -> Result<(usize, u } let mut ptr = mem.realloc(0, 0, 2, size)?; let mut copied = 0; - if size > 0 { - let bytes = &mut mem.as_slice_mut()[ptr..][..size]; - for (u, bytes) in string.encode_utf16().zip(bytes.chunks_mut(2)) { - let u_bytes = u.to_le_bytes(); - bytes[0] = u_bytes[0]; - bytes[1] = u_bytes[1]; - copied += 1; - } - if (copied * 2) < size { - ptr = mem.realloc(ptr, size, 2, copied * 2)?; - } + let bytes = &mut mem.as_slice_mut()[ptr..][..size]; + for (u, bytes) in string.encode_utf16().zip(bytes.chunks_mut(2)) { + let u_bytes = u.to_le_bytes(); + bytes[0] = u_bytes[0]; + bytes[1] = u_bytes[1]; + copied += 1; + } + if (copied * 2) < size { + ptr = mem.realloc(ptr, size, 2, copied * 2)?; } Ok((ptr, copied)) } diff --git a/tests/all/component_model/func.rs b/tests/all/component_model/func.rs index aa73a19ba9..6c30bd985d 100644 --- a/tests/all/component_model/func.rs +++ b/tests/all/component_model/func.rs @@ -1127,19 +1127,21 @@ fn some_traps() -> Result<()> { err, ); } - instance(&mut store)? + let err = instance(&mut store)? .get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-base-oob")? .call(&mut store, (&[],)) - .unwrap(); + .unwrap_err(); + assert_oob(&err); let err = instance(&mut store)? .get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-base-oob")? .call(&mut store, (&[1],)) .unwrap_err(); assert_oob(&err); - instance(&mut store)? + let err = instance(&mut store)? .get_typed_func::<(&str,), (), _>(&mut store, "take-string-base-oob")? .call(&mut store, ("",)) - .unwrap(); + .unwrap_err(); + assert_oob(&err); let err = instance(&mut store)? .get_typed_func::<(&str,), (), _>(&mut store, "take-string-base-oob")? .call(&mut store, ("x",)) @@ -1191,13 +1193,13 @@ fn some_traps() -> Result<()> { // For this function the first allocation, the space to store all the // arguments, is in-bounds but then all further allocations, such as for // each individual string, are all out of bounds. - instance(&mut store)? + let err = instance(&mut store)? .get_typed_func::<(&str, &str, &str, &str, &str, &str, &str, &str, &str, &str), (), _>( &mut store, "take-many-second-oob", )? .call(&mut store, ("", "", "", "", "", "", "", "", "", "")) - .unwrap(); + .unwrap_err(); assert_oob(&err); let err = instance(&mut store)? .get_typed_func::<(&str, &str, &str, &str, &str, &str, &str, &str, &str, &str), (), _>( diff --git a/tests/misc_testsuite/component-model/adapter.wast b/tests/misc_testsuite/component-model/adapter.wast index 96eaf258b2..69d73620de 100644 --- a/tests/misc_testsuite/component-model/adapter.wast +++ b/tests/misc_testsuite/component-model/adapter.wast @@ -130,4 +130,4 @@ ) (export "empty-list" (func $f)) ) -(assert_return (invoke "empty-list" (list.const)) (unit.const)) +(assert_trap (invoke "empty-list" (list.const)) "realloc return: beyond end of memory")