Spaces:
Build error
Build error
porla
commited on
Commit
·
ec946c6
1
Parent(s):
f7a42c7
Implement CustomAgent and CustomToolNode; refactor tools and agent initialization
Browse files- app.py +19 -19
- src/{my_app.py → agent.py} +28 -23
- src/custom_tool_node.py +45 -0
- src/tools.py +14 -93
app.py
CHANGED
|
@@ -4,9 +4,9 @@ import requests
|
|
| 4 |
import inspect
|
| 5 |
import pandas as pd
|
| 6 |
from langchain_core.messages import HumanMessage, SystemMessage
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
from src.my_app import build_graph, get_prompt
|
| 10 |
import json
|
| 11 |
|
| 12 |
# (Keep Constants as is)
|
|
@@ -15,22 +15,22 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
|
| 15 |
|
| 16 |
# --- Basic Agent Definition ---
|
| 17 |
# ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
|
| 18 |
-
class BasicAgent:
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
|
| 35 |
def run_and_submit_all( profile: gr.OAuthProfile | None):
|
| 36 |
"""
|
|
@@ -53,7 +53,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 53 |
|
| 54 |
# 1. Instantiate Agent ( modify this part to create your agent)
|
| 55 |
try:
|
| 56 |
-
agent =
|
| 57 |
except Exception as e:
|
| 58 |
print(f"Error instantiating agent: {e}")
|
| 59 |
return f"Error initializing agent: {e}", None
|
|
|
|
| 4 |
import inspect
|
| 5 |
import pandas as pd
|
| 6 |
from langchain_core.messages import HumanMessage, SystemMessage
|
| 7 |
+
from src.
|
| 8 |
|
| 9 |
+
from src.agent import CustomAgent
|
|
|
|
| 10 |
import json
|
| 11 |
|
| 12 |
# (Keep Constants as is)
|
|
|
|
| 15 |
|
| 16 |
# --- Basic Agent Definition ---
|
| 17 |
# ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
|
| 18 |
+
# class BasicAgent:
|
| 19 |
+
# def __init__(self):
|
| 20 |
+
# print("BasicAgent initialized.")
|
| 21 |
+
# self.graph = build_graph() # Build the state graph for the agent
|
| 22 |
+
# def __call__(self, question: str, task_id: str) -> str:
|
| 23 |
+
# print(f"Agent received question (first 50 chars): {question[:50]}...")
|
| 24 |
+
# system_prompt = SystemMessage(content=get_prompt())
|
| 25 |
+
# messages = self.graph.invoke({
|
| 26 |
+
# "messages": [
|
| 27 |
+
# system_prompt,
|
| 28 |
+
# {"role": "user", "content": question}
|
| 29 |
+
# ],
|
| 30 |
+
# "task_id": task_id
|
| 31 |
+
# })
|
| 32 |
+
# answer = messages['messages'][-1].content
|
| 33 |
+
# return answer[14:]
|
| 34 |
|
| 35 |
def run_and_submit_all( profile: gr.OAuthProfile | None):
|
| 36 |
"""
|
|
|
|
| 53 |
|
| 54 |
# 1. Instantiate Agent ( modify this part to create your agent)
|
| 55 |
try:
|
| 56 |
+
agent = CustomAgent()
|
| 57 |
except Exception as e:
|
| 58 |
print(f"Error instantiating agent: {e}")
|
| 59 |
return f"Error initializing agent: {e}", None
|
src/{my_app.py → agent.py}
RENAMED
|
@@ -1,32 +1,36 @@
|
|
| 1 |
import os
|
| 2 |
-
from typing import List, Dict, Any, Optional
|
| 3 |
-
from langgraph.graph import StateGraph, START, END
|
| 4 |
-
from langchain_openai import ChatOpenAI
|
| 5 |
-
from langchain_core.messages import HumanMessage
|
| 6 |
from IPython.display import Image, display
|
| 7 |
-
|
| 8 |
-
from langchain_openai import AzureChatOpenAI
|
| 9 |
-
from langgraph.prebuilt import ToolNode, tools_condition
|
| 10 |
-
from langgraph.prebuilt import tools_condition
|
| 11 |
-
|
| 12 |
-
from .tools import reverse_text, is_question_reversed, route_question, avaiable_tools, CustomToolNode
|
| 13 |
from langchain_core.messages import SystemMessage
|
|
|
|
|
|
|
|
|
|
| 14 |
from .state import State
|
| 15 |
-
|
| 16 |
-
from
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
-
# llm_with_tools = llm.bind_tools(avaiable_tools, parallel_tool_calls=False)
|
| 21 |
def get_prompt() -> str:
|
| 22 |
with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
| 23 |
return f.read()
|
| 24 |
-
# return """You are a helpful assistant tasked with answering questions using a set of tools.
|
| 25 |
-
# Now, I will ask you a question. Report your thoughts, show the task_id and finish your answer with the following template:
|
| 26 |
-
# FINAL ANSWER: [YOUR FINAL ANSWER].
|
| 27 |
-
# YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
| 28 |
-
# Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
|
| 29 |
-
# """
|
| 30 |
|
| 31 |
|
| 32 |
def build_graph():
|
|
@@ -39,6 +43,7 @@ def build_graph():
|
|
| 39 |
openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
| 40 |
temperature=0.0,
|
| 41 |
)
|
|
|
|
| 42 |
llm_with_tools = llm.bind_tools(avaiable_tools)
|
| 43 |
|
| 44 |
|
|
@@ -59,8 +64,8 @@ def build_graph():
|
|
| 59 |
graph_builder = StateGraph(State)
|
| 60 |
|
| 61 |
# Add nodes
|
| 62 |
-
graph_builder.add_node("check_question_reversed", is_question_reversed)
|
| 63 |
-
graph_builder.add_node("reverse_text", reverse_text)
|
| 64 |
graph_builder.add_node("assistant", assistant)
|
| 65 |
tools_dict = {tool.name: tool for tool in avaiable_tools}
|
| 66 |
|
|
@@ -100,7 +105,7 @@ if __name__ == "__main__":
|
|
| 100 |
#question = """Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\""""
|
| 101 |
# question = """Hi, I was out sick from my classes on Friday, so I'm trying to figure out what I need to study for my Calculus mid-term next week. My friend from class sent me an audio recording of Professor Willowbrook giving out the recommended reading for the test, but my headphones are broken :(\n\nCould you please listen to the recording for me and tell me the page numbers I'm supposed to go over? I've attached a file called Homework.mp3 that has the recording. Please provide just the page numbers as a comma-delimited list. And please provide the list in ascending order."""
|
| 102 |
# question = """The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places."""
|
| 103 |
-
question = """What
|
| 104 |
task_id = "7bd855d8-463d-4ed5-93ca-5fe35145f733"
|
| 105 |
system_prompt = SystemMessage(content=get_prompt())
|
| 106 |
messages = react_graph.invoke({
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from IPython.display import Image, display
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from langchain_core.messages import SystemMessage
|
| 4 |
+
from langchain_openai import AzureChatOpenAI
|
| 5 |
+
from langgraph.graph import StateGraph, START, END
|
| 6 |
+
from langgraph.prebuilt import tools_condition
|
| 7 |
from .state import State
|
| 8 |
+
from .custom_tool_node import CustomToolNode
|
| 9 |
+
from .tools import get_avaiable_tools
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CustomAgent:
|
| 13 |
+
def __init__(self):
|
| 14 |
+
print("CustomAgent initialized.")
|
| 15 |
+
self.graph = build_graph() # Build the state graph for the agent
|
| 16 |
+
|
| 17 |
+
def __call__(self, question: str, task_id: str) -> str:
|
| 18 |
+
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
| 19 |
+
system_prompt = SystemMessage(content=get_prompt())
|
| 20 |
+
messages = self.graph.invoke({
|
| 21 |
+
"messages": [
|
| 22 |
+
system_prompt,
|
| 23 |
+
{"role": "user", "content": question}
|
| 24 |
+
],
|
| 25 |
+
"task_id": task_id
|
| 26 |
+
})
|
| 27 |
+
answer = messages['messages'][-1].content
|
| 28 |
+
return answer[14:]
|
| 29 |
|
| 30 |
|
|
|
|
| 31 |
def get_prompt() -> str:
|
| 32 |
with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
| 33 |
return f.read()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
|
| 36 |
def build_graph():
|
|
|
|
| 43 |
openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
| 44 |
temperature=0.0,
|
| 45 |
)
|
| 46 |
+
avaiable_tools = get_avaiable_tools()
|
| 47 |
llm_with_tools = llm.bind_tools(avaiable_tools)
|
| 48 |
|
| 49 |
|
|
|
|
| 64 |
graph_builder = StateGraph(State)
|
| 65 |
|
| 66 |
# Add nodes
|
| 67 |
+
# graph_builder.add_node("check_question_reversed", is_question_reversed)
|
| 68 |
+
# graph_builder.add_node("reverse_text", reverse_text)
|
| 69 |
graph_builder.add_node("assistant", assistant)
|
| 70 |
tools_dict = {tool.name: tool for tool in avaiable_tools}
|
| 71 |
|
|
|
|
| 105 |
#question = """Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\""""
|
| 106 |
# question = """Hi, I was out sick from my classes on Friday, so I'm trying to figure out what I need to study for my Calculus mid-term next week. My friend from class sent me an audio recording of Professor Willowbrook giving out the recommended reading for the test, but my headphones are broken :(\n\nCould you please listen to the recording for me and tell me the page numbers I'm supposed to go over? I've attached a file called Homework.mp3 that has the recording. Please provide just the page numbers as a comma-delimited list. And please provide the list in ascending order."""
|
| 107 |
# question = """The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places."""
|
| 108 |
+
question = """What country had the least number of athletes at the 1928 Summer Olympics? If there's a tie for a number of athletes, return the first in alphabetical order. Give the IOC country code as your answer."""
|
| 109 |
task_id = "7bd855d8-463d-4ed5-93ca-5fe35145f733"
|
| 110 |
system_prompt = SystemMessage(content=get_prompt())
|
| 111 |
messages = react_graph.invoke({
|
src/custom_tool_node.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
class CustomToolNode:
|
| 3 |
+
"""Tool node personalizzato che può accedere allo stato completo"""
|
| 4 |
+
|
| 5 |
+
def __init__(self, tools_dict):
|
| 6 |
+
self.tools_dict = tools_dict
|
| 7 |
+
|
| 8 |
+
def __call__(self, state):
|
| 9 |
+
messages = state["messages"]
|
| 10 |
+
last_message = messages[-1]
|
| 11 |
+
|
| 12 |
+
# Estrai i tool calls dall'ultimo messaggio
|
| 13 |
+
if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
|
| 14 |
+
results = []
|
| 15 |
+
for tool_call in last_message.tool_calls:
|
| 16 |
+
tool_name = tool_call["name"]
|
| 17 |
+
tool_args = tool_call["args"]
|
| 18 |
+
|
| 19 |
+
# Aggiungi task_id agli argomenti del tool
|
| 20 |
+
tool_args_with_state = {
|
| 21 |
+
**tool_args,
|
| 22 |
+
"task_id": state["task_id"],
|
| 23 |
+
"state": state # Opzionale: passa tutto lo stato
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
if tool_name in self.tools_dict:
|
| 27 |
+
try:
|
| 28 |
+
result = self.tools_dict[tool_name].invoke(tool_args_with_state)
|
| 29 |
+
results.append({
|
| 30 |
+
"type": "tool",
|
| 31 |
+
"name": tool_name,
|
| 32 |
+
"tool_call_id": tool_call["id"],
|
| 33 |
+
"content": str(result)
|
| 34 |
+
})
|
| 35 |
+
except Exception as e:
|
| 36 |
+
results.append({
|
| 37 |
+
"type": "tool",
|
| 38 |
+
"name": tool_name,
|
| 39 |
+
"tool_call_id": tool_call["id"],
|
| 40 |
+
"content": f"Error: {str(e)}"
|
| 41 |
+
})
|
| 42 |
+
|
| 43 |
+
return {"messages": results}
|
| 44 |
+
|
| 45 |
+
return {"messages": []}
|
src/tools.py
CHANGED
|
@@ -8,7 +8,7 @@ from dotenv import load_dotenv
|
|
| 8 |
from langchain_openai import AzureChatOpenAI
|
| 9 |
from langchain_perplexity import ChatPerplexity
|
| 10 |
from langchain_core.tools import tool
|
| 11 |
-
from
|
| 12 |
|
| 13 |
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
|
| 14 |
from youtube_transcript_api import YouTubeTranscriptApi
|
|
@@ -27,8 +27,6 @@ from .state import State
|
|
| 27 |
|
| 28 |
load_dotenv()
|
| 29 |
|
| 30 |
-
# Set your OpenAI API key here
|
| 31 |
-
# os.environ["OPENAI_API_KEY"] = "sk-xxxxx" # Replace with your actual API key
|
| 32 |
|
| 33 |
# Initialize our LLM
|
| 34 |
llm = AzureChatOpenAI(
|
|
@@ -142,94 +140,17 @@ def route_question(state: State) -> str:
|
|
| 142 |
else:
|
| 143 |
return "question_not_reversed"
|
| 144 |
|
| 145 |
-
# web_search_tool = DuckDuckGoSearchRun()
|
| 146 |
-
web_search_tool = TavilySearchResults()
|
| 147 |
-
arxiv_search_tool = ArxivQueryRun()
|
| 148 |
-
wikipedia_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
results = []
|
| 164 |
-
for tool_call in last_message.tool_calls:
|
| 165 |
-
tool_name = tool_call["name"]
|
| 166 |
-
tool_args = tool_call["args"]
|
| 167 |
-
|
| 168 |
-
# Aggiungi task_id agli argomenti del tool
|
| 169 |
-
tool_args_with_state = {
|
| 170 |
-
**tool_args,
|
| 171 |
-
"task_id": state["task_id"],
|
| 172 |
-
"state": state # Opzionale: passa tutto lo stato
|
| 173 |
-
}
|
| 174 |
-
|
| 175 |
-
if tool_name in self.tools_dict:
|
| 176 |
-
try:
|
| 177 |
-
result = self.tools_dict[tool_name].invoke(tool_args_with_state)
|
| 178 |
-
results.append({
|
| 179 |
-
"type": "tool",
|
| 180 |
-
"name": tool_name,
|
| 181 |
-
"tool_call_id": tool_call["id"],
|
| 182 |
-
"content": str(result)
|
| 183 |
-
})
|
| 184 |
-
except Exception as e:
|
| 185 |
-
results.append({
|
| 186 |
-
"type": "tool",
|
| 187 |
-
"name": tool_name,
|
| 188 |
-
"tool_call_id": tool_call["id"],
|
| 189 |
-
"content": f"Error: {str(e)}"
|
| 190 |
-
})
|
| 191 |
-
|
| 192 |
-
return {"messages": results}
|
| 193 |
-
|
| 194 |
-
return {"messages": []}
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
avaiable_tools = [
|
| 199 |
-
wikipedia_tool,
|
| 200 |
-
arxiv_search_tool,
|
| 201 |
-
web_search_tool,
|
| 202 |
-
get_youtube_transcript,
|
| 203 |
-
transcript_mp3_audio,
|
| 204 |
-
load_and_analyze_excel_file
|
| 205 |
-
]
|
| 206 |
-
# response = wikipedia_tool.run("HUNTER X HUNTER")
|
| 207 |
-
# print(response)
|
| 208 |
-
|
| 209 |
-
# 1) tool_for_fetch_wikipedia_data
|
| 210 |
-
|
| 211 |
-
# 2) fetch youtube video data
|
| 212 |
-
|
| 213 |
-
# 3 ) reverse del testo
|
| 214 |
-
|
| 215 |
-
# 4 ) tool/agente che valuta se la domanda è sensata o è scritta al contrario
|
| 216 |
-
|
| 217 |
-
# 5) chess-image-to-dict
|
| 218 |
-
|
| 219 |
-
# 6) chess agent
|
| 220 |
-
|
| 221 |
-
# 7) general python code execution tool
|
| 222 |
-
|
| 223 |
-
# 8) get trascript from youtube video
|
| 224 |
-
|
| 225 |
-
# 9) web-search tool
|
| 226 |
-
|
| 227 |
-
# 10) fetch page content
|
| 228 |
-
|
| 229 |
-
# 11) trascribe mp3 audio file
|
| 230 |
-
|
| 231 |
-
# 12) read mp3 audio file
|
| 232 |
-
|
| 233 |
-
# 13) Research paper MCP
|
| 234 |
-
|
| 235 |
-
# 14) read excel file
|
|
|
|
| 8 |
from langchain_openai import AzureChatOpenAI
|
| 9 |
from langchain_perplexity import ChatPerplexity
|
| 10 |
from langchain_core.tools import tool
|
| 11 |
+
from langchain_tavily import TavilySearchResults
|
| 12 |
|
| 13 |
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
|
| 14 |
from youtube_transcript_api import YouTubeTranscriptApi
|
|
|
|
| 27 |
|
| 28 |
load_dotenv()
|
| 29 |
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# Initialize our LLM
|
| 32 |
llm = AzureChatOpenAI(
|
|
|
|
| 140 |
else:
|
| 141 |
return "question_not_reversed"
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
+
def get_avaiable_tools():
|
| 145 |
+
"""Returns a list of available tools."""
|
| 146 |
+
web_search_tool = TavilySearchResults()
|
| 147 |
+
arxiv_search_tool = ArxivQueryRun()
|
| 148 |
+
wikipedia_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
|
| 149 |
+
return [
|
| 150 |
+
wikipedia_tool,
|
| 151 |
+
arxiv_search_tool,
|
| 152 |
+
web_search_tool,
|
| 153 |
+
get_youtube_transcript,
|
| 154 |
+
transcript_mp3_audio,
|
| 155 |
+
load_and_analyze_excel_file
|
| 156 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|