use proc_macro::TokenStream; use proc_macro2::Span; use quote::{format_ident, quote}; use syn::{ parse_macro_input, FnArg, ItemFn, PatType, ReturnType, Visibility, }; /** * wrap webtest annotated functions in an async Pin **/ pub fn webtest_impl(_args: TokenStream, input: TokenStream) -> TokenStream { let input_fn = parse_macro_input!(input as ItemFn); if let Err(err) = validate_function(&input_fn) { return syn::Error::new(Span::call_site(), err) .to_compile_error() .into(); } let fn_name = &input_fn.sig.ident; let fn_name_str = fn_name.to_string(); let vis = &input_fn.vis; let attrs = &input_fn.attrs; let fn_block = &input_fn.block; let fn_inputs = &input_fn.sig.inputs; let fn_output = &input_fn.sig.output; let fn_generics = &input_fn.sig.generics; let fn_asyncness = &input_fn.sig.asyncness; let original_fn_name = format_ident!("{}_original", fn_name); let expanded = quote! { #(#attrs)* #vis #fn_asyncness fn #original_fn_name #fn_generics(#fn_inputs) #fn_output #fn_block fn #fn_name #fn_generics( driver: &::thirtyfour::WebDriver ) -> ::std::pin::Pin + Send + '_>> { Box::pin(async move { #original_fn_name(driver).await }) } ::inventory::submit! { crate::TestFunction { name: #fn_name_str, func: #fn_name, } } }; TokenStream::from(expanded) } fn validate_function(func: &ItemFn) -> Result<(), String> { is_async(func)?; is_visible(func)?; matches_arguments(&["WebDriver"], func)?; matches_return_type("TestResult", func)?; Ok(()) } fn is_async(func: &ItemFn) -> Result<(), String> { if func.sig.asyncness.is_none() { Err("Test functions must be async".to_string()) } else { Ok(()) } } fn is_visible(func: &ItemFn) -> Result<(), String> { match func.vis { Visibility::Inherited | Visibility::Public(_) => { Ok(()) } _ => { Err("Test functions should be pub or have default visibility".to_string()) } } } fn matches_arguments(types: &[&str], func: &ItemFn) -> Result<(), String> { let args = matches_argument_count(types.len(), func)?; for index in 0..args.len() { matches_argument_type(types.get(index).unwrap(), args.get(index).unwrap())?; } Ok(()) } fn matches_argument_count(expected: usize, func: &ItemFn) -> Result, String> { let inputs: Vec<_> = func.sig.inputs.iter().collect(); if inputs.len() != expected { Err(format!("Test functions must take exactly {} parameters", expected)) } else { Ok(inputs) } } fn matches_argument_type(expected: &str, arg: &FnArg) -> Result<(), String> { match arg { FnArg::Receiver(_) => { Err("Test functions cannot be methods (no self parameter)".to_string()) } FnArg::Typed(PatType { ty, .. }) => { let type_str = quote!(#ty).to_string(); if !type_str.contains(expected) { Err("First parameter must be &WebDriver".to_string()) } else { Ok(()) } } } } fn matches_return_type(expected: &str, func: &ItemFn) -> Result<(), String> { match &func.sig.output { ReturnType::Default => { Err("Test functions must return TestResult".to_string()) } ReturnType::Type(_, ty) => { let type_str = quote!(#ty).to_string(); if !type_str.contains(expected) && !type_str.contains("Result") { Err("Test functions must return TestResult".to_string()) } else { Ok(()) } } } }