Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 99 additions & 1 deletion backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright 2023-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -14,6 +14,7 @@
"""

import logging
from collections import deque
from itertools import count
from typing import Callable, List, Optional, Sequence, Tuple

Expand Down Expand Up @@ -118,6 +119,76 @@
)


def _validate_partition(nodes: set[torch.fx.Node]) -> bool:
"""Check whether a set of nodes can be extracted as a subgraph without cycles.

Perform a BFS from the external users of partition nodes. If any node
reached by BFS is itself inside the partition, then extracting the
partition would create a dependency cycle in the remaining graph.

Args:
nodes: The set of FX nodes that form the partition.

Returns:
True if the partition is valid (no cycles), False otherwise.

"""
outputs: list[torch.fx.Node] = []
for node in nodes:
for user in node.users:
if user not in nodes:
outputs.append(user)

visited: set[torch.fx.Node] = set()
queue = deque(outputs)
while queue:
current = queue.popleft()
if current in visited:
continue
visited.add(current)
if current in nodes:
return False
for user in current.users:
if user not in visited:
queue.append(user)
return True


def _find_connected_components(nodes: set[torch.fx.Node]) -> list[set[torch.fx.Node]]:
"""Find connected components in a set of nodes treating edges as undirected.

Two nodes are connected if one is an input or user of the other and both
are in ``nodes``.

Args:
nodes: The node set to partition into components.

Returns:
A list of disjoint node sets, one per connected component.

"""
remaining = set(nodes)
components: list[set[torch.fx.Node]] = []
while remaining:
seed = next(iter(remaining))
component: set[torch.fx.Node] = set()
queue = deque([seed])
while queue:
node = queue.popleft()
if node in component or node not in remaining:
continue
component.add(node)
for inp in node.all_input_nodes:
if inp in remaining and inp not in component:
queue.append(inp)
for user in node.users:
if user in remaining and user not in component:
queue.append(user)
remaining -= component
components.append(component)
return components


class TOSAPartitioner(Partitioner):
"""Partition an exported program into TOSA-delegable subgraphs.

Expand Down Expand Up @@ -255,6 +326,32 @@
reporter,
)

# After de-tagging, the remaining tagged nodes may form
# dependency cycles. This happens when models contain complex
# attention blocks (e.g. MobileViT) where Q/DQ nodes act as
# bridges between partition segments. Detect such cycles and
# split the partition into valid connected components.
surviving = {
n for n in partition.nodes if is_partitioned(n, tag)
}
if surviving and not _validate_partition(surviving):
components = _find_connected_components(surviving)
logger.info(
f"Partition {tag} has dependency cycle after Q/DQ "
f"de-tagging. Splitting into {len(components)} "
f"sub-partition(s)."
)
# Remove the original tag from all nodes
for node in surviving:
del node.meta["delegation_tag"]
tags.remove(tag)
# Re-tag each connected component as a new partition
for component in components:
new_tag = f"tag{next(tag_iterator)}"
tags.add(new_tag)
for node in component:
node.meta["delegation_tag"] = new_tag

# Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation."
is_nocompute_partition = all(
_is_noop_clone(node)
Expand All @@ -272,7 +369,8 @@
partition,
reporter,
)
tags.remove(tag)
if tag in tags:
tags.remove(tag)
return tags

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
Expand Down
Loading