Skip to main content

dfir_lang/graph/
meta_graph.rs

1#![warn(missing_docs)]
2
3extern crate proc_macro;
4
5use std::collections::{BTreeMap, BTreeSet};
6use std::fmt::Debug;
7use std::iter::FusedIterator;
8
9use itertools::Itertools;
10use proc_macro2::{Ident, Literal, Span, TokenStream};
11use quote::{ToTokens, format_ident, quote, quote_spanned};
12use serde::{Deserialize, Serialize};
13use slotmap::{Key, SecondaryMap, SlotMap, SparseSecondaryMap};
14use syn::spanned::Spanned;
15
16use super::graph_write::{Dot, GraphWrite, Mermaid};
17use super::ops::{
18    DelayType, OPERATORS, OperatorWriteOutput, WriteContextArgs, find_op_op_constraints,
19    null_write_iterator_fn,
20};
21use super::{
22    CONTEXT, Color, DiMulGraph, GRAPH, GraphEdgeId, GraphLoopId, GraphNode, GraphNodeId,
23    GraphSubgraphId, HANDOFF_NODE_STR, MODULE_BOUNDARY_NODE_STR, OperatorInstance, PortIndexValue,
24    Varname, change_spans, get_operator_generics,
25};
26use crate::diagnostic::{Diagnostic, Diagnostics, Level};
27use crate::pretty_span::{PrettyRowCol, PrettySpan};
28use crate::process_singletons;
29
30/// An abstract "meta graph" representation of a DFIR graph.
31///
32/// Can be with or without subgraph partitioning, stratification, and handoff insertion. This is
33/// the meta graph used for generating Rust source code in macros from DFIR sytnax.
34///
35/// This struct has a lot of methods for manipulating the graph, vaguely grouped together in
36/// separate `impl` blocks. You might notice a few particularly specific arbitray-seeming methods
37/// in here--those are just what was needed for the compilation algorithms. If you need another
38/// method then add it.
39#[derive(Default, Debug, Serialize, Deserialize)]
40pub struct DfirGraph {
41    /// Each node type (operator or handoff).
42    nodes: SlotMap<GraphNodeId, GraphNode>,
43
44    /// Instance data corresponding to each operator node.
45    /// This field will be empty after deserialization.
46    #[serde(skip)]
47    operator_instances: SecondaryMap<GraphNodeId, OperatorInstance>,
48    /// Debugging/tracing tag for each operator node.
49    operator_tag: SecondaryMap<GraphNodeId, String>,
50    /// Graph data structure (two-way adjacency list).
51    graph: DiMulGraph<GraphNodeId, GraphEdgeId>,
52    /// Input and output port for each edge.
53    ports: SecondaryMap<GraphEdgeId, (PortIndexValue, PortIndexValue)>,
54
55    /// Which loop a node belongs to (or none for top-level).
56    node_loops: SecondaryMap<GraphNodeId, GraphLoopId>,
57    /// Which nodes belong to each loop.
58    loop_nodes: SlotMap<GraphLoopId, Vec<GraphNodeId>>,
59    /// For the loop, what is its parent (`None` for top-level).
60    loop_parent: SparseSecondaryMap<GraphLoopId, GraphLoopId>,
61    /// What loops are at the root.
62    root_loops: Vec<GraphLoopId>,
63    /// For the loop, what are its child loops.
64    loop_children: SecondaryMap<GraphLoopId, Vec<GraphLoopId>>,
65
66    /// Which subgraph each node belongs to.
67    node_subgraph: SecondaryMap<GraphNodeId, GraphSubgraphId>,
68
69    /// Which nodes belong to each subgraph.
70    subgraph_nodes: SlotMap<GraphSubgraphId, Vec<GraphNodeId>>,
71
72    /// Resolved singletons varnames references, per node.
73    node_singleton_references: SparseSecondaryMap<GraphNodeId, Vec<Option<GraphNodeId>>>,
74    /// What variable name each graph node belongs to (if any). For debugging (graph writing) purposes only.
75    node_varnames: SparseSecondaryMap<GraphNodeId, Varname>,
76
77    /// Delay type for handoff nodes that represent tick-boundary back-edges.
78    /// Set by `order_subgraphs` for `defer_tick` / `defer_tick_lazy`, either on handoff nodes
79    /// it injects or on existing handoff nodes that it marks as tick-boundary back-edges.
80    handoff_delay_type: SparseSecondaryMap<GraphNodeId, DelayType>,
81}
82
83/// Basic methods.
84impl DfirGraph {
85    /// Create a new empty graph.
86    pub fn new() -> Self {
87        Default::default()
88    }
89}
90
91/// Node methods.
92impl DfirGraph {
93    /// Get a node with its operator instance (if applicable).
94    pub fn node(&self, node_id: GraphNodeId) -> &GraphNode {
95        self.nodes.get(node_id).expect("Node not found.")
96    }
97
98    /// Get the `OperatorInstance` for a given node. Node must be an operator and have an
99    /// `OperatorInstance` present, otherwise will return `None`.
100    ///
101    /// Note that no operator instances will be persent after deserialization.
102    pub fn node_op_inst(&self, node_id: GraphNodeId) -> Option<&OperatorInstance> {
103        self.operator_instances.get(node_id)
104    }
105
106    /// Get the debug variable name attached to a graph node.
107    pub fn node_varname(&self, node_id: GraphNodeId) -> Option<&Varname> {
108        self.node_varnames.get(node_id)
109    }
110
111    /// Get subgraph for node.
112    pub fn node_subgraph(&self, node_id: GraphNodeId) -> Option<GraphSubgraphId> {
113        self.node_subgraph.get(node_id).copied()
114    }
115
116    /// Degree into a node, i.e. the number of predecessors.
117    pub fn node_degree_in(&self, node_id: GraphNodeId) -> usize {
118        self.graph.degree_in(node_id)
119    }
120
121    /// Degree out of a node, i.e. the number of successors.
122    pub fn node_degree_out(&self, node_id: GraphNodeId) -> usize {
123        self.graph.degree_out(node_id)
124    }
125
126    /// Successors, iterator of `(GraphEdgeId, GraphNodeId)` of outgoing edges.
127    pub fn node_successors(
128        &self,
129        src: GraphNodeId,
130    ) -> impl '_
131    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
132    + ExactSizeIterator
133    + FusedIterator
134    + Clone
135    + Debug {
136        self.graph.successors(src)
137    }
138
139    /// Predecessors, iterator of `(GraphEdgeId, GraphNodeId)` of incoming edges.
140    pub fn node_predecessors(
141        &self,
142        dst: GraphNodeId,
143    ) -> impl '_
144    + DoubleEndedIterator<Item = (GraphEdgeId, GraphNodeId)>
145    + ExactSizeIterator
146    + FusedIterator
147    + Clone
148    + Debug {
149        self.graph.predecessors(dst)
150    }
151
152    /// Successor edges, iterator of `GraphEdgeId` of outgoing edges.
153    pub fn node_successor_edges(
154        &self,
155        src: GraphNodeId,
156    ) -> impl '_
157    + DoubleEndedIterator<Item = GraphEdgeId>
158    + ExactSizeIterator
159    + FusedIterator
160    + Clone
161    + Debug {
162        self.graph.successor_edges(src)
163    }
164
165    /// Predecessor edges, iterator of `GraphEdgeId` of incoming edges.
166    pub fn node_predecessor_edges(
167        &self,
168        dst: GraphNodeId,
169    ) -> impl '_
170    + DoubleEndedIterator<Item = GraphEdgeId>
171    + ExactSizeIterator
172    + FusedIterator
173    + Clone
174    + Debug {
175        self.graph.predecessor_edges(dst)
176    }
177
178    /// Successor nodes, iterator of `GraphNodeId`.
179    pub fn node_successor_nodes(
180        &self,
181        src: GraphNodeId,
182    ) -> impl '_
183    + DoubleEndedIterator<Item = GraphNodeId>
184    + ExactSizeIterator
185    + FusedIterator
186    + Clone
187    + Debug {
188        self.graph.successor_vertices(src)
189    }
190
191    /// Predecessor nodes, iterator of `GraphNodeId`.
192    pub fn node_predecessor_nodes(
193        &self,
194        dst: GraphNodeId,
195    ) -> impl '_
196    + DoubleEndedIterator<Item = GraphNodeId>
197    + ExactSizeIterator
198    + FusedIterator
199    + Clone
200    + Debug {
201        self.graph.predecessor_vertices(dst)
202    }
203
204    /// Iterator of node IDs `GraphNodeId`.
205    pub fn node_ids(&self) -> slotmap::basic::Keys<'_, GraphNodeId, GraphNode> {
206        self.nodes.keys()
207    }
208
209    /// Iterator over `(GraphNodeId, &Node)` pairs.
210    pub fn nodes(&self) -> slotmap::basic::Iter<'_, GraphNodeId, GraphNode> {
211        self.nodes.iter()
212    }
213
214    /// Insert a node, assigning the given varname.
215    pub fn insert_node(
216        &mut self,
217        node: GraphNode,
218        varname_opt: Option<Ident>,
219        loop_opt: Option<GraphLoopId>,
220    ) -> GraphNodeId {
221        let node_id = self.nodes.insert(node);
222        if let Some(varname) = varname_opt {
223            self.node_varnames.insert(node_id, Varname(varname));
224        }
225        if let Some(loop_id) = loop_opt {
226            self.node_loops.insert(node_id, loop_id);
227            self.loop_nodes[loop_id].push(node_id);
228        }
229        node_id
230    }
231
232    /// Insert an operator instance for the given node. Panics if already set.
233    pub fn insert_node_op_inst(&mut self, node_id: GraphNodeId, op_inst: OperatorInstance) {
234        assert!(matches!(
235            self.nodes.get(node_id),
236            Some(GraphNode::Operator(_))
237        ));
238        let old_inst = self.operator_instances.insert(node_id, op_inst);
239        assert!(old_inst.is_none());
240    }
241
242    /// Assign all operator instances if not set. Write diagnostic messages/errors into `diagnostics`.
243    pub fn insert_node_op_insts_all(&mut self, diagnostics: &mut Diagnostics) {
244        let mut op_insts = Vec::new();
245        for (node_id, node) in self.nodes() {
246            let GraphNode::Operator(operator) = node else {
247                continue;
248            };
249            if self.node_op_inst(node_id).is_some() {
250                continue;
251            };
252
253            // Op constraints.
254            let Some(op_constraints) = find_op_op_constraints(operator) else {
255                diagnostics.push(Diagnostic::spanned(
256                    operator.path.span(),
257                    Level::Error,
258                    format!("Unknown operator `{}`", operator.name_string()),
259                ));
260                continue;
261            };
262
263            // Input and output ports.
264            let (input_ports, output_ports) = {
265                let mut input_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
266                    .node_predecessors(node_id)
267                    .map(|(edge_id, pred_id)| (self.edge_ports(edge_id).1, pred_id))
268                    .collect();
269                // Ensure sorted by port index.
270                input_edges.sort();
271                let input_ports: Vec<PortIndexValue> = input_edges
272                    .into_iter()
273                    .map(|(port, _pred)| port)
274                    .cloned()
275                    .collect();
276
277                // Collect output arguments (successors).
278                let mut output_edges: Vec<(&PortIndexValue, GraphNodeId)> = self
279                    .node_successors(node_id)
280                    .map(|(edge_id, succ)| (self.edge_ports(edge_id).0, succ))
281                    .collect();
282                // Ensure sorted by port index.
283                output_edges.sort();
284                let output_ports: Vec<PortIndexValue> = output_edges
285                    .into_iter()
286                    .map(|(port, _succ)| port)
287                    .cloned()
288                    .collect();
289
290                (input_ports, output_ports)
291            };
292
293            // Generic arguments.
294            let generics = get_operator_generics(diagnostics, operator);
295            // Generic argument errors.
296            {
297                // Span of `generic_args` (if it exists), otherwise span of the operator name.
298                let generics_span = generics
299                    .generic_args
300                    .as_ref()
301                    .map(Spanned::span)
302                    .unwrap_or_else(|| operator.path.span());
303
304                if !op_constraints
305                    .persistence_args
306                    .contains(&generics.persistence_args.len())
307                {
308                    diagnostics.push(Diagnostic::spanned(
309                        generics.persistence_args_span().unwrap_or(generics_span),
310                        Level::Error,
311                        format!(
312                            "`{}` should have {} persistence lifetime arguments, actually has {}.",
313                            op_constraints.name,
314                            op_constraints.persistence_args.human_string(),
315                            generics.persistence_args.len()
316                        ),
317                    ));
318                }
319                if !op_constraints.type_args.contains(&generics.type_args.len()) {
320                    diagnostics.push(Diagnostic::spanned(
321                        generics.type_args_span().unwrap_or(generics_span),
322                        Level::Error,
323                        format!(
324                            "`{}` should have {} generic type arguments, actually has {}.",
325                            op_constraints.name,
326                            op_constraints.type_args.human_string(),
327                            generics.type_args.len()
328                        ),
329                    ));
330                }
331            }
332
333            op_insts.push((
334                node_id,
335                OperatorInstance {
336                    op_constraints,
337                    input_ports,
338                    output_ports,
339                    singletons_referenced: operator.singletons_referenced.clone(),
340                    generics,
341                    arguments_pre: operator.args.clone(),
342                    arguments_raw: operator.args_raw.clone(),
343                },
344            ));
345        }
346
347        for (node_id, op_inst) in op_insts {
348            self.insert_node_op_inst(node_id, op_inst);
349        }
350    }
351
352    /// Inserts a node between two existing nodes connected by the given `edge_id`.
353    ///
354    /// `edge`: (src, dst, dst_idx)
355    ///
356    /// Before: A (src) ------------> B (dst)
357    /// After:  A (src) -> X (new) -> B (dst)
358    ///
359    /// Returns the ID of X & ID of edge OUT of X.
360    ///
361    /// Note that both the edges will be new and `edge_id` will be removed. Both new edges will
362    /// get the edge type of the original edge.
363    pub fn insert_intermediate_node(
364        &mut self,
365        edge_id: GraphEdgeId,
366        new_node: GraphNode,
367    ) -> (GraphNodeId, GraphEdgeId) {
368        let span = Some(new_node.span());
369
370        // Make corresponding operator instance (if `node` is an operator).
371        let op_inst_opt = 'oc: {
372            let GraphNode::Operator(operator) = &new_node else {
373                break 'oc None;
374            };
375            let Some(op_constraints) = find_op_op_constraints(operator) else {
376                break 'oc None;
377            };
378            let (input_port, output_port) = self.ports.get(edge_id).cloned().unwrap();
379
380            let mut dummy_diagnostics = Diagnostics::new();
381            let generics = get_operator_generics(&mut dummy_diagnostics, operator);
382            assert!(dummy_diagnostics.is_empty());
383
384            Some(OperatorInstance {
385                op_constraints,
386                input_ports: vec![input_port],
387                output_ports: vec![output_port],
388                singletons_referenced: operator.singletons_referenced.clone(),
389                generics,
390                arguments_pre: operator.args.clone(),
391                arguments_raw: operator.args_raw.clone(),
392            })
393        };
394
395        // Insert new `node`.
396        let node_id = self.nodes.insert(new_node);
397        // Insert corresponding `OperatorInstance` if applicable.
398        if let Some(op_inst) = op_inst_opt {
399            self.operator_instances.insert(node_id, op_inst);
400        }
401        // Update edges to insert node within `edge_id`.
402        let (e0, e1) = self
403            .graph
404            .insert_intermediate_vertex(node_id, edge_id)
405            .unwrap();
406
407        // Update corresponding ports.
408        let (src_idx, dst_idx) = self.ports.remove(edge_id).unwrap();
409        self.ports
410            .insert(e0, (src_idx, PortIndexValue::Elided(span)));
411        self.ports
412            .insert(e1, (PortIndexValue::Elided(span), dst_idx));
413
414        (node_id, e1)
415    }
416
417    /// Remove the node `node_id` but preserves and connects the single predecessor and single successor.
418    /// Panics if the node does not have exactly one predecessor and one successor, or is not in the graph.
419    pub fn remove_intermediate_node(&mut self, node_id: GraphNodeId) {
420        assert_eq!(
421            1,
422            self.node_degree_in(node_id),
423            "Removed intermediate node must have one predecessor"
424        );
425        assert_eq!(
426            1,
427            self.node_degree_out(node_id),
428            "Removed intermediate node must have one successor"
429        );
430        assert!(
431            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
432            "Should not remove intermediate node after subgraph partitioning"
433        );
434
435        assert!(self.nodes.remove(node_id).is_some());
436        let (new_edge_id, (pred_edge_id, succ_edge_id)) =
437            self.graph.remove_intermediate_vertex(node_id).unwrap();
438        self.operator_instances.remove(node_id);
439        self.node_varnames.remove(node_id);
440
441        let (src_port, _) = self.ports.remove(pred_edge_id).unwrap();
442        let (_, dst_port) = self.ports.remove(succ_edge_id).unwrap();
443        self.ports.insert(new_edge_id, (src_port, dst_port));
444    }
445
446    /// Helper method: determine the "color" (pull vs push) of a node based on its in and out degree,
447    /// excluding reference edges. If linear (1 in, 1 out), color is `None`, indicating it can be
448    /// either push or pull.
449    ///
450    /// Note that this does NOT consider `DelayType` barriers (which generally implies `Pull`).
451    pub(crate) fn node_color(&self, node_id: GraphNodeId) -> Option<Color> {
452        if matches!(self.node(node_id), GraphNode::Handoff { .. }) {
453            return Some(Color::Hoff);
454        }
455
456        // TODO(shadaj): this is a horrible hack
457        if let GraphNode::Operator(op) = self.node(node_id)
458            && (op.name_string() == "resolve_futures_blocking"
459                || op.name_string() == "resolve_futures_blocking_ordered")
460        {
461            return Some(Color::Push);
462        }
463
464        // In-degree, excluding ref-edges.
465        let inn_degree = self.node_predecessor_nodes(node_id).len();
466        // Out-degree excluding ref-edges.
467        let out_degree = self.node_successor_nodes(node_id).len();
468
469        match (inn_degree, out_degree) {
470            (0, 0) => None, // Generally should not happen, "Degenerate subgraph detected".
471            (0, 1) => Some(Color::Pull),
472            (1, 0) => Some(Color::Push),
473            (1, 1) => None, // Linear, can be either push or pull.
474            (_many, 0 | 1) => Some(Color::Pull),
475            (0 | 1, _many) => Some(Color::Push),
476            (_many, _to_many) => Some(Color::Comp),
477        }
478    }
479
480    /// Set the operator tag (for debugging/tracing).
481    pub fn set_operator_tag(&mut self, node_id: GraphNodeId, tag: String) {
482        self.operator_tag.insert(node_id, tag);
483    }
484}
485
486/// Singleton references.
487impl DfirGraph {
488    /// Set the singletons referenced for the `node_id` operator. Each reference corresponds to the
489    /// same index in the [`crate::parse::Operator::singletons_referenced`] vec.
490    pub fn set_node_singleton_references(
491        &mut self,
492        node_id: GraphNodeId,
493        singletons_referenced: Vec<Option<GraphNodeId>>,
494    ) -> Option<Vec<Option<GraphNodeId>>> {
495        self.node_singleton_references
496            .insert(node_id, singletons_referenced)
497    }
498
499    /// Gets the singletons referenced by a node. Returns an empty iterator for non-operators and
500    /// operators that do not reference singletons.
501    pub fn node_singleton_references(&self, node_id: GraphNodeId) -> &[Option<GraphNodeId>] {
502        self.node_singleton_references
503            .get(node_id)
504            .map(std::ops::Deref::deref)
505            .unwrap_or_default()
506    }
507}
508
509/// Module methods.
510impl DfirGraph {
511    /// When modules are imported into a flat graph, they come with an input and output ModuleBoundary node.
512    /// The partitioner doesn't understand these nodes and will panic if it encounters them.
513    /// merge_modules removes them from the graph, stitching the input and ouput sides of the ModuleBondaries based on their ports
514    /// For example:
515    ///     source_iter([]) -> \[myport\]ModuleBoundary(input)\[my_port\] -> map(|x| x) -> ModuleBoundary(output) -> null();
516    /// in the above eaxmple, the \[myport\] port will be used to connect the source_iter with the map that is inside of the module.
517    /// The output module boundary has elided ports, this is also used to match up the input/output across the module boundary.
518    pub fn merge_modules(&mut self) -> Result<(), Diagnostic> {
519        let mod_bound_nodes = self
520            .nodes()
521            .filter(|(_nid, node)| matches!(node, GraphNode::ModuleBoundary { .. }))
522            .map(|(nid, _node)| nid)
523            .collect::<Vec<_>>();
524
525        for mod_bound_node in mod_bound_nodes {
526            self.remove_module_boundary(mod_bound_node)?;
527        }
528
529        Ok(())
530    }
531
532    /// see `merge_modules`
533    /// This function removes a singular module boundary from the graph and performs the necessary stitching to fix the graph afterward.
534    /// `merge_modules` calls this function for each module boundary in the graph.
535    fn remove_module_boundary(&mut self, mod_bound_node: GraphNodeId) -> Result<(), Diagnostic> {
536        assert!(
537            self.node_subgraph.is_empty() && self.subgraph_nodes.is_empty(),
538            "Should not remove intermediate node after subgraph partitioning"
539        );
540
541        let mut mod_pred_ports = BTreeMap::new();
542        let mut mod_succ_ports = BTreeMap::new();
543
544        for mod_out_edge in self.node_predecessor_edges(mod_bound_node) {
545            let (pred_port, succ_port) = self.edge_ports(mod_out_edge);
546            mod_pred_ports.insert(succ_port.clone(), (mod_out_edge, pred_port.clone()));
547        }
548
549        for mod_inn_edge in self.node_successor_edges(mod_bound_node) {
550            let (pred_port, succ_port) = self.edge_ports(mod_inn_edge);
551            mod_succ_ports.insert(pred_port.clone(), (mod_inn_edge, succ_port.clone()));
552        }
553
554        if mod_pred_ports.keys().collect::<BTreeSet<_>>()
555            != mod_succ_ports.keys().collect::<BTreeSet<_>>()
556        {
557            // get module boundary node
558            let GraphNode::ModuleBoundary { input, import_expr } = self.node(mod_bound_node) else {
559                panic!();
560            };
561
562            if *input {
563                return Err(Diagnostic {
564                    span: *import_expr,
565                    level: Level::Error,
566                    message: format!(
567                        "The ports into the module did not match. input: {:?}, expected: {:?}",
568                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
569                        mod_succ_ports.keys().map(|x| x.to_string()).join(", ")
570                    ),
571                });
572            } else {
573                return Err(Diagnostic {
574                    span: *import_expr,
575                    level: Level::Error,
576                    message: format!(
577                        "The ports out of the module did not match. output: {:?}, expected: {:?}",
578                        mod_succ_ports.keys().map(|x| x.to_string()).join(", "),
579                        mod_pred_ports.keys().map(|x| x.to_string()).join(", "),
580                    ),
581                });
582            }
583        }
584
585        for (port, (pred_edge, pred_port)) in mod_pred_ports {
586            let (succ_edge, succ_port) = mod_succ_ports.remove(&port).unwrap();
587
588            let (src, _) = self.edge(pred_edge);
589            let (_, dst) = self.edge(succ_edge);
590            self.remove_edge(pred_edge);
591            self.remove_edge(succ_edge);
592
593            let new_edge_id = self.graph.insert_edge(src, dst);
594            self.ports.insert(new_edge_id, (pred_port, succ_port));
595        }
596
597        self.graph.remove_vertex(mod_bound_node);
598        self.nodes.remove(mod_bound_node);
599
600        Ok(())
601    }
602}
603
604/// Edge methods.
605impl DfirGraph {
606    /// Get the `src` and `dst` for an edge: `(src GraphNodeId, dst GraphNodeId)`.
607    pub fn edge(&self, edge_id: GraphEdgeId) -> (GraphNodeId, GraphNodeId) {
608        let (src, dst) = self.graph.edge(edge_id).expect("Edge not found.");
609        (src, dst)
610    }
611
612    /// Get the source and destination ports for an edge: `(src &PortIndexValue, dst &PortIndexValue)`.
613    pub fn edge_ports(&self, edge_id: GraphEdgeId) -> (&PortIndexValue, &PortIndexValue) {
614        let (src_port, dst_port) = self.ports.get(edge_id).expect("Edge not found.");
615        (src_port, dst_port)
616    }
617
618    /// Iterator of all edge IDs `GraphEdgeId`.
619    pub fn edge_ids(&self) -> slotmap::basic::Keys<'_, GraphEdgeId, (GraphNodeId, GraphNodeId)> {
620        self.graph.edge_ids()
621    }
622
623    /// Iterator over all edges: `(GraphEdgeId, (src GraphNodeId, dst GraphNodeId))`.
624    pub fn edges(
625        &self,
626    ) -> impl '_
627    + ExactSizeIterator<Item = (GraphEdgeId, (GraphNodeId, GraphNodeId))>
628    + FusedIterator
629    + Clone
630    + Debug {
631        self.graph.edges()
632    }
633
634    /// Insert an edge between nodes thru the given ports.
635    pub fn insert_edge(
636        &mut self,
637        src: GraphNodeId,
638        src_port: PortIndexValue,
639        dst: GraphNodeId,
640        dst_port: PortIndexValue,
641    ) -> GraphEdgeId {
642        let edge_id = self.graph.insert_edge(src, dst);
643        self.ports.insert(edge_id, (src_port, dst_port));
644        edge_id
645    }
646
647    /// Removes an edge and its corresponding ports and edge type info.
648    pub fn remove_edge(&mut self, edge: GraphEdgeId) {
649        let (_src, _dst) = self.graph.remove_edge(edge).unwrap();
650        let (_src_port, _dst_port) = self.ports.remove(edge).unwrap();
651    }
652}
653
654/// Subgraph methods.
655impl DfirGraph {
656    /// Nodes belonging to the given subgraph.
657    pub fn subgraph(&self, subgraph_id: GraphSubgraphId) -> &Vec<GraphNodeId> {
658        self.subgraph_nodes
659            .get(subgraph_id)
660            .expect("Subgraph not found.")
661    }
662
663    /// Iterator over all subgraph IDs.
664    pub fn subgraph_ids(&self) -> slotmap::basic::Keys<'_, GraphSubgraphId, Vec<GraphNodeId>> {
665        self.subgraph_nodes.keys()
666    }
667
668    /// Iterator over all subgraphs, ID and members: `(GraphSubgraphId, Vec<GraphNodeId>)`.
669    pub fn subgraphs(&self) -> slotmap::basic::Iter<'_, GraphSubgraphId, Vec<GraphNodeId>> {
670        self.subgraph_nodes.iter()
671    }
672
673    /// Create a subgraph consisting of `node_ids`. Returns an error if any of the nodes are already in a subgraph.
674    pub fn insert_subgraph(
675        &mut self,
676        node_ids: Vec<GraphNodeId>,
677    ) -> Result<GraphSubgraphId, (GraphNodeId, GraphSubgraphId)> {
678        // Check none are already in subgraphs
679        for &node_id in node_ids.iter() {
680            if let Some(&old_sg_id) = self.node_subgraph.get(node_id) {
681                return Err((node_id, old_sg_id));
682            }
683        }
684        let subgraph_id = self.subgraph_nodes.insert_with_key(|sg_id| {
685            for &node_id in node_ids.iter() {
686                self.node_subgraph.insert(node_id, sg_id);
687            }
688            node_ids
689        });
690
691        Ok(subgraph_id)
692    }
693
694    /// Removes a node from its subgraph. Returns true if the node was in a subgraph.
695    pub fn remove_from_subgraph(&mut self, node_id: GraphNodeId) -> bool {
696        if let Some(old_sg_id) = self.node_subgraph.remove(node_id) {
697            self.subgraph_nodes[old_sg_id].retain(|&other_node_id| other_node_id != node_id);
698            true
699        } else {
700            false
701        }
702    }
703
704    /// Gets the delay type for a handoff node, if set.
705    pub fn handoff_delay_type(&self, node_id: GraphNodeId) -> Option<DelayType> {
706        self.handoff_delay_type.get(node_id).copied()
707    }
708
709    /// Sets the delay type for a handoff node.
710    pub fn set_handoff_delay_type(&mut self, node_id: GraphNodeId, delay_type: DelayType) {
711        self.handoff_delay_type.insert(node_id, delay_type);
712    }
713
714    /// Helper: finds the first index in `subgraph_nodes` where it transitions from pull to push.
715    fn find_pull_to_push_idx(&self, subgraph_nodes: &[GraphNodeId]) -> usize {
716        subgraph_nodes
717            .iter()
718            .position(|&node_id| {
719                self.node_color(node_id)
720                    .is_some_and(|color| Color::Pull != color)
721            })
722            .unwrap_or(subgraph_nodes.len())
723    }
724}
725
726/// Display/output methods.
727impl DfirGraph {
728    /// Helper to generate a deterministic `Ident` for the given node.
729    fn node_as_ident(&self, node_id: GraphNodeId, is_pred: bool) -> Ident {
730        let name = match &self.nodes[node_id] {
731            GraphNode::Operator(_) => format!("op_{:?}", node_id.data()),
732            GraphNode::Handoff { .. } => format!(
733                "hoff_{:?}_{}",
734                node_id.data(),
735                if is_pred { "recv" } else { "send" }
736            ),
737            GraphNode::ModuleBoundary { .. } => panic!(),
738        };
739        let span = match (is_pred, &self.nodes[node_id]) {
740            (_, GraphNode::Operator(operator)) => operator.span(),
741            (true, &GraphNode::Handoff { src_span, .. }) => src_span,
742            (false, &GraphNode::Handoff { dst_span, .. }) => dst_span,
743            (_, GraphNode::ModuleBoundary { .. }) => panic!(),
744        };
745        Ident::new(&name, span)
746    }
747
748    /// Helper to generate the main buffer `Ident` for a handoff node.
749    fn hoff_buf_ident(&self, hoff_id: GraphNodeId, span: Span) -> Ident {
750        Ident::new(&format!("hoff_{:?}_buf", hoff_id.data()), span)
751    }
752
753    /// Helper to generate the back (double-buffer) `Ident` for a handoff node.
754    fn hoff_back_ident(&self, hoff_id: GraphNodeId, span: Span) -> Ident {
755        Ident::new(&format!("hoff_{:?}_back", hoff_id.data()), span)
756    }
757
758    /// For per-node singleton references. Helper to generate a deterministic `Ident` for the given node.
759    fn node_as_singleton_ident(&self, node_id: GraphNodeId, span: Span) -> Ident {
760        Ident::new(&format!("singleton_op_{:?}", node_id.data()), span)
761    }
762
763    /// Resolve the singletons via [`Self::node_singleton_references`] for the given `node_id`.
764    fn helper_resolve_singletons(&self, node_id: GraphNodeId, span: Span) -> Vec<Ident> {
765        self.node_singleton_references(node_id)
766            .iter()
767            .map(|singleton_node_id| {
768                // TODO(mingwei): this `expect` should be caught in error checking
769                self.node_as_singleton_ident(
770                    singleton_node_id
771                        .expect("Expected singleton to be resolved but was not, this is a bug."),
772                    span,
773                )
774            })
775            .collect::<Vec<_>>()
776    }
777
778    /// Returns each subgraph's receive and send handoffs.
779    /// `Map<GraphSubgraphId, (recv handoffs, send handoffs)>`
780    fn helper_collect_subgraph_handoffs(
781        &self,
782    ) -> SecondaryMap<GraphSubgraphId, (Vec<GraphNodeId>, Vec<GraphNodeId>)> {
783        // Get data on handoff src and dst subgraphs.
784        let mut subgraph_handoffs: SecondaryMap<
785            GraphSubgraphId,
786            (Vec<GraphNodeId>, Vec<GraphNodeId>),
787        > = self
788            .subgraph_nodes
789            .keys()
790            .map(|k| (k, Default::default()))
791            .collect();
792
793        // For each handoff node, add it to the `send`/`recv` lists for the corresponding subgraphs.
794        for (hoff_id, node) in self.nodes() {
795            if !matches!(node, GraphNode::Handoff { .. }) {
796                continue;
797            }
798            // Receivers from the handoff. (Should really only be one).
799            for (_edge, succ_id) in self.node_successors(hoff_id) {
800                let succ_sg = self.node_subgraph(succ_id).unwrap();
801                subgraph_handoffs[succ_sg].0.push(hoff_id);
802            }
803            // Senders into the handoff. (Should really only be one).
804            for (_edge, pred_id) in self.node_predecessors(hoff_id) {
805                let pred_sg = self.node_subgraph(pred_id).unwrap();
806                subgraph_handoffs[pred_sg].1.push(hoff_id);
807            }
808        }
809
810        subgraph_handoffs
811    }
812
813    /// Emit this graph as runnable Rust source code tokens that execute inline.
814    /// Generates a flat `async move |df: &mut Context|` closure where subgraph
815    /// blocks are inlined in topological order, using local `Vec<T>` buffers
816    /// instead of runtime handoffs. Each call to the closure runs one tick.
817    ///
818    /// The generated code block evaluates to a `Dfir` instance wrapping the
819    /// closure. Operator prologues run at construction time before the `Context`
820    /// is moved into `Dfir::new`. `Dfir` provides the `Context` to the closure
821    /// on each tick run.
822    ///
823    /// # Errors
824    ///
825    /// Returns all diagnostics as `Err(diagnostics)` if any are errors
826    /// (leaving `&mut diagnostics` empty).
827    pub fn as_code(
828        &self,
829        root: &TokenStream,
830        include_type_guards: bool,
831        prefix: TokenStream,
832        diagnostics: &mut Diagnostics,
833    ) -> Result<TokenStream, Diagnostics> {
834        self.as_code_with_options(root, include_type_guards, true, prefix, diagnostics)
835    }
836
837    /// Like [`Self::as_code`], but with `include_meta` controlling whether
838    /// the runtime meta graph + diagnostics JSON blobs are baked into the
839    /// generated `Dfir::new(...)` call.
840    ///
841    /// The simulator calls Dfir::new() on each iteration, and as a part of that
842    /// it does parsing of the metagraph and diganostics blob. One of them causes spans to get allocated,
843    /// each time a span is allocated, some threadlocal u32 is being incremented, and, on a long simulator run,
844    /// the u32 overflows and panics.
845    pub fn as_code_with_options(
846        &self,
847        root: &TokenStream,
848        include_type_guards: bool,
849        include_meta: bool,
850        prefix: TokenStream,
851        diagnostics: &mut Diagnostics,
852    ) -> Result<TokenStream, Diagnostics> {
853        // Extract the slot index from a slotmap key for use as a runtime metrics key.
854        // Uses the low 32 bits of `KeyData::as_ffi()` (the idx, ignoring the version).
855        // TODO(cleanup): When scheduled Dfir is removed, DfirMetrics could use slotmap
856        // SecondaryMaps directly, eliminating this conversion.
857        fn slotmap_raw_idx(key: impl Key) -> usize {
858            (key.data().as_ffi() & 0xFFFF_FFFF) as usize
859        }
860
861        let df = Ident::new(GRAPH, Span::call_site());
862        let context = Ident::new(CONTEXT, Span::call_site());
863
864        // 1. Generate local Vec buffers for each handoff node.
865        let handoff_nodes: Vec<_> = self
866            .nodes
867            .iter()
868            .filter_map(|(node_id, node)| match node {
869                GraphNode::Operator(_) => None,
870                &GraphNode::Handoff { src_span, dst_span } => Some((node_id, (src_span, dst_span))),
871                GraphNode::ModuleBoundary { .. } => panic!(),
872            })
873            .collect();
874
875        let buffer_code: Vec<TokenStream> = handoff_nodes
876            .iter()
877            .map(|&(node_id, (src_span, dst_span))| {
878                let span = src_span.join(dst_span).unwrap_or(src_span);
879                let buf_ident = self.hoff_buf_ident(node_id, span);
880                quote_spanned! {span=>
881                    let mut #buf_ident: Vec<_> = Vec::new();
882                }
883            })
884            .collect();
885
886        // For tick-boundary handoffs (`defer_tick` / `defer_tick_lazy`), declare a
887        // second "back" buffer for double-buffering. At the start of each tick, the
888        // main buffer and back buffer are swapped so the consumer reads last tick's
889        // data while the producer writes to a fresh buffer.
890        let back_buffer_code: Vec<TokenStream> = handoff_nodes
891            .iter()
892            .filter(|(node_id, _)| self.handoff_delay_type(*node_id).is_some())
893            .map(|&(node_id, (src_span, dst_span))| {
894                let span = src_span.join(dst_span).unwrap_or(src_span);
895                let back_ident = self.hoff_back_ident(node_id, span);
896                quote_spanned! {span=>
897                    let mut #back_ident: Vec<_> = Vec::new();
898                }
899            })
900            .collect();
901
902        // 2. Collect subgraph handoffs (same as as_code).
903        let subgraph_handoffs = self.helper_collect_subgraph_handoffs();
904
905        // 3. Sort subgraphs topologically and collect non-lazy defer_tick buffer idents.
906        //
907        // Handoffs marked with a `DelayType` (Tick/TickLazy) are tick-boundary back-edges.
908        // These are excluded from the topo sort (no ordering constraint). Double-buffering
909        // ensures data written by the producer in tick N is only visible to the consumer
910        // in tick N+1, regardless of execution order.
911        //
912        // While iterating handoffs, we also collect buffer idents for non-lazy tick-boundary
913        // edges (defer_tick). When these buffers are non-empty at end of tick, we set
914        // can_start_tick so that run_available continues ticking.
915        let mut defer_tick_buf_idents: Vec<Ident> = Vec::new();
916        let mut back_edge_hoff_ids: BTreeSet<GraphNodeId> = BTreeSet::new();
917        let all_subgraphs = {
918            // Build predecessor map for subgraphs.
919            let mut sg_preds = SecondaryMap::<_, Vec<_>>::with_capacity(self.subgraph_nodes.len());
920            for (hoff_id, node) in self.nodes() {
921                if !matches!(node, GraphNode::Handoff { .. }) {
922                    // Not a handoff; skip.
923                    continue;
924                }
925                assert_eq!(1, self.node_successors(hoff_id).len());
926                assert_eq!(1, self.node_predecessors(hoff_id).len());
927                let (_edge_id, pred) = self.node_predecessors(hoff_id).next().unwrap();
928                let (_edge_id, succ) = self.node_successors(hoff_id).next().unwrap();
929                let pred_sg = self.node_subgraph(pred).unwrap();
930                let succ_sg = self.node_subgraph(succ).unwrap();
931                if pred_sg == succ_sg {
932                    panic!("bug: unexpected subgraph self-handoff cycle");
933                }
934                if let Some(delay_type) = self.handoff_delay_type(hoff_id) {
935                    debug_assert!(matches!(delay_type, DelayType::Tick | DelayType::TickLazy));
936                    // Tick/back-edge handoff: no ordering constraint. Double-buffering
937                    // handles the tick deferral regardless of execution order.
938                    back_edge_hoff_ids.insert(hoff_id);
939
940                    // Non-lazy tick-boundary: defer_tick (not defer_tick_lazy).
941                    if !matches!(delay_type, DelayType::TickLazy) {
942                        defer_tick_buf_idents.push(self.hoff_buf_ident(hoff_id, node.span()));
943                    }
944                } else {
945                    sg_preds.entry(succ_sg).unwrap().or_default().push(pred_sg);
946                }
947            }
948
949            // Include singleton reference edges: if node A references the
950            // singleton output of node B, then A's subgraph must run after B's.
951            for dst_id in self.node_ids() {
952                for src_ref_id in self
953                    .node_singleton_references(dst_id)
954                    .iter()
955                    .copied()
956                    .flatten()
957                {
958                    let src_sg = self
959                        .node_subgraph(src_ref_id)
960                        .expect("bug: singleton ref node must belong to a subgraph");
961                    let dst_sg = self
962                        .node_subgraph(dst_id)
963                        .expect("bug: singleton ref consumer must belong to a subgraph");
964                    if src_sg != dst_sg {
965                        sg_preds.entry(dst_sg).unwrap().or_default().push(src_sg);
966                    }
967                }
968            }
969
970            let topo_sort = super::graph_algorithms::topo_sort(self.subgraph_ids(), |sg_id| {
971                sg_preds.get(sg_id).into_iter().flatten().copied()
972            })
973            .expect("bug: unexpected cycle between subgraphs within the tick");
974
975            topo_sort
976                .into_iter()
977                .map(|sg_id| (sg_id, self.subgraph(sg_id)))
978                .collect::<Vec<_>>()
979        };
980
981        // Generate swap code for tick-boundary (defer_tick / defer_tick_lazy) handoffs.
982        // At the start of each tick, swap the main buffer and back buffer so the
983        // consumer reads last tick's data from the back buffer.
984        let back_edge_swap_code: Vec<TokenStream> = back_edge_hoff_ids
985            .iter()
986            .map(|&hoff_id| {
987                let span = self.nodes[hoff_id].span();
988                let buf_ident = self.hoff_buf_ident(hoff_id, span);
989                let back_ident = self.hoff_back_ident(hoff_id, span);
990                quote_spanned! {span=>
991                    ::std::mem::swap(&mut #buf_ident, &mut #back_ident);
992                }
993            })
994            .collect();
995
996        let mut op_prologue_code = Vec::new();
997        let mut op_prologue_after_code = Vec::new();
998        let mut op_tick_end_code = Vec::new();
999        let mut subgraph_blocks = Vec::new();
1000        {
1001            for &(subgraph_id, subgraph_nodes) in all_subgraphs.iter() {
1002                let sg_metrics_idx = slotmap_raw_idx(subgraph_id);
1003                let (recv_hoffs, send_hoffs) = &subgraph_handoffs[subgraph_id];
1004
1005                // Generate buffer ident helpers for this subgraph's handoffs.
1006                let recv_port_idents: Vec<Ident> = recv_hoffs
1007                    .iter()
1008                    .map(|&hoff_id| self.node_as_ident(hoff_id, true))
1009                    .collect();
1010                let send_port_idents: Vec<Ident> = send_hoffs
1011                    .iter()
1012                    .map(|&hoff_id| self.node_as_ident(hoff_id, false))
1013                    .collect();
1014
1015                // Map handoff node IDs to buffer idents.
1016                let recv_buf_idents: Vec<Ident> = recv_hoffs
1017                    .iter()
1018                    .map(|&hoff_id| self.hoff_buf_ident(hoff_id, self.nodes[hoff_id].span()))
1019                    .collect();
1020                let send_buf_idents: Vec<Ident> = send_hoffs
1021                    .iter()
1022                    .map(|&hoff_id| self.hoff_buf_ident(hoff_id, self.nodes[hoff_id].span()))
1023                    .collect();
1024
1025                // Recv port code: drain from buffer into iterator, tracking if non-empty.
1026                // For back-edge (defer_tick) handoffs, drain from the back buffer instead.
1027                // Also update handoff metrics (measured at recv, not send — see graph.rs).
1028                let recv_port_code: Vec<TokenStream> = recv_port_idents
1029                    .iter()
1030                    .zip(recv_buf_idents.iter())
1031                    .zip(recv_hoffs.iter())
1032                    .map(|((port_ident, buf_ident), &hoff_id)| {
1033                        let hoff_idx = slotmap_raw_idx(hoff_id);
1034                        // Use call_site span for internal identifiers to avoid
1035                        // hygiene issues when invoked through declarative macros
1036                        // (e.g. dfir_expect_warnings!). TODO(#2781): define these once.
1037                        let work_done = Ident::new("__dfir_work_done", Span::call_site());
1038                        let metrics = Ident::new("__dfir_metrics", Span::call_site());
1039                        // Tick-boundary handoffs drain from the back buffer (double-buffering).
1040                        // (Sending always writes to the regular buffer — no branch needed there.)
1041                        let drain_ident = if back_edge_hoff_ids.contains(&hoff_id) {
1042                            self.hoff_back_ident(hoff_id, buf_ident.span())
1043                        } else {
1044                            buf_ident.clone()
1045                        };
1046                        quote_spanned! {port_ident.span()=>
1047                            {
1048                                let hoff_len = #drain_ident.len();
1049                                if hoff_len > 0 {
1050                                    #work_done = true;
1051                                }
1052                                let hoff_metrics = &#metrics.handoffs[
1053                                    #root::util::slot_vec::Key::<#root::scheduled::HandoffTag>::from_raw(#hoff_idx)
1054                                ];
1055                                hoff_metrics.total_items_count.update(|x| x + hoff_len);
1056                                hoff_metrics.curr_items_count.set(hoff_len);
1057                            }
1058                            let #port_ident = #root::dfir_pipes::pull::iter(#drain_ident.drain(..));
1059                        }
1060                    })
1061                    .collect();
1062
1063                // Send port code: push into buffer.
1064                let send_port_code: Vec<TokenStream> = send_port_idents
1065                    .iter()
1066                    .zip(send_buf_idents.iter())
1067                    .map(|(port_ident, buf_ident)| {
1068                        quote_spanned! {port_ident.span()=>
1069                            let #port_ident = #root::dfir_pipes::push::vec_push(&mut #buf_ident);
1070                        }
1071                    })
1072                    .collect();
1073
1074                // All nodes in a subgraph should be in the same loop.
1075                let loop_id = self.node_loop(subgraph_nodes[0]);
1076
1077                let mut subgraph_op_iter_code = Vec::new();
1078                let mut subgraph_op_iter_after_code = Vec::new();
1079                {
1080                    let pull_to_push_idx = self.find_pull_to_push_idx(subgraph_nodes);
1081
1082                    let (pull_half, push_half) = subgraph_nodes.split_at(pull_to_push_idx);
1083                    let nodes_iter = pull_half.iter().chain(push_half.iter().rev());
1084
1085                    for (idx, &node_id) in nodes_iter.enumerate() {
1086                        let node = &self.nodes[node_id];
1087                        assert!(
1088                            matches!(node, GraphNode::Operator(_)),
1089                            "Handoffs are not part of subgraphs."
1090                        );
1091                        let op_inst = &self.operator_instances[node_id];
1092
1093                        let op_span = node.span();
1094                        let op_name = op_inst.op_constraints.name;
1095                        // Use op's span for root. #root is expected to be correct, any errors should span back to the op gen.
1096                        let root = change_spans(root.clone(), op_span);
1097                        let op_constraints = OPERATORS
1098                            .iter()
1099                            .find(|op| op_name == op.name)
1100                            .unwrap_or_else(|| panic!("Failed to find op: {}", op_name));
1101
1102                        let ident = self.node_as_ident(node_id, false);
1103
1104                        {
1105                            // TODO clean this up.
1106                            // Collect input arguments (predecessors).
1107                            let mut input_edges = self
1108                                .graph
1109                                .predecessor_edges(node_id)
1110                                .map(|edge_id| (self.edge_ports(edge_id).1, edge_id))
1111                                .collect::<Vec<_>>();
1112                            // Ensure sorted by port index.
1113                            input_edges.sort();
1114
1115                            let inputs = input_edges
1116                                .iter()
1117                                .map(|&(_port, edge_id)| {
1118                                    let (pred, _) = self.edge(edge_id);
1119                                    self.node_as_ident(pred, true)
1120                                })
1121                                .collect::<Vec<_>>();
1122
1123                            // Collect output arguments (successors).
1124                            let mut output_edges = self
1125                                .graph
1126                                .successor_edges(node_id)
1127                                .map(|edge_id| (&self.ports[edge_id].0, edge_id))
1128                                .collect::<Vec<_>>();
1129                            // Ensure sorted by port index.
1130                            output_edges.sort();
1131
1132                            let outputs = output_edges
1133                                .iter()
1134                                .map(|&(_port, edge_id)| {
1135                                    let (_, succ) = self.edge(edge_id);
1136                                    self.node_as_ident(succ, false)
1137                                })
1138                                .collect::<Vec<_>>();
1139
1140                            let is_pull = idx < pull_to_push_idx;
1141
1142                            let singleton_output_ident = &if op_constraints.has_singleton_output {
1143                                self.node_as_singleton_ident(node_id, op_span)
1144                            } else {
1145                                // This ident *should* go unused.
1146                                Ident::new(&format!("{}_has_no_singleton_output", op_name), op_span)
1147                            };
1148
1149                            // There's a bit of dark magic hidden in `Span`s... you'd think it's just a `file:line:column`,
1150                            // but it has one extra bit of info for _name resolution_, used for `Ident`s. `Span::call_site()`
1151                            // has the (unhygienic) resolution we want, an ident is just solely determined by its string name,
1152                            // which is what you'd expect out of unhygienic proc macros like this. Meanwhile, declarative macros
1153                            // use `Span::mixed_site()` which is weird and I don't understand it. It turns out that if you call
1154                            // the dfir syntax proc macro from _within_ a declarative macro then `op_span` will have the
1155                            // bad `Span::mixed_site()` name resolution and cause "Cannot find value `df/context`" errors. So
1156                            // we call `.resolved_at()` to fix resolution back to `Span::call_site()`. -Mingwei
1157                            let df_local = &Ident::new(GRAPH, op_span.resolved_at(df.span()));
1158                            let context = &Ident::new(CONTEXT, op_span.resolved_at(context.span()));
1159
1160                            let singletons_resolved =
1161                                self.helper_resolve_singletons(node_id, op_span);
1162                            let arguments = &process_singletons::postprocess_singletons(
1163                                op_inst.arguments_raw.clone(),
1164                                singletons_resolved.clone(),
1165                            );
1166                            let arguments_handles =
1167                                &process_singletons::postprocess_singletons_handles(
1168                                    op_inst.arguments_raw.clone(),
1169                                    singletons_resolved.clone(),
1170                                );
1171
1172                            let source_tag = 'a: {
1173                                if let Some(tag) = self.operator_tag.get(node_id).cloned() {
1174                                    break 'a tag;
1175                                }
1176
1177                                #[cfg(nightly)]
1178                                if proc_macro::is_available() {
1179                                    let op_span = op_span.unwrap();
1180                                    break 'a format!(
1181                                        "loc_{}_{}_{}_{}_{}",
1182                                        crate::pretty_span::make_source_path_relative(
1183                                            &op_span.file()
1184                                        )
1185                                        .display()
1186                                        .to_string()
1187                                        .replace(|x: char| !x.is_ascii_alphanumeric(), "_"),
1188                                        op_span.start().line(),
1189                                        op_span.start().column(),
1190                                        op_span.end().line(),
1191                                        op_span.end().column(),
1192                                    );
1193                                }
1194
1195                                format!(
1196                                    "loc_nopath_{}_{}_{}_{}",
1197                                    op_span.start().line,
1198                                    op_span.start().column,
1199                                    op_span.end().line,
1200                                    op_span.end().column
1201                                )
1202                            };
1203
1204                            let work_fn = format_ident!(
1205                                "{}__{}__{}",
1206                                ident,
1207                                op_name,
1208                                source_tag,
1209                                span = op_span
1210                            );
1211                            let work_fn_async = format_ident!("{}__async", work_fn, span = op_span);
1212
1213                            let context_args = WriteContextArgs {
1214                                root: &root,
1215                                df_ident: df_local,
1216                                context,
1217                                subgraph_id,
1218                                node_id,
1219                                loop_id,
1220                                op_span,
1221                                op_tag: self.operator_tag.get(node_id).cloned(),
1222                                work_fn: &work_fn,
1223                                work_fn_async: &work_fn_async,
1224                                ident: &ident,
1225                                is_pull,
1226                                inputs: &inputs,
1227                                outputs: &outputs,
1228                                singleton_output_ident,
1229                                op_name,
1230                                op_inst,
1231                                arguments,
1232                                arguments_handles,
1233                            };
1234
1235                            let write_result =
1236                                (op_constraints.write_fn)(&context_args, diagnostics);
1237                            let OperatorWriteOutput {
1238                                write_prologue,
1239                                write_prologue_after,
1240                                write_iterator,
1241                                write_iterator_after,
1242                                write_tick_end,
1243                            } = write_result.unwrap_or_else(|()| {
1244                                assert!(
1245                                    diagnostics.has_error(),
1246                                    "Operator `{}` returned `Err` but emitted no diagnostics, this is a bug.",
1247                                    op_name,
1248                                );
1249                                OperatorWriteOutput {
1250                                    write_iterator: null_write_iterator_fn(&context_args),
1251                                    ..Default::default()
1252                                }
1253                            });
1254
1255                            op_prologue_code.push(syn::parse_quote! {
1256                                #[allow(non_snake_case)]
1257                                #[inline(always)]
1258                                fn #work_fn<T>(thunk: impl ::std::ops::FnOnce() -> T) -> T {
1259                                    thunk()
1260                                }
1261
1262                                #[allow(non_snake_case)]
1263                                #[inline(always)]
1264                                async fn #work_fn_async<T>(
1265                                    thunk: impl ::std::future::Future<Output = T>,
1266                                ) -> T {
1267                                    thunk.await
1268                                }
1269                            });
1270                            op_prologue_code.push(write_prologue);
1271                            op_prologue_after_code.push(write_prologue_after);
1272                            op_tick_end_code.push(write_tick_end);
1273                            subgraph_op_iter_code.push(write_iterator);
1274
1275                            if include_type_guards {
1276                                let type_guard = if is_pull {
1277                                    quote_spanned! {op_span=>
1278                                        let #ident = {
1279                                            #[allow(non_snake_case)]
1280                                            #[inline(always)]
1281                                            pub fn #work_fn<Item, Input>(input: Input)
1282                                                -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = (), CanPend = Input::CanPend, CanEnd = Input::CanEnd>
1283                                            where
1284                                                Input: #root::dfir_pipes::pull::Pull<Item = Item, Meta = ()>,
1285                                            {
1286                                                #root::pin_project_lite::pin_project! {
1287                                                    #[repr(transparent)]
1288                                                    struct Pull<Item, Input: #root::dfir_pipes::pull::Pull<Item = Item>> {
1289                                                        #[pin]
1290                                                        inner: Input
1291                                                    }
1292                                                }
1293
1294                                                impl<Item, Input> #root::dfir_pipes::pull::Pull for Pull<Item, Input>
1295                                                where
1296                                                    Input: #root::dfir_pipes::pull::Pull<Item = Item>,
1297                                                {
1298                                                    type Ctx<'ctx> = Input::Ctx<'ctx>;
1299
1300                                                    type Item = Item;
1301                                                    type Meta = Input::Meta;
1302                                                    type CanPend = Input::CanPend;
1303                                                    type CanEnd = Input::CanEnd;
1304
1305                                                    #[inline(always)]
1306                                                    fn pull(
1307                                                        self: ::std::pin::Pin<&mut Self>,
1308                                                        ctx: &mut Self::Ctx<'_>,
1309                                                    ) -> #root::dfir_pipes::pull::PullStep<Self::Item, Self::Meta, Self::CanPend, Self::CanEnd> {
1310                                                        #root::dfir_pipes::pull::Pull::pull(self.project().inner, ctx)
1311                                                    }
1312
1313                                                    #[inline(always)]
1314                                                    fn size_hint(&self) -> (usize, Option<usize>) {
1315                                                        #root::dfir_pipes::pull::Pull::size_hint(&self.inner)
1316                                                    }
1317                                                }
1318
1319                                                Pull {
1320                                                    inner: input
1321                                                }
1322                                            }
1323                                            #work_fn::<_, _>( #ident )
1324                                        };
1325                                    }
1326                                } else {
1327                                    quote_spanned! {op_span=>
1328                                        let #ident = {
1329                                            #[allow(non_snake_case)]
1330                                            #[inline(always)]
1331                                            pub fn #work_fn<Item, Psh>(psh: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
1332                                            where
1333                                                Psh: #root::dfir_pipes::push::Push<Item, ()>
1334                                            {
1335                                                #root::pin_project_lite::pin_project! {
1336                                                    #[repr(transparent)]
1337                                                    struct PushGuard<Psh> {
1338                                                        #[pin]
1339                                                        inner: Psh,
1340                                                    }
1341                                                }
1342
1343                                                impl<Item, Psh> #root::dfir_pipes::push::Push<Item, ()> for PushGuard<Psh>
1344                                                where
1345                                                    Psh: #root::dfir_pipes::push::Push<Item, ()>,
1346                                                {
1347                                                    type Ctx<'ctx> = Psh::Ctx<'ctx>;
1348
1349                                                    type CanPend = Psh::CanPend;
1350
1351                                                    #[inline(always)]
1352                                                    fn poll_ready(
1353                                                        self: ::std::pin::Pin<&mut Self>,
1354                                                        ctx: &mut Self::Ctx<'_>,
1355                                                    ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1356                                                        #root::dfir_pipes::push::Push::poll_ready(self.project().inner, ctx)
1357                                                    }
1358
1359                                                    #[inline(always)]
1360                                                    fn start_send(
1361                                                        self: ::std::pin::Pin<&mut Self>,
1362                                                        item: Item,
1363                                                        meta: (),
1364                                                    ) {
1365                                                        #root::dfir_pipes::push::Push::start_send(self.project().inner, item, meta)
1366                                                    }
1367
1368                                                    #[inline(always)]
1369                                                    fn poll_flush(
1370                                                        self: ::std::pin::Pin<&mut Self>,
1371                                                        ctx: &mut Self::Ctx<'_>,
1372                                                    ) -> #root::dfir_pipes::push::PushStep<Self::CanPend> {
1373                                                        #root::dfir_pipes::push::Push::poll_flush(self.project().inner, ctx)
1374                                                    }
1375
1376                                                    #[inline(always)]
1377                                                    fn size_hint(
1378                                                        self: ::std::pin::Pin<&mut Self>,
1379                                                        hint: (usize, Option<usize>),
1380                                                    ) {
1381                                                        #root::dfir_pipes::push::Push::size_hint(self.project().inner, hint)
1382                                                    }
1383                                                }
1384
1385                                                PushGuard {
1386                                                    inner: psh
1387                                                }
1388                                            }
1389                                            #work_fn( #ident )
1390                                        };
1391                                    }
1392                                };
1393                                subgraph_op_iter_code.push(type_guard);
1394                            }
1395                            subgraph_op_iter_after_code.push(write_iterator_after);
1396                        }
1397                    }
1398
1399                    {
1400                        // Determine pull and push halves of the `Pivot`.
1401                        let pull_ident = if 0 < pull_to_push_idx {
1402                            self.node_as_ident(subgraph_nodes[pull_to_push_idx - 1], false)
1403                        } else {
1404                            // Entire subgraph is push (with a single recv/pull handoff input).
1405                            recv_port_idents[0].clone()
1406                        };
1407
1408                        #[rustfmt::skip]
1409                        let push_ident = if let Some(&node_id) =
1410                            subgraph_nodes.get(pull_to_push_idx)
1411                        {
1412                            self.node_as_ident(node_id, false)
1413                        } else if 1 == send_port_idents.len() {
1414                            // Entire subgraph is pull (with a single send/push handoff output).
1415                            send_port_idents[0].clone()
1416                        } else {
1417                            diagnostics.push(Diagnostic::spanned(
1418                                pull_ident.span(),
1419                                Level::Error,
1420                                "Degenerate subgraph detected, is there a disconnected `null()` or other degenerate pipeline somewhere?",
1421                            ));
1422                            continue;
1423                        };
1424
1425                        // Pivot span is combination of pull and push spans (or if not possible, just take the push).
1426                        let pivot_span = pull_ident
1427                            .span()
1428                            .join(push_ident.span())
1429                            .unwrap_or_else(|| push_ident.span());
1430                        let pivot_fn_ident =
1431                            Ident::new(&format!("pivot_run_sg_{:?}", subgraph_id.0), pivot_span);
1432                        let root = change_spans(root.clone(), pivot_span);
1433                        subgraph_op_iter_code.push(quote_spanned! {pivot_span=>
1434                            #[inline(always)]
1435                            fn #pivot_fn_ident<Pul, Psh, Item>(pull: Pul, push: Psh)
1436                                -> impl ::std::future::Future<Output = ()>
1437                            where
1438                                Pul: #root::dfir_pipes::pull::Pull<Item = Item>,
1439                                Psh: #root::dfir_pipes::push::Push<Item, Pul::Meta>,
1440                            {
1441                                #root::dfir_pipes::pull::Pull::send_push(pull, push)
1442                            }
1443                            (#pivot_fn_ident)(#pull_ident, #push_ident).await;
1444                        });
1445                    }
1446                };
1447
1448                // Each subgraph block is an async block so it can be individually instrumented.
1449                // Note: this ident is for the subgraph future, not a runtime SubgraphId binding
1450                // (unlike the scheduled path's `sg_ident`).
1451                let sg_fut_ident = subgraph_id.as_ident(Span::call_site());
1452
1453                // Generate send-side curr_items_count updates (after subgraph runs).
1454                let send_metrics_code: Vec<TokenStream> = send_hoffs
1455                    .iter()
1456                    .zip(send_buf_idents.iter())
1457                    .map(|(&hoff_id, buf_ident)| {
1458                        let hoff_idx = slotmap_raw_idx(hoff_id);
1459                        quote! {
1460                            __dfir_metrics.handoffs[
1461                                #root::util::slot_vec::Key::<#root::scheduled::HandoffTag>::from_raw(#hoff_idx)
1462                            ].curr_items_count.set(#buf_ident.len());
1463                        }
1464                    })
1465                    .collect();
1466
1467                subgraph_blocks.push(quote! {
1468                    let #sg_fut_ident = async {
1469                        let #context = &#df;
1470                        #( #recv_port_code )*
1471                        #( #send_port_code )*
1472                        #( #subgraph_op_iter_code )*
1473                        #( #subgraph_op_iter_after_code )*
1474                    };
1475                    {
1476                        let sg_metrics = &__dfir_metrics.subgraphs[
1477                            #root::util::slot_vec::Key::<#root::scheduled::SubgraphTag>::from_raw(#sg_metrics_idx)
1478                        ];
1479                        #root::scheduled::metrics::InstrumentSubgraph::new(
1480                            #sg_fut_ident, sg_metrics
1481                        ).await;
1482                        sg_metrics.total_run_count.update(|x| x + 1);
1483                    }
1484                    #( #send_metrics_code )*
1485                });
1486
1487                // Collect per-subgraph prologues into the main prologue lists.
1488                // (They are already pushed above in the operator loop.)
1489            }
1490        }
1491
1492        if diagnostics.has_error() {
1493            return Err(std::mem::take(diagnostics));
1494        }
1495        let _ = diagnostics; // Ensure no more diagnostics may be added after checking for errors.
1496
1497        let (meta_graph_arg, diagnostics_arg) = if include_meta {
1498            let meta_graph_json = serde_json::to_string(&self).unwrap();
1499            let meta_graph_json = Literal::string(&meta_graph_json);
1500
1501            let serde_diagnostics: Vec<_> = diagnostics.iter().map(Diagnostic::to_serde).collect();
1502            let diagnostics_json = serde_json::to_string(&*serde_diagnostics).unwrap();
1503            let diagnostics_json = Literal::string(&diagnostics_json);
1504
1505            (
1506                quote! { Some(#meta_graph_json) },
1507                quote! { Some(#diagnostics_json) },
1508            )
1509        } else {
1510            (quote! { None }, quote! { None })
1511        };
1512
1513        // Generate metrics initialization: one entry per handoff and per subgraph.
1514        let metrics_init_code = {
1515            let handoff_inits = handoff_nodes.iter().map(|&(node_id, _)| {
1516                let idx = slotmap_raw_idx(node_id);
1517                quote! {
1518                    dfir_metrics.handoffs.insert(
1519                        #root::util::slot_vec::Key::from_raw(#idx),
1520                        ::std::default::Default::default(),
1521                    );
1522                }
1523            });
1524            let subgraph_inits = all_subgraphs.iter().map(|&(sg_id, _)| {
1525                let idx = slotmap_raw_idx(sg_id);
1526                quote! {
1527                    dfir_metrics.subgraphs.insert(
1528                        #root::util::slot_vec::Key::from_raw(#idx),
1529                        ::std::default::Default::default(),
1530                    );
1531                }
1532            });
1533            handoff_inits.chain(subgraph_inits).collect::<Vec<_>>()
1534        };
1535
1536        // Prologues and buffer declarations persist across ticks (outside the closure).
1537        // Subgraph blocks run each tick (inside the closure).
1538        Ok(quote! {
1539            {
1540                #prefix
1541
1542                use #root::{var_expr, var_args};
1543
1544                let __dfir_wake_state = ::std::sync::Arc::new(
1545                    #root::scheduled::context::WakeState::default()
1546                );
1547
1548                let __dfir_metrics = {
1549                    let mut dfir_metrics = #root::scheduled::metrics::DfirMetrics::default();
1550                    #( #metrics_init_code )*
1551                    ::std::rc::Rc::new(dfir_metrics)
1552                };
1553
1554                #[allow(unused_mut)]
1555                let mut #df = #root::scheduled::context::Context::new(
1556                    ::std::clone::Clone::clone(&__dfir_wake_state),
1557                    __dfir_metrics,
1558                );
1559
1560                #( #buffer_code )*
1561                #( #back_buffer_code )*
1562                #( #op_prologue_code )*
1563                #( #op_prologue_after_code )*
1564
1565                // Pre-set to true so the first tick always returns true
1566                // (matching Dfir pre-scheduling behavior). Subsequent ticks
1567                // start false (from take()) and are set true by recv port code
1568                // if any handoff buffer has data.
1569                let mut __dfir_work_done = true;
1570                #[allow(unused_qualifications, unused_mut, unused_variables, clippy::await_holding_refcell_ref)]
1571                let __dfir_inline_tick = async move |#df: &mut #root::scheduled::context::Context| {
1572                    let __dfir_metrics = #df.metrics();
1573                    // Double-buffer swap for defer_tick handoffs: move last tick's
1574                    // producer output into the back buffer for the consumer to drain.
1575                    #( #back_edge_swap_code )*
1576                    #( #subgraph_blocks )*
1577
1578                    // For non-lazy defer_tick: if any deferred buffer has data,
1579                    // signal that another tick should run (sets can_start_tick).
1580                    // Inline DFIR doesn't dynamically schedule subgraph IDs, so the
1581                    // subgraph ID here is a meaningless placeholder.
1582                    // TODO(cleanup): remove the subgraph ID parameter once scheduled DFIR is gone.
1583                    if false #( || !#defer_tick_buf_idents.is_empty() )* {
1584                        #df.schedule_subgraph(
1585                            #root::scheduled::SubgraphId::from_raw(0),
1586                            true,
1587                        );
1588                    }
1589
1590                    // End-of-tick state reset (e.g. 'tick persistence).
1591                    #( #op_tick_end_code )*
1592
1593                    #df.__end_tick();
1594                    ::std::mem::take(&mut __dfir_work_done)
1595                };
1596                #root::scheduled::context::Dfir::new(
1597                    __dfir_inline_tick,
1598                    #df,
1599                    #meta_graph_arg,
1600                    #diagnostics_arg,
1601                )
1602            }
1603        })
1604    }
1605
1606    /// Color mode (pull vs. push, handoff vs. comp) for nodes. Some nodes can be push *OR* pull;
1607    /// those nodes will not be set in the returned map.
1608    pub fn node_color_map(&self) -> SparseSecondaryMap<GraphNodeId, Color> {
1609        let mut node_color_map: SparseSecondaryMap<GraphNodeId, Color> = self
1610            .node_ids()
1611            .filter_map(|node_id| {
1612                let op_color = self.node_color(node_id)?;
1613                Some((node_id, op_color))
1614            })
1615            .collect();
1616
1617        // Fill in rest via subgraphs.
1618        for sg_nodes in self.subgraph_nodes.values() {
1619            let pull_to_push_idx = self.find_pull_to_push_idx(sg_nodes);
1620
1621            for (idx, node_id) in sg_nodes.iter().copied().enumerate() {
1622                let is_pull = idx < pull_to_push_idx;
1623                node_color_map.insert(node_id, if is_pull { Color::Pull } else { Color::Push });
1624            }
1625        }
1626
1627        node_color_map
1628    }
1629
1630    /// Writes this graph as mermaid into a string.
1631    pub fn to_mermaid(&self, write_config: &WriteConfig) -> String {
1632        let mut output = String::new();
1633        self.write_mermaid(&mut output, write_config).unwrap();
1634        output
1635    }
1636
1637    /// Writes this graph as mermaid into the given `Write`.
1638    pub fn write_mermaid(
1639        &self,
1640        output: impl std::fmt::Write,
1641        write_config: &WriteConfig,
1642    ) -> std::fmt::Result {
1643        let mut graph_write = Mermaid::new(output);
1644        self.write_graph(&mut graph_write, write_config)
1645    }
1646
1647    /// Writes this graph as DOT (graphviz) into a string.
1648    pub fn to_dot(&self, write_config: &WriteConfig) -> String {
1649        let mut output = String::new();
1650        let mut graph_write = Dot::new(&mut output);
1651        self.write_graph(&mut graph_write, write_config).unwrap();
1652        output
1653    }
1654
1655    /// Writes this graph as DOT (graphviz) into the given `Write`.
1656    pub fn write_dot(
1657        &self,
1658        output: impl std::fmt::Write,
1659        write_config: &WriteConfig,
1660    ) -> std::fmt::Result {
1661        let mut graph_write = Dot::new(output);
1662        self.write_graph(&mut graph_write, write_config)
1663    }
1664
1665    /// Write out this graph using the given `GraphWrite`. E.g. `Mermaid` or `Dot.
1666    pub(crate) fn write_graph<W>(
1667        &self,
1668        mut graph_write: W,
1669        write_config: &WriteConfig,
1670    ) -> Result<(), W::Err>
1671    where
1672        W: GraphWrite,
1673    {
1674        fn helper_edge_label(
1675            src_port: &PortIndexValue,
1676            dst_port: &PortIndexValue,
1677        ) -> Option<String> {
1678            let src_label = match src_port {
1679                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1680                PortIndexValue::Int(index) => Some(index.value.to_string()),
1681                _ => None,
1682            };
1683            let dst_label = match dst_port {
1684                PortIndexValue::Path(path) => Some(path.to_token_stream().to_string()),
1685                PortIndexValue::Int(index) => Some(index.value.to_string()),
1686                _ => None,
1687            };
1688            let label = match (src_label, dst_label) {
1689                (Some(l1), Some(l2)) => Some(format!("{}\n{}", l1, l2)),
1690                (Some(l1), None) => Some(l1),
1691                (None, Some(l2)) => Some(l2),
1692                (None, None) => None,
1693            };
1694            label
1695        }
1696
1697        // Make node color map one time.
1698        let node_color_map = self.node_color_map();
1699
1700        // Write prologue.
1701        graph_write.write_prologue()?;
1702
1703        // Define nodes.
1704        let mut skipped_handoffs = BTreeSet::new();
1705        let mut subgraph_handoffs = <BTreeMap<GraphSubgraphId, Vec<GraphNodeId>>>::new();
1706        for (node_id, node) in self.nodes() {
1707            if matches!(node, GraphNode::Handoff { .. }) {
1708                if write_config.no_handoffs {
1709                    skipped_handoffs.insert(node_id);
1710                    continue;
1711                } else {
1712                    let pred_node = self.node_predecessor_nodes(node_id).next().unwrap();
1713                    let pred_sg = self.node_subgraph(pred_node);
1714                    let succ_node = self.node_successor_nodes(node_id).next().unwrap();
1715                    let succ_sg = self.node_subgraph(succ_node);
1716                    if let Some((pred_sg, succ_sg)) = pred_sg.zip(succ_sg)
1717                        && pred_sg == succ_sg
1718                    {
1719                        subgraph_handoffs.entry(pred_sg).or_default().push(node_id);
1720                    }
1721                }
1722            }
1723            graph_write.write_node_definition(
1724                node_id,
1725                &if write_config.op_short_text {
1726                    node.to_name_string()
1727                } else if write_config.op_text_no_imports {
1728                    // Remove any lines that start with "use" (imports)
1729                    let full_text = node.to_pretty_string();
1730                    let mut output = String::new();
1731                    for sentence in full_text.split('\n') {
1732                        if sentence.trim().starts_with("use") {
1733                            continue;
1734                        }
1735                        output.push('\n');
1736                        output.push_str(sentence);
1737                    }
1738                    output.into()
1739                } else {
1740                    node.to_pretty_string()
1741                },
1742                if write_config.no_pull_push {
1743                    None
1744                } else {
1745                    node_color_map.get(node_id).copied()
1746                },
1747            )?;
1748        }
1749
1750        // Write edges.
1751        for (edge_id, (src_id, mut dst_id)) in self.edges() {
1752            // Handling for if `write_config.no_handoffs` true.
1753            if skipped_handoffs.contains(&src_id) {
1754                continue;
1755            }
1756
1757            let (src_port, mut dst_port) = self.edge_ports(edge_id);
1758            if skipped_handoffs.contains(&dst_id) {
1759                let mut handoff_succs = self.node_successors(dst_id);
1760                assert_eq!(1, handoff_succs.len());
1761                let (succ_edge, succ_node) = handoff_succs.next().unwrap();
1762                dst_id = succ_node;
1763                dst_port = self.edge_ports(succ_edge).1;
1764            }
1765
1766            let label = helper_edge_label(src_port, dst_port);
1767            let delay_type = self
1768                .node_op_inst(dst_id)
1769                .and_then(|op_inst| (op_inst.op_constraints.input_delaytype_fn)(dst_port));
1770            graph_write.write_edge(src_id, dst_id, delay_type, label.as_deref(), false)?;
1771        }
1772
1773        // Write reference edges.
1774        if !write_config.no_references {
1775            for dst_id in self.node_ids() {
1776                for src_ref_id in self
1777                    .node_singleton_references(dst_id)
1778                    .iter()
1779                    .copied()
1780                    .flatten()
1781                {
1782                    let delay_type = Some(DelayType::Stratum);
1783                    let label = None;
1784                    graph_write.write_edge(src_ref_id, dst_id, delay_type, label, true)?;
1785                }
1786            }
1787        }
1788
1789        // The following code is a little bit tricky. Generally, the graph has the hierarchy:
1790        // `loop -> subgraph -> varname -> node`. However, each of these can be disabled via the `write_config`. To
1791        // handle both the enabled and disabled case, this code is structured as a series of nested loops. If the layer
1792        // is disabled, then the HashMap<Option<KEY>, Vec<VALUE>> will only have a single key (`None`) with a
1793        // corresponding `Vec` value containing everything. This way no special handling is needed for the next layer.
1794
1795        // Loop -> Subgraphs
1796        let loop_subgraphs = self.subgraph_ids().map(|sg_id| {
1797            let loop_id = if write_config.no_loops {
1798                None
1799            } else {
1800                self.subgraph_loop(sg_id)
1801            };
1802            (loop_id, sg_id)
1803        });
1804        let loop_subgraphs = into_group_map(loop_subgraphs);
1805        for (loop_id, subgraph_ids) in loop_subgraphs {
1806            if let Some(loop_id) = loop_id {
1807                graph_write.write_loop_start(loop_id)?;
1808            }
1809
1810            // Subgraph -> Varnames.
1811            let subgraph_varnames_nodes = subgraph_ids.into_iter().flat_map(|sg_id| {
1812                self.subgraph(sg_id).iter().copied().map(move |node_id| {
1813                    let opt_sg_id = if write_config.no_subgraphs {
1814                        None
1815                    } else {
1816                        Some(sg_id)
1817                    };
1818                    (opt_sg_id, (self.node_varname(node_id), node_id))
1819                })
1820            });
1821            let subgraph_varnames_nodes = into_group_map(subgraph_varnames_nodes);
1822            for (sg_id, varnames) in subgraph_varnames_nodes {
1823                if let Some(sg_id) = sg_id {
1824                    graph_write.write_subgraph_start(sg_id)?;
1825                }
1826
1827                // Varnames -> Nodes.
1828                let varname_nodes = varnames.into_iter().map(|(varname, node)| {
1829                    let varname = if write_config.no_varnames {
1830                        None
1831                    } else {
1832                        varname
1833                    };
1834                    (varname, node)
1835                });
1836                let varname_nodes = into_group_map(varname_nodes);
1837                for (varname, node_ids) in varname_nodes {
1838                    if let Some(varname) = varname {
1839                        graph_write.write_varname_start(&varname.0.to_string(), sg_id)?;
1840                    }
1841
1842                    // Write all nodes.
1843                    for node_id in node_ids {
1844                        graph_write.write_node(node_id)?;
1845                    }
1846
1847                    if varname.is_some() {
1848                        graph_write.write_varname_end()?;
1849                    }
1850                }
1851
1852                if sg_id.is_some() {
1853                    graph_write.write_subgraph_end()?;
1854                }
1855            }
1856
1857            if loop_id.is_some() {
1858                graph_write.write_loop_end()?;
1859            }
1860        }
1861
1862        // Write epilogue.
1863        graph_write.write_epilogue()?;
1864
1865        Ok(())
1866    }
1867
1868    /// Convert back into surface syntax.
1869    pub fn surface_syntax_string(&self) -> String {
1870        let mut string = String::new();
1871        self.write_surface_syntax(&mut string).unwrap();
1872        string
1873    }
1874
1875    /// Convert back into surface syntax.
1876    pub fn write_surface_syntax(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1877        for (key, node) in self.nodes.iter() {
1878            match node {
1879                GraphNode::Operator(op) => {
1880                    writeln!(write, "{:?} = {};", key.data(), op.to_token_stream())?;
1881                }
1882                GraphNode::Handoff { .. } => {
1883                    writeln!(write, "// {:?} = <handoff>;", key.data())?;
1884                }
1885                GraphNode::ModuleBoundary { .. } => panic!(),
1886            }
1887        }
1888        writeln!(write)?;
1889        for (_e, (src_key, dst_key)) in self.graph.edges() {
1890            writeln!(write, "{:?} -> {:?};", src_key.data(), dst_key.data())?;
1891        }
1892        Ok(())
1893    }
1894
1895    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
1896    pub fn mermaid_string_flat(&self) -> String {
1897        let mut string = String::new();
1898        self.write_mermaid_flat(&mut string).unwrap();
1899        string
1900    }
1901
1902    /// Convert into a [mermaid](https://mermaid-js.github.io/) graph. Ignores subgraphs.
1903    pub fn write_mermaid_flat(&self, write: &mut impl std::fmt::Write) -> std::fmt::Result {
1904        writeln!(write, "flowchart TB")?;
1905        for (key, node) in self.nodes.iter() {
1906            match node {
1907                GraphNode::Operator(operator) => writeln!(
1908                    write,
1909                    "    %% {span}\n    {id:?}[\"{row_col} <tt>{code}</tt>\"]",
1910                    span = PrettySpan(node.span()),
1911                    id = key.data(),
1912                    row_col = PrettyRowCol(node.span()),
1913                    code = operator
1914                        .to_token_stream()
1915                        .to_string()
1916                        .replace('&', "&amp;")
1917                        .replace('<', "&lt;")
1918                        .replace('>', "&gt;")
1919                        .replace('"', "&quot;")
1920                        .replace('\n', "<br>"),
1921                ),
1922                GraphNode::Handoff { .. } => {
1923                    writeln!(write, r#"    {:?}{{"{}"}}"#, key.data(), HANDOFF_NODE_STR)
1924                }
1925                GraphNode::ModuleBoundary { .. } => {
1926                    writeln!(
1927                        write,
1928                        r#"    {:?}{{"{}"}}"#,
1929                        key.data(),
1930                        MODULE_BOUNDARY_NODE_STR
1931                    )
1932                }
1933            }?;
1934        }
1935        writeln!(write)?;
1936        for (_e, (src_key, dst_key)) in self.graph.edges() {
1937            writeln!(write, "    {:?}-->{:?}", src_key.data(), dst_key.data())?;
1938        }
1939        Ok(())
1940    }
1941}
1942
1943/// Loops
1944impl DfirGraph {
1945    /// Iterator over all loop IDs.
1946    pub fn loop_ids(&self) -> slotmap::basic::Keys<'_, GraphLoopId, Vec<GraphNodeId>> {
1947        self.loop_nodes.keys()
1948    }
1949
1950    /// Iterator over all loops, ID and members: `(GraphLoopId, Vec<GraphNodeId>)`.
1951    pub fn loops(&self) -> slotmap::basic::Iter<'_, GraphLoopId, Vec<GraphNodeId>> {
1952        self.loop_nodes.iter()
1953    }
1954
1955    /// Create a new loop context, with the given parent loop (or `None`).
1956    pub fn insert_loop(&mut self, parent_loop: Option<GraphLoopId>) -> GraphLoopId {
1957        let loop_id = self.loop_nodes.insert(Vec::new());
1958        self.loop_children.insert(loop_id, Vec::new());
1959        if let Some(parent_loop) = parent_loop {
1960            self.loop_parent.insert(loop_id, parent_loop);
1961            self.loop_children
1962                .get_mut(parent_loop)
1963                .unwrap()
1964                .push(loop_id);
1965        } else {
1966            self.root_loops.push(loop_id);
1967        }
1968        loop_id
1969    }
1970
1971    /// Get a node's loop context (or `None` for root).
1972    pub fn node_loop(&self, node_id: GraphNodeId) -> Option<GraphLoopId> {
1973        self.node_loops.get(node_id).copied()
1974    }
1975
1976    /// Get a subgraph's loop context (or `None` for root).
1977    pub fn subgraph_loop(&self, subgraph_id: GraphSubgraphId) -> Option<GraphLoopId> {
1978        let &node_id = self.subgraph(subgraph_id).first().unwrap();
1979        let out = self.node_loop(node_id);
1980        debug_assert!(
1981            self.subgraph(subgraph_id)
1982                .iter()
1983                .all(|&node_id| self.node_loop(node_id) == out),
1984            "Subgraph nodes should all have the same loop context."
1985        );
1986        out
1987    }
1988
1989    /// Get a loop context's parent loop context (or `None` for root).
1990    pub fn loop_parent(&self, loop_id: GraphLoopId) -> Option<GraphLoopId> {
1991        self.loop_parent.get(loop_id).copied()
1992    }
1993
1994    /// Get a loop context's child loops.
1995    pub fn loop_children(&self, loop_id: GraphLoopId) -> &Vec<GraphLoopId> {
1996        self.loop_children.get(loop_id).unwrap()
1997    }
1998}
1999
2000/// Configuration for writing graphs.
2001#[derive(Clone, Debug, Default)]
2002#[cfg_attr(feature = "clap-derive", derive(clap::Args))]
2003pub struct WriteConfig {
2004    /// Subgraphs will not be rendered if set.
2005    #[cfg_attr(feature = "clap-derive", arg(long))]
2006    pub no_subgraphs: bool,
2007    /// Variable names will not be rendered if set.
2008    #[cfg_attr(feature = "clap-derive", arg(long))]
2009    pub no_varnames: bool,
2010    /// Will not render pull/push shapes if set.
2011    #[cfg_attr(feature = "clap-derive", arg(long))]
2012    pub no_pull_push: bool,
2013    /// Will not render handoffs if set.
2014    #[cfg_attr(feature = "clap-derive", arg(long))]
2015    pub no_handoffs: bool,
2016    /// Will not render singleton references if set.
2017    #[cfg_attr(feature = "clap-derive", arg(long))]
2018    pub no_references: bool,
2019    /// Will not render loops if set.
2020    #[cfg_attr(feature = "clap-derive", arg(long))]
2021    pub no_loops: bool,
2022
2023    /// Op text will only be their name instead of the whole source.
2024    #[cfg_attr(feature = "clap-derive", arg(long))]
2025    pub op_short_text: bool,
2026    /// Op text will exclude any line that starts with "use".
2027    #[cfg_attr(feature = "clap-derive", arg(long))]
2028    pub op_text_no_imports: bool,
2029}
2030
2031/// Enum for choosing between mermaid and dot graph writing.
2032#[derive(Copy, Clone, Debug)]
2033#[cfg_attr(feature = "clap-derive", derive(clap::Parser, clap::ValueEnum))]
2034pub enum WriteGraphType {
2035    /// Mermaid graphs.
2036    Mermaid,
2037    /// Dot (Graphviz) graphs.
2038    Dot,
2039}
2040
2041/// [`itertools::Itertools::into_group_map`], but for `BTreeMap`.
2042fn into_group_map<K, V>(iter: impl IntoIterator<Item = (K, V)>) -> BTreeMap<K, Vec<V>>
2043where
2044    K: Ord,
2045{
2046    let mut out: BTreeMap<_, Vec<_>> = BTreeMap::new();
2047    for (k, v) in iter {
2048        out.entry(k).or_default().push(v);
2049    }
2050    out
2051}