compute_squared_distance.rs (4519B)
1 /* This Source Code Form is subject to the terms of the Mozilla Public 2 * License, v. 2.0. If a copy of the MPL was not distributed with this 3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ 4 5 use crate::animate::{AnimationFieldAttrs, AnimationInputAttrs, AnimationVariantAttrs}; 6 use crate::cg; 7 use proc_macro2::TokenStream; 8 use quote::TokenStreamExt; 9 use syn::{DeriveInput, WhereClause}; 10 use synstructure; 11 12 pub fn derive(mut input: DeriveInput) -> TokenStream { 13 let animation_input_attrs = cg::parse_input_attrs::<AnimationInputAttrs>(&input); 14 let no_bound = animation_input_attrs.no_bound.unwrap_or_default(); 15 let mut where_clause = input.generics.where_clause.take(); 16 for param in input.generics.type_params() { 17 if !no_bound.iter().any(|name| name.is_ident(¶m.ident)) { 18 cg::add_predicate( 19 &mut where_clause, 20 parse_quote!(#param: crate::values::distance::ComputeSquaredDistance), 21 ); 22 } 23 } 24 25 let (mut match_body, needs_catchall_branch) = { 26 let s = synstructure::Structure::new(&input); 27 let needs_catchall_branch = s.variants().len() > 1; 28 29 let match_body = s.variants().iter().fold(quote!(), |body, variant| { 30 let arm = derive_variant_arm(variant, &mut where_clause); 31 quote! { #body #arm } 32 }); 33 34 (match_body, needs_catchall_branch) 35 }; 36 37 input.generics.where_clause = where_clause; 38 39 if needs_catchall_branch { 40 // This ideally shouldn't be needed, but see: 41 // https://github.com/rust-lang/rust/issues/68867 42 43 match_body.append_all(quote! { _ => unsafe { 44 use ::debug_unreachable::debug_unreachable; 45 debug_unreachable!() 46 } }); 47 } 48 49 let name = &input.ident; 50 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); 51 52 quote! { 53 impl #impl_generics crate::values::distance::ComputeSquaredDistance for #name #ty_generics #where_clause { 54 #[allow(unused_variables, unused_imports)] 55 #[inline] 56 fn compute_squared_distance( 57 &self, 58 other: &Self, 59 ) -> Result<crate::values::distance::SquaredDistance, ()> { 60 if std::mem::discriminant(self) != std::mem::discriminant(other) { 61 return Err(()); 62 } 63 match (self, other) { 64 #match_body 65 } 66 } 67 } 68 } 69 } 70 71 fn derive_variant_arm( 72 variant: &synstructure::VariantInfo, 73 mut where_clause: &mut Option<WhereClause>, 74 ) -> TokenStream { 75 let variant_attrs = cg::parse_variant_attrs_from_ast::<AnimationVariantAttrs>(&variant.ast()); 76 let (this_pattern, this_info) = cg::ref_pattern(&variant, "this"); 77 let (other_pattern, other_info) = cg::ref_pattern(&variant, "other"); 78 79 if variant_attrs.error { 80 return quote! { 81 (&#this_pattern, &#other_pattern) => Err(()), 82 }; 83 } 84 85 let sum = if this_info.is_empty() { 86 quote! { crate::values::distance::SquaredDistance::from_sqrt(0.) } 87 } else { 88 let mut sum = quote!(); 89 sum.append_separated(this_info.iter().zip(&other_info).map(|(this, other)| { 90 let field_attrs = cg::parse_field_attrs::<DistanceFieldAttrs>(&this.ast()); 91 if field_attrs.field_bound { 92 let ty = &this.ast().ty; 93 cg::add_predicate( 94 &mut where_clause, 95 parse_quote!(#ty: crate::values::distance::ComputeSquaredDistance), 96 ); 97 } 98 99 let animation_field_attrs = 100 cg::parse_field_attrs::<AnimationFieldAttrs>(&this.ast()); 101 102 if animation_field_attrs.constant { 103 quote! { 104 { 105 if #this != #other { 106 return Err(()); 107 } 108 crate::values::distance::SquaredDistance::from_sqrt(0.) 109 } 110 } 111 } else { 112 quote! { 113 crate::values::distance::ComputeSquaredDistance::compute_squared_distance(#this, #other)? 114 } 115 } 116 }), quote!(+)); 117 sum 118 }; 119 120 return quote! { 121 (&#this_pattern, &#other_pattern) => Ok(#sum), 122 }; 123 } 124 125 #[derive(Default, FromField)] 126 #[darling(attributes(distance), default)] 127 struct DistanceFieldAttrs { 128 field_bound: bool, 129 }