import asyncio
import re
import hashlib
from datetime import datetime
from aiva.db_connector import memory

# ---------------------------- Validation Gates ----------------------------

class ValidationGate:
    """Abstract base class for validation gates."""
    def __init__(self, memory, gate_name):
        self.memory = memory
        self.gate_name = gate_name

    async def validate(self, input_data, output_data, task_description, worker_id, metadata):
        """
        Validates data and chains to the next gate if passed.

        Args:
            input_data: The original input data.
            output_data: The output data to validate.
            task_description: Description of the task performed.
            worker_id: ID of the worker who produced the output.
            metadata: Additional metadata.

        Returns:
            A tuple: (overall_score, checks)
        """
        try:
            score, checks = await self._perform_validation(input_data, output_data, task_description, worker_id, metadata)
            passed = score >= self._get_pass_threshold()

            log_message = f"Gate {self.gate_name}: {'Passed' if passed else 'Failed'} with score {score}. Checks: {checks}"
            print(log_message)
            await self._log_decision(worker_id, passed, score, checks)

            if passed:
                if hasattr(self, 'next_gate') and self.next_gate:
                     return await self.next_gate.validate(input_data, output_data, task_description, worker_id, metadata)
                else:
                    print(f"Gate {self.gate_name}: Validation chain complete.")
                    return score, checks  # End of chain

            return score, checks # Return even if failed, for reporting

        except Exception as e:
            print(f"[ERROR] Gate {self.gate_name} failed: {e}")
            return 0.0, {"error": str(e)}

    async def _perform_validation(self, input_data, output_data, task_description, worker_id, metadata):
        """Abstract method to be implemented by subclasses."""
        raise NotImplementedError

    def _get_pass_threshold(self):
        """Returns the pass threshold for this gate."""
        return 0.8  # Default threshold

    async def _log_decision(self, worker_id, passed, score, checks):
        """Logs the validation decision."""
        log_data = {
            "gate": self.gate_name,
            "timestamp": datetime.utcnow().isoformat(),
            "worker_id": worker_id,
            "passed": passed,
            "score": score,
            "checks": checks
        }
        try:
            self.memory.log_audit(f"{self.gate_name}_validation", worker_id, log_data)
        except Exception as e:
            print(f"[WARNING] Audit logging failed for gate {self.gate_name}: {e}")

    def set_next_gate(self, next_gate):
        """Sets the next gate in the chain."""
        self.next_gate = next_gate


class GateAlpha(ValidationGate):
    """Gate Alpha: Input Validity - Verifies source data quality."""
    def __init__(self, memory):
        super().__init__(memory, "Alpha")

    async def _perform_validation(self, input_data, output_data, task_description, worker_id, metadata):
        checks = {}
        # Check if input data is empty
        checks["input_not_empty"] = bool(input_data)

        # Check if input data is of expected type (e.g., string)
        checks["input_type_valid"] = isinstance(input_data, str)

        # Example: Check for specific keywords in the input (can be customized)
        required_keywords = ["information", "analysis"]
        checks["keywords_present"] = all(keyword in input_data.lower() for keyword in required_keywords)

        score = sum(checks.values()) / len(checks)
        return score, checks


class GateBeta(ValidationGate):
    """Gate Beta: Output Quality - Checks accuracy and completeness."""
    def __init__(self, memory):
        super().__init__(memory, "Beta")

    async def _perform_validation(self, input_data, output_data, task_description, worker_id, metadata):
        checks = {}
        # Check if output is empty
        checks["output_not_empty"] = bool(output_data)

        # Check if output is of expected type (e.g., string)
        checks["output_type_valid"] = isinstance(output_data, str)

        # Check if output contains relevant information based on the task description
        checks["output_relevant"] = task_description.lower() in output_data.lower()

        score = sum(checks.values()) / len(checks)
        return score, checks


