"""# Generated from swarm story V-K04: Post-Call Memory Extraction
"""

"""
Post-Call Memory Extraction System.

This module provides asynchronous memory extraction from call transcripts using
Google Gemini Flash. It implements structured output parsing, content-based
deduplication, and confidence-based filtering.

Example:
    >>> extractor = MemoryExtractor(storage=RedisMemoryStore())
    >>> await extractor.process_call(
    ...     call_id="call_123",
    ...     transcript="User agreed to meet tomorrow at 3pm...",
    ...     metadata={"user_id": "user_456"}
    ... )
"""

from __future__ import annotations

import hashlib
import json
import logging
import asyncio
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Set, Union
from dataclasses import dataclass, field

import google.generativeai as genai
from google.api_core.exceptions import GoogleAPIError, RetryError
from pydantic import BaseModel, Field, validator, ValidationError
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

# Configure module-level logging
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


class MemoryType(str, Enum):
    """Classification of memory types extracted from conversations."""
    ENTITY = "entity"           # People, organizations, locations, objects
    DECISION = "decision"       # Agreements, choices made
    PREFERENCE = "preference"   # Likes, dislikes, preferences stated
    ACTION_ITEM = "action_item" # Tasks, deadlines, commitments
    TOPIC = "topic"             # Subjects discussed


class MemoryItem(BaseModel):
    """
    Represents a single extracted memory with metadata.
    
    Attributes:
        content: The normalized fact/memory text
        memory_type: Classification of the memory
        confidence: Model confidence score (0.0-1.0)
        source_call_id: Reference to originating call
        timestamp: UTC extraction timestamp
        content_hash: SHA256 hash for deduplication
        metadata: Additional context (speakers, turn numbers, etc.)
    """
    
    content: str = Field(..., min_length=1, description="Extracted memory content")
    memory_type: MemoryType = Field(..., description="Type of memory")
    confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
    source_call_id: str = Field(..., description="Source call identifier")
    timestamp: datetime = Field(default_factory=datetime.utcnow)
    content_hash: str = Field(..., description="SHA256 hash of normalized content")
    metadata: Dict[str, Any] = Field(default_factory=dict)
    
    @validator("content")
    def normalize_content(cls, v: str) -> str:
        """Normalize content for consistent hashing."""
        return v.strip().lower()
    
    def __hash__(self) -> int:
        """Enable use in sets/dicts based on content hash."""
        return hash(self.content_hash)
    
    def __eq__(self, other: object) -> bool:
        """Equality based on content hash for deduplication."""
        if not isinstance(other, MemoryItem):
            return NotImplemented
        return self.content_hash == other.content_hash


class ExtractionResult(BaseModel):
    """Structured output from the LLM extraction."""
    
    memories: List[MemoryItem] = Field(
        default_factory=list,
        description="List of extracted memories"
    )
    
    @validator("memories")
    def validate_confidence_threshold(cls, v: List[MemoryItem], values: Dict[str, Any]) -> List[MemoryItem]:
        """Ensure all memories meet minimum confidence requirements."""
        # Note: Actual filtering happens in extractor with configurable threshold
        return v


class MemoryRepository(ABC):
    """
    Abstract base class for memory persistence.
    
    Implementations must provide atomic check-and-set operations
    for deduplication and storage.
    """
    
    @abstractmethod
    async def exists(self, content_hash: str) -> bool:
        """
        Check if a memory with given hash already exists.
        
        Args:
            content_hash: SHA256 hash of normalized content
            
        Returns:
            True if memory exists, False otherwise
        """
        pass
    
    @abstractmethod
    async def store(self, memory: MemoryItem) -> None:
        """
        Persist a memory item.
        
        Args:
            memory: Validated MemoryItem to store
            
        Raises:
            StorageError: If persistence fails
        """
        pass
    
    @abstractmethod
    async def get_by_call(self, call_id: str) -> List[MemoryItem]:
        """
        Retrieve all memories for a specific call.
        
        Args:
            call_id: The call identifier
            
        Returns:
            List of memories associated with the call
        """
        pass


