import asyncio
import re
import hashlib
from datetime import datetime
from typing import Dict, Tuple

# Mock Memory Store (Replace with actual DB connector)
class MemoryStore:
    def __init__(self):
        self.audit_log = []
        self.worker_trust = {}
        self.data_quality_scores = {} # Store historical data quality scores

    def log_audit(self, gate_name, worker_id, report):
        self.audit_log.append({"gate": gate_name, "worker_id": worker_id, "report": report})

    def update_worker_trust(self, worker_id, passed):
        if worker_id not in self.worker_trust:
            self.worker_trust[worker_id] = {"successes": 0, "failures": 0}

        if passed:
            self.worker_trust[worker_id]["successes"] += 1
        else:
            self.worker_trust[worker_id]["failures"] += 1

    def store_data_quality_score(self, data_source, score):
        self.data_quality_scores[data_source] = score

    def get_data_quality_score(self, data_source):
        return self.data_quality_scores.get(data_source, 0.9) # Default to 90% if not found

memory = MemoryStore()

# --- Base Gate Class ---
class ValidationGate:
    def __init__(self, memory_store, next_gate=None):
        self.memory = memory_store
        self.next_gate = next_gate
        self.gate_name = self.__class__.__name__ # Automatically set gate name

    async def validate(self, output: str, task_description: str, worker_id: str, metadata: Dict) -> Tuple[float, Dict]:
        """
        Abstract validation method.  Must be implemented by subclasses.
        :param output: The output from the worker.
        :param task_description:  Description of the task assigned.
        :param worker_id: ID of the worker.
        :param metadata:  Metadata associated with the task.
        :return: A tuple of (score, checks) where score is a float between 0 and 1,
                 and checks is a dictionary of boolean checks performed.
        """
        raise NotImplementedError

    async def run_next_gate(self, output: str, task_description: str, worker_id: str, metadata: Dict) -> Dict:
        """
        Chain to the next gate if it exists.
        :param output: The output from the worker.
        :param task_description:  Description of the task assigned.
        :param worker_id: ID of the worker.
        :param metadata:  Metadata associated with the task.
        :return: The report from the next gate, or None if there is no next gate.
        """
        if self.next_gate:
            return await self.next_gate.validate(output, task_description, worker_id, metadata)
        return None

    async def log_decision(self, worker_id: str, report: Dict):
        """
        Log the validation decision to the audit trail.
        :param worker_id: ID of the worker.
        :param report: The validation report.
        """
        try:
            self.memory.log_audit(self.gate_name, worker_id, report)
        except Exception as e:
            print(f"[WARNING] Audit logging failed for {self.gate_name}: {e}")

# --- Gate Implementations ---
class GateAlpha(ValidationGate): # Input Validity
    async def validate(self, output: str, task_description: str, worker_id: str, metadata: Dict) -> Tuple[float, Dict]:
        """Verify source data quality."""
        checks = {}
        data_source = metadata.get("data_source", "default")
        historical_quality = self.memory.get_data_quality_score(data_source)

        # Check if data source is specified
        checks["data_source_specified"] = data_source != "default"
        checks["data_quality_high"] = historical_quality > 0.7 #Adjust threshold as needed

        score = sum(checks.values()) / len(checks)
        report = {
            "gate": self.gate_name,
            "worker_id": worker_id,
            "passed": score >= 0.8, # Adjust threshold as needed
            "score": score,
            "checks": checks,
            "reasoning": "Data source must be specified and historical quality must be above threshold."
        }
        await self.log_decision(worker_id, report)

        return score, checks

class GateBeta(ValidationGate): # Output Quality
    async def validate(self, output: str, task_description: str, worker_id: str, metadata: Dict) -> Tuple[float, Dict]:
        """Check accuracy and completeness of output."""
        checks = {}

        # Example: Check for keywords related to task description
        keywords = task_description.split()
        present_keywords = [keyword for keyword in keywords if keyword.lower() in output.lower()]
        checks["keywords_present"] = len(present_keywords) >= len(keywords) / 2 # At least half of the keywords must be present

        # Example: Check for minimum output length
        checks["output_length_sufficient"] = len(output) > 50

        score = sum(checks.values()) / len(checks)
        report = {
            "gate": self.gate_name,
            "worker_id": worker_id,
            "passed": score >= 0.7, # Adjust threshold as needed
            "score": score,
            "checks": checks,
            "reasoning": "Output must contain relevant keywords and be of sufficient length."
        }
        await self.log_decision(worker_id, report)

        return score, checks

