import re
import nltk
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import spacy
import torch
from transformers import pipeline

# nltk.download('stopwords')  # Uncomment if you haven't downloaded stopwords
# nltk.download('wordnet')  # Uncomment if you haven't downloaded wordnet

class HallucinationDetector:
    """
    A class for detecting hallucinations in generated text using various methods.
    """

    def __init__(self, knowledge_base_path=None, nlp_model="en_core_web_sm", zero_shot_model="facebook/bart-large-mnli"):
        """
        Initializes the HallucinationDetector.

        Args:
            knowledge_base_path (str, optional): Path to the knowledge base file. Defaults to None.
            nlp_model (str, optional): Name of the spaCy NLP model. Defaults to "en_core_web_sm".
            zero_shot_model (str, optional): Name of the Hugging Face Transformers zero-shot classification model. Defaults to "facebook/bart-large-mnli".
        """
        self.knowledge_base = self._load_knowledge_base(knowledge_base_path) if knowledge_base_path else []
        self.nlp = spacy.load(nlp_model)
        self.stopwords = nltk.corpus.stopwords.words('english')
        self.lemmatizer = nltk.stem.WordNetLemmatizer()
        self.device = 0 if torch.cuda.is_available() else -1  # Use GPU if available
        self.classifier = pipeline("zero-shot-classification", model=zero_shot_model, device=self.device)

    def _load_knowledge_base(self, knowledge_base_path):
        """
        Loads the knowledge base from a file.

        Args:
            knowledge_base_path (str): Path to the knowledge base file.

        Returns:
            list: A list of strings, where each string is a document from the knowledge base.
        """
        try:
            with open(knowledge_base_path, 'r', encoding='utf-8') as f:
                knowledge_base = f.read().split('\n\n')  # Split into documents
            return knowledge_base
        except FileNotFoundError:
            print(f"Warning: Knowledge base file not found at {knowledge_base_path}. Hallucination detection will be less effective.")
            return []
        except Exception as e:
            print(f"Warning: Error loading knowledge base: {e}. Hallucination detection will be less effective.")
            return []

    def _preprocess_text(self, text):
        """
        Preprocesses the input text by lowercasing, removing punctuation,
        stop words, and lemmatizing.

        Args:
            text (str): The input text.

        Returns:
            str: The preprocessed text.
        """
        text = text.lower()
        text = re.sub(r'[^\w\s]', '', text)
        text = ' '.join([word for word in text.split() if word not in self.stopwords])
        text = ' '.join([self.lemmatizer.lemmatize(word) for word in text.split()])
        return text

    def _knowledge_base_verification(self, claim):
        """
        Verifies a claim against the knowledge base using cosine similarity.

        Args:
            claim (str): The claim to verify.

        Returns:
            float: The maximum cosine similarity score between the claim and the documents in the knowledge base. Returns 0 if the knowledge base is empty.
        """
        if not self.knowledge_base:
            return 0.0

        claim_processed = self._preprocess_text(claim)
        claim_vector = self.nlp(claim_processed).vector

        max_similarity = 0.0
        for document in self.knowledge_base:
            document_processed = self._preprocess_text(document)
            document_vector = self.nlp(document_processed).vector
            similarity = cosine_similarity([claim_vector], [document_vector])[0][0]
            max_similarity = max(max_similarity, similarity)

        return max_similarity

    def _self_consistency_check(self, text):
        """
        Evaluates the internal consistency of the generated text using cosine similarity between sentences.

        Args:
            text (str): The generated text.

        Returns:
            float: The average cosine similarity score between all pairs of sentences in the text.
        """
        sentences = [sent.text for sent in self.nlp(text).sents]
        if len(sentences) < 2:
            return 1.0  # Consider single-sentence text as consistent

        similarities = []
        for i in range(len(sentences)):
            for j in range(i + 1, len(sentences)):
                sentence1_vector = self.nlp(sentences[i]).vector
                sentence2_vector = self.nlp(sentences[j]).vector
                similarity = cosine_similarity([sentence1_vector], [sentence2_vector])[0][0]
                similarities.append(similarity)

        if not similarities:
            return 1.0 #Avoid division by zero if no sentences are comparable

        return sum(similarities) / len(similarities)  # Correctly calculate the average

    def _source_attribution(self, claim, sources):
        """
        Compares a claim against provided sources using cosine similarity.

        Args:
            claim (str): The claim to verify.
            sources (list): A list of source documents.

        Returns:
            float: The maximum cosine similarity score between the claim and the sources. Returns 0 if no sources are provided.
        """
        if not sources:
            return 0.0

        claim_processed = self._preprocess_text(claim)
        claim_vector = self.nlp(claim_processed).vector

        max_similarity = 0.0
        for source in sources:
            source_processed = self._preprocess_text(source)
            source_vector = self.nlp(source_processed).vector
            similarity = cosine_similarity([claim_vector], [source_vector])[0][0]
            max_similarity = max(max_similarity, similarity)

        return max_similarity

    def _classify_hallucination(self, text, claim):
        """
        Classifies a claim as a hallucination using zero-shot classification.

        Args:
            text (str): The generated text.
            claim (str): The claim to classify.

        Returns:
            float: The probability that the claim is a hallucination.
        """
        candidate_labels = ["hallucination", "not hallucination"]
        hypothesis_template = "This example is {}."
        try:
            result = self.classifier(claim, candidate_labels, hypothesis_template=hypothesis_template)
            hallucination_probability = result['scores'][0]  # Probability of being a hallucination
            return hallucination_probability
        except Exception as e:
            print(f"Error in zero-shot classification: {e}")
            return 0.5 # Return a neutral probability in case of error

    def detect_hallucinations(self, text, sources=None, claim_extraction_method="rule_based"):
        """
        Detects hallucinations in the generated text.

        Args:
            text (str): The generated text.
            sources (list, optional): A list of source documents. Defaults to None.
            claim_extraction_method (str, optional): Method for extracting claims. Defaults to "rule_based".

        Returns:
            dict: A dictionary containing a report of flagged segments and confidence scores.
        """

        claims = self._extract_claims(text, method=claim_extraction_method)
        report = {"flagged_segments": [], "confidence_scores": {}}

        for i, claim in enumerate(claims):
            knowledge_base_score = self._knowledge_base_verification(claim)
            self_consistency_score = self._self_consistency_check(text)
            source_attribution_score = self._source_attribution(claim, sources)
            hallucination_probability = self._classify_hallucination(text, claim)

            # Confidence Calibration (Simple Example - Can be replaced with more sophisticated methods)
            # This is a placeholder.  Ideally, this would involve training a calibration model.
            calibrated_knowledge_base_score = min(1.0, knowledge_base_score * 1.2)  # Boost slightly
            calibrated_self_consistency_score = min(1.0, self_consistency_score * 1.1)
            calibrated_source_attribution_score = min(1.0, source_attribution_score * 1.3)

            # Hallucination Probability Scoring
            # You can adjust the weights based on the performance of each method
            hallucination_score = (
                0.3 * (1 - calibrated_knowledge_base_score) +
                0.2 * (1 - calibrated_self_consistency_score) +
                0.2 * (1 - calibrated_source_attribution_score) +
                0.3 * hallucination_probability
            )

            report["confidence_scores"][f"claim_{i+1}"] = {
                "knowledge_base_score": knowledge_base_score,
                "self_consistency_score": self_consistency_score,
                "source_attribution_score": source_attribution_score,
                "hallucination_probability": hallucination_probability,
                "calibrated_knowledge_base_score": calibrated_knowledge_base_score,
                "calibrated_self_consistency_score": calibrated_self_consistency_score,
                "calibrated_source_attribution_score": calibrated_source_attribution_score,
                "hallucination_score": hallucination_score
            }

            if hallucination_score > 0.5:  # Threshold can be adjusted
                report["flagged_segments"].append({
                    "claim": claim,
                    "hallucination_score": hallucination_score
                })

        return report

    def _extract_claims(self, text, method="rule_based"):
        """
        Extracts claims from the text using a specified method.

        Args:
            text (str): The input text.
            method (str, optional): The method for claim extraction. Defaults to "rule_based".

        Returns:
            list: A list of extracted claims.
        """
        if method == "rule_based":
            # Simple rule-based claim extraction (splitting by sentences)
            claims = [sent.text for sent in self.nlp(text).sents]
            return claims
        elif method == "spacy_ner":
            # Extract claims based on Named Entity Recognition (NER)
            doc = self.nlp(text)
            claims = [ent.text for ent in doc.ents] # All named entities
            return claims

        else:
            print("Warning: Invalid claim extraction method. Using rule-based extraction.")
            claims = [sent.text for sent in self.nlp(text).sents]
            return claims


    def _domain_specific_detection(self, claim, domain="patent"):
      """
      Performs domain-specific hallucination detection.  This is a placeholder.
      In a real application, this would involve specialized knowledge bases,
      ontologies, and rules specific to the domain.  For example, for patents,
      this might involve checking against patent databases, legal precedents,
      and technical standards.

      Args:
          claim (str): The claim to verify.
          domain (str, optional): The domain of the text. Defaults to "patent".

      Returns:
          float: A score indicating the likelihood of hallucination in the given domain.
      """
      if domain == "patent":
          # Placeholder for patent-specific logic
          # This could involve querying patent databases, checking for legal precedents, etc.
          # For now, return a random value.
          return np.random.rand()  # Replace with actual domain-specific logic
      else:
          print(f"Warning: Domain '{domain}' not supported for domain-specific detection.")
          return 0.5  # Neutral score if domain is not supported


if __name__ == '__main__':
    # Create a dummy knowledge base file
    with open("knowledge_base.txt", "w", encoding="utf-8") as f:
        f.write("The quick brown fox jumps over the lazy dog.\n\n")
        f.write("Hallucination detection is an important task in NLP.\n\n")
        f.write("This is a sample document for the knowledge base.")

    # Example usage
    detector = HallucinationDetector(knowledge_base_path="knowledge_base.txt")

    text = "The quick brown rabbit jumps over the lazy cat. Hallucination detection is a crucial part of modern NLP.  This is unrelated to the knowledge base.  Elephants can fly."
    sources = ["The quick brown fox jumps over the lazy dog.", "Hallucination detection is an important task in NLP."]

    print("--- Analysis without Sources ---")
    report = detector.detect_hallucinations(text)
    print(report)

    print("\n--- Analysis with Sources ---")
    report_with_sources = detector.detect_hallucinations(text, sources=sources)
    print(report_with_sources)