class InMemoryRepository(MemoryRepository):
    """
    In-memory implementation for testing and development.
    
    Not suitable for production due to lack of persistence.
    """
    
    def __init__(self) -> None:
        self._storage: Dict[str, MemoryItem] = {}
        self._call_index: Dict[str, Set[str]] = {}
    
    async def exists(self, content_hash: str) -> bool:
        return content_hash in self._storage
    
    async def store(self, memory: MemoryItem) -> None:
        self._storage[memory.content_hash] = memory
        if memory.source_call_id not in self._call_index:
            self._call_index[memory.source_call_id] = set()
        self._call_index[memory.source_call_id].add(memory.content_hash)
    
    async def get_by_call(self, call_id: str) -> List[MemoryItem]:
        hashes = self._call_index.get(call_id, set())
        return [self._storage[h] for h in hashes]


class ExtractionConfig(BaseModel):
    """
    Configuration for memory extraction behavior.
    
    Attributes:
        confidence_threshold: Minimum confidence to store (0.0-1.0)
        model_name: Gemini model identifier
        max_retries: API retry attempts for transient failures
        temperature: LLM sampling temperature
        timeout_seconds: API call timeout
    """
    
    confidence_threshold: float = Field(default=0.7, ge=0.0, le=1.0)
    model_name: str = Field(default="gemini-1.5-flash")
    max_retries: int = Field(default=3, ge=1)
    temperature: float = Field(default=0.1, ge=0.0, le=2.0)
    timeout_seconds: int = Field(default=30, ge=1)
    
    class Config:
        frozen = True  # Immutable configuration


class ExtractionError(Exception):
    """Base exception for extraction failures."""
    pass


class LLMExtractionError(ExtractionError):
    """Raised when LLM extraction fails."""
    pass


class StorageError(ExtractionError):
    """Raised when memory persistence fails."""
    pass


