From 5542c4ef268eb62bc8da7e1c351511b4ce9d04d1 Mon Sep 17 00:00:00 2001 From: Joel Dice Date: Tue, 5 Jul 2022 09:36:43 -0600 Subject: [PATCH] support enums with more than 256 variants in derive macro (#4370) * support enums with more than 256 variants in derive macro This addresses #4361. Technically, we now support up to 2^32 variants, which is the maximum for the canonical ABI. In practice, though, the derived code for enums with even just 2^16 variants takes a prohibitively long time to compile. Signed-off-by: Joel Dice * simplify `LowerExpander::expand_variant` code Signed-off-by: Joel Dice --- Cargo.lock | 10 ++ Cargo.toml | 1 + crates/component-macro/src/lib.rs | 121 +++++++++++++++----- crates/misc/component-macro-test/Cargo.toml | 15 +++ crates/misc/component-macro-test/src/lib.rs | 34 ++++++ tests/all/component_model/macros.rs | 37 ++++++ 6 files changed, 190 insertions(+), 28 deletions(-) create mode 100644 crates/misc/component-macro-test/Cargo.toml create mode 100644 crates/misc/component-macro-test/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 04c2b74f10..4717e16e90 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -450,6 +450,15 @@ dependencies = [ "cc", ] +[[package]] +name = "component-macro-test" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "console" version = "0.15.0" @@ -3413,6 +3422,7 @@ dependencies = [ "anyhow", "async-trait", "clap 3.1.15", + "component-macro-test", "criterion", "env_logger 0.9.0", "filecheck", diff --git a/Cargo.toml b/Cargo.toml index eb1281ddf1..74e73f1954 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,7 @@ async-trait = "0.1" wat = "1.0.43" once_cell = "1.9.0" rayon = "1.5.0" +component-macro-test = { path = "crates/misc/component-macro-test" } [target.'cfg(windows)'.dev-dependencies] windows-sys = { version = "0.36.0", features = ["Win32_System_Memory"] } diff --git a/crates/component-macro/src/lib.rs b/crates/component-macro/src/lib.rs index b956c9e318..878c53f681 100644 --- a/crates/component-macro/src/lib.rs +++ b/crates/component-macro/src/lib.rs @@ -145,6 +145,54 @@ fn add_trait_bounds(generics: &syn::Generics, bound: syn::TypeParamBound) -> syn generics } +#[derive(Debug, Copy, Clone)] +enum DiscriminantSize { + Size1, + Size2, + Size4, +} + +impl DiscriminantSize { + fn quote(self, discriminant: usize) -> TokenStream { + match self { + Self::Size1 => { + let discriminant = u8::try_from(discriminant).unwrap(); + quote!(#discriminant) + } + Self::Size2 => { + let discriminant = u16::try_from(discriminant).unwrap(); + quote!(#discriminant) + } + Self::Size4 => { + let discriminant = u32::try_from(discriminant).unwrap(); + quote!(#discriminant) + } + } + } +} + +impl From for u32 { + fn from(size: DiscriminantSize) -> u32 { + match size { + DiscriminantSize::Size1 => 1, + DiscriminantSize::Size2 => 2, + DiscriminantSize::Size4 => 4, + } + } +} + +fn discriminant_size(case_count: usize) -> Option { + if case_count <= 0xFF { + Some(DiscriminantSize::Size1) + } else if case_count <= 0xFFFF { + Some(DiscriminantSize::Size2) + } else if case_count <= 0xFFFF_FFFF { + Some(DiscriminantSize::Size4) + } else { + None + } +} + struct VariantCase<'a> { attrs: &'a [syn::Attribute], ident: &'a syn::Ident, @@ -157,6 +205,7 @@ trait Expander { fn expand_variant( &self, input: &DeriveInput, + discriminant_size: DiscriminantSize, cases: &[VariantCase], style: VariantStyle, ) -> Result; @@ -217,6 +266,13 @@ fn expand_variant( )); } + let discriminant_size = discriminant_size(body.variants.len()).ok_or_else(|| { + Error::new( + input.ident.span(), + "`enum`s with more than 2^32 variants are not supported", + ) + })?; + let cases = body .variants .iter() @@ -240,8 +296,13 @@ fn expand_variant( name.span(), format!( "`{}` component types can only be derived for Rust `enum`s \ - containing variants with at most one unnamed field each", - style + containing variants with {}", + style, + match style { + VariantStyle::Variant => "at most one unnamed field each", + VariantStyle::Enum => "no fields", + VariantStyle::Union => "exactly one unnamed field each", + } ), )) } @@ -251,7 +312,7 @@ fn expand_variant( ) .collect::>>()?; - expander.expand_variant(input, &cases, style) + expander.expand_variant(input, discriminant_size, &cases, style) } #[proc_macro_derive(Lift, attributes(component))] @@ -321,6 +382,7 @@ impl Expander for LiftExpander { fn expand_variant( &self, input: &DeriveInput, + discriminant_size: DiscriminantSize, cases: &[VariantCase], _style: VariantStyle, ) -> Result { @@ -330,31 +392,26 @@ impl Expander for LiftExpander { let mut loads = TokenStream::new(); for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() { - let index_u8 = u8::try_from(index).map_err(|_| { - Error::new( - input.ident.span(), - "`enum`s with more than 256 variants not yet supported", - ) - })?; + let index_u32 = u32::try_from(index).unwrap(); - let index_i32 = index_u8 as i32; + let index_quoted = discriminant_size.quote(index); if let Some(ty) = ty { lifts.extend( - quote!(#index_i32 => Self::#ident(<#ty as wasmtime::component::Lift>::lift( + quote!(#index_u32 => Self::#ident(<#ty as wasmtime::component::Lift>::lift( store, options, unsafe { &src.payload.#ident } )?),), ); loads.extend( - quote!(#index_u8 => Self::#ident(<#ty as wasmtime::component::Lift>::load( + quote!(#index_quoted => Self::#ident(<#ty as wasmtime::component::Lift>::load( memory, &payload[..<#ty as wasmtime::component::ComponentType>::SIZE32] )?),), ); } else { - lifts.extend(quote!(#index_i32 => Self::#ident,)); + lifts.extend(quote!(#index_u32 => Self::#ident,)); - loads.extend(quote!(#index_u8 => Self::#ident,)); + loads.extend(quote!(#index_quoted => Self::#ident,)); } } @@ -362,6 +419,14 @@ impl Expander for LiftExpander { let generics = add_trait_bounds(&input.generics, parse_quote!(wasmtime::component::Lift)); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let from_bytes = match discriminant_size { + DiscriminantSize::Size1 => quote!(bytes[0]), + DiscriminantSize::Size2 => quote!(u16::from_le_bytes(bytes[0..2].try_into()?)), + DiscriminantSize::Size4 => quote!(u32::from_le_bytes(bytes[0..4].try_into()?)), + }; + + let payload_offset = u32::from(discriminant_size) as usize; + let expanded = quote! { unsafe impl #impl_generics wasmtime::component::Lift for #name #ty_generics #where_clause { #[inline] @@ -370,7 +435,7 @@ impl Expander for LiftExpander { options: &#internal::Options, src: &Self::Lower, ) -> #internal::anyhow::Result { - Ok(match src.tag.get_i32() { + Ok(match src.tag.get_u32() { #lifts discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim), }) @@ -380,8 +445,8 @@ impl Expander for LiftExpander { fn load(memory: &#internal::Memory, bytes: &[u8]) -> #internal::anyhow::Result { let align = ::ALIGN32; debug_assert!((bytes.as_ptr() as usize) % (align as usize) == 0); - let discrim = bytes[0]; - let payload = &bytes[#internal::align_to(1, align)..]; + let discrim = #from_bytes; + let payload = &bytes[#internal::align_to(#payload_offset, align)..]; Ok(match discrim { #loads discrim => #internal::anyhow::bail!("unexpected discriminant: {}", discrim), @@ -456,6 +521,7 @@ impl Expander for LowerExpander { fn expand_variant( &self, input: &DeriveInput, + discriminant_size: DiscriminantSize, cases: &[VariantCase], _style: VariantStyle, ) -> Result { @@ -465,14 +531,9 @@ impl Expander for LowerExpander { let mut stores = TokenStream::new(); for (index, VariantCase { ident, ty, .. }) in cases.iter().enumerate() { - let index_u8 = u8::try_from(index).map_err(|_| { - Error::new( - input.ident.span(), - "`enum`s with more than 256 variants not yet supported", - ) - })?; + let index_u32 = u32::try_from(index).unwrap(); - let index_i32 = index_u8 as i32; + let index_quoted = discriminant_size.quote(index); let pattern; let lower; @@ -492,12 +553,14 @@ impl Expander for LowerExpander { } lowers.extend(quote!(#pattern => { - #internal::map_maybe_uninit!(dst.tag).write(wasmtime::ValRaw::i32(#index_i32)); + #internal::map_maybe_uninit!(dst.tag).write(wasmtime::ValRaw::i32(#index_u32 as i32)); #lower })); + let discriminant_size = u32::from(discriminant_size) as usize; + stores.extend(quote!(#pattern => { - memory.get::<1>(offset)[0] = #index_u8; + *memory.get::<#discriminant_size>(offset) = #index_quoted.to_le_bytes(); #store })); } @@ -668,6 +731,7 @@ impl Expander for ComponentTypeExpander { fn expand_variant( &self, input: &DeriveInput, + discriminant_size: DiscriminantSize, cases: &[VariantCase], style: VariantStyle, ) -> Result { @@ -766,6 +830,7 @@ impl Expander for ComponentTypeExpander { let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let lower = format_ident!("Lower{}", name); let lower_payload = format_ident!("LowerPayload{}", name); + let discriminant_size = u32::from(discriminant_size); // You may wonder why we make the types of all the fields of the #lower struct and #lower_payload union // generic. This is to work around a [normalization bug in @@ -806,11 +871,11 @@ impl Expander for ComponentTypeExpander { const SIZE32: usize = { let mut size = 0; #sizes - #internal::align_to(1, Self::ALIGN32) + size + #internal::align_to(#discriminant_size as usize, Self::ALIGN32) + size }; const ALIGN32: u32 = { - let mut align = 1; + let mut align = #discriminant_size; #alignments align }; diff --git a/crates/misc/component-macro-test/Cargo.toml b/crates/misc/component-macro-test/Cargo.toml new file mode 100644 index 0000000000..f613aaeb2f --- /dev/null +++ b/crates/misc/component-macro-test/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "component-macro-test" +authors = ["The Wasmtime Project Developers"] +license = "Apache-2.0 WITH LLVM-exception" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0" +quote = "1.0" +syn = { version = "1.0", features = ["full"] } diff --git a/crates/misc/component-macro-test/src/lib.rs b/crates/misc/component-macro-test/src/lib.rs new file mode 100644 index 0000000000..59a62de385 --- /dev/null +++ b/crates/misc/component-macro-test/src/lib.rs @@ -0,0 +1,34 @@ +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::parse_macro_input; + +#[proc_macro_attribute] +pub fn add_variants( + attr: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + expand_variants( + &parse_macro_input!(attr as syn::LitInt), + parse_macro_input!(item as syn::ItemEnum), + ) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +fn expand_variants(count: &syn::LitInt, mut ty: syn::ItemEnum) -> syn::Result { + let count = count + .base10_digits() + .parse::() + .map_err(|_| syn::Error::new(count.span(), "expected unsigned integer"))?; + + ty.variants = (0..count) + .map(|index| syn::Variant { + attrs: Vec::new(), + ident: syn::Ident::new(&format!("V{}", index), Span::call_site()), + fields: syn::Fields::Unit, + discriminant: None, + }) + .collect(); + + Ok(quote!(#ty)) +} diff --git a/tests/all/component_model/macros.rs b/tests/all/component_model/macros.rs index 71a1385b11..ca519d2f42 100644 --- a/tests/all/component_model/macros.rs +++ b/tests/all/component_model/macros.rs @@ -1,5 +1,6 @@ use super::TypedFuncExt; use anyhow::Result; +use component_macro_test::add_variants; use std::fmt::Write; use wasmtime::component::{Component, ComponentType, Lift, Linker, Lower}; use wasmtime::Store; @@ -475,5 +476,41 @@ fn enum_derive() -> Result<()> { .get_typed_func::<(Foo,), Foo, _>(&mut store, "echo") .is_err()); + #[add_variants(257)] + #[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)] + #[component(enum)] + enum Many {} + + let component = Component::new( + &engine, + make_echo_component( + &format!( + r#"(type $Foo (enum {}))"#, + (0..257) + .map(|index| format!(r#""V{}""#, index)) + .collect::>() + .join(" ") + ), + 4, + ), + )?; + let instance = Linker::new(&engine).instantiate(&mut store, &component)?; + let func = instance.get_typed_func::<(Many,), Many, _>(&mut store, "echo")?; + + for &input in &[Many::V0, Many::V1, Many::V254, Many::V255, Many::V256] { + let output = func.call_and_post_return(&mut store, (input,))?; + + assert_eq!(input, output); + } + + // TODO: The following case takes forever (i.e. I gave up after 30 minutes) to compile; we'll need to profile + // the compiler to find out why, which may point the way to a more efficient option. On the other hand, this + // may not be worth spending time on. Enums with over 2^16 variants are rare enough. + + // #[add_variants(65537)] + // #[derive(ComponentType, Lift, Lower, PartialEq, Eq, Debug, Copy, Clone)] + // #[component(enum)] + // enum ManyMore {} + Ok(()) }