"""
agent/ai_agent.py — AI Voice Agent (MySQL vector edition)
Start: py -3.12 -m uvicorn ai_agent:app --host 127.0.0.1 --port 8000 --reload
"""

import os, json, time, base64, asyncio, tempfile, logging, threading
from pathlib import Path
from typing import Optional
import numpy as np

try:
    from dotenv import load_dotenv
    load_dotenv(Path(__file__).parent / ".env")
    print("✓ .env loaded")
except ImportError:
    pass

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, Form, HTTPException, Header
from openai import AsyncOpenAI
from deepgram import DeepgramClient, LiveTranscriptionEvents, LiveOptions
from elevenlabs.client import ElevenLabs
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader, WebBaseLoader
import mysql.connector

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S")
log = logging.getLogger("voice_agent")

def req(k):
    v = os.environ.get(k, "").strip()
    if not v: raise RuntimeError(f"Missing .env: {k}")
    return v

OPENAI_KEY       = req("OPENAI_API_KEY")
DEEPGRAM_KEY     = req("DEEPGRAM_API_KEY")
ELEVENLABS_KEY   = req("ELEVENLABS_API_KEY")
ELEVENLABS_VOICE = os.environ.get("ELEVENLABS_VOICE_ID", "EXAVITQu4vr4xnSDxMaL")
AGENT_SECRET     = os.environ.get("AI_AGENT_KEY", "local_dev_secret_key_change_me")
DB_HOST          = req("DB_HOST")
DB_PORT          = int(os.environ.get("DB_PORT", "3306"))
DB_USER          = req("DB_USER")
DB_PASS          = req("DB_PASS")
DB_NAME          = req("DB_NAME")
LLM_MODEL        = os.environ.get("LLM_MODEL", "gpt-4o")

oai = AsyncOpenAI(api_key=OPENAI_KEY)
el  = ElevenLabs(api_key=ELEVENLABS_KEY)
app = FastAPI(title="AI Voice Agent")

# ── DB ───────────────────────────────────────────────────────────────────
def db_exec(sql: str, params: tuple = ()) -> list:
    conn = mysql.connector.connect(
        host=DB_HOST, port=DB_PORT, user=DB_USER,
        password=DB_PASS, database=DB_NAME,
        connection_timeout=10, ssl_disabled=True,
        consume_results=True
    )
    cur = conn.cursor(dictionary=True, buffered=True)
    cur.execute(sql, params)
    conn.commit()
    rows = cur.fetchall() if cur.description else []
    cur.close()
    conn.close()
    return rows

def get_config(key: str, default: str = "") -> str:
    try:
        rows = db_exec("SELECT config_value FROM agent_config WHERE config_key=%s", (key,))
        return rows[0]["config_value"] if rows else default
    except:
        return default

# ── Embeddings ───────────────────────────────────────────────────────────
async def get_embedding(text: str) -> list:
    resp = await oai.embeddings.create(
        model="text-embedding-3-small",
        input=text.replace("\n", " ")
    )
    return resp.data[0].embedding

async def get_embeddings_batch(texts: list) -> list:
    all_vecs = []
    for i in range(0, len(texts), 100):
        resp = await oai.embeddings.create(
            model="text-embedding-3-small",
            input=[t.replace("\n", " ") for t in texts[i:i+100]]
        )
        all_vecs.extend([x.embedding for x in resp.data])
    return all_vecs

# ── RAG ──────────────────────────────────────────────────────────────────
async def retrieve_context(query: str, top_n: int = 4) -> str:
    try:
        rows = db_exec("SELECT content, embedding FROM kb_chunks WHERE embedding IS NOT NULL")
        if not rows:
            return ""
        q_vec = np.array(await get_embedding(query), dtype=np.float32)
        scored = []
        for row in rows:
            try:
                cv    = np.array(json.loads(row["embedding"]), dtype=np.float32)
                denom = np.linalg.norm(q_vec) * np.linalg.norm(cv)
                score = float(np.dot(q_vec, cv) / denom) if denom > 0 else 0.0
                scored.append((score, row["content"]))
            except:
                continue
        scored.sort(key=lambda x: x[0], reverse=True)
        if scored:
            log.info(f"RAG scores: {[round(s,3) for s,_ in scored[:4]]}")
        # Filter out low-relevance chunks (navigation menus, junk HTML)
        relevant = [(s,t) for s,t in scored if s > 0.3]
        if not relevant:
            log.warning("RAG: no chunks above threshold — try re-embedding cleaner content")
            return ""
        return "\n\n".join(t for _,t in relevant[:top_n])
    except Exception as e:
        log.error(f"RAG error: {e}")
        return ""

