# enhanced_hallucination_detection.py

import nltk
import spacy
import requests
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import torch
import transformers
from typing import List, Dict
import logging
import time
import threading
from queue import Queue
import json
import os
import unittest

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Download necessary NLTK resources (run this once)
# nltk.download('punkt')
# nltk.download('averaged_perceptron_tagger')
# nltk.download('maxent_ne_chunker')
# nltk.download('words')

# Load spaCy model (consider a larger model for production)
nlp = spacy.load("en_core_web_lg")

# Load pre-trained transformer model for semantic similarity
try:
    semantic_model = transformers.AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
    semantic_tokenizer = transformers.AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
except OSError as e:
    logging.error(f"Error loading transformer model: {e}. Please ensure you have the necessary libraries installed and an internet connection.")
    semantic_model = None
    semantic_tokenizer = None

# Define type hints for better readability
ClaimVerificationResult = Dict[str, any]
HallucinationDetectionResult = Dict[str, any]

class HallucinationDetector:
    """
    Detects potential hallucinations in AI output using multi-modal detection,
    machine learning integration, real-time processing, and knowledge base integration.
    """

    def __init__(self, knowledge_base_url: str, similarity_threshold: float = 0.75,
                 hallucination_threshold: float = 0.8, confidence_calibration_model=None,
                 domain_validators: List[callable] = None, enable_citation_validation: bool = True):
        """
        Initializes the HallucinationDetector.

        Args:
            knowledge_base_url (str): URL of the knowledge base API endpoint.
            similarity_threshold (float): Cosine similarity threshold for claim verification.
            hallucination_threshold (float): Hallucination probability threshold for flagging content.
            confidence_calibration_model: A model to calibrate confidence scores (optional).
            domain_validators (List[callable]): List of domain-specific validation functions.
            enable_citation_validation (bool): Enables or disables citation validation.
        """
        self.knowledge_base_url = knowledge_base_url
        self.similarity_threshold = similarity_threshold
        self.hallucination_threshold = hallucination_threshold
        self.confidence_calibration_model = confidence_calibration_model
        self.domain_validators = domain_validators if domain_validators else []
        self.enable_citation_validation = enable_citation_validation

        self.fact_checking_enabled = True  # Enable/disable fact-checking
        self.wikipedia_enabled = True  # Enable/disable Wikipedia integration
        self.temporal_validation_enabled = True # Enable/disable temporal validation

        self.stream_queue = Queue()
        self.processing_thread = threading.Thread(target=self._process_stream, daemon=True)
        self.processing_thread.start()

    def enable_fact_checking(self):
        """Enables fact-checking against the knowledge base."""
        self.fact_checking_enabled = True

    def disable_fact_checking(self):
        """Disables fact-checking against the knowledge base."""
        self.fact_checking_enabled = False

    def enable_wikipedia_integration(self):
        """Enables Wikipedia integration for fact verification."""
        self.wikipedia_enabled = True

    def disable_wikipedia_integration(self):
        """Disables Wikipedia integration."""
        self.wikipedia_enabled = False

    def enable_temporal_validation(self):
        """Enables temporal validation of facts."""
        self.temporal_validation_enabled = True

    def disable_temporal_validation(self):
        """Disables temporal validation."""
        self.temporal_validation_enabled = False


    def detect_hallucinations(self, ai_output: str) -> HallucinationDetectionResult:
        """
        Detects potential hallucinations in AI output.

        Args:
            ai_output (str): The AI-generated text to analyze.

        Returns:
            HallucinationDetectionResult: A dictionary containing analysis results.
        """
        claims = self.extract_claims(ai_output)
        hallucinations = []
        claim_verifications = []
        hallucination_scores = []
        flagged_claims = []

        for claim in claims:
            verification_result = self.verify_claim(claim) if self.fact_checking_enabled else {
                'is_supported': True,
                'supporting_evidence': [],
                'similarity_score': 1.0
            }
            claim_verifications.append(verification_result)

            hallucination_score = self.calculate_hallucination_score(verification_result)
            hallucination_scores.append(hallucination_score)

            is_hallucination = hallucination_score > self.hallucination_threshold
            hallucinations.append(is_hallucination)

            if is_hallucination:
                flagged_claims.append(claim)

        quarantine_flag = any(hallucinations)

        return {
            'hallucinations': hallucinations,
            'claim_verifications': claim_verifications,
            'hallucination_scores': hallucination_scores,
            'quarantine_flag': quarantine_flag,
            'flagged_claims': flagged_claims
        }

    def extract_claims(self, text: str) -> List[str]:
        """
        Extracts factual claims from the given text using more sophisticated methods.

        Args:
            text (str): The text to extract claims from.

        Returns:
            List[str]: A list of strings, where each string is a potential factual claim.
        """
        doc = nlp(text)
        sentences = [sent.text for sent in doc.sents] # improved sentence splitting
        return sentences

    def verify_claim(self, claim: str) -> ClaimVerificationResult:
        """
        Verifies a claim against the knowledge base, Wikipedia, and performs temporal validation.

        Args:
            claim (str): The claim to verify.

        Returns:
            ClaimVerificationResult: A dictionary containing verification results.
        """
        kb_verification = self._verify_claim_against_kb(claim)
        is_supported = kb_verification['is_supported']
        supporting_evidence = kb_verification['supporting_evidence']
        similarity_score = kb_verification['similarity_score']

        if self.wikipedia_enabled:
            wikipedia_verification = self._verify_claim_against_wikipedia(claim)
            is_supported = is_supported or wikipedia_verification['is_supported']
            supporting_evidence.extend(wikipedia_verification['supporting_evidence'])
            similarity_score = max(similarity_score, wikipedia_verification['similarity_score'])

        if self.temporal_validation_enabled:
            temporal_result = self._perform_temporal_validation(claim)
            if not temporal_result['is_valid']:
                is_supported = False # Invalidate if temporal context is wrong.

        # Apply domain-specific validators
        for validator in self.domain_validators:
            try:
                validation_result = validator(claim)
                if not validation_result['is_valid']:
                    is_supported = False  # Domain-specific validation failed
                    supporting_evidence.append(validation_result.get('reason', 'Domain validation failed'))
            except Exception as e:
                logging.error(f"Error in domain validator: {e}")


        return {
            'is_supported': is_supported,
            'supporting_evidence': supporting_evidence,
            'similarity_score': similarity_score
        }


    def _verify_claim_against_kb(self, claim: str) -> ClaimVerificationResult:
        """
        Verifies a claim against the primary knowledge base.

        Args:
            claim (str): The claim to verify.

        Returns:
            ClaimVerificationResult: A dictionary containing verification results.
        """
        try:
            response = requests.get(self.knowledge_base_url, params={'query': claim})
            response.raise_for_status()
            data = response.json()

            if 'results' not in data or not isinstance(data['results'], list):
                logging.warning(f"Unexpected knowledge base response format: {data}")
                return {
                    'is_supported': False,
                    'supporting_evidence': [],
                    'similarity_score': 0.0
                }

            knowledge_base_documents = data['results']
        except requests.exceptions.RequestException as e:
            logging.error(f"Error querying knowledge base: {e}")
            return {
                'is_supported': False,
                'supporting_evidence': [],
                'similarity_score': 0.0
            }

        if not knowledge_base_documents:
            return {
                'is_supported': False,
                'supporting_evidence': [],
                'similarity_score': 0.0
            }

        claim_embedding = self.get_embedding(claim)
        max_similarity = 0.0
        best_supporting_evidence = []

        for document in knowledge_base_documents:
            document_embedding = self.get_embedding(document)
            similarity = cosine_similarity([claim_embedding], [document_embedding])[0][0]
            if similarity > max_similarity:
                max_similarity = similarity
                best_supporting_evidence = [document]

        is_supported = max_similarity >= self.similarity_threshold

        return {
            'is_supported': is_supported,
            'supporting_evidence': best_supporting_evidence,
            'similarity_score': max_similarity
        }

    def _verify_claim_against_wikipedia(self, claim: str) -> ClaimVerificationResult:
        """
        Verifies a claim against Wikipedia using a simplified search.

        Args:
            claim (str): The claim to verify.

        Returns:
            ClaimVerificationResult: A dictionary containing verification results.
        """
        try:
            # Simplified Wikipedia search (replace with a more robust Wikipedia API interaction)
            wikipedia_url = f"https://en.wikipedia.org/w/api.php?action=query&format=json&list=search&srsearch={claim}"
            response = requests.get(wikipedia_url)
            response.raise_for_status()
            data = response.json()

            if 'query' in data and 'search' in data['query']:
                search_results = data['query']['search']
                if search_results:
                    best_result = search_results[0]['snippet']  # Simplistic, use actual content extraction
                    similarity_score = self._calculate_semantic_similarity(claim, best_result)
                    is_supported = similarity_score >= self.similarity_threshold
                    return {
                        'is_supported': is_supported,
                        'supporting_evidence': [best_result],
                        'similarity_score': similarity_score
                    }

        except requests.exceptions.RequestException as e:
            logging.error(f"Error querying Wikipedia: {e}")
            return {
                'is_supported': False,
                'supporting_evidence': [],
                'similarity_score': 0.0
            }

        return {
            'is_supported': False,
            'supporting_evidence': [],
            'similarity_score': 0.0
        }

    def _perform_temporal_validation(self, claim: str) -> Dict[str, bool]:
        """
        Performs temporal validation of the claim.  This is a placeholder.
        In a real implementation, this would involve identifying temporal keywords
        in the claim and verifying if the claim is still valid in the current time.

        Args:
            claim (str): The claim to validate.

        Returns:
            Dict[str, bool]: A dictionary indicating whether the claim is valid in the current time.
        """
        # Placeholder: Always returns valid for now
        return {'is_valid': True}

    def calculate_hallucination_score(self, verification_result: ClaimVerificationResult) -> float:
        """
        Calculates a hallucination probability score based on the verification result.

        Args:
            verification_result (ClaimVerificationResult): The result of the claim verification.

        Returns:
            float: A hallucination probability score (0-1).
        """
        if verification_result['is_supported']:
            score = 1.0 - verification_result['similarity_score']
            if self.confidence_calibration_model:
                score = self.confidence_calibration_model.predict(np.array([[score]]))[0] # Calibrate
            return score
        else:
            return 1.0

    def get_embedding(self, text: str) -> np.ndarray:
        """
        Generates a text embedding using spaCy.

        Args:
            text (str): The text to embed.

        Returns:
            np.ndarray: The text embedding.
        """
        doc = nlp(text)
        return doc.vector

    def _calculate_semantic_similarity(self, text1: str, text2: str) -> float:
        """
        Calculates semantic similarity between two texts using a transformer model.

        Args:
            text1 (str): The first text.
            text2 (str): The second text.

        Returns:
            float: The semantic similarity score (0-1).
        """
        if semantic_model is None or semantic_tokenizer is None:
            logging.warning("Semantic similarity model not loaded. Returning 0 similarity.")
            return 0.0

        try:
            # Tokenize and encode the texts
            encoded_input1 = semantic_tokenizer(text1, padding=True, truncation=True, return_tensors='pt')
            encoded_input2 = semantic_tokenizer(text2, padding=True, truncation=True, return_tensors='pt')

            # Get the model outputs
            with torch.no_grad():
                output1 = semantic_model(**encoded_input1)
                output2 = semantic_model(**encoded_input2)

            # Calculate the mean pooling output
            embedding1 = self._mean_pooling(output1, encoded_input1['attention_mask'])
            embedding2 = self._mean_pooling(output2, encoded_input2['attention_mask'])

            # Calculate cosine similarity
            similarity = cosine_similarity(embedding1, embedding2)[0][0]
            return similarity

        except Exception as e:
            logging.error(f"Error calculating semantic similarity: {e}")
            return 0.0

    def _mean_pooling(self, model_output, attention_mask):
        """
        Performs mean pooling to generate sentence embeddings.
        """
        token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

    def process_stream(self, text_chunk: str):
        """
        Processes a chunk of text from a stream in real-time.

        Args:
            text_chunk (str): The chunk of text to process.
        """
        self.stream_queue.put(text_chunk)

    def _process_stream(self):
        """
        Internal method to process the stream queue.
        """
        while True:
            text_chunk = self.stream_queue.get()
            if text_chunk is None:
                break  # Exit if None is received (signal to stop)

            try:
                results = self.detect_hallucinations(text_chunk)
                logging.info(f"Stream Processing Results: {results}")
            except Exception as e:
                logging.error(f"Error processing stream chunk: {e}")
            finally:
                self.stream_queue.task_done()

    def stop_stream_processing(self):
        """
        Stops the stream processing thread.
        """
        self.stream_queue.put(None)  # Signal to stop
        self.processing_thread.join()
        logging.info("Stream processing stopped.")

    def validate_citation(self, text: str, citation: str) -> bool:
        """
        Validates a citation against the given text.  Placeholder.

        Args:
            text (str): The text containing the claim.
            citation (str): The citation string.

        Returns:
            bool: True if the citation supports the claim, False otherwise.
        """
        # In a real implementation, this would involve parsing the citation,
        # retrieving the cited document, and verifying if the document supports the claim.
        return True  # Placeholder: Always returns True for now


