diff options
Diffstat (limited to 'e2etest-macros/src/webtest.rs')
| -rw-r--r-- | e2etest-macros/src/webtest.rs | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/e2etest-macros/src/webtest.rs b/e2etest-macros/src/webtest.rs new file mode 100644 index 0000000..6b47559 --- /dev/null +++ b/e2etest-macros/src/webtest.rs @@ -0,0 +1,123 @@ +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::{format_ident, quote}; +use syn::{ + parse_macro_input, FnArg, ItemFn, PatType, ReturnType, + Visibility, +}; + +/** +* wrap webtest annotated functions in an async Pin<Box> +**/ +pub fn webtest_impl(_args: TokenStream, input: TokenStream) -> TokenStream { + let input_fn = parse_macro_input!(input as ItemFn); + + if let Err(err) = validate_function(&input_fn) { + return syn::Error::new(Span::call_site(), err) + .to_compile_error() + .into(); + } + + let fn_name = &input_fn.sig.ident; + let fn_name_str = fn_name.to_string(); + let vis = &input_fn.vis; + let attrs = &input_fn.attrs; + let fn_block = &input_fn.block; + let fn_inputs = &input_fn.sig.inputs; + let fn_output = &input_fn.sig.output; + let fn_generics = &input_fn.sig.generics; + let fn_asyncness = &input_fn.sig.asyncness; + + let original_fn_name = format_ident!("{}_original", fn_name); + + let expanded = quote! { + #(#attrs)* + #vis #fn_asyncness fn #original_fn_name #fn_generics(#fn_inputs) #fn_output #fn_block + + fn #fn_name #fn_generics( + driver: &::thirtyfour::WebDriver + ) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = crate::TestResult> + Send + '_>> { + Box::pin(async move { + #original_fn_name(driver).await + }) + } + + ::inventory::submit! { + crate::TestFunction { + name: #fn_name_str, + func: #fn_name, + } + } + }; + + TokenStream::from(expanded) +} + +fn validate_function(func: &ItemFn) -> Result<(), String> { + is_async(func)?; + is_visible(func)?; + matches_arguments(&["WebDriver"], func)?; + matches_return_type("TestResult", func)?; + + Ok(()) +} + +fn is_async(func: &ItemFn) -> Result<(), String> { + if func.sig.asyncness.is_none() { + Err("Test functions must be async".to_string()) + } else { Ok(()) } +} + +fn is_visible(func: &ItemFn) -> Result<(), String> { + match func.vis { + Visibility::Inherited | Visibility::Public(_) => { + Ok(()) + } + _ => { + Err("Test functions should be pub or have default visibility".to_string()) + } + } +} + +fn matches_arguments(types: &[&str], func: &ItemFn) -> Result<(), String> { + let args = matches_argument_count(types.len(), func)?; + for index in 0..args.len() { + matches_argument_type(types.get(index).unwrap(), args.get(index).unwrap())?; + } + Ok(()) +} + +fn matches_argument_count(expected: usize, func: &ItemFn) -> Result<Vec<&FnArg>, String> { + let inputs: Vec<_> = func.sig.inputs.iter().collect(); + if inputs.len() != expected { + Err(format!("Test functions must take exactly {} parameters", expected)) + } else { Ok(inputs) } +} + +fn matches_argument_type(expected: &str, arg: &FnArg) -> Result<(), String> { + match arg { + FnArg::Receiver(_) => { + Err("Test functions cannot be methods (no self parameter)".to_string()) + } + FnArg::Typed(PatType { ty, .. }) => { + let type_str = quote!(#ty).to_string(); + if !type_str.contains(expected) { + Err("First parameter must be &WebDriver".to_string()) + } else { Ok(()) } + } + } +} + +fn matches_return_type(expected: &str, func: &ItemFn) -> Result<(), String> { + match &func.sig.output { + ReturnType::Default => { + Err("Test functions must return TestResult".to_string()) + } + ReturnType::Type(_, ty) => { + let type_str = quote!(#ty).to_string(); + if !type_str.contains(expected) && !type_str.contains("Result") { + Err("Test functions must return TestResult".to_string()) + } else { Ok(()) } + } + } +} |
