#!/usr/bin/env python3
"""
Memory Auto-Ingestion Pipeline
================================
Watches E:\\genesis-system\\Research reports\\ for new files.
Auto-ingests new research reports into Qdrant vectors.
Uses PostgreSQL for deduplication tracking.

Supports: .md, .txt, .pdf, .docx files
Runs as a daemon or one-shot via --scan-once flag.

Usage:
    python memory_auto_ingest.py               # daemon mode
    python memory_auto_ingest.py --scan-once   # scan existing + exit
    python memory_auto_ingest.py --file <path> # ingest single file
"""

import argparse
import hashlib
import json
import logging
import re
import sys
import time
from datetime import datetime
from pathlib import Path

# ---------------------------------------------------------------------------
# Path setup
# ---------------------------------------------------------------------------
GENESIS_ROOT = Path("E:/genesis-system")
sys.path.insert(0, str(GENESIS_ROOT / "data" / "genesis-memory"))

RESEARCH_DIR = GENESIS_ROOT / "Research reports"
KG_ENTITIES_DIR = GENESIS_ROOT / "KNOWLEDGE_GRAPH" / "entities"
INGEST_LOG = KG_ENTITIES_DIR / "auto_ingest_log.jsonl"
WATCH_EXTENSIONS = {".md", ".txt", ".pdf", ".docx"}
POLL_INTERVAL_SECONDS = 30
EMBEDDING_DIM = 1536  # Must match Qdrant collection vector size

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [auto-ingest] %(levelname)s %(message)s",
)
logger = logging.getLogger("memory_auto_ingest")


# ---------------------------------------------------------------------------
# Deduplication via PostgreSQL
# ---------------------------------------------------------------------------

def _get_pg_conn():
    try:
        import psycopg2
        from elestio_config import PostgresConfig
        conn = psycopg2.connect(
            connect_timeout=8,
            **PostgresConfig.get_connection_params()
        )
        return conn
    except Exception as e:
        logger.error("PostgreSQL connection failed: %s", e)
        return None


def _ensure_ingest_table(conn):
    """Create the ingest tracking table if it doesn't exist."""
    try:
        cur = conn.cursor()
        cur.execute("""
            CREATE TABLE IF NOT EXISTS genesis_ingest_tracker (
                id SERIAL PRIMARY KEY,
                file_path TEXT NOT NULL UNIQUE,
                file_hash TEXT NOT NULL,
                ingested_at TIMESTAMP DEFAULT NOW(),
                chunk_count INTEGER DEFAULT 0,
                qdrant_ids TEXT[],
                status TEXT DEFAULT 'pending'
            )
        """)
        conn.commit()
        cur.close()
    except Exception as e:
        logger.error("Failed to create ingest table: %s", e)
        conn.rollback()


def _is_already_ingested(conn, file_path: str, file_hash: str) -> bool:
    """Check if this file (by path + hash) has already been ingested."""
    try:
        cur = conn.cursor()
        cur.execute(
            "SELECT id FROM genesis_ingest_tracker WHERE file_path = %s AND file_hash = %s AND status = 'done'",
            (file_path, file_hash)
        )
        result = cur.fetchone()
        cur.close()
        return result is not None
    except Exception as e:
        logger.error("Dedup check failed: %s", e)
        return False


def _record_ingest(conn, file_path: str, file_hash: str, chunk_count: int, qdrant_ids: list):
    """Record a successful ingest in PostgreSQL."""
    try:
        cur = conn.cursor()
        cur.execute("""
            INSERT INTO genesis_ingest_tracker (file_path, file_hash, chunk_count, qdrant_ids, status)
            VALUES (%s, %s, %s, %s, 'done')
            ON CONFLICT (file_path)
            DO UPDATE SET
                file_hash = EXCLUDED.file_hash,
                ingested_at = NOW(),
                chunk_count = EXCLUDED.chunk_count,
                qdrant_ids = EXCLUDED.qdrant_ids,
                status = 'done'
        """, (file_path, file_hash, chunk_count, qdrant_ids))
        conn.commit()
        cur.close()
        logger.info("Recorded ingest: %s (%d chunks)", file_path, chunk_count)
    except Exception as e:
        logger.error("Failed to record ingest: %s", e)
        conn.rollback()


# ---------------------------------------------------------------------------
# Text extraction
# ---------------------------------------------------------------------------