class GateGamma(ValidationGate): # Insight Purity (Hallucination Detection)
    async def validate(self, output: str, task_description: str, worker_id: str, metadata: Dict) -> Tuple[float, Dict]:
        """Detect hallucinations in output."""
        checks = {}

        # Example: Check for contradictory statements within the output
        # This is a placeholder and needs a more sophisticated implementation
        checks["no_contradictions"] = "contradiction" not in output.lower()

        # Example: Check against a knowledge base for factual accuracy
        # This is a placeholder and needs a more sophisticated implementation
        checks["factually_accurate"] = True

        score = sum(checks.values()) / len(checks)
        report = {
            "gate": self.gate_name,
            "worker_id": worker_id,
            "passed": score >= 0.6, # Adjust threshold as needed
            "score": score,
            "checks": checks,
            "reasoning": "Output must not contain contradictions or factual inaccuracies."
        }
        await self.log_decision(worker_id, report)

        return score, checks

class GateDelta(ValidationGate): # Memory Integration
    async def validate(self, output: str, task_description: str, worker_id: str, metadata: Dict) -> Tuple[float, Dict]:
        """Validate memory storage operations."""
        checks = {}

        # Example: Check if output was successfully stored in memory
        # Assume a function store_in_memory exists
        try:
            # store_in_memory(output, metadata) # Simulate storing in memory
            checks["memory_storage_success"] = True # If no exception, assume success
        except Exception:
            checks["memory_storage_success"] = False

        score = sum(checks.values()) / len(checks)
        report = {
            "gate": self.gate_name,
            "worker_id": worker_id,
            "passed": score >= 0.9, # Adjust threshold as needed
            "score": score,
            "checks": checks,
            "reasoning": "Output must be successfully stored in memory."
        }
        await self.log_decision(worker_id, report)

        return score, checks

class GateEpsilon(ValidationGate): # Strategy Alignment
    async def validate(self, output: str, task_description: str, worker_id: str, metadata: Dict) -> Tuple[float, Dict]:
        """Ensure output aligns with revenue pathway."""
        checks = {}

        # Example: Check if output contains elements that can be monetized
        # This needs a more sophisticated implementation
        checks["monetizable_content"] = "revenue" in output.lower()

        # Example: Check if output aligns with overall business strategy
        # This needs a more sophisticated implementation
        checks["strategy_aligned"] = True

        score = sum(checks.values()) / len(checks)
        report = {
            "gate": self.gate_name,
            "worker_id": worker_id,
            "passed": score >= 0.5, # Adjust threshold as needed
            "score": score,
            "checks": checks,
            "reasoning": "Output must contain monetizable content and align with overall business strategy."
        }
        await self.log_decision(worker_id, report)

        return score, checks

class GateZeta(ValidationGate): # Budget Compliance
    async def validate(self, output: str, task_description: str, worker_id: str, metadata: Dict) -> Tuple[float, Dict]:
        """Monitor resource usage and budget compliance."""
        checks = {}

        # Example: Check if resource usage is within budget
        # Assume a function get_resource_usage exists
        resource_usage = 10 #get_resource_usage(worker_id) # Simulate resource usage
        budget_limit = metadata.get("budget_limit", 100)
        checks["budget_compliant"] = resource_usage <= budget_limit

        score = sum(checks.values()) / len(checks)
        report = {
            "gate": self.gate_name,
            "worker_id": worker_id,
            "passed": score >= 0.8, # Adjust threshold as needed
            "score": score,
            "checks": checks,
            "reasoning": "Resource usage must be within budget."
        }
        await self.log_decision(worker_id, report)

        return score, checks

