-
Notifications
You must be signed in to change notification settings - Fork 21
fix: instrument PyTorch nn.Module forward method calls via instance #1418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
When optimizing a `forward` method on a class (e.g., AlexNet.forward), the test pattern `model = AlexNet(...); model(input_data)` wasn't being instrumented because the call `model(input_data)` didn't match the expected function name "forward". This fix adds special handling for the PyTorch nn.Module pattern: - Collect variable names assigned from class instantiations - Also wrap calls to those instance variables when optimizing `forward` Fixes the "Ignoring test case that passed but had no runtime" error when running codeflash on PyTorch model forward methods. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
| def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef: | ||
| if node.name.startswith("test_"): | ||
| # Collect instance variables for forward method instrumentation (PyTorch pattern) | ||
| self.collect_instance_variables(node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit (non-blocking): instance_variable_names is accumulated across all test functions without being cleared. If a file has multiple test functions, variable names collected from test_a will persist when processing test_b. This could cause false-positive instrumentation if a variable name from one test happens to be called in another.
Consider clearing the set at the start of each test function:
| self.collect_instance_variables(node) | |
| self.instance_variable_names.clear() | |
| self.collect_instance_variables(node) |
PR Review SummaryPrek Checks✅ All prek checks pass (ruff check, ruff format). No fixes needed. Mypy
Code ReviewExisting issue still open:
No new critical issues found in this synchronize push. The implementation correctly:
Test Coverage
New code coverage: The core new logic paths (instance variable collection, Note: 10 tests in this file fail due to missing Codeflash Optimization PRsNo optimization PRs are eligible for merge — all have failing CI checks. Last updated: 2026-02-11T |
The optimized code achieves a **768% speedup** (from 1.30ms to 150μs) by replacing the expensive `ast.walk()` traversal with a targeted manual traversal strategy. **Key Optimization:** The original code uses `ast.walk(func_node)`, which recursively visits *every* node in the entire AST tree - including all expression nodes, operators, literals, and other irrelevant node types. The line profiler shows this single loop consumed 87.3% of the execution time (9.2ms out of 10.5ms). The optimized version implements a **work-list algorithm** that only traverses statement nodes (body, orelse, finalbody, handlers). This dramatically reduces the number of nodes examined: - Original: 1,889 nodes visited per call - Optimized: ~317 nodes visited per call (83% reduction) **Why This Works:** 1. **Targeted traversal**: Assignment statements (`ast.Assign`) can only appear as statements, not as expressions buried deep in the tree. By only following statement-level structure (`body`, `orelse`, etc.), we skip visiting thousands of irrelevant expression nodes. 2. **Cache-friendly**: Local variables `class_name` and `instance_vars` eliminate repeated `self.` attribute lookups, reducing pointer indirection. 3. **Early filtering**: The manual stack-based approach allows us to skip entire branches of the AST that can't contain assignments. **Performance Impact by Test Case:** - Simple cases (single assignment): ~500-600% faster - Complex nested cases: ~429% faster - Large-scale scenario (300 assignments): **807% faster** - showing the optimization scales particularly well with code complexity The optimization preserves all functionality (same nodes discovered, same instance variables collected) while dramatically reducing the algorithmic complexity from O(all_nodes) to O(statement_nodes).
⚡️ Codeflash found optimizations for this PR📄 769% (7.69x) speedup for
|
…2026-02-06T22.39.42 ⚡️ Speed up method `InjectPerfOnly.collect_instance_variables` by 769% in PR #1418 (`fix/pytorch-forward-method-instrumentation`)
|
This PR is now faster! 🚀 @KRRT7 accepted my optimizations from: |
|
@claude fix the mypy type issues and push |
1 similar comment
|
@claude fix the mypy type issues and push |
Summary
nn.Moduleforward method when called via instance (e.g.,model(input_data))model = ClassName(...); model(input_data)wheremodel(input_data)internally callsforward()forwardmethodsProblem
When running
codeflash --function AlexNet.forward, tests with this pattern weren't being instrumented:The instrumentation was looking for direct calls to
forwardorAlexNet, butmodel(input_data)matched neither.Solution
collect_instance_variables()to track variables assigned from class instantiationsfind_and_update_line_node()to wrap calls to instance variables when optimizingforwardmethodsTest plan
test_pytorch_forward_method_instrumentationtest case🤖 Generated with Claude Code