Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import io | |
| import tempfile | |
| import streamlit as st | |
| from huggingface_hub import InferenceClient | |
| import pdfplumber | |
| from PIL import Image | |
| import base64 | |
| # ---------- Configuration ---------- | |
| HF_TOKEN = os.environ.get("HF_TOKEN") # required | |
| GROQ_KEY = os.environ.get("GROQ_API_KEY") # optional: if you want to call Groq directly | |
| USE_GROQ_PROVIDER = True # set False to route to default HF provider | |
| # model IDs (change if you prefer other models) | |
| LLAMA_MODEL = "Groq/Llama-3-Groq-8B-Tool-Use" # Groq Llama model on HF | |
| TTS_MODEL = "espnet/kan-bayashi_ljspeech_vits" # a HF-hosted TTS model example | |
| SDXL_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" # SDXL base model | |
| # create Inference client (route via HF token by default) | |
| if USE_GROQ_PROVIDER: | |
| client = InferenceClient(provider="groq", api_key=HF_TOKEN) | |
| else: | |
| client = InferenceClient(api_key=HF_TOKEN) | |
| # ---------- Helpers ---------- | |
| def pdf_to_text(uploaded_file) -> str: | |
| text_chunks = [] | |
| with pdfplumber.open(uploaded_file) as pdf: | |
| for page in pdf.pages: | |
| ptext = page.extract_text() | |
| if ptext: | |
| text_chunks.append(ptext) | |
| return "\n\n".join(text_chunks) | |
| def llama_summarize(text, max_tokens=512): | |
| prompt = [ | |
| {"role": "system", "content": "You are a concise summarizer. Produce a clear summary in bullet points."}, | |
| {"role": "user", "content": f"Summarize the following document in <= 8 bullet points. Keep it short:\n\n{text}"} | |
| ] | |
| # Use chat completion endpoint style | |
| resp = client.chat.completions.create(model=LLAMA_MODEL, messages=prompt) | |
| try: | |
| summary = resp.choices[0].message["content"] | |
| except Exception: | |
| # fallback: try text generation field | |
| summary = resp.choices[0].text if hasattr(resp.choices[0], "text") else str(resp) | |
| return summary | |
| def llama_chat(chat_history, user_question): | |
| messages = chat_history + [{"role":"user","content":user_question}] | |
| resp = client.chat.completions.create(model=LLAMA_MODEL, messages=messages) | |
| return resp.choices[0].message["content"] | |
| def tts_synthesize(text) -> bytes: | |
| # InferenceClient offers text->audio utilities. This returns raw audio bytes (wav). | |
| audio_bytes = client.text_to_speech(model=TTS_MODEL, inputs=text) | |
| return audio_bytes | |
| def generate_image(prompt_text) -> Image.Image: | |
| img_bytes = client.text_to_image(prompt_text, model=SDXL_MODEL) | |
| return Image.open(io.BytesIO(img_bytes)) | |
| def audio_download_button(wav_bytes, filename="summary.wav"): | |
| b64 = base64.b64encode(wav_bytes).decode() | |
| href = f'<a href="data:audio/wav;base64,{b64}" download="{filename}">Download audio (WAV)</a>' | |
| st.markdown(href, unsafe_allow_html=True) | |
| # ---------- Streamlit UI ---------- | |
| st.set_page_config(page_title="PDFGPT (Groq + HF)", layout="wide") | |
| st.title("PDF → Summary + Speech + Chat + Diagram (Groq + HF)") | |
| uploaded = st.file_uploader("Upload PDF", type=["pdf"]) | |
| if uploaded: | |
| with st.spinner("Extracting text from PDF..."): | |
| text = pdf_to_text(uploaded) | |
| st.subheader("Extracted text (preview)") | |
| st.text_area("Document text", value=text[:1000], height=200) | |
| if st.button("Create summary (Groq Llama)"): | |
| with st.spinner("Summarizing with Groq Llama..."): | |
| summary = llama_summarize(text) | |
| st.subheader("Summary") | |
| st.write(summary) | |
| st.session_state["summary"] = summary | |
| if "summary" in st.session_state: | |
| summary = st.session_state["summary"] | |
| if st.button("Synthesize audio from summary (TTS)"): | |
| with st.spinner("Creating audio..."): | |
| try: | |
| audio = tts_synthesize(summary) | |
| st.audio(audio) | |
| audio_download_button(audio) | |
| except Exception as e: | |
| st.error(f"TTS failed: {e}") | |
| st.markdown("---") | |
| st.subheader("Chat with your PDF (ask questions about document)") | |
| if "chat_history" not in st.session_state: | |
| # start with system + doc context (shortened) | |
| doc_context = (text[:4000] + "...") if len(text) > 4000 else text | |
| st.session_state["chat_history"] = [ | |
| {"role":"system","content":"You are a helpful assistant that answers questions based on the provided document."}, | |
| {"role":"user","content": f"Document context:\n{doc_context}"} | |
| ] | |
| user_q = st.text_input("Ask a question about the PDF") | |
| if st.button("Ask") and user_q: | |
| with st.spinner("Getting answer from Groq Llama..."): | |
| answer = llama_chat(st.session_state["chat_history"], user_q) | |
| st.session_state.setdefault("convo", []).append(("You", user_q)) | |
| st.session_state.setdefault("convo", []).append(("Assistant", answer)) | |
| # append to history for next calls | |
| st.session_state["chat_history"].append({"role":"user","content":user_q}) | |
| st.session_state["chat_history"].append({"role":"assistant","content":answer}) | |
| st.write(answer) | |
| st.markdown("---") | |
| st.subheader("Generate a diagram from your question (SDXL)") | |
| diagram_prompt = st.text_input("Describe the diagram or scene to generate") | |
| if st.button("Generate diagram") and diagram_prompt: | |
| with st.spinner("Generating image (SDXL)..."): | |
| try: | |
| img = generate_image(diagram_prompt) | |
| st.image(img, use_column_width=True) | |
| # allow download | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG") | |
| st.download_button("Download diagram (PNG)", data=buf.getvalue(), file_name="diagram.png", mime="image/png") | |
| except Exception as e: | |
| st.error(f"Image generation failed: {e}") | |
| st.sidebar.title("Settings") | |
| st.sidebar.write("Models in use:") | |
| st.sidebar.write(f"LLM: {LLAMA_MODEL}") | |
| st.sidebar.write(f"TTS: {TTS_MODEL}") | |
| st.sidebar.write(f"Image: {SDXL_MODEL}") | |
| st.sidebar.markdown("**Notes**\n- Set HF_TOKEN in Space secrets or environment before starting.\n- To route directly to Groq with your Groq API key, set `GROQ_API_KEY` and change the client init accordingly.") | |