From 5f597abf1414603bff95315f89a6b8aeb8f7a5d3 Mon Sep 17 00:00:00 2001 From: beomwookang Date: Mon, 16 Mar 2026 16:07:11 +0900 Subject: [PATCH] fix(arm): validate partitions for dependency cycles after Q/DQ de-tagging `_detag_boundary_nodes` removes Q/DQ nodes from partition boundaries after `CapabilityBasedPartitioner` has produced cycle-free partitions. However, this de-tagging can introduce dependency cycles for models with complex attention blocks (e.g. MobileViT, where CNN and Transformer ops are grouped into a single large partition). The cycle occurs because removing Q/DQ bridge nodes creates paths that exit the partition and re-enter it through the now-unpartitioned nodes, making it impossible to extract the partition as a valid subgraph. This change adds cycle validation after `_detag_boundary_nodes`. When a cycle is detected, the partition is split into connected components of the surviving (still-tagged) nodes. Each component becomes a separate partition that is individually cycle-free after de-tagging. - Add `_validate_partition()`: BFS-based cycle detection (same algorithm as `torch.fx.passes.utils.fuser_utils.validate_partition`) - Add `_find_connected_components()`: undirected graph traversal to split surviving nodes into disjoint sub-partitions - Guard the nocompute-partition `tags.remove()` against already-removed tags from the cycle-split path Tested with MobileViT-S on Ethos-U85: previously failed with `AssertionError: Invalid partition, found dependency cycles`, now successfully produces a .pte file (5.7 MB). Nine attention-block partitions are each split into 3 sub-partitions. All sub-partitions remain on NPU (no CPU fallback). Existing CNN-only models (ResNet, MobileNetV2, EfficientNet) are unaffected as their partitions have no cycles after de-tagging. --- backends/arm/tosa/partitioner.py | 100 ++++++++++++++++++++++++++++++- 1 file changed, 99 insertions(+), 1 deletion(-) diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 27f24e5958d..01d5e048cb3 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -14,6 +14,7 @@ """ import logging +from collections import deque from itertools import count from typing import Callable, List, Optional, Sequence, Tuple @@ -118,6 +119,76 @@ def reject_partition( ) +def _validate_partition(nodes: set[torch.fx.Node]) -> bool: + """Check whether a set of nodes can be extracted as a subgraph without cycles. + + Perform a BFS from the external users of partition nodes. If any node + reached by BFS is itself inside the partition, then extracting the + partition would create a dependency cycle in the remaining graph. + + Args: + nodes: The set of FX nodes that form the partition. + + Returns: + True if the partition is valid (no cycles), False otherwise. + + """ + outputs: list[torch.fx.Node] = [] + for node in nodes: + for user in node.users: + if user not in nodes: + outputs.append(user) + + visited: set[torch.fx.Node] = set() + queue = deque(outputs) + while queue: + current = queue.popleft() + if current in visited: + continue + visited.add(current) + if current in nodes: + return False + for user in current.users: + if user not in visited: + queue.append(user) + return True + + +def _find_connected_components(nodes: set[torch.fx.Node]) -> list[set[torch.fx.Node]]: + """Find connected components in a set of nodes treating edges as undirected. + + Two nodes are connected if one is an input or user of the other and both + are in ``nodes``. + + Args: + nodes: The node set to partition into components. + + Returns: + A list of disjoint node sets, one per connected component. + + """ + remaining = set(nodes) + components: list[set[torch.fx.Node]] = [] + while remaining: + seed = next(iter(remaining)) + component: set[torch.fx.Node] = set() + queue = deque([seed]) + while queue: + node = queue.popleft() + if node in component or node not in remaining: + continue + component.add(node) + for inp in node.all_input_nodes: + if inp in remaining and inp not in component: + queue.append(inp) + for user in node.users: + if user in remaining and user not in component: + queue.append(user) + remaining -= component + components.append(component) + return components + + class TOSAPartitioner(Partitioner): """Partition an exported program into TOSA-delegable subgraphs. @@ -255,6 +326,32 @@ def _tag_module( # noqa reporter, ) + # After de-tagging, the remaining tagged nodes may form + # dependency cycles. This happens when models contain complex + # attention blocks (e.g. MobileViT) where Q/DQ nodes act as + # bridges between partition segments. Detect such cycles and + # split the partition into valid connected components. + surviving = { + n for n in partition.nodes if is_partitioned(n, tag) + } + if surviving and not _validate_partition(surviving): + components = _find_connected_components(surviving) + logger.info( + f"Partition {tag} has dependency cycle after Q/DQ " + f"de-tagging. Splitting into {len(components)} " + f"sub-partition(s)." + ) + # Remove the original tag from all nodes + for node in surviving: + del node.meta["delegation_tag"] + tags.remove(tag) + # Re-tag each connected component as a new partition + for component in components: + new_tag = f"tag{next(tag_iterator)}" + tags.add(new_tag) + for node in component: + node.meta["delegation_tag"] = new_tag + # Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation." is_nocompute_partition = all( _is_noop_clone(node) @@ -272,7 +369,8 @@ def _tag_module( # noqa partition, reporter, ) - tags.remove(tag) + if tag in tags: + tags.remove(tag) return tags def partition(self, exported_program: ExportedProgram) -> PartitionResult: