Skip to content

Conversation

@HeshamHM28
Copy link
Contributor

No description provided.

@HeshamHM28 HeshamHM28 marked this pull request as ready for review February 6, 2026 17:20
@HeshamHM28
Copy link
Contributor Author

@claude review this

@claude
Copy link

claude bot commented Feb 6, 2026

Claude finished @HeshamHM28's task in 2m 8s —— View job


Review of PR #1390: Kryo Serialization

  • Gather context and read changed files
  • Review KryoPlaceholder.java and KryoPlaceholderAccessException.java
  • Review Serializer.java changes
  • Review Comparator.java changes
  • Review CodeFlash.java and ResultWriter.java changes
  • Review test files
  • Review instrumentation.py changes
  • Post comprehensive review feedback

Overall Summary

This PR replaces the previous serialization approach with Kryo-based binary serialization for the Java runtime. The main components are:

  1. Serializer.java - Kryo-based serializer with recursive fallback for unserializable objects
  2. KryoPlaceholder.java / KryoPlaceholderAccessException.java - Placeholder pattern for objects that can't be serialized
  3. Comparator.java - Deep object comparison with placeholder rejection
  4. instrumentation.py - Improvements for primitive array cast handling and void method call handling

The design is sound overall - the "serialize, fallback to recursive processing, replace with placeholder" pattern mirrors Python's dill/pickle_patcher approach. Test coverage is extensive.


Bugs / Issues

1. Critical: _get_method_call_pattern is defined twice (instrumentation.py:768-781)

The function _get_method_call_pattern is defined twice at the bottom of the file - lines 768-773 and 776-781. The second definition silently overrides the first. Both have identical implementations so there's no functional bug today, but the duplicate should be removed.

Fix this →

2. Bug: System.getenv() does not have a getOrDefault method (instrumentation.py:552)

In _add_timing_instrumentation, the generated Java code uses:

int _cf_innerIterations = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));

System.getenv() returns Map<String, String> which does support getOrDefault in Java 8+, so this technically works. However, it's inconsistent with the pattern used for CODEFLASH_LOOP_INDEX one line above (System.getenv("CODEFLASH_LOOP_INDEX")). More importantly, if CODEFLASH_INNER_ITERATIONS is set to an empty string, Integer.parseInt("") will throw. Consider using a null-check pattern like the one used in CodeFlash.java:getWarmupIterations().

3. Potential ConcurrentModificationException in isKnownUnserializable (Serializer.java:266-269)

for (Class<?> unserializable : UNSERIALIZABLE_TYPES) {
    if (unserializable.isAssignableFrom(clazz)) {
        UNSERIALIZABLE_TYPES.add(clazz); // Cache for future <-- modifying during iteration
        return true;
    }
}

Iterating over ConcurrentHashMap.newKeySet() and adding to it simultaneously is technically safe (weakly consistent iterators), but the newly added element may or may not be seen by the current iteration. Consider collecting the type to add and adding it after the loop to be clearer about intent.

4. Circular reference tracking is removed too early (Serializer.java:253-255)

} finally {
    seen.remove(obj);
}

The seen map is used for circular reference detection, but the finally block removes the object right away. This means if the same object appears in sibling branches of the object graph (diamond pattern), it won't be detected as "already seen" by the second branch. This allows Kryo's reference tracking to handle it but means the seen map underdetects. For diamond references this is actually fine (and handled in tests), but the comment in recursiveProcess should clarify this is intentional to avoid false circular-reference detections.

5. Known limitation: Circular references between collections cause StackOverflowError (SerializerEdgeCaseTest.java:654)

The test documents this:

assertThrows(StackOverflowError.class, () -> {
    Serializer.serialize(list1);
});

This is documented as a known limitation, which is fine for now, but StackOverflowError is an Error, not an Exception, and catching it in production can leave the JVM in an unpredictable state. Worth adding a depth guard in the Kryo direct serialization path or catching StackOverflowError in tryDirectSerialize.

6. Void method call handling has an edge case (instrumentation.py:384-386)

stripped_new = new_line.strip().rstrip(';').strip()
if stripped_new and stripped_new != var_name and stripped_new != var_with_cast:
    wrapped_body_lines.append(new_line)

When there are multiple matches in a single line, var_name and var_with_cast only reference the last match's variable. If the line originally had two void calls, the check against only the last variable name would fail to detect that the entire line is just variable references.


Design Suggestions

7. Performance: serialize-then-deserialize verification is expensive (Serializer.java:206-214)

For simple containers, the code serializes, then immediately deserializes to verify correctness:

byte[] serialized = tryDirectSerialize(obj);
if (serialized != null) {
    try {
        deserialize(serialized); // wasteful if we just want to check serializability
        return obj;
    }
}

This doubles the cost for every simple container. Consider just trying directSerialize and treating success as sufficient, or only doing the roundtrip verification in debug/test mode.

8. Comparator.compareMaps is O(n²) (Comparator.java:394-424)

The current implementation iterates over all entries in newMap for each entry in orig, making it O(n²) for maps with custom key types. For typical use with small maps this is fine, but for large maps it could be slow. Consider a comment documenting this tradeoff.

9. Comparator.compareSets creates a new IdentityHashMap copy per element (Comparator.java:367)