# --- Orchestrator ---
class SixGateValidator:
    def __init__(self, memory):
        self.memory = memory
        self.gate_alpha = GateAlpha(memory)
        self.gate_beta = GateBeta(memory)
        self.gate_gamma = GateGamma(memory)
        self.gate_delta = GateDelta(memory)
        self.gate_epsilon = GateEpsilon(memory)
        self.gate_zeta = GateZeta(memory)

        # Chain the gates together
        self.gate_alpha.next_gate = self.gate_beta
        self.gate_beta.next_gate = self.gate_gamma
        self.gate_gamma.next_gate = self.gate_delta
        self.gate_delta.next_gate = self.gate_epsilon
        self.gate_epsilon.next_gate = self.gate_zeta

    async def validate_worker_output(self, output: str, task_description: str, worker_id: str, metadata: Dict) -> Dict:
        """Run complete 6-gate validation suite asynchronously."""
        # Start with the first gate
        alpha_score, alpha_checks = await self.gate_alpha.validate(output, task_description, worker_id, metadata)
        
        # Chain the validation if alpha passes
        if alpha_score >= 0.8:
            beta_score, beta_checks = await self.gate_beta.validate(output, task_description, worker_id, metadata)

            if beta_score >= 0.7:
                gamma_score, gamma_checks = await self.gate_gamma.validate(output, task_description, worker_id, metadata)
                
                if gamma_score >= 0.6:
                    delta_score, delta_checks = await self.gate_delta.validate(output, task_description, worker_id, metadata)
                    
                    if delta_score >= 0.9:
                        epsilon_score, epsilon_checks = await self.gate_epsilon.validate(output, task_description, worker_id, metadata)
                        
                        if epsilon_score >= 0.5:
                            zeta_score, zeta_checks = await self.gate_zeta.validate(output, task_description, worker_id, metadata)

                            overall_score = (alpha_score + beta_score + gamma_score + delta_score + epsilon_score + zeta_score) / 6
                            passed = overall_score >= 0.7
                        else:
                            overall_score = (alpha_score + beta_score + gamma_score + delta_score + epsilon_score) / 5
                            passed = False
                    else:
                        overall_score = (alpha_score + beta_score + gamma_score + delta_score) / 4
                        passed = False
                else:
                    overall_score = (alpha_score + beta_score + gamma_score) / 3
                    passed = False
            else:
                overall_score = (alpha_score + beta_score) / 2
                passed = False
        else:
            overall_score = alpha_score
            passed = False

        report = {
            "timestamp": datetime.utcnow().isoformat(),
            "worker_id": worker_id,
            "passed": passed,
            "overall_score": overall_score,
            "gates": {
                "alpha": {"score": alpha_score, "checks": alpha_checks},
                "beta": {"score": beta_score, "checks": beta_checks},
                "gamma": {"score": gamma_score, "checks": gamma_checks},
                "delta": {"score": delta_score, "checks": delta_checks},
                "epsilon": {"score": epsilon_score, "checks": epsilon_checks},
                "zeta": {"score": zeta_score, "checks": zeta_checks}
            }
        }

        # Log to immutable audit trail
        try:
            self.memory.log_audit("six_gate_validation", worker_id, report)
        except Exception as e:
            print(f"[WARNING] Audit logging failed: {e}")

        # Update worker trust
        try:
            self.memory.update_worker_trust(worker_id, passed)
        except Exception as e:
            print(f"[WARNING] Trust update failed: {e}")

        return report

# Global validator instance
validator = SixGateValidator(memory)

# Example usage (async)
async def main():
    output = "This is a sample output containing the word revenue. It is factually accurate and doesn't contradict itself."
    task_description = "Generate a revenue-focused output."
    worker_id = "worker123"
    metadata = {"data_source": "internal_data", "output_hash": hashlib.sha256(output.encode()).hexdigest(), "budget_limit": 150}
    report = await validator.validate_worker_output(output, task_description, worker_id, metadata)
    print(report)

if __name__ == "__main__":
    asyncio.run(main())