def _extract_text(file_path: Path) -> str:
    """Extract text from a file based on its extension."""
    ext = file_path.suffix.lower()

    if ext in (".md", ".txt"):
        try:
            return file_path.read_text(encoding="utf-8", errors="replace")
        except Exception as e:
            logger.error("Failed to read %s: %s", file_path, e)
            return ""

    elif ext == ".pdf":
        try:
            import pdfplumber
            text_parts = []
            with pdfplumber.open(str(file_path)) as pdf:
                for page in pdf.pages:
                    page_text = page.extract_text()
                    if page_text:
                        text_parts.append(page_text)
            return "\n\n".join(text_parts)
        except ImportError:
            logger.warning("pdfplumber not installed; trying PyPDF2 for %s", file_path)
            try:
                import PyPDF2
                text_parts = []
                with open(file_path, "rb") as f:
                    reader = PyPDF2.PdfReader(f)
                    for page in reader.pages:
                        text_parts.append(page.extract_text() or "")
                return "\n\n".join(text_parts)
            except Exception as e:
                logger.error("PDF extraction failed for %s: %s", file_path, e)
                return ""
        except Exception as e:
            logger.error("PDF extraction failed for %s: %s", file_path, e)
            return ""

    elif ext == ".docx":
        try:
            import docx
            doc = docx.Document(str(file_path))
            return "\n".join(para.text for para in doc.paragraphs if para.text.strip())
        except ImportError:
            logger.warning("python-docx not installed; skipping %s", file_path)
            return ""
        except Exception as e:
            logger.error("DOCX extraction failed for %s: %s", file_path, e)
            return ""

    return ""


# ---------------------------------------------------------------------------
# Chunking
# ---------------------------------------------------------------------------

def _chunk_text(text: str, chunk_size: int = 800, overlap: int = 100) -> list[str]:
    """Split text into overlapping chunks for vector embedding."""
    # Split on paragraphs first, then by size
    paragraphs = re.split(r"\n\s*\n", text)
    chunks = []
    current_chunk = []
    current_len = 0

    for para in paragraphs:
        para = para.strip()
        if not para:
            continue
        para_len = len(para)

        if current_len + para_len > chunk_size and current_chunk:
            chunk_text = "\n\n".join(current_chunk)
            chunks.append(chunk_text)
            # Keep overlap by retaining last paragraph if short enough
            if current_chunk and len(current_chunk[-1]) <= overlap:
                current_chunk = [current_chunk[-1]]
                current_len = len(current_chunk[0])
            else:
                current_chunk = []
                current_len = 0

        current_chunk.append(para)
        current_len += para_len

    if current_chunk:
        chunks.append("\n\n".join(current_chunk))

    return [c for c in chunks if len(c.strip()) > 50]


# ---------------------------------------------------------------------------
# Embedding (via OpenAI-compatible API or fallback)
# ---------------------------------------------------------------------------

def _get_embeddings(texts: list[str]) -> list[list[float]]:
    """Get embeddings for a list of texts using OpenAI API."""
    try:
        import openai
        import os
        api_key = os.environ.get("OPENAI_API_KEY", "")
        if not api_key:
            # Try OpenRouter
            api_key = os.environ.get("OPENROUTER_API_KEY", "")
            base_url = "https://openrouter.ai/api/v1"
        else:
            base_url = "https://api.openai.com/v1"

        if not api_key:
            logger.warning("No embedding API key found; using zero vectors (fallback)")
            return [[0.0] * EMBEDDING_DIM for _ in texts]

        client = openai.OpenAI(api_key=api_key, base_url=base_url)
        response = client.embeddings.create(
            model="text-embedding-3-small",
            input=texts,
        )
        return [item.embedding for item in response.data]

    except Exception as e:
        logger.error("Embedding failed: %s — using zero vectors", e)
        return [[0.0] * EMBEDDING_DIM for _ in texts]


# ---------------------------------------------------------------------------
# Qdrant upsert
# ---------------------------------------------------------------------------

