1use std::ffi::CString;
8
9use proc_macro2::TokenStream;
10use quote::{
11 format_ident,
12 quote,
13 ToTokens, };
15use syn::{
16 parse_quote,
17 Error,
18 Ident,
19 Item,
20 ItemMod,
21 LitCStr,
22 Result, };
24
25pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> {
26 if test_suite.to_string().len() > 255 {
27 return Err(Error::new_spanned(
28 test_suite,
29 "test suite names cannot exceed the maximum length of 255 bytes",
30 ));
31 }
32
33 let Some((module_brace, module_items)) = module.content.take() else {
35 Err(Error::new_spanned(
36 module,
37 "`#[kunit_tests(test_name)]` attribute should only be applied to inline modules",
38 ))?
39 };
40
41 module
43 .attrs
44 .insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")]));
45
46 let mut processed_items = Vec::new();
47 let mut test_cases = Vec::new();
48
49 for item in module_items {
85 let Item::Fn(mut f) = item else {
86 processed_items.push(item);
87 continue;
88 };
89
90 if f.attrs
91 .extract_if(.., |attr| attr.path().is_ident("test"))
92 .count()
93 == 0
94 {
95 processed_items.push(Item::Fn(f));
96 continue;
97 }
98
99 let test = f.sig.ident.clone();
100
101 let cfg_attrs: Vec<_> = f
103 .attrs
104 .iter()
105 .filter(|attr| attr.path().is_ident("cfg"))
106 .cloned()
107 .collect();
108
109 let test_str = test.to_string();
112 let path = CString::new(crate::helpers::file()).expect("file path cannot contain NUL");
113 processed_items.push(parse_quote! {
114 #[allow(unused)]
115 macro_rules! assert {
116 ($cond:expr $(,)?) => {{
117 kernel::kunit_assert!(#test_str, #path, 0, $cond);
118 }}
119 }
120 });
121 processed_items.push(parse_quote! {
122 #[allow(unused)]
123 macro_rules! assert_eq {
124 ($left:expr, $right:expr $(,)?) => {{
125 kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right);
126 }}
127 }
128 });
129
130 processed_items.push(Item::Fn(f));
132
133 let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}");
134 let test_cstr = LitCStr::new(
135 &CString::new(test_str.as_str()).expect("identifier cannot contain NUL"),
136 test.span(),
137 );
138 processed_items.push(parse_quote! {
139 unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) {
140 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
141
142 #(#cfg_attrs)*
146 {
147 (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
148 use ::kernel::kunit::is_test_result_ok;
149 assert!(is_test_result_ok(#test()));
150 }
151 }
152 });
153
154 test_cases.push(quote!(
155 ::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name)
156 ));
157 }
158
159 let num_tests_plus_1 = test_cases.len() + 1;
160 processed_items.push(parse_quote! {
161 static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [
162 #(#test_cases,)*
163 ::pin_init::zeroed(),
164 ];
165 });
166 processed_items.push(parse_quote! {
167 ::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES);
168 });
169
170 module.content = Some((module_brace, processed_items));
171 Ok(module.to_token_stream())
172}