# ingestion_production.py

import os
import hashlib
import logging
import time
from typing import List, Dict, Any
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing
from tqdm import tqdm  # For progress tracking
import json # For config
import traceback


# Load configuration from JSON file
def load_config(config_path: str = "config.json") -> Dict[str, Any]:
    """Loads configuration from a JSON file."""
    try:
        with open(config_path, 'r') as f:
            return json.load(f)
    except FileNotFoundError:
        logging.error(f"Configuration file not found: {config_path}")
        raise
    except json.JSONDecodeError as e:
        logging.error(f"Error decoding JSON in {config_path}: {e}")
        raise

# Placeholder implementations for pipeline stages (replace with actual logic)
def extract_text_from_pdf(pdf_path: str, config: Dict[str, Any]) -> str:
    """Extracts text from a PDF file. Returns empty string if extraction fails."""
    try:
        # Replace with actual PDF extraction logic (e.g., using PyPDF2, pdfminer.six)
        # Example using PyPDF2 (install with: pip install PyPDF2)
        import PyPDF2
        with open(pdf_path, 'rb') as pdf_file:
            reader = PyPDF2.PdfReader(pdf_file)
            text = ""
            for page in reader.pages:
                text += page.extract_text()
        if not text:
            logging.warning(f"No text extracted from {pdf_path} using PyPDF2. Trying fallback.")
            # Fallback to simple text reading (if PDF is already plaintext)
            with open(pdf_path, 'r', encoding='utf-8', errors='ignore') as f:  # Robust handling of potential encoding issues
                text = f.read()  # Assuming PDF is plaintext - REPLACE!
        return text
    except Exception as e:
        logging.error(f"Error extracting text from {pdf_path}: {e}")
        return ""  # Return empty string to handle extraction failures gracefully


def clean_text(text: str, config: Dict[str, Any]) -> str:
    """Cleans the extracted text (e.g., remove special characters, whitespace)."""
    # Replace with actual text cleaning logic (e.g., using regex, nltk)
    cleaned_text = text.replace('\n', ' ').replace('\r', '')  # Example cleaning
    cleaned_text = ' '.join(cleaned_text.split())  # Remove extra whitespace
    # Example using regex (install with: pip install regex)
    import re
    cleaned_text = re.sub(r'[^\x00-\x7F]+', '', cleaned_text)  # Remove non-ASCII characters
    return cleaned_text


def chunk_text(text: str, config: Dict[str, Any]) -> List[str]:
    """Splits the text into smaller chunks."""
    chunk_size = config.get("chunk_size", 2000)
    chunk_overlap = config.get("chunk_overlap", 200)
    # Replace with actual chunking logic (e.g., Langchain TextSplitter)
    chunks = []
    for i in range(0, len(text), chunk_size - chunk_overlap):
        chunk = text[i:i + chunk_size]
        chunks.append(chunk)
    return chunks


def embed_chunks(chunks: List[str], config: Dict[str, Any]) -> List[List[float]]:
    """Embeds the text chunks using a specified embedding model."""
    embedding_model = config.get("embedding_model", "default")
    embedding_batch_size = config.get("embedding_batch_size", 32) # Add batch size
    # Replace with actual embedding logic (e.g., using Sentence Transformers, OpenAI API)
    # Example using Sentence Transformers (install with: pip install sentence-transformers)
    from sentence_transformers import SentenceTransformer
    try:
        model = SentenceTransformer(embedding_model)
        embeddings = model.encode(chunks, batch_size=embedding_batch_size) # Embed in batches
        return embeddings.tolist() # Convert to list for JSON serialization
    except Exception as e:
        logging.error(f"Error embedding chunks: {e}")
        return [] # Return empty list to handle embedding failures gracefully


def store_embeddings_in_vector_db(chunks: List[str], embeddings: List[List[float]], document_id: str, vector_db_client, config: Dict[str, Any]) -> None:
    """Stores the embeddings and corresponding chunks in a vector database."""
    # Replace with actual vector database interaction (e.g., using Pinecone, ChromaDB, FAISS)
    # This is a placeholder - assumes vector_db_client has an 'add' method
    for chunk, embedding in zip(chunks, embeddings):
        try:
            vector_db_client.add(
                documents=[chunk],
                embeddings=[embedding],
                ids=[f"{document_id}_{hashlib.md5(chunk.encode()).hexdigest()[:8]}"],  # Unique ID per chunk
                metadatas=[{"document_id": document_id}]
            )
        except Exception as e:
            logging.error(f"Error storing embedding for document {document_id}: {e}")
            raise  # Re-raise exception for handling at a higher level.


def update_knowledge_graph(document_id: str, chunks: List[str], config: Dict[str, Any]) -> None:
    """Updates the knowledge graph with information from the processed document."""
    # Replace with actual knowledge graph update logic (e.g., using Neo4j, RDFlib)
    # Placeholder:  Log the update
    logging.info(f"Knowledge graph updated for document {document_id} with {len(chunks)} chunks.")


class PatentIngestionPipeline:
    """
    Orchestrates the patent ingestion pipeline, handling PDF extraction, text cleaning,
    chunking, embedding, vector storage, and knowledge graph updates.
    """

    def __init__(self, source_directory: str, vector_db_client, config_path: str = "config.json", processed_files_log: str = "processed_files.log", num_workers: int = 4):
        """
        Initializes the pipeline.

        Args:
            source_directory: The directory containing the PDF files to process.
            vector_db_client: Client for interacting with the vector database.
            config_path: Path to the JSON configuration file.
            processed_files_log: Path to the log file tracking processed files.
            num_workers: Number of worker processes for parallel processing.
        """
        self.source_directory = source_directory
        self.vector_db_client = vector_db_client
        self.config = load_config(config_path)
        self.processed_files_log = processed_files_log
        self.num_workers = num_workers
        self.processed_files = self._load_processed_files()  # Load on initialization
        self.lock = multiprocessing.Lock()  # For thread-safe access to processed_files
        self.total_files = 0 # Total files to process for progress tracking
        self.processed_count = 0 # Number of files processed for progress tracking
        self.progress_bar = None


    def _load_processed_files(self) -> Dict[str, str]:
        """Loads the list of already processed files and their hashes from the log file."""
        processed_files = {}
        if os.path.exists(self.processed_files_log):
            try:
                with open(self.processed_files_log, 'r') as f:
                    for line in f:
                        filename, file_hash = line.strip().split(',')
                        processed_files[filename] = file_hash
            except Exception as e:
                logging.warning(f"Error loading processed files log: {e}. Starting with an empty log.")
                return {}

        return processed_files

    def _save_processed_file(self, filename: str, file_hash: str) -> None:
        """Appends the processed filename and its hash to the log file."""
        with self.lock:
            with open(self.processed_files_log, 'a') as f:
                f.write(f"{filename},{file_hash}\n")
            self.processed_files[filename] = file_hash

    def _is_duplicate(self, filename: str, file_hash: str) -> bool:
        """Checks if a file is a duplicate based on its hash."""
        return filename in self.processed_files and self.processed_files[filename] == file_hash

    def _calculate_file_hash(self, filepath: str) -> str:
        """Calculates the MD5 hash of a file."""
        hasher = hashlib.md5()
        try:
            with open(filepath, 'rb') as afile:
                buf = afile.read()
                hasher.update(buf)
            return hasher.hexdigest()
        except Exception as e:
            logging.error(f"Error calculating hash for {filepath}: {e}")
            return None

    def _validate_text(self, text: str, stage_name: str) -> bool:
        """Validates that the text is not empty and contains meaningful content."""
        if not text:
            logging.warning(f"{stage_name} produced empty text.")
            return False
        if len(text.strip()) < 10:  # Basic check for minimal content
            logging.warning(f"{stage_name} produced very short text.  Possible issue.")
            return True #Not critical enough to stop processing
        return True


    def process_document(self, filepath: str) -> None:
        """Processes a single patent document."""
        filename = os.path.basename(filepath)
        file_hash = self._calculate_file_hash(filepath)

        if not file_hash:
            logging.warning(f"Skipping {filename} due to hash calculation failure.")
            return

        if self._is_duplicate(filename, file_hash):
            logging.info(f"Skipping duplicate file: {filename}")
            with self.lock:
                self.processed_count += 1
                self.progress_bar.update(1)
            return

        logging.info(f"Processing file: {filename}")
        start_time = time.time()
        try:
            # 1. Extract
            raw_text = extract_text_from_pdf(filepath, self.config)
            if not self._validate_text(raw_text, "Extraction"):
                logging.warning(f"Extraction failed for {filename}. Skipping.")
                return

            # 2. Clean
            cleaned_text = clean_text(raw_text, self.config)
            if not self._validate_text(cleaned_text, "Cleaning"):
                logging.warning(f"Cleaning produced invalid text for {filename}. Continuing, but review cleaning config.")

            # 3. Chunk
            chunks = chunk_text(cleaned_text, self.config)
            if not chunks:
                logging.warning(f"Chunking produced no chunks for {filename}. Skipping.")
                return

            # 4. Embed
            embeddings = embed_chunks(chunks, self.config)
            if not embeddings or len(embeddings) != len(chunks):  #Crucial validation
                logging.error(f"Embedding failed or produced incorrect number of embeddings for {filename}. Skipping.")
                return

            # 5. Store
            document_id = filename.replace(".pdf", "")  # Create a unique ID for the document
            store_embeddings_in_vector_db(chunks, embeddings, document_id, self.vector_db_client, self.config)

            # 6. Index
            update_knowledge_graph(document_id, chunks, self.config)

            # Mark as processed
            self._save_processed_file(filename, file_hash)

            logging.info(f"Successfully processed {filename} in {time.time() - start_time:.2f} seconds.")

        except Exception as e:
            logging.error(f"Error processing {filename}: {e}", exc_info=True)  # Log traceback
            #Implement error recovery here, e.g., move the file to an 'error' directory
            error_directory = self.config.get("error_directory", "error_pdfs")
            if not os.path.exists(error_directory):
                os.makedirs(error_directory)
            try:
                os.rename(filepath, os.path.join(error_directory, filename))
                logging.info(f"Moved {filename} to {error_directory} for manual review.")
            except Exception as move_err:
                logging.error(f"Failed to move {filename} to error directory: {move_err}")


        finally:
            with self.lock:
                self.processed_count += 1
                self.progress_bar.update(1)


    def run(self):
        """Runs the ingestion pipeline for all PDF files in the source directory."""
        pdf_files = [os.path.join(self.source_directory, f) for f in os.listdir(self.source_directory) if f.endswith(".pdf")]
        self.total_files = len(pdf_files)
        logging.info(f"Found {self.total_files} PDF files to process.")

        if not pdf_files:
            logging.info("No PDF files found. Exiting.")
            return

        # Initialize progress bar
        with self.lock:
             self.progress_bar = tqdm(total=self.total_files, desc="Processing Patents", unit="file")

        # Batch processing using a process pool
        with ProcessPoolExecutor(max_workers=self.num_workers) as executor:
            futures = [executor.submit(self.process_document, filepath) for filepath in pdf_files]

            # Monitor for exceptions and handle them as they occur
            for future in as_completed(futures):
                try:
                    future.result()  # Get the result to raise any exceptions
                except Exception as e:
                    logging.error(f"Task generated an exception: {e}")

        with self.lock:
            self.progress_bar.close()
        logging.info("Pipeline execution complete.")


import multiprocessing

if __name__ == '__main__':
    # Configure logging
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

    # Example usage:
    # 1. Create a directory with some dummy PDF files (or use real ones)
    # 2. Implement the placeholder functions (extract_text_from_pdf, etc.)
    # 3. Instantiate a vector database client (replace with your actual client)

    # Create a dummy source directory if it doesn't exist
    source_directory = "patent_pdfs"
    if not os.path.exists(source_directory):
        os.makedirs(source_directory)
        # Create dummy PDF files
        for i in range(5):
            with open(os.path.join(source_directory, f"patent_{i}.pdf"), 'w') as f:
                f.write(f"This is a dummy patent document {i}.\nIt contains some text to be processed.  This is line two.\nAnd line three.")

    # Dummy Vector DB client (replace with your actual client)
    class DummyVectorDBClient:
        def __init__(self):
            self.data = []

        def add(self, documents: List[str], embeddings: List[List[float]], ids: List[str], metadatas: List[Dict[str, Any]]):
            for doc, emb, id, meta in zip(documents, embeddings, ids, metadatas):
                self.data.append({"document": doc, "embedding": emb, "id": id, "metadata": meta})
            logging.info(f"Added {len(documents)} vectors to the dummy vector DB.")


    vector_db_client = DummyVectorDBClient()

    # Create a dummy config.json file
    config_data = {
        "chunk_size": 1500,
        "chunk_overlap": 150,
        "embedding_model": "all-MiniLM-L6-v2", #Requires sentence-transformers
        "embedding_batch_size": 64,
        "error_directory": "failed_patents"
    }

    with open("config.json", "w") as f:
        json.dump(config_data, f, indent=4) #Indent for readability


    # Instantiate and run the pipeline
    pipeline = PatentIngestionPipeline(source_directory=source_directory, vector_db_client=vector_db_client, num_workers=multiprocessing.cpu_count())
    pipeline.run()

    logging.info("Pipeline execution complete.")