# ── LLM ──────────────────────────────────────────────────────────────────
async def generate_answer(transcript: str, context: str) -> str:
    agent_name  = get_config("agent_name", "AI Assistant")
    fallback    = get_config("fallback_message", "I'm sorry, I don't have that information.")
    temperature = float(get_config("llm_temperature", "0.3"))
    system = f"""You are {agent_name}, a voice assistant answering phone calls.
Answer ONLY using the CONTEXT below. If not found say: "{fallback}"
Max 2 sentences. No bullet points or markdown.
CONTEXT: {context or '(empty)'}"""
    try:
        resp = await oai.chat.completions.create(
            model=LLM_MODEL,
            messages=[{"role":"system","content":system},{"role":"user","content":transcript}],
            temperature=temperature, max_tokens=150
        )
        return resp.choices[0].message.content.strip()
    except Exception as e:
        log.error(f"LLM error: {e}")
        return "I am having a technical issue. Please try again."

# ── TTS ──────────────────────────────────────────────────────────────────
def synthesize(text: str) -> bytes:
    try:
        return b"".join(el.generate(
            text=text, voice=ELEVENLABS_VOICE,
            model="eleven_turbo_v2", output_format="ulaw_8000"
        ))
    except Exception as e:
        log.error(f"TTS error: {e}")
        return b""

# ── WebSocket ─────────────────────────────────────────────────────────────
@app.websocket("/ws/audio")
async def audio_ws(ws: WebSocket):
    await ws.accept()
    call_sid   = ws.query_params.get("call_sid", "unknown")
    turn_index = 0
    log.info(f"[{call_sid}] Call connected")

    try:
        db_exec("UPDATE call_logs SET status='in-progress' WHERE call_sid=%s", (call_sid,))
    except:
        pass

    loop      = asyncio.get_event_loop()
    queue     = asyncio.Queue()
    stream_sid = call_sid  # will be updated from Twilio's start event

    # ── Deepgram ─────────────────────────────────────────────
    dg      = DeepgramClient(DEEPGRAM_KEY)
    dg_conn = dg.listen.live.v("1")

    def on_message(self_ref, result, **kwargs):
        try:
            alt = result.channel.alternatives[0]
            if result.is_final and alt.transcript.strip():
                text = alt.transcript.strip()
                log.info(f"[{call_sid}] Deepgram final: {text}")
                loop.call_soon_threadsafe(queue.put_nowait, text)
        except Exception as e:
            log.warning(f"on_message error: {e}")

    def on_error(self_ref, error, **kwargs):
        log.error(f"[{call_sid}] Deepgram error: {error}")

    dg_conn.on(LiveTranscriptionEvents.Transcript, on_message)
    dg_conn.on(LiveTranscriptionEvents.Error,      on_error)

    started = dg_conn.start(LiveOptions(
        model="nova-2",
        language="en-US",
        smart_format=True,
        endpointing=500,
        encoding="mulaw",
        sample_rate=8000,
        channels=1,
        interim_results=False
    ))
    log.info(f"[{call_sid}] Deepgram started: {started}")

    # ── Process transcripts from queue ───────────────────────
    async def process_queue():
        nonlocal turn_index
        while True:
            try:
                text = await asyncio.wait_for(queue.get(), timeout=0.5)
                log.info(f"[{call_sid}] Processing: {text}")

                t0      = time.time()
                context = await retrieve_context(text)
                answer  = await generate_answer(text, context)
                audio   = synthesize(answer)
                ms      = int((time.time() - t0) * 1000)

                log.info(f"[{call_sid}] Answer: {answer}")

                try:
                    db_exec("""INSERT INTO conversation_turns
                        (call_sid,turn_index,caller_said,agent_replied,latency_ms)
                        VALUES(%s,%s,%s,%s,%s)""",
                        (call_sid, turn_index, text, answer, ms))
                    turn_index += 1
                except Exception as e:
                    log.warning(f"DB save error: {e}")

                if audio:
                    await ws.send_text(json.dumps({
                        "event": "media",
                        "streamSid": stream_sid,
                        "media": {"payload": base64.b64encode(audio).decode()}
                    }))
            except asyncio.TimeoutError:
                continue
            except asyncio.CancelledError:
                break
            except Exception as e:
                log.error(f"Process error: {e}")

    processor = asyncio.create_task(process_queue())

    # ── Receive audio from Twilio ─────────────────────────────
    try:
        async for raw in ws.iter_text():
            try:
                data  = json.loads(raw)
                event = data.get("event", "")
                if event == "start":
                    # Capture the real streamSid Twilio sends
                    stream_sid = data.get("streamSid", call_sid)
                    log.info(f"[{call_sid}] Stream started, streamSid={stream_sid}")
                elif event == "media":
                    audio_bytes = base64.b64decode(data["media"]["payload"])
                    dg_conn.send(audio_bytes)
                elif event == "stop":
                    log.info(f"[{call_sid}] Stream stopped")
                    break
            except Exception as e:
                log.warning(f"Message error: {e}")
    except WebSocketDisconnect:
        log.info(f"[{call_sid}] Disconnected")
    finally:
        processor.cancel()
        try:
            dg_conn.finish()
        except:
            pass
        log.info(f"[{call_sid}] Cleaned up")

