|
|
|
|
|
""" |
|
|
core/explainability.py |
|
|
Cross-model explainability and visualization layer for ESG ABSA. |
|
|
Integrates outputs from rule-based, classical, deep, and hybrid models |
|
|
to produce comparative dashboards and interpretable summaries. |
|
|
""" |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
|
|
|
from .app_state import app_state |
|
|
from .utils import safe_plot |
|
|
from .deep_model import plot_attention_plotly |
|
|
from .hybrid_model import plot_ontology_scatter |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_df_safe(key: str) -> pd.DataFrame: |
|
|
state = app_state.get(key) |
|
|
if not state: |
|
|
return pd.DataFrame() |
|
|
if isinstance(state, dict): |
|
|
return state.get("df", pd.DataFrame()) |
|
|
return pd.DataFrame() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compare_explain(): |
|
|
""" |
|
|
Combine predictions from rule-based, classical, deep, and hybrid models |
|
|
to highlight overlaps and divergences across sentiment/tone/aspect levels. |
|
|
Returns: comparison DataFrame, matplotlib fig, and optional plotly scatter. |
|
|
""" |
|
|
df_rule = _get_df_safe("rule") |
|
|
df_class = _get_df_safe("classical") |
|
|
df_deep = _get_df_safe("deep") |
|
|
df_hyb = _get_df_safe("hybrid") |
|
|
|
|
|
|
|
|
if df_rule.empty and df_class.empty and df_deep.empty and df_hyb.empty: |
|
|
fig = safe_plot(lambda ax: (ax.text(0.5, 0.5, "No data yet", ha="center"), ax.axis("off")), "Explainability") |
|
|
return pd.DataFrame(), fig, None |
|
|
|
|
|
dfs = [] |
|
|
for name, df in [("Rule-based", df_rule), ("Classical", df_class), |
|
|
("Deep", df_deep), ("Hybrid", df_hyb)]: |
|
|
if df is None or df.empty: |
|
|
continue |
|
|
|
|
|
if "Sentiment_Pred" in df.columns: |
|
|
sent_col = "Sentiment_Pred" |
|
|
elif "RQ2_Sentiment" in df.columns: |
|
|
sent_col = "RQ2_Sentiment" |
|
|
elif "RQ1_Baseline_Sentiment" in df.columns: |
|
|
sent_col = "RQ1_Baseline_Sentiment" |
|
|
else: |
|
|
sent_col = None |
|
|
if "Tone_Pred" in df.columns: |
|
|
tone_col = "Tone_Pred" |
|
|
elif "RQ2_Tone" in df.columns: |
|
|
tone_col = "RQ2_Tone" |
|
|
else: |
|
|
tone_col = None |
|
|
base = pd.DataFrame({ |
|
|
"Sentence_Text": df["Sentence_Text"] if "Sentence_Text" in df.columns else df.iloc[:, 0], |
|
|
"Sentiment": df[sent_col] if sent_col else "Neutral", |
|
|
"Tone": df[tone_col] if tone_col else "Unknown", |
|
|
}) |
|
|
base["Model"] = name |
|
|
dfs.append(base) |
|
|
|
|
|
merged = pd.concat(dfs, ignore_index=True) |
|
|
|
|
|
|
|
|
grp = merged.groupby("Sentence_Text")["Sentiment"].nunique().reset_index(name="Sentiment_Disagreement") |
|
|
avg_disagreement = grp["Sentiment_Disagreement"].mean() if not grp.empty else 0.0 |
|
|
|
|
|
|
|
|
def plot_disagreement(ax): |
|
|
if merged.empty: |
|
|
ax.text(0.5, 0.5, "No models to compare", ha="center") |
|
|
ax.axis("off") |
|
|
return |
|
|
(merged.groupby(["Model", "Sentiment"]).size() |
|
|
.unstack(fill_value=0) |
|
|
.plot(kind="bar", stacked=True, ax=ax)) |
|
|
ax.set_ylabel("Sentence Count") |
|
|
ax.set_title(f"Cross-model Sentiment Comparison (Disagreement: {avg_disagreement:.2f})") |
|
|
|
|
|
fig = safe_plot(plot_disagreement, "Model Sentiment Comparison") |
|
|
|
|
|
|
|
|
hybrid_state = app_state.get("hybrid") |
|
|
plotly_scatter = None |
|
|
if hybrid_state: |
|
|
try: |
|
|
plotly_scatter = plot_ontology_scatter(hybrid_state) |
|
|
except Exception: |
|
|
plotly_scatter = None |
|
|
|
|
|
return merged, fig, plotly_scatter |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def explain_sentence_across_models(sentence_text: str): |
|
|
""" |
|
|
Compare the same sentence across all available models and produce a consolidated explanation dict. |
|
|
Returns: dict with keys: 'sentence', 'rule', 'classical', 'deep', 'hybrid' |
|
|
""" |
|
|
explanations = {"sentence": sentence_text} |
|
|
|
|
|
|
|
|
try: |
|
|
from .rule_based import explain_rule_based_sentence |
|
|
explanations["rule"] = explain_rule_based_sentence(sentence_text) |
|
|
except Exception: |
|
|
explanations["rule"] = ["(unavailable)"] |
|
|
|
|
|
|
|
|
try: |
|
|
from .classical_ml import explain_classical_sentence |
|
|
explanations["classical"] = explain_classical_sentence(sentence_text) |
|
|
except Exception: |
|
|
explanations["classical"] = {"error": "unavailable"} |
|
|
|
|
|
|
|
|
try: |
|
|
from .deep_model import explain_deep_sentence |
|
|
deep_expl = explain_deep_sentence(sentence_text) |
|
|
if isinstance(deep_expl, dict) and "tokens" in deep_expl: |
|
|
fig = plot_attention_plotly(deep_expl["tokens"], deep_expl["weights"], title="Deep Attention") |
|
|
explanations["deep"] = {"tokens": deep_expl["tokens"], "weights": deep_expl["weights"], "plot": fig} |
|
|
else: |
|
|
explanations["deep"] = deep_expl |
|
|
except Exception as e: |
|
|
explanations["deep"] = {"error": str(e)} |
|
|
|
|
|
|
|
|
try: |
|
|
from .hybrid_model import explain_hybrid_sentence |
|
|
explanations["hybrid"] = explain_hybrid_sentence(sentence_text) |
|
|
except Exception: |
|
|
explanations["hybrid"] = {"error": "unavailable"} |
|
|
|
|
|
return explanations |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_consistency_summary(): |
|
|
""" |
|
|
Build a simple consistency heatmap across model pairs. |
|
|
""" |
|
|
dfs = { |
|
|
"Rule-based": _get_df_safe("rule"), |
|
|
"Classical": _get_df_safe("classical"), |
|
|
"Deep": _get_df_safe("deep"), |
|
|
"Hybrid": _get_df_safe("hybrid") |
|
|
} |
|
|
|
|
|
sentiment_matrix = {} |
|
|
for k, df in dfs.items(): |
|
|
if df.empty: |
|
|
continue |
|
|
sent_col = next((c for c in df.columns if "Sentiment" in c), None) |
|
|
if sent_col: |
|
|
sentiment_matrix[k] = df[sent_col].reset_index(drop=True) |
|
|
if not sentiment_matrix: |
|
|
fig = safe_plot(lambda ax: (ax.text(0.5, 0.5, "No data", ha="center"), ax.axis("off")), "Consistency") |
|
|
return fig |
|
|
|
|
|
|
|
|
models = list(sentiment_matrix.keys()) |
|
|
n = len(models) |
|
|
mat = np.zeros((n, n)) |
|
|
for i in range(n): |
|
|
for j in range(n): |
|
|
if i == j: |
|
|
mat[i, j] = 1.0 |
|
|
else: |
|
|
si = sentiment_matrix[models[i]] |
|
|
sj = sentiment_matrix[models[j]] |
|
|
m = min(len(si), len(sj)) |
|
|
if m == 0: |
|
|
mat[i, j] = 0.0 |
|
|
else: |
|
|
mat[i, j] = (si[:m] == sj[:m]).mean() |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
fig, ax = plt.subplots(figsize=(4, 4)) |
|
|
im = ax.imshow(mat, cmap="Blues") |
|
|
ax.set_xticks(range(n)) |
|
|
ax.set_xticklabels(models, rotation=45, ha="right") |
|
|
ax.set_yticks(range(n)) |
|
|
ax.set_yticklabels(models) |
|
|
for i in range(n): |
|
|
for j in range(n): |
|
|
ax.text(j, i, f"{mat[i,j]:.2f}", ha="center", va="center", color="black", fontsize=8) |
|
|
plt.title("Sentiment Consistency Matrix") |
|
|
plt.tight_layout() |
|
|
return fig |
|
|
|