class MemoryExtractor:
    """
    Core orchestrator for post-call memory extraction.
    
    Handles the complete pipeline: preprocessing, LLM extraction,
    validation, deduplication, filtering, and storage.
    
    Attributes:
        repository: Storage backend for memories
        config: Extraction behavior configuration
        _gen_model: Initialized Gemini model instance
    """
    
    # System prompt engineered for structured extraction
    _EXTRACTION_PROMPT = """You are a precise memory extraction system. Analyze the provided conversation transcript and extract factual information as structured memories.

Extract the following types:
1. **entities**: People, companies, products, locations mentioned (e.g., "User works at Google")
2. **decisions**: Agreements, choices, conclusions reached (e.g., "Decided to use Python for the project")
3. **preferences**: Likes, dislikes, stated preferences (e.g., "Prefers morning meetings")
4. **action_items**: Tasks, deadlines, commitments (e.g., "Will send report by Friday")
5. **topics**: Subjects discussed (e.g., "Discussed Q4 budget planning")

For each memory:
- Provide normalized, factual content (3rd person, past tense)
- Assign confidence score (0.0-1.0) based on clarity and explicitness
- Classify into exactly one type

Output MUST be valid JSON matching this schema:
{
  "memories": [
    {
      "content": "string",
      "memory_type": "entity|decision|preference|action_item|topic",
      "confidence": float,
      "metadata": {"speakers": ["name"], "turn": int}
    }
  ]
}

Rules:
- Only extract explicit facts, not inferences
- Confidence < 0.7 for ambiguous or implied information
- Normalize names and dates to standard formats
- If no memories found, return empty array"""
    
    def __init__(
        self,
        repository: MemoryRepository,
        config: Optional[ExtractionConfig] = None,
        api_key: Optional[str] = None
    ) -> None:
        """
        Initialize the extractor.
        
        Args:
            repository: Storage backend implementation
            config: Optional custom configuration
            api_key: Google API key (or set via env var GOOGLE_API_KEY)
        """
        self.repository = repository
        self.config = config or ExtractionConfig()
        
        # Initialize Gemini client
        if api_key:
            genai.configure(api_key=api_key)
        
        self._gen_model = genai.GenerativeModel(
            model_name=self.config.model_name,
            generation_config={
                "temperature": self.config.temperature,
                "response_mime_type": "application/json",
            }
        )
        
        logger.info(
            f"Initialized MemoryExtractor with model={self.config.model_name}, "
            f"confidence_threshold={self.confidence_threshold}"
        )
    
    @property
    def confidence_threshold(self) -> float:
        """Current confidence filtering threshold."""
        return self.config.confidence_threshold
    
    def _compute_content_hash(self, content: str) -> str:
        """
        Generate SHA256 hash of normalized content for deduplication.
        
        Args:
            content: Raw memory content
            
        Returns:
            Hex digest of SHA256 hash
        """
        normalized = content.strip().lower().encode("utf-8")
        return hashlib.sha256(normalized).hexdigest()
    
    @retry(
        retry=retry_if_exception_type((GoogleAPIError, RetryError)),
        stop=stop_after_attempt(3),
        wait=wait_exponential(multiplier=1, min=2, max=10),
        reraise=True
    )
    async def _extract_with_llm(
        self,
        transcript: str,
        metadata: Dict[str, Any]
    ) -> ExtractionResult:
        """
        Call Gemini Flash to extract structured memories.
        
        Args:
            transcript: Call transcript text
            metadata: Call metadata (participants, duration, etc.)
            
        Returns:
            Parsed ExtractionResult
            
        Raises:
            LLMExtractionError: If extraction or parsing fails
        """
        try:
            # Construct prompt with context
            prompt = f"{self._EXTRACTION_PROMPT}\n\nMetadata: {json.dumps(metadata)}\n\nTranscript:\n{transcript}"
            
            # Execute async generation
            loop = asyncio.get_event_loop()
            response = await loop.run_in_executor(
                None,
                lambda: self._gen_model.generate_content(
                    prompt,
                    generation_config={"response_mime_type": "application/json"}
                )
            )
            
            # Parse JSON response
            raw_text = response.text
            if not raw_text:
                raise LLMExtractionError("Empty response from LLM")
            
            data = json.loads(raw_text)
            return ExtractionResult(**data)
            
        except json.JSONDecodeError as e:
            logger.error(f"Failed to parse LLM response as JSON: {e}")
            raise LLMExtractionError(f"Invalid JSON in LLM response: {e}")
        except ValidationError as e:
            logger.error(f"Schema validation failed: {e}")
            raise LLMExtractionError(f"Schema mismatch: {e}")
        except Exception as e:
            logger.error(f"Unexpected error during extraction: {e}")
            raise LLMExtractionError(f"Extraction failed: {e}")
    
    async def _filter_and_deduplicate(
        self,
        memories: List[MemoryItem],
        call_id: str
    ) -> List[MemoryItem]:
        """
        Filter by confidence and deduplicate against existing storage.
        
        Args:
            memories: Raw extracted memories
            call_id: Source call identifier
            
        Returns:
            Filtered list of unique, high-confidence memories
        """
        filtered = []
        
        for memory in memories:
            # Apply confidence threshold
            if memory.confidence < self.confidence_threshold:
                logger.debug(
                    f"Filtered out low-confidence memory: {memory.content[:50]}... "
                    f"({memory.confidence:.2f} < {self.confidence_threshold})"
                )
                continue
            
            # Compute hash for deduplication
            content_hash = self._compute_content_hash(memory.content)
            
            # Check for duplicates
            if await self.repository.exists(content_hash):
                logger.debug(f"Duplicate detected, skipping: {memory.content[:50]}...")
                continue
            
            # Update memory with computed hash and call ID
            memory.content_hash = content_hash
            memory.source_call_id = call_id
            filtered.append(memory)
        
        return filtered
    
    async def process_call(
        self,
        call_id: str,
        transcript: str,
        metadata: Optional[Dict[str, Any]] = None
    ) -> List[MemoryItem]:
        """
        Execute full extraction pipeline for a call.
        
        This is the main entry point for processing call transcripts.
        
        Args:
            call_id: Unique identifier for the call
            transcript: Call transcript text
            metadata: Optional context (participants, timestamps, etc.)
            
        Returns:
            List of successfully stored memories
            
        Raises:
            ExtractionError: If extraction or storage fails
        """
        metadata = metadata or {}
        logger.info(f"Starting memory extraction for call {call_id}")
        
        try:
            # Step 1: Extract from LLM
            extraction = await self._extract_with_llm(transcript, metadata)
            
            # Step 2: Filter and deduplicate
            valid_memories = await self._filter_and_deduplicate(
                extraction.memories,
                call_id
            )
            
            # Step 3: Store results
            stored = []
            for memory in valid_memories:
                try:
                    await self.repository.store(memory)
                    stored.append(memory)
                    logger.info(
                        f"Stored memory [{memory.memory_type}]: {memory.content[:50]}... "
                        f"(confidence: {memory.confidence:.2f})"
                    )
                except Exception as e:
                    logger.error(f"Failed to store memory: {e}")
                    raise StorageError(f"Failed to store memory: {e}")
            
            logger.info(
                f"Extraction complete for {call_id}: "
                f"{len(stored)}/{len(extraction.memories)} memories stored"
            )
            return stored
            
        except (LLMExtractionError, StorageError):
            raise
        except Exception as e:
            logger.exception(f"Unexpected error processing call {call_id}")
            raise ExtractionError(f"Processing failed: {e}")
    
    async def process_call_background(
        self,
        call_id: str,
        transcript: str,
        metadata: Optional[Dict[str, Any]] = None
    ) -> asyncio.Task[List[MemoryItem]]:
        """
        Schedule extraction as a background task.
        
        Use this method to trigger extraction without blocking
        the main flow after a call ends.
        
        Args:
            call_id: Unique identifier for the call
            transcript: Call transcript text
            metadata: Optional context
            
        Returns:
            Asyncio Task handle for monitoring/awaiting if needed
        """
        logger.info(f"Scheduling background extraction for call {call_id}")
        
        async def _wrapped():
            try:
                return await self.process_call(call_id, transcript, metadata)
            except Exception as e:
                logger.error(f"Background extraction failed for {call_id}: {e}")
                # In production, you might want to send to a dead-letter queue
                raise
        
        return asyncio.create_task(_wrapped())


