"""
Genesis Memory System - Qdrant Vector Store
============================================
Handles vector embeddings and semantic similarity search.
Built on Elestio Qdrant with HNSW indexing.
"""

from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import (
    Distance, VectorParams, PointStruct,
    Filter, FieldCondition, MatchValue, Range,
    SearchRequest, RecommendRequest
)
import hashlib
import json
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass
import uuid


@dataclass
class VectorSearchResult:
    """Result from vector similarity search."""
    id: str
    content: str
    score: float
    payload: Dict[str, Any]


class QdrantStore:
    """
    Qdrant-backed vector memory store.

    Collections:
    - episodic_embeddings: Episode content vectors
    - entity_embeddings: Semantic entity vectors
    - community_summaries: Community summary vectors

    Features:
    - HNSW-based approximate nearest neighbor search
    - Hybrid search with payload filtering
    - Batch vector operations
    - Quantization for memory efficiency
    """

    # Collection configurations
    COLLECTIONS = {
        "episodic_embeddings": {
            "size": 3072,  # text-embedding-3-large
            "distance": Distance.COSINE
        },
        "entity_embeddings": {
            "size": 3072,
            "distance": Distance.COSINE
        },
        "community_summaries": {
            "size": 3072,
            "distance": Distance.COSINE
        }
    }

    def __init__(
        self,
        host: str = "qdrant-b3knu-u50607.vm.elestio.app",
        port: int = 6333,
        api_key: str = "7b74e6621bd0e6650789f6662bca4cbf4143d3d1d710a0002b3b563973ca6876",
        https: bool = True
    ):
        url = f"https://{host}:{port}" if https else f"http://{host}:{port}"

        self.client = QdrantClient(
            url=url,
            api_key=api_key,
            timeout=30
        )

        # Ensure collections exist
        self._ensure_collections()

    def _ensure_collections(self):
        """Create collections if they don't exist."""
        existing = {c.name for c in self.client.get_collections().collections}

        for name, config in self.COLLECTIONS.items():
            if name not in existing:
                self.client.create_collection(
                    collection_name=name,
                    vectors_config=VectorParams(
                        size=config["size"],
                        distance=config["distance"]
                    ),
                    hnsw_config=models.HnswConfigDiff(
                        m=16,
                        ef_construct=100
                    ),
                    optimizers_config=models.OptimizersConfigDiff(
                        indexing_threshold=20000
                    )
                )

    # =========================================================================
    # EPISODIC EMBEDDINGS
    # =========================================================================

    def store_episode_embedding(
        self,
        episode_id: str,
        embedding: List[float],
        content: str,
        source_type: str = "observation",
        agent_id: Optional[str] = None,
        importance: float = 0.5,
        strength: float = 1.0,
        metadata: Optional[Dict] = None
    ) -> bool:
        """Store an episode's embedding vector."""
        point = PointStruct(
            id=episode_id,
            vector=embedding,
            payload={
                "episode_id": episode_id,
                "content_preview": content[:500],
                "source_type": source_type,
                "agent_id": agent_id or "unknown",
                "importance": importance,
                "strength": strength,
                "created_at": datetime.utcnow().isoformat(),
                "status": "active",
                **(metadata or {})
            }
        )

        self.client.upsert(
            collection_name="episodic_embeddings",
            points=[point]
        )

        return True

    def search_episodes(
        self,
        query_embedding: List[float],
        limit: int = 10,
        min_score: float = 0.0,
        source_type: Optional[str] = None,
        agent_id: Optional[str] = None,
        min_importance: float = 0.0,
        min_strength: float = 0.0,
        status: str = "active"
    ) -> List[VectorSearchResult]:
        """
        Search for similar episodes by embedding.

        Returns results sorted by similarity score.
        """
        # Build filter conditions
        conditions = []

        if source_type:
            conditions.append(
                FieldCondition(key="source_type", match=MatchValue(value=source_type))
            )

        if agent_id:
            conditions.append(
                FieldCondition(key="agent_id", match=MatchValue(value=agent_id))
            )

        if min_importance > 0:
            conditions.append(
                FieldCondition(key="importance", range=Range(gte=min_importance))
            )

        if min_strength > 0:
            conditions.append(
                FieldCondition(key="strength", range=Range(gte=min_strength))
            )

        conditions.append(
            FieldCondition(key="status", match=MatchValue(value=status))
        )

        query_filter = Filter(must=conditions) if conditions else None

        results = self.client.search(
            collection_name="episodic_embeddings",
            query_vector=query_embedding,
            query_filter=query_filter,
            limit=limit,
            score_threshold=min_score,
            with_payload=True
        )

        return [
            VectorSearchResult(
                id=str(r.id),
                content=r.payload.get("content_preview", ""),
                score=r.score,
                payload=r.payload
            )
            for r in results
        ]

    def update_episode_status(
        self,
        episode_id: str,
        status: str = "active",
        strength: Optional[float] = None
    ) -> bool:
        """Update an episode's status or strength in vector store."""
        payload_update = {"status": status}

        if strength is not None:
            payload_update["strength"] = strength

        self.client.set_payload(
            collection_name="episodic_embeddings",
            payload=payload_update,
            points=[episode_id]
        )

        return True

    def delete_episode(self, episode_id: str) -> bool:
        """Delete an episode from vector store."""
        self.client.delete(
            collection_name="episodic_embeddings",
            points_selector=models.PointIdsList(points=[episode_id])
        )
        return True

    # =========================================================================
    # ENTITY EMBEDDINGS
    # =========================================================================

    def store_entity_embedding(
        self,
        entity_id: str,
        embedding: List[float],
        name: str,
        entity_type: str,
        properties: Optional[Dict] = None,
        importance: float = 0.5
    ) -> bool:
        """Store an entity's embedding vector."""
        point = PointStruct(
            id=entity_id,
            vector=embedding,
            payload={
                "entity_id": entity_id,
                "name": name,
                "entity_type": entity_type,
                "properties": properties or {},
                "importance": importance,
                "created_at": datetime.utcnow().isoformat()
            }
        )

        self.client.upsert(
            collection_name="entity_embeddings",
            points=[point]
        )

        return True

    def search_entities(
        self,
        query_embedding: List[float],
        limit: int = 10,
        entity_type: Optional[str] = None,
        min_importance: float = 0.0
    ) -> List[VectorSearchResult]:
        """Search for similar entities by embedding."""
        conditions = []

        if entity_type:
            conditions.append(
                FieldCondition(key="entity_type", match=MatchValue(value=entity_type))
            )

        if min_importance > 0:
            conditions.append(
                FieldCondition(key="importance", range=Range(gte=min_importance))
            )

        query_filter = Filter(must=conditions) if conditions else None

        results = self.client.search(
            collection_name="entity_embeddings",
            query_vector=query_embedding,
            query_filter=query_filter,
            limit=limit,
            with_payload=True
        )

        return [
            VectorSearchResult(
                id=str(r.id),
                content=r.payload.get("name", ""),
                score=r.score,
                payload=r.payload
            )
            for r in results
        ]

    def find_related_entities(
        self,
        entity_id: str,
        limit: int = 10
    ) -> List[VectorSearchResult]:
        """Find entities similar to a given entity (recommendation)."""
        results = self.client.recommend(
            collection_name="entity_embeddings",
            positive=[entity_id],
            limit=limit,
            with_payload=True
        )

        return [
            VectorSearchResult(
                id=str(r.id),
                content=r.payload.get("name", ""),
                score=r.score,
                payload=r.payload
            )
            for r in results
        ]

    # =========================================================================
    # COMMUNITY SUMMARIES
    # =========================================================================

    def store_community_embedding(
        self,
        community_id: str,
        embedding: List[float],
        name: str,
        description: str,
        member_count: int = 0
    ) -> bool:
        """Store a community summary embedding."""
        point = PointStruct(
            id=community_id,
            vector=embedding,
            payload={
                "community_id": community_id,
                "name": name,
                "description": description[:500],
                "member_count": member_count,
                "created_at": datetime.utcnow().isoformat()
            }
        )

        self.client.upsert(
            collection_name="community_summaries",
            points=[point]
        )

        return True

    def search_communities(
        self,
        query_embedding: List[float],
        limit: int = 5
    ) -> List[VectorSearchResult]:
        """Search for relevant communities by embedding."""
        results = self.client.search(
            collection_name="community_summaries",
            query_vector=query_embedding,
            limit=limit,
            with_payload=True
        )

        return [
            VectorSearchResult(
                id=str(r.id),
                content=r.payload.get("description", ""),
                score=r.score,
                payload=r.payload
            )
            for r in results
        ]

    # =========================================================================
    # BATCH OPERATIONS
    # =========================================================================

    def batch_store_episodes(
        self,
        episodes: List[Tuple[str, List[float], Dict]]
    ) -> int:
        """
        Batch store multiple episode embeddings.

        Args:
            episodes: List of (episode_id, embedding, payload) tuples

        Returns:
            Number of episodes stored
        """
        points = []

        for episode_id, embedding, payload in episodes:
            points.append(PointStruct(
                id=episode_id,
                vector=embedding,
                payload={
                    "episode_id": episode_id,
                    "status": "active",
                    "created_at": datetime.utcnow().isoformat(),
                    **payload
                }
            ))

        if points:
            self.client.upsert(
                collection_name="episodic_embeddings",
                points=points
            )

        return len(points)

    def batch_delete_episodes(self, episode_ids: List[str]) -> int:
        """Batch delete multiple episodes."""
        if episode_ids:
            self.client.delete(
                collection_name="episodic_embeddings",
                points_selector=models.PointIdsList(points=episode_ids)
            )
        return len(episode_ids)

    # =========================================================================
    # HYBRID SEARCH
    # =========================================================================

    def hybrid_search(
        self,
        query_embedding: List[float],
        keywords: Optional[List[str]] = None,
        limit: int = 10,
        vector_weight: float = 0.7,
        keyword_weight: float = 0.3
    ) -> List[VectorSearchResult]:
        """
        Hybrid search combining vector similarity and keyword matching.

        Note: This is a simplified implementation. For production,
        consider using Qdrant's sparse vectors or external BM25.
        """
        # Vector search
        vector_results = self.search_episodes(
            query_embedding=query_embedding,
            limit=limit * 2  # Get more for reranking
        )

        if not keywords:
            return vector_results[:limit]

        # Score adjustment based on keyword presence
        keyword_set = set(k.lower() for k in keywords)

        scored_results = []
        for result in vector_results:
            content_lower = result.content.lower()
            keyword_matches = sum(1 for k in keyword_set if k in content_lower)
            keyword_score = min(keyword_matches / len(keyword_set), 1.0) if keyword_set else 0

            # Combined score
            combined_score = (
                result.score * vector_weight +
                keyword_score * keyword_weight
            )

            scored_results.append((combined_score, result))

        # Sort by combined score
        scored_results.sort(key=lambda x: x[0], reverse=True)

        return [r for _, r in scored_results[:limit]]

    # =========================================================================
    # STATISTICS
    # =========================================================================

    def get_stats(self) -> Dict[str, Any]:
        """Get Qdrant store statistics."""
        stats = {}

        for name in self.COLLECTIONS.keys():
            try:
                info = self.client.get_collection(name)
                stats[name] = {
                    "vectors_count": info.vectors_count,
                    "points_count": info.points_count,
                    "status": info.status.value,
                    "optimizer_status": info.optimizer_status.status.value
                }
            except Exception as e:
                stats[name] = {"error": str(e)}

        return stats

    def health_check(self) -> bool:
        """Check Qdrant connection health."""
        try:
            self.client.get_collections()
            return True
        except Exception:
            return False

    def get_collection_count(self, collection_name: str) -> int:
        """Get vector count in a collection."""
        try:
            info = self.client.get_collection(collection_name)
            return info.vectors_count or 0
        except Exception:
            return 0


# Singleton instance
_store = None

def get_qdrant_store() -> QdrantStore:
    """Get or create the Qdrant store singleton."""
    global _store
    if _store is None:
        _store = QdrantStore()
    return _store


if __name__ == "__main__":
    import random

    # Test the store
    store = get_qdrant_store()

    print("Qdrant Vector Store Test")
    print("=" * 50)

    # Generate a test embedding (random for demo)
    test_embedding = [random.random() for _ in range(3072)]

    # Store a test episode
    episode_id = str(uuid.uuid4())
    store.store_episode_embedding(
        episode_id=episode_id,
        embedding=test_embedding,
        content="Genesis Memory System test episode with vector embedding.",
        source_type="test",
        agent_id="test_agent",
        importance=0.8
    )
    print(f"Stored episode: {episode_id}")

    # Search (using same embedding should return high similarity)
    results = store.search_episodes(
        query_embedding=test_embedding,
        limit=5
    )
    print(f"Search results: {len(results)}")
    for r in results:
        print(f"  - {r.id}: score={r.score:.3f}")

    # Get stats
    stats = store.get_stats()
    print(f"\nStats: {stats}")
