1use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use documented::DocumentedVariants;
9use proc_macro2::{Ident, Literal, Span, TokenStream};
10use quote::quote_spanned;
11use serde::{Deserialize, Serialize};
12use slotmap::Key;
13use syn::punctuated::Punctuated;
14use syn::{Expr, Token, parse_quote_spanned};
15
16use super::{
17 GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
18 PortIndexValue,
19};
20use crate::diagnostic::{Diagnostic, Diagnostics, Level};
21use crate::parse::{Operator, PortIndex};
22
23#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
25pub enum DelayType {
26 Stratum,
28 MonotoneAccum,
30 Tick,
32 TickLazy,
34}
35
36pub enum PortListSpec {
38 Variadic,
40 Fixed(Punctuated<PortIndex, Token![,]>),
42}
43
44pub struct OperatorConstraints {
46 pub name: &'static str,
48 pub categories: &'static [OperatorCategory],
50
51 pub hard_range_inn: &'static dyn RangeTrait<usize>,
54 pub soft_range_inn: &'static dyn RangeTrait<usize>,
56 pub hard_range_out: &'static dyn RangeTrait<usize>,
58 pub soft_range_out: &'static dyn RangeTrait<usize>,
60 pub num_args: usize,
62 pub persistence_args: &'static dyn RangeTrait<usize>,
64 pub type_args: &'static dyn RangeTrait<usize>,
68 pub is_external_input: bool,
71 pub has_singleton_output: bool,
75 pub flo_type: Option<FloType>,
77
78 pub ports_inn: Option<fn() -> PortListSpec>,
80 pub ports_out: Option<fn() -> PortListSpec>,
82
83 pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
85 pub write_fn: WriteFn,
87}
88
89pub type WriteFn = fn(&WriteContextArgs<'_>, &mut Diagnostics) -> Result<OperatorWriteOutput, ()>;
91
92impl Debug for OperatorConstraints {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("OperatorConstraints")
95 .field("name", &self.name)
96 .field("hard_range_inn", &self.hard_range_inn)
97 .field("soft_range_inn", &self.soft_range_inn)
98 .field("hard_range_out", &self.hard_range_out)
99 .field("soft_range_out", &self.soft_range_out)
100 .field("num_args", &self.num_args)
101 .field("persistence_args", &self.persistence_args)
102 .field("type_args", &self.type_args)
103 .field("is_external_input", &self.is_external_input)
104 .field("ports_inn", &self.ports_inn)
105 .field("ports_out", &self.ports_out)
106 .finish()
110 }
111}
112
113#[derive(Default)]
117pub struct OperatorWriteOutput {
118 pub write_prologue: TokenStream,
122 pub write_prologue_after: TokenStream,
125 pub write_iterator: TokenStream,
132 pub write_iterator_after: TokenStream,
134 pub write_tick_end: TokenStream,
137}
138
139pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
141pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
143pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
145
146pub fn identity_write_iterator_fn(
149 &WriteContextArgs {
150 root,
151 op_span,
152 ident,
153 inputs,
154 outputs,
155 is_pull,
156 op_inst:
157 OperatorInstance {
158 generics: OpInstGenerics { type_args, .. },
159 ..
160 },
161 ..
162 }: &WriteContextArgs,
163) -> TokenStream {
164 let generic_type = type_args
165 .first()
166 .map(quote::ToTokens::to_token_stream)
167 .unwrap_or(quote_spanned!(op_span=> _));
168
169 if is_pull {
170 let input = &inputs[0];
171 quote_spanned! {op_span=>
172 let #ident = {
173 fn check_input<Pull, Item>(pull: Pull) -> impl #root::dfir_pipes::pull::Pull<Item = Item, Meta = Pull::Meta, CanPend = Pull::CanPend, CanEnd = Pull::CanEnd>
174 where
175 Pull: #root::dfir_pipes::pull::Pull<Item = Item>,
176 {
177 pull
178 }
179 check_input::<_, #generic_type>(#input)
180 };
181 }
182 } else {
183 let output = &outputs[0];
184 quote_spanned! {op_span=>
185 let #ident = {
186 fn check_output<Psh, Item>(push: Psh) -> impl #root::dfir_pipes::push::Push<Item, (), CanPend = Psh::CanPend>
187 where
188 Psh: #root::dfir_pipes::push::Push<Item, ()>,
189 {
190 push
191 }
192 check_output::<_, #generic_type>(#output)
193 };
194 }
195 }
196}
197
198pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
200 let write_iterator = identity_write_iterator_fn(write_context_args);
201 Ok(OperatorWriteOutput {
202 write_iterator,
203 ..Default::default()
204 })
205};
206
207pub fn null_write_iterator_fn(
210 &WriteContextArgs {
211 root,
212 op_span,
213 ident,
214 inputs,
215 outputs,
216 is_pull,
217 op_inst:
218 OperatorInstance {
219 generics: OpInstGenerics { type_args, .. },
220 ..
221 },
222 ..
223 }: &WriteContextArgs,
224) -> TokenStream {
225 let default_type = parse_quote_spanned! {op_span=> _};
226 let iter_type = type_args.first().unwrap_or(&default_type);
227
228 if is_pull {
229 quote_spanned! {op_span=>
230 let #ident = #root::dfir_pipes::pull::poll_fn({
231 #(
232 let mut #inputs = ::std::boxed::Box::pin(#inputs);
233 )*
234 move |_cx| {
235 #(
239 let #inputs = #root::dfir_pipes::pull::Pull::pull(
240 ::std::pin::Pin::as_mut(&mut #inputs),
241 <_ as #root::dfir_pipes::Context>::from_task(_cx),
242 );
243 )*
244 #(
245 if let #root::dfir_pipes::pull::PullStep::Pending(_) = #inputs {
246 return #root::dfir_pipes::pull::PullStep::Pending(#root::dfir_pipes::Yes);
247 }
248 )*
249 #root::dfir_pipes::pull::PullStep::<_, _, #root::dfir_pipes::Yes, _>::Ended(#root::dfir_pipes::Yes)
250 }
251 });
252 }
253 } else {
254 quote_spanned! {op_span=>
255 #[allow(clippy::let_unit_value)]
256 let _ = (#(#outputs),*);
257 let #ident = #root::dfir_pipes::push::for_each::<_, #iter_type>(::std::mem::drop::<#iter_type>);
258 }
259 }
260}
261
262pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
265 let write_iterator = null_write_iterator_fn(write_context_args);
266 Ok(OperatorWriteOutput {
267 write_iterator,
268 ..Default::default()
269 })
270};
271
272macro_rules! declare_ops {
273 ( $( $mod:ident :: $op:ident, )* ) => {
274 $( pub(crate) mod $mod; )*
275 pub const OPERATORS: &[OperatorConstraints] = &[
277 $( $mod :: $op, )*
278 ];
279 };
280}
281declare_ops![
282 all_iterations::ALL_ITERATIONS,
283 all_once::ALL_ONCE,
284 anti_join::ANTI_JOIN,
285 assert::ASSERT,
286 assert_eq::ASSERT_EQ,
287 batch::BATCH,
288 chain::CHAIN,
289 chain_first_n::CHAIN_FIRST_N,
290 _counter::_COUNTER,
291 cross_join::CROSS_JOIN,
292 cross_join_multiset::CROSS_JOIN_MULTISET,
293 cross_singleton::CROSS_SINGLETON,
294 demux_enum::DEMUX_ENUM,
295 dest_file::DEST_FILE,
296 dest_sink::DEST_SINK,
297 dest_sink_serde::DEST_SINK_SERDE,
298 difference::DIFFERENCE,
299 enumerate::ENUMERATE,
300 filter::FILTER,
301 filter_map::FILTER_MAP,
302 flat_map::FLAT_MAP,
303 flat_map_stream_blocking::FLAT_MAP_STREAM_BLOCKING,
304 flatten::FLATTEN,
305 flatten_stream_blocking::FLATTEN_STREAM_BLOCKING,
306 fold::FOLD,
307 fold_no_replay::FOLD_NO_REPLAY,
308 for_each::FOR_EACH,
309 identity::IDENTITY,
310 initialize::INITIALIZE,
311 inspect::INSPECT,
312 join::JOIN,
313 join_fused::JOIN_FUSED,
314 join_fused_lhs::JOIN_FUSED_LHS,
315 join_fused_rhs::JOIN_FUSED_RHS,
316 join_multiset::JOIN_MULTISET,
317 join_multiset_half::JOIN_MULTISET_HALF,
318 fold_keyed::FOLD_KEYED,
319 reduce_keyed::REDUCE_KEYED,
320 repeat_n::REPEAT_N,
321 lattice_bimorphism::LATTICE_BIMORPHISM,
323 _lattice_fold_batch::_LATTICE_FOLD_BATCH,
324 lattice_fold::LATTICE_FOLD,
325 _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
326 lattice_reduce::LATTICE_REDUCE,
327 map::MAP,
328 union::UNION,
329 multiset_delta::MULTISET_DELTA,
330 next_iteration::NEXT_ITERATION,
331 defer_signal::DEFER_SIGNAL,
332 defer_tick::DEFER_TICK,
333 defer_tick_lazy::DEFER_TICK_LAZY,
334 null::NULL,
335 partition::PARTITION,
336 persist::PERSIST,
337 persist_mut::PERSIST_MUT,
338 persist_mut_keyed::PERSIST_MUT_KEYED,
339 prefix::PREFIX,
340 resolve_futures::RESOLVE_FUTURES,
341 resolve_futures_blocking::RESOLVE_FUTURES_BLOCKING,
342 resolve_futures_blocking_ordered::RESOLVE_FUTURES_BLOCKING_ORDERED,
343 resolve_futures_ordered::RESOLVE_FUTURES_ORDERED,
344 reduce::REDUCE,
345 reduce_no_replay::REDUCE_NO_REPLAY,
346 scan::SCAN,
347 scan_async_blocking::SCAN_ASYNC_BLOCKING,
348 spin::SPIN,
349 sort::SORT,
350 sort_by_key::SORT_BY_KEY,
351 source_file::SOURCE_FILE,
352 source_interval::SOURCE_INTERVAL,
353 source_iter::SOURCE_ITER,
354 source_json::SOURCE_JSON,
355 source_stdin::SOURCE_STDIN,
356 source_stream::SOURCE_STREAM,
357 source_stream_serde::SOURCE_STREAM_SERDE,
358 state::STATE,
359 state_by::STATE_BY,
360 tee::TEE,
361 unique::UNIQUE,
362 unzip::UNZIP,
363 zip::ZIP,
364 zip_longest::ZIP_LONGEST,
365];
366
367pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
369 pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
370 OnceLock::new();
371 OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
372}
373pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
375 if let GraphNode::Operator(operator) = node {
376 find_op_op_constraints(operator)
377 } else {
378 None
379 }
380}
381pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
383 let name = &*operator.name_string();
384 operator_lookup().get(name).copied()
385}
386
387#[derive(Clone)]
389pub struct WriteContextArgs<'a> {
390 pub root: &'a TokenStream,
392 pub context: &'a Ident,
395 pub df_ident: &'a Ident,
399 pub subgraph_id: GraphSubgraphId,
401 pub node_id: GraphNodeId,
403 pub loop_id: Option<GraphLoopId>,
405 pub op_span: Span,
407 pub op_tag: Option<String>,
409 pub work_fn: &'a Ident,
411 pub work_fn_async: &'a Ident,
413
414 pub ident: &'a Ident,
416 pub is_pull: bool,
418 pub inputs: &'a [Ident],
420 pub outputs: &'a [Ident],
422 pub singleton_output_ident: &'a Ident,
424
425 pub op_name: &'static str,
427 pub op_inst: &'a OperatorInstance,
429 pub arguments: &'a Punctuated<Expr, Token![,]>,
435 pub arguments_handles: &'a Punctuated<Expr, Token![,]>,
437}
438impl WriteContextArgs<'_> {
439 pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
445 Ident::new(
446 &format!(
447 "sg_{:?}_node_{:?}_{}",
448 self.subgraph_id.data(),
449 self.node_id.data(),
450 suffix.as_ref(),
451 ),
452 self.op_span,
453 )
454 }
455
456 pub fn persistence_as_state_lifespan(&self, persistence: Persistence) -> Option<TokenStream> {
459 let root = self.root;
460 let variant =
461 persistence.as_state_lifespan_variant(self.subgraph_id, self.loop_id, self.op_span)?;
462 Some(quote_spanned! {self.op_span=>
463 #root::scheduled::StateLifespan::#variant
464 })
465 }
466
467 pub fn persistence_args_disallow_mutable<const N: usize>(
469 &self,
470 diagnostics: &mut Diagnostics,
471 ) -> [Persistence; N] {
472 let len = self.op_inst.generics.persistence_args.len();
473 if 0 != len && 1 != len && N != len {
474 diagnostics.push(Diagnostic::spanned(
475 self.op_span,
476 Level::Error,
477 format!(
478 "The operator `{}` only accepts 0, 1, or {} persistence arguments",
479 self.op_name, N
480 ),
481 ));
482 }
483
484 let default_persistence = if self.loop_id.is_some() {
485 Persistence::None
486 } else {
487 Persistence::Tick
488 };
489 let mut out = [default_persistence; N];
490 self.op_inst
491 .generics
492 .persistence_args
493 .iter()
494 .copied()
495 .cycle() .take(N)
497 .enumerate()
498 .filter(|&(_i, p)| {
499 if p == Persistence::Mutable {
500 diagnostics.push(Diagnostic::spanned(
501 self.op_span,
502 Level::Error,
503 format!(
504 "An implementation of `'{}` does not exist",
505 p.to_str_lowercase()
506 ),
507 ));
508 false
509 } else {
510 true
511 }
512 })
513 .for_each(|(i, p)| {
514 out[i] = p;
515 });
516 out
517 }
518}
519
520pub trait RangeTrait<T>: Send + Sync + Debug
522where
523 T: ?Sized,
524{
525 fn start_bound(&self) -> Bound<&T>;
527 fn end_bound(&self) -> Bound<&T>;
529 fn contains(&self, item: &T) -> bool
531 where
532 T: PartialOrd<T>;
533
534 fn human_string(&self) -> String
536 where
537 T: Display + PartialEq,
538 {
539 match (self.start_bound(), self.end_bound()) {
540 (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
541
542 (Bound::Included(n), Bound::Included(x)) if n == x => {
543 format!("exactly {}", n)
544 }
545 (Bound::Included(n), Bound::Included(x)) => {
546 format!("at least {} and at most {}", n, x)
547 }
548 (Bound::Included(n), Bound::Excluded(x)) => {
549 format!("at least {} and less than {}", n, x)
550 }
551 (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
552 (Bound::Excluded(n), Bound::Included(x)) => {
553 format!("more than {} and at most {}", n, x)
554 }
555 (Bound::Excluded(n), Bound::Excluded(x)) => {
556 format!("more than {} and less than {}", n, x)
557 }
558 (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
559 (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
560 (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
561 }
562 }
563}
564
565impl<R, T> RangeTrait<T> for R
566where
567 R: RangeBounds<T> + Send + Sync + Debug,
568{
569 fn start_bound(&self) -> Bound<&T> {
570 self.start_bound()
571 }
572
573 fn end_bound(&self) -> Bound<&T> {
574 self.end_bound()
575 }
576
577 fn contains(&self, item: &T) -> bool
578 where
579 T: PartialOrd<T>,
580 {
581 self.contains(item)
582 }
583}
584
585#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
587pub enum Persistence {
588 None,
590 Loop,
592 Tick,
594 Static,
596 Mutable,
598}
599impl Persistence {
600 pub fn as_state_lifespan_variant(
602 self,
603 subgraph_id: GraphSubgraphId,
604 loop_id: Option<GraphLoopId>,
605 span: Span,
606 ) -> Option<TokenStream> {
607 match self {
608 Persistence::None => {
609 let sg_ident = subgraph_id.as_ident(span);
610 Some(quote_spanned!(span=> Subgraph(#sg_ident)))
611 }
612 Persistence::Loop => {
613 let loop_ident = loop_id
614 .expect("`Persistence::Loop` outside of a loop context.")
615 .as_ident(span);
616 Some(quote_spanned!(span=> Loop(#loop_ident)))
617 }
618 Persistence::Tick => Some(quote_spanned!(span=> Tick)),
619 Persistence::Static => None,
620 Persistence::Mutable => None,
621 }
622 }
623
624 pub fn to_str_lowercase(self) -> &'static str {
626 match self {
627 Persistence::None => "none",
628 Persistence::Tick => "tick",
629 Persistence::Loop => "loop",
630 Persistence::Static => "static",
631 Persistence::Mutable => "mutable",
632 }
633 }
634}
635
636fn make_missing_runtime_msg(op_name: &str) -> Literal {
638 Literal::string(&format!(
639 "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
640 op_name
641 ))
642}
643
644#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, DocumentedVariants)]
646pub enum OperatorCategory {
647 Map,
649 Filter,
651 Flatten,
653 Fold,
655 KeyedFold,
657 LatticeFold,
659 Persistence,
661 MultiIn,
663 MultiOut,
665 Source,
667 Sink,
669 Control,
671 CompilerFusionOperator,
673 Windowing,
675 Unwindowing,
677}
678impl OperatorCategory {
679 pub fn name(self) -> &'static str {
681 self.get_variant_docs().split_once(":").unwrap().0
682 }
683 pub fn description(self) -> &'static str {
685 self.get_variant_docs().split_once(":").unwrap().1
686 }
687}
688
689#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
691pub enum FloType {
692 Source,
694 Windowing,
696 Unwindowing,
698 NextIteration,
700}