1use std::collections::{BTreeMap, BTreeSet};
4
5use proc_macro2::Span;
6use slotmap::{SecondaryMap, SparseSecondaryMap};
7
8use super::meta_graph::DfirGraph;
9use super::ops::{DelayType, FloType};
10use super::{Color, GraphEdgeId, GraphNode, GraphNodeId, GraphSubgraphId, graph_algorithms};
11use crate::diagnostic::{Diagnostic, Level};
12use crate::union_find::UnionFind;
13
14struct BarrierCrossers {
16 pub edge_barrier_crossers: SecondaryMap<GraphEdgeId, DelayType>,
18 pub singleton_barrier_crossers: Vec<(GraphNodeId, GraphNodeId)>,
20}
21impl BarrierCrossers {
22 fn iter_node_pairs<'a>(
24 &'a self,
25 partitioned_graph: &'a DfirGraph,
26 ) -> impl 'a + Iterator<Item = ((GraphNodeId, GraphNodeId), DelayType)> {
27 let edge_pairs_iter = self
28 .edge_barrier_crossers
29 .iter()
30 .map(|(edge_id, &delay_type)| {
31 let src_dst = partitioned_graph.edge(edge_id);
32 (src_dst, delay_type)
33 });
34 let singleton_pairs_iter = self
35 .singleton_barrier_crossers
36 .iter()
37 .map(|&src_dst| (src_dst, DelayType::Stratum));
38 edge_pairs_iter.chain(singleton_pairs_iter)
39 }
40
41 fn replace_edge(&mut self, old_edge_id: GraphEdgeId, new_edge_id: GraphEdgeId) {
43 if let Some(delay_type) = self.edge_barrier_crossers.remove(old_edge_id) {
44 self.edge_barrier_crossers.insert(new_edge_id, delay_type);
45 }
46 }
47}
48
49fn find_barrier_crossers(partitioned_graph: &DfirGraph) -> BarrierCrossers {
51 let edge_barrier_crossers = partitioned_graph
52 .edges()
53 .filter(|&(_, (_src, dst))| {
54 partitioned_graph.node_loop(dst).is_none()
56 })
57 .filter_map(|(edge_id, (_src, dst))| {
58 let (_src_port, dst_port) = partitioned_graph.edge_ports(edge_id);
59 let op_constraints = partitioned_graph.node_op_inst(dst)?.op_constraints;
60 let input_barrier = (op_constraints.input_delaytype_fn)(dst_port)?;
61 Some((edge_id, input_barrier))
62 })
63 .collect();
64 let singleton_barrier_crossers = partitioned_graph
65 .node_ids()
66 .flat_map(|dst| {
67 partitioned_graph
68 .node_singleton_references(dst)
69 .iter()
70 .flatten()
71 .map(move |&src_ref| (src_ref, dst))
72 })
73 .collect();
74 BarrierCrossers {
75 edge_barrier_crossers,
76 singleton_barrier_crossers,
77 }
78}
79
80fn find_subgraph_unionfind(
81 partitioned_graph: &DfirGraph,
82 barrier_crossers: &BarrierCrossers,
83) -> (UnionFind<GraphNodeId>, BTreeSet<GraphEdgeId>) {
84 let mut node_color = partitioned_graph
89 .node_ids()
90 .filter_map(|node_id| {
91 let op_color = partitioned_graph.node_color(node_id)?;
92 Some((node_id, op_color))
93 })
94 .collect::<SparseSecondaryMap<_, _>>();
95
96 let mut subgraph_unionfind: UnionFind<GraphNodeId> =
97 UnionFind::with_capacity(partitioned_graph.nodes().len());
98
99 let mut handoff_edges: BTreeSet<GraphEdgeId> = partitioned_graph.edge_ids().collect();
102 let mut progress = true;
111 while progress {
112 progress = false;
113 for (edge_id, (src, dst)) in partitioned_graph.edges().collect::<Vec<_>>() {
115 if subgraph_unionfind.same_set(src, dst) {
117 continue;
120 }
121
122 if barrier_crossers
124 .iter_node_pairs(partitioned_graph)
125 .any(|((x_src, x_dst), _)| {
126 (subgraph_unionfind.same_set(x_src, src)
127 && subgraph_unionfind.same_set(x_dst, dst))
128 || (subgraph_unionfind.same_set(x_src, dst)
129 && subgraph_unionfind.same_set(x_dst, src))
130 })
131 {
132 continue;
133 }
134
135 if partitioned_graph.node_loop(src) != partitioned_graph.node_loop(dst) {
137 continue;
138 }
139 if partitioned_graph.node_op_inst(dst).is_some_and(|op_inst| {
141 Some(FloType::NextIteration) == op_inst.op_constraints.flo_type
142 }) {
143 continue;
144 }
145
146 if can_connect_colorize(&mut node_color, src, dst) {
147 subgraph_unionfind.union(src, dst);
150 assert!(handoff_edges.remove(&edge_id));
151 progress = true;
152 }
153 }
154 }
155
156 (subgraph_unionfind, handoff_edges)
157}
158
159fn make_subgraph_collect(
163 partitioned_graph: &DfirGraph,
164 mut subgraph_unionfind: UnionFind<GraphNodeId>,
165) -> SecondaryMap<GraphNodeId, Vec<GraphNodeId>> {
166 let topo_sort = graph_algorithms::topo_sort(
170 partitioned_graph
171 .nodes()
172 .filter(|&(_, node)| !matches!(node, GraphNode::Handoff { .. }))
173 .map(|(node_id, _)| node_id),
174 |v| {
175 partitioned_graph
176 .node_predecessor_nodes(v)
177 .filter(|&pred_id| {
178 let pred = partitioned_graph.node(pred_id);
179 !matches!(pred, GraphNode::Handoff { .. })
180 })
181 },
182 )
183 .expect("Subgraphs are in-out trees.");
184
185 let mut grouped_nodes: SecondaryMap<GraphNodeId, Vec<GraphNodeId>> = Default::default();
186 for node_id in topo_sort {
187 let repr_node = subgraph_unionfind.find(node_id);
188 if !grouped_nodes.contains_key(repr_node) {
189 grouped_nodes.insert(repr_node, Default::default());
190 }
191 grouped_nodes[repr_node].push(node_id);
192 }
193 grouped_nodes
194}
195
196fn make_subgraphs(partitioned_graph: &mut DfirGraph, barrier_crossers: &mut BarrierCrossers) {
200 let (subgraph_unionfind, handoff_edges) =
209 find_subgraph_unionfind(partitioned_graph, barrier_crossers);
210
211 for edge_id in handoff_edges {
213 let (src_id, dst_id) = partitioned_graph.edge(edge_id);
214
215 let src_node = partitioned_graph.node(src_id);
217 let dst_node = partitioned_graph.node(dst_id);
218 if matches!(src_node, GraphNode::Handoff { .. })
219 || matches!(dst_node, GraphNode::Handoff { .. })
220 {
221 continue;
222 }
223
224 let hoff = GraphNode::Handoff {
225 src_span: src_node.span(),
226 dst_span: dst_node.span(),
227 };
228 let (_node_id, out_edge_id) = partitioned_graph.insert_intermediate_node(edge_id, hoff);
229
230 barrier_crossers.replace_edge(edge_id, out_edge_id);
232 }
233
234 let grouped_nodes = make_subgraph_collect(partitioned_graph, subgraph_unionfind);
238 for (_repr_node, member_nodes) in grouped_nodes {
239 partitioned_graph.insert_subgraph(member_nodes).unwrap();
240 }
241}
242
243fn can_connect_colorize(
249 node_color: &mut SparseSecondaryMap<GraphNodeId, Color>,
250 src: GraphNodeId,
251 dst: GraphNodeId,
252) -> bool {
253 let can_connect = match (node_color.get(src), node_color.get(dst)) {
258 (None, None) => false,
261
262 (None, Some(Color::Pull | Color::Comp)) => {
264 node_color.insert(src, Color::Pull);
265 true
266 }
267 (None, Some(Color::Push | Color::Hoff)) => {
268 node_color.insert(src, Color::Push);
269 true
270 }
271
272 (Some(Color::Pull | Color::Hoff), None) => {
274 node_color.insert(dst, Color::Pull);
275 true
276 }
277 (Some(Color::Comp | Color::Push), None) => {
278 node_color.insert(dst, Color::Push);
279 true
280 }
281
282 (Some(Color::Pull), Some(Color::Pull)) => true,
284 (Some(Color::Pull), Some(Color::Comp)) => true,
285 (Some(Color::Pull), Some(Color::Push)) => true,
286
287 (Some(Color::Comp), Some(Color::Pull)) => false,
288 (Some(Color::Comp), Some(Color::Comp)) => false,
289 (Some(Color::Comp), Some(Color::Push)) => true,
290
291 (Some(Color::Push), Some(Color::Pull)) => false,
292 (Some(Color::Push), Some(Color::Comp)) => false,
293 (Some(Color::Push), Some(Color::Push)) => true,
294
295 (Some(Color::Hoff), Some(_)) => false,
297 (Some(_), Some(Color::Hoff)) => false,
298 };
299 can_connect
300}
301
302fn order_subgraphs(
308 partitioned_graph: &mut DfirGraph,
309 barrier_crossers: &BarrierCrossers,
310) -> Result<(), Diagnostic> {
311 let mut sg_preds: BTreeMap<GraphSubgraphId, Vec<GraphSubgraphId>> = Default::default();
313
314 let mut tick_edges: Vec<(GraphEdgeId, DelayType)> = Vec::new();
316
317 for (node_id, node) in partitioned_graph.nodes() {
319 if !matches!(node, GraphNode::Handoff { .. }) {
320 continue;
321 }
322 assert_eq!(1, partitioned_graph.node_successors(node_id).len());
323 let (succ_edge, succ) = partitioned_graph.node_successors(node_id).next().unwrap();
324
325 let succ_edge_delaytype = barrier_crossers
326 .edge_barrier_crossers
327 .get(succ_edge)
328 .copied();
329 if let Some(delay_type @ (DelayType::Tick | DelayType::TickLazy)) = succ_edge_delaytype {
331 tick_edges.push((succ_edge, delay_type));
332 continue;
333 }
334
335 assert_eq!(1, partitioned_graph.node_predecessors(node_id).len());
336 let (_edge_id, pred) = partitioned_graph.node_predecessors(node_id).next().unwrap();
337
338 let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
339 let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
340
341 sg_preds.entry(succ_sg).or_default().push(pred_sg);
342 }
343 for &(pred, succ) in barrier_crossers.singleton_barrier_crossers.iter() {
345 assert_ne!(pred, succ, "TODO(mingwei)");
346 let pred_sg = partitioned_graph.node_subgraph(pred).unwrap();
347 let succ_sg = partitioned_graph.node_subgraph(succ).unwrap();
348 assert_ne!(pred_sg, succ_sg);
349 sg_preds.entry(succ_sg).or_default().push(pred_sg);
350 }
351
352 if let Err(cycle) = graph_algorithms::topo_sort(partitioned_graph.subgraph_ids(), |v| {
354 sg_preds.get(&v).into_iter().flatten().copied()
355 }) {
356 let span = cycle
357 .first()
358 .and_then(|&sg_id| partitioned_graph.subgraph(sg_id).first().copied())
359 .map(|n| partitioned_graph.node(n).span())
360 .unwrap_or_else(Span::call_site);
361 return Err(Diagnostic::spanned(
362 span,
363 Level::Error,
364 "Cyclical dataflow within a tick is not supported. Use `defer_tick()` or `defer_tick_lazy()` to break the cycle across ticks.",
365 ));
366 }
367
368 for (edge_id, delay_type) in tick_edges {
373 let (hoff, _dst) = partitioned_graph.edge(edge_id);
374 assert!(matches!(
375 partitioned_graph.node(hoff),
376 GraphNode::Handoff { .. }
377 ));
378 partitioned_graph.set_handoff_delay_type(hoff, delay_type);
379 }
380 Ok(())
381}
382
383pub fn partition_graph(flat_graph: DfirGraph) -> Result<DfirGraph, Diagnostic> {
387 let mut barrier_crossers = find_barrier_crossers(&flat_graph);
389 let mut partitioned_graph = flat_graph;
390
391 make_subgraphs(&mut partitioned_graph, &mut barrier_crossers);
393
394 order_subgraphs(&mut partitioned_graph, &barrier_crossers)?;
396
397 Ok(partitioned_graph)
398}