# Example Usage and Tests
if __name__ == '__main__':

    # Mock knowledge base (replace with a real API endpoint)
    MOCK_KNOWLEDGE_BASE_URL = "https://example.com/knowledge_base"

    def mock_knowledge_base(query):
        """
        A simple mock knowledge base for testing.
        """
        if "capital of France" in query:
            return {"results": ["Paris is the capital of France."]}
        elif "invented the telephone" in query:
            return {"results": ["Alexander Graham Bell invented the telephone."]}
        elif "first man on the moon" in query:
            return {"results": ["Neil Armstrong was the first man to walk on the Moon."]}
        elif "invented the light bulb" in query:
            return {"results": ["Thomas Edison is credited with inventing the practical light bulb."]}
        else:
            return {"results": []}

    # Monkey patch requests.get for testing
    import requests
    def mock_requests_get(url, params=None):
        class MockResponse:
            def __init__(self, json_data, status_code):
                self.json_data = json_data
                self.status_code = status_code

            def json(self):
                return self.json_data

            def raise_for_status(self):
                if self.status_code >= 400:
                    raise requests.exceptions.HTTPError(f"HTTP Error: {self.status_code}")

        query = params['query'] if params else None
        json_data = mock_knowledge_base(query)
        return MockResponse(json_data, 200)

    requests.get = mock_requests_get

    # Example of a domain-specific validator
    def validate_date_format(claim: str) -> Dict[str, bool]:
        """Validates that dates in the claim are in a valid format."""
        if "date:" in claim.lower():
            try:
                # Example: date: 2023-10-27
                date_str = claim.lower().split("date:")[1].strip()
                datetime.datetime.strptime(date_str, "%Y-%m-%d")
                return {"is_valid": True}
            except ValueError:
                return {"is_valid": False, "reason": "Invalid date format."}
        return {"is_valid": True}

    # Create an instance of the HallucinationDetector
    detector = HallucinationDetector(
        knowledge_base_url=MOCK_KNOWLEDGE_BASE_URL,
        similarity_threshold=0.75,
        hallucination_threshold=0.8,
        domain_validators=[validate_date_format]  # Add the domain validator
    )

    # Test cases
    ai_output_1 = "The capital of France is Paris. Alexander Graham Bell invented the telephone. The first man on the moon was Neil Armstrong."
    ai_output_2 = "The capital of Germany is Berlin. Einstein invented the light bulb. The first woman on Mars was Valentina Tereshkova."
    ai_output_3 = "The capital of France is Paris. Thomas Edison invented the light bulb. date: 2023-10-27"
    ai_output_4 = "The capital of France is Paris. Thomas Edison invented the light bulb. date: October 27, 2023" # Invalid date

    print("Analyzing AI Output 1:")
    result1 = detector.detect_hallucinations(ai_output_1)
    print(result1)

    print("\nAnalyzing AI Output 2:")
    result2 = detector.detect_hallucinations(ai_output_2)
    print(result2)

    print("\nAnalyzing AI Output 3:")
    result3 = detector.detect_hallucinations(ai_output_3)
    print(result3)

    print("\nAnalyzing AI Output 4:")
    result4 = detector.detect_hallucinations(ai_output_4)
    print(result4)

    # Example of stream processing
    print("\nStream Processing Example:")
    detector.process_stream("The Eiffel Tower is in Paris. ")
    detector.process_stream("Mount Everest is the tallest mountain.")
    time.sleep(2)  # Allow time for processing
    detector.stop_stream_processing()

    # Unit tests (using unittest module)
    import datetime
    class TestHallucinationDetector(unittest.TestCase):
        def setUp(self):
            self.detector = HallucinationDetector(
                knowledge_base_url=MOCK_KNOWLEDGE_BASE_URL,
                similarity_threshold=0.75,
                hallucination_threshold=0.8,
                domain_validators=[validate_date_format]
            )
            # Monkey patch requests.get for testing in unit tests
            requests.get = mock_requests_get

        def test_detect_hallucinations_accurate(self):
            output = "The capital of France is Paris."
            result = self.detector.detect_hallucinations(output)
            self.assertFalse(result['quarantine_flag'])
            self.assertAlmostEqual(result['hallucination_scores'][0], 0.25, places=2)

        def test_detect_hallucinations_hallucination(self):
            output = "The capital of Germany is London."
            result = self.detector.detect_hallucinations(output)
            self.assertTrue(result['quarantine_flag'])
            self.assertEqual(result['hallucination_scores'][0], 1.0)

        def test_domain_validator(self):
            output = "The event occurred date: 2024-01-01"
            result = self.detector.detect_hallucinations(output)
            self.assertFalse(result['quarantine_flag'])

            output_invalid_date = "The event occurred date: January 1, 2024"
            result_invalid = self.detector.detect_hallucinations(output_invalid_date)
            self.assertTrue(result_invalid['quarantine_flag'])

        def test_extract_claims(self):
            text = "This is a claim. This is another claim."
            claims = self.detector.extract_claims(text)
            self.assertEqual(len(claims), 2)
            self.assertEqual(claims[0], "This is a claim.")
            self.assertEqual(claims[1], "This is another claim.")

        def test_stream_processing(self):
            self.detector.process_stream("Test stream claim.")
            time.sleep(1) # Allow time for processing
            self.detector.stop_stream_processing()

    # Run the tests
    unittest.main(argv=['first-arg-is-ignored'], exit=False)
