-
Notifications
You must be signed in to change notification settings - Fork 106
[torch_lib] Fix torchvision_roi_align signature mismatch for PyTorch 2.10+ #2848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,13 +56,13 @@ | |
| def torchvision_roi_align( | ||
| input, | ||
| boxes, | ||
| output_size: Sequence[int], | ||
| spatial_scale: float = 1.0, | ||
| spatial_scale: float, | ||
| pooled_height: int, | ||
| pooled_width: int, | ||
| sampling_ratio: int = -1, | ||
| aligned: bool = False, | ||
| ): | ||
| """roi_align(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = -1, aligned: bool = False) -> torch.Tensor""" | ||
| pooled_height, pooled_width = output_size | ||
| """roi_align(input, boxes, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned)""" | ||
Check warningCode scanning / lintrunner RUFF/W291 Warning
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
| batch_indices = _process_batch_indices_for_roi_align(boxes) | ||
| rois_coords = _process_rois_for_roi_align(boxes) | ||
| coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel" | ||
|
|
@@ -79,7 +79,6 @@ | |
| sampling_ratio=sampling_ratio, | ||
| ) | ||
|
|
||
|
|
||
| @torch_op("torchvision::roi_pool", trace_only=True) | ||
| def torchvision_roi_pool(input, boxes, output_size: Sequence[int], spatial_scale: float = 1.0): | ||
| """roi_pool(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0) -> torch.Tensor""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning test
27-31: Trailing whitespace
Check warningCode scanning / lintrunner RUFF/format Warning test
Run lintrunner -a to apply this patch.
Check warningCode scanning / lintrunner RUFF-FORMAT/format Warning test
Run lintrunner -a to apply this patch.
|
||
| # Licensed under the MIT License. | ||
|
|
||
| import unittest | ||
Check warningCode scanning / lintrunner RUFF/I001 Warning test
Import block is un-sorted or un-formatted.
See https://docs.astral.sh/ruff/rules/unsorted-imports |
||
| import torch | ||
| import torchvision | ||
| from torch.onnx import export | ||
| import os | ||
Check noticeCode scanning / lintrunner PYLINT/C0411 Note test
standard import "os" should be placed before third party imports "torch", "torchvision", "torch.onnx.export" (wrong-import-order)
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order |
||
|
|
||
| class VisionOperatorTest(unittest.TestCase): | ||
| def setUp(self): | ||
| self.model_path = "roi_align_test.onnx" | ||
|
|
||
| def tearDown(self): | ||
| if os.path.exists(self.model_path): | ||
| os.remove(self.model_path) | ||
|
|
||
| def test_roi_align_export_with_seven_arguments(self): | ||
| """ | ||
| Tests that torchvision::roi_align exports correctly with 7 positional arguments. | ||
| This covers the signature change where output_size is decomposed into | ||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning test
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W291 Warning test
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
| pooled_height and pooled_width. | ||
| """ | ||
| class RoiAlignModel(torch.nn.Module): | ||
| def forward(self, x, boxes): | ||
| return torchvision.ops.roi_align( | ||
| x, | ||
Check warningCode scanning / lintrunner RUFF/W291 Warning test
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
| boxes, | ||
Check warningCode scanning / lintrunner RUFF/W291 Warning test
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
| output_size=(7, 7), | ||
Check warningCode scanning / lintrunner RUFF/W291 Warning test
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
| spatial_scale=0.5, | ||
Check warningCode scanning / lintrunner RUFF/W291 Warning test
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
| sampling_ratio=2, | ||
Check warningCode scanning / lintrunner RUFF/W291 Warning test
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
| aligned=True | ||
| ) | ||
|
|
||
| # Create dummy inputs: (N, C, H, W) and (K, 5) | ||
| x = torch.randn(1, 3, 32, 32, dtype=torch.float32) | ||
| boxes = torch.tensor([[0, 0, 0, 10, 10]], dtype=torch.float32) | ||
| model = RoiAlignModel().eval() | ||
|
|
||
| try: | ||
| export(model, (x, boxes), self.model_path) | ||
| export_success = True | ||
| except Exception as e: | ||
Check warningCode scanning / lintrunner PYLINT/W0718 Warning test
Catching too general exception Exception (broad-exception-caught)
See broad-exception-caught. To disable, use # pylint: disable=broad-exception-caught |
||
| export_success = False | ||
| self.fail(f"torch.onnx.export failed for roi_align: {e}") | ||
|
|
||
| self.assertTrue(export_success) | ||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
Check warning
Code scanning / lintrunner
EDITORCONFIG-CHECKER/editorconfig Warning