import logging
import nltk
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from nltk.corpus import stopwords
import re  # Import the regular expression module

# Download necessary NLTK data
try:
    nltk.data.find("corpora/stopwords")
except LookupError:
    nltk.download("stopwords")

try:
    nltk.data.find("tokenizers/punkt")
except LookupError:
    nltk.download("punkt")


# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class Axiom:
    """Represents an axiom with its properties and relationships."""

    def __init__(self, text, axiom_id=None, type="Unknown", inferred_from=None, contradicts=None, confidence=1.0, usage_count=0, hierarchy_level=0):
        self.text = text
        self.axiom_id = axiom_id
        self.type = type
        self.inferred_from = inferred_from if inferred_from is not None else set()  # Set of axiom IDs
        self.contradicts = contradicts if contradicts is not None else set()  # Set of axiom IDs
        self.confidence = confidence
        self.usage_count = usage_count
        self.hierarchy_level = hierarchy_level

    def __eq__(self, other):
        if isinstance(other, Axiom):
            return self.text == other.text
        return False

    def __hash__(self):
        return hash(self.text)

    def to_dict(self):
        """Exports the axiom to a dictionary, storing dependencies by ID."""
        return {
            'axiom_id': self.axiom_id,
            'text': self.text,
            'type': self.type,
            'inferred_from': list(self.inferred_from),  # Store IDs
            'contradicts': list(self.contradicts),  # Store IDs
            'confidence': self.confidence,
            'usage_count': self.usage_count,
            'hierarchy_level': self.hierarchy_level
        }

    @classmethod
    def from_dict(cls, data, axiom_lookup):
        """Imports an axiom from a dictionary, resolving dependencies using the lookup."""
        axiom = cls(text=data['text'], axiom_id=data['axiom_id'], type=data['type'], confidence=data['confidence'], usage_count=data['usage_count'], hierarchy_level=data['hierarchy_level'])
        axiom.inferred_from = set(axiom_lookup.get(axiom_id) for axiom_id in data['inferred_from'] if axiom_id in axiom_lookup)
        axiom.contradicts = set(axiom_lookup.get(axiom_id) for axiom_id in data['contradicts'] if axiom_id in axiom_lookup)
        return axiom

    def __repr__(self):
        return f"Axiom(id={self.axiom_id}, text='{self.text[:50]}...', type={self.type}, confidence={self.confidence}, usage={self.usage_count}, level={self.hierarchy_level})"