if (compareInternal(elem1, elem2, new IdentityHashMap<>(seen))) {

Creating a copy of the seen map for each comparison attempt is O(n * m * sizeof(seen)). This is correct (prevents cross-contamination of circular reference tracking between different comparison paths) but worth being aware of for large sets with complex objects.

10. ThreadLocal Kryo instances are never cleaned up (Serializer.java:40-62)

The ThreadLocal<Kryo> instance will keep Kryo objects alive for the lifetime of each thread. In a server or long-running environment, this is fine. In a test environment with many short-lived threads, this could leak memory. Consider adding a cleanup() method or documenting this behavior.


Code Quality

11. Comparator class name shadows java.util.Comparator (Comparator.java:23)

The class name Comparator shadows java.util.Comparator, which could cause confusion and import conflicts in client code. Consider renaming to DeepComparator or ObjectComparator.

12. KryoPlaceholder.equals only compares objType and path (KryoPlaceholder.java:106-112)

The equals/hashCode implementation only uses objType and path, ignoring objStr and errorMsg. This is documented by the test but could surprise users who expect full equality. A comment on the class explaining why would be helpful.

13. Test files have good coverage

The test suite is comprehensive with:

  • Basic roundtrip tests for all primitive types, collections, maps, arrays
  • Edge cases for BigDecimal, BigInteger, NaN, Infinity, circular references
  • Placeholder handling and access detection
  • SQLite integration tests
  • Known limitation documentation

instrumentation.py Changes

The changes add two improvements:

  1. Primitive array cast inference - When assertArrayEquals is used with primitive arrays like int[], the captured Object result is now cast back to the correct type. This is correct and handles a real issue with Java's type system.

  2. Void method call filtering - When a method call like BubbleSort.bubbleSort(original) is captured, the original line would become just _cf_result1_1; which is an invalid statement. The new code correctly filters these out.

Both changes are well-motivated with clear logic.


Summary of Recommendations

Priority Issue Location
High Remove duplicate _get_method_call_pattern function instrumentation.py:776-781
High StackOverflowError on circular collection references is unsafe Serializer.java
Medium isKnownUnserializable modifies set during iteration Serializer.java:268
Medium Comparator class name shadows java.util.Comparator Comparator.java:23
Low Serialize-then-deserialize verification is expensive Serializer.java:206-214
Low compareMaps is O(n²) Comparator.java:394
Low Document KryoPlaceholder.equals intentional behavior KryoPlaceholder.java:106

Overall this is solid work with a thoughtful design. The placeholder pattern, type preservation for collections, and comprehensive test coverage are well done. The main items to address before merging are the duplicate function in instrumentation.py and considering hardening against StackOverflowError in the serializer.

@HeshamHM28
Copy link
Contributor Author

@claude run prek and push

@claude
Copy link

claude bot commented Feb 6, 2026

Claude finished @HeshamHM28's task in 3m 5s —— View job


Running pre-commit and pushing

  • Gather context and check changed files
  • Fix duplicate _get_method_call_pattern function definition
  • Commit and push changes

Summary

I was unable to run uv run pre-commit run --all-files directly because the Bash tool doesn't have permission to execute uv, ruff, or pre-commit commands. However, I:

  1. Fixed the duplicate _get_method_call_pattern function in codeflash/languages/java/instrumentation.py (lines 776-781 were an exact duplicate of lines 768-773) — this was the only Python lint issue identified in the PR changes.
  2. Committed and pushed the fix to feat/kryo-serialization.

To allow me to run pre-commit in the future, the workflow needs to grant permission for commands like uv run pre-commit run --all-files via --allowedTools.

Comment on lines +62 to +63
assertion_methods = ("assertArrayEquals", "assertArrayNotEquals")
if not any(method in line for method in assertion_methods):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 66% (0.66x) speedup for _infer_array_cast_type in codeflash/languages/java/instrumentation.py

⏱️ Runtime : 480 microseconds 289 microseconds (best of 232 runs)

📝 Explanation and details

The optimized code achieves a 66% speedup by replacing the any() generator expression with direct substring checks using the in operator.

Key Performance Improvements:

  1. Eliminated Generator Overhead: The original code created a tuple ("assertArrayEquals", "assertArrayNotEquals") on every function call (7.3% of runtime) and then used any() with a generator expression that performed substring searches (47.4% of runtime combined). The optimized version directly checks both strings with explicit in operators, removing generator iteration overhead.

  2. Early Exit Optimization: By combining both checks with and, the function can short-circuit immediately when the first assertion method isn't found, avoiding the second check in many cases. This is particularly effective in test cases where no assertion methods are present, showing speedups of 140-143% (e.g., test_no_assertion_method_returns_none, test_case_sensitive_method_name).

  3. CPU Cache Efficiency: Direct in operations on string literals are more cache-friendly than tuple allocation and generator iteration, especially for repeated calls with similar input patterns.

Test Case Performance:

  • Lines with assertion methods: 41-58% faster (e.g., test_basic_int_array_match: 52.6% faster)
  • Lines without assertion methods: 120-143% faster (e.g., test_empty_string_returns_none: 143% faster) due to immediate early exit
  • Large-scale tests with 200+ variations: 78.9% faster, demonstrating consistent performance gains across diverse inputs

The optimization is particularly valuable when processing many Java test files where the function may be called thousands of times during instrumentation, with a significant portion of lines not containing the target assertion methods at all.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 292 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
from __future__ import annotations

import re

# imports
import pytest  # used for our unit tests
from codeflash.languages.java.instrumentation import _infer_array_cast_type

def test_basic_int_array_match():
    # Basic: assertArrayEquals with a simple int array literal should return 'int[]'
    line = 'assertArrayEquals(new int[] {1, 2, 3}, result);'
    codeflash_output = _infer_array_cast_type(line) # 4.07μs -> 2.67μs (52.6% faster)

def test_basic_long_array_with_extra_spaces():
    # Spaces and formatting: multiple spaces between 'new' and type and inside brackets
    line = 'assertArrayEquals("msg", new   long [  ] , result);'
    # Should still recognize 'long' as the primitive type
    codeflash_output = _infer_array_cast_type(line) # 3.72μs -> 2.63μs (41.1% faster)

def test_assert_array_not_equals_double():
    # The alternate assertion method should also be recognized
    line = 'assertArrayNotEquals(new double[]{1.0}, got);'
    codeflash_output = _infer_array_cast_type(line) # 3.91μs -> 2.62μs (49.4% faster)

def test_no_assertion_method_returns_none():
    # If there's no assertion method, function should not infer anything even if a primitive appears
    line = 'someOtherMethod(new int[] {1,2,3});'
    codeflash_output = _infer_array_cast_type(line) # 1.49μs -> 621ns (140% faster)

def test_non_primitive_array_returns_none():
    # Object arrays like String[] should NOT be considered primitive and must return None
    line = 'assertArrayEquals(new String[] {"a"}, actual);'
    codeflash_output = _infer_array_cast_type(line) # 3.03μs -> 1.80μs (67.8% faster)

def test_case_sensitive_method_name():
    # Method name check is case sensitive; lowercase variant should not trigger detection
    line = 'assertarrayequals(new int[] {1});'
    codeflash_output = _infer_array_cast_type(line) # 1.49μs -> 641ns (133% faster)

def test_multiple_primitives_returns_first_occurrence():
    # If multiple primitive 'new TYPE[]' occurrences exist, the first should be returned
    line = 'assertArrayEquals(new float[]{0.1f}, other, maybe, new int[]{1});'
    codeflash_output = _infer_array_cast_type(line) # 3.62μs -> 2.43μs (48.6% faster)

def test_method_name_as_part_of_larger_string_still_triggers():
    # If the assertion method string appears as a substring anywhere, detection proceeds
    line = 'prefix_assertArrayEquals_suffix new char[] {\'a\'};'
    # Since the line contains both the method substring and a primitive array, should return 'char[]'
    codeflash_output = _infer_array_cast_type(line) # 3.77μs -> 2.48μs (52.2% faster)

@pytest.mark.parametrize('prim', ['int', 'long', 'double', 'float', 'short', 'byte', 'char', 'boolean'])
def test_all_primitive_types_matched(prim):
    # Parametrized test to ensure all primitive types listed in the regex are recognized
    line = f'assertArrayEquals(someMsg, new {prim}[] {{ }}, target);'
    expected = f'{prim}[]'
    codeflash_output = _infer_array_cast_type(line) # 29.8μs -> 20.0μs (48.5% faster)

def test_ignores_object_array_with_primitive_word_present_in_type_name():
    # When the array type is an object type (e.g., Integer[]), even if the name contains a primitive substring,
    # it should not match; only exact primitive types are matched.
    line = 'assertArrayEquals(new Integer[] {1}, actual);'
    codeflash_output = _infer_array_cast_type(line) # 2.94μs -> 1.80μs (62.7% faster)

def test_whitespace_and_tabs_inside_new_expression():
    # The regex permits arbitrary whitespace; tabs and spaces inside 'new TYPE [ ]' should be handled
    line = 'assertArrayEquals(\tnew\tint\t[\t]\t, value);'
    codeflash_output = _infer_array_cast_type(line) # 3.74μs -> 2.48μs (50.4% faster)

def test_first_match_when_other_text_precedes_new_expression():
    # Even if the 'new TYPE[]' comes after other text on the same line, detection should find it
    line = 'log("debug"); assertArrayEquals(/*expected*/ new boolean[] {}, actual);'
    codeflash_output = _infer_array_cast_type(line) # 3.70μs -> 2.58μs (43.0% faster)

def test_returns_none_when_primitive_appears_but_no_brackets_after_type():
    # Pattern requires brackets after the primitive; 'new int' without [] should not match
    line = 'assertArrayEquals(new int, source);'
    codeflash_output = _infer_array_cast_type(line) # 3.11μs -> 1.88μs (64.9% faster)

def test_detection_when_brackets_have_internal_whitespace_and_comments():
    # Even with spaces between bracket characters or comments, the pattern only allows whitespace.
    # Comments inside brackets would not be matched by the pattern; ensure it returns None in such odd cases.
    line = 'assertArrayEquals(new int[/*comment*/], result);'
    codeflash_output = _infer_array_cast_type(line) # 3.12μs -> 1.91μs (62.9% faster)

def test_multiple_assertion_methods_in_line_prefers_first_primitive():
    # If multiple assertion method names occur, still the first primitive array occurrence in the line should be used
    line = 'assertArrayNotEquals(x); // later assertArrayEquals(new short[] {}, y);'
    # Even though assertArrayEquals appears later, the first primitive array is later in the line and should be found.
    # Here there is a primitive only after the second method, but the function checks whole line for methods and then searches;
    # the primitive found is 'short[]'.
    codeflash_output = _infer_array_cast_type(line) # 3.75μs -> 2.56μs (46.7% faster)

def test_large_scale_many_variants_performance_and_correctness():
    # Large-scale: generate many variant lines (but keep under 1000 iterations) to validate performance and correctness.
    primitives = ['int', 'long', 'double', 'float', 'short', 'byte', 'char', 'boolean']
    methods = ['assertArrayEquals', 'assertArrayNotEquals']
    lines = []
    expected = []

    # Create 200 variations combining methods and primitives plus some negatives mixed in.
    for i in range(200):
        prim = primitives[i % len(primitives)]
        method = methods[i % len(methods)]
        # Every 8th line will be a negative case (no assertion method or non-primitive)
        if i % 8 == 0:
            # Negative: object array or no assertion method
            if i % 16 == 0:
                lines.append(f'log("i={i}"); new String[] {{ "x" }};')  # no assertion -> None
                expected.append(None)
            else:
                lines.append(f'{method}(new String[]{{"x"}}, result);')  # object array with assertion -> None
                expected.append(None)
        else:
            # Positive: should detect the primitive type
            # Vary whitespace occasionally to exercise regex flexibility
            if i % 3 == 0:
                lines.append(f'{method}( new {prim} [ ] , value{i});')
            elif i % 3 == 1:
                lines.append(f'{method}("msg",new {prim}[]{{}},value{i});')
            else:
                lines.append(f'prefix; {method}(/*c*/ new   {prim}[] ,x);')
            expected.append(f'{prim}[]')

    # Run checks: all lines produce expected outputs
    for ln, exp in zip(lines, expected):
        codeflash_output = _infer_array_cast_type(ln) # 234μs -> 131μs (78.9% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import re

import pytest
from codeflash.languages.java.instrumentation import _infer_array_cast_type

class TestInferArrayCastTypeBasic:
    """Basic test cases for _infer_array_cast_type function."""

    def test_int_array_with_assertArrayEquals(self):
        """Test detection of int[] primitive array in assertArrayEquals."""
        line = "assertTrue(Arrays.equals(new int[] {1, 2, 3}, result));"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 1.56μs -> 712ns (120% faster)

    def test_long_array_with_assertArrayEquals(self):
        """Test detection of long[] primitive array in assertArrayEquals."""
        line = "assertArrayEquals(new long[] {1L, 2L, 3L}, expected);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.76μs -> 2.50μs (49.9% faster)

    def test_double_array_with_assertArrayEquals(self):
        """Test detection of double[] primitive array in assertArrayEquals."""
        line = "assertArrayEquals(new double[] {1.0, 2.0, 3.0}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.52μs -> 2.43μs (44.5% faster)

    def test_float_array_with_assertArrayEquals(self):
        """Test detection of float[] primitive array in assertArrayEquals."""
        line = "assertArrayEquals(new float[] {1.0f, 2.0f}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.51μs -> 2.35μs (48.9% faster)

    def test_boolean_array_with_assertArrayEquals(self):
        """Test detection of boolean[] primitive array in assertArrayEquals."""
        line = "assertArrayEquals(new boolean[] {true, false}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.62μs -> 2.36μs (53.0% faster)

    def test_byte_array_with_assertArrayEquals(self):
        """Test detection of byte[] primitive array in assertArrayEquals."""
        line = "assertArrayEquals(new byte[] {1, 2, 3}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.52μs -> 2.33μs (50.6% faster)

    def test_short_array_with_assertArrayEquals(self):
        """Test detection of short[] primitive array in assertArrayEquals."""
        line = "assertArrayEquals(new short[] {1, 2, 3}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.48μs -> 2.35μs (48.3% faster)

    def test_char_array_with_assertArrayEquals(self):
        """Test detection of char[] primitive array in assertArrayEquals."""
        line = "assertArrayEquals(new char[] {'a', 'b'}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.47μs -> 2.35μs (47.3% faster)

    def test_int_array_with_assertArrayNotEquals(self):
        """Test detection of int[] with assertArrayNotEquals method."""
        line = "assertArrayNotEquals(new int[] {1, 2, 3}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.72μs -> 2.56μs (45.5% faster)

    def test_no_assertion_method_returns_none(self):
        """Test that lines without assertion methods return None."""
        line = "int[] array = new int[] {1, 2, 3};"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 1.48μs -> 621ns (139% faster)

    def test_assertion_without_primitive_array_returns_none(self):
        """Test that assertions without primitive arrays return None."""
        line = "assertArrayEquals(new String[] {\"a\", \"b\"}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 2.98μs -> 1.76μs (69.2% faster)

    def test_empty_string_returns_none(self):
        """Test that empty string returns None."""
        codeflash_output = _infer_array_cast_type(""); result = codeflash_output # 1.29μs -> 531ns (143% faster)

    def test_simple_assertion_with_spaces(self):
        """Test handling of extra spaces in array declaration."""
        line = "assertArrayEquals(new int   [  ] {1, 2}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.80μs -> 2.52μs (50.4% faster)

    def test_assertion_with_multiline_context(self):
        """Test that function works with single line containing assertion."""
        line = "assertTrue(method()); assertArrayEquals(new long[] {1, 2}, expected);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.64μs -> 2.53μs (43.5% faster)

class TestInferArrayCastTypeEdgeCases:
    """Edge case tests for _infer_array_cast_type function."""

    def test_assertion_method_as_substring(self):
        """Test detection when method name appears as substring of other text."""
        line = "myassertArrayEquals(new int[] {1}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.57μs -> 2.48μs (43.6% faster)

    def test_assertion_with_no_spaces_between_new_and_type(self):
        """Test array declaration without spaces after 'new'."""
        line = "assertArrayEquals(newint[] {1}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 2.83μs -> 1.57μs (80.2% faster)

    def test_assertion_with_comments(self):
        """Test line containing comments with array pattern."""
        line = "// assertArrayEquals(new int[] {1}); assertArrayEquals(new long[] {1}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.60μs -> 2.44μs (47.1% faster)

    def test_multiple_array_declarations_in_line(self):
        """Test line with multiple array declarations - returns first match."""
        line = "assertArrayEquals(new int[] {1}, new long[] {2});"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.56μs -> 2.42μs (46.7% faster)

    def test_assertion_with_method_parameters(self):
        """Test assertion with additional parameters like delta."""
        line = "assertArrayEquals(new double[] {1.0}, result, 0.01);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.53μs -> 2.32μs (51.8% faster)

    def test_nested_arrays_only_primitive_matters(self):
        """Test that we only detect primitive array, not Object arrays."""
        line = "assertArrayEquals(new Object[] {new int[0]}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.28μs -> 2.11μs (55.0% faster)

    def test_assertion_case_sensitivity(self):
        """Test that method names are case-sensitive."""
        line = "assertarrayequals(new int[] {1}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 1.56μs -> 701ns (123% faster)

    def test_whitespace_variations_in_new_declaration(self):
        """Test various whitespace patterns in array declaration."""
        line = "assertArrayEquals(new   int   [   ] {1}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.66μs -> 2.50μs (46.6% faster)

    def test_array_with_no_initializer(self):
        """Test array declaration without initializer."""
        line = "assertArrayEquals(new int[], result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.53μs -> 2.29μs (53.7% faster)

    def test_long_line_with_multiple_statements(self):
        """Test long line with multiple Java statements."""
        line = "int x = 5; assertArrayEquals(new short[] {1, 2}, result); String s = \"test\";"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.66μs -> 2.48μs (47.8% faster)

    def test_assertion_in_string_literal_ignored(self):
        """Test that pattern in string literals is still detected by regex."""
        line = 'String msg = "assertArrayEquals(new int[] {1})"; assertArrayEquals(new long[] {1}, result);'
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.61μs -> 2.44μs (48.1% faster)

    def test_extremely_long_array_initialization(self):
        """Test with very long array initializer values."""
        line = "assertArrayEquals(new int[] {" + ", ".join(str(i) for i in range(100)) + "}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.58μs -> 2.26μs (58.0% faster)

    def test_unicode_in_line_before_assertion(self):
        """Test line with unicode characters before assertion."""
        line = "// 测试 test\nassertArrayEquals(new byte[] {1}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 4.45μs -> 3.19μs (39.6% faster)

    def test_assertion_with_field_access(self):
        """Test assertion where array is from field access."""
        line = "assertArrayEquals(new int[] {1}, obj.getArray());"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.53μs -> 2.35μs (50.4% faster)

    def test_assertion_with_method_call_argument(self):
        """Test assertion with method call as second argument."""
        line = "assertArrayEquals(new int[] {1}, getExpectedArray());"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.56μs -> 2.27μs (56.4% faster)

class TestInferArrayCastTypeLargeScale:
    """Large scale test cases for _infer_array_cast_type function."""

    def test_performance_with_large_line_length(self):
        """Test function performance with very large line content."""
        # Create a line with 10000 characters but only one assertion
        padding = "x" * 5000
        line = f"{padding}assertArrayEquals(new int[] {{1, 2}}, result){padding}"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 5.76μs -> 4.63μs (24.5% faster)

    def test_performance_with_many_similar_patterns(self):
        """Test performance with many 'new' patterns but only one matching assertion."""
        line = "new Object() new String() new Integer() assertArrayEquals(new long[] {1}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 4.11μs -> 2.91μs (41.3% faster)

    def test_all_primitive_types_in_sequence(self):
        """Test detection of each primitive type in separate calls."""
        test_cases = [
            ("assertArrayEquals(new int[] {1}, r);", "int[]"),
            ("assertArrayEquals(new long[] {1}, r);", "long[]"),
            ("assertArrayEquals(new double[] {1.0}, r);", "double[]"),
            ("assertArrayEquals(new float[] {1.0}, r);", "float[]"),
            ("assertArrayEquals(new byte[] {1}, r);", "byte[]"),
            ("assertArrayEquals(new short[] {1}, r);", "short[]"),
            ("assertArrayEquals(new char[] {'a'}, r);", "char[]"),
            ("assertArrayEquals(new boolean[] {true}, r);", "boolean[]"),
        ]
        for line, expected in test_cases:
            codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 12.9μs -> 7.79μs (65.6% faster)

    def test_consistency_across_repeated_calls(self):
        """Test that repeated calls with same input return same result."""
        line = "assertArrayEquals(new int[] {1, 2, 3}, result);"
        results = [_infer_array_cast_type(line) for _ in range(100)]

    def test_various_assertion_argument_combinations(self):
        """Test various argument combinations in assertions."""
        lines_with_expected = [
            ("assertArrayEquals(new int[] {1}, actual);", "int[]"),
            ("assertArrayEquals(new int[] {1}, actual, delta);", "int[]"),
            ("assertArrayEquals(\"message\", new int[] {1}, actual);", "int[]"),
            ("assertArrayEquals(\"message\", new int[] {1}, actual, delta);", "int[]"),
            ("assertArrayNotEquals(new int[] {1}, actual);", "int[]"),
        ]
        for line, expected in lines_with_expected:
            codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 9.43μs -> 5.78μs (63.0% faster)

    def test_stress_test_with_complex_code_snippets(self):
        """Test with realistic complex Java code snippets."""
        test_snippets = [
            ("if (condition) { assertArrayEquals(new int[] {1}, result); }", "int[]"),
            ("for (int i = 0; i < 10; i++) { assertArrayEquals(new long[] {i}, data); }", "long[]"),
            ("try { assertArrayEquals(new double[] {1.0}, calc()); } catch (Exception e) {}", "double[]"),
            ("obj.method(arg1, arg2); assertArrayEquals(new byte[] {1}, result);", "byte[]"),
        ]
        for line, expected in test_snippets:
            codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 8.38μs -> 5.22μs (60.7% faster)

    def test_non_matching_lines_batch(self):
        """Test batch of non-matching lines to ensure correct None return."""
        non_matching_lines = [
            "int[] array = new int[] {1, 2};",
            "Object[] objs = new Object[] {1, 2};",
            "assertSomethingElse(new int[] {1});",
            "// assertArrayEquals(new int[] {1})",
            "String str = \"assertArrayEquals\";",
            "new ArrayList<Integer>();",
        ]
        for line in non_matching_lines:
            codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 7.39μs -> 3.85μs (91.8% faster)

    def test_boundary_primitive_types(self):
        """Test all 8 primitive types are correctly identified."""
        primitive_types = [
            "int", "long", "double", "float", "short", "byte", "char", "boolean"
        ]
        for prim_type in primitive_types:
            line = f"assertArrayEquals(new {prim_type}[] {{1}}, result);"
            codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 12.7μs -> 7.62μs (66.2% faster)

    def test_mixed_case_method_names_not_matched(self):
        """Test that mixed case variations of method names don't match."""
        variations = [
            "AssertArrayEquals(new int[] {1}, result);",
            "assertArrayequals(new int[] {1}, result);",
            "ASSERTARRAYEQUALS(new int[] {1}, result);",
            "assertArray Equals(new int[] {1}, result);",
        ]
        for line in variations:
            codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.54μs -> 1.50μs (136% faster)

    def test_large_number_of_assertions_in_one_line(self):
        """Test behavior with multiple assertion keywords in a line."""
        line = "a(); assertArrayEquals(new int[] {1}, r); b(); assertArrayNotEquals(new long[] {2}, s);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.71μs -> 2.50μs (48.5% faster)

    def test_special_characters_in_surrounding_context(self):
        """Test with special characters surrounding the assertion."""
        line = "!!!assertArrayEquals(new int[] {1}, result)???&&& ||"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.57μs -> 2.23μs (59.7% faster)

    def test_regex_special_chars_in_initialization(self):
        """Test array with special regex characters in values."""
        line = "assertArrayEquals(new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, result);"
        codeflash_output = _infer_array_cast_type(line); result = codeflash_output # 3.49μs -> 2.33μs (49.3% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1390-2026-02-06T17.27.32

Suggested change
assertion_methods = ("assertArrayEquals", "assertArrayNotEquals")
if not any(method in line for method in assertion_methods):
if "assertArrayEquals" not in line and "assertArrayNotEquals" not in line:

Static Badge

Co-authored-by: HeshamHM28 <HeshamHM28@users.noreply.github.com>
@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Feb 6, 2026

⚡️ Codeflash found optimizations for this PR

📄 12% (0.12x) speedup for _add_behavior_instrumentation in codeflash/languages/java/instrumentation.py

⏱️ Runtime : 2.82 milliseconds 2.52 milliseconds (best of 241 runs)

A new Optimization Review has been created.

🔗 Review here

Static Badge

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.


Ubuntu seems not to be a GitHub user. You need a GitHub account to be able to sign the CLA. If you have already a GitHub account, please add the email address used for this commit to your account.
You have signed the CLA already but the status is still pending? Let us recheck it.

Comment on lines +292 to +294
if Path("mvnw").exists():
return "./mvnw"
if os.path.exists("mvnw.cmd"):
if Path("mvnw.cmd").exists():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 70% (0.70x) speedup for find_maven_executable in codeflash/languages/java/build_tools.py

⏱️ Runtime : 792 microseconds 464 microseconds (best of 63 runs)

📝 Explanation and details

This optimization achieves a 70% speedup (from 792μs to 464μs) by replacing Path("file").exists() calls with os.path.exists("file").

Key Performance Improvements:

  1. Eliminated Path Object Creation Overhead: The original code created a new Path object for each existence check. Line profiler shows Path("mvnw").exists() taking 2.87ms (73.6% of total time), while os.path.exists("mvnw") takes only 662μs (53.5% of time) - a 77% reduction for this single check.

  2. Reduced Memory Allocations: Path() instantiates objects with internal attributes and methods, while os.path.exists() performs a direct system call without object creation overhead. This is especially beneficial for hot-path code.

  3. Faster System Call Path: os.path.exists() has a more direct code path to the underlying OS existence check compared to Path's object-oriented wrapper.

Why This Matters:

From the function references, find_maven_executable() is called from:

  • Test discovery workflows (test_instrumentation.py)
  • Build tool detection (test_build_tools.py)

These are invoked during project analysis and setup phases where the function may be called repeatedly. The test results show consistent 98-142% speedup across all scenarios:

  • Best case (wrapper present): 142% faster (13.7μs → 5.66μs)
  • No Maven found: 99.5% faster (27.0μs → 13.5μs)
  • Large directories: 57-61% faster, demonstrating scalability

The optimization maintains identical behavior - all functional logic, return values, and error paths remain unchanged. This is purely a performance win from using a more efficient existence check API.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 31 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import os
import shutil
from pathlib import Path

import pytest
from codeflash.languages.java.build_tools import find_maven_executable

def test_mvnw_wrapper_present_returns_relative_executable(tmp_path, monkeypatch):
    # Ensure we run inside a temporary directory to avoid interference from real files.
    monkeypatch.chdir(tmp_path)
    # Create a file named "mvnw" (a typical Unix Maven wrapper). No need to set executable bit for Path.exists().
    (tmp_path / "mvnw").write_text("#!/bin/sh\necho mvnw\n")
    # Make sure system 'mvn' lookup cannot influence this test: return an unlikely path if called.
    monkeypatch.setattr(shutil, "which", lambda name: "/should/not/be/used" if name == "mvn" else None)
    # Call the function and verify it prefers the wrapper and returns the exact relative path "./mvnw".
    codeflash_output = find_maven_executable() # 13.7μs -> 5.66μs (142% faster)

def test_mvnw_cmd_present_returns_cmd_string(tmp_path, monkeypatch):
    # Switch to an isolated temp directory.
    monkeypatch.chdir(tmp_path)
    # Create only the Windows wrapper file "mvnw.cmd".
    (tmp_path / "mvnw.cmd").write_text("@echo off\necho mvnw.cmd\n")
    # Ensure system mvn is not interfering by returning None.
    monkeypatch.setattr(shutil, "which", lambda name: None)
    # The function should detect the Windows wrapper and return the exact filename "mvnw.cmd".
    codeflash_output = find_maven_executable() # 27.1μs -> 13.2μs (105% faster)

def test_both_wrappers_present_prefers_unix_wrapper(tmp_path, monkeypatch):
    # When both wrappers exist, the Unix-style mvnw should be preferred.
    monkeypatch.chdir(tmp_path)
    (tmp_path / "mvnw").write_text("#!/bin/sh\necho mvnw\n")
    (tmp_path / "mvnw.cmd").write_text("@echo off\necho mvnw.cmd\n")
    # System mvn should not matter; ensure it's set to something that would be ignored.
    monkeypatch.setattr(shutil, "which", lambda name: "/usr/bin/mvn")
    # Verify preference ordering: "./mvnw" is returned, not "mvnw.cmd" or "/usr/bin/mvn".
    codeflash_output = find_maven_executable() # 13.0μs -> 5.57μs (134% faster)

def test_mvnw_is_directory_still_recognized(tmp_path, monkeypatch):
    # Create a directory named "mvnw" instead of a file.
    monkeypatch.chdir(tmp_path)
    (tmp_path / "mvnw").mkdir()
    # Ensure system mvn isn't used by setting which to a known value.
    monkeypatch.setattr(shutil, "which", lambda name: "/usr/bin/mvn")
    # Path.exists() returns True for directories as well; the function should still return "./mvnw".
    codeflash_output = find_maven_executable() # 12.7μs -> 5.61μs (126% faster)

def test_no_wrappers_but_system_mvn_found(tmp_path, monkeypatch):
    # No wrapper files exist in this directory.
    monkeypatch.chdir(tmp_path)
    # Simulate a system 'mvn' being present on PATH by monkeypatching shutil.which.
    monkeypatch.setattr(shutil, "which", lambda name: "C:\\Program Files\\Maven\\bin\\mvn" if name == "mvn" else None)
    # The function should return the exact string provided by shutil.which.
    codeflash_output = find_maven_executable() # 27.7μs -> 13.9μs (98.5% faster)

def test_nothing_found_returns_none(tmp_path, monkeypatch):
    # Ensure clean directory and simulate no mvn on PATH.
    monkeypatch.chdir(tmp_path)
    monkeypatch.setattr(shutil, "which", lambda name: None)
    # No wrapper files and no system mvn; function must return None.
    codeflash_output = find_maven_executable() # 27.0μs -> 13.5μs (99.5% faster)

def test_shutil_which_called_with_correct_argument(tmp_path, monkeypatch):
    # Ensure when wrappers are absent, shutil.which is queried with 'mvn'.
    monkeypatch.chdir(tmp_path)
    called = {"args": None}
    def fake_which(name):
        # record the argument passed in
        called["args"] = name
        return None
    monkeypatch.setattr(shutil, "which", fake_which)
    # Run function; it should call our fake_which with 'mvn'.
    codeflash_output = find_maven_executable(); _ = codeflash_output # 27.0μs -> 13.7μs (97.1% faster)

def test_large_directory_many_files_still_detects_system_mvn(tmp_path, monkeypatch):
    # Simulate a directory with many files to ensure the function remains correct and reasonably fast.
    monkeypatch.chdir(tmp_path)
    # Create a large number of files (under the 1000 element guideline).
    for i in range(500):
        # Writing small content to avoid heavy IO; keep test deterministic.
        (tmp_path / f"file_{i}.txt").write_text(f"content {i}")
    # No wrapper files; simulate system mvn present.
    expected_path = "/opt/maven/bin/mvn"
    monkeypatch.setattr(shutil, "which", lambda name: expected_path if name == "mvn" else None)
    # The function should still return the system mvn path despite many files being present.
    codeflash_output = find_maven_executable() # 35.0μs -> 22.2μs (57.6% faster)

def test_large_directory_with_wrapper_present_still_prefers_wrapper(tmp_path, monkeypatch):
    # Even in a directory with many files, if a wrapper exists it must be chosen.
    monkeypatch.chdir(tmp_path)
    for i in range(300):
        (tmp_path / f"data_{i}.log").write_text("x")
    # Create the Windows wrapper only.
    (tmp_path / "mvnw.cmd").write_text("@echo wrapper\n")
    # Even if system mvn is present, wrapper should be preferred.
    monkeypatch.setattr(shutil, "which", lambda name: "/some/system/mvn")
    codeflash_output = find_maven_executable() # 30.6μs -> 18.9μs (61.8% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import os
import sys
import tempfile
from pathlib import Path
from unittest import mock

import pytest
from codeflash.languages.java.build_tools import find_maven_executable

def test_find_maven_wrapper_unix():
    """Test that mvnw wrapper is found when it exists in current directory."""
    # Create a temporary directory and change to it
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            # Create a mock mvnw file
            Path("mvnw").touch()
            codeflash_output = find_maven_executable(); result = codeflash_output
        finally:
            os.chdir(original_cwd)

def test_find_maven_wrapper_windows():
    """Test that mvnw.cmd wrapper is found when it exists in current directory."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            # Create a mock mvnw.cmd file
            Path("mvnw.cmd").touch()
            codeflash_output = find_maven_executable(); result = codeflash_output
        finally:
            os.chdir(original_cwd)

def test_find_system_maven_when_no_wrapper():
    """Test that system Maven is found when no wrapper exists."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            # No mvnw or mvnw.cmd files created
            # Mock shutil.which to return a Maven path
            with mock.patch('shutil.which') as mock_which:
                mock_which.return_value = "/usr/bin/mvn"
                codeflash_output = find_maven_executable(); result = codeflash_output
                # Verify shutil.which was called with "mvn"
                mock_which.assert_called_once_with("mvn")
        finally:
            os.chdir(original_cwd)

def test_returns_none_when_no_maven_found():
    """Test that None is returned when no Maven executable is found."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            # No wrapper files and mock shutil.which to return None
            with mock.patch('shutil.which') as mock_which:
                mock_which.return_value = None
                codeflash_output = find_maven_executable(); result = codeflash_output
        finally:
            os.chdir(original_cwd)

def test_mvnw_takes_precedence_over_mvnw_cmd():
    """Test that mvnw is returned before mvnw.cmd when both exist."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            # Create both wrapper files
            Path("mvnw").touch()
            Path("mvnw.cmd").touch()
            codeflash_output = find_maven_executable(); result = codeflash_output
        finally:
            os.chdir(original_cwd)

def test_wrapper_takes_precedence_over_system_maven():
    """Test that mvnw wrapper takes precedence over system Maven."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            # Create mvnw wrapper
            Path("mvnw").touch()
            # Mock shutil.which to return a system Maven path
            with mock.patch('shutil.which') as mock_which:
                mock_which.return_value = "/usr/bin/mvn"
                codeflash_output = find_maven_executable(); result = codeflash_output
                # shutil.which should not be called if wrapper exists
                mock_which.assert_not_called()
        finally:
            os.chdir(original_cwd)

def test_mvnw_cmd_takes_precedence_over_system_maven():
    """Test that mvnw.cmd wrapper takes precedence over system Maven."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            # Create mvnw.cmd wrapper (but not mvnw)
            Path("mvnw.cmd").touch()
            # Mock shutil.which to return a system Maven path
            with mock.patch('shutil.which') as mock_which:
                mock_which.return_value = "/usr/bin/mvn"
                codeflash_output = find_maven_executable(); result = codeflash_output
                # shutil.which should not be called if wrapper exists
                mock_which.assert_not_called()
        finally:
            os.chdir(original_cwd)

def test_empty_directory_no_files():
    """Test behavior with completely empty directory."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            with mock.patch('shutil.which') as mock_which:
                mock_which.return_value = None
                codeflash_output = find_maven_executable(); result = codeflash_output
        finally:
            os.chdir(original_cwd)

def test_mvnw_with_different_permissions():
    """Test that mvnw is found regardless of file permissions."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            # Create mvnw with restricted permissions
            mvnw_path = Path("mvnw")
            mvnw_path.touch()
            mvnw_path.chmod(0o000)
            codeflash_output = find_maven_executable(); result = codeflash_output
            # Restore permissions for cleanup
            mvnw_path.chmod(0o644)
        finally:
            os.chdir(original_cwd)

def test_system_maven_returns_absolute_path():
    """Test that system Maven returns an absolute path."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            # Mock shutil.which to return an absolute path
            with mock.patch('shutil.which') as mock_which:
                expected_path = "/usr/local/bin/mvn"
                mock_which.return_value = expected_path
                codeflash_output = find_maven_executable(); result = codeflash_output
        finally:
            os.chdir(original_cwd)

def test_mvnw_path_is_relative():
    """Test that mvnw path returned is relative (./mvnw)."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            Path("mvnw").touch()
            codeflash_output = find_maven_executable(); result = codeflash_output
        finally:
            os.chdir(original_cwd)

def test_mvnw_cmd_path_is_relative():
    """Test that mvnw.cmd path returned is relative (not absolute)."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            Path("mvnw.cmd").touch()
            codeflash_output = find_maven_executable(); result = codeflash_output
        finally:
            os.chdir(original_cwd)

def test_performance_with_many_files_in_directory():
    """Test that function performs well with many files in directory."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            # Create many files in the directory (large scale scenario)
            num_files = 500
            for i in range(num_files):
                Path(f"file_{i}.txt").touch()
            
            # Create mvnw and measure that it's still found quickly
            Path("mvnw").touch()
            codeflash_output = find_maven_executable(); result = codeflash_output
        finally:
            os.chdir(original_cwd)

def test_performance_with_many_files_no_maven():
    """Test that function doesn't degrade with many files when Maven not found."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            # Create many files but no Maven wrapper
            num_files = 500
            for i in range(num_files):
                Path(f"file_{i}.txt").touch()
            
            with mock.patch('shutil.which') as mock_which:
                mock_which.return_value = None
                codeflash_output = find_maven_executable(); result = codeflash_output
        finally:
            os.chdir(original_cwd)

def test_multiple_consecutive_calls_same_state():
    """Test that multiple consecutive calls return consistent results."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            Path("mvnw").touch()
            
            # Call multiple times and verify consistency
            num_calls = 100
            results = [find_maven_executable() for _ in range(num_calls)]
        finally:
            os.chdir(original_cwd)

def test_switching_between_states():
    """Test behavior when switching between different Maven states."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            
            # Start with no Maven
            with mock.patch('shutil.which') as mock_which:
                mock_which.return_value = None
                codeflash_output = find_maven_executable()
            
            # Add mvnw and verify it's found
            Path("mvnw").touch()
            codeflash_output = find_maven_executable()
            
            # Remove mvnw and verify None is returned
            Path("mvnw").unlink()
            with mock.patch('shutil.which') as mock_which:
                mock_which.return_value = None
                codeflash_output = find_maven_executable()
        finally:
            os.chdir(original_cwd)

def test_deeply_nested_directory_structure():
    """Test that function works correctly from different directory depths."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            # Create nested directory structure
            nested_dir = Path(tmpdir) / "a" / "b" / "c" / "d"
            nested_dir.mkdir(parents=True)
            
            # Create mvnw at root of tmpdir
            (Path(tmpdir) / "mvnw").touch()
            
            # Change to nested directory and verify mvnw is NOT found
            # (because mvnw is not in the nested directory)
            os.chdir(nested_dir)
            with mock.patch('shutil.which') as mock_which:
                mock_which.return_value = None
                codeflash_output = find_maven_executable(); result = codeflash_output
        finally:
            os.chdir(original_cwd)

def test_return_type_consistency():
    """Test that return type is always str or None."""
    with tempfile.TemporaryDirectory() as tmpdir:
        original_cwd = os.getcwd()
        try:
            os.chdir(tmpdir)
            
            # Test case 1: mvnw exists
            Path("mvnw").touch()
            codeflash_output = find_maven_executable(); result = codeflash_output
            
            # Clean up and test case 2: system maven exists
            Path("mvnw").unlink()
            with mock.patch('shutil.which') as mock_which:
                mock_which.return_value = "/usr/bin/mvn"
                codeflash_output = find_maven_executable(); result = codeflash_output
            
            # Test case 3: no maven found
            with mock.patch('shutil.which') as mock_which:
                mock_which.return_value = None
                codeflash_output = find_maven_executable(); result = codeflash_output
        finally:
            os.chdir(original_cwd)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1390-2026-02-06T18.43.06

Suggested change
if Path("mvnw").exists():
return "./mvnw"
if os.path.exists("mvnw.cmd"):
if Path("mvnw.cmd").exists():
if os.path.exists("mvnw"):
return "./mvnw"
if os.path.exists("mvnw.cmd"):

Static Badge

@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Feb 6, 2026

⚡️ Codeflash found optimizations for this PR

📄 21% (0.21x) speedup for _add_timing_instrumentation in codeflash/languages/java/instrumentation.py

⏱️ Runtime : 3.17 milliseconds 2.61 milliseconds (best of 245 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch feat/kryo-serialization).

Static Badge

@HeshamHM28 HeshamHM28 merged commit 2725be0 into omni-java Feb 6, 2026
20 of 33 checks passed
@HeshamHM28 HeshamHM28 deleted the feat/kryo-serialization branch February 6, 2026 20:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants