lib.rs (7318B)
1 // Copyright 2019 The Servo Project Developers. See the COPYRIGHT 2 // file at the top-level directory of this distribution and at 3 // http://rust-lang.org/COPYRIGHT. 4 // 5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or 6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license 7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your 8 // option. This file may not be copied, modified, or distributed 9 // except according to those terms. 10 11 use proc_macro2::{Span, TokenStream}; 12 use quote::quote; 13 use syn::{Ident, Index, TraitBound}; 14 use synstructure::{decl_derive, Structure, BindStyle, AddBounds}; 15 16 /// Calculates size type for number of variants (used for enums) 17 fn get_discriminant_size_type(len: usize) -> TokenStream { 18 if len <= u8::MAX as usize { 19 quote! { u8 } 20 } else if len <= u16::MAX as usize { 21 quote! { u16 } 22 } else { 23 quote! { u32 } 24 } 25 } 26 27 fn is_struct(s: &Structure) -> bool { 28 // a single variant with no prefix is 'struct' 29 matches!(s.variants(), [v] if v.prefix.is_none()) 30 } 31 32 fn derive_max_size(s: &Structure) -> TokenStream { 33 let max_size = s.variants().iter().fold(quote!(0), |acc, vi| { 34 let variant_size = vi.bindings().iter().fold(quote!(0), |acc, bi| { 35 // compute size of each variant by summing the sizes of its bindings 36 let ty = &bi.ast().ty; 37 quote!(#acc + <#ty>::max_size()) 38 }); 39 40 // find the maximum of each variant 41 quote! { 42 max(#acc, #variant_size) 43 } 44 }); 45 46 let body = if is_struct(s) { 47 max_size 48 } else { 49 let discriminant_size_type = get_discriminant_size_type(s.variants().len()); 50 quote! { 51 #discriminant_size_type ::max_size() + #max_size 52 } 53 }; 54 55 quote! { 56 #[inline(always)] 57 fn max_size() -> usize { 58 use std::cmp::max; 59 #body 60 } 61 } 62 } 63 64 fn derive_peek_from_for_enum(s: &mut Structure) -> TokenStream { 65 assert!(!is_struct(s)); 66 s.bind_with(|_| BindStyle::Move); 67 68 let num_variants = s.variants().len(); 69 let discriminant_size_type = get_discriminant_size_type(num_variants); 70 let body = s 71 .variants() 72 .iter() 73 .enumerate() 74 .fold(quote!(), |acc, (i, vi)| { 75 let bindings = vi 76 .bindings() 77 .iter() 78 .map(|bi| quote!(#bi)) 79 .collect::<Vec<_>>(); 80 81 let variant_pat = Index::from(i); 82 let poke_exprs = bindings.iter().fold(quote!(), |acc, bi| { 83 quote! { 84 #acc 85 let (#bi, bytes) = peek_poke::peek_from_default(bytes); 86 } 87 }); 88 let construct = vi.construct(|_, i| { 89 let bi = &bindings[i]; 90 quote!(#bi) 91 }); 92 93 quote! { 94 #acc 95 #variant_pat => { 96 #poke_exprs 97 *output = #construct; 98 bytes 99 } 100 } 101 }); 102 103 let type_name = s.ast().ident.to_string(); 104 let max_tag_value = num_variants - 1; 105 106 quote! { 107 #[inline(always)] 108 unsafe fn peek_from(bytes: *const u8, output: *mut Self) -> *const u8 { 109 let (variant, bytes) = peek_poke::peek_from_default::<#discriminant_size_type>(bytes); 110 match variant { 111 #body 112 out_of_range_tag => { 113 panic!("WRDL: memory corruption detected while parsing {} - enum tag should be <= {}, but was {}", 114 #type_name, #max_tag_value, out_of_range_tag); 115 } 116 } 117 } 118 } 119 } 120 121 fn derive_peek_from_for_struct(s: &mut Structure) -> TokenStream { 122 assert!(is_struct(s)); 123 124 s.variants_mut()[0].bind_with(|_| BindStyle::RefMut); 125 let pat = s.variants()[0].pat(); 126 let peek_exprs = s.variants()[0].bindings().iter().fold(quote!(), |acc, bi| { 127 let ty = &bi.ast().ty; 128 quote! { 129 #acc 130 let bytes = <#ty>::peek_from(bytes, #bi); 131 } 132 }); 133 134 let body = quote! { 135 #pat => { 136 #peek_exprs 137 bytes 138 } 139 }; 140 141 quote! { 142 #[inline(always)] 143 unsafe fn peek_from(bytes: *const u8, output: *mut Self) -> *const u8 { 144 match &mut (*output) { 145 #body 146 } 147 } 148 } 149 } 150 151 fn derive_poke_into(s: &Structure) -> TokenStream { 152 let is_struct = is_struct(s); 153 let discriminant_size_type = get_discriminant_size_type(s.variants().len()); 154 let body = s 155 .variants() 156 .iter() 157 .enumerate() 158 .fold(quote!(), |acc, (i, vi)| { 159 let init = if !is_struct { 160 let index = Index::from(i); 161 quote! { 162 let bytes = #discriminant_size_type::poke_into(&#index, bytes); 163 } 164 } else { 165 quote!() 166 }; 167 let variant_pat = vi.pat(); 168 let poke_exprs = vi.bindings().iter().fold(init, |acc, bi| { 169 quote! { 170 #acc 171 let bytes = #bi.poke_into(bytes); 172 } 173 }); 174 175 quote! { 176 #acc 177 #variant_pat => { 178 #poke_exprs 179 bytes 180 } 181 } 182 }); 183 184 quote! { 185 #[inline(always)] 186 unsafe fn poke_into(&self, bytes: *mut u8) -> *mut u8 { 187 match &*self { 188 #body 189 } 190 } 191 } 192 } 193 194 fn peek_poke_derive(mut s: Structure) -> TokenStream { 195 s.binding_name(|_, i| Ident::new(&format!("__self_{}", i), Span::call_site())); 196 197 let max_size_fn = derive_max_size(&s); 198 let poke_into_fn = derive_poke_into(&s); 199 let peek_from_fn = if is_struct(&s) { 200 derive_peek_from_for_struct(&mut s) 201 } else { 202 derive_peek_from_for_enum(&mut s) 203 }; 204 205 let poke_impl = s.gen_impl(quote! { 206 extern crate peek_poke; 207 208 gen unsafe impl peek_poke::Poke for @Self { 209 #max_size_fn 210 #poke_into_fn 211 } 212 }); 213 214 // To implement `fn peek_from` we require that types implement `Default` 215 // trait to create temporary values. This code does the addition all 216 // manually until https://github.com/mystor/synstructure/issues/24 is fixed. 217 let default_trait = syn::parse_str::<TraitBound>("::std::default::Default").unwrap(); 218 let peek_trait = syn::parse_str::<TraitBound>("peek_poke::Peek").unwrap(); 219 220 let ast = s.ast(); 221 let name = &ast.ident; 222 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl(); 223 let mut where_clause = where_clause.cloned(); 224 s.add_trait_bounds(&default_trait, &mut where_clause, AddBounds::Generics); 225 s.add_trait_bounds(&peek_trait, &mut where_clause, AddBounds::Generics); 226 227 let peek_impl = quote! { 228 #[allow(non_upper_case_globals)] 229 const _: () = { 230 extern crate peek_poke; 231 232 impl #impl_generics peek_poke::Peek for #name #ty_generics #where_clause { 233 #peek_from_fn 234 } 235 }; 236 }; 237 238 quote! { 239 #poke_impl 240 #peek_impl 241 } 242 } 243 244 decl_derive!([PeekPoke] => peek_poke_derive);