class AxiomGenerator:
    """Generates, validates, and manages axioms."""

    def __init__(self, llm=None):  # Add LLM integration
        self.axioms = set()
        self.axiom_id_counter = 0
        self.tfidf_vectorizer = TfidfVectorizer(stop_words=stopwords.words('english'))
        self.tfidf_matrix = None
        self.llm = llm  # Placeholder for an LLM (e.g., OpenAI API)

    def _generate_axiom_id(self):
        """Generates a unique axiom ID."""
        self.axiom_id_counter += 1
        return self.axiom_id_counter

    def add_axiom(self, axiom):
        """Adds an axiom to the set, assigning a unique ID and checking for redundancy."""
        if not self._is_redundant(axiom):
            axiom.axiom_id = self._generate_axiom_id()
            self.axioms.add(axiom)
            logging.info(f"Added axiom: {axiom.axiom_id} - {axiom.text}")
            self._update_tfidf_matrix() # Update TF-IDF matrix after adding
            return True
        else:
            logging.info(f"Redundant axiom not added: {axiom.text}")
            return False

    def _update_tfidf_matrix(self):
        """Updates the TF-IDF matrix with the current axioms."""
        texts = [axiom.text for axiom in self.axioms]
        if texts:  # Only fit if there are axioms
            self.tfidf_matrix = self.tfidf_vectorizer.fit_transform(texts)
        else:
            self.tfidf_matrix = None

    def generate_axioms_from_patent(self, patent_text):
        """Generates axioms from a patent text using basic sentence splitting and LLM (optional)."""
        added_axioms = []
        sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', patent_text) # Improved sentence splitting
        for sentence in sentences:
            sentence = sentence.strip()
            if sentence:
                confidence = self._calculate_confidence(sentence, patent_text)
                if self.llm:
                    # Use LLM to refine or validate the axiom
                    try:
                        llm_refined_sentence = self._llm_refine_axiom(sentence)  #Example LLM usage
                        axiom = Axiom(llm_refined_sentence, confidence=confidence)
                    except Exception as e:
                        logging.error(f"LLM processing failed for sentence: {sentence}.  Error: {e}")
                        axiom = Axiom(sentence, confidence=confidence)
                else:
                    axiom = Axiom(sentence, confidence=confidence)

                if self.add_axiom(axiom):
                    added_axioms.append(axiom)
        return added_axioms

    def _llm_refine_axiom(self, sentence):
        """Placeholder for LLM-based axiom refinement.  Replace with actual LLM API calls."""
        #This is a placeholder.  In reality, this would call an LLM.
        #Example:
        # response = openai.Completion.create(engine="davinci", prompt=f"Refine the following axiom: {sentence}", max_tokens=50)
        # return response.choices[0].text.strip()
        return sentence # Return the original sentence if LLM is not available

    def _calculate_confidence(self, sentence, patent_text):
        """Calculates confidence based on sentence frequency in the patent."""
        try:
            sentence_count = patent_text.count(sentence)
            total_sentences = len(re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', patent_text)) # Consistent sentence splitting
            return sentence_count / total_sentences if total_sentences > 0 else 0.0
        except ZeroDivisionError:
            return 0.0

    def _is_redundant(self, new_axiom):
        """Detects redundant axioms using cosine similarity."""
        if self.tfidf_matrix is None or len(self.axioms) == 0:
            return False

        new_vector = self.tfidf_vectorizer.transform([new_axiom.text])
        similarities = cosine_similarity(new_vector, self.tfidf_matrix)
        if similarities.max() > 0.9:  # Threshold for redundancy
            return True
        return False

    def validate_axioms(self):
        """Validates axioms, checking for contradictions and logical consistency."""
        logging.info("Validating axioms...")
        for axiom1 in list(self.axioms): #Iterate over a copy to allow modification
            for axiom2 in list(self.axioms):
                if axiom1.axiom_id != axiom2.axiom_id:
                    if self._detect_contradiction(axiom1, axiom2):
                        axiom1.contradicts.add(axiom2.axiom_id)
                        axiom2.contradicts.add(axiom1.axiom_id)
                        logging.warning(f"Contradiction detected between axiom {axiom1.axiom_id} and {axiom2.axiom_id}")
                    if self._check_logical_consistency(axiom1):
                        logging.warning(f"Axiom {axiom1.axiom_id} is logically inconsistent")

    def _detect_contradiction(self, axiom1, axiom2):
        """Detects contradictions based on simple keyword matching (e.g., 'not').  Expand with LLM or more sophisticated logic."""
        # Placeholder. Requires more sophisticated NLP and reasoning.
        text1 = axiom1.text.lower()
        text2 = axiom2.text.lower()

        # Check for explicit negation
        if ("not" in text1 and text1.replace("not", "") in text2) or ("not" in text2 and text2.replace("not", "") in text1):
            return True

        # Use LLM for contradiction detection (Example)
        if self.llm:
            try:
                prompt = f"Do these two statements contradict each other? Answer True or False.\nStatement 1: {axiom1.text}\nStatement 2: {axiom2.text}"
                #llm_response = self.llm(prompt) #Replace with your LLM call
                llm_response = "False" #Dummy response
                if "true" in llm_response.lower():
                    return True
            except Exception as e:
                logging.error(f"LLM contradiction detection failed: {e}")
        return False

    def _check_logical_consistency(self, axiom):
        """Placeholder for logical consistency check. Requires formal logic or LLM."""
        # Placeholder. Requires formal logic or LLM.
        if "is not" in axiom.text.lower() and "is" in axiom.text.lower():
            return True
        return False

    def derive_axioms(self):
        """Derives new axioms by combining existing ones.  Expand with more sophisticated inference."""
        logging.info("Deriving axioms...")
        derived_axioms = []
        for axiom1 in list(self.axioms):
            for axiom2 in list(self.axioms):
                if axiom1.axiom_id != axiom2.axiom_id:
                    new_axiom_text = self._combine_axioms(axiom1.text, axiom2.text)
                    if new_axiom_text:
                        new_axiom = Axiom(new_axiom_text, type="Derived", inferred_from={axiom1.axiom_id, axiom2.axiom_id})
                        if self.add_axiom(new_axiom):
                            new_axiom.inferred_from.add(axiom1.axiom_id)
                            new_axiom.inferred_from.add(axiom2.axiom_id)
                            derived_axioms.append(new_axiom)
                            logging.info(f"Derived axiom {new_axiom.axiom_id} from {axiom1.axiom_id} and {axiom2.axiom_id}")
        return derived_axioms

    def _combine_axioms(self, text1, text2):
        """Combines two axioms based on shared terms. Expand with more sophisticated NLP."""
        # Placeholder. Requires more sophisticated NLP.
        words1 = set(text1.lower().split())
        words2 = set(text2.lower().split())
        common_words = words1.intersection(words2)
        if common_words:
            return f"{text1} and {text2}"
        return None

    def categorize_axioms(self):
        """Categorizes axioms into foundation, derived, and operational."""
        logging.info("Categorizing axioms...")
        for axiom in self.axioms:
            if not axiom.inferred_from:
                axiom.type = "Foundation"
            elif axiom.type != "Derived":  #Don't overwrite if already derived
                axiom.type = "Operational"

    def get_axioms_by_type(self, axiom_type):
        """Returns a list of axioms of a specific type."""
        return [axiom for axiom in self.axioms if axiom.type == axiom_type]

    def export_axioms(self):
        """Exports axioms to a list of dictionaries, including dependencies by ID."""
        return [axiom.to_dict() for axiom in self.axioms]

    def import_axioms(self, axiom_data):
        """Imports axioms from a list of dictionaries, resolving dependencies."""
        logging.info("Importing axioms...")
        axiom_lookup = {axiom.axiom_id: axiom for axiom in self.axioms if axiom.axiom_id is not None}  # Existing axioms by ID

        for data in axiom_data:
            axiom_id = data.get('axiom_id')
            if axiom_id in axiom_lookup:
                # Update existing axiom
                axiom = axiom_lookup[axiom_id]
                axiom.text = data['text']
                axiom.type = data['type']
                axiom.confidence = data['confidence']
                axiom.usage_count = data['usage_count']
                axiom.hierarchy_level = data['hierarchy_level']
                # Clear existing dependencies and update
                axiom.inferred_from = set()
                axiom.contradicts = set()
                axiom.inferred_from.update(axiom_lookup.get(axiom_id) for axiom_id in data['inferred_from'] if axiom_id in axiom_lookup)
                axiom.contradicts.update(axiom_lookup.get(axiom_id) for axiom_id in data['contradicts'] if axiom_id in axiom_lookup)

                logging.info(f"Updated axiom: {axiom.axiom_id} - {axiom.text}")
            else:
                # Create new axiom
                axiom = Axiom.from_dict(data, axiom_lookup)
                if axiom.axiom_id is None:
                    axiom.axiom_id = self._generate_axiom_id() #Assign new ID if missing
                self.axioms.add(axiom)
                axiom_lookup[axiom.axiom_id] = axiom  # Add to lookup
                logging.info(f"Imported axiom: {axiom.axiom_id} - {axiom.text}")

        self._update_tfidf_matrix()  # Update TF-IDF matrix after import
        return list(self.axioms)

    def track_usage(self, axiom_id):
        """Increments the usage count of an axiom."""
        for axiom in self.axioms:
            if axiom.axiom_id == axiom_id:
                axiom.usage_count += 1
                logging.info(f"Axiom {axiom_id} usage count increased to {axiom.usage_count}")
                self._decay_confidence(axiom) #Apply confidence decay
                return
        logging.warning(f"Axiom with ID {axiom_id} not found for usage tracking.")

    def _decay_confidence(self, axiom):
        """Decays the confidence of an axiom over time based on usage.  Customize the decay rate."""
        decay_rate = 0.01 # Example decay rate
        axiom.confidence *= (1 - decay_rate)
        axiom.confidence = max(0.1, axiom.confidence) #Ensure confidence doesn't drop too low
        logging.info(f"Axiom {axiom.axiom_id} confidence decayed to {axiom.confidence}")

    def build_hierarchy(self):
        """Builds a hierarchy of axioms based on dependencies.  Assigns hierarchy levels."""
        logging.info("Building axiom hierarchy...")
        # Start with foundation axioms at level 0
        for axiom in self.get_axioms_by_type("Foundation"):
            axiom.hierarchy_level = 0

        # Iteratively assign levels to derived axioms
        changed = True
        while changed:
            changed = False
            for axiom in self.axioms:
                if axiom.type != "Foundation":
                    max_dependency_level = -1
                    for dep_id in axiom.inferred_from:
                        dependency = next((a for a in self.axioms if a.axiom_id == dep_id), None)
                        if dependency:
                            max_dependency_level = max(max_dependency_level, dependency.hierarchy_level)

                    if max_dependency_level != -1 and axiom.hierarchy_level <= max_dependency_level:
                        axiom.hierarchy_level = max_dependency_level + 1
                        changed = True
        logging.info("Axiom hierarchy built.")

    def cross_reference(self):
      """Cross-references axioms based on shared keywords or concepts.  This is a placeholder."""
      logging.info("Cross-referencing axioms...")
      #This is a placeholder.  A real implementation would use NLP techniques to identify related axioms.
      #For example, you could use word embeddings or topic modeling to find axioms that are semantically similar.
      #You could also use a knowledge graph to represent the relationships between axioms.
      for axiom1 in list(self.axioms):
          for axiom2 in list(self.axioms):
              if axiom1.axiom_id != axiom2.axiom_id:
                  words1 = set(axiom1.text.lower().split())
                  words2 = set(axiom2.text.lower().split())
                  common_words = words1.intersection(words2)
                  if common_words:
                      #Create a dummy cross-reference.  In a real implementation, you would store this information in a separate data structure.
                      logging.info(f"Axioms {axiom1.axiom_id} and {axiom2.axiom_id} share common words: {common_words}")

    def print_axioms(self):
        """Prints all axioms with their properties."""
        for axiom in sorted(self.axioms, key=lambda x: x.axiom_id):
            print(axiom)


if __name__ == '__main__':
    # Example Usage
    axiom_generator = AxiomGenerator() # LLM can be passed here if available

    # Sample Patents
    patent1_text = "The system comprises a processor. The processor is configured to execute instructions. The instructions cause the processor to perform a method. This method includes receiving data. The data is processed by the processor. The system is efficient. The system is not slow."
    patent2_text = "A device includes a memory. The memory stores data. The device also includes a controller. The controller accesses the data from the memory. The controller processes the data. The device is fast. The device is reliable."

    # Generate axioms
    axioms_from_patent1 = axiom_generator.generate_axioms_from_patent(patent1_text)
    axioms_from_patent2 = axiom_generator.generate_axioms_from_patent(patent2_text)

    # Validate axioms
    axiom_generator.validate_axioms()

    # Derive new axioms
    axiom_generator.derive_axioms()

    # Categorize axioms
    axiom_generator.categorize_axioms()

    # Build hierarchy
    axiom_generator.build_hierarchy()

    # Cross-reference axioms
    axiom_generator.cross_reference()

    # Print axioms
    axiom_generator.print_axioms()

    # Track usage (example)
    if axiom_generator.axioms:
        first_axiom = next(iter(axiom_generator.axioms))
        axiom_generator.track_usage(first_axiom.axiom_id)

    # Export and Import axioms
    exported_axioms = axiom_generator.export_axioms()
    axiom_generator2 = AxiomGenerator()
    axiom_generator2.import_axioms(exported_axioms)

    print("\nImported Axioms:")
    axiom_generator2.print_axioms() #Print imported axioms
