absa-ontology / core /explainability.py
darisdzakwanhoesien2
Modular Attempt
736b015
# -*- coding: utf-8 -*-
"""
core/explainability.py
Cross-model explainability and visualization layer for ESG ABSA.
Integrates outputs from rule-based, classical, deep, and hybrid models
to produce comparative dashboards and interpretable summaries.
"""
import numpy as np
import pandas as pd
from .app_state import app_state
from .utils import safe_plot
from .deep_model import plot_attention_plotly
from .hybrid_model import plot_ontology_scatter
# -------------------------
# Helper: safe retrieve
# -------------------------
def _get_df_safe(key: str) -> pd.DataFrame:
state = app_state.get(key)
if not state:
return pd.DataFrame()
if isinstance(state, dict):
return state.get("df", pd.DataFrame())
return pd.DataFrame()
# -------------------------
# Model comparison summary
# -------------------------
def compare_explain():
"""
Combine predictions from rule-based, classical, deep, and hybrid models
to highlight overlaps and divergences across sentiment/tone/aspect levels.
Returns: comparison DataFrame, matplotlib fig, and optional plotly scatter.
"""
df_rule = _get_df_safe("rule")
df_class = _get_df_safe("classical")
df_deep = _get_df_safe("deep")
df_hyb = _get_df_safe("hybrid")
# unify columns
if df_rule.empty and df_class.empty and df_deep.empty and df_hyb.empty:
fig = safe_plot(lambda ax: (ax.text(0.5, 0.5, "No data yet", ha="center"), ax.axis("off")), "Explainability")
return pd.DataFrame(), fig, None
dfs = []
for name, df in [("Rule-based", df_rule), ("Classical", df_class),
("Deep", df_deep), ("Hybrid", df_hyb)]:
if df is None or df.empty:
continue
# unify into columns: Sentence_Text, Sentiment, Tone
if "Sentiment_Pred" in df.columns:
sent_col = "Sentiment_Pred"
elif "RQ2_Sentiment" in df.columns:
sent_col = "RQ2_Sentiment"
elif "RQ1_Baseline_Sentiment" in df.columns:
sent_col = "RQ1_Baseline_Sentiment"
else:
sent_col = None
if "Tone_Pred" in df.columns:
tone_col = "Tone_Pred"
elif "RQ2_Tone" in df.columns:
tone_col = "RQ2_Tone"
else:
tone_col = None
base = pd.DataFrame({
"Sentence_Text": df["Sentence_Text"] if "Sentence_Text" in df.columns else df.iloc[:, 0],
"Sentiment": df[sent_col] if sent_col else "Neutral",
"Tone": df[tone_col] if tone_col else "Unknown",
})
base["Model"] = name
dfs.append(base)
merged = pd.concat(dfs, ignore_index=True)
# Compute disagreement count (sentence level)
grp = merged.groupby("Sentence_Text")["Sentiment"].nunique().reset_index(name="Sentiment_Disagreement")
avg_disagreement = grp["Sentiment_Disagreement"].mean() if not grp.empty else 0.0
# Visualization
def plot_disagreement(ax):
if merged.empty:
ax.text(0.5, 0.5, "No models to compare", ha="center")
ax.axis("off")
return
(merged.groupby(["Model", "Sentiment"]).size()
.unstack(fill_value=0)
.plot(kind="bar", stacked=True, ax=ax))
ax.set_ylabel("Sentence Count")
ax.set_title(f"Cross-model Sentiment Comparison (Disagreement: {avg_disagreement:.2f})")
fig = safe_plot(plot_disagreement, "Model Sentiment Comparison")
# optional hybrid scatter
hybrid_state = app_state.get("hybrid")
plotly_scatter = None
if hybrid_state:
try:
plotly_scatter = plot_ontology_scatter(hybrid_state)
except Exception:
plotly_scatter = None
return merged, fig, plotly_scatter
# -------------------------
# Explainability fusion
# -------------------------
def explain_sentence_across_models(sentence_text: str):
"""
Compare the same sentence across all available models and produce a consolidated explanation dict.
Returns: dict with keys: 'sentence', 'rule', 'classical', 'deep', 'hybrid'
"""
explanations = {"sentence": sentence_text}
# Rule-based
try:
from .rule_based import explain_rule_based_sentence
explanations["rule"] = explain_rule_based_sentence(sentence_text)
except Exception:
explanations["rule"] = ["(unavailable)"]
# Classical ML
try:
from .classical_ml import explain_classical_sentence
explanations["classical"] = explain_classical_sentence(sentence_text)
except Exception:
explanations["classical"] = {"error": "unavailable"}
# Deep
try:
from .deep_model import explain_deep_sentence
deep_expl = explain_deep_sentence(sentence_text)
if isinstance(deep_expl, dict) and "tokens" in deep_expl:
fig = plot_attention_plotly(deep_expl["tokens"], deep_expl["weights"], title="Deep Attention")
explanations["deep"] = {"tokens": deep_expl["tokens"], "weights": deep_expl["weights"], "plot": fig}
else:
explanations["deep"] = deep_expl
except Exception as e:
explanations["deep"] = {"error": str(e)}
# Hybrid
try:
from .hybrid_model import explain_hybrid_sentence
explanations["hybrid"] = explain_hybrid_sentence(sentence_text)
except Exception:
explanations["hybrid"] = {"error": "unavailable"}
return explanations
# -------------------------
# Visualization of consistency
# -------------------------
def plot_consistency_summary():
"""
Build a simple consistency heatmap across model pairs.
"""
dfs = {
"Rule-based": _get_df_safe("rule"),
"Classical": _get_df_safe("classical"),
"Deep": _get_df_safe("deep"),
"Hybrid": _get_df_safe("hybrid")
}
# extract sentiment vectors
sentiment_matrix = {}
for k, df in dfs.items():
if df.empty:
continue
sent_col = next((c for c in df.columns if "Sentiment" in c), None)
if sent_col:
sentiment_matrix[k] = df[sent_col].reset_index(drop=True)
if not sentiment_matrix:
fig = safe_plot(lambda ax: (ax.text(0.5, 0.5, "No data", ha="center"), ax.axis("off")), "Consistency")
return fig
# compute pairwise agreements
models = list(sentiment_matrix.keys())
n = len(models)
mat = np.zeros((n, n))
for i in range(n):
for j in range(n):
if i == j:
mat[i, j] = 1.0
else:
si = sentiment_matrix[models[i]]
sj = sentiment_matrix[models[j]]
m = min(len(si), len(sj))
if m == 0:
mat[i, j] = 0.0
else:
mat[i, j] = (si[:m] == sj[:m]).mean()
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(4, 4))
im = ax.imshow(mat, cmap="Blues")
ax.set_xticks(range(n))
ax.set_xticklabels(models, rotation=45, ha="right")
ax.set_yticks(range(n))
ax.set_yticklabels(models)
for i in range(n):
for j in range(n):
ax.text(j, i, f"{mat[i,j]:.2f}", ha="center", va="center", color="black", fontsize=8)
plt.title("Sentiment Consistency Matrix")
plt.tight_layout()
return fig