File size: 8,227 Bytes
3e435ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#!/usr/bin/env python3
"""
Parallel tool execution optimization for ALS Research Agent
This module replaces sequential tool execution with parallel execution
to reduce response time by ~60-70% for multi-tool queries.
"""

import asyncio
from typing import List, Dict, Tuple, Any
import logging

logger = logging.getLogger(__name__)


async def execute_single_tool(
    tool_call: Dict,
    call_mcp_tool_func,
    index: int
) -> Tuple[int, str, Dict]:
    """
    Execute a single tool call asynchronously.
    Returns (index, progress_text, result_dict) to maintain order.
    """
    tool_name = tool_call["name"]
    tool_args = tool_call["input"]

    # Show search info in progress text
    tool_display = tool_name.replace('__', ' → ')
    search_info = ""
    if "query" in tool_args:
        search_info = f" `{tool_args['query'][:50]}{'...' if len(tool_args['query']) > 50 else ''}`"
    elif "condition" in tool_args:
        search_info = f" `{tool_args['condition'][:50]}{'...' if len(tool_args['condition']) > 50 else ''}`"

    try:
        # Call MCP tool
        start_time = asyncio.get_event_loop().time()
        tool_result = await call_mcp_tool_func(tool_name, tool_args)
        elapsed = asyncio.get_event_loop().time() - start_time

        logger.info(f"Tool {tool_name} completed in {elapsed:.2f}s")

        # Check for zero results to provide clear indicators
        has_results = True
        results_count = 0

        if isinstance(tool_result, str):
            result_lower = tool_result.lower()

            # Check for specific result counts
            import re
            count_matches = re.findall(r'found (\d+) (?:papers?|trials?|preprints?|results?)', result_lower)
            if count_matches:
                results_count = int(count_matches[0])

            # Check for no results
            if any(phrase in result_lower for phrase in [
                "no results found", "0 results", "no papers found",
                "no trials found", "no preprints found", "not found",
                "zero results", "no matches"
            ]) or results_count == 0:
                has_results = False

        # Create clear success/failure indicator
        if has_results:
            if results_count > 0:
                progress_text = f"\n✅ **Found {results_count} results:** {tool_display}{search_info}"
            else:
                progress_text = f"\n✅ **Success:** {tool_display}{search_info}"
        else:
            progress_text = f"\n⚠️ **No results:** {tool_display}{search_info} - will try alternatives"

        # Add timing for long operations
        if elapsed > 5:
            progress_text += f" (took {elapsed:.1f}s)"

        # Check for zero results to enable self-correction
        if not has_results:
                # Add self-correction hint to the result
                tool_result += "\n\n**SELF-CORRECTION HINT:** No results found with this query. Consider:\n"
                tool_result += "1. Broadening search terms (remove qualifiers)\n"
                tool_result += "2. Using alternative terminology or synonyms\n"
                tool_result += "3. Searching related concepts\n"
                tool_result += "4. Checking for typos in search terms"

        result_dict = {
            "type": "tool_result",
            "tool_use_id": tool_call["id"],
            "content": tool_result
        }

        return index, progress_text, result_dict

    except Exception as e:
        logger.error(f"Error executing tool {tool_name}: {e}")

        # Clear failure indicator for errors
        progress_text = f"\n❌ **Failed:** {tool_display}{search_info} - {str(e)[:50]}"

        error_result = {
            "type": "tool_result",
            "tool_use_id": tool_call["id"],
            "content": f"Error executing tool: {str(e)}"
        }
        return index, progress_text, error_result


async def execute_tool_calls_parallel(
    tool_calls: List[Dict],
    call_mcp_tool_func
) -> Tuple[str, List[Dict]]:
    """
    Execute tool calls in parallel and collect results.
    Maintains the original order of tool calls in results.

    Returns: (progress_text, tool_results_content)
    """
    if not tool_calls:
        return "", []

    # Track execution time for progress reporting
    start_time = asyncio.get_event_loop().time()

    # Log parallel execution
    logger.info(f"Executing {len(tool_calls)} tools in parallel")

    # Create tasks for parallel execution
    tasks = [
        execute_single_tool(tool_call, call_mcp_tool_func, i)
        for i, tool_call in enumerate(tool_calls)
    ]

    # Execute all tasks in parallel
    results = await asyncio.gather(*tasks, return_exceptions=True)

    # Sort results by index to maintain original order
    sorted_results = sorted(
        [r for r in results if not isinstance(r, Exception)],
        key=lambda x: x[0]
    )

    # Combine results with progress summary
    completed_count = len(sorted_results)
    total_count = len(tool_calls)

    # Create progress summary with timing info
    elapsed_time = asyncio.get_event_loop().time() - start_time

    if elapsed_time > 5:
        timing_info = f" in {elapsed_time:.1f}s"
    else:
        timing_info = ""

    progress_text = f"\n📊 **Search Progress:** Completed {completed_count}/{total_count} searches{timing_info}\n"

    tool_results_content = []

    for index, prog_text, result_dict in sorted_results:
        progress_text += prog_text
        tool_results_content.append(result_dict)

    # Handle any exceptions
    for i, result in enumerate(results):
        if isinstance(result, Exception):
            logger.error(f"Task {i} failed with exception: {result}")
            # Add error result for failed tasks
            if i < len(tool_calls):
                tool_results_content.insert(i, {
                    "type": "tool_result",
                    "tool_use_id": tool_calls[i]["id"],
                    "content": f"Tool execution failed: {str(result)}"
                })

    return progress_text, tool_results_content


# Backward compatibility wrapper
async def execute_tool_calls_optimized(
    tool_calls: List[Dict],
    call_mcp_tool_func,
    parallel: bool = True
) -> Tuple[str, List[Dict]]:
    """
    Execute tool calls with optional parallel execution.

    Args:
        tool_calls: List of tool calls to execute
        call_mcp_tool_func: Function to call MCP tools
        parallel: If True, execute tools in parallel; if False, execute sequentially

    Returns: (progress_text, tool_results_content)
    """
    if parallel and len(tool_calls) > 1:
        # Use parallel execution for multiple tools
        return await execute_tool_calls_parallel(tool_calls, call_mcp_tool_func)
    else:
        # Fall back to sequential execution (import from original)
        from refactored_helpers import execute_tool_calls
        return await execute_tool_calls(tool_calls, call_mcp_tool_func)


def estimate_time_savings(num_tools: int, avg_tool_time: float = 3.5) -> Dict[str, float]:
    """
    Estimate time savings from parallel execution.

    Args:
        num_tools: Number of tools to execute
        avg_tool_time: Average time per tool in seconds

    Returns: Dictionary with timing estimates
    """
    sequential_time = num_tools * avg_tool_time
    # Parallel time is roughly the time of the slowest tool plus overhead
    parallel_time = avg_tool_time + 0.5  # 0.5s overhead for coordination

    savings = sequential_time - parallel_time
    savings_percent = (savings / sequential_time) * 100 if sequential_time > 0 else 0

    return {
        "sequential_time": sequential_time,
        "parallel_time": parallel_time,
        "time_saved": savings,
        "savings_percent": savings_percent
    }


# Test the optimization
if __name__ == "__main__":
    # Test time savings estimation
    for n in [2, 3, 4, 5]:
        estimates = estimate_time_savings(n)
        print(f"\n{n} tools:")
        print(f"  Sequential: {estimates['sequential_time']:.1f}s")
        print(f"  Parallel: {estimates['parallel_time']:.1f}s")
        print(f"  Savings: {estimates['time_saved']:.1f}s ({estimates['savings_percent']:.0f}%)")