def _upsert_to_qdrant(
    chunks: list[str],
    embeddings: list[list[float]],
    file_path: str,
    file_name: str,
) -> list[str]:
    """Upsert chunks with embeddings into Qdrant. Returns list of point IDs."""
    try:
        from qdrant_client import QdrantClient
        from qdrant_client.models import PointStruct, VectorParams, Distance
        from elestio_config import QdrantConfig
        import uuid

        cfg = QdrantConfig()
        client = QdrantClient(url=cfg.url, api_key=cfg.api_key)

        # Ensure collection exists
        try:
            client.get_collection(cfg.collection_name)
        except Exception:
            logger.info("Creating Qdrant collection: %s", cfg.collection_name)
            client.create_collection(
                collection_name=cfg.collection_name,
                vectors_config=VectorParams(size=EMBEDDING_DIM, distance=Distance.COSINE),
            )

        points = []
        point_ids = []
        for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
            point_id = str(uuid.uuid4())
            point_ids.append(point_id)
            points.append(PointStruct(
                id=point_id,
                vector=embedding,
                payload={
                    "text": chunk[:2000],
                    "source_file": file_name,
                    "source_path": file_path,
                    "chunk_index": i,
                    "ingested_at": datetime.utcnow().isoformat() + "Z",
                    "category": "research_report",
                },
            ))

        # Upsert in batches of 50
        batch_size = 50
        for i in range(0, len(points), batch_size):
            batch = points[i:i + batch_size]
            client.upsert(collection_name=cfg.collection_name, points=batch)

        logger.info("Upserted %d chunks to Qdrant for: %s", len(points), file_name)
        return point_ids

    except Exception as e:
        logger.error("Qdrant upsert failed for %s: %s", file_name, e)
        return []


# ---------------------------------------------------------------------------
# Main ingest function
# ---------------------------------------------------------------------------

def ingest_file(file_path: Path, conn=None) -> dict:
    """Ingest a single file into Qdrant with PG dedup tracking."""
    close_conn = False
    if conn is None:
        conn = _get_pg_conn()
        close_conn = True
        if conn is None:
            return {"status": "error", "error": "No database connection"}
        _ensure_ingest_table(conn)

    result = {
        "file": str(file_path),
        "file_name": file_path.name,
        "timestamp": datetime.utcnow().isoformat() + "Z",
        "status": "skipped",
        "chunk_count": 0,
        "qdrant_ids": [],
    }

    try:
        # Hash file for dedup
        file_content_bytes = file_path.read_bytes()
        file_hash = hashlib.sha256(file_content_bytes).hexdigest()[:16]

        # Check if already ingested with same hash
        if _is_already_ingested(conn, str(file_path), file_hash):
            logger.debug("Already ingested (unchanged): %s", file_path.name)
            result["status"] = "skipped"
            return result

        # Extract text
        text = _extract_text(file_path)
        if not text or len(text.strip()) < 100:
            logger.info("No extractable text in: %s", file_path.name)
            result["status"] = "empty"
            return result

        # Chunk
        chunks = _chunk_text(text)
        if not chunks:
            result["status"] = "empty"
            return result

        logger.info("Ingesting %s: %d chunks from %d chars", file_path.name, len(chunks), len(text))

        # Embed in batches
        all_embeddings = []
        batch_size = 20
        for i in range(0, len(chunks), batch_size):
            batch = chunks[i:i + batch_size]
            embeddings = _get_embeddings(batch)
            all_embeddings.extend(embeddings)

        # Upsert to Qdrant
        point_ids = _upsert_to_qdrant(chunks, all_embeddings, str(file_path), file_path.name)

        # Record in PostgreSQL
        _record_ingest(conn, str(file_path), file_hash, len(chunks), point_ids)

        # Log to KG entity file
        KG_ENTITIES_DIR.mkdir(parents=True, exist_ok=True)
        log_entry = {
            "id": f"ingest_{file_hash}",
            "timestamp": result["timestamp"],
            "file_name": file_path.name,
            "file_path": str(file_path),
            "chunk_count": len(chunks),
            "qdrant_point_count": len(point_ids),
            "text_length": len(text),
        }
        with open(INGEST_LOG, "a", encoding="utf-8") as f:
            f.write(json.dumps(log_entry) + "\n")

        result["status"] = "ingested"
        result["chunk_count"] = len(chunks)
        result["qdrant_ids"] = point_ids[:5]  # Keep first 5 for reference
        return result

    except Exception as e:
        logger.error("Failed to ingest %s: %s", file_path, e)
        result["status"] = "error"
        result["error"] = str(e)
        return result
    finally:
        if close_conn and conn:
            try:
                conn.close()
            except Exception:
                pass