# ── Embed helpers ─────────────────────────────────────────────────────────
async def store_chunks(doc_id: int, texts: list):
    if not texts:
        raise ValueError("No chunks to store")
    log.info(f"Embedding {len(texts)} chunks...")
    vectors = await get_embeddings_batch(texts)
    for i, (text, vec) in enumerate(zip(texts, vectors)):
        db_exec("""INSERT INTO kb_chunks (document_id,chunk_index,content,embedding)
                   VALUES(%s,%s,%s,%s)""", (doc_id, i, text, json.dumps(vec)))
    db_exec("UPDATE kb_documents SET status='ready',chunk_count=%s WHERE id=%s",
            (len(texts), doc_id))
    log.info(f"✓ {len(texts)} chunks stored for doc {doc_id}")

def split_docs(docs) -> list:
    sp = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=60)
    return [c.page_content for c in sp.split_documents(docs) if c.page_content.strip()]

def split_text(text: str) -> list:
    sp = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=60)
    return [c for c in sp.split_text(text) if c.strip()]

def check_key(k):
    if k != AGENT_SECRET:
        raise HTTPException(403, "Forbidden")

# ── Admin endpoints ───────────────────────────────────────────────────────
@app.post("/embed/pdf")
async def embed_pdf(file: UploadFile=File(...), title: str=Form(...),
                    doc_id: int=Form(...), x_agent_key: Optional[str]=Header(None)):
    check_key(x_agent_key)
    db_exec("UPDATE kb_documents SET status='processing' WHERE id=%s", (doc_id,))
    suffix = Path(file.filename).suffix if file.filename else ".pdf"
    with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
        tmp.write(await file.read()); path = tmp.name
    try:
        texts = split_docs(PyPDFLoader(path).load())
        await store_chunks(doc_id, texts)
        return {"ok": True, "chunks": len(texts)}
    except Exception as e:
        db_exec("UPDATE kb_documents SET status='error',error_msg=%s WHERE id=%s", (str(e), doc_id))
        raise HTTPException(500, str(e))
    finally:
        os.unlink(path)

@app.post("/embed/url")
async def embed_url(payload: dict, x_agent_key: Optional[str]=Header(None)):
    check_key(x_agent_key)
    url, doc_id = payload.get("url",""), payload.get("doc_id")
    if not url or not doc_id: raise HTTPException(400, "url and doc_id required")
    db_exec("UPDATE kb_documents SET status='processing' WHERE id=%s", (doc_id,))
    try:
        # Use httpx to fetch raw HTML then parse with BeautifulSoup
        # This gives much cleaner text than WebBaseLoader
        import httpx
        from bs4 import BeautifulSoup
        resp = httpx.get(url, timeout=30, follow_redirects=True,
                        headers={"User-Agent": "Mozilla/5.0"})
        soup = BeautifulSoup(resp.text, "lxml")
        # Remove nav, header, footer, scripts, styles
        for tag in soup(["nav","header","footer","script","style","aside","form"]):
            tag.decompose()
        # Get clean text
        clean = soup.get_text(separator="\n", strip=True)
        # Remove very short lines (menu items, single words)
        lines = [l.strip() for l in clean.splitlines() if len(l.strip()) > 40]
        clean_text = "\n".join(lines)
        log.info(f"URL fetched: {len(clean_text)} chars after cleaning")
        texts = split_text(clean_text)
        await store_chunks(doc_id, texts)
        return {"ok": True, "chunks": len(texts)}
    except Exception as e:
        db_exec("UPDATE kb_documents SET status='error',error_msg=%s WHERE id=%s", (str(e), doc_id))
        raise HTTPException(500, str(e))

@app.post("/embed/text")
async def embed_text(payload: dict, x_agent_key: Optional[str]=Header(None)):
    check_key(x_agent_key)
    text, doc_id = payload.get("text",""), payload.get("doc_id")
    if not text or not doc_id: raise HTTPException(400, "text and doc_id required")
    db_exec("UPDATE kb_documents SET status='processing' WHERE id=%s", (doc_id,))
    try:
        texts = split_text(text)
        await store_chunks(doc_id, texts)
        return {"ok": True, "chunks": len(texts)}
    except Exception as e:
        db_exec("UPDATE kb_documents SET status='error',error_msg=%s WHERE id=%s", (str(e), doc_id))
        raise HTTPException(500, str(e))

@app.delete("/embed/{doc_id}")
async def delete_embed(doc_id: int, x_agent_key: Optional[str]=Header(None)):
    check_key(x_agent_key)
    db_exec("DELETE FROM kb_chunks WHERE document_id=%s", (doc_id,))
    return {"ok": True}

@app.get("/health")
def health():
    try:
        db_exec("SELECT 1")
        db_ok = True
        db_err = None
    except Exception as e:
        db_ok = False
        db_err = str(e)
    try:
        rows = db_exec("SELECT COUNT(*) AS n FROM kb_chunks WHERE embedding IS NOT NULL")
        chunks = rows[0]["n"] if rows else 0
    except:
        chunks = 0
    return {"status":"ok","db_connected":db_ok,"db_error":db_err,
            "chroma_chunks":chunks,"llm_model":LLM_MODEL,"vector_engine":"mysql",
            "elevenlabs_voice":ELEVENLABS_VOICE}