From 893fadb485717db6afaa9e2571221d4cb4a9bfc6 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Mon, 1 Aug 2022 11:05:09 -0500 Subject: [PATCH] components: Fix support for 0-sized flags (#4560) This commit goes through and updates support in the various argument passing routines to support 0-sized flags. A bit of a degenerate case but clarified in WebAssembly/component-model#76 as intentional. --- crates/component-macro/src/lib.rs | 17 +++++++++++++++++ crates/component-util/src/lib.rs | 6 +++++- .../fuzz/fuzz_targets/fact-valid-module.rs | 6 ++---- crates/environ/src/fact/trampoline.rs | 1 + crates/wasmtime/src/component/types.rs | 4 ++++ crates/wasmtime/src/component/values.rs | 3 +++ tests/all/component_model.rs | 4 +++- tests/all/component_model/macros.rs | 19 ++++++++++++++++--- .../misc_testsuite/component-model/fused.wast | 8 ++++++++ 9 files changed, 59 insertions(+), 9 deletions(-) diff --git a/crates/component-macro/src/lib.rs b/crates/component-macro/src/lib.rs index e782ff96a0..eb85f4010c 100644 --- a/crates/component-macro/src/lib.rs +++ b/crates/component-macro/src/lib.rs @@ -352,6 +352,7 @@ fn expand_record_for_component_type( #[repr(C)] pub struct #lower <#lower_generic_params> { #lower_field_declarations + _align: [wasmtime::ValRaw; 0], } unsafe impl #impl_generics wasmtime::component::ComponentType for #name #ty_generics #where_clause { @@ -965,6 +966,10 @@ fn expand_flags(flags: &Flags) -> Result { let count = flags.flags.len(); match size { + FlagsSize::Size0 => { + ty = quote!(()); + eq = quote!(true); + } FlagsSize::Size1 => { ty = quote!(u8); @@ -1021,6 +1026,17 @@ fn expand_flags(flags: &Flags) -> Result { let mut not; match size { + FlagsSize::Size0 => { + count = 0; + as_array = quote!([]); + bitor = quote!(Self {}); + bitor_assign = quote!(); + bitand = quote!(Self {}); + bitand_assign = quote!(); + bitxor = quote!(Self {}); + bitxor_assign = quote!(); + not = quote!(Self {}); + } FlagsSize::Size1 | FlagsSize::Size2 => { count = 1; as_array = quote!([self.__inner0 as u32]); @@ -1085,6 +1101,7 @@ fn expand_flags(flags: &Flags) -> Result { component_names.extend(quote!(#component_name,)); let fields = match size { + FlagsSize::Size0 => quote!(), FlagsSize::Size1 => { let init = 1_u8 << index; quote!(__inner0: #init) diff --git a/crates/component-util/src/lib.rs b/crates/component-util/src/lib.rs index c59c39040f..3823abcedc 100644 --- a/crates/component-util/src/lib.rs +++ b/crates/component-util/src/lib.rs @@ -48,6 +48,8 @@ impl From for usize { /// Represents the number of bytes required to store a flags value in the component model pub enum FlagsSize { + /// There are no flags + Size0, /// Flags can fit in a u8 Size1, /// Flags can fit in a u16 @@ -59,7 +61,9 @@ pub enum FlagsSize { impl FlagsSize { /// Calculate the size needed to represent a value with the specified number of flags. pub fn from_count(count: usize) -> FlagsSize { - if count <= 8 { + if count == 0 { + FlagsSize::Size0 + } else if count <= 8 { FlagsSize::Size1 } else if count <= 16 { FlagsSize::Size2 diff --git a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs index 5d183a82f9..3aaafca092 100644 --- a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs +++ b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs @@ -53,10 +53,8 @@ enum ValType { Float64, Char, Record(Vec), - // FIXME(WebAssembly/component-model#75) are zero-sized flags allowed? - // - // ... otherwise go up to 65 flags to exercise up to 3 u32 values - Flags(UsizeInRange<1, 65>), + // Up to 65 flags to exercise up to 3 u32 values + Flags(UsizeInRange<0, 65>), Tuple(Vec), Variant(NonZeroLenVec), Union(NonZeroLenVec), diff --git a/crates/environ/src/fact/trampoline.rs b/crates/environ/src/fact/trampoline.rs index b50000208a..8276af892a 100644 --- a/crates/environ/src/fact/trampoline.rs +++ b/crates/environ/src/fact/trampoline.rs @@ -700,6 +700,7 @@ impl Compiler<'_, '_> { assert_eq!(src_ty.names, dst_ty.names); let cnt = src_ty.names.len(); match FlagsSize::from_count(cnt) { + FlagsSize::Size0 => {} FlagsSize::Size1 => { let mask = if cnt == 8 { 0xff } else { (1 << cnt) - 1 }; self.convert_u8_mask(src, dst, mask); diff --git a/crates/wasmtime/src/component/types.rs b/crates/wasmtime/src/component/types.rs index 9d05866125..e90fe6e166 100644 --- a/crates/wasmtime/src/component/types.rs +++ b/crates/wasmtime/src/component/types.rs @@ -606,6 +606,10 @@ impl Type { } Type::Flags(handle) => match FlagsSize::from_count(handle.names().len()) { + FlagsSize::Size0 => SizeAndAlignment { + size: 0, + alignment: 1, + }, FlagsSize::Size1 => SizeAndAlignment { size: 1, alignment: 1, diff --git a/crates/wasmtime/src/component/values.rs b/crates/wasmtime/src/component/values.rs index 4588c08a58..3ef855ecc2 100644 --- a/crates/wasmtime/src/component/values.rs +++ b/crates/wasmtime/src/component/values.rs @@ -699,6 +699,7 @@ impl Val { ty: handle.clone(), count: u32::try_from(handle.names().len())?, value: match FlagsSize::from_count(handle.names().len()) { + FlagsSize::Size0 => Box::new([]), FlagsSize::Size1 => iter::once(u8::load(mem, bytes)? as u32).collect(), FlagsSize::Size2 => iter::once(u16::load(mem, bytes)? as u32).collect(), FlagsSize::Size4Plus(n) => (0..n) @@ -850,6 +851,7 @@ impl Val { Val::Flags(Flags { count, value, .. }) => { match FlagsSize::from_count(*count as usize) { + FlagsSize::Size0 => {} FlagsSize::Size1 => u8::try_from(value[0]).unwrap().store(mem, offset)?, FlagsSize::Size2 => u16::try_from(value[0]).unwrap().store(mem, offset)?, FlagsSize::Size4Plus(_) => { @@ -1018,6 +1020,7 @@ fn lower_list( /// Note that this will always return at least 1, even if the `count` parameter is zero. pub(crate) fn u32_count_for_flag_count(count: usize) -> usize { match FlagsSize::from_count(count) { + FlagsSize::Size0 => 0, FlagsSize::Size1 | FlagsSize::Size2 => 1, FlagsSize::Size4Plus(n) => n, } diff --git a/tests/all/component_model.rs b/tests/all/component_model.rs index 2c5a0f7e63..4ebb3b6125 100644 --- a/tests/all/component_model.rs +++ b/tests/all/component_model.rs @@ -210,7 +210,9 @@ fn make_echo_component(type_definition: &str, type_size: u32) -> String { } fn make_echo_component_with_params(type_definition: &str, params: &[Param]) -> String { - let func = if params.len() == 1 || params.len() > 16 { + let func = if params.len() == 0 { + format!("(func (export \"echo\"))") + } else if params.len() == 1 || params.len() > 16 { let primitive = if params.len() == 1 { params[0].0.primitive() } else { diff --git a/tests/all/component_model/macros.rs b/tests/all/component_model/macros.rs index e8dd38baee..190d6062c7 100644 --- a/tests/all/component_model/macros.rs +++ b/tests/all/component_model/macros.rs @@ -426,6 +426,22 @@ fn enum_derive() -> Result<()> { #[test] fn flags() -> Result<()> { + let engine = super::engine(); + let mut store = Store::new(&engine, ()); + + // Edge case of 0 flags + wasmtime::component::flags! { + Flags0 {} + } + assert_eq!(Flags0::default(), Flags0::default()); + + let component = Component::new(&engine, make_echo_component(r#"(flags)"#, 0))?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_typed_func::<(Flags0,), Flags0, _>(&mut store, "echo")?; + let output = func.call_and_post_return(&mut store, (Flags0::default(),))?; + assert_eq!(output, Flags0::default()); + + // Simple 8-bit flags wasmtime::component::flags! { Foo { #[component(name = "foo-bar-baz")] @@ -442,9 +458,6 @@ fn flags() -> Result<()> { assert_eq!(Foo::default(), Foo::A ^ Foo::A); assert_eq!(Foo::B | Foo::C, !Foo::A); - let engine = super::engine(); - let mut store = Store::new(&engine, ()); - // Happy path: component type matches flag count and names let component = Component::new( diff --git a/tests/misc_testsuite/component-model/fused.wast b/tests/misc_testsuite/component-model/fused.wast index fbab3d704e..6d326da6e2 100644 --- a/tests/misc_testsuite/component-model/fused.wast +++ b/tests/misc_testsuite/component-model/fused.wast @@ -1228,6 +1228,7 @@ ;; test that flags get their upper bits all masked off (component + (type $f0 (flags)) (type $f1 (flags "f1")) (type $f8 (flags "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8")) (type $f9 (flags "f1" "f2" "f3" "f4" "f5" "f6" "f7" "f8" "f9")) @@ -1277,6 +1278,7 @@ (component $c1 (core module $m + (func (export "f0")) (func (export "f1") (param i32) (if (i32.ne (local.get 0) (i32.const 0x1)) (unreachable)) ) @@ -1310,6 +1312,7 @@ ) ) (core instance $m (instantiate $m)) + (func (export "f0") (param $f0) (canon lift (core func $m "f0"))) (func (export "f1") (param $f1) (canon lift (core func $m "f1"))) (func (export "f8") (param $f8) (canon lift (core func $m "f8"))) (func (export "f9") (param $f9) (canon lift (core func $m "f9"))) @@ -1324,6 +1327,7 @@ (component $c2 (import "" (instance $i + (export "f0" (func (param $f0))) (export "f1" (func (param $f1))) (export "f8" (func (param $f8))) (export "f9" (func (param $f9))) @@ -1334,6 +1338,7 @@ (export "f64" (func (param $f64))) (export "f65" (func (param $f65))) )) + (core func $f0 (canon lower (func $i "f0"))) (core func $f1 (canon lower (func $i "f1"))) (core func $f8 (canon lower (func $i "f8"))) (core func $f9 (canon lower (func $i "f9"))) @@ -1345,6 +1350,7 @@ (core func $f65 (canon lower (func $i "f65"))) (core module $m + (import "" "f0" (func $f0)) (import "" "f1" (func $f1 (param i32))) (import "" "f8" (func $f8 (param i32))) (import "" "f9" (func $f9 (param i32))) @@ -1356,6 +1362,7 @@ (import "" "f65" (func $f65 (param i32 i32 i32))) (func $start + (call $f0) (call $f1 (i32.const 0xffffff01)) (call $f8 (i32.const 0xffffff11)) (call $f9 (i32.const 0xffffff11)) @@ -1371,6 +1378,7 @@ ) (core instance $m (instantiate $m (with "" (instance + (export "f0" (func $f0)) (export "f1" (func $f1)) (export "f8" (func $f8)) (export "f9" (func $f9))