Skip to main content

macros/
module.rs

1// SPDX-License-Identifier: GPL-2.0
2
3use std::ffi::CString;
4
5use proc_macro2::{
6    Literal,
7    TokenStream, //
8};
9use quote::{
10    format_ident,
11    quote, //
12};
13use syn::{
14    braced,
15    bracketed,
16    ext::IdentExt,
17    parse::{
18        Parse,
19        ParseStream, //
20    },
21    parse_quote,
22    punctuated::Punctuated,
23    Error,
24    Expr,
25    Ident,
26    LitStr,
27    Path,
28    Result,
29    Token,
30    Type, //
31};
32
33use crate::helpers::*;
34
35struct ModInfoBuilder<'a> {
36    module: &'a str,
37    counter: usize,
38    ts: TokenStream,
39    param_ts: TokenStream,
40}
41
42impl<'a> ModInfoBuilder<'a> {
43    fn new(module: &'a str) -> Self {
44        ModInfoBuilder {
45            module,
46            counter: 0,
47            ts: TokenStream::new(),
48            param_ts: TokenStream::new(),
49        }
50    }
51
52    fn emit_base(&mut self, field: &str, content: &str, builtin: bool, param: bool) {
53        let string = if builtin {
54            // Built-in modules prefix their modinfo strings by `module.`.
55            format!("{module}.{field}={content}\0", module = self.module)
56        } else {
57            // Loadable modules' modinfo strings go as-is.
58            format!("{field}={content}\0")
59        };
60        let length = string.len();
61        let string = Literal::byte_string(string.as_bytes());
62        let cfg = if builtin {
63            quote!(#[cfg(not(MODULE))])
64        } else {
65            quote!(#[cfg(MODULE)])
66        };
67
68        let counter = format_ident!(
69            "__{module}_{counter}",
70            module = self.module.to_uppercase(),
71            counter = self.counter
72        );
73        let item = quote! {
74            #cfg
75            #[cfg_attr(not(target_os = "macos"), link_section = ".modinfo")]
76            #[used(compiler)]
77            pub static #counter: [u8; #length] = *#string;
78        };
79
80        if param {
81            self.param_ts.extend(item);
82        } else {
83            self.ts.extend(item);
84        }
85
86        self.counter += 1;
87    }
88
89    fn emit_only_builtin(&mut self, field: &str, content: &str, param: bool) {
90        self.emit_base(field, content, true, param)
91    }
92
93    fn emit_only_loadable(&mut self, field: &str, content: &str, param: bool) {
94        self.emit_base(field, content, false, param)
95    }
96
97    fn emit(&mut self, field: &str, content: &str) {
98        self.emit_internal(field, content, false);
99    }
100
101    fn emit_internal(&mut self, field: &str, content: &str, param: bool) {
102        self.emit_only_builtin(field, content, param);
103        self.emit_only_loadable(field, content, param);
104    }
105
106    fn emit_param(&mut self, field: &str, param: &str, content: &str) {
107        let content = format!("{param}:{content}");
108        self.emit_internal(field, &content, true);
109    }
110
111    fn emit_params(&mut self, info: &ModuleInfo) {
112        let Some(params) = &info.params else {
113            return;
114        };
115
116        for param in params {
117            let param_name_str = param.name.to_string();
118            let param_type_str = param.ptype.to_string();
119
120            let ops = param_ops_path(&param_type_str);
121
122            // Note: The spelling of these fields is dictated by the user space
123            // tool `modinfo`.
124            self.emit_param("parmtype", &param_name_str, &param_type_str);
125            self.emit_param("parm", &param_name_str, &param.description.value());
126
127            let static_name = format_ident!("__{}_{}_struct", self.module, param.name);
128            let param_name_cstr =
129                CString::new(param_name_str).expect("name contains NUL-terminator");
130            let param_name_cstr_with_module =
131                CString::new(format!("{}.{}", self.module, param.name))
132                    .expect("name contains NUL-terminator");
133
134            let param_name = &param.name;
135            let param_type = &param.ptype;
136            let param_default = &param.default;
137
138            self.param_ts.extend(quote! {
139                #[allow(non_upper_case_globals)]
140                pub(crate) static #param_name:
141                    ::kernel::module_param::ModuleParamAccess<#param_type> =
142                        ::kernel::module_param::ModuleParamAccess::new(#param_default);
143
144                const _: () = {
145                    #[allow(non_upper_case_globals)]
146                    #[link_section = "__param"]
147                    #[used(compiler)]
148                    static #static_name:
149                        ::kernel::module_param::KernelParam =
150                        ::kernel::module_param::KernelParam::new(
151                            ::kernel::bindings::kernel_param {
152                                name: kernel::str::as_char_ptr_in_const_context(
153                                    if ::core::cfg!(MODULE) {
154                                        #param_name_cstr
155                                    } else {
156                                        #param_name_cstr_with_module
157                                    }
158                                ),
159                                // SAFETY: `__this_module` is constructed by the kernel at load
160                                // time and will not be freed until the module is unloaded.
161                                #[cfg(MODULE)]
162                                mod_: unsafe {
163                                    core::ptr::from_ref(&::kernel::bindings::__this_module)
164                                        .cast_mut()
165                                },
166                                #[cfg(not(MODULE))]
167                                mod_: ::core::ptr::null_mut(),
168                                ops: core::ptr::from_ref(&#ops),
169                                perm: 0, // Will not appear in sysfs
170                                level: -1,
171                                flags: 0,
172                                __bindgen_anon_1: ::kernel::bindings::kernel_param__bindgen_ty_1 {
173                                    arg: #param_name.as_void_ptr()
174                                },
175                            }
176                        );
177                };
178            });
179        }
180    }
181}
182
183fn param_ops_path(param_type: &str) -> Path {
184    match param_type {
185        "i8" => parse_quote!(::kernel::module_param::PARAM_OPS_I8),
186        "u8" => parse_quote!(::kernel::module_param::PARAM_OPS_U8),
187        "i16" => parse_quote!(::kernel::module_param::PARAM_OPS_I16),
188        "u16" => parse_quote!(::kernel::module_param::PARAM_OPS_U16),
189        "i32" => parse_quote!(::kernel::module_param::PARAM_OPS_I32),
190        "u32" => parse_quote!(::kernel::module_param::PARAM_OPS_U32),
191        "i64" => parse_quote!(::kernel::module_param::PARAM_OPS_I64),
192        "u64" => parse_quote!(::kernel::module_param::PARAM_OPS_U64),
193        "isize" => parse_quote!(::kernel::module_param::PARAM_OPS_ISIZE),
194        "usize" => parse_quote!(::kernel::module_param::PARAM_OPS_USIZE),
195        t => panic!("Unsupported parameter type {}", t),
196    }
197}
198
199/// Parse fields that are required to use a specific order.
200///
201/// As fields must follow a specific order, we *could* just parse fields one by one by peeking.
202/// However the error message generated when implementing that way is not very friendly.
203///
204/// So instead we parse fields in an arbitrary order, but only enforce the ordering after parsing,
205/// and if the wrong order is used, the proper order is communicated to the user with error message.
206///
207/// Usage looks like this:
208/// ```ignore
209/// parse_ordered_fields! {
210///     from input;
211///
212///     // This will extract "foo: <field>" into a variable named "foo".
213///     // The variable will have type `Option<_>`.
214///     foo => <expression that parses the field>,
215///
216///     // If you need the variable name to be different than the key name.
217///     // This extracts "baz: <field>" into a variable named "bar".
218///     // You might want this if "baz" is a keyword.
219///     baz as bar => <expression that parse the field>,
220///
221///     // You can mark a key as required, and the variable will no longer be `Option`.
222///     // foobar will be of type `Expr` instead of `Option<Expr>`.
223///     foobar [required] => input.parse::<Expr>()?,
224/// }
225/// ```
226macro_rules! parse_ordered_fields {
227    (@gen
228        [$input:expr]
229        [$([$name:ident; $key:ident; $parser:expr])*]
230        [$([$req_name:ident; $req_key:ident])*]
231    ) => {
232        $(let mut $name = None;)*
233
234        const EXPECTED_KEYS: &[&str] = &[$(stringify!($key),)*];
235        const REQUIRED_KEYS: &[&str] = &[$(stringify!($req_key),)*];
236
237        let span = $input.span();
238        let mut seen_keys = Vec::new();
239
240        while !$input.is_empty() {
241            let key = $input.call(Ident::parse_any)?;
242
243            if seen_keys.contains(&key) {
244                Err(Error::new_spanned(
245                    &key,
246                    format!(r#"duplicated key "{key}". Keys can only be specified once."#),
247                ))?
248            }
249
250            $input.parse::<Token![:]>()?;
251
252            match &*key.to_string() {
253                $(
254                    stringify!($key) => $name = Some($parser),
255                )*
256                _ => {
257                    Err(Error::new_spanned(
258                        &key,
259                        format!(r#"unknown key "{key}". Valid keys are: {EXPECTED_KEYS:?}."#),
260                    ))?
261                }
262            }
263
264            $input.parse::<Token![,]>()?;
265            seen_keys.push(key);
266        }
267
268        for key in REQUIRED_KEYS {
269            if !seen_keys.iter().any(|e| e == key) {
270                Err(Error::new(span, format!(r#"missing required key "{key}""#)))?
271            }
272        }
273
274        let mut ordered_keys: Vec<&str> = Vec::new();
275        for key in EXPECTED_KEYS {
276            if seen_keys.iter().any(|e| e == key) {
277                ordered_keys.push(key);
278            }
279        }
280
281        if seen_keys != ordered_keys {
282            Err(Error::new(
283                span,
284                format!(r#"keys are not ordered as expected. Order them like: {ordered_keys:?}."#),
285            ))?
286        }
287
288        $(let $req_name = $req_name.expect("required field");)*
289    };
290
291    // Handle required fields.
292    (@gen
293        [$input:expr] [$($tok:tt)*] [$($req:tt)*]
294        $key:ident as $name:ident [required] => $parser:expr,
295        $($rest:tt)*
296    ) => {
297        parse_ordered_fields!(
298            @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)* [$name; $key]] $($rest)*
299        )
300    };
301    (@gen
302        [$input:expr] [$($tok:tt)*] [$($req:tt)*]
303        $name:ident [required] => $parser:expr,
304        $($rest:tt)*
305    ) => {
306        parse_ordered_fields!(
307            @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)* [$name; $name]] $($rest)*
308        )
309    };
310
311    // Handle optional fields.
312    (@gen
313        [$input:expr] [$($tok:tt)*] [$($req:tt)*]
314        $key:ident as $name:ident => $parser:expr,
315        $($rest:tt)*
316    ) => {
317        parse_ordered_fields!(
318            @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)*] $($rest)*
319        )
320    };
321    (@gen
322        [$input:expr] [$($tok:tt)*] [$($req:tt)*]
323        $name:ident => $parser:expr,
324        $($rest:tt)*
325    ) => {
326        parse_ordered_fields!(
327            @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)*] $($rest)*
328        )
329    };
330
331    (from $input:expr; $($tok:tt)*) => {
332        parse_ordered_fields!(@gen [$input] [] [] $($tok)*)
333    }
334}
335
336struct Parameter {
337    name: Ident,
338    ptype: Ident,
339    default: Expr,
340    description: LitStr,
341}
342
343impl Parse for Parameter {
344    fn parse(input: ParseStream<'_>) -> Result<Self> {
345        let name = input.parse()?;
346        input.parse::<Token![:]>()?;
347        let ptype = input.parse()?;
348
349        let fields;
350        braced!(fields in input);
351
352        parse_ordered_fields! {
353            from fields;
354            default [required] => fields.parse()?,
355            description [required] => fields.parse()?,
356        }
357
358        Ok(Self {
359            name,
360            ptype,
361            default,
362            description,
363        })
364    }
365}
366
367pub(crate) struct ModuleInfo {
368    type_: Type,
369    license: AsciiLitStr,
370    name: AsciiLitStr,
371    authors: Option<Punctuated<AsciiLitStr, Token![,]>>,
372    description: Option<LitStr>,
373    alias: Option<Punctuated<AsciiLitStr, Token![,]>>,
374    firmware: Option<Punctuated<AsciiLitStr, Token![,]>>,
375    imports_ns: Option<Punctuated<AsciiLitStr, Token![,]>>,
376    params: Option<Punctuated<Parameter, Token![,]>>,
377}
378
379impl Parse for ModuleInfo {
380    fn parse(input: ParseStream<'_>) -> Result<Self> {
381        parse_ordered_fields!(
382            from input;
383            type as type_ [required] => input.parse()?,
384            name [required] => input.parse()?,
385            authors => {
386                let list;
387                bracketed!(list in input);
388                Punctuated::parse_terminated(&list)?
389            },
390            description => input.parse()?,
391            license [required] => input.parse()?,
392            alias => {
393                let list;
394                bracketed!(list in input);
395                Punctuated::parse_terminated(&list)?
396            },
397            firmware => {
398                let list;
399                bracketed!(list in input);
400                Punctuated::parse_terminated(&list)?
401            },
402            imports_ns => {
403                let list;
404                bracketed!(list in input);
405                Punctuated::parse_terminated(&list)?
406            },
407            params => {
408                let list;
409                braced!(list in input);
410                Punctuated::parse_terminated(&list)?
411            },
412        );
413
414        Ok(ModuleInfo {
415            type_,
416            license,
417            name,
418            authors,
419            description,
420            alias,
421            firmware,
422            imports_ns,
423            params,
424        })
425    }
426}
427
428pub(crate) fn module(info: ModuleInfo) -> Result<TokenStream> {
429    let ModuleInfo {
430        type_,
431        license,
432        name,
433        authors,
434        description,
435        alias,
436        firmware,
437        imports_ns,
438        params: _,
439    } = &info;
440
441    // Rust does not allow hyphens in identifiers, use underscore instead.
442    let ident = name.value().replace('-', "_");
443    let mut modinfo = ModInfoBuilder::new(ident.as_ref());
444    if let Some(authors) = authors {
445        for author in authors {
446            modinfo.emit("author", &author.value());
447        }
448    }
449    if let Some(description) = description {
450        modinfo.emit("description", &description.value());
451    }
452    modinfo.emit("license", &license.value());
453    if let Some(aliases) = alias {
454        for alias in aliases {
455            modinfo.emit("alias", &alias.value());
456        }
457    }
458    if let Some(firmware) = firmware {
459        for fw in firmware {
460            modinfo.emit("firmware", &fw.value());
461        }
462    }
463    if let Some(imports) = imports_ns {
464        for ns in imports {
465            modinfo.emit("import_ns", &ns.value());
466        }
467    }
468
469    // Built-in modules also export the `file` modinfo string.
470    let file =
471        std::env::var("RUST_MODFILE").expect("Unable to fetch RUST_MODFILE environmental variable");
472    modinfo.emit_only_builtin("file", &file, false);
473
474    modinfo.emit_params(&info);
475
476    let modinfo_ts = modinfo.ts;
477    let params_ts = modinfo.param_ts;
478
479    let ident_init = format_ident!("__{ident}_init");
480    let ident_exit = format_ident!("__{ident}_exit");
481    let ident_initcall = format_ident!("__{ident}_initcall");
482    let initcall_section = ".initcall6.init";
483
484    let global_asm = format!(
485        r#".section "{initcall_section}", "a"
486        __{ident}_initcall:
487            .long   __{ident}_init - .
488            .previous
489        "#
490    );
491
492    let name_cstr = CString::new(name.value()).expect("name contains NUL-terminator");
493
494    Ok(quote! {
495        /// The module name.
496        ///
497        /// Used by the printing macros, e.g. [`info!`].
498        const __LOG_PREFIX: &[u8] = #name_cstr.to_bytes_with_nul();
499
500        // SAFETY: `__this_module` is constructed by the kernel at load time and will not be
501        // freed until the module is unloaded.
502        #[cfg(MODULE)]
503        static THIS_MODULE: ::kernel::ThisModule = unsafe {
504            extern "C" {
505                static __this_module: ::kernel::types::Opaque<::kernel::bindings::module>;
506            };
507
508            ::kernel::ThisModule::from_ptr(__this_module.get())
509        };
510
511        #[cfg(not(MODULE))]
512        static THIS_MODULE: ::kernel::ThisModule = unsafe {
513            ::kernel::ThisModule::from_ptr(::core::ptr::null_mut())
514        };
515
516        /// The `LocalModule` type is the type of the module created by `module!`,
517        /// `module_pci_driver!`, `module_platform_driver!`, etc.
518        type LocalModule = #type_;
519
520        impl ::kernel::ModuleMetadata for #type_ {
521            const NAME: &'static ::kernel::str::CStr = #name_cstr;
522        }
523
524        // Double nested modules, since then nobody can access the public items inside.
525        #[doc(hidden)]
526        mod __module_init {
527            mod __module_init {
528                use pin_init::PinInit;
529
530                /// The "Rust loadable module" mark.
531                //
532                // This may be best done another way later on, e.g. as a new modinfo
533                // key or a new section. For the moment, keep it simple.
534                #[cfg(MODULE)]
535                #[used(compiler)]
536                static __IS_RUST_MODULE: () = ();
537
538                static mut __MOD: ::core::mem::MaybeUninit<super::super::LocalModule> =
539                    ::core::mem::MaybeUninit::uninit();
540
541                // Loadable modules need to export the `{init,cleanup}_module` identifiers.
542                /// # Safety
543                ///
544                /// This function must not be called after module initialization, because it may be
545                /// freed after that completes.
546                #[cfg(MODULE)]
547                #[no_mangle]
548                #[link_section = ".init.text"]
549                pub unsafe extern "C" fn init_module() -> ::kernel::ffi::c_int {
550                    // SAFETY: This function is inaccessible to the outside due to the double
551                    // module wrapping it. It is called exactly once by the C side via its
552                    // unique name.
553                    unsafe { __init() }
554                }
555
556                #[cfg(MODULE)]
557                #[used(compiler)]
558                #[link_section = ".init.data"]
559                static __UNIQUE_ID___addressable_init_module: unsafe extern "C" fn() -> i32 =
560                    init_module;
561
562                #[cfg(MODULE)]
563                #[no_mangle]
564                #[link_section = ".exit.text"]
565                pub extern "C" fn cleanup_module() {
566                    // SAFETY:
567                    // - This function is inaccessible to the outside due to the double
568                    //   module wrapping it. It is called exactly once by the C side via its
569                    //   unique name,
570                    // - furthermore it is only called after `init_module` has returned `0`
571                    //   (which delegates to `__init`).
572                    unsafe { __exit() }
573                }
574
575                #[cfg(MODULE)]
576                #[used(compiler)]
577                #[link_section = ".exit.data"]
578                static __UNIQUE_ID___addressable_cleanup_module: extern "C" fn() = cleanup_module;
579
580                // Built-in modules are initialized through an initcall pointer
581                // and the identifiers need to be unique.
582                #[cfg(not(MODULE))]
583                #[cfg(not(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS))]
584                #[link_section = #initcall_section]
585                #[used(compiler)]
586                pub static #ident_initcall: extern "C" fn() ->
587                    ::kernel::ffi::c_int = #ident_init;
588
589                #[cfg(not(MODULE))]
590                #[cfg(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS)]
591                ::core::arch::global_asm!(#global_asm);
592
593                #[cfg(not(MODULE))]
594                #[no_mangle]
595                pub extern "C" fn #ident_init() -> ::kernel::ffi::c_int {
596                    // SAFETY: This function is inaccessible to the outside due to the double
597                    // module wrapping it. It is called exactly once by the C side via its
598                    // placement above in the initcall section.
599                    unsafe { __init() }
600                }
601
602                #[cfg(not(MODULE))]
603                #[no_mangle]
604                pub extern "C" fn #ident_exit() {
605                    // SAFETY:
606                    // - This function is inaccessible to the outside due to the double
607                    //   module wrapping it. It is called exactly once by the C side via its
608                    //   unique name,
609                    // - furthermore it is only called after `#ident_init` has
610                    //   returned `0` (which delegates to `__init`).
611                    unsafe { __exit() }
612                }
613
614                /// # Safety
615                ///
616                /// This function must only be called once.
617                unsafe fn __init() -> ::kernel::ffi::c_int {
618                    let initer = <super::super::LocalModule as ::kernel::InPlaceModule>::init(
619                        &super::super::THIS_MODULE
620                    );
621                    // SAFETY: No data race, since `__MOD` can only be accessed by this module
622                    // and there only `__init` and `__exit` access it. These functions are only
623                    // called once and `__exit` cannot be called before or during `__init`.
624                    match unsafe { initer.__pinned_init(__MOD.as_mut_ptr()) } {
625                        Ok(m) => 0,
626                        Err(e) => e.to_errno(),
627                    }
628                }
629
630                /// # Safety
631                ///
632                /// This function must
633                /// - only be called once,
634                /// - be called after `__init` has been called and returned `0`.
635                unsafe fn __exit() {
636                    // SAFETY: No data race, since `__MOD` can only be accessed by this module
637                    // and there only `__init` and `__exit` access it. These functions are only
638                    // called once and `__init` was already called.
639                    unsafe {
640                        // Invokes `drop()` on `__MOD`, which should be used for cleanup.
641                        __MOD.assume_init_drop();
642                    }
643                }
644
645                #modinfo_ts
646            }
647        }
648
649        mod module_parameters {
650            #params_ts
651        }
652    })
653}