# ---------------------------------------------------------------------------
# Directory scanner
# ---------------------------------------------------------------------------

def scan_and_ingest(directory: Path = RESEARCH_DIR) -> dict:
    """Scan directory and ingest any new/changed files."""
    if not directory.exists():
        logger.warning("Research reports directory not found: %s", directory)
        return {"scanned": 0, "ingested": 0, "skipped": 0, "errors": 0}

    conn = _get_pg_conn()
    if conn:
        _ensure_ingest_table(conn)

    stats = {"scanned": 0, "ingested": 0, "skipped": 0, "errors": 0}

    # Recursively find all supported files
    for file_path in directory.rglob("*"):
        if file_path.suffix.lower() not in WATCH_EXTENSIONS:
            continue
        if file_path.stat().st_size < 100:
            continue  # Skip near-empty files

        stats["scanned"] += 1
        result = ingest_file(file_path, conn=conn)

        if result["status"] == "ingested":
            stats["ingested"] += 1
        elif result["status"] == "skipped":
            stats["skipped"] += 1
        elif result["status"] == "error":
            stats["errors"] += 1

    if conn:
        conn.close()

    logger.info(
        "Scan complete: %d scanned, %d ingested, %d skipped, %d errors",
        stats["scanned"], stats["ingested"], stats["skipped"], stats["errors"]
    )
    return stats


# ---------------------------------------------------------------------------
# Daemon mode - watches for new files
# ---------------------------------------------------------------------------

def watch_daemon(directory: Path = RESEARCH_DIR, interval: int = POLL_INTERVAL_SECONDS):
    """Run as a daemon, polling for new files every interval seconds."""
    logger.info("Memory Auto-Ingest daemon started (watching: %s, interval: %ds)", directory, interval)

    # Track known files and their mtimes
    known_files: dict = {}

    while True:
        try:
            if not directory.exists():
                logger.warning("Watch directory missing: %s", directory)
                time.sleep(interval)
                continue

            # Find new or modified files
            conn = _get_pg_conn()
            if conn:
                _ensure_ingest_table(conn)

            new_or_changed = []
            for file_path in directory.rglob("*"):
                if file_path.suffix.lower() not in WATCH_EXTENSIONS:
                    continue
                if file_path.stat().st_size < 100:
                    continue

                mtime = file_path.stat().st_mtime
                fp_str = str(file_path)

                if fp_str not in known_files or known_files[fp_str] != mtime:
                    new_or_changed.append(file_path)
                    known_files[fp_str] = mtime

            if new_or_changed:
                logger.info("Found %d new/changed files to process", len(new_or_changed))
                for file_path in new_or_changed:
                    result = ingest_file(file_path, conn=conn)
                    logger.info("  %s -> %s (%d chunks)", file_path.name, result["status"], result.get("chunk_count", 0))

            if conn:
                conn.close()

        except KeyboardInterrupt:
            logger.info("Daemon stopped by user.")
            break
        except Exception as e:
            logger.error("Daemon error: %s", e)

        try:
            time.sleep(interval)
        except KeyboardInterrupt:
            logger.info("Daemon stopped by user.")
            break


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(description="Genesis Memory Auto-Ingest Pipeline")
    parser.add_argument("--scan-once", action="store_true", help="Scan once and exit")
    parser.add_argument("--file", type=str, help="Ingest a single specific file")
    parser.add_argument("--dir", type=str, default=str(RESEARCH_DIR), help="Directory to watch")
    parser.add_argument("--interval", type=int, default=POLL_INTERVAL_SECONDS, help="Poll interval (daemon mode)")
    args = parser.parse_args()

    watch_dir = Path(args.dir)

    if args.file:
        file_path = Path(args.file)
        if not file_path.exists():
            logger.error("File not found: %s", args.file)
            sys.exit(1)
        result = ingest_file(file_path)
        print(json.dumps(result, indent=2))
        sys.exit(0 if result["status"] in ("ingested", "skipped") else 1)

    elif args.scan_once:
        stats = scan_and_ingest(watch_dir)
        print(json.dumps(stats, indent=2))
        sys.exit(0)

    else:
        watch_daemon(watch_dir, interval=args.interval)


if __name__ == "__main__":
    main()
