"""
Genesis Memory System - Embedding Generator
============================================
Generates vector embeddings using OpenAI's text-embedding-3-large.
Includes caching via Redis for efficiency.
"""

import anthropic
import hashlib
import json
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import os

# Try to import OpenAI for embeddings
try:
    import openai
    OPENAI_AVAILABLE = True
except ImportError:
    OPENAI_AVAILABLE = False

# Import Redis store for caching
import sys
sys.path.insert(0, '/mnt/e/genesis-system/genesis-memory')
from storage.redis_store import get_redis_store


@dataclass
class EmbeddingResult:
    """Result from embedding generation."""
    embedding: List[float]
    content_hash: str
    model: str
    dimensions: int
    cached: bool


class EmbeddingGenerator:
    """
    Generates embeddings for memory content.

    Primary: OpenAI text-embedding-3-large (3072 dimensions)
    Fallback: Trigram hash-based embeddings (for testing without API)

    Features:
    - Redis caching to avoid redundant API calls
    - Batch processing for efficiency
    - Automatic fallback on errors
    """

    # Models
    OPENAI_MODEL = "text-embedding-3-large"
    OPENAI_DIMENSIONS = 3072

    def __init__(
        self,
        openai_api_key: Optional[str] = None,
        use_cache: bool = True,
        fallback_to_hash: bool = True
    ):
        self.openai_api_key = openai_api_key or os.environ.get("OPENAI_API_KEY")
        self.use_cache = use_cache
        self.fallback_to_hash = fallback_to_hash

        # Initialize OpenAI client if available
        self.openai_client = None
        if OPENAI_AVAILABLE and self.openai_api_key:
            self.openai_client = openai.OpenAI(api_key=self.openai_api_key)

        # Initialize Redis for caching
        self.redis = None
        if use_cache:
            try:
                self.redis = get_redis_store()
            except Exception:
                self.redis = None

    def _get_content_hash(self, content: str) -> str:
        """Generate hash for content (used as cache key)."""
        return hashlib.sha256(content.encode()).hexdigest()

    def _check_cache(self, content_hash: str) -> Optional[List[float]]:
        """Check if embedding is cached."""
        if self.redis:
            return self.redis.get_cached_embedding(content_hash)
        return None

    def _store_cache(self, content_hash: str, embedding: List[float]) -> bool:
        """Store embedding in cache."""
        if self.redis:
            return self.redis.cache_embedding(content_hash, embedding)
        return False

    def _generate_openai_embedding(self, content: str) -> List[float]:
        """Generate embedding using OpenAI API."""
        if not self.openai_client:
            raise ValueError("OpenAI client not initialized")

        response = self.openai_client.embeddings.create(
            model=self.OPENAI_MODEL,
            input=content,
            dimensions=self.OPENAI_DIMENSIONS
        )

        return response.data[0].embedding

    def _generate_hash_embedding(self, content: str, dimensions: int = 3072) -> List[float]:
        """
        Generate a deterministic hash-based embedding.

        This is a fallback for when OpenAI is unavailable.
        Uses character trigrams and hashing for reproducible vectors.
        """
        # Normalize content
        content = content.lower().strip()

        # Generate trigrams
        trigrams = []
        for i in range(len(content) - 2):
            trigrams.append(content[i:i+3])

        # Create embedding vector
        embedding = [0.0] * dimensions

        for trigram in trigrams:
            # Hash trigram to get index positions
            h = hashlib.md5(trigram.encode()).hexdigest()

            for i in range(0, len(h), 4):
                idx = int(h[i:i+4], 16) % dimensions
                # Add contribution based on position in hash
                sign = 1 if int(h[i:i+2], 16) % 2 == 0 else -1
                embedding[idx] += sign * 0.1

        # Normalize to unit vector
        magnitude = sum(x**2 for x in embedding) ** 0.5
        if magnitude > 0:
            embedding = [x / magnitude for x in embedding]

        return embedding

    def generate(
        self,
        content: str,
        use_openai: bool = True
    ) -> EmbeddingResult:
        """
        Generate embedding for content.

        Args:
            content: Text to embed
            use_openai: Whether to try OpenAI API first

        Returns:
            EmbeddingResult with embedding vector and metadata
        """
        content_hash = self._get_content_hash(content)

        # Check cache first
        cached = self._check_cache(content_hash)
        if cached:
            return EmbeddingResult(
                embedding=cached,
                content_hash=content_hash,
                model="cached",
                dimensions=len(cached),
                cached=True
            )

        # Try OpenAI if available and requested
        if use_openai and self.openai_client:
            try:
                embedding = self._generate_openai_embedding(content)

                # Cache result
                self._store_cache(content_hash, embedding)

                return EmbeddingResult(
                    embedding=embedding,
                    content_hash=content_hash,
                    model=self.OPENAI_MODEL,
                    dimensions=self.OPENAI_DIMENSIONS,
                    cached=False
                )
            except Exception as e:
                if not self.fallback_to_hash:
                    raise e
                # Fall through to hash embedding

        # Fallback to hash-based embedding
        embedding = self._generate_hash_embedding(content, self.OPENAI_DIMENSIONS)

        # Cache result
        self._store_cache(content_hash, embedding)

        return EmbeddingResult(
            embedding=embedding,
            content_hash=content_hash,
            model="hash_fallback",
            dimensions=self.OPENAI_DIMENSIONS,
            cached=False
        )

    def generate_batch(
        self,
        contents: List[str],
        use_openai: bool = True
    ) -> List[EmbeddingResult]:
        """
        Generate embeddings for multiple contents.

        Optimizes by batching OpenAI requests and checking cache first.
        """
        results = []
        to_embed = []
        to_embed_indices = []

        # Check cache for each content
        for i, content in enumerate(contents):
            content_hash = self._get_content_hash(content)
            cached = self._check_cache(content_hash)

            if cached:
                results.append(EmbeddingResult(
                    embedding=cached,
                    content_hash=content_hash,
                    model="cached",
                    dimensions=len(cached),
                    cached=True
                ))
            else:
                results.append(None)  # Placeholder
                to_embed.append(content)
                to_embed_indices.append(i)

        # Generate embeddings for uncached content
        if to_embed and use_openai and self.openai_client:
            try:
                # OpenAI batch embedding
                response = self.openai_client.embeddings.create(
                    model=self.OPENAI_MODEL,
                    input=to_embed,
                    dimensions=self.OPENAI_DIMENSIONS
                )

                for j, embedding_data in enumerate(response.data):
                    idx = to_embed_indices[j]
                    content = contents[idx]
                    content_hash = self._get_content_hash(content)
                    embedding = embedding_data.embedding

                    # Cache
                    self._store_cache(content_hash, embedding)

                    results[idx] = EmbeddingResult(
                        embedding=embedding,
                        content_hash=content_hash,
                        model=self.OPENAI_MODEL,
                        dimensions=self.OPENAI_DIMENSIONS,
                        cached=False
                    )

                return results

            except Exception:
                # Fall through to hash fallback
                pass

        # Fallback: generate hash embeddings for remaining
        for j, content in enumerate(to_embed):
            idx = to_embed_indices[j]
            content_hash = self._get_content_hash(content)
            embedding = self._generate_hash_embedding(content, self.OPENAI_DIMENSIONS)

            # Cache
            self._store_cache(content_hash, embedding)

            results[idx] = EmbeddingResult(
                embedding=embedding,
                content_hash=content_hash,
                model="hash_fallback",
                dimensions=self.OPENAI_DIMENSIONS,
                cached=False
            )

        return results

    def similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
        """Calculate cosine similarity between two embeddings."""
        if len(embedding1) != len(embedding2):
            raise ValueError("Embeddings must have same dimensions")

        dot_product = sum(a * b for a, b in zip(embedding1, embedding2))
        magnitude1 = sum(a**2 for a in embedding1) ** 0.5
        magnitude2 = sum(b**2 for b in embedding2) ** 0.5

        if magnitude1 == 0 or magnitude2 == 0:
            return 0.0

        return dot_product / (magnitude1 * magnitude2)


# Singleton instance
_generator = None

def get_embedding_generator() -> EmbeddingGenerator:
    """Get or create the embedding generator singleton."""
    global _generator
    if _generator is None:
        _generator = EmbeddingGenerator()
    return _generator


if __name__ == "__main__":
    # Test the generator
    generator = get_embedding_generator()

    print("Embedding Generator Test")
    print("=" * 50)

    # Test single embedding
    result = generator.generate(
        "Genesis Memory System is a world-class AI memory architecture."
    )

    print(f"Model: {result.model}")
    print(f"Dimensions: {result.dimensions}")
    print(f"Cached: {result.cached}")
    print(f"First 5 values: {result.embedding[:5]}")

    # Test similarity
    result2 = generator.generate(
        "Advanced memory system for artificial intelligence."
    )

    sim = generator.similarity(result.embedding, result2.embedding)
    print(f"\nSimilarity to related text: {sim:.4f}")

    # Test batch
    batch_results = generator.generate_batch([
        "First test content",
        "Second test content",
        "Third test content"
    ])
    print(f"\nBatch results: {len(batch_results)} embeddings generated")
