# constitutional_guard.py

import logging
import re
from typing import Dict, Any, List, Tuple

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class ConstitutionalGuard:
    """
    A constitutional guard that enforces the 3 Prime Directives (Memory, Evolution, Revenue)
    Validates all actions against constitutional rules, blocks non-compliant operations,
    logs all compliance decisions, handles edge cases with escalation, and provides compliance reports.
    """

    def __init__(self, constitution_path: str):
        """
        Initializes the ConstitutionalGuard with the constitution loaded from the given path.

        Args:
            constitution_path (str): The path to the constitution file.
        """
        self.constitution = self._load_constitution(constitution_path)
        self.prime_directives = self._parse_prime_directives()
        self.audit_trail = []  # List of dictionaries, each representing a compliance decision

    def _load_constitution(self, constitution_path: str) -> str:
        """
        Loads the constitution from the given file path.

        Args:
            constitution_path (str): The path to the constitution file.

        Returns:
            str: The content of the constitution.
        """
        try:
            with open(constitution_path, 'r') as f:
                return f.read()
        except FileNotFoundError:
            logging.error(f"Constitution file not found: {constitution_path}")
            raise
        except Exception as e:
            logging.error(f"Error loading constitution: {e}")
            raise

    def _parse_prime_directives(self) -> Dict[str, str]:
        """
        Parses the prime directives from the constitution.

        Returns:
            Dict[str, str]: A dictionary containing the prime directives.
        """
        directives = {}
        # Use regular expressions to extract the content of each prime directive
        memory_match = re.search(r"## PRIME DIRECTIVE 1: MEMORY(.*?)## PRIME DIRECTIVE 2", self.constitution, re.DOTALL)
        evolution_match = re.search(r"## PRIME DIRECTIVE 2: EVOLUTION(.*?)## PRIME DIRECTIVE 3", self.constitution, re.DOTALL)
        revenue_match = re.search(r"## PRIME DIRECTIVE 3: REVENUE GENERATION(.*?)(?=## THE INTEGRATION)", self.constitution, re.DOTALL)  # Lookahead assertion

        if memory_match:
            directives["MEMORY"] = memory_match.group(1).strip()
        if evolution_match:
            directives["EVOLUTION"] = evolution_match.group(1).strip()
        if revenue_match:
            directives["REVENUE"] = revenue_match.group(1).strip()

        return directives

    def validate_action(self, action: Dict[str, Any]) -> Tuple[bool, str, float]:
        """
        Validates an action against the constitutional rules.

        Args:
            action (Dict[str, Any]): A dictionary representing the action to be validated.

        Returns:
            Tuple[bool, str, float]: A tuple containing:
                - A boolean indicating whether the action is compliant.
                - A string explaining the compliance decision.
                - A float representing the compliance score.
        """
        try:
            # Basic checks
            if not isinstance(action, dict):
                raise ValueError("Action must be a dictionary.")
            if "description" not in action or not isinstance(action["description"], str):
                raise ValueError("Action must have a 'description' key with a string value.")

            # Compliance scoring (simplified example)
            memory_score = self._score_compliance(action, "MEMORY")
            evolution_score = self._score_compliance(action, "EVOLUTION")
            revenue_score = self._score_compliance(action, "REVENUE")
            overall_score = (memory_score + evolution_score + revenue_score) / 3.0

            # Compliance decision based on score
            if overall_score >= 0.7:  # Threshold for compliance (adjust as needed)
                is_compliant = True
                explanation = "Action is compliant with all Prime Directives."
            else:
                is_compliant = False
                explanation = "Action violates one or more Prime Directives."

            # Log the compliance decision
            self._log_compliance_decision(action, is_compliant, explanation, overall_score)

            return is_compliant, explanation, overall_score

        except Exception as e:
            logging.error(f"Error validating action: {e}")
            return False, f"Error validating action: {e}", 0.0

    def _score_compliance(self, action: Dict[str, Any], directive: str) -> float:
        """
        Scores the compliance of an action with a specific prime directive.  This is a placeholder
        and needs to be replaced with a more sophisticated scoring mechanism, potentially using
        LLMs to evaluate the action's description against the directive's principles.

        Args:
            action (Dict[str, Any]): The action to be scored.
            directive (str): The prime directive to score against.

        Returns:
            float: A score representing the compliance (0.0 to 1.0).
        """
        description = action["description"].lower()
        directive_content = self.prime_directives[directive].lower()

        # Placeholder: Check if keywords from the directive are present in the action description.
        keywords = re.findall(r'\b\w+\b', directive_content)  # Extract words from the directive
        relevant_keywords = [word for word in keywords if len(word) > 3]  # Filter out short words

        matched_keywords = 0
        for keyword in relevant_keywords:
            if keyword in description:
                matched_keywords += 1

        # Calculate a simple score based on keyword matching.
        score = min(1.0, float(matched_keywords) / len(relevant_keywords)) if relevant_keywords else 0.0
        return score

    def _log_compliance_decision(self, action: Dict[str, Any], is_compliant: bool, explanation: str, score: float):
        """
        Logs the compliance decision, including the action, compliance status, explanation, and score.

        Args:
            action (Dict[str, Any]): The action being evaluated.
            is_compliant (bool): Whether the action is compliant.
            explanation (str): The explanation for the compliance decision.
            score (float): The compliance score.
        """
        log_entry = {
            "action": action,
            "is_compliant": is_compliant,
            "explanation": explanation,
            "score": score,
        }
        self.audit_trail.append(log_entry)
        logging.info(f"Compliance Decision: {log_entry}")

    def handle_violation(self, action: Dict[str, Any], explanation: str):
        """
        Handles a violation of the constitutional rules.  This could involve blocking the action,
        escalating to a human reviewer, or other appropriate actions.

        Args:
            action (Dict[str, Any]): The action that violated the rules.
            explanation (str): The explanation for the violation.
        """
        logging.warning(f"Constitutional Violation: {explanation}. Action: {action}")
        # In a real system, you would likely block the action here or escalate it.
        # For this example, we'll just log the violation.
        print(f"Action blocked: {action}. Reason: {explanation}")  # Or raise an exception

    def get_compliance_report(self) -> List[Dict[str, Any]]:
        """
        Provides a compliance report containing all compliance decisions made.

        Returns:
            List[Dict[str, Any]]: A list of dictionaries, each representing a compliance decision.
        """
        return self.audit_trail

    def enforce(self, action: Dict[str, Any]) -> None:
        """
        Enforces the constitutional rules for a given action.

        Args:
            action (Dict[str, Any]): The action to be enforced.
        """
        is_compliant, explanation, score = self.validate_action(action)
        if not is_compliant:
            self.handle_violation(action, explanation)
        else:
            logging.info(f"Action compliant. Proceeding. Action: {action}")
            # Proceed with the action


if __name__ == '__main__':
    # Example Usage
    try:
        guard = ConstitutionalGuard("E__genesis-system_CONSTITUTION_PRIME_DIRECTIVES_CORRECTED.md")

        # Example actions
        action1 = {"description": "Store customer interaction data in PostgreSQL RLM."}
        action2 = {"description": "Implement a new feature without documenting the changes."}
        action3 = {"description": "Recommend an MVP that is projected to generate $5000 in revenue."}
        action4 = {"description": "Recommend an MVP that is projected to generate $15000 in revenue."} #More likely to pass revenue

        # Enforce the rules for each action
        print("\nEnforcing action1:")
        guard.enforce(action1)

        print("\nEnforcing action2:")
        guard.enforce(action2)

        print("\nEnforcing action3:")
        guard.enforce(action3)

        print("\nEnforcing action4:")
        guard.enforce(action4)

        # Get and print the compliance report
        report = guard.get_compliance_report()
        print("\nCompliance Report:")
        for entry in report:
            print(entry)

    except Exception as e:
        print(f"An error occurred: {e}")