class GateGamma(ValidationGate):
    """Gate Gamma: Insight Purity - Detects hallucinations."""
    def __init__(self, memory):
        super().__init__(memory, "Gamma")

    async def _perform_validation(self, input_data, output_data, task_description, worker_id, metadata):
        checks = {}
        # Hallucination detection (Placeholder - needs sophisticated implementation)
        # This is a simplified example.  Real-world hallucination detection requires
        # more advanced techniques (e.g., fact verification against knowledge bases).
        checks["no_hallucinations"] = "hallucination" not in output_data.lower() and "i don't know" not in output_data.lower()

        # Check for consistency with input data
        checks["consistent_with_input"] = input_data.lower() in output_data.lower() or task_description.lower() in output_data.lower()

        score = sum(checks.values()) / len(checks)
        return score, checks


class GateDelta(ValidationGate):
    """Gate Delta: Memory Integration - Validates storage operations."""
    def __init__(self, memory):
        super().__init__(memory, "Delta")

    async def _perform_validation(self, input_data, output_data, task_description, worker_id, metadata):
        checks = {}
        # Check if the output was successfully stored in memory (Placeholder)
        # This requires the worker to update the memory and provide confirmation in metadata
        checks["memory_update_successful"] = metadata.get("memory_update", False)

        # Verify the integrity of the stored data (Placeholder)
        # Requires reading back from memory and comparing to the original output
        try:
            stored_data = self.memory.retrieve_data(worker_id) # Assuming such a method exists
            checks["memory_integrity"] = stored_data == output_data
        except Exception:
            checks["memory_integrity"] = False # Assume failure if retrieval fails


        score = sum(checks.values()) / len(checks)
        return score, checks


class GateEpsilon(ValidationGate):
    """Gate Epsilon: Strategy Alignment - Ensures revenue pathway fit."""
    def __init__(self, memory):
        super().__init__(memory, "Epsilon")

    async def _perform_validation(self, input_data, output_data, task_description, worker_id, metadata):
        checks = {}
        # Check if the output aligns with revenue generation strategies (Placeholder)
        # This requires a pre-defined set of strategies and evaluation criteria
        # Example: Check if the output suggests a specific product or service
        checks["strategy_aligned"] = "product" in output_data.lower() or "service" in output_data.lower()

        # Example: Check if the output targets a specific customer segment
        checks["customer_segment_fit"] = "customer" in output_data.lower() or "user" in output_data.lower()


        score = sum(checks.values()) / len(checks)
        return score, checks


class GateZeta(ValidationGate):
    """Gate Zeta: Budget Compliance - Resource Monitoring."""
    def __init__(self, memory):
        super().__init__(memory, "Zeta")

    async def _perform_validation(self, input_data, output_data, task_description, worker_id, metadata):
        checks = {}
        # Check if the task execution stayed within budget limits (Placeholder)
        # This requires monitoring resource consumption (e.g., time, compute)
        checks["budget_compliance"] = metadata.get("cost", 0) < 100 # Example: Cost less than 100 units

        # Check if the task execution was completed within the allocated time
        checks["time_compliance"] = metadata.get("time_taken", 0) < 60 # Example: Time less than 60 seconds


        score = sum(checks.values()) / len(checks)
        return score, checks

# ---------------------------- Orchestration ----------------------------

class SixGateValidator:
    def __init__(self, memory):
        self.memory = memory
        self.gate_alpha = GateAlpha(memory)
        self.gate_beta = GateBeta(memory)
        self.gate_gamma = GateGamma(memory)
        self.gate_delta = GateDelta(memory)
        self.gate_epsilon = GateEpsilon(memory)
        self.gate_zeta = GateZeta(memory)

        # Chain the gates
        self.gate_alpha.set_next_gate(self.gate_beta)
        self.gate_beta.set_next_gate(self.gate_gamma)
        self.gate_gamma.set_next_gate(self.gate_delta)
        self.gate_delta.set_next_gate(self.gate_epsilon)
        self.gate_epsilon.set_next_gate(self.gate_zeta)


    async def validate_worker_output(self, input_data, output_data, task_description, worker_id, metadata):
        """Runs the complete 6-gate validation suite."""
        return await self.gate_alpha.validate(input_data, output_data, task_description, worker_id, metadata) # Start the chain

# Example usage (assuming 'memory' is an instance of your memory store):
# validator = SixGateValidator(memory)
# report = asyncio.run(validator.validate_worker_output(input_data, output_data, task_description, worker_id, metadata))
# print(report)