Browse Source

Check safety of `as_raw` with a simplified borrow checker (#37)

* wiggle-runtime: add as_raw method for [T]

* add trivial borrow checker back in

* integrate runtime borrow checker with as_raw methods

* handle pointer arith overflow correctly in as_raw, create PtrOverflow error

* runtime: add validation back to GuestType

* generate: impl validate for enums, flags, handles, ints

* oops! make validate its own method on trait GuestTypeTransparent

* fix transparent impls for enum, flag, handle, int

* some structs are transparent. fix tests.

* tests: define byte_slice_strat and friends

* wiggle-tests: i believe my allocator is working now

* some type juggling around memset for ease of use

* make GuestTypeTransparent an unsafe trait

* delete redundant validation of pointer align

* fix doc

* wiggle_test: aha, you cant use sets to track memory areas

* add multi-string test

which exercises the runtime borrow checker against
HostMemory::byte_slice_strat

* oops left debug panic in

* remove redundant (& incorrect, since unchecked) length calc

* redesign validate again, and actually hook to as_raw

* makr all validate impls as inline

this should hopefully allow as_raw's check loop to be unrolled to a
no-op in most cases!

* code review fixes
pull/1278/head
Pat Hickey 5 years ago
committed by GitHub
parent
commit
c78416912c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 27
      crates/generate/src/lifetimes.rs
  2. 16
      crates/generate/src/types/enum.rs
  3. 18
      crates/generate/src/types/flags.rs
  4. 11
      crates/generate/src/types/handle.rs
  5. 10
      crates/generate/src/types/int.rs
  6. 28
      crates/generate/src/types/struct.rs
  7. 91
      crates/runtime/src/borrow.rs
  8. 2
      crates/runtime/src/error.rs
  9. 25
      crates/runtime/src/guest_type.rs
  10. 75
      crates/runtime/src/lib.rs
  11. 206
      crates/test/src/lib.rs
  12. 8
      tests/arrays.rs
  13. 2
      tests/flags.rs
  14. 8
      tests/pointers.rs
  15. 144
      tests/strings.rs
  16. 8
      tests/strings.witx
  17. 22
      tests/structs.rs
  18. 4
      tests/union.rs

27
crates/generate/src/lifetimes.rs

@ -2,16 +2,34 @@ use proc_macro2::TokenStream;
use quote::quote;
pub trait LifetimeExt {
fn is_transparent(&self) -> bool;
fn needs_lifetime(&self) -> bool;
}
impl LifetimeExt for witx::TypeRef {
fn is_transparent(&self) -> bool {
self.type_().is_transparent()
}
fn needs_lifetime(&self) -> bool {
self.type_().needs_lifetime()
}
}
impl LifetimeExt for witx::Type {
fn is_transparent(&self) -> bool {
match self {
witx::Type::Builtin(b) => b.is_transparent(),
witx::Type::Struct(s) => s.is_transparent(),
witx::Type::Enum { .. }
| witx::Type::Flags { .. }
| witx::Type::Int { .. }
| witx::Type::Handle { .. } => true,
witx::Type::Union { .. }
| witx::Type::Pointer { .. }
| witx::Type::ConstPointer { .. }
| witx::Type::Array { .. } => false,
}
}
fn needs_lifetime(&self) -> bool {
match self {
witx::Type::Builtin(b) => b.needs_lifetime(),
@ -29,6 +47,9 @@ impl LifetimeExt for witx::Type {
}
impl LifetimeExt for witx::BuiltinType {
fn is_transparent(&self) -> bool {
!self.needs_lifetime()
}
fn needs_lifetime(&self) -> bool {
match self {
witx::BuiltinType::String => true,
@ -38,12 +59,18 @@ impl LifetimeExt for witx::BuiltinType {
}
impl LifetimeExt for witx::StructDatatype {
fn is_transparent(&self) -> bool {
self.members.iter().all(|m| m.tref.is_transparent())
}
fn needs_lifetime(&self) -> bool {
self.members.iter().any(|m| m.tref.needs_lifetime())
}
}
impl LifetimeExt for witx::UnionDatatype {
fn is_transparent(&self) -> bool {
false
}
fn needs_lifetime(&self) -> bool {
self.variants
.iter()

16
crates/generate/src/types/enum.rs

@ -87,8 +87,9 @@ pub(super) fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype
fn read(location: &wiggle_runtime::GuestPtr<#ident>) -> Result<#ident, wiggle_runtime::GuestError> {
use std::convert::TryFrom;
let val = #repr::read(&location.cast())?;
#ident::try_from(val)
let reprval = #repr::read(&location.cast())?;
let value = #ident::try_from(reprval)?;
Ok(value)
}
fn write(location: &wiggle_runtime::GuestPtr<'_, #ident>, val: Self)
@ -97,5 +98,16 @@ pub(super) fn define_enum(names: &Names, name: &witx::Id, e: &witx::EnumDatatype
#repr::write(&location.cast(), #repr::from(val))
}
}
unsafe impl <'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {
#[inline]
fn validate(location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> {
use std::convert::TryFrom;
// Validate value in memory using #ident::try_from(reprval)
let reprval = unsafe { (location as *mut #repr).read() };
let _val = #ident::try_from(reprval)?;
Ok(())
}
}
}
}

18
crates/generate/src/types/flags.rs

@ -134,10 +134,11 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty
#repr::guest_align()
}
fn read(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<#ident, wiggle_runtime::GuestError> {
fn read(location: &wiggle_runtime::GuestPtr<#ident>) -> Result<#ident, wiggle_runtime::GuestError> {
use std::convert::TryFrom;
let bits = #repr::read(&location.cast())?;
#ident::try_from(bits)
let reprval = #repr::read(&location.cast())?;
let value = #ident::try_from(reprval)?;
Ok(value)
}
fn write(location: &wiggle_runtime::GuestPtr<'_, #ident>, val: Self) -> Result<(), wiggle_runtime::GuestError> {
@ -145,5 +146,16 @@ pub(super) fn define_flags(names: &Names, name: &witx::Id, f: &witx::FlagsDataty
#repr::write(&location.cast(), val)
}
}
unsafe impl <'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {
#[inline]
fn validate(location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> {
use std::convert::TryFrom;
// Validate value in memory using #ident::try_from(reprval)
let reprval = unsafe { (location as *mut #repr).read() };
let _val = #ident::try_from(reprval)?;
Ok(())
}
}
}
}

11
crates/generate/src/types/handle.rs

@ -13,6 +13,7 @@ pub(super) fn define_handle(
let size = h.mem_size_align().size as u32;
let align = h.mem_size_align().align as usize;
quote! {
#[repr(transparent)]
#[derive(Copy, Clone, Debug, ::std::hash::Hash, Eq, PartialEq)]
pub struct #ident(u32);
@ -62,5 +63,15 @@ pub(super) fn define_handle(
u32::write(&location.cast(), val.0)
}
}
unsafe impl<'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {
#[inline]
fn validate(_location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> {
// All bit patterns accepted
Ok(())
}
}
}
}

10
crates/generate/src/types/int.rs

@ -73,11 +73,21 @@ pub(super) fn define_int(names: &Names, name: &witx::Id, i: &witx::IntDatatype)
fn read(location: &wiggle_runtime::GuestPtr<'a, #ident>) -> Result<#ident, wiggle_runtime::GuestError> {
Ok(#ident(#repr::read(&location.cast())?))
}
fn write(location: &wiggle_runtime::GuestPtr<'_, #ident>, val: Self) -> Result<(), wiggle_runtime::GuestError> {
#repr::write(&location.cast(), val.0)
}
}
unsafe impl<'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {
#[inline]
fn validate(_location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> {
// All bit patterns accepted
Ok(())
}
}
}
}

28
crates/generate/src/types/struct.rs

@ -77,6 +77,32 @@ pub(super) fn define_struct(
(quote!(), quote!(, Copy, PartialEq))
};
let transparent = if s.is_transparent() {
let member_validate = s.member_layout().into_iter().map(|ml| {
let offset = ml.offset;
let typename = names.type_ref(&ml.member.tref, anon_lifetime());
quote! {
// SAFETY: caller has validated bounds and alignment of `location`.
// member_layout gives correctly-aligned pointers inside that area.
#typename::validate(
unsafe { (location as *mut u8).add(#offset) as *mut _ }
)?;
}
});
quote! {
unsafe impl<'a> wiggle_runtime::GuestTypeTransparent<'a> for #ident {
#[inline]
fn validate(location: *mut #ident) -> Result<(), wiggle_runtime::GuestError> {
#(#member_validate)*
Ok(())
}
}
}
} else {
quote!()
};
quote! {
#[derive(Clone, Debug #extra_derive)]
pub struct #ident #struct_lifetime {
@ -102,5 +128,7 @@ pub(super) fn define_struct(
Ok(())
}
}
#transparent
}
}

91
crates/runtime/src/borrow.rs

@ -0,0 +1,91 @@
use crate::region::Region;
use crate::GuestError;
#[derive(Debug)]
pub struct GuestBorrows {
borrows: Vec<Region>,
}
impl GuestBorrows {
pub fn new() -> Self {
Self {
borrows: Vec::new(),
}
}
fn is_borrowed(&self, r: Region) -> bool {
!self.borrows.iter().all(|b| !b.overlaps(r))
}
pub fn borrow(&mut self, r: Region) -> Result<(), GuestError> {
if self.is_borrowed(r) {
Err(GuestError::PtrBorrowed(r))
} else {
self.borrows.push(r);
Ok(())
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn nonoverlapping() {
let mut bs = GuestBorrows::new();
let r1 = Region::new(0, 10);
let r2 = Region::new(10, 10);
assert!(!r1.overlaps(r2));
bs.borrow(r1).expect("can borrow r1");
bs.borrow(r2).expect("can borrow r2");
let mut bs = GuestBorrows::new();
let r1 = Region::new(10, 10);
let r2 = Region::new(0, 10);
assert!(!r1.overlaps(r2));
bs.borrow(r1).expect("can borrow r1");
bs.borrow(r2).expect("can borrow r2");
}
#[test]
fn overlapping() {
let mut bs = GuestBorrows::new();
let r1 = Region::new(0, 10);
let r2 = Region::new(9, 10);
assert!(r1.overlaps(r2));
bs.borrow(r1).expect("can borrow r1");
assert!(bs.borrow(r2).is_err(), "cant borrow r2");
let mut bs = GuestBorrows::new();
let r1 = Region::new(0, 10);
let r2 = Region::new(2, 5);
assert!(r1.overlaps(r2));
bs.borrow(r1).expect("can borrow r1");
assert!(bs.borrow(r2).is_err(), "cant borrow r2");
let mut bs = GuestBorrows::new();
let r1 = Region::new(9, 10);
let r2 = Region::new(0, 10);
assert!(r1.overlaps(r2));
bs.borrow(r1).expect("can borrow r1");
assert!(bs.borrow(r2).is_err(), "cant borrow r2");
let mut bs = GuestBorrows::new();
let r1 = Region::new(2, 5);
let r2 = Region::new(0, 10);
assert!(r1.overlaps(r2));
bs.borrow(r1).expect("can borrow r1");
assert!(bs.borrow(r2).is_err(), "cant borrow r2");
let mut bs = GuestBorrows::new();
let r1 = Region::new(2, 5);
let r2 = Region::new(10, 5);
let r3 = Region::new(15, 5);
let r4 = Region::new(0, 10);
assert!(r1.overlaps(r4));
bs.borrow(r1).expect("can borrow r1");
bs.borrow(r2).expect("can borrow r2");
bs.borrow(r3).expect("can borrow r3");
assert!(bs.borrow(r4).is_err(), "cant borrow r4");
}
}

2
crates/runtime/src/error.rs

@ -7,6 +7,8 @@ pub enum GuestError {
InvalidFlagValue(&'static str),
#[error("Invalid enum value {0}")]
InvalidEnumValue(&'static str),
#[error("Pointer overflow")]
PtrOverflow,
#[error("Pointer out of bounds: {0:?}")]
PtrOutOfBounds(Region),
#[error("Pointer not aligned to {1}: {0:?}")]

25
crates/runtime/src/guest_type.rs

@ -40,6 +40,22 @@ pub trait GuestType<'a>: Sized {
fn write(ptr: &GuestPtr<'_, Self>, val: Self) -> Result<(), GuestError>;
}
/// A trait for `GuestType`s that have the same representation in guest memory
/// as in Rust. These types can be used with the `GuestPtr::as_raw` method to
/// view as a slice.
///
/// Unsafe trait because a correct GuestTypeTransparent implemengation ensures that the
/// GuestPtr::as_raw methods are safe. This trait should only ever be implemented
/// by wiggle_generate-produced code.
pub unsafe trait GuestTypeTransparent<'a>: GuestType<'a> {
/// Checks that the memory at `ptr` is a valid representation of `Self`.
///
/// Assumes that memory safety checks have already been performed: `ptr`
/// has been checked to be aligned correctly and reside in memory using
/// `GuestMemory::validate_size_align`
fn validate(ptr: *mut Self) -> Result<(), GuestError>;
}
macro_rules! primitives {
($($i:ident)*) => ($(
impl<'a> GuestType<'a> for $i {
@ -78,6 +94,15 @@ macro_rules! primitives {
Ok(())
}
}
unsafe impl<'a> GuestTypeTransparent<'a> for $i {
#[inline]
fn validate(_ptr: *mut $i) -> Result<(), GuestError> {
// All bit patterns are safe, nothing to do here
Ok(())
}
}
)*)
}

75
crates/runtime/src/lib.rs

@ -6,11 +6,14 @@ use std::slice;
use std::str;
use std::sync::Arc;
mod borrow;
mod error;
mod guest_type;
mod region;
pub use borrow::GuestBorrows;
pub use error::GuestError;
pub use guest_type::{GuestErrorType, GuestType};
pub use guest_type::{GuestErrorType, GuestType, GuestTypeTransparent};
pub use region::Region;
/// A trait which abstracts how to get at the region of host memory taht
@ -119,12 +122,12 @@ pub unsafe trait GuestMemory {
// Figure out our pointer to the start of memory
let start = match (base_ptr as usize).checked_add(offset as usize) {
Some(ptr) => ptr,
None => return Err(GuestError::PtrOutOfBounds(region)),
None => return Err(GuestError::PtrOverflow),
};
// and use that to figure out the end pointer
let end = match start.checked_add(len as usize) {
Some(ptr) => ptr,
None => return Err(GuestError::PtrOutOfBounds(region)),
None => return Err(GuestError::PtrOverflow),
};
// and then verify that our end doesn't reach past the end of our memory
if end > (base_ptr as usize) + (base_len as usize) {
@ -335,7 +338,7 @@ impl<'a, T: ?Sized + Pointee> GuestPtr<'a, T> {
.and_then(|o| self.pointer.checked_add(o));
let offset = match offset {
Some(o) => o,
None => return Err(GuestError::InvalidFlagValue("")),
None => return Err(GuestError::PtrOverflow),
};
Ok(GuestPtr::new(self.mem, offset))
}
@ -369,6 +372,54 @@ impl<'a, T> GuestPtr<'a, [T]> {
(0..self.len()).map(move |i| base.add(i))
}
/// Attempts to read a raw `*mut [T]` pointer from this pointer, performing
/// bounds checks and type validation.
/// The resulting `*mut [T]` can be used as a `&mut [t]` as long as the
/// reference is dropped before any Wasm code is re-entered.
///
/// This function will return a raw pointer into host memory if all checks
/// succeed (valid utf-8, valid pointers, etc). If any checks fail then
/// `GuestError` will be returned.
///
/// Note that the `*mut [T]` pointer is still unsafe to use in general, but
/// there are specific situations that it is safe to use. For more
/// information about using the raw pointer, consult the [`GuestMemory`]
/// trait documentation.
///
/// For safety against overlapping mutable borrows, the user must use the
/// same `GuestBorrows` to create all *mut str or *mut [T] that are alive
/// at the same time.
pub fn as_raw(&self, bc: &mut GuestBorrows) -> Result<*mut [T], GuestError>
where
T: GuestTypeTransparent<'a>,
{
let len = match self.pointer.1.checked_mul(T::guest_size()) {
Some(l) => l,
None => return Err(GuestError::PtrOverflow),
};
let ptr =
self.mem
.validate_size_align(self.pointer.0, T::guest_align(), len)? as *mut T;
bc.borrow(Region {
start: self.pointer.0,
len,
})?;
// Validate all elements in slice.
// SAFETY: ptr has been validated by self.mem.validate_size_align
for offs in 0..self.pointer.1 {
T::validate(unsafe { ptr.add(offs as usize) })?;
}
// SAFETY: iff there are no overlapping borrows (all uses of as_raw use this same
// GuestBorrows), its valid to construct a *mut [T]
unsafe {
let s = slice::from_raw_parts_mut(ptr, self.pointer.1 as usize);
Ok(s as *mut [T])
}
}
/// Returns a `GuestPtr` pointing to the base of the array for the interior
/// type `T`.
pub fn as_ptr(&self) -> GuestPtr<'a, T> {
@ -396,6 +447,8 @@ impl<'a> GuestPtr<'a, str> {
/// Attempts to read a raw `*mut str` pointer from this pointer, performing
/// bounds checks and utf-8 checks.
/// The resulting `*mut str` can be used as a `&mut str` as long as the
/// reference is dropped before any Wasm code is re-entered.
///
/// This function will return a raw pointer into host memory if all checks
/// succeed (valid utf-8, valid pointers, etc). If any checks fail then
@ -405,12 +458,22 @@ impl<'a> GuestPtr<'a, str> {
/// there are specific situations that it is safe to use. For more
/// information about using the raw pointer, consult the [`GuestMemory`]
/// trait documentation.
pub fn as_raw(&self) -> Result<*mut str, GuestError> {
///
/// For safety against overlapping mutable borrows, the user must use the
/// same `GuestBorrows` to create all *mut str or *mut [T] that are alive
/// at the same time.
pub fn as_raw(&self, bc: &mut GuestBorrows) -> Result<*mut str, GuestError> {
let ptr = self
.mem
.validate_size_align(self.pointer.0, 1, self.pointer.1)?;
// TODO: doc unsafety here
bc.borrow(Region {
start: self.pointer.0,
len: self.pointer.1,
})?;
// SAFETY: iff there are no overlapping borrows (all uses of as_raw use this same
// GuestBorrows), its valid to construct a *mut str
unsafe {
let s = slice::from_raw_parts_mut(ptr, self.pointer.1 as usize);
match str::from_utf8_mut(s) {

206
crates/test/src/lib.rs

@ -2,6 +2,45 @@ use proptest::prelude::*;
use std::cell::UnsafeCell;
use wiggle_runtime::GuestMemory;
#[derive(Debug, Clone)]
pub struct MemAreas(Vec<MemArea>);
impl MemAreas {
pub fn new() -> Self {
MemAreas(Vec::new())
}
pub fn insert(&mut self, a: MemArea) {
// Find if `a` is already in the vector
match self.0.binary_search(&a) {
// It is present - insert it next to existing one
Ok(loc) => self.0.insert(loc, a),
// It is not present - heres where to insert it
Err(loc) => self.0.insert(loc, a),
}
}
pub fn iter(&self) -> impl Iterator<Item = &MemArea> {
self.0.iter()
}
}
impl<R> From<R> for MemAreas
where
R: AsRef<[MemArea]>,
{
fn from(ms: R) -> MemAreas {
let mut out = MemAreas::new();
for m in ms.as_ref().into_iter() {
out.insert(*m);
}
out
}
}
impl Into<Vec<MemArea>> for MemAreas {
fn into(self) -> Vec<MemArea> {
self.0.clone()
}
}
#[repr(align(4096))]
pub struct HostMemory {
buffer: UnsafeCell<[u8; 4096]>,
@ -26,6 +65,42 @@ impl HostMemory {
})
.boxed()
}
/// Takes a sorted list or memareas, and gives a sorted list of memareas covering
/// the parts of memory not covered by the previous
pub fn invert(regions: &MemAreas) -> MemAreas {
let mut out = MemAreas::new();
let mut start = 0;
for r in regions.iter() {
let len = r.ptr - start;
if len > 0 {
out.insert(MemArea {
ptr: start,
len: r.ptr - start,
});
}
start = r.ptr + r.len;
}
if start < 4096 {
out.insert(MemArea {
ptr: start,
len: 4096 - start,
});
}
out
}
pub fn byte_slice_strat(size: u32, exclude: &MemAreas) -> BoxedStrategy<MemArea> {
let available: Vec<MemArea> = Self::invert(exclude)
.iter()
.flat_map(|a| a.inside(size))
.collect();
Just(available)
.prop_filter("available memory for allocation", |a| !a.is_empty())
.prop_flat_map(|a| prop::sample::select(a))
.boxed()
}
}
unsafe impl GuestMemory for HostMemory {
@ -37,7 +112,7 @@ unsafe impl GuestMemory for HostMemory {
}
}
#[derive(Debug)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct MemArea {
pub ptr: u32,
pub len: u32,
@ -48,7 +123,7 @@ impl MemArea {
// test.
// So, I implemented this one with std::ops::Range so it is less likely I wrote the same bug in two
// places.
pub fn overlapping(&self, b: &Self) -> bool {
pub fn overlapping(&self, b: Self) -> bool {
// a_range is all elems in A
let a_range = std::ops::Range {
start: self.ptr,
@ -73,18 +148,33 @@ impl MemArea {
}
return false;
}
pub fn non_overlapping_set(areas: &[&Self]) -> bool {
// A is all areas
for (i, a) in areas.iter().enumerate() {
// (A, B) is every pair of areas
for b in areas[i + 1..].iter() {
if a.overlapping(b) {
return false;
pub fn non_overlapping_set<M>(areas: M) -> bool
where
M: Into<MemAreas>,
{
let areas = areas.into();
for (aix, a) in areas.iter().enumerate() {
for (bix, b) in areas.iter().enumerate() {
if aix != bix {
// (A, B) is every pairing of areas
if a.overlapping(*b) {
return false;
}
}
}
}
return true;
}
/// Enumerate all memareas of size `len` inside a given area
fn inside(&self, len: u32) -> impl Iterator<Item = MemArea> {
let end: i64 = self.len as i64 - len as i64;
let start = self.ptr;
(0..end).into_iter().map(move |v| MemArea {
ptr: start + v as u32,
len,
})
}
}
#[cfg(test)]
@ -97,6 +187,104 @@ mod test {
let h = Box::new(h);
assert_eq!(h.base().0 as usize % 4096, 0);
}
#[test]
fn invert() {
fn invert_equality(input: &[MemArea], expected: &[MemArea]) {
let input: MemAreas = input.into();
let inverted: Vec<MemArea> = HostMemory::invert(&input).into();
assert_eq!(expected, inverted.as_slice());
}
invert_equality(&[], &[MemArea { ptr: 0, len: 4096 }]);
invert_equality(
&[MemArea { ptr: 0, len: 1 }],
&[MemArea { ptr: 1, len: 4095 }],
);
invert_equality(
&[MemArea { ptr: 1, len: 1 }],
&[MemArea { ptr: 0, len: 1 }, MemArea { ptr: 2, len: 4094 }],
);
invert_equality(
&[MemArea { ptr: 1, len: 4095 }],
&[MemArea { ptr: 0, len: 1 }],
);
invert_equality(
&[MemArea { ptr: 0, len: 1 }, MemArea { ptr: 1, len: 4095 }],
&[],
);
invert_equality(
&[MemArea { ptr: 1, len: 2 }, MemArea { ptr: 4, len: 1 }],
&[
MemArea { ptr: 0, len: 1 },
MemArea { ptr: 3, len: 1 },
MemArea { ptr: 5, len: 4091 },
],
);
}
fn set_of_slices_strat(
s1: u32,
s2: u32,
s3: u32,
) -> BoxedStrategy<(MemArea, MemArea, MemArea)> {
HostMemory::byte_slice_strat(s1, &MemAreas::new())
.prop_flat_map(move |a1| {
(
Just(a1),
HostMemory::byte_slice_strat(s2, &MemAreas::from(&[a1])),
)
})
.prop_flat_map(move |(a1, a2)| {
(
Just(a1),
Just(a2),
HostMemory::byte_slice_strat(s3, &MemAreas::from(&[a1, a2])),
)
})
.boxed()
}
#[test]
fn trivial_inside() {
let a = MemArea { ptr: 24, len: 4072 };
let interior = a.inside(24).collect::<Vec<_>>();
assert!(interior.len() > 0);
}
proptest! {
#[test]
// For some random region of decent size
fn inside(r in HostMemory::mem_area_strat(123)) {
let set_of_r = MemAreas::from(&[r]);
// All regions outside of r:
let exterior = HostMemory::invert(&set_of_r);
// All regions inside of r:
let interior = r.inside(22);
for i in interior {
// i overlaps with r:
assert!(r.overlapping(i));
// i is inside r:
assert!(i.ptr >= r.ptr);
assert!(r.ptr + r.len >= i.ptr + i.len);
// the set of exterior and i is non-overlapping
let mut all = exterior.clone();
all.insert(i);
assert!(MemArea::non_overlapping_set(all));
}
}
#[test]
fn byte_slices((s1, s2, s3) in set_of_slices_strat(12, 34, 56)) {
let all = MemAreas::from(&[s1, s2, s3]);
assert!(MemArea::non_overlapping_set(all));
}
}
}
use std::cell::RefCell;

8
tests/arrays.rs

@ -67,9 +67,9 @@ impl ReduceExcusesExcercise {
},
)
.prop_filter("non-overlapping pointers", |e| {
let mut all = vec![&e.array_ptr_loc, &e.return_ptr_loc];
let mut all = vec![e.array_ptr_loc, e.return_ptr_loc];
all.extend(e.excuse_ptr_locs.iter());
MemArea::non_overlapping_set(&all)
MemArea::non_overlapping_set(all)
})
.boxed()
}
@ -155,9 +155,9 @@ impl PopulateExcusesExcercise {
elements,
})
.prop_filter("non-overlapping pointers", |e| {
let mut all = vec![&e.array_ptr_loc];
let mut all = vec![e.array_ptr_loc];
all.extend(e.elements.iter());
MemArea::non_overlapping_set(&all)
MemArea::non_overlapping_set(all)
})
.boxed()
}

2
tests/flags.rs

@ -57,7 +57,7 @@ impl ConfigureCarExercise {
},
)
.prop_filter("non-overlapping ptrs", |e| {
MemArea::non_overlapping_set(&[&e.other_config_by_ptr, &e.return_ptr_loc])
MemArea::non_overlapping_set(&[e.other_config_by_ptr, e.return_ptr_loc])
})
.boxed()
}

8
tests/pointers.rs

@ -117,10 +117,10 @@ impl PointersAndEnumsExercise {
)
.prop_filter("non-overlapping pointers", |e| {
MemArea::non_overlapping_set(&[
&e.input2_loc,
&e.input3_loc,
&e.input4_loc,
&e.input4_ptr_loc,
e.input2_loc,
e.input3_loc,
e.input4_loc,
e.input4_ptr_loc,
])
})
.boxed()

144
tests/strings.rs

@ -1,6 +1,6 @@
use proptest::prelude::*;
use wiggle_runtime::{GuestError, GuestMemory, GuestPtr};
use wiggle_test::{impl_errno, HostMemory, MemArea, WasiCtx};
use wiggle_runtime::{GuestBorrows, GuestError, GuestMemory, GuestPtr};
use wiggle_test::{impl_errno, HostMemory, MemArea, MemAreas, WasiCtx};
wiggle::from_witx!({
witx: ["tests/strings.witx"],
@ -11,12 +11,33 @@ impl_errno!(types::Errno);
impl strings::Strings for WasiCtx {
fn hello_string(&self, a_string: &GuestPtr<str>) -> Result<u32, types::Errno> {
let s = a_string.as_raw().expect("should be valid string");
let mut bc = GuestBorrows::new();
let s = a_string.as_raw(&mut bc).expect("should be valid string");
unsafe {
println!("a_string='{}'", &*s);
Ok((*s).len() as u32)
}
}
fn multi_string(
&self,
a: &GuestPtr<str>,
b: &GuestPtr<str>,
c: &GuestPtr<str>,
) -> Result<u32, types::Errno> {
let mut bc = GuestBorrows::new();
let sa = a.as_raw(&mut bc).expect("A should be valid string");
let sb = b.as_raw(&mut bc).expect("B should be valid string");
let sc = c.as_raw(&mut bc).expect("C should be valid string");
unsafe {
let total_len = (&*sa).len() + (&*sb).len() + (&*sc).len();
println!(
"len={}, a='{}', b='{}', c='{}'",
total_len, &*sa, &*sb, &*sc
);
Ok(total_len as u32)
}
}
}
fn test_string_strategy() -> impl Strategy<Value = String> {
@ -46,7 +67,7 @@ impl HelloStringExercise {
return_ptr_loc,
})
.prop_filter("non-overlapping pointers", |e| {
MemArea::non_overlapping_set(&[&e.string_ptr_loc, &e.return_ptr_loc])
MemArea::non_overlapping_set(&[e.string_ptr_loc, e.return_ptr_loc])
})
.boxed()
}
@ -85,3 +106,118 @@ proptest! {
e.test()
}
}
#[derive(Debug)]
struct MultiStringExercise {
a: String,
b: String,
c: String,
sa_ptr_loc: MemArea,
sb_ptr_loc: MemArea,
sc_ptr_loc: MemArea,
return_ptr_loc: MemArea,
}
impl MultiStringExercise {
pub fn strat() -> BoxedStrategy<Self> {
(
test_string_strategy(),
test_string_strategy(),
test_string_strategy(),
HostMemory::mem_area_strat(4),
)
.prop_flat_map(|(a, b, c, return_ptr_loc)| {
(
Just(a.clone()),
Just(b.clone()),
Just(c.clone()),
HostMemory::byte_slice_strat(a.len() as u32, &MemAreas::from([return_ptr_loc])),
Just(return_ptr_loc),
)
})
.prop_flat_map(|(a, b, c, sa_ptr_loc, return_ptr_loc)| {
(
Just(a.clone()),
Just(b.clone()),
Just(c.clone()),
Just(sa_ptr_loc),
HostMemory::byte_slice_strat(
b.len() as u32,
&MemAreas::from([sa_ptr_loc, return_ptr_loc]),
),
Just(return_ptr_loc),
)
})
.prop_flat_map(|(a, b, c, sa_ptr_loc, sb_ptr_loc, return_ptr_loc)| {
(
Just(a.clone()),
Just(b.clone()),
Just(c.clone()),
Just(sa_ptr_loc),
Just(sb_ptr_loc),
HostMemory::byte_slice_strat(
c.len() as u32,
&MemAreas::from([sa_ptr_loc, sb_ptr_loc, return_ptr_loc]),
),
Just(return_ptr_loc),
)
})
.prop_map(
|(a, b, c, sa_ptr_loc, sb_ptr_loc, sc_ptr_loc, return_ptr_loc)| {
MultiStringExercise {
a,
b,
c,
sa_ptr_loc,
sb_ptr_loc,
sc_ptr_loc,
return_ptr_loc,
}
},
)
.boxed()
}
pub fn test(&self) {
let ctx = WasiCtx::new();
let host_memory = HostMemory::new();
let write_string = |val: &str, loc: MemArea| {
let ptr = host_memory.ptr::<str>((loc.ptr, val.len() as u32));
for (slot, byte) in ptr.as_bytes().iter().zip(val.bytes()) {
slot.expect("should be valid pointer")
.write(byte)
.expect("failed to write");
}
};
write_string(&self.a, self.sa_ptr_loc);
write_string(&self.b, self.sb_ptr_loc);
write_string(&self.c, self.sc_ptr_loc);
let res = strings::multi_string(
&ctx,
&host_memory,
self.sa_ptr_loc.ptr as i32,
self.a.len() as i32,
self.sb_ptr_loc.ptr as i32,
self.b.len() as i32,
self.sc_ptr_loc.ptr as i32,
self.c.len() as i32,
self.return_ptr_loc.ptr as i32,
);
assert_eq!(res, types::Errno::Ok.into(), "multi string errno");
let given = host_memory
.ptr::<u32>(self.return_ptr_loc.ptr)
.read()
.expect("deref ptr to return value");
assert_eq!((self.a.len() + self.b.len() + self.c.len()) as u32, given);
}
}
proptest! {
#[test]
fn multi_string(e in MultiStringExercise::strat()) {
e.test()
}
}

8
tests/strings.witx

@ -5,4 +5,12 @@
(result $error $errno)
(result $total_bytes u32)
)
(@interface func (export "multi_string")
(param $a string)
(param $b string)
(param $c string)
(result $error $errno)
(result $total_bytes u32)
)
)

22
tests/structs.rs

@ -72,7 +72,7 @@ impl SumOfPairExercise {
return_loc,
})
.prop_filter("non-overlapping pointers", |e| {
MemArea::non_overlapping_set(&[&e.input_loc, &e.return_loc])
MemArea::non_overlapping_set(&[e.input_loc, e.return_loc])
})
.boxed()
}
@ -157,10 +157,10 @@ impl SumPairPtrsExercise {
)
.prop_filter("non-overlapping pointers", |e| {
MemArea::non_overlapping_set(&[
&e.input_first_loc,
&e.input_second_loc,
&e.input_struct_loc,
&e.return_loc,
e.input_first_loc,
e.input_second_loc,
e.input_struct_loc,
e.return_loc,
])
})
.boxed()
@ -245,11 +245,7 @@ impl SumIntAndPtrExercise {
},
)
.prop_filter("non-overlapping pointers", |e| {
MemArea::non_overlapping_set(&[
&e.input_first_loc,
&e.input_struct_loc,
&e.return_loc,
])
MemArea::non_overlapping_set(&[e.input_first_loc, e.input_struct_loc, e.return_loc])
})
.boxed()
}
@ -371,11 +367,7 @@ impl ReturnPairPtrsExercise {
},
)
.prop_filter("non-overlapping pointers", |e| {
MemArea::non_overlapping_set(&[
&e.input_first_loc,
&e.input_second_loc,
&e.return_loc,
])
MemArea::non_overlapping_set(&[e.input_first_loc, e.input_second_loc, e.return_loc])
})
.boxed()
}

4
tests/union.rs

@ -99,7 +99,7 @@ impl GetTagExercise {
return_loc,
})
.prop_filter("non-overlapping pointers", |e| {
MemArea::non_overlapping_set(&[&e.input_loc, &e.return_loc])
MemArea::non_overlapping_set(&[e.input_loc, e.return_loc])
})
.boxed()
}
@ -176,7 +176,7 @@ impl ReasonMultExercise {
},
)
.prop_filter("non-overlapping pointers", |e| {
MemArea::non_overlapping_set(&[&e.input_loc, &e.input_pointee_loc])
MemArea::non_overlapping_set(&[e.input_loc, e.input_pointee_loc])
})
.boxed()
}

Loading…
Cancel
Save