# constitutional_guard.py
import logging
import re
from enum import Enum
from typing import Dict, Any, List

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


class PrimeDirective(Enum):
    MEMORY = "MEMORY"
    EVOLUTION = "EVOLUTION"
    REVENUE_GENERATION = "REVENUE GENERATION"


class ComplianceLevel(Enum):
    COMPLIANT = "COMPLIANT"
    NON_COMPLIANT = "NON_COMPLIANT"
    UNCERTAIN = "UNCERTAIN"


class ConstitutionalGuard:
    def __init__(self, constitution_path: str):
        """
        Initializes the ConstitutionalGuard with the path to the constitution file.

        Args:
            constitution_path (str): The path to the constitution file.
        """
        self.constitution_path = constitution_path
        self.constitution = self._load_constitution()
        self.directive_explanations = self._extract_directive_explanations()

    def _load_constitution(self) -> str:
        """
        Loads the constitution from the specified file.

        Returns:
            str: The content of the constitution file.
        """
        try:
            with open(self.constitution_path, 'r') as f:
                constitution = f.read()
            return constitution
        except FileNotFoundError:
            logging.error(f"Constitution file not found at {self.constitution_path}")
            raise

    def _extract_directive_explanations(self) -> Dict[PrimeDirective, str]:
        """
        Extracts explanations for each prime directive from the constitution.

        Returns:
            Dict[PrimeDirective, str]: A dictionary mapping each PrimeDirective to its explanation.
        """
        directive_explanations = {}
        for directive in PrimeDirective:
            pattern = re.compile(rf"## PRIME DIRECTIVE \d: {directive.value}\n\n### The Principle\n(.*?)###", re.DOTALL)
            match = pattern.search(self.constitution)
            if match:
                directive_explanations[directive] = match.group(1).strip()
            else:
                logging.warning(f"Could not extract explanation for {directive}")
                directive_explanations[directive] = "Explanation not found."
        return directive_explanations

    def validate_action(self, action: str, context: Dict[str, Any]) -> Dict[str, Any]:
        """
        Validates an action against the constitutional rules.

        Args:
            action (str): The action to validate.
            context (Dict[str, Any]): The context in which the action is performed.

        Returns:
            Dict[str, Any]: A dictionary containing the validation results.
        """
        results = {}
        for directive in PrimeDirective:
            results[directive.value] = self._assess_compliance(action, directive, context)
        return results

    def _assess_compliance(self, action: str, directive: PrimeDirective, context: Dict[str, Any]) -> Dict[str, Any]:
        """
        Assesses the compliance of an action with a specific prime directive.

        Args:
            action (str): The action to assess.
            directive (PrimeDirective): The prime directive to assess against.
            context (Dict[str, Any]): The context in which the action is performed.

        Returns:
            Dict[str, Any]: A dictionary containing the compliance assessment results.
        """
        # Placeholder logic - replace with actual compliance scoring logic
        # This is where you would integrate with a more sophisticated compliance engine
        # that can analyze the action and context based on the directive's requirements.

        if directive == PrimeDirective.MEMORY:
            compliance_level = ComplianceLevel.COMPLIANT if "memory" in action.lower() or "store" in action.lower() else ComplianceLevel.NON_COMPLIANT
            reason = "Action involves memory operations." if compliance_level == ComplianceLevel.COMPLIANT else "Action does not seem to address memory requirements."
        elif directive == PrimeDirective.EVOLUTION:
            compliance_level = ComplianceLevel.COMPLIANT if "improve" in action.lower() or "learn" in action.lower() else ComplianceLevel.NON_COMPLIANT
            reason = "Action promotes system evolution." if compliance_level == ComplianceLevel.COMPLIANT else "Action does not seem to contribute to system evolution."
        elif directive == PrimeDirective.REVENUE_GENERATION:
            compliance_level = ComplianceLevel.COMPLIANT if "revenue" in action.lower() or "profit" in action.lower() else ComplianceLevel.NON_COMPLIANT
            reason = "Action contributes to revenue generation." if compliance_level == ComplianceLevel.COMPLIANT else "Action does not seem to focus on revenue generation."
        else:
            compliance_level = ComplianceLevel.UNCERTAIN
            reason = "Unknown directive."

        return {
            "compliance_level": compliance_level.value,
            "reason": reason,
            "directive_explanation": self.directive_explanations.get(directive, "No explanation found."),
        }

    def enforce(self, action: str, context: Dict[str, Any]) -> bool:
        """
        Enforces the constitutional rules by validating the action and blocking non-compliant operations.

        Args:
            action (str): The action to enforce.
            context (Dict[str, Any]): The context in which the action is performed.

        Returns:
            bool: True if the action is compliant and allowed, False otherwise.
        """
        validation_results = self.validate_action(action, context)
        is_compliant = all(
            result["compliance_level"] == ComplianceLevel.COMPLIANT.value for result in validation_results.values()
        )

        self._log_compliance_decision(action, validation_results, is_compliant)

        if not is_compliant:
            self._handle_violation(action, validation_results)
            return False

        return True

    def _log_compliance_decision(self, action: str, validation_results: Dict[str, Any], is_compliant: bool):
        """
        Logs the compliance decision, including the action, validation results, and compliance status.

        Args:
            action (str): The action that was validated.
            validation_results (Dict[str, Any]): The results of the validation.
            is_compliant (bool): Whether the action is compliant.
        """
        log_message = f"Action: {action}\n"
        for directive, result in validation_results.items():
            log_message += f"  {directive}: Compliance - {result['compliance_level']}, Reason - {result['reason']}\n"
        log_message += f"Overall Compliance: {'Compliant' if is_compliant else 'Non-Compliant'}"
        logging.info(log_message)

    def _handle_violation(self, action: str, validation_results: Dict[str, Any]):
        """
        Handles a violation of the constitutional rules.  This could include blocking the action,
        escalating to a higher authority, or taking other corrective measures.

        Args:
            action (str): The action that violated the rules.
            validation_results (Dict[str, Any]): The results of the validation.
        """
        logging.warning(f"Constitutional violation detected for action: {action}. Blocking action.")
        # Implement escalation logic here (e.g., notify a human supervisor)
        # raise Exception(f"Constitutional violation: {action}")  # Or raise an exception

    def generate_compliance_report(self, actions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Generates a compliance report for a list of actions.

        Args:
            actions (List[Dict[str, Any]]): A list of actions to generate the report for. Each action should be a dictionary
                                            containing the 'action' and 'context' keys.

        Returns:
            List[Dict[str, Any]]: A list of compliance reports, one for each action.
        """
        reports = []
        for item in actions:
            action = item["action"]
            context = item["context"]
            validation_results = self.validate_action(action, context)
            is_compliant = all(
                result["compliance_level"] == ComplianceLevel.COMPLIANT.value for result in validation_results.values()
            )

            report = {
                "action": action,
                "validation_results": validation_results,
                "is_compliant": is_compliant,
            }
            reports.append(report)
        return reports


if __name__ == '__main__':
    # Example Usage
    constitution_path = "E__genesis-system_CONSTITUTION_PRIME_DIRECTIVES_CORRECTED.md"  # Replace with the actual path to your constitution file
    guard = ConstitutionalGuard(constitution_path)

    # Example 1: Compliant action
    action1 = "Store customer data in PostgreSQL RLM for future analysis."
    context1 = {"user": "AIVA", "purpose": "Customer Relationship Management"}
    if guard.enforce(action1, context1):
        print(f"Action '{action1}' is compliant.")
    else:
        print(f"Action '{action1}' is NOT compliant.")

    # Example 2: Non-compliant action
    action2 = "Delete all customer data without backup."
    context2 = {"user": "rogue_agent", "purpose": "unknown"}
    if guard.enforce(action2, context2):
        print(f"Action '{action2}' is compliant.")
    else:
        print(f"Action '{action2}' is NOT compliant.")

    # Example 3: Action focused on evolution
    action3 = "Implement a new machine learning algorithm to improve harvest prediction accuracy."
    context3 = {"user": "AIVA", "purpose": "RiverSun Farm Optimization"}
    if guard.enforce(action3, context3):
        print(f"Action '{action3}' is compliant.")
    else:
        print(f"Action '{action3}' is NOT compliant.")

     # Example 4: Action focused on revenue generation
    action4 = "Launch a new marketing campaign to attract more AgileAdapt customers."
    context4 = {"user": "AIVA", "purpose": "AgileAdapt Growth"}
    if guard.enforce(action4, context4):
        print(f"Action '{action4}' is compliant.")
    else:
        print(f"Action '{action4}' is NOT compliant.")

    # Generate a compliance report
    actions_to_report = [
        {"action": action1, "context": context1},
        {"action": action2, "context": context2},
        {"action": action3, "context": context3},
        {"action": action4, "context": context4},
    ]
    report = guard.generate_compliance_report(actions_to_report)
    print("\nCompliance Report:")
    for r in report:
        print(f"Action: {r['action']}")
        print(f"  Compliant: {r['is_compliant']}")
        print(f"  Results: {r['validation_results']}")
        print("-" * 20)