#!/usr/bin/env python3
"""
Long Document Analyzer Skill for Genesis System

This skill processes documents with 1M+ tokens using chunking and synthesis
strategies inspired by Qwen-Long patterns. It implements a memory mechanism
to maintain context across chunks and synthesize coherent analysis.

Usage:
    python long_document_analyzer.py <document_path> [--query "your question"]

    Or import and use programmatically:
    from long_document_analyzer import LongDocumentAnalyzer
    analyzer = LongDocumentAnalyzer()
    results = analyzer.analyze("path/to/document.pdf", query="Summarize key findings")
"""

import os
import sys
import json
import hashlib
import logging
from datetime import datetime
from typing import Dict, List, Optional, Any, Generator, Tuple
from dataclasses import dataclass, asdict, field
from pathlib import Path
from abc import ABC, abstractmethod
import re
from collections import deque

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


# Token estimation (roughly 4 chars per token for English)
CHARS_PER_TOKEN = 4
DEFAULT_CHUNK_SIZE = 8000  # tokens
DEFAULT_OVERLAP = 500  # tokens
MAX_CONTEXT_WINDOW = 128000  # tokens (Claude-3 context)


@dataclass
class DocumentChunk:
    """Represents a chunk of a document."""
    index: int
    content: str
    start_position: int
    end_position: int
    token_count: int
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class ChunkAnalysis:
    """Analysis result for a single chunk."""
    chunk_index: int
    summary: str
    key_points: List[str]
    entities: List[str]
    relevance_score: float
    extracted_facts: List[str]


@dataclass
class MemoryEntry:
    """Entry in the document memory system."""
    content: str
    importance: float
    chunk_index: int
    timestamp: str
    entry_type: str  # 'summary', 'fact', 'entity', 'insight'


@dataclass
class AnalysisResult:
    """Complete analysis result for a document."""
    document_path: str
    document_hash: str
    total_tokens: int
    total_chunks: int
    query: str
    synthesized_answer: str
    executive_summary: str
    key_findings: List[str]
    chunk_analyses: List[ChunkAnalysis]
    memory_entries: List[MemoryEntry]
    processing_time: float
    metadata: Dict[str, Any]

    def to_dict(self) -> Dict:
        """Convert to dictionary for serialization."""
        return {
            "document_path": self.document_path,
            "document_hash": self.document_hash,
            "total_tokens": self.total_tokens,
            "total_chunks": self.total_chunks,
            "query": self.query,
            "synthesized_answer": self.synthesized_answer,
            "executive_summary": self.executive_summary,
            "key_findings": self.key_findings,
            "chunk_analyses": [asdict(ca) for ca in self.chunk_analyses],
            "memory_entries": [asdict(me) for me in self.memory_entries],
            "processing_time": self.processing_time,
            "metadata": self.metadata
        }


class DocumentReader(ABC):
    """Abstract base class for document readers."""

    @abstractmethod
    def read(self, path: str) -> str:
        """Read document and return text content."""
        pass

    @abstractmethod
    def supports(self, path: str) -> bool:
        """Check if this reader supports the given file type."""
        pass


class TextReader(DocumentReader):
    """Reader for plain text files."""

    def read(self, path: str) -> str:
        with open(path, 'r', encoding='utf-8', errors='ignore') as f:
            return f.read()

    def supports(self, path: str) -> bool:
        return path.lower().endswith(('.txt', '.md', '.rst', '.log'))


class PDFReader(DocumentReader):
    """Reader for PDF files."""

    def read(self, path: str) -> str:
        try:
            import PyPDF2
            text_parts = []
            with open(path, 'rb') as f:
                reader = PyPDF2.PdfReader(f)
                for page in reader.pages:
                    text_parts.append(page.extract_text() or "")
            return "\n\n".join(text_parts)
        except ImportError:
            logger.warning("PyPDF2 not installed. Install with: pip install PyPDF2")
            raise
        except Exception as e:
            logger.error(f"Failed to read PDF: {e}")
            raise

    def supports(self, path: str) -> bool:
        return path.lower().endswith('.pdf')


