porla commited on
Commit
ec946c6
·
1 Parent(s): f7a42c7

Implement CustomAgent and CustomToolNode; refactor tools and agent initialization

Browse files
Files changed (4) hide show
  1. app.py +19 -19
  2. src/{my_app.py → agent.py} +28 -23
  3. src/custom_tool_node.py +45 -0
  4. src/tools.py +14 -93
app.py CHANGED
@@ -4,9 +4,9 @@ import requests
4
  import inspect
5
  import pandas as pd
6
  from langchain_core.messages import HumanMessage, SystemMessage
 
7
 
8
-
9
- from src.my_app import build_graph, get_prompt
10
  import json
11
 
12
  # (Keep Constants as is)
@@ -15,22 +15,22 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
15
 
16
  # --- Basic Agent Definition ---
17
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
18
- class BasicAgent:
19
- def __init__(self):
20
- print("BasicAgent initialized.")
21
- self.graph = build_graph() # Build the state graph for the agent
22
- def __call__(self, question: str, task_id: str) -> str:
23
- print(f"Agent received question (first 50 chars): {question[:50]}...")
24
- system_prompt = SystemMessage(content=get_prompt())
25
- messages = self.graph.invoke({
26
- "messages": [
27
- system_prompt,
28
- {"role": "user", "content": question}
29
- ],
30
- "task_id": task_id
31
- })
32
- answer = messages['messages'][-1].content
33
- return answer[14:]
34
 
35
  def run_and_submit_all( profile: gr.OAuthProfile | None):
36
  """
@@ -53,7 +53,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
53
 
54
  # 1. Instantiate Agent ( modify this part to create your agent)
55
  try:
56
- agent = BasicAgent()
57
  except Exception as e:
58
  print(f"Error instantiating agent: {e}")
59
  return f"Error initializing agent: {e}", None
 
4
  import inspect
5
  import pandas as pd
6
  from langchain_core.messages import HumanMessage, SystemMessage
7
+ from src.
8
 
9
+ from src.agent import CustomAgent
 
10
  import json
11
 
12
  # (Keep Constants as is)
 
15
 
16
  # --- Basic Agent Definition ---
17
  # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
18
+ # class BasicAgent:
19
+ # def __init__(self):
20
+ # print("BasicAgent initialized.")
21
+ # self.graph = build_graph() # Build the state graph for the agent
22
+ # def __call__(self, question: str, task_id: str) -> str:
23
+ # print(f"Agent received question (first 50 chars): {question[:50]}...")
24
+ # system_prompt = SystemMessage(content=get_prompt())
25
+ # messages = self.graph.invoke({
26
+ # "messages": [
27
+ # system_prompt,
28
+ # {"role": "user", "content": question}
29
+ # ],
30
+ # "task_id": task_id
31
+ # })
32
+ # answer = messages['messages'][-1].content
33
+ # return answer[14:]
34
 
35
  def run_and_submit_all( profile: gr.OAuthProfile | None):
36
  """
 
53
 
54
  # 1. Instantiate Agent ( modify this part to create your agent)
55
  try:
56
+ agent = CustomAgent()
57
  except Exception as e:
58
  print(f"Error instantiating agent: {e}")
59
  return f"Error initializing agent: {e}", None
src/{my_app.py → agent.py} RENAMED
@@ -1,32 +1,36 @@
1
  import os
2
- from typing import List, Dict, Any, Optional
3
- from langgraph.graph import StateGraph, START, END
4
- from langchain_openai import ChatOpenAI
5
- from langchain_core.messages import HumanMessage
6
  from IPython.display import Image, display
7
-
8
- from langchain_openai import AzureChatOpenAI
9
- from langgraph.prebuilt import ToolNode, tools_condition
10
- from langgraph.prebuilt import tools_condition
11
-
12
- from .tools import reverse_text, is_question_reversed, route_question, avaiable_tools, CustomToolNode
13
  from langchain_core.messages import SystemMessage
 
 
 
14
  from .state import State
15
- # from langgraph.prebuilt import create_react_agent
16
- from langchain.agents import create_tool_calling_agent
17
- from langchain_core.runnables import Runnable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
- # llm_with_tools = llm.bind_tools(avaiable_tools, parallel_tool_calls=False)
21
  def get_prompt() -> str:
22
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
23
  return f.read()
24
- # return """You are a helpful assistant tasked with answering questions using a set of tools.
25
- # Now, I will ask you a question. Report your thoughts, show the task_id and finish your answer with the following template:
26
- # FINAL ANSWER: [YOUR FINAL ANSWER].
27
- # YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
28
- # Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
29
- # """
30
 
31
 
32
  def build_graph():
@@ -39,6 +43,7 @@ def build_graph():
39
  openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"),
40
  temperature=0.0,
41
  )
 
42
  llm_with_tools = llm.bind_tools(avaiable_tools)
43
 
44
 
@@ -59,8 +64,8 @@ def build_graph():
59
  graph_builder = StateGraph(State)
60
 
61
  # Add nodes
