Pawan Patil commited on
Commit
afefd94
·
1 Parent(s): dcd4138

Final commit after Git identity setup

Browse files
Files changed (4) hide show
  1. agent_app.py +105 -0
  2. requirements.txt +8 -0
  3. sheet_tool.py +87 -0
  4. web_app.py +15 -0
agent_app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # agent_app.py
2
+ from smolagents import CodeAgent, TransformersModel
3
+ from smolagents import tool
4
+ from sheet_tool import (
5
+ fetch_sheet_as_df,
6
+ create_pivot,
7
+ summary_stats,
8
+ plot_dataframe,
9
+ df_to_csv_bytes,
10
+ )
11
+ import base64
12
+
13
+ # Initialize model
14
+ model = TransformersModel(model_id="HuggingFaceTB/SmolLM-135M-Instruct")
15
+
16
+ # -------------------------------
17
+ # ✅ TOOL DEFINITIONS
18
+ # -------------------------------
19
+
20
+ @tool
21
+ def load_sheet() -> dict:
22
+ """Load Google Sheet into a dataframe and return a short summary (not the full sheet)."""
23
+ df = fetch_sheet_as_df()
24
+ if df.empty:
25
+ return {"error": "Sheet is empty or not found."}
26
+ return {
27
+ "rows": len(df),
28
+ "columns": list(df.columns),
29
+ "head": df.head(5).to_dict(orient="records"),
30
+ }
31
+
32
+
33
+ @tool
34
+ def pivot(index_cols: str, column_cols: str, value_cols: str, aggfunc: str = "sum") -> dict:
35
+ """
36
+ Create a pivot table from the Google Sheet.
37
+
38
+ Args:
39
+ index_cols (str): Comma-separated list of columns to use as the pivot table index.
40
+ column_cols (str): Comma-separated list of columns to use as pivot table columns.
41
+ value_cols (str): Comma-separated list of columns to aggregate.
42
+ aggfunc (str, optional): Aggregation function to apply (e.g., 'sum', 'mean', 'count'). Defaults to 'sum'.
43
+ """
44
+ df = fetch_sheet_as_df()
45
+ if df.empty:
46
+ return {"error": "Sheet empty"}
47
+ index = [c.strip() for c in index_cols.split(",")] if index_cols else []
48
+ columns = [c.strip() for c in column_cols.split(",")] if column_cols else []
49
+ values = [c.strip() for c in value_cols.split(",")] if value_cols else []
50
+ pivot_df = create_pivot(df, index=index, columns=columns, values=values, aggfunc=aggfunc)
51
+ csv_bytes = df_to_csv_bytes(pivot_df)
52
+ return {
53
+ "pivot_preview": pivot_df.head(10).to_dict(orient="records"),
54
+ "csv_b64": base64.b64encode(csv_bytes).decode("utf-8")
55
+ }
56
+
57
+
58
+
59
+ @tool
60
+ def stats() -> dict:
61
+ """Generate summary statistics of the sheet."""
62
+ df = fetch_sheet_as_df()
63
+ if df.empty:
64
+ return {"error": "Sheet empty"}
65
+ s = summary_stats(df)
66
+ return {"summary": s.to_dict()}
67
+
68
+
69
+ @tool
70
+ def plot(kind: str = "bar", x: str = None, y: str = None, title: str = None) -> dict:
71
+ """
72
+ Create a plot from the Google Sheet data.
73
+
74
+ Args:
75
+ kind (str): Type of chart to create. Example values: 'bar', 'line', 'pie', 'scatter'.
76
+ x (str, optional): Column name to use for the X-axis. Example: 'Date'.
77
+ y (str, optional): Comma-separated column names to use for Y-axis. Example: 'Sales,Profit'.
78
+ title (str, optional): Chart title to display at the top.
79
+
80
+ Returns:
81
+ dict: A dictionary containing the base64-encoded plot image or an error message if the sheet is empty.
82
+ """
83
+ df = fetch_sheet_as_df()
84
+ if df.empty:
85
+ return {"error": "Sheet empty"}
86
+
87
+ y_list = [c.strip() for c in y.split(",")] if y else None
88
+ img_data_uri = plot_dataframe(df, kind=kind, x=x, y=y_list, title=title)
89
+ return {"image": img_data_uri}
90
+
91
+
92
+ # -------------------------------
93
+ # ✅ AGENT CREATION
94
+ # -------------------------------
95
+ agent = CodeAgent(model=model, tools=[load_sheet, pivot, stats, plot], add_base_tools=True)
96
+
97
+
98
+ def ask_agent(nl_query: str) -> dict:
99
+ """Send a natural-language query to the agent and return structured response."""
100
+ try:
101
+ resp = agent.run(nl_query)
102
+ return {"text": str(resp)}
103
+ except Exception as e:
104
+ return {"error": str(e)}
105
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ pandas
3
+ google-auth
4
+ google-auth-oauthlib
5
+ google-auth-httplib2
6
+ gspread
7
+ smolagents
8
+ requests
sheet_tool.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sheet_tool.py
2
+ from google.oauth2.service_account import Credentials
3
+ import gspread
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ import io
7
+ import base64
8
+ from typing import Tuple
9
+
10
+ # CONFIG - change these
11
+ CREDENTIALS_FILE = "credentials.json" # path to your service account JSON
12
+ SHEET_ID = "1nOekWGmPsjoHj9T-MFjlNGSFyKPRogVbEjFNRxGgxuM" # replace with your sheet id
13
+ WORKSHEET_INDEX = 0 # first sheet
14
+
15
+ SCOPES = [
16
+ "https://www.googleapis.com/auth/spreadsheets",
17
+ "https://www.googleapis.com/auth/drive",
18
+ ]
19
+
20
+ def authorize_gs():
21
+ creds = Credentials.from_service_account_file(CREDENTIALS_FILE, scopes=SCOPES)
22
+ client = gspread.authorize(creds)
23
+ return client
24
+
25
+ def fetch_sheet_as_df(sheet_id: str = SHEET_ID, worksheet_index: int = WORKSHEET_INDEX) -> pd.DataFrame:
26
+ client = authorize_gs()
27
+ sh = client.open_by_key(sheet_id)
28
+ worksheet = sh.get_worksheet(worksheet_index)
29
+ data = worksheet.get_all_values()
30
+ if not data:
31
+ return pd.DataFrame()
32
+ df = pd.DataFrame(data[1:], columns=data[0])
33
+ # try to convert numeric columns where possible
34
+ for col in df.columns:
35
+ df[col] = pd.to_numeric(df[col], errors="ignore")
36
+ return df
37
+
38
+ def create_pivot(df: pd.DataFrame, index: list, columns: list, values: list, aggfunc: str = "sum") -> pd.DataFrame:
39
+ if df.empty:
40
+ return pd.DataFrame()
41
+ # pandas supports aggfunc as string or function
42
+ agg = getattr(pd.core.groupby.SeriesGroupBy, aggfunc, None)
43
+ pivot = pd.pivot_table(df, index=index, columns=columns, values=values, aggfunc=aggfunc, fill_value=0)
44
+ # convert MultiIndex columns to string for easier display
45
+ pivot = pivot.reset_index()
46
+ pivot.columns = [(" " .join(map(str, c)) if isinstance(c, tuple) else c).strip() for c in pivot.columns]
47
+ return pivot
48
+
49
+ def summary_stats(df: pd.DataFrame, numeric_only: bool = True) -> pd.DataFrame:
50
+ if df.empty:
51
+ return pd.DataFrame()
52
+ return df.describe(include="all") if not numeric_only else df.describe()
53
+
54
+ def plot_dataframe(df: pd.DataFrame, kind: str = "bar", x: str = None, y: list = None, title: str = None, figsize=(8,5)) -> str:
55
+ """
56
+ Creates a matplotlib plot and returns a base64 PNG data URI.
57
+ """
58
+ if df.empty:
59
+ raise ValueError("DataFrame is empty")
60
+
61
+ plt.close('all')
62
+ fig, ax = plt.subplots(figsize=figsize)
63
+
64
+ if kind == "bar":
65
+ if x is None or y is None:
66
+ df.plot(kind="bar", ax=ax)
67
+ else:
68
+ df.plot(kind="bar", x=x, y=y, ax=ax)
69
+ elif kind == "line":
70
+ df.plot(kind="line", x=x, y=y, ax=ax)
71
+ elif kind == "pie":
72
+ df.set_index(x)[y].plot(kind="pie", ax=ax, autopct='%1.1f%%')
73
+ else:
74
+ df.plot(kind=kind, x=x, y=y, ax=ax)
75
+
76
+ if title:
77
+ ax.set_title(title)
78
+ ax.grid(True)
79
+ buf = io.BytesIO()
80
+ fig.tight_layout()
81
+ fig.savefig(buf, format="png")
82
+ buf.seek(0)
83
+ b64 = base64.b64encode(buf.read()).decode("utf-8")
84
+ return "data:image/png;base64," + b64
85
+
86
+ def df_to_csv_bytes(df: pd.DataFrame) -> bytes:
87
+ return df.to_csv(index=False).encode("utf-8")
web_app.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # web_app.py
2
+ from flask import Flask, request, jsonify
3
+ from agent_app import ask_agent
4
+
5
+ app = Flask(__name__)
6
+
7
+ @app.route("/ask", methods=["POST"])
8
+ def ask():
9
+ data = request.json
10
+ prompt = data.get("prompt")
11
+ result = ask_agent(prompt)
12
+ return jsonify({"result": result})
13
+
14
+ if __name__ == "__main__":
15
+ app.run(port=8000, debug=True)