Skip to main content

dfir_lang/
process_singletons.rs

1//! Utility methods for processing singleton references: `#my_var`.
2
3use itertools::Itertools;
4use proc_macro2::{Group, Ident, TokenStream, TokenTree};
5use syn::punctuated::Punctuated;
6use syn::{Expr, Token};
7
8use crate::parse::parse_terminated;
9
10/// Finds all the singleton references `#my_var` and appends them to `found_idents`. Returns the
11/// `TokenStream` but with the hashes removed from the varnames.
12///
13/// The returned tokens are used for "preflight" parsing, to check that the rest of the syntax is
14/// OK. However the returned tokens are not used in the codegen as we need to use [`postprocess_singletons`]
15/// later to substitute-in the context referencing code for each singleton
16pub fn preprocess_singletons(tokens: TokenStream, found_idents: &mut Vec<Ident>) -> TokenStream {
17    process_singletons(tokens, &mut |singleton_ident| {
18        found_idents.push(singleton_ident.clone());
19        TokenTree::Ident(singleton_ident)
20    })
21}
22
23/// Replaces singleton references `#my_var` with the code needed to actually get the value inside.
24///
25/// * `tokens` - The tokens to update singleton references within.
26/// * `resolved_idents` - The local variable idents that correspond 1:1 and in the same
27///   order as the singleton references within `tokens` (found in-order via [`preprocess_singletons`]).
28///
29/// Generates a direct reference to the local variable. Use [`postprocess_singletons_handles`] for
30/// just the raw idents.
31pub fn postprocess_singletons(
32    tokens: TokenStream,
33    resolved_idents: impl IntoIterator<Item = Ident>,
34) -> Punctuated<Expr, Token![,]> {
35    let mut resolved_idents_iter = resolved_idents.into_iter();
36    let processed = process_singletons(tokens, &mut |singleton_ident| {
37        let span = singleton_ident.span();
38        let mut resolved_ident = resolved_idents_iter.next().unwrap();
39        resolved_ident.set_span(span);
40        TokenTree::Ident(resolved_ident)
41    });
42    parse_terminated(processed).unwrap()
43}
44
45/// Same as [`postprocess_singletons`] but generates just the raw ident rather than
46/// `RefCell` borrowing code.
47pub fn postprocess_singletons_handles(
48    tokens: TokenStream,
49    resolved_idents: impl IntoIterator<Item = Ident>,
50) -> Punctuated<Expr, Token![,]> {
51    let mut resolved_idents_iter = resolved_idents.into_iter();
52    let processed = process_singletons(tokens, &mut |singleton_ident| {
53        let mut resolved_ident = resolved_idents_iter.next().unwrap();
54        resolved_ident.set_span(singleton_ident.span().resolved_at(resolved_ident.span()));
55        TokenTree::Ident(resolved_ident)
56    });
57    parse_terminated(processed).unwrap()
58}
59
60/// Traverse the token stream, applying the `map_singleton_fn` whenever a singleton is found,
61/// returning the transformed token stream.
62fn process_singletons(
63    tokens: TokenStream,
64    map_singleton_fn: &mut impl FnMut(Ident) -> TokenTree,
65) -> TokenStream {
66    tokens
67        .into_iter()
68        .peekable()
69        .batching(|iter| {
70            let out = match iter.next()? {
71                TokenTree::Group(group) => {
72                    let mut new_group = Group::new(
73                        group.delimiter(),
74                        process_singletons(group.stream(), map_singleton_fn),
75                    );
76                    new_group.set_span(group.span());
77                    TokenTree::Group(new_group)
78                }
79                TokenTree::Ident(ident) => TokenTree::Ident(ident),
80                TokenTree::Punct(punct) => {
81                    if '#' == punct.as_char() && matches!(iter.peek(), Some(TokenTree::Ident(_))) {
82                        // Found a singleton.
83                        let Some(TokenTree::Ident(mut singleton_ident)) = iter.next() else {
84                            unreachable!()
85                        };
86                        {
87                            // Include the `#` in the span.
88                            let span = singleton_ident
89                                .span()
90                                .join(punct.span())
91                                .unwrap_or(singleton_ident.span());
92                            singleton_ident.set_span(span.resolved_at(singleton_ident.span()));
93                        }
94                        (map_singleton_fn)(singleton_ident)
95                    } else {
96                        TokenTree::Punct(punct)
97                    }
98                }
99                TokenTree::Literal(lit) => TokenTree::Literal(lit),
100            };
101            Some(out)
102        })
103        .collect()
104}