From c78416912c97d7beed1ac54bfa5338149fb40d15 Mon Sep 17 00:00:00 2001 From: Pat Hickey Date: Fri, 6 Mar 2020 16:04:56 -0800 Subject: [PATCH] Check safety of `as_raw` with a simplified borrow checker (#37) * wiggle-runtime: add as_raw method for [T] * add trivial borrow checker back in * integrate runtime borrow checker with as_raw methods * handle pointer arith overflow correctly in as_raw, create PtrOverflow error * runtime: add validation back to GuestType * generate: impl validate for enums, flags, handles, ints * oops! make validate its own method on trait GuestTypeTransparent * fix transparent impls for enum, flag, handle, int * some structs are transparent. fix tests. * tests: define byte_slice_strat and friends * wiggle-tests: i believe my allocator is working now * some type juggling around memset for ease of use * make GuestTypeTransparent an unsafe trait * delete redundant validation of pointer align * fix doc * wiggle_test: aha, you cant use sets to track memory areas * add multi-string test which exercises the runtime borrow checker against HostMemory::byte_slice_strat * oops left debug panic in * remove redundant (& incorrect, since unchecked) length calc * redesign validate again, and actually hook to as_raw * makr all validate impls as inline this should hopefully allow as_raw's check loop to be unrolled to a no-op in most cases! * code review fixes --- crates/generate/src/lifetimes.rs | 27 ++++ crates/generate/src/types/enum.rs | 16 ++- crates/generate/src/types/flags.rs | 18 ++- crates/generate/src/types/handle.rs | 11 ++ crates/generate/src/types/int.rs | 10 ++ crates/generate/src/types/struct.rs | 28 ++++ crates/runtime/src/borrow.rs | 91 ++++++++++++ crates/runtime/src/error.rs | 2 + crates/runtime/src/guest_type.rs | 25 ++++ crates/runtime/src/lib.rs | 75 +++++++++- crates/test/src/lib.rs | 206 ++++++++++++++++++++++++++-- tests/arrays.rs | 8 +- tests/flags.rs | 2 +- tests/pointers.rs | 8 +- tests/strings.rs | 144 ++++++++++++++++++- tests/strings.witx | 8 ++ tests/structs.rs | 22 +-- tests/union.rs | 4 +- 18 files changed, 655 insertions(+), 50 deletions(-) create mode 100644 crates/runtime/src/borrow.rs diff --git a/crates/generate/src/lifetimes.rs b/crates/generate/src/lifetimes.rs index ae9631a63b..75b102209c 100644 --- a/crates/generate/src/lifetimes.rs +++ b/crates/generate/src/lifetimes.rs @@ -2,16 +2,34 @@ use proc_macro2::TokenStream; use quote::quote; pub trait LifetimeExt { + fn is_transparent(&self) -> bool; fn needs_lifetime(&self) -> bool; } impl LifetimeExt for witx::TypeRef { + fn is_transparent(&self) -> bool { + self.type_().is_transparent() + } fn needs_lifetime(&self) -> bool { self.type_().needs_lifetime() } } impl LifetimeExt for witx::Type { + fn is_transparent(&self) -> bool { + match self { + witx::Type::Builtin(b) => b.is_transparent(), + witx::Type::Struct(s) => s.is_transparent(), + witx::Type::Enum { .. } + | witx::Type::Flags { .. } + | witx::Type::Int { .. } + | witx::Type::Handle { .. } => true, + witx::Type::Union { .. } + | witx::Type::Pointer { .. } + | witx::Type::ConstPointer { .. } + | witx::Type::Array { .. } => false, + } + } fn needs_lifetime(&self) -> bool { match self { witx::Type::Builtin(b) => b.needs_lifetime(), @@ -29,6 +47,9 @@ impl LifetimeExt for witx::Type { } impl LifetimeExt for witx::BuiltinType { + fn is_transparent(&self) -> bool { + !self.needs_lifetime() + } fn needs_lifetime(&self) -> bool { match self { witx::BuiltinType::String => true, @@ -38,12 +59,18 @@ impl LifetimeExt for witx::BuiltinType { } impl LifetimeExt for witx::StructDatatype { + fn is_transparent(&self) -> bool { + self.members.iter().all(|m| m.tref.is_transparent()) + } fn needs_lifetime(&self) -> bool { self.members.iter().any(|m| m.tref.needs_lifetime()) } } impl LifetimeExt for witx::UnionDatatype { + fn is_transparent(&self) -> bool { + false + } fn needs_lifetime(&self) -> bool { self.variants .iter() diff --git a/crates/generate/src/types/enum.rs b/crates/generate/src/types/enum.rs index ec9ab021b0..8aa77f8682 100644 --- a/crates/generate/src/types/enum.rs +++ b/crates/generate/src/types/enum.rs @@ -87,8 +87,9 @@ pub(super) fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype fn read(location: &wiggle_runtime::GuestPtr<#ident>) -> Result<#ident, wiggle_runtime::GuestError> { use std::convert::TryFrom; - let val = #repr::read(&location.cast())?; - #ident::try_from(val) + let reprval = #repr::read(&location.cast())?; + let value = #ident::try_from(reprval)?; + Ok(value) } fn write(location: &wiggle_runtime::GuestPtr<'_, #ident>, val: Self) @@ -97,5 +98,16 @@ pub(super) fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype #repr::write(&location.cast(), #repr::from(val)) } } + + unsafe impl <'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident { + #[inline] + fn validate(location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> { + use std::convert::TryFrom; + // Validate value in memory using #ident::try_from(reprval) + let reprval = unsafe { (location as *mut #repr).read() }; + let _val = #ident::try_from(reprval)?; + Ok(()) + } + } } } diff --git a/crates/generate/src/types/flags.rs b/crates/generate/src/types/flags.rs index 201e0a4529..60aa671310 100644 --- a/crates/generate/src/types/flags.rs +++ b/crates/generate/src/types/flags.rs @@ -134,10 +134,11 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty #repr::guest_align() } - fn read(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<#ident, wiggle_runtime::GuestError> { + fn read(location: &wiggle_runtime::GuestPtr<#ident>) -> Result<#ident, wiggle_runtime::GuestError> { use std::convert::TryFrom; - let bits = #repr::read(&location.cast())?; - #ident::try_from(bits) + let reprval = #repr::read(&location.cast())?; + let value = #ident::try_from(reprval)?; + Ok(value) } fn write(location: &wiggle_runtime::GuestPtr<'_, #ident>, val: Self) -> Result<(), wiggle_runtime::GuestError> { @@ -145,5 +146,16 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty #repr::write(&location.cast(), val) } } + unsafe impl <'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident { + #[inline] + fn validate(location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> { + use std::convert::TryFrom; + // Validate value in memory using #ident::try_from(reprval) + let reprval = unsafe { (location as *mut #repr).read() }; + let _val = #ident::try_from(reprval)?; + Ok(()) + } + } + } } diff --git a/crates/generate/src/types/handle.rs b/crates/generate/src/types/handle.rs index 294e36a028..4f922cbd7b 100644 --- a/crates/generate/src/types/handle.rs +++ b/crates/generate/src/types/handle.rs @@ -13,6 +13,7 @@ pub(super) fn define_handle( let size = h.mem_size_align().size as u32; let align = h.mem_size_align().align as usize; quote! { + #[repr(transparent)] #[derive(Copy, Clone, Debug, ::std::hash::Hash, Eq, PartialEq)] pub struct #ident(u32); @@ -62,5 +63,15 @@ pub(super) fn define_handle( u32::write(&location.cast(), val.0) } } + + unsafe impl<'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident { + #[inline] + fn validate(_location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> { + // All bit patterns accepted + Ok(()) + } + } + + } } diff --git a/crates/generate/src/types/int.rs b/crates/generate/src/types/int.rs index 25375700b9..ee870eb6a6 100644 --- a/crates/generate/src/types/int.rs +++ b/crates/generate/src/types/int.rs @@ -73,11 +73,21 @@ pub(super) fn define_int(names: &Names, name: &witx::Id, i: &witx::IntDatatype) fn read(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<#ident, wiggle_runtime::GuestError> { Ok(#ident(#repr::read(&location.cast())?)) + } fn write(location: &wiggle_runtime::GuestPtr<'_, #ident>, val: Self) -> Result<(), wiggle_runtime::GuestError> { #repr::write(&location.cast(), val.0) } } + + unsafe impl<'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident { + #[inline] + fn validate(_location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> { + // All bit patterns accepted + Ok(()) + } + } + } } diff --git a/crates/generate/src/types/struct.rs b/crates/generate/src/types/struct.rs index 11a6bdf06f..d0c3b03cdd 100644 --- a/crates/generate/src/types/struct.rs +++ b/crates/generate/src/types/struct.rs @@ -77,6 +77,32 @@ pub(super) fn define_struct( (quote!(), quote!(, Copy, PartialEq)) }; + let transparent = if s.is_transparent() { + let member_validate = s.member_layout().into_iter().map(|ml| { + let offset = ml.offset; + let typename = names.type_ref(&ml.member.tref, anon_lifetime()); + quote! { + // SAFETY: caller has validated bounds and alignment of `location`. + // member_layout gives correctly-aligned pointers inside that area. + #typename::validate( + unsafe { (location as *mut u8).add(#offset) as *mut _ } + )?; + } + }); + + quote! { + unsafe impl<'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident { + #[inline] + fn validate(location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> { + #(#member_validate)* + Ok(()) + } + } + } + } else { + quote!() + }; + quote! { #[derive(Clone, Debug #extra_derive)] pub struct #ident #struct_lifetime { @@ -102,5 +128,7 @@ pub(super) fn define_struct( Ok(()) } } + + #transparent } } diff --git a/crates/runtime/src/borrow.rs b/crates/runtime/src/borrow.rs new file mode 100644 index 0000000000..5c3c80f429 --- /dev/null +++ b/crates/runtime/src/borrow.rs @@ -0,0 +1,91 @@ +use crate::region::Region; +use crate::GuestError; + +#[derive(Debug)] +pub struct GuestBorrows { + borrows: Vec, +} + +impl GuestBorrows { + pub fn new() -> Self { + Self { + borrows: Vec::new(), + } + } + + fn is_borrowed(&self, r: Region) -> bool { + !self.borrows.iter().all(|b| !b.overlaps(r)) + } + + pub fn borrow(&mut self, r: Region) -> Result<(), GuestError> { + if self.is_borrowed(r) { + Err(GuestError::PtrBorrowed(r)) + } else { + self.borrows.push(r); + Ok(()) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn nonoverlapping() { + let mut bs = GuestBorrows::new(); + let r1 = Region::new(0, 10); + let r2 = Region::new(10, 10); + assert!(!r1.overlaps(r2)); + bs.borrow(r1).expect("can borrow r1"); + bs.borrow(r2).expect("can borrow r2"); + + let mut bs = GuestBorrows::new(); + let r1 = Region::new(10, 10); + let r2 = Region::new(0, 10); + assert!(!r1.overlaps(r2)); + bs.borrow(r1).expect("can borrow r1"); + bs.borrow(r2).expect("can borrow r2"); + } + + #[test] + fn overlapping() { + let mut bs = GuestBorrows::new(); + let r1 = Region::new(0, 10); + let r2 = Region::new(9, 10); + assert!(r1.overlaps(r2)); + bs.borrow(r1).expect("can borrow r1"); + assert!(bs.borrow(r2).is_err(), "cant borrow r2"); + + let mut bs = GuestBorrows::new(); + let r1 = Region::new(0, 10); + let r2 = Region::new(2, 5); + assert!(r1.overlaps(r2)); + bs.borrow(r1).expect("can borrow r1"); + assert!(bs.borrow(r2).is_err(), "cant borrow r2"); + + let mut bs = GuestBorrows::new(); + let r1 = Region::new(9, 10); + let r2 = Region::new(0, 10); + assert!(r1.overlaps(r2)); + bs.borrow(r1).expect("can borrow r1"); + assert!(bs.borrow(r2).is_err(), "cant borrow r2"); + + let mut bs = GuestBorrows::new(); + let r1 = Region::new(2, 5); + let r2 = Region::new(0, 10); + assert!(r1.overlaps(r2)); + bs.borrow(r1).expect("can borrow r1"); + assert!(bs.borrow(r2).is_err(), "cant borrow r2"); + + let mut bs = GuestBorrows::new(); + let r1 = Region::new(2, 5); + let r2 = Region::new(10, 5); + let r3 = Region::new(15, 5); + let r4 = Region::new(0, 10); + assert!(r1.overlaps(r4)); + bs.borrow(r1).expect("can borrow r1"); + bs.borrow(r2).expect("can borrow r2"); + bs.borrow(r3).expect("can borrow r3"); + assert!(bs.borrow(r4).is_err(), "cant borrow r4"); + } +} diff --git a/crates/runtime/src/error.rs b/crates/runtime/src/error.rs index 14c48be1b9..8a163dbb63 100644 --- a/crates/runtime/src/error.rs +++ b/crates/runtime/src/error.rs @@ -7,6 +7,8 @@ pub enum GuestError { InvalidFlagValue(&'static str), #[error("Invalid enum value {0}")] InvalidEnumValue(&'static str), + #[error("Pointer overflow")] + PtrOverflow, #[error("Pointer out of bounds: {0:?}")] PtrOutOfBounds(Region), #[error("Pointer not aligned to {1}: {0:?}")] diff --git a/crates/runtime/src/guest_type.rs b/crates/runtime/src/guest_type.rs index c7517bf0ea..981301da46 100644 --- a/crates/runtime/src/guest_type.rs +++ b/crates/runtime/src/guest_type.rs @@ -40,6 +40,22 @@ pub trait GuestType<'a>: Sized { fn write(ptr: &GuestPtr<'_, Self>, val: Self) -> Result<(), GuestError>; } +/// A trait for `GuestType`s that have the same representation in guest memory +/// as in Rust. These types can be used with the `GuestPtr::as_raw` method to +/// view as a slice. +/// +/// Unsafe trait because a correct GuestTypeTransparent implemengation ensures that the +/// GuestPtr::as_raw methods are safe. This trait should only ever be implemented +/// by wiggle_generate-produced code. +pub unsafe trait GuestTypeTransparent<'a>: GuestType<'a> { + /// Checks that the memory at `ptr` is a valid representation of `Self`. + /// + /// Assumes that memory safety checks have already been performed: `ptr` + /// has been checked to be aligned correctly and reside in memory using + /// `GuestMemory::validate_size_align` + fn validate(ptr: *mut Self) -> Result<(), GuestError>; +} + macro_rules! primitives { ($($i:ident)*) => ($( impl<'a> GuestType<'a> for $i { @@ -78,6 +94,15 @@ macro_rules! primitives { Ok(()) } } + + unsafe impl<'a> GuestTypeTransparent<'a> for $i { + #[inline] + fn validate(_ptr: *mut $i) -> Result<(), GuestError> { + // All bit patterns are safe, nothing to do here + Ok(()) + } + } + )*) } diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index b57a5bfe08..a101a19a13 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -6,11 +6,14 @@ use std::slice; use std::str; use std::sync::Arc; +mod borrow; mod error; mod guest_type; mod region; + +pub use borrow::GuestBorrows; pub use error::GuestError; -pub use guest_type::{GuestErrorType, GuestType}; +pub use guest_type::{GuestErrorType, GuestType, GuestTypeTransparent}; pub use region::Region; /// A trait which abstracts how to get at the region of host memory taht @@ -119,12 +122,12 @@ pub unsafe trait GuestMemory { // Figure out our pointer to the start of memory let start = match (base_ptr as usize).checked_add(offset as usize) { Some(ptr) => ptr, - None => return Err(GuestError::PtrOutOfBounds(region)), + None => return Err(GuestError::PtrOverflow), }; // and use that to figure out the end pointer let end = match start.checked_add(len as usize) { Some(ptr) => ptr, - None => return Err(GuestError::PtrOutOfBounds(region)), + None => return Err(GuestError::PtrOverflow), }; // and then verify that our end doesn't reach past the end of our memory if end > (base_ptr as usize) + (base_len as usize) { @@ -335,7 +338,7 @@ impl<'a, T: ?Sized + Pointee> GuestPtr<'a, T> { .and_then(|o| self.pointer.checked_add(o)); let offset = match offset { Some(o) => o, - None => return Err(GuestError::InvalidFlagValue("")), + None => return Err(GuestError::PtrOverflow), }; Ok(GuestPtr::new(self.mem, offset)) } @@ -369,6 +372,54 @@ impl<'a, T> GuestPtr<'a, [T]> { (0..self.len()).map(move |i| base.add(i)) } + /// Attempts to read a raw `*mut [T]` pointer from this pointer, performing + /// bounds checks and type validation. + /// The resulting `*mut [T]` can be used as a `&mut [t]` as long as the + /// reference is dropped before any Wasm code is re-entered. + /// + /// This function will return a raw pointer into host memory if all checks + /// succeed (valid utf-8, valid pointers, etc). If any checks fail then + /// `GuestError` will be returned. + /// + /// Note that the `*mut [T]` pointer is still unsafe to use in general, but + /// there are specific situations that it is safe to use. For more + /// information about using the raw pointer, consult the [`GuestMemory`] + /// trait documentation. + /// + /// For safety against overlapping mutable borrows, the user must use the + /// same `GuestBorrows` to create all *mut str or *mut [T] that are alive + /// at the same time. + pub fn as_raw(&self, bc: &mut GuestBorrows) -> Result<*mut [T], GuestError> + where + T: GuestTypeTransparent<'a>, + { + let len = match self.pointer.1.checked_mul(T::guest_size()) { + Some(l) => l, + None => return Err(GuestError::PtrOverflow), + }; + let ptr = + self.mem + .validate_size_align(self.pointer.0, T::guest_align(), len)? as *mut T; + + bc.borrow(Region { + start: self.pointer.0, + len, + })?; + + // Validate all elements in slice. + // SAFETY: ptr has been validated by self.mem.validate_size_align + for offs in 0..self.pointer.1 { + T::validate(unsafe { ptr.add(offs as usize) })?; + } + + // SAFETY: iff there are no overlapping borrows (all uses of as_raw use this same + // GuestBorrows), its valid to construct a *mut [T] + unsafe { + let s = slice::from_raw_parts_mut(ptr, self.pointer.1 as usize); + Ok(s as *mut [T]) + } + } + /// Returns a `GuestPtr` pointing to the base of the array for the interior /// type `T`. pub fn as_ptr(&self) -> GuestPtr<'a, T> { @@ -396,6 +447,8 @@ impl<'a> GuestPtr<'a, str> { /// Attempts to read a raw `*mut str` pointer from this pointer, performing /// bounds checks and utf-8 checks. + /// The resulting `*mut str` can be used as a `&mut str` as long as the + /// reference is dropped before any Wasm code is re-entered. /// /// This function will return a raw pointer into host memory if all checks /// succeed (valid utf-8, valid pointers, etc). If any checks fail then @@ -405,12 +458,22 @@ impl<'a> GuestPtr<'a, str> { /// there are specific situations that it is safe to use. For more /// information about using the raw pointer, consult the [`GuestMemory`] /// trait documentation. - pub fn as_raw(&self) -> Result<*mut str, GuestError> { + /// + /// For safety against overlapping mutable borrows, the user must use the + /// same `GuestBorrows` to create all *mut str or *mut [T] that are alive + /// at the same time. + pub fn as_raw(&self, bc: &mut GuestBorrows) -> Result<*mut str, GuestError> { let ptr = self .mem .validate_size_align(self.pointer.0, 1, self.pointer.1)?; - // TODO: doc unsafety here + bc.borrow(Region { + start: self.pointer.0, + len: self.pointer.1, + })?; + + // SAFETY: iff there are no overlapping borrows (all uses of as_raw use this same + // GuestBorrows), its valid to construct a *mut str unsafe { let s = slice::from_raw_parts_mut(ptr, self.pointer.1 as usize); match str::from_utf8_mut(s) { diff --git a/crates/test/src/lib.rs b/crates/test/src/lib.rs index c946df73ba..de4474a528 100644 --- a/crates/test/src/lib.rs +++ b/crates/test/src/lib.rs @@ -2,6 +2,45 @@ use proptest::prelude::*; use std::cell::UnsafeCell; use wiggle_runtime::GuestMemory; +#[derive(Debug, Clone)] +pub struct MemAreas(Vec); +impl MemAreas { + pub fn new() -> Self { + MemAreas(Vec::new()) + } + pub fn insert(&mut self, a: MemArea) { + // Find if `a` is already in the vector + match self.0.binary_search(&a) { + // It is present - insert it next to existing one + Ok(loc) => self.0.insert(loc, a), + // It is not present - heres where to insert it + Err(loc) => self.0.insert(loc, a), + } + } + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } +} + +impl From for MemAreas +where + R: AsRef<[MemArea]>, +{ + fn from(ms: R) -> MemAreas { + let mut out = MemAreas::new(); + for m in ms.as_ref().into_iter() { + out.insert(*m); + } + out + } +} + +impl Into> for MemAreas { + fn into(self) -> Vec { + self.0.clone() + } +} + #[repr(align(4096))] pub struct HostMemory { buffer: UnsafeCell<[u8; 4096]>, @@ -26,6 +65,42 @@ impl HostMemory { }) .boxed() } + + /// Takes a sorted list or memareas, and gives a sorted list of memareas covering + /// the parts of memory not covered by the previous + pub fn invert(regions: &MemAreas) -> MemAreas { + let mut out = MemAreas::new(); + let mut start = 0; + for r in regions.iter() { + let len = r.ptr - start; + if len > 0 { + out.insert(MemArea { + ptr: start, + len: r.ptr - start, + }); + } + start = r.ptr + r.len; + } + if start < 4096 { + out.insert(MemArea { + ptr: start, + len: 4096 - start, + }); + } + out + } + + pub fn byte_slice_strat(size: u32, exclude: &MemAreas) -> BoxedStrategy { + let available: Vec = Self::invert(exclude) + .iter() + .flat_map(|a| a.inside(size)) + .collect(); + + Just(available) + .prop_filter("available memory for allocation", |a| !a.is_empty()) + .prop_flat_map(|a| prop::sample::select(a)) + .boxed() + } } unsafe impl GuestMemory for HostMemory { @@ -37,7 +112,7 @@ unsafe impl GuestMemory for HostMemory { } } -#[derive(Debug)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct MemArea { pub ptr: u32, pub len: u32, @@ -48,7 +123,7 @@ impl MemArea { // test. // So, I implemented this one with std::ops::Range so it is less likely I wrote the same bug in two // places. - pub fn overlapping(&self, b: &Self) -> bool { + pub fn overlapping(&self, b: Self) -> bool { // a_range is all elems in A let a_range = std::ops::Range { start: self.ptr, @@ -73,18 +148,33 @@ impl MemArea { } return false; } - pub fn non_overlapping_set(areas: &[&Self]) -> bool { - // A is all areas - for (i, a) in areas.iter().enumerate() { - // (A, B) is every pair of areas - for b in areas[i + 1..].iter() { - if a.overlapping(b) { - return false; + pub fn non_overlapping_set(areas: M) -> bool + where + M: Into, + { + let areas = areas.into(); + for (aix, a) in areas.iter().enumerate() { + for (bix, b) in areas.iter().enumerate() { + if aix != bix { + // (A, B) is every pairing of areas + if a.overlapping(*b) { + return false; + } } } } return true; } + + /// Enumerate all memareas of size `len` inside a given area + fn inside(&self, len: u32) -> impl Iterator { + let end: i64 = self.len as i64 - len as i64; + let start = self.ptr; + (0..end).into_iter().map(move |v| MemArea { + ptr: start + v as u32, + len, + }) + } } #[cfg(test)] @@ -97,6 +187,104 @@ mod test { let h = Box::new(h); assert_eq!(h.base().0 as usize % 4096, 0); } + + #[test] + fn invert() { + fn invert_equality(input: &[MemArea], expected: &[MemArea]) { + let input: MemAreas = input.into(); + let inverted: Vec = HostMemory::invert(&input).into(); + assert_eq!(expected, inverted.as_slice()); + } + + invert_equality(&[], &[MemArea { ptr: 0, len: 4096 }]); + invert_equality( + &[MemArea { ptr: 0, len: 1 }], + &[MemArea { ptr: 1, len: 4095 }], + ); + + invert_equality( + &[MemArea { ptr: 1, len: 1 }], + &[MemArea { ptr: 0, len: 1 }, MemArea { ptr: 2, len: 4094 }], + ); + + invert_equality( + &[MemArea { ptr: 1, len: 4095 }], + &[MemArea { ptr: 0, len: 1 }], + ); + + invert_equality( + &[MemArea { ptr: 0, len: 1 }, MemArea { ptr: 1, len: 4095 }], + &[], + ); + + invert_equality( + &[MemArea { ptr: 1, len: 2 }, MemArea { ptr: 4, len: 1 }], + &[ + MemArea { ptr: 0, len: 1 }, + MemArea { ptr: 3, len: 1 }, + MemArea { ptr: 5, len: 4091 }, + ], + ); + } + + fn set_of_slices_strat( + s1: u32, + s2: u32, + s3: u32, + ) -> BoxedStrategy<(MemArea, MemArea, MemArea)> { + HostMemory::byte_slice_strat(s1, &MemAreas::new()) + .prop_flat_map(move |a1| { + ( + Just(a1), + HostMemory::byte_slice_strat(s2, &MemAreas::from(&[a1])), + ) + }) + .prop_flat_map(move |(a1, a2)| { + ( + Just(a1), + Just(a2), + HostMemory::byte_slice_strat(s3, &MemAreas::from(&[a1, a2])), + ) + }) + .boxed() + } + + #[test] + fn trivial_inside() { + let a = MemArea { ptr: 24, len: 4072 }; + let interior = a.inside(24).collect::>(); + + assert!(interior.len() > 0); + } + + proptest! { + #[test] + // For some random region of decent size + fn inside(r in HostMemory::mem_area_strat(123)) { + let set_of_r = MemAreas::from(&[r]); + // All regions outside of r: + let exterior = HostMemory::invert(&set_of_r); + // All regions inside of r: + let interior = r.inside(22); + for i in interior { + // i overlaps with r: + assert!(r.overlapping(i)); + // i is inside r: + assert!(i.ptr >= r.ptr); + assert!(r.ptr + r.len >= i.ptr + i.len); + // the set of exterior and i is non-overlapping + let mut all = exterior.clone(); + all.insert(i); + assert!(MemArea::non_overlapping_set(all)); + } + } + + #[test] + fn byte_slices((s1, s2, s3) in set_of_slices_strat(12, 34, 56)) { + let all = MemAreas::from(&[s1, s2, s3]); + assert!(MemArea::non_overlapping_set(all)); + } + } } use std::cell::RefCell; diff --git a/tests/arrays.rs b/tests/arrays.rs index d9ccd4ce93..7e6e7b332f 100644 --- a/tests/arrays.rs +++ b/tests/arrays.rs @@ -67,9 +67,9 @@ impl ReduceExcusesExcercise { }, ) .prop_filter("non-overlapping pointers", |e| { - let mut all = vec![&e.array_ptr_loc, &e.return_ptr_loc]; + let mut all = vec![e.array_ptr_loc, e.return_ptr_loc]; all.extend(e.excuse_ptr_locs.iter()); - MemArea::non_overlapping_set(&all) + MemArea::non_overlapping_set(all) }) .boxed() } @@ -155,9 +155,9 @@ impl PopulateExcusesExcercise { elements, }) .prop_filter("non-overlapping pointers", |e| { - let mut all = vec![&e.array_ptr_loc]; + let mut all = vec![e.array_ptr_loc]; all.extend(e.elements.iter()); - MemArea::non_overlapping_set(&all) + MemArea::non_overlapping_set(all) }) .boxed() } diff --git a/tests/flags.rs b/tests/flags.rs index d008dcb03c..d00c5b84f0 100644 --- a/tests/flags.rs +++ b/tests/flags.rs @@ -57,7 +57,7 @@ impl ConfigureCarExercise { }, ) .prop_filter("non-overlapping ptrs", |e| { - MemArea::non_overlapping_set(&[&e.other_config_by_ptr, &e.return_ptr_loc]) + MemArea::non_overlapping_set(&[e.other_config_by_ptr, e.return_ptr_loc]) }) .boxed() } diff --git a/tests/pointers.rs b/tests/pointers.rs index dc26ed33f6..f15a53349e 100644 --- a/tests/pointers.rs +++ b/tests/pointers.rs @@ -117,10 +117,10 @@ impl PointersAndEnumsExercise { ) .prop_filter("non-overlapping pointers", |e| { MemArea::non_overlapping_set(&[ - &e.input2_loc, - &e.input3_loc, - &e.input4_loc, - &e.input4_ptr_loc, + e.input2_loc, + e.input3_loc, + e.input4_loc, + e.input4_ptr_loc, ]) }) .boxed() diff --git a/tests/strings.rs b/tests/strings.rs index dc6526326e..7cf0badf14 100644 --- a/tests/strings.rs +++ b/tests/strings.rs @@ -1,6 +1,6 @@ use proptest::prelude::*; -use wiggle_runtime::{GuestError, GuestMemory, GuestPtr}; -use wiggle_test::{impl_errno, HostMemory, MemArea, WasiCtx}; +use wiggle_runtime::{GuestBorrows, GuestError, GuestMemory, GuestPtr}; +use wiggle_test::{impl_errno, HostMemory, MemArea, MemAreas, WasiCtx}; wiggle::from_witx!({ witx: ["tests/strings.witx"], @@ -11,12 +11,33 @@ impl_errno!(types::Errno); impl strings::Strings for WasiCtx { fn hello_string(&self, a_string: &GuestPtr) -> Result { - let s = a_string.as_raw().expect("should be valid string"); + let mut bc = GuestBorrows::new(); + let s = a_string.as_raw(&mut bc).expect("should be valid string"); unsafe { println!("a_string='{}'", &*s); Ok((*s).len() as u32) } } + + fn multi_string( + &self, + a: &GuestPtr, + b: &GuestPtr, + c: &GuestPtr, + ) -> Result { + let mut bc = GuestBorrows::new(); + let sa = a.as_raw(&mut bc).expect("A should be valid string"); + let sb = b.as_raw(&mut bc).expect("B should be valid string"); + let sc = c.as_raw(&mut bc).expect("C should be valid string"); + unsafe { + let total_len = (&*sa).len() + (&*sb).len() + (&*sc).len(); + println!( + "len={}, a='{}', b='{}', c='{}'", + total_len, &*sa, &*sb, &*sc + ); + Ok(total_len as u32) + } + } } fn test_string_strategy() -> impl Strategy { @@ -46,7 +67,7 @@ impl HelloStringExercise { return_ptr_loc, }) .prop_filter("non-overlapping pointers", |e| { - MemArea::non_overlapping_set(&[&e.string_ptr_loc, &e.return_ptr_loc]) + MemArea::non_overlapping_set(&[e.string_ptr_loc, e.return_ptr_loc]) }) .boxed() } @@ -85,3 +106,118 @@ proptest! { e.test() } } + +#[derive(Debug)] +struct MultiStringExercise { + a: String, + b: String, + c: String, + sa_ptr_loc: MemArea, + sb_ptr_loc: MemArea, + sc_ptr_loc: MemArea, + return_ptr_loc: MemArea, +} + +impl MultiStringExercise { + pub fn strat() -> BoxedStrategy { + ( + test_string_strategy(), + test_string_strategy(), + test_string_strategy(), + HostMemory::mem_area_strat(4), + ) + .prop_flat_map(|(a, b, c, return_ptr_loc)| { + ( + Just(a.clone()), + Just(b.clone()), + Just(c.clone()), + HostMemory::byte_slice_strat(a.len() as u32, &MemAreas::from([return_ptr_loc])), + Just(return_ptr_loc), + ) + }) + .prop_flat_map(|(a, b, c, sa_ptr_loc, return_ptr_loc)| { + ( + Just(a.clone()), + Just(b.clone()), + Just(c.clone()), + Just(sa_ptr_loc), + HostMemory::byte_slice_strat( + b.len() as u32, + &MemAreas::from([sa_ptr_loc, return_ptr_loc]), + ), + Just(return_ptr_loc), + ) + }) + .prop_flat_map(|(a, b, c, sa_ptr_loc, sb_ptr_loc, return_ptr_loc)| { + ( + Just(a.clone()), + Just(b.clone()), + Just(c.clone()), + Just(sa_ptr_loc), + Just(sb_ptr_loc), + HostMemory::byte_slice_strat( + c.len() as u32, + &MemAreas::from([sa_ptr_loc, sb_ptr_loc, return_ptr_loc]), + ), + Just(return_ptr_loc), + ) + }) + .prop_map( + |(a, b, c, sa_ptr_loc, sb_ptr_loc, sc_ptr_loc, return_ptr_loc)| { + MultiStringExercise { + a, + b, + c, + sa_ptr_loc, + sb_ptr_loc, + sc_ptr_loc, + return_ptr_loc, + } + }, + ) + .boxed() + } + + pub fn test(&self) { + let ctx = WasiCtx::new(); + let host_memory = HostMemory::new(); + + let write_string = |val: &str, loc: MemArea| { + let ptr = host_memory.ptr::((loc.ptr, val.len() as u32)); + for (slot, byte) in ptr.as_bytes().iter().zip(val.bytes()) { + slot.expect("should be valid pointer") + .write(byte) + .expect("failed to write"); + } + }; + + write_string(&self.a, self.sa_ptr_loc); + write_string(&self.b, self.sb_ptr_loc); + write_string(&self.c, self.sc_ptr_loc); + + let res = strings::multi_string( + &ctx, + &host_memory, + self.sa_ptr_loc.ptr as i32, + self.a.len() as i32, + self.sb_ptr_loc.ptr as i32, + self.b.len() as i32, + self.sc_ptr_loc.ptr as i32, + self.c.len() as i32, + self.return_ptr_loc.ptr as i32, + ); + assert_eq!(res, types::Errno::Ok.into(), "multi string errno"); + + let given = host_memory + .ptr::(self.return_ptr_loc.ptr) + .read() + .expect("deref ptr to return value"); + assert_eq!((self.a.len() + self.b.len() + self.c.len()) as u32, given); + } +} +proptest! { + #[test] + fn multi_string(e in MultiStringExercise::strat()) { + e.test() + } +} diff --git a/tests/strings.witx b/tests/strings.witx index ebc0f8bf05..b3531e87bb 100644 --- a/tests/strings.witx +++ b/tests/strings.witx @@ -5,4 +5,12 @@ (result $error $errno) (result $total_bytes u32) ) + + (@interface func (export "multi_string") + (param $a string) + (param $b string) + (param $c string) + (result $error $errno) + (result $total_bytes u32) + ) ) diff --git a/tests/structs.rs b/tests/structs.rs index 420d02561e..69abc4b003 100644 --- a/tests/structs.rs +++ b/tests/structs.rs @@ -72,7 +72,7 @@ impl SumOfPairExercise { return_loc, }) .prop_filter("non-overlapping pointers", |e| { - MemArea::non_overlapping_set(&[&e.input_loc, &e.return_loc]) + MemArea::non_overlapping_set(&[e.input_loc, e.return_loc]) }) .boxed() } @@ -157,10 +157,10 @@ impl SumPairPtrsExercise { ) .prop_filter("non-overlapping pointers", |e| { MemArea::non_overlapping_set(&[ - &e.input_first_loc, - &e.input_second_loc, - &e.input_struct_loc, - &e.return_loc, + e.input_first_loc, + e.input_second_loc, + e.input_struct_loc, + e.return_loc, ]) }) .boxed() @@ -245,11 +245,7 @@ impl SumIntAndPtrExercise { }, ) .prop_filter("non-overlapping pointers", |e| { - MemArea::non_overlapping_set(&[ - &e.input_first_loc, - &e.input_struct_loc, - &e.return_loc, - ]) + MemArea::non_overlapping_set(&[e.input_first_loc, e.input_struct_loc, e.return_loc]) }) .boxed() } @@ -371,11 +367,7 @@ impl ReturnPairPtrsExercise { }, ) .prop_filter("non-overlapping pointers", |e| { - MemArea::non_overlapping_set(&[ - &e.input_first_loc, - &e.input_second_loc, - &e.return_loc, - ]) + MemArea::non_overlapping_set(&[e.input_first_loc, e.input_second_loc, e.return_loc]) }) .boxed() } diff --git a/tests/union.rs b/tests/union.rs index 6081b0d4e7..87f3fb5464 100644 --- a/tests/union.rs +++ b/tests/union.rs @@ -99,7 +99,7 @@ impl GetTagExercise { return_loc, }) .prop_filter("non-overlapping pointers", |e| { - MemArea::non_overlapping_set(&[&e.input_loc, &e.return_loc]) + MemArea::non_overlapping_set(&[e.input_loc, e.return_loc]) }) .boxed() } @@ -176,7 +176,7 @@ impl ReasonMultExercise { }, ) .prop_filter("non-overlapping pointers", |e| { - MemArea::non_overlapping_set(&[&e.input_loc, &e.input_pointee_loc]) + MemArea::non_overlapping_set(&[e.input_loc, e.input_pointee_loc]) }) .boxed() }