class JSONReader(DocumentReader):
    """Reader for JSON files."""

    def read(self, path: str) -> str:
        with open(path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return json.dumps(data, indent=2)

    def supports(self, path: str) -> bool:
        return path.lower().endswith('.json')


class HTMLReader(DocumentReader):
    """Reader for HTML files."""

    def read(self, path: str) -> str:
        try:
            from bs4 import BeautifulSoup
            with open(path, 'r', encoding='utf-8', errors='ignore') as f:
                soup = BeautifulSoup(f.read(), 'html.parser')
            # Remove scripts and styles
            for element in soup(['script', 'style']):
                element.decompose()
            return soup.get_text(separator='\n', strip=True)
        except ImportError:
            logger.warning("BeautifulSoup not installed. Install with: pip install beautifulsoup4")
            raise

    def supports(self, path: str) -> bool:
        return path.lower().endswith(('.html', '.htm'))


class MemorySystem:
    """
    Memory system for maintaining context across document chunks.

    Implements a working memory with importance-based retention,
    similar to how Qwen-Long manages long-context processing.
    """

    def __init__(self, max_entries: int = 100, importance_threshold: float = 0.3):
        """
        Initialize the memory system.

        Args:
            max_entries: Maximum number of memory entries to retain
            importance_threshold: Minimum importance score to retain
        """
        self.max_entries = max_entries
        self.importance_threshold = importance_threshold
        self.entries: List[MemoryEntry] = []
        self.recent_context: deque = deque(maxlen=5)  # Last 5 chunk summaries

    def add_entry(self, content: str, importance: float, chunk_index: int, entry_type: str):
        """Add an entry to memory."""
        entry = MemoryEntry(
            content=content,
            importance=importance,
            chunk_index=chunk_index,
            timestamp=datetime.now().isoformat(),
            entry_type=entry_type
        )
        self.entries.append(entry)

        # Prune if needed
        if len(self.entries) > self.max_entries:
            self._prune_memory()

    def add_chunk_summary(self, summary: str, chunk_index: int):
        """Add a chunk summary to recent context."""
        self.recent_context.append({
            "chunk_index": chunk_index,
            "summary": summary
        })
        self.add_entry(summary, 0.5, chunk_index, "summary")

    def _prune_memory(self):
        """Remove low-importance entries when memory is full."""
        # Sort by importance (descending) and recency
        self.entries.sort(key=lambda e: (e.importance, e.chunk_index), reverse=True)
        # Keep top entries
        self.entries = self.entries[:self.max_entries]

    def get_context_for_chunk(self, chunk_index: int) -> str:
        """
        Get relevant context for processing a specific chunk.

        Returns a condensed version of important memories
        to provide context for the current chunk.
        """
        context_parts = []

        # Add recent chunk summaries
        if self.recent_context:
            context_parts.append("=== Recent Context ===")
            for ctx in self.recent_context:
                context_parts.append(f"Chunk {ctx['chunk_index']}: {ctx['summary']}")

        # Add high-importance facts and insights
        important_entries = [
            e for e in self.entries
            if e.importance >= 0.7 and e.entry_type in ('fact', 'insight')
        ]
        if important_entries:
            context_parts.append("\n=== Key Information ===")
            for entry in important_entries[-10:]:  # Last 10 important entries
                context_parts.append(f"- {entry.content}")

        return "\n".join(context_parts)

    def get_all_entries(self) -> List[MemoryEntry]:
        """Get all memory entries."""
        return self.entries.copy()

    def clear(self):
        """Clear all memory."""
        self.entries = []
        self.recent_context.clear()


class ChunkProcessor:
    """
    Processes individual document chunks.

    This is where you would integrate with an LLM for actual analysis.
    Currently provides a rule-based analysis as a demonstration.
    """

    def __init__(self, memory: MemorySystem):
        """Initialize the chunk processor."""
        self.memory = memory

    def process_chunk(self, chunk: DocumentChunk, query: str = None) -> ChunkAnalysis:
        """
        Process a single chunk and extract insights.

        In production, this would call an LLM API. Currently uses
        rule-based extraction as a demonstration.
        """
        content = chunk.content

        # Extract key points (sentences with important keywords)
        key_points = self._extract_key_points(content)

        # Extract entities (simple NER-like extraction)
        entities = self._extract_entities(content)

        # Generate summary (first few sentences + key sentence)
        summary = self._generate_summary(content)

        # Extract facts (declarative statements)
        facts = self._extract_facts(content)

        # Calculate relevance to query
        relevance_score = self._calculate_relevance(content, query) if query else 0.5

        # Update memory with findings
        if relevance_score > 0.6:
            self.memory.add_entry(summary, relevance_score, chunk.index, "summary")
            for fact in facts[:3]:  # Top 3 facts
                self.memory.add_entry(fact, 0.7, chunk.index, "fact")

        # Add to recent context
        self.memory.add_chunk_summary(summary[:200], chunk.index)

        return ChunkAnalysis(
            chunk_index=chunk.index,
            summary=summary,
            key_points=key_points,
            entities=entities,
            relevance_score=relevance_score,
            extracted_facts=facts
        )

    def _extract_key_points(self, content: str) -> List[str]:
        """Extract key points from content."""
        sentences = re.split(r'[.!?]+', content)
        key_words = {'important', 'key', 'significant', 'main', 'critical',
                    'essential', 'notable', 'primary', 'major', 'conclude'}

        key_points = []
        for sentence in sentences:
            sentence = sentence.strip()
            if len(sentence) > 20:
                sentence_lower = sentence.lower()
                if any(kw in sentence_lower for kw in key_words):
                    key_points.append(sentence)

        return key_points[:10]  # Limit to 10

    def _extract_entities(self, content: str) -> List[str]:
        """Extract named entities (simplified pattern matching)."""
        # Capitalized word sequences (potential names/organizations)
        pattern = r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)+\b'
        entities = re.findall(pattern, content)
        return list(set(entities))[:20]

    def _generate_summary(self, content: str) -> str:
        """Generate a basic summary."""
        sentences = re.split(r'[.!?]+', content)
        sentences = [s.strip() for s in sentences if len(s.strip()) > 30]

        if not sentences:
            return content[:200]

        # Take first 2 sentences and a middle sentence
        summary_parts = []
        if len(sentences) >= 1:
            summary_parts.append(sentences[0])
        if len(sentences) >= 2:
            summary_parts.append(sentences[1])
        if len(sentences) >= 5:
            summary_parts.append(sentences[len(sentences)//2])

        return ". ".join(summary_parts) + "."

    def _extract_facts(self, content: str) -> List[str]:
        """Extract factual statements."""
        sentences = re.split(r'[.!?]+', content)
        facts = []

        fact_indicators = ['is', 'are', 'was', 'were', 'has', 'have', 'shows',
                         'demonstrates', 'indicates', 'reveals', 'found', 'discovered']

        for sentence in sentences:
            sentence = sentence.strip()
            if len(sentence) > 20 and len(sentence) < 300:
                sentence_lower = sentence.lower()
                if any(f' {ind} ' in sentence_lower for ind in fact_indicators):
                    facts.append(sentence)

        return facts[:10]

    def _calculate_relevance(self, content: str, query: str) -> float:
        """Calculate relevance score between content and query."""
        if not query:
            return 0.5

        content_lower = content.lower()
        query_words = set(query.lower().split())

        # Remove common words
        stop_words = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'what',
                     'how', 'why', 'when', 'where', 'who', 'which', 'and', 'or'}
        query_words = query_words - stop_words

        if not query_words:
            return 0.5

        matches = sum(1 for word in query_words if word in content_lower)
        return min(1.0, matches / len(query_words))


class LongDocumentAnalyzer:
    """
    Main analyzer for long documents using Qwen-Long inspired patterns.

    Features:
    - Intelligent chunking with overlap
    - Memory-based context maintenance
    - Hierarchical synthesis
    - Query-focused analysis
    """

    def __init__(self,
                 chunk_size: int = DEFAULT_CHUNK_SIZE,
                 overlap: int = DEFAULT_OVERLAP,
                 knowledge_base_path: str = None):
        """
        Initialize the Long Document Analyzer.

        Args:
            chunk_size: Size of each chunk in tokens
            overlap: Overlap between chunks in tokens
            knowledge_base_path: Path to store analysis results
        """
        self.chunk_size = chunk_size
        self.overlap = overlap
        self.memory = MemorySystem()
        self.processor = ChunkProcessor(self.memory)

        # Initialize readers
        self.readers: List[DocumentReader] = [
            TextReader(),
            PDFReader(),
            JSONReader(),
            HTMLReader()
        ]

        # Set up knowledge base
        if knowledge_base_path is None:
            knowledge_base_path = os.path.join(
                os.path.dirname(__file__), "..", "knowledge_base", "documents"
            )
        self.knowledge_base_path = Path(knowledge_base_path)
        self.knowledge_base_path.mkdir(parents=True, exist_ok=True)

    def _get_reader(self, path: str) -> Optional[DocumentReader]:
        """Get appropriate reader for the file type."""
        for reader in self.readers:
            if reader.supports(path):
                return reader
        return None

    def _estimate_tokens(self, text: str) -> int:
        """Estimate token count from text."""
        return len(text) // CHARS_PER_TOKEN

    def _calculate_hash(self, content: str) -> str:
        """Calculate hash of document content."""
        return hashlib.sha256(content.encode()).hexdigest()[:16]

    def chunk_document(self, content: str) -> Generator[DocumentChunk, None, None]:
        """
        Split document into overlapping chunks.

        Implements intelligent chunking that tries to break at
        paragraph boundaries when possible.
        """
        total_chars = len(content)
        chunk_chars = self.chunk_size * CHARS_PER_TOKEN
        overlap_chars = self.overlap * CHARS_PER_TOKEN

        position = 0
        index = 0

        while position < total_chars:
            # Calculate end position
            end_position = min(position + chunk_chars, total_chars)

            # Try to find a good break point (paragraph or sentence)
            if end_position < total_chars:
                # Look for paragraph break
                para_break = content.rfind('\n\n', position + chunk_chars - 500, end_position)
                if para_break > position:
                    end_position = para_break

                # Or sentence break
                elif end_position > position:
                    for punct in ['. ', '! ', '? ', '.\n']:
                        sent_break = content.rfind(punct, position + chunk_chars - 300, end_position)
                        if sent_break > position:
                            end_position = sent_break + 1
                            break

            chunk_content = content[position:end_position].strip()

            if chunk_content:
                yield DocumentChunk(
                    index=index,
                    content=chunk_content,
                    start_position=position,
                    end_position=end_position,
                    token_count=self._estimate_tokens(chunk_content)
                )

            # Move position with overlap
            position = end_position - overlap_chars
            if position <= 0 or end_position >= total_chars:
                position = end_position
            index += 1

    def analyze(self, document_path: str, query: str = None) -> AnalysisResult:
        """
        Perform comprehensive analysis of a long document.

        Args:
            document_path: Path to the document
            query: Optional query to focus the analysis

        Returns:
            AnalysisResult with complete analysis
        """
        start_time = datetime.now()
        logger.info(f"Starting analysis of: {document_path}")

        # Clear memory for new document
        self.memory.clear()

        # Read document
        reader = self._get_reader(document_path)
        if reader is None:
            # Default to text reader
            reader = TextReader()

        try:
            content = reader.read(document_path)
        except Exception as e:
            logger.error(f"Failed to read document: {e}")
            raise

        total_tokens = self._estimate_tokens(content)
        doc_hash = self._calculate_hash(content)
        logger.info(f"Document loaded: ~{total_tokens:,} tokens")

        # Chunk and process
        chunks = list(self.chunk_document(content))
        logger.info(f"Document split into {len(chunks)} chunks")

        chunk_analyses = []
        for chunk in chunks:
            # Get context from memory
            context = self.memory.get_context_for_chunk(chunk.index)

            # Process chunk (in production, context would be passed to LLM)
            logger.info(f"Processing chunk {chunk.index + 1}/{len(chunks)} "
                       f"(tokens: {chunk.token_count})")

            analysis = self.processor.process_chunk(chunk, query)
            chunk_analyses.append(analysis)

        # Synthesize results
        synthesized_answer = self._synthesize_answer(chunk_analyses, query)
        executive_summary = self._generate_executive_summary(chunk_analyses)
        key_findings = self._extract_key_findings(chunk_analyses)

        # Calculate processing time
        processing_time = (datetime.now() - start_time).total_seconds()

        result = AnalysisResult(
            document_path=document_path,
            document_hash=doc_hash,
            total_tokens=total_tokens,
            total_chunks=len(chunks),
            query=query or "",
            synthesized_answer=synthesized_answer,
            executive_summary=executive_summary,
            key_findings=key_findings,
            chunk_analyses=chunk_analyses,
            memory_entries=self.memory.get_all_entries(),
            processing_time=processing_time,
            metadata={
                "chunk_size": self.chunk_size,
                "overlap": self.overlap,
                "analysis_timestamp": datetime.now().isoformat(),
                "reader_type": reader.__class__.__name__
            }
        )

        # Store results
        self._store_results(result)

        return result

    def _synthesize_answer(self, analyses: List[ChunkAnalysis], query: str) -> str:
        """Synthesize answer from all chunk analyses."""
        if not query:
            return "No specific query provided. See executive summary."

        # Gather relevant information
        relevant_analyses = sorted(
            analyses,
            key=lambda a: a.relevance_score,
            reverse=True
        )

        # Combine top relevant summaries and facts
        parts = []
        for analysis in relevant_analyses[:5]:  # Top 5 most relevant
            if analysis.relevance_score > 0.3:
                parts.append(analysis.summary)
                parts.extend(analysis.extracted_facts[:2])

        if parts:
            return f"Based on analysis of the document regarding '{query}':\n\n" + \
                   "\n\n".join(parts[:5])
        else:
            return f"No highly relevant information found for query: {query}"

    def _generate_executive_summary(self, analyses: List[ChunkAnalysis]) -> str:
        """Generate executive summary from all analyses."""
        all_summaries = [a.summary for a in analyses]

        # Combine first and last chunk summaries with key middle points
        summary_parts = []

        if all_summaries:
            summary_parts.append("Introduction: " + all_summaries[0])

        if len(all_summaries) > 2:
            middle_idx = len(all_summaries) // 2
            summary_parts.append("Key content: " + all_summaries[middle_idx])

        if len(all_summaries) > 1:
            summary_parts.append("Conclusion: " + all_summaries[-1])

        return "\n\n".join(summary_parts)

    def _extract_key_findings(self, analyses: List[ChunkAnalysis]) -> List[str]:
        """Extract key findings from all analyses."""
        all_points = []
        for analysis in analyses:
            all_points.extend(analysis.key_points)
            all_points.extend(analysis.extracted_facts)

        # Deduplicate and limit
        seen = set()
        unique_findings = []
        for point in all_points:
            point_normalized = point.lower().strip()
            if point_normalized not in seen and len(point) > 20:
                seen.add(point_normalized)
                unique_findings.append(point)

        return unique_findings[:20]  # Top 20 findings

    def _store_results(self, result: AnalysisResult) -> str:
        """Store analysis results in knowledge base."""
        filename = f"analysis_{result.document_hash}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        filepath = self.knowledge_base_path / filename

        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(result.to_dict(), f, indent=2, ensure_ascii=False)

        logger.info(f"Results stored at: {filepath}")
        return str(filepath)


def main():
    """Main entry point for the long document analyzer skill."""
    import argparse

    parser = argparse.ArgumentParser(
        description="Analyze long documents with memory-based processing"
    )
    parser.add_argument("document", help="Path to the document to analyze")
    parser.add_argument("--query", "-q", help="Query to focus the analysis")
    parser.add_argument("--chunk-size", type=int, default=DEFAULT_CHUNK_SIZE,
                       help=f"Chunk size in tokens (default: {DEFAULT_CHUNK_SIZE})")
    parser.add_argument("--overlap", type=int, default=DEFAULT_OVERLAP,
                       help=f"Overlap between chunks in tokens (default: {DEFAULT_OVERLAP})")

    args = parser.parse_args()

    if not os.path.exists(args.document):
        print(f"Error: Document not found: {args.document}")
        sys.exit(1)

    # Run analysis
    analyzer = LongDocumentAnalyzer(
        chunk_size=args.chunk_size,
        overlap=args.overlap
    )

    try:
        result = analyzer.analyze(args.document, args.query)
    except Exception as e:
        print(f"Error during analysis: {e}")
        sys.exit(1)

    # Output results
    print("\n" + "="*60)
    print("LONG DOCUMENT ANALYSIS REPORT")
    print("="*60)
    print(f"\nDocument: {result.document_path}")
    print(f"Hash: {result.document_hash}")
    print(f"Total Tokens: ~{result.total_tokens:,}")
    print(f"Chunks Processed: {result.total_chunks}")
    print(f"Processing Time: {result.processing_time:.2f}s")

    if result.query:
        print(f"\n--- Query Response ---")
        print(f"Query: {result.query}")
        print(f"\n{result.synthesized_answer}")

    print(f"\n--- Executive Summary ---")
    print(result.executive_summary)

    print(f"\n--- Key Findings ({len(result.key_findings)}) ---")
    for i, finding in enumerate(result.key_findings[:10], 1):
        print(f"{i}. {finding}")

    if len(result.key_findings) > 10:
        print(f"... and {len(result.key_findings) - 10} more findings")

    print(f"\nFull results stored in knowledge base.")

    return result


if __name__ == "__main__":
    main()