# ============================================================================
# TESTING UTILITIES AND EXAMPLES
# ============================================================================

class MockGeminiClient:
    """
    Mock LLM client for testing extraction logic without API calls.
    """
    
    def __init__(self, responses: Optional[List[Dict]] = None):
        self.responses = responses or []
        self.call_count = 0
    
    def generate_content(self, prompt: str, **kwargs) -> Any:
        self.call_count += 1
        response_data = self.responses.pop(0) if self.responses else {"memories": []}
        
        class MockResponse:
            def __init__(self, data):
                self._data = data
            
            @property
            def text(self):
                return json.dumps(self._data)
        
        return MockResponse(response_data)


async def run_black_box_test():
    """
    Black Box Test: Complete call flow verification.
    
    Verifies:
    - Extraction runs after call ends
    - Memories are stored
    - No duplicates stored
    - Low-confidence items filtered
    """
    print("\n=== BLACK BOX TEST ===")
    
    # Setup
    repo = InMemoryRepository()
    extractor = MemoryExtractor(repository=repo)
    
    # Mock the LLM client
    mock_response = {
        "memories": [
            {
                "content": "User prefers Python over JavaScript",
                "memory_type": "preference",
                "confidence": 0.95,
                "metadata": {"speakers": ["Alice", "Bob"]}
            },
            {
                "content": "Meeting scheduled for tomorrow 3pm",  # High confidence
                "memory_type": "action_item",
                "confidence": 0.85,
                "metadata": {}
            },
            {
                "content": "Maybe they like coffee",  # Low confidence - should be filtered
                "memory_type": "preference",
                "confidence": 0.5,
                "metadata": {}
            },
            {
                "content": "User prefers Python over JavaScript",  # Duplicate
                "memory_type": "preference",
                "confidence": 0.92,
                "metadata": {}
            }
        ]
    }
    
    # Inject mock
    extractor._gen_model = MockGeminiClient([mock_response])
    
    # Execute
    call_id = "test_call_001"
    transcript = "Alice: I prefer Python over JavaScript...\nBob: Great, meeting tomorrow at 3pm?"
    
    memories = await extractor.process_call(call_id, transcript)
    
    # Assertions
    assert len(memories) == 2, f"Expected 2 memories, got {len(memories)}"
    assert all(m.confidence >= 0.7 for m in memories), "Low confidence not filtered"
    
    # Verify storage
    stored = await repo.get_by_call(call_id)
    assert len(stored) == 2, "Not all memories stored"
    
    # Verify deduplication (try to store same content again)
    extractor._gen_model = MockGeminiClient([mock_response])
    memories_2 = await extractor.process_call("test_call_002", transcript)
    assert len(memories_2) == 0, "Duplicates not filtered"
    
    print("✓ Extraction runs successfully")
    print("✓ Memories stored correctly")
    print("✓ Low-confidence items filtered (0.5 < 0.7)")
    print("✓ Duplicates prevented")
    print(f"✓ Stored memories: {[m.content for m in stored]}")


