diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 27f24e5958d..01d5e048cb3 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -14,6 +14,7 @@ """ import logging +from collections import deque from itertools import count from typing import Callable, List, Optional, Sequence, Tuple @@ -118,6 +119,76 @@ def reject_partition( ) +def _validate_partition(nodes: set[torch.fx.Node]) -> bool: + """Check whether a set of nodes can be extracted as a subgraph without cycles. + + Perform a BFS from the external users of partition nodes. If any node + reached by BFS is itself inside the partition, then extracting the + partition would create a dependency cycle in the remaining graph. + + Args: + nodes: The set of FX nodes that form the partition. + + Returns: + True if the partition is valid (no cycles), False otherwise. + + """ + outputs: list[torch.fx.Node] = [] + for node in nodes: + for user in node.users: + if user not in nodes: + outputs.append(user) + + visited: set[torch.fx.Node] = set() + queue = deque(outputs) + while queue: + current = queue.popleft() + if current in visited: + continue + visited.add(current) + if current in nodes: + return False + for user in current.users: + if user not in visited: + queue.append(user) + return True + + +def _find_connected_components(nodes: set[torch.fx.Node]) -> list[set[torch.fx.Node]]: + """Find connected components in a set of nodes treating edges as undirected. + + Two nodes are connected if one is an input or user of the other and both + are in ``nodes``. + + Args: + nodes: The node set to partition into components. + + Returns: + A list of disjoint node sets, one per connected component. + + """ + remaining = set(nodes) + components: list[set[torch.fx.Node]] = [] + while remaining: + seed = next(iter(remaining)) + component: set[torch.fx.Node] = set() + queue = deque([seed]) + while queue: + node = queue.popleft() + if node in component or node not in remaining: + continue + component.add(node) + for inp in node.all_input_nodes: + if inp in remaining and inp not in component: + queue.append(inp) + for user in node.users: + if user in remaining and user not in component: + queue.append(user) + remaining -= component + components.append(component) + return components + + class TOSAPartitioner(Partitioner): """Partition an exported program into TOSA-delegable subgraphs. @@ -255,6 +326,32 @@ def _tag_module( # noqa reporter, ) + # After de-tagging, the remaining tagged nodes may form + # dependency cycles. This happens when models contain complex + # attention blocks (e.g. MobileViT) where Q/DQ nodes act as + # bridges between partition segments. Detect such cycles and + # split the partition into valid connected components. + surviving = { + n for n in partition.nodes if is_partitioned(n, tag) + } + if surviving and not _validate_partition(surviving): + components = _find_connected_components(surviving) + logger.info( + f"Partition {tag} has dependency cycle after Q/DQ " + f"de-tagging. Splitting into {len(components)} " + f"sub-partition(s)." + ) + # Remove the original tag from all nodes + for node in surviving: + del node.meta["delegation_tag"] + tags.remove(tag) + # Re-tag each connected component as a new partition + for component in components: + new_tag = f"tag{next(tag_iterator)}" + tags.add(new_tag) + for node in component: + node.meta["delegation_tag"] = new_tag + # Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation." is_nocompute_partition = all( _is_noop_clone(node) @@ -272,7 +369,8 @@ def _tag_module( # noqa partition, reporter, ) - tags.remove(tag) + if tag in tags: + tags.remove(tag) return tags def partition(self, exported_program: ExportedProgram) -> PartitionResult: