AshenH commited on
Commit
6b8bffc
·
verified ·
1 Parent(s): 307e82b

Added the space folder to main

Browse files
space/Dockerfile ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+ WORKDIR /app
3
+ COPY ../requirements.txt /app/requirements.txt
4
+ RUN pip install --no-cache-dir -r requirements.txt
5
+ COPY . /app
6
+ ENV HF_HOME=/app/.cache/hf_cache
7
+ CMD ["python", "app.py"]
space/README_SPACE.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deploying as a Hugging Face Space
2
+
3
+ 1. Create a new **Gradio** Space.
4
+ 2. Upload the **contents of `space/`** to the Space root.
5
+ 3. Add Space Secrets:
6
+ - `HF_TOKEN`
7
+ - For BigQuery: `GCP_SERVICE_ACCOUNT_JSON`, `GCP_PROJECT`
8
+ - For MotherDuck: `MOTHERDUCK_TOKEN`, `MOTHERDUCK_DB`
9
+ - Optional tracing: `LANGFUSE_PUBLIC_KEY`, `LANGFUSE_SECRET_KEY`, `LANGFUSE_HOST`
10
+ 4. Set `SQL_BACKEND` to `bigquery` or `motherduck`.
11
+ 5. Set `HF_MODEL_REPO` to your private model repo id.
12
+ 6. (Optional) Set `ORCHESTRATOR_MODEL` for the tiny CPU LLM.
space/app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import gradio as gr
4
+ import pandas as pd
5
+ from typing import Dict, Any
6
+
7
+ from tools.sql_tool import SQLTool
8
+ from tools.predict_tool import PredictTool
9
+ from tools.explain_tool import ExplainTool
10
+ from tools.report_tool import ReportTool
11
+ from utils.tracing import Tracer
12
+ from utils.config import AppConfig
13
+
14
+ # Optional: tiny orchestration LLM (keep it simple on CPU)
15
+ try:
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
17
+ LLM_ID = os.getenv("ORCHESTRATOR_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
18
+ _tok = AutoTokenizer.from_pretrained(LLM_ID)
19
+ _mdl = AutoModelForCausalLM.from_pretrained(LLM_ID)
20
+ llm = pipeline("text-generation", model=_mdl, tokenizer=_tok, max_new_tokens=512)
21
+ except Exception:
22
+ llm = None # Fallback: deterministic tool routing without LLM
23
+
24
+ cfg = AppConfig.from_env()
25
+ tracer = Tracer.from_env()
26
+
27
+ sql_tool = SQLTool(cfg, tracer)
28
+ predict_tool = PredictTool(cfg, tracer)
29
+ explain_tool = ExplainTool(cfg, tracer)
30
+ report_tool = ReportTool(cfg, tracer)
31
+
32
+ SYSTEM_PROMPT = (
33
+ "You are an analytical assistant for tabular data. "
34
+ "When the user asks a question, decide which tools to call in order: "
35
+ "1) SQL (if data retrieval is needed) 2) Predict (if scoring is requested) "
36
+ "3) Explain (if attributions or why-questions) 4) Report (if a document is requested). "
37
+ "Always disclose the steps taken and include links to traces if available."
38
+ )
39
+
40
+
41
+ def plan_actions(message: str) -> Dict[str, Any]:
42
+ """Very lightweight planner. Uses LLM if available, else rule-based heuristics."""
43
+ if llm is not None:
44
+ prompt = (
45
+ f"{SYSTEM_PROMPT}\nUser: {message}\n"
46
+ "Return JSON with fields: steps (array, subset of ['sql','predict','explain','report']), "
47
+ "and rationale (one sentence)."
48
+ )
49
+ out = llm(prompt)[0]["generated_text"].split("\n")[-1]
50
+ try:
51
+ plan = json.loads(out)
52
+ return plan
53
+ except Exception:
54
+ pass
55
+ # Heuristic fallback
56
+ steps = []
57
+ m = message.lower()
58
+ if any(k in m for k in ["show", "average", "count", "trend", "top", "sql", "query", "kpi"]):
59
+ steps.append("sql")
60
+ if any(k in m for k in ["predict", "score", "risk", "propensity", "probability"]):
61
+ steps.append("predict")
62
+ if any(k in m for k in ["why", "explain", "shap", "feature", "attribution"]):
63
+ steps.append("explain")
64
+ if any(k in m for k in ["report", "download", "pdf", "summary"]):
65
+ steps.append("report")
66
+ if not steps:
67
+ steps = ["sql"]
68
+ return {"steps": steps, "rationale": "Rule-based plan."}
69
+
70
+
71
+ def run_agent(message: str, hitl_decision: str = "Approve", reviewer_note: str = ""):
72
+ tracer.trace_event("user_message", {"message": message})
73
+ plan = plan_actions(message)
74
+ tracer.trace_event("plan", plan)
75
+
76
+ sql_df = None
77
+ predict_df = None
78
+ explain_plots = {}
79
+ artifacts = {}
80
+
81
+ if "sql" in plan["steps"]:
82
+ sql_df = sql_tool.run(message)
83
+ artifacts["sql_rows"] = len(sql_df) if isinstance(sql_df, pd.DataFrame) else 0
84
+
85
+ if "predict" in plan["steps"]:
86
+ predict_df = predict_tool.run(sql_df)
87
+
88
+ if "explain" in plan["steps"]:
89
+ explain_plots = explain_tool.run(predict_df or sql_df)
90
+
91
+ report_link = None
92
+ if "report" in plan["steps"]:
93
+ report_link = report_tool.render_and_save(
94
+ user_query=message,
95
+ sql_preview=sql_df.head(50) if isinstance(sql_df, pd.DataFrame) else None,
96
+ predict_preview=predict_df.head(50) if isinstance(predict_df, pd.DataFrame) else None,
97
+ explain_images=explain_plots,
98
+ plan=plan,
99
+ )
100
+
101
+ # HITL log (append-only). In production, push to a private HF dataset via API.
102
+ hitl_record = {
103
+ "message": message,
104
+ "decision": hitl_decision,
105
+ "reviewer_note": reviewer_note,
106
+ "timestamp": pd.Timestamp.utcnow().isoformat(),
107
+ "artifacts": artifacts,
108
+ "plan": plan,
109
+ }
110
+ tracer.trace_event("hitl", hitl_record)
111
+
112
+ response = f"**Plan:** {plan['steps']}\n**Rationale:** {plan['rationale']}\n"
113
+ if isinstance(sql_df, pd.DataFrame):
114
+ response += f"\n**SQL rows:** {len(sql_df)}"
115
+ if isinstance(predict_df, pd.DataFrame):
116
+ response += f"\n**Predictions rows:** {len(predict_df)}"
117
+ if report_link:
118
+ response += f"\n**Report:** {report_link}"
119
+ if tracer.trace_url:
120
+ response += f"\n**Trace:** {tracer.trace_url}"
121
+
122
+ preview_df = predict_df or sql_df
123
+ return response, preview_df
124
+
125
+ with gr.Blocks() as demo:
126
+ gr.Markdown("# Tabular Agentic XAI (Free‑Tier)")
127
+ with gr.Row():
128
+ msg = gr.Textbox(label="Ask your question")
129
+ with gr.Row():
130
+ hitl = gr.Radio(["Approve", "Needs Changes"], value="Approve", label="Human Review")
131
+ note = gr.Textbox(label="Reviewer note (optional)")
132
+ out_md = gr.Markdown()
133
+ out_df = gr.Dataframe(interactive=False)
134
+ ask = gr.Button("Run")
135
+ ask.click(run_agent, inputs=[msg, hitl, note], outputs=[out_md, out_df])
136
+
137
+ if __name__ == "__main__":
138
+ demo.launch()
space/templates/report_styles.css ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ body { font-family: system-ui, -apple-system, Segoe UI, Roboto, Arial, sans-serif; padding: 24px; line-height: 1.5; }
2
+ h1,h2,h3 { margin-top: 1.2em; }
3
+ code, pre { background: #f6f8fa; padding: 2px 4px; border-radius: 4px; }
4
+ table { border-collapse: collapse; width: 100%; }
5
+ th, td { border: 1px solid #ddd; padding: 8px; }
6
+ th { background: #fafafa; }
space/templates/report_template.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Insight Report
2
+
3
+ **User Query**: {{ user_query }}
4
+
5
+ **Plan**: {{ plan.steps }}
6
+ **Rationale**: {{ plan.rationale }}
7
+
8
+ {% if sql_preview %}
9
+ ## SQL Preview
10
+ {{ sql_preview }}
11
+ {% endif %}
12
+
13
+ {% if predict_preview %}
14
+ ## Predictions Preview
15
+ {{ predict_preview }}
16
+ {% endif %}
17
+
18
+ {% if explain_images.global_bar %}
19
+ ## Global Feature Importance (SHAP)
20
+ <img src="{{ explain_images.global_bar }}" style="max-width: 100%;" />
21
+ {% endif %}
22
+
23
+ {% if explain_images.beeswarm %}
24
+ ## SHAP Beeswarm
25
+ <img src="{{ explain_images.beeswarm }}" style="max-width: 100%;" />
26
+ {% endif %}
space/tools/__init__.py ADDED
File without changes
space/tools/explain_tool.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import shap
4
+ import base64
5
+ import pandas as pd
6
+ from huggingface_hub import hf_hub_download
7
+ from ..utils.config import AppConfig
8
+ from ..utils.tracing import Tracer
9
+
10
+ class ExplainTool:
11
+ def __init__(self, cfg: AppConfig, tracer: Tracer):
12
+ self.cfg = cfg
13
+ self.tracer = tracer
14
+ self._model = None
15
+
16
+ def _ensure_model(self):
17
+ if self._model is None:
18
+ import joblib
19
+ path = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="model.pkl", token=os.getenv("HF_TOKEN"))
20
+ self._model = joblib.load(path)
21
+
22
+ def _to_data_uri(self, fig) -> str:
23
+ buf = io.BytesIO()
24
+ fig.savefig(buf, format="png", bbox_inches="tight")
25
+ buf.seek(0)
26
+ return "data:image/png;base64," + base64.b64encode(buf.read()).decode()
27
+
28
+ def run(self, df: pd.DataFrame):
29
+ self._ensure_model()
30
+ # Use a small sample for speed on CPU Spaces
31
+ sample = df.sample(min(len(df), 500), random_state=42)
32
+ explainer = shap.Explainer(self._model, sample, feature_names=list(sample.columns))
33
+ shap_values = explainer(sample)
34
+
35
+ # Global summary plot
36
+ fig1 = shap.plots.bar(shap_values, show=False)
37
+ img1 = self._to_data_uri(fig1)
38
+
39
+ # Beeswarm (optional)
40
+ fig2 = shap.plots.beeswarm(shap_values, show=False)
41
+ img2 = self._to_data_uri(fig2)
42
+
43
+ self.tracer.trace_event("explain", {"rows": len(sample)})
44
+ return {"global_bar": img1, "beeswarm": img2}
space/tools/predict_tool.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import joblib
4
+ from huggingface_hub import hf_hub_download
5
+ from ..utils.config import AppConfig
6
+ from ..utils.tracing import Tracer
7
+
8
+ class PredictTool:
9
+ def __init__(self, cfg: AppConfig, tracer: Tracer):
10
+ self.cfg = cfg
11
+ self.tracer = tracer
12
+ self._model = None
13
+ self._feature_meta = None
14
+
15
+ def _ensure_loaded(self):
16
+ if self._model is None:
17
+ path = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="model.pkl", token=os.getenv("HF_TOKEN"))
18
+ self._model = joblib.load(path)
19
+ meta = hf_hub_download(repo_id=self.cfg.hf_model_repo, filename="feature_metadata.json", token=os.getenv("HF_TOKEN"))
20
+ import json
21
+ with open(meta, "r") as f:
22
+ self._feature_meta = json.load(f)
23
+
24
+ def run(self, df: pd.DataFrame) -> pd.DataFrame:
25
+ self._ensure_loaded()
26
+ use_cols = self._feature_meta.get("feature_order", list(df.columns))
27
+ X = df[use_cols].copy()
28
+ preds = self._model.predict_proba(X)[:, 1] if hasattr(self._model, "predict_proba") else self._model.predict(X)
29
+ out = df.copy()
30
+ out[self._feature_meta.get("prediction_column", "prediction")] = preds
31
+ self.tracer.trace_event("predict", {"rows": len(out)})
32
+ return out
space/tools/report_tool.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from jinja2 import Environment, FileSystemLoader
3
+ import pandas as pd
4
+ from ..utils.tracing import Tracer
5
+
6
+ class ReportTool:
7
+ def __init__(self, cfg, tracer: Tracer):
8
+ self.cfg = cfg
9
+ self.tracer = tracer
10
+ self.env = Environment(loader=FileSystemLoader("templates"))
11
+
12
+ def render_and_save(self, user_query: str, sql_preview: pd.DataFrame | None, predict_preview: pd.DataFrame | None, explain_images: dict, plan: dict):
13
+ tmpl = self.env.get_template("report_template.md")
14
+ html = tmpl.render(
15
+ user_query=user_query,
16
+ plan=plan,
17
+ sql_preview=sql_preview.to_markdown(index=False) if sql_preview is not None else "",
18
+ predict_preview=predict_preview.to_markdown(index=False) if predict_preview is not None else "",
19
+ explain_images=explain_images,
20
+ )
21
+ out_path = f"report_{pd.Timestamp.utcnow().strftime('%Y%m%d_%H%M%S')}.html"
22
+ with open(out_path, "w", encoding="utf-8") as f:
23
+ f.write("<link rel=\"stylesheet\" href=\"templates/report_styles.css\">\n" + html)
24
+ self.tracer.trace_event("report", {"path": out_path})
25
+ return out_path
space/tools/sql_tool.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+ from typing import Optional
5
+ from ..utils.config import AppConfig
6
+ from ..utils.tracing import Tracer
7
+
8
+ class SQLTool:
9
+ def __init__(self, cfg: AppConfig, tracer: Tracer):
10
+ self.cfg = cfg
11
+ self.tracer = tracer
12
+ self.backend = cfg.sql_backend # "bigquery" or "motherduck"
13
+ if self.backend == "bigquery":
14
+ from google.cloud import bigquery
15
+ from google.oauth2 import service_account
16
+ key_json = os.getenv("GCP_SERVICE_ACCOUNT_JSON")
17
+ if not key_json:
18
+ raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
19
+ creds = service_account.Credentials.from_service_account_info(
20
+ eval(key_json) if key_json.strip().startswith("{") else {}
21
+ )
22
+ self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
23
+ elif self.backend == "motherduck":
24
+ import duckdb
25
+ token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
26
+ db_name = self.cfg.motherduck_db or "default"
27
+ self.client = duckdb.connect(f"md:/{db_name}?motherduck_token={token}")
28
+ else:
29
+ raise RuntimeError("Unknown SQL backend")
30
+
31
+ def _nl_to_sql(self, message: str) -> str:
32
+ # Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
33
+ # Expect users to include table names. Example: "avg revenue by month from dataset.sales"
34
+ m = message.lower()
35
+ if "avg" in m and " by " in m:
36
+ return "-- Example template; edit me\nSELECT DATE_TRUNC(month, date_col) AS month, AVG(metric) AS avg_metric FROM dataset.table GROUP BY 1 ORDER BY 1;"
37
+ # fallback: pass-through if user typed SQL explicitly
38
+ if re.match(r"^\s*select ", m):
39
+ return message
40
+ return "SELECT * FROM dataset.table LIMIT 100;"
41
+
42
+ def run(self, message: str) -> pd.DataFrame:
43
+ sql = self._nl_to_sql(message)
44
+ self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
45
+ if self.backend == "bigquery":
46
+ df = self.client.query(sql).to_dataframe()
47
+ else:
48
+ df = self.client.execute(sql).fetch_df()
49
+ return df
space/utils/config.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class AppConfig:
6
+ hf_model_repo: str
7
+ sql_backend: str # "bigquery" or "motherduck"
8
+ gcp_project: str | None = None
9
+ motherduck_db: str | None = None
10
+ motherduck_token: str | None = None
11
+
12
+
13
+ @classmethod
14
+ def from_env(cls):
15
+ return cls(
16
+ hf_model_repo=os.getenv("HF_MODEL_REPO", "your-username/your-private-tabular-model"),
17
+ sql_backend=os.getenv("SQL_BACKEND", "motherduck"),
18
+ gcp_project=os.getenv("GCP_PROJECT"),
19
+ motherduck_db=os.getenv("MOTHERDUCK_DB", "default"),
20
+ motherduck_token=os.getenv("MOTHERDUCK_TOKEN")
21
+ )
space/utils/hf_io.py ADDED
File without changes
space/utils/tracing.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Optional
4
+
5
+ class Tracer:
6
+ def __init__(self, client=None, trace_url: Optional[str] = None):
7
+ self.client = client
8
+ self.trace_url = trace_url
9
+
10
+ @classmethod
11
+ def from_env(cls):
12
+ try:
13
+ from langfuse import Langfuse
14
+ pk = os.getenv("LANGFUSE_PUBLIC_KEY")
15
+ sk = os.getenv("LANGFUSE_SECRET_KEY")
16
+ host = os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
17
+ if pk and sk:
18
+ client = Langfuse(public_key=pk, secret_key=sk, host=host)
19
+ session = client.trace("tabular-agentic-xai")
20
+ return cls(client=session, trace_url=session.get_url() if hasattr(session, "get_url") else None)
21
+ except Exception:
22
+ pass
23
+ return cls()
24
+
25
+ def trace_event(self, name: str, payload: dict):
26
+ if self.client:
27
+ try:
28
+ self.client.event(name=name, input=json.dumps(payload))
29
+ except Exception:
30
+ pass