async def run_white_box_test():
    """
    White Box Test: Internal implementation verification.
    
    Verifies:
    - Extraction prompt includes structured output format
    - Deduplication uses content hash (SHA256)
    - Confidence threshold is configurable
    """
    print("\n=== WHITE BOX TEST ===")
    
    # Test 1: Hash generation
    repo = InMemoryRepository()
    extractor = MemoryExtractor(repository=repo)
    
    content = "  Test Content  "
    hash1 = extractor._compute_content_hash(content)
    hash2 = extractor._compute_content_hash("test content")
    hash3 = extractor._compute_content_hash("different content")
    
    assert hash1 == hash2, "Hash not normalized (case/whitespace)"
    assert hash1 != hash3, "Hash collision"
    assert len(hash1) == 64, "Not SHA256 (64 hex chars)"
    print("✓ Deduplication uses SHA256 content hash")
    print(f"  Hash: {hash1}")
    
    # Test 2: Configurable threshold
    custom_config = ExtractionConfig(confidence_threshold=0.9)
    strict_extractor = MemoryExtractor(repository=repo, config=custom_config)
    
    assert strict_extractor.confidence_threshold == 0.9
    print("✓ Confidence threshold configurable (0.9)")
    
    # Test 3: Prompt content verification
    assert "application/json" in extractor._EXTRACTION_PROMPT
    assert "entity" in extractor._EXTRACTION_PROMPT
    assert "confidence" in extractor._EXTRACTION_PROMPT
    assert "memory_type" in extractor._EXTRACTION_PROMPT
    print("✓ Extraction prompt includes structured output format")
    
    # Test 4: Async background task creation
    mock_response = {"memories": []}
    extractor._gen_model = MockGeminiClient([mock_response])
    
    task = await extractor.process_call_background(
        "call_123",
        "test transcript"
    )
    assert isinstance(task, asyncio.Task)
    await task  # Ensure completion
    print("✓ Background task creation works")


async def example_usage():
    """Demonstrate typical production usage patterns."""
    print("\n=== EXAMPLE USAGE ===")
    
    # Initialize with custom config
    config = ExtractionConfig(
        confidence_threshold=0.75,
        model_name="gemini-1.5-flash",
        temperature=0.1
    )
    
    # Use Redis or PostgreSQL in production
    repository = InMemoryRepository()
    extractor = MemoryExtractor(
        repository=repository,
        config=config,
        api_key="your-api-key-here"  # Or use env var GOOGLE_API_KEY
    )
    
    # Simulate post-call hook
    async def on_call_ended(call_id: str, transcript: str, metadata: dict):
        """Webhook or event handler for call completion."""
        # Fire and forget background extraction
        await extractor.process_call_background(
            call_id=call_id,
            transcript=transcript,
            metadata=metadata
        )
        print(f"Scheduled extraction for {call_id}")
    
    # Example call
    await on_call_ended(
        call_id="call_2024_001",
        transcript="""
        Sarah: Hi John, do you still prefer Slack over email?
        John: Yes, definitely. Let's also schedule the review for next Tuesday.
        Sarah: Perfect, I'll send the calendar invite.
        """,
        metadata={"participants": ["Sarah", "John"], "duration": 300}
    )
    
    # Allow background task to complete
    await asyncio.sleep(0.1)
    
    # Verify extraction
    memories = await repository.get_by_call("call_2024_001")
    print(f"Extracted {len(memories)} memories from example call")


if __name__ == "__main__":
    # Run tests
    asyncio.run(run_white_box_test())
    asyncio.run(run_black_box_test())
    asyncio.run(example_usage())