62
- graph_builder.add_node("check_question_reversed", is_question_reversed)
63
- graph_builder.add_node("reverse_text", reverse_text)
64
  graph_builder.add_node("assistant", assistant)
65
  tools_dict = {tool.name: tool for tool in avaiable_tools}
66
 
@@ -100,7 +105,7 @@ if __name__ == "__main__":
100
  #question = """Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\""""
101
  # question = """Hi, I was out sick from my classes on Friday, so I'm trying to figure out what I need to study for my Calculus mid-term next week. My friend from class sent me an audio recording of Professor Willowbrook giving out the recommended reading for the test, but my headphones are broken :(\n\nCould you please listen to the recording for me and tell me the page numbers I'm supposed to go over? I've attached a file called Homework.mp3 that has the recording. Please provide just the page numbers as a comma-delimited list. And please provide the list in ascending order."""
102
  # question = """The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places."""
103
- question = """What is the first name of the only Malko Competition recipient from the 20th Century (after 1977) whose nationality on record is a country that no longer exists?"""
104
  task_id = "7bd855d8-463d-4ed5-93ca-5fe35145f733"
105
  system_prompt = SystemMessage(content=get_prompt())
106
  messages = react_graph.invoke({
 
1
  import os
 
 
 
 
2
  from IPython.display import Image, display
 
 
 
 
 
 
3
  from langchain_core.messages import SystemMessage
4
+ from langchain_openai import AzureChatOpenAI
5
+ from langgraph.graph import StateGraph, START, END
6
+ from langgraph.prebuilt import tools_condition
7
  from .state import State
8
+ from .custom_tool_node import CustomToolNode
9
+ from .tools import get_avaiable_tools
10
+
11
+
12
+ class CustomAgent:
13
+ def __init__(self):
14
+ print("CustomAgent initialized.")
15
+ self.graph = build_graph() # Build the state graph for the agent
16
+
17
+ def __call__(self, question: str, task_id: str) -> str:
18
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
19
+ system_prompt = SystemMessage(content=get_prompt())
20
+ messages = self.graph.invoke({
21
+ "messages": [
22
+ system_prompt,
23
+ {"role": "user", "content": question}
24
+ ],
25
+ "task_id": task_id
26
+ })
27
+ answer = messages['messages'][-1].content
28
+ return answer[14:]
29
 
30
 
 
31
  def get_prompt() -> str:
32
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
33
  return f.read()
 
 
 
 
 
 
34
 
35
 
36
  def build_graph():
 
43
  openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"),
44
  temperature=0.0,
45
  )
46
+ avaiable_tools = get_avaiable_tools()
47
  llm_with_tools = llm.bind_tools(avaiable_tools)
48
 
49
 
 
64
  graph_builder = StateGraph(State)
65
 
66
  # Add nodes
67
+ # graph_builder.add_node("check_question_reversed", is_question_reversed)
68
+ # graph_builder.add_node("reverse_text", reverse_text)
69
  graph_builder.add_node("assistant", assistant)
70
  tools_dict = {tool.name: tool for tool in avaiable_tools}
71
 
 
105
  #question = """Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\""""
106
  # question = """Hi, I was out sick from my classes on Friday, so I'm trying to figure out what I need to study for my Calculus mid-term next week. My friend from class sent me an audio recording of Professor Willowbrook giving out the recommended reading for the test, but my headphones are broken :(\n\nCould you please listen to the recording for me and tell me the page numbers I'm supposed to go over? I've attached a file called Homework.mp3 that has the recording. Please provide just the page numbers as a comma-delimited list. And please provide the list in ascending order."""
107
  # question = """The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places."""
108
+ question = """What country had the least number of athletes at the 1928 Summer Olympics? If there's a tie for a number of athletes, return the first in alphabetical order. Give the IOC country code as your answer."""
109
  task_id = "7bd855d8-463d-4ed5-93ca-5fe35145f733"
110
  system_prompt = SystemMessage(content=get_prompt())
111
  messages = react_graph.invoke({
src/custom_tool_node.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class CustomToolNode:
3
+ """Tool node personalizzato che può accedere allo stato completo"""
4
+
5
+ def __init__(self, tools_dict):
6
+ self.tools_dict = tools_dict
7
+
8
+ def __call__(self, state):
9
+ messages = state["messages"]
10
+ last_message = messages[-1]
11
+
12
+ # Estrai i tool calls dall'ultimo messaggio
13
+ if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
14
+ results = []
15
+ for tool_call in last_message.tool_calls:
16
+ tool_name = tool_call["name"]
17
+ tool_args = tool_call["args"]
18
+
19
+ # Aggiungi task_id agli argomenti del tool
20
+ tool_args_with_state = {
21
+ **tool_args,
22
+ "task_id": state["task_id"],
23
+ "state": state # Opzionale: passa tutto lo stato
24
+ }
25
+
26
+ if tool_name in self.tools_dict:
27
+ try:
28
+ result = self.tools_dict[tool_name].invoke(tool_args_with_state)
29
+ results.append({
30
+ "type": "tool",
31
+ "name": tool_name,
32
+ "tool_call_id": tool_call["id"],
33
+ "content": str(result)
34
+ })
35
+ except Exception as e:
36
+ results.append({
37
+ "type": "tool",
38
+ "name": tool_name,
39
+ "tool_call_id": tool_call["id"],
40
+ "content": f"Error: {str(e)}"
41
+ })
42
+
43
+ return {"messages": results}
44
+
45
+ return {"messages": []}
src/tools.py CHANGED
@@ -8,7 +8,7 @@ from dotenv import load_dotenv
8
  from langchain_openai import AzureChatOpenAI
9
  from langchain_perplexity import ChatPerplexity
10
  from langchain_core.tools import tool
11
- from langchain_community.tools.tavily_search.tool import TavilySearchResults
12
 
13
  from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
14
  from youtube_transcript_api import YouTubeTranscriptApi
@@ -27,8 +27,6 @@ from .state import State
27
 
28
  load_dotenv()
29
 
30
- # Set your OpenAI API key here
31
- # os.environ["OPENAI_API_KEY"] = "sk-xxxxx" # Replace with your actual API key
32
 
33
  # Initialize our LLM
34
  llm = AzureChatOpenAI(
@@ -142,94 +140,17 @@ def route_question(state: State) -> str:
142
  else:
143
  return "question_not_reversed"
144
 
145
- # web_search_tool = DuckDuckGoSearchRun()
146
- web_search_tool = TavilySearchResults()
147
- arxiv_search_tool = ArxivQueryRun()
148
- wikipedia_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
149
 
150
-
151
- class CustomToolNode:
152
- """Tool node personalizzato che può accedere allo stato completo"""
153
-
154
- def __init__(self, tools_dict):
155
- self.tools_dict = tools_dict
156
-
157
- def __call__(self, state):
158
- messages = state["messages"]
159
- last_message = messages[-1]
160
-
161
- # Estrai i tool calls dall'ultimo messaggio
162
- if hasattr(last_message, 'tool_calls') and last_message.tool_calls:
163
- results = []
164
- for tool_call in last_message.tool_calls:
165
- tool_name = tool_call["name"]
166
- tool_args = tool_call["args"]
167
-
168
- # Aggiungi task_id agli argomenti del tool
169
- tool_args_with_state = {
170
- **tool_args,
171
- "task_id": state["task_id"],
172
- "state": state # Opzionale: passa tutto lo stato
173
- }
174
-
175
- if tool_name in self.tools_dict:
176
- try:
177
- result = self.tools_dict[tool_name].invoke(tool_args_with_state)
178
- results.append({
179
- "type": "tool",
180
- "name": tool_name,
181
- "tool_call_id": tool_call["id"],
182
- "content": str(result)
183
- })
184
- except Exception as e:
185
- results.append({
186
- "type": "tool",
187
- "name": tool_name,
188
- "tool_call_id": tool_call["id"],
189
- "content": f"Error: {str(e)}"
190
- })
191
-
192
- return {"messages": results}
193
-
194
- return {"messages": []}
195
-
196
-
197
-
198
- avaiable_tools = [
199
- wikipedia_tool,
200
- arxiv_search_tool,
201
- web_search_tool,
202
- get_youtube_transcript,
203
- transcript_mp3_audio,
204
- load_and_analyze_excel_file
205
- ]
206
- # response = wikipedia_tool.run("HUNTER X HUNTER")
207
- # print(response)
208
-
209
- # 1) tool_for_fetch_wikipedia_data
210
-
211
- # 2) fetch youtube video data
212
-
213
- # 3 ) reverse del testo
214
-
215
- # 4 ) tool/agente che valuta se la domanda è sensata o è scritta al contrario
216
-
217
- # 5) chess-image-to-dict
218
-
219
- # 6) chess agent
220
-
221
- # 7) general python code execution tool
222
-
223
- # 8) get trascript from youtube video
224
-
225
- # 9) web-search tool
226
-
227
- # 10) fetch page content
228
-
229
- # 11) trascribe mp3 audio file
230
-
231
- # 12) read mp3 audio file
232
-
233
- # 13) Research paper MCP
234
-
235
- # 14) read excel file
 
8
  from langchain_openai import AzureChatOpenAI
9
  from langchain_perplexity import ChatPerplexity
10
  from langchain_core.tools import tool
11
+ from langchain_tavily import TavilySearchResults
12
 
13
  from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
14
  from youtube_transcript_api import YouTubeTranscriptApi
 
27
 
28
  load_dotenv()
29
 
 
 
30
 
31
  # Initialize our LLM
32
  llm = AzureChatOpenAI(
 
140
  else:
141
  return "question_not_reversed"
142
 
 
 
 
 
143
 
144
+ def get_avaiable_tools():
145
+ """Returns a list of available tools."""
146
+ web_search_tool = TavilySearchResults()
147
+ arxiv_search_tool = ArxivQueryRun()
148
+ wikipedia_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
149
+ return [
150
+ wikipedia_tool,
151
+ arxiv_search_tool,
152
+ web_search_tool,
153
+ get_youtube_transcript,
154
+ transcript_mp3_audio,
155
+ load_and_analyze_excel_file
156
+ ]