-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_classifier.py
More file actions
93 lines (77 loc) · 4.57 KB
/
test_classifier.py
File metadata and controls
93 lines (77 loc) · 4.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Quick test for semantic task classifier"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.task_classifier import SemanticTaskClassifier
def main():
classifier = SemanticTaskClassifier()
tests = [
# Basic tests
("Sort an array of integers in ascending order", "sorting"),
("Find element 42 in this array", "array_search"),
("Find shortest path from A to B in a graph", "graph_search"),
("Find pattern 'hello' in text document", "string_search"),
("Search for key 15 in a binary search tree", "tree_search"),
("I need to arrange these numbers from smallest to largest", "sorting"),
("Check if value exists in the sorted list using binary search", "array_search"),
("Navigate from node 1 to node 5 using Dijkstra", "graph_search"),
# Complex/verbose problem statements
("""I have a large dataset of customer records stored in an array. Each record
contains a customer ID. I need to quickly determine if a customer with a
specific ID exists in our database. The array is already sorted by ID.""", "array_search"),
("""Given a map of cities connected by roads with different distances, I want
to find the most efficient route from my starting city to my destination.
Some roads are one-way and there are no negative distances.""", "graph_search"),
("""I'm building a search engine and need to find all documents that contain
certain keywords. Given a query with multiple words, find which documents
contain the pattern. The documents are plain text files.""", "string_search"),
("""We have inventory data that needs to be displayed on our website. The data
should be shown in order from highest price to lowest price. Please help me
arrange the items appropriately.""", "sorting"),
("""I have a hierarchical organization structure stored as a BST where each node
represents an employee. I need to find if a specific employee ID exists in
our organization tree.""", "tree_search"),
# Edge cases
("quicksort algorithm implementation", "sorting"),
("binary search implementation for sorted array", "array_search"),
("Dijkstra's algorithm for weighted graphs", "graph_search"),
("trie data structure for autocomplete", "string_search"),
("AVL tree operations", "tree_search"),
# LeetCode-style indirect problem statements
("Given an array of integers, return indices of the two numbers "
"that add up to a specific target.", "array_search"),
("Merge k sorted linked lists and return it as one sorted list.", "sorting"),
("Given a 2d grid map of '1's and '0's, count the number of islands.", "graph_search"),
("Given a binary tree, determine if it is a valid binary search tree.", "tree_search"),
("There are courses with prerequisites. Can all courses be completed?", "graph_search"),
("Given a 2D board and a word, find if the word exists in the grid.", "graph_search"),
("Given a binary tree, find its maximum depth.", "tree_search"),
("Find the kth largest element in an unsorted array.", "sorting"),
("Given a string, find the longest substring without repeating characters.", "string_search"),
("Given an array of integers, find if the array contains any duplicates.", "array_search"),
("Invert a binary tree.", "tree_search"),
("Given an array of strings, group anagrams together.", "string_search"),
("Given a collection of intervals, merge all overlapping intervals.", "sorting"),
]
output_lines = []
output_lines.append("Testing Semantic Task Classifier")
output_lines.append("=" * 50)
passed = 0
for statement, expected in tests:
result = classifier.classify(statement)
status = "PASS" if result.task_type == expected else "FAIL"
if result.task_type == expected:
passed += 1
output_lines.append(f"{status}: '{statement[:40]}...'")
output_lines.append(f" Expected: {expected}, Got: {result.task_type}")
output_lines.append(f" Confidence: {result.confidence:.2f}")
output_lines.append("")
output_lines.append(f"Result: {passed}/{len(tests)} tests passed")
# Print and save to file
output = "\n".join(output_lines)
print(output)
# Also save to file
with open(os.path.join(os.path.dirname(__file__), "test_results.txt"), "w") as f:
f.write(output)
if __name__ == "__main__":
main()