Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@
def torchvision_roi_align(
input,
boxes,
output_size: Sequence[int],
spatial_scale: float = 1.0,
spatial_scale: float,
pooled_height: int,
pooled_width: int,
sampling_ratio: int = -1,
aligned: bool = False,
):
"""roi_align(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = -1, aligned: bool = False) -> torch.Tensor"""
pooled_height, pooled_width = output_size
"""roi_align(input, boxes, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned)"""

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

batch_indices = _process_batch_indices_for_roi_align(boxes)
rois_coords = _process_rois_for_roi_align(boxes)
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
Expand All @@ -79,7 +79,6 @@
sampling_ratio=sampling_ratio,
)


@torch_op("torchvision::roi_pool", trace_only=True)
def torchvision_roi_pool(input, boxes, output_size: Sequence[int], spatial_scale: float = 1.0):
"""roi_pool(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0) -> torch.Tensor"""
Expand Down
50 changes: 50 additions & 0 deletions tests/function_libs/torch_lib/ops/vision_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning test

27-31: Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/format Warning test

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning test

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.

import unittest

Check warning

Code scanning / lintrunner

RUFF/I001 Warning test

Import block is un-sorted or un-formatted.
See https://docs.astral.sh/ruff/rules/unsorted-imports
import torch
import torchvision
from torch.onnx import export
import os

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note test

standard import "os" should be placed before third party imports "torch", "torchvision", "torch.onnx.export" (wrong-import-order)
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order

class VisionOperatorTest(unittest.TestCase):
def setUp(self):
self.model_path = "roi_align_test.onnx"

def tearDown(self):
if os.path.exists(self.model_path):
os.remove(self.model_path)

def test_roi_align_export_with_seven_arguments(self):
"""
Tests that torchvision::roi_align exports correctly with 7 positional arguments.
This covers the signature change where output_size is decomposed into

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning test

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

pooled_height and pooled_width.
"""
class RoiAlignModel(torch.nn.Module):
def forward(self, x, boxes):
return torchvision.ops.roi_align(
x,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

boxes,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

output_size=(7, 7),

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

spatial_scale=0.5,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

sampling_ratio=2,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning test

aligned=True
)

# Create dummy inputs: (N, C, H, W) and (K, 5)
x = torch.randn(1, 3, 32, 32, dtype=torch.float32)
boxes = torch.tensor([[0, 0, 0, 10, 10]], dtype=torch.float32)
model = RoiAlignModel().eval()

try:
export(model, (x, boxes), self.model_path)
export_success = True
except Exception as e:

Check warning

Code scanning / lintrunner

PYLINT/W0718 Warning test

Catching too general exception Exception (broad-exception-caught)
See broad-exception-caught. To disable, use # pylint: disable=broad-exception-caught
export_success = False
self.fail(f"torch.onnx.export failed for roi_align: {e}")

self.assertTrue(export_success)

if __name__ == "__main__":
unittest.main()
Loading