Skip to content

Conversation

@aseembits93
Copy link
Contributor

Summary

  • Fix instrumentation of PyTorch nn.Module forward method when called via instance (e.g., model(input_data))
  • Add special handling for the pattern: model = ClassName(...); model(input_data) where model(input_data) internally calls forward()
  • Resolves "Ignoring test case that passed but had no runtime" error when optimizing forward methods

Problem

When running codeflash --function AlexNet.forward, tests with this pattern weren't being instrumented:

model = AlexNet(num_classes=10)
result = model(input_data)  # calls __call__ which invokes forward()

The instrumentation was looking for direct calls to forward or AlexNet, but model(input_data) matched neither.

Solution

  1. Added collect_instance_variables() to track variables assigned from class instantiations
  2. Modified find_and_update_line_node() to wrap calls to instance variables when optimizing forward methods
  3. Added test case specifically for this PyTorch pattern

Test plan

  • Added test_pytorch_forward_method_instrumentation test case
  • All existing instrumentation tests pass (19/19)
  • Verified with actual codeflash command - runtime is now properly measured

🤖 Generated with Claude Code

When optimizing a `forward` method on a class (e.g., AlexNet.forward),
the test pattern `model = AlexNet(...); model(input_data)` wasn't being
instrumented because the call `model(input_data)` didn't match the
expected function name "forward".

This fix adds special handling for the PyTorch nn.Module pattern:
- Collect variable names assigned from class instantiations
- Also wrap calls to those instance variables when optimizing `forward`

Fixes the "Ignoring test case that passed but had no runtime" error
when running codeflash on PyTorch model forward methods.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
if node.name.startswith("test_"):
# Collect instance variables for forward method instrumentation (PyTorch pattern)
self.collect_instance_variables(node)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (non-blocking): instance_variable_names is accumulated across all test functions without being cleared. If a file has multiple test functions, variable names collected from test_a will persist when processing test_b. This could cause false-positive instrumentation if a variable name from one test happens to be called in another.

Consider clearing the set at the start of each test function:

Suggested change
self.collect_instance_variables(node)
self.instance_variable_names.clear()
self.collect_instance_variables(node)

@claude
Copy link

claude bot commented Feb 6, 2026

PR Review Summary

Prek Checks

✅ All prek checks pass (ruff check, ruff format). No fixes needed.

Mypy

⚠️ 122 mypy errors found in the two changed files, but these are all pre-existing — none are introduced by this PR's changes. The file codeflash/code_utils/instrument_existing_tests.py is not in mypy_allowlist.txt and is not checked in CI.

Code Review

Existing issue still open:

  • instance_variable_names is accumulated across test functions without being cleared (inline comment from previous review). This could cause false-positive instrumentation when a variable name from one test matches a call in another test. The fix is to call self.instance_variable_names.clear() before self.collect_instance_variables(node) in visit_FunctionDef.

No new critical issues found in this synchronize push. The implementation correctly:

  • Detects PyTorch nn.Module pattern (model = ClassName(...) followed by model(input_data))
  • Only activates for forward method optimization (guarded by self.only_function_name == "forward")
  • Properly traverses nested AST nodes for instance variable collection
  • Includes a well-structured test covering the core use case

Test Coverage

File Stmts Miss Coverage
codeflash/code_utils/instrument_existing_tests.py 452 197 56.4%
tests/test_instrument_tests.py 1192 590 50.5%

New code coverage: The core new logic paths (instance variable collection, is_instance_call check, modified filter, collect_instance_variables call site) are all covered by the new test. Only deeply nested traversal paths in collect_instance_variables for orelse, finalbody, and handlers blocks (lines 120-127) are uncovered — these are edge cases that don't affect correctness verification.

Note: 10 tests in this file fail due to missing CODEFLASH_API_KEY environment variable (pre-existing, not related to this PR). Coverage numbers reflect only the 9 passing tests.

Codeflash Optimization PRs

No optimization PRs are eligible for merge — all have failing CI checks.


Last updated: 2026-02-11T

The optimized code achieves a **768% speedup** (from 1.30ms to 150μs) by replacing the expensive `ast.walk()` traversal with a targeted manual traversal strategy.

**Key Optimization:**

The original code uses `ast.walk(func_node)`, which recursively visits *every* node in the entire AST tree - including all expression nodes, operators, literals, and other irrelevant node types. The line profiler shows this single loop consumed 87.3% of the execution time (9.2ms out of 10.5ms).

The optimized version implements a **work-list algorithm** that only traverses statement nodes (body, orelse, finalbody, handlers). This dramatically reduces the number of nodes examined:
- Original: 1,889 nodes visited per call
- Optimized: ~317 nodes visited per call (83% reduction)

**Why This Works:**

1. **Targeted traversal**: Assignment statements (`ast.Assign`) can only appear as statements, not as expressions buried deep in the tree. By only following statement-level structure (`body`, `orelse`, etc.), we skip visiting thousands of irrelevant expression nodes.

2. **Cache-friendly**: Local variables `class_name` and `instance_vars` eliminate repeated `self.` attribute lookups, reducing pointer indirection.

3. **Early filtering**: The manual stack-based approach allows us to skip entire branches of the AST that can't contain assignments.

**Performance Impact by Test Case:**

- Simple cases (single assignment): ~500-600% faster
- Complex nested cases: ~429% faster  
- Large-scale scenario (300 assignments): **807% faster** - showing the optimization scales particularly well with code complexity

The optimization preserves all functionality (same nodes discovered, same instance variables collected) while dramatically reducing the algorithmic complexity from O(all_nodes) to O(statement_nodes).
@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Feb 6, 2026

⚡️ Codeflash found optimizations for this PR

📄 769% (7.69x) speedup for InjectPerfOnly.collect_instance_variables in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 1.30 milliseconds 150 microseconds (best of 15 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch fix/pytorch-forward-method-instrumentation).

Static Badge

…2026-02-06T22.39.42

⚡️ Speed up method `InjectPerfOnly.collect_instance_variables` by 769% in PR #1418 (`fix/pytorch-forward-method-instrumentation`)
@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Feb 11, 2026

@KRRT7
Copy link
Collaborator

KRRT7 commented Feb 11, 2026

@claude fix the mypy type issues and push

1 similar comment
@KRRT7
Copy link
Collaborator

KRRT7 commented Feb 11, 2026

@claude fix the mypy type issues and push

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants