#!/usr/bin/env python3
"""
GENESIS PHASE 5 INTEGRATION TEST SUITE
========================================
Tests integration between all Phase 5 modules.

Tests:
    - Memory tier integration
    - Event bus pub/sub
    - Agent protocol messaging
    - Queue processing
    - Health monitoring
    - Orchestrator coordination

Run:
    python -m pytest tests/test_phase5_integration.py -v
    python tests/test_phase5_integration.py  # Standalone
"""

import json
import sys
import time
import threading
import unittest
from pathlib import Path
from datetime import datetime, timedelta

# Add core to path
sys.path.insert(0, str(Path(__file__).parent.parent / "core"))


class TestMemoryIntegration(unittest.TestCase):
    """Test memory tier integration."""

    def setUp(self):
        from memory_integration import MemoryIntegration
        self.memory = MemoryIntegration()

    def test_working_memory_store_recall(self):
        """Test working memory store and recall."""
        key = f"test_key_{time.time()}"
        value = {"test": "data", "number": 42}

        # Store
        success = self.memory.store(key, value, tier="working")
        self.assertTrue(success)

        # Recall
        result = self.memory.recall(key)
        self.assertEqual(len(result.entries), 1)
        self.assertEqual(result.entries[0].value, value)

    def test_episodic_memory_store_search(self):
        """Test episodic memory store and search."""
        key = f"task:test_{time.time()}"
        value = {"task_id": "test-001", "status": "complete"}

        # Store
        success = self.memory.store(
            key, value, tier="episodic",
            episode_type="task", importance=0.8
        )
        self.assertTrue(success)

        # Recall exact
        result = self.memory.recall(key)
        self.assertGreater(len(result.entries), 0)

    def test_semantic_memory_keyword_search(self):
        """Test semantic memory keyword search."""
        # Store knowledge
        self.memory.store(
            "genesis_architecture",
            "Genesis is a self-evolving AI system with memory tiers",
            tier="semantic",
            category="architecture"
        )

        # Search
        result = self.memory.recall("genesis architecture")
        # Should find based on keywords
        self.assertIsNotNone(result)

    def test_procedural_memory_steps(self):
        """Test procedural memory for skills."""
        steps = ["Step 1: Initialize", "Step 2: Process", "Step 3: Complete"]

        success = self.memory.store(
            "howto:test_procedure",
            steps,
            tier="procedural",
            category="testing"
        )
        self.assertTrue(success)

        result = self.memory.recall("test_procedure")
        self.assertGreater(len(result.entries), 0)

    def test_auto_tier_detection(self):
        """Test automatic tier detection."""
        # Task data -> episodic
        task_data = {"task_id": "123", "outcome": "success"}
        self.memory.store("auto_task", task_data, tier="auto")

        # Knowledge data -> semantic (if key indicates)
        self.memory.store("concept:testing", "Testing ensures quality", tier="auto")

    def test_health_check(self):
        """Test memory health check."""
        health = self.memory.health_check()

        self.assertIn("working", health)
        self.assertIn("episodic", health)
        self.assertIn("semantic", health)
        self.assertIn("procedural", health)

    def test_get_stats(self):
        """Test memory statistics."""
        stats = self.memory.get_stats()

        self.assertIn("working", stats)
        self.assertIn("episodic", stats)
        self.assertIn("semantic", stats)
        self.assertIn("procedural", stats)


class TestEventBus(unittest.TestCase):
    """Test event bus pub/sub."""

    def setUp(self):
        from event_bus import EventBus
        self.bus = EventBus(persist_events=False)
        self.received_events = []

    def test_simple_pubsub(self):
        """Test simple publish/subscribe."""
        def handler(event):
            self.received_events.append(event)

        self.bus.subscribe("test.event", handler)
        self.bus.publish("test.event", {"data": "test"})

        self.assertEqual(len(self.received_events), 1)
        self.assertEqual(self.received_events[0].data["data"], "test")

    def test_wildcard_subscription(self):
        """Test wildcard topic matching."""
        def handler(event):
            self.received_events.append(event)

        self.bus.subscribe("task.*", handler)

        self.bus.publish("task.started", {"id": 1})
        self.bus.publish("task.completed", {"id": 1})
        self.bus.publish("other.event", {"id": 2})  # Should not match

        self.assertEqual(len(self.received_events), 2)

    def test_unsubscribe(self):
        """Test unsubscribe."""
        def handler(event):
            self.received_events.append(event)

        sub_id = self.bus.subscribe("test.unsub", handler)
        self.bus.publish("test.unsub", {"count": 1})

        self.bus.unsubscribe(sub_id)
        self.bus.publish("test.unsub", {"count": 2})

        self.assertEqual(len(self.received_events), 1)

    def test_event_filter(self):
        """Test event filtering."""
        def handler(event):
            self.received_events.append(event)

        def high_priority_filter(event):
            return event.priority >= 7

        self.bus.subscribe("test.filter", handler, filter_fn=high_priority_filter)

        self.bus.publish("test.filter", {"level": "low"}, priority=3)
        self.bus.publish("test.filter", {"level": "high"}, priority=8)

        self.assertEqual(len(self.received_events), 1)
        self.assertEqual(self.received_events[0].data["level"], "high")

    def test_event_history(self):
        """Test event history."""
        self.bus.publish("history.test", {"n": 1})
        self.bus.publish("history.test", {"n": 2})
        self.bus.publish("history.test", {"n": 3})

        history = self.bus.get_history("history.test")
        self.assertEqual(len(history), 3)

    def test_get_stats(self):
        """Test bus statistics."""
        self.bus.subscribe("stat.test", lambda e: None)
        self.bus.publish("stat.test", {})

        stats = self.bus.get_stats()
        self.assertIn("topics", stats)
        self.assertIn("subscriptions", stats)


class TestAgentProtocol(unittest.TestCase):
    """Test agent protocol messaging."""

    def setUp(self):
        from agent_protocol import (
            AgentProtocol, AgentIdentity, AgentRole,
            CLAUDE_OPUS, GEMINI_FLASH
        )
        self.protocol = AgentProtocol(CLAUDE_OPUS)
        self.claude = CLAUDE_OPUS
        self.gemini = GEMINI_FLASH

    def test_create_task_request(self):
        """Test creating a task request."""
        from agent_protocol import MessageType

        request = self.protocol.create_task_request(
            recipient=self.gemini,
            task_id="test-001",
            title="Test Task",
            description="A test task",
            task_type="testing"
        )

        self.assertEqual(request.message_type, MessageType.TASK_REQUEST)
        self.assertIsNotNone(request.message_id)
        self.assertEqual(request.payload["task_id"], "test-001")

    def test_create_task_response(self):
        """Test creating a task response."""
        request = self.protocol.create_task_request(
            recipient=self.gemini,
            task_id="test-002",
            title="Test",
            description="Test"
        )

        response = self.protocol.create_task_response(
            request_message=request,
            success=True,
            output={"result": "done"},
            duration=1.5,
            tokens_used=100
        )

        self.assertEqual(response.correlation_id, request.message_id)
        self.assertTrue(response.payload["success"])

    def test_create_handoff(self):
        """Test creating a handoff message."""
        from agent_protocol import MessageType, CLAUDE_SONNET

        handoff = self.protocol.create_handoff(
            recipient=CLAUDE_SONNET,
            reason="Need code review",
            context_summary="Implemented feature X",
            key_findings=["Feature works", "Tests pass"],
            pending_actions=["Review code", "Merge PR"]
        )

        self.assertEqual(handoff.message_type, MessageType.HANDOFF)
        self.assertEqual(handoff.priority, 8)  # Handoffs are high priority

    def test_message_validation(self):
        """Test message validation."""
        from agent_protocol import Message, MessageType

        # Valid message
        valid_msg = self.protocol.create_task_request(
            recipient=self.gemini,
            task_id="val-001",
            title="Valid",
            description="Valid task"
        )
        errors = self.protocol.validate_message(valid_msg)
        self.assertEqual(len(errors), 0)

    def test_message_serialization(self):
        """Test message JSON serialization."""
        request = self.protocol.create_task_request(
            recipient=self.gemini,
            task_id="ser-001",
            title="Serialize Test",
            description="Test serialization"
        )

        json_str = request.to_json()
        data = json.loads(json_str)

        self.assertEqual(data["message_id"], request.message_id)
        self.assertEqual(data["payload"]["task_id"], "ser-001")

    def test_format_for_agent(self):
        """Test agent-specific formatting."""
        request = self.protocol.create_task_request(
            recipient=self.gemini,
            task_id="fmt-001",
            title="Format Test",
            description="Test"
        )

        # Claude format (markdown)
        claude_format = self.protocol.format_for_agent(request, "claude-opus")
        self.assertIn("##", claude_format)

        # Gemini format (JSON)
        gemini_format = self.protocol.format_for_agent(request, "gemini-flash")
        self.assertTrue(gemini_format.startswith("{"))


class TestAutonomousQueue(unittest.TestCase):
    """Test autonomous queue manager."""

    def setUp(self):
        from autonomous_queue import AutonomousQueue, TaskPriority
        self.queue = AutonomousQueue(persist=False)
        self.TaskPriority = TaskPriority

    def test_enqueue_dequeue(self):
        """Test basic enqueue and dequeue."""
        task_id = self.queue.enqueue(
            title="Test Task",
            description="A test task"
        )

        self.assertIsNotNone(task_id)

        task = self.queue.dequeue("worker-1")
        self.assertEqual(task.task_id, task_id)

    def test_priority_ordering(self):
        """Test priority-based ordering."""
        low = self.queue.enqueue("Low Priority", priority=self.TaskPriority.LOW)
        high = self.queue.enqueue("High Priority", priority=self.TaskPriority.HIGH)
        critical = self.queue.enqueue("Critical", priority=self.TaskPriority.CRITICAL)

        # Dequeue should be in priority order
        task1 = self.queue.dequeue()
        task2 = self.queue.dequeue()
        task3 = self.queue.dequeue()

        self.assertEqual(task1.task_id, critical)
        self.assertEqual(task2.task_id, high)
        self.assertEqual(task3.task_id, low)

    def test_dependencies(self):
        """Test dependency resolution."""
        t1 = self.queue.enqueue("Task 1")
        t2 = self.queue.enqueue("Task 2", dependencies=[t1])

        # T2 should not be dequeued before T1 completes
        task = self.queue.dequeue()
        self.assertEqual(task.task_id, t1)

        # Complete T1
        self.queue.complete(t1, {"done": True})

        # Now T2 should be available
        task2 = self.queue.dequeue()
        self.assertEqual(task2.task_id, t2)

    def test_retry_on_failure(self):
        """Test retry on failure."""
        task_id = self.queue.enqueue("Retry Test")

        task = self.queue.dequeue()
        self.queue.fail(task_id, "First failure")

        # Task should be back in queue for retry
        status = self.queue.get_queue_status()
        task_obj = self.queue.get_task(task_id)
        self.assertEqual(task_obj.retry_count, 1)

    def test_complete_task(self):
        """Test task completion."""
        task_id = self.queue.enqueue("Complete Test")

        self.queue.dequeue()
        self.queue.complete(task_id, {"result": "success"})

        task = self.queue.get_task(task_id)
        self.assertEqual(task.state.value, "completed")
        self.assertIsNotNone(task.completed_at)

    def test_cancel_task(self):
        """Test task cancellation."""
        task_id = self.queue.enqueue("Cancel Test")

        success = self.queue.cancel(task_id)
        self.assertTrue(success)

        task = self.queue.get_task(task_id)
        self.assertEqual(task.state.value, "cancelled")

    def test_batch_enqueue(self):
        """Test batch enqueue."""
        tasks = [
            {"title": "Batch 1", "priority": 2},
            {"title": "Batch 2", "priority": 1},
            {"title": "Batch 3", "priority": 3}
        ]

        task_ids = self.queue.batch_enqueue(tasks)
        self.assertEqual(len(task_ids), 3)

    def test_queue_status(self):
        """Test queue status."""
        self.queue.enqueue("Status Test")

        status = self.queue.get_queue_status()

        self.assertIn("total_tasks", status)
        self.assertIn("queued", status)
        self.assertIn("running", status)
        self.assertIn("metrics", status)


class TestHealthMonitor(unittest.TestCase):
    """Test health monitoring."""

    def setUp(self):
        from health_monitor import HealthMonitor
        self.monitor = HealthMonitor(check_interval=1)

    def test_get_full_status(self):
        """Test getting full system status."""
        status = self.monitor.get_full_status()

        self.assertIn("overall_status", status)
        self.assertIn("timestamp", status)
        self.assertIn("components", status)
        self.assertIn("system", status)

    def test_system_metrics(self):
        """Test system metrics collection."""
        from health_monitor import SystemMetrics

        cpu = SystemMetrics.get_cpu_usage()
        self.assertIsInstance(cpu, float)
        self.assertGreaterEqual(cpu, 0)

        mem = SystemMetrics.get_memory_usage()
        self.assertIn("total_gb", mem)
        self.assertIn("percent", mem)

        process = SystemMetrics.get_process_info()
        self.assertIn("pid", process)
        self.assertIn("memory_mb", process)

    def test_metrics_recording(self):
        """Test metrics recording."""
        self.monitor.metrics.record("test_metric", 42.5)
        self.monitor.metrics.record("test_metric", 43.0)

        latest = self.monitor.metrics.get_latest("test_metric")
        self.assertEqual(latest.value, 43.0)

        history = self.monitor.metrics.get_history("test_metric")
        self.assertEqual(len(history), 2)

    def test_generate_report(self):
        """Test report generation."""
        report = self.monitor.generate_report()

        self.assertIn("GENESIS HEALTH REPORT", report)
        self.assertIn("Overall Status", report)
        self.assertIn("COMPONENTS", report)
        self.assertIn("SYSTEM RESOURCES", report)

    def test_dashboard_data(self):
        """Test dashboard data format."""
        dashboard = self.monitor.get_dashboard_data()

        self.assertIn("status_banner", dashboard)
        self.assertIn("components", dashboard)
        self.assertIn("system_gauges", dashboard)


class TestVerifiedExecutor(unittest.TestCase):
    """Test verified executor."""

    def setUp(self):
        from verified_executor import VerifiedExecutor
        self.executor = VerifiedExecutor(max_fix_attempts=2)

    def test_is_code_output(self):
        """Test code detection."""
        code = "def hello(): print('hello')"
        not_code = "This is just text"

        self.assertTrue(self.executor._is_code_output(code))
        self.assertFalse(self.executor._is_code_output(not_code))

    def test_get_stats(self):
        """Test statistics."""
        stats = self.executor.get_stats()

        self.assertIn("total", stats)
        self.assertIn("verified", stats)


class TestCrossModuleIntegration(unittest.TestCase):
    """Test integration between multiple modules."""

    def test_memory_to_event_bus(self):
        """Test memory operations triggering events."""
        from memory_integration import MemoryIntegration
        from event_bus import EventBus

        memory = MemoryIntegration()
        bus = EventBus(persist_events=False)
        events_received = []

        def on_memory_event(event):
            events_received.append(event)

        bus.subscribe("memory.*", on_memory_event)

        # Store in memory and publish event
        memory.store("test_key", {"data": "value"})
        bus.publish("memory.stored", {"key": "test_key"})

        self.assertEqual(len(events_received), 1)

    def test_queue_to_health_monitor(self):
        """Test queue metrics in health monitor."""
        from autonomous_queue import AutonomousQueue
        from health_monitor import HealthMonitor

        queue = AutonomousQueue(persist=False)
        monitor = HealthMonitor()

        # Add some tasks
        queue.enqueue("Test 1")
        queue.enqueue("Test 2")

        # Get queue status
        q_status = queue.get_queue_status()

        # Record in metrics
        monitor.metrics.record("queue_depth", q_status["queued"])

        latest = monitor.metrics.get_latest("queue_depth")
        self.assertEqual(latest.value, 2)

    def test_agent_protocol_with_queue(self):
        """Test agent protocol messages via queue."""
        from agent_protocol import AgentProtocol, CLAUDE_OPUS, GEMINI_FLASH
        from autonomous_queue import AutonomousQueue, TaskPriority

        protocol = AgentProtocol(CLAUDE_OPUS)
        queue = AutonomousQueue(persist=False)

        # Create protocol message
        request = protocol.create_task_request(
            recipient=GEMINI_FLASH,
            task_id="queue-proto-001",
            title="Protocol Queue Test",
            description="Test message"
        )

        # Enqueue as task
        task_id = queue.enqueue(
            title=request.payload["title"],
            description=request.payload["description"],
            metadata={"protocol_message": request.to_dict()}
        )

        # Dequeue and verify
        task = queue.dequeue()
        self.assertIn("protocol_message", task.metadata)
        self.assertEqual(
            task.metadata["protocol_message"]["payload"]["task_id"],
            "queue-proto-001"
        )


class TestEndToEnd(unittest.TestCase):
    """End-to-end integration tests."""

    def test_full_task_lifecycle(self):
        """Test complete task lifecycle through all modules."""
        from memory_integration import MemoryIntegration
        from event_bus import EventBus
        from autonomous_queue import AutonomousQueue, TaskPriority
        from health_monitor import HealthMonitor

        memory = MemoryIntegration()
        bus = EventBus(persist_events=False)
        queue = AutonomousQueue(persist=False)
        monitor = HealthMonitor()

        # Track events
        lifecycle_events = []

        def track_event(event):
            lifecycle_events.append(event.topic)

        bus.subscribe("task.**", track_event)

        # 1. Create task
        task_id = queue.enqueue(
            title="Lifecycle Test",
            description="Full lifecycle test",
            priority=TaskPriority.HIGH
        )
        bus.publish("task.created", {"task_id": task_id})

        # 2. Dequeue task
        task = queue.dequeue("test-worker")
        bus.publish("task.started", {"task_id": task_id})

        # 3. Store task context in memory
        memory.store(
            f"task:{task_id}:context",
            {"started_at": task.started_at},
            tier="episodic"
        )

        # 4. Complete task
        queue.complete(task_id, {"result": "success"})
        bus.publish("task.completed", {"task_id": task_id})

        # 5. Store result in memory
        memory.store(
            f"task:{task_id}:result",
            {"status": "complete", "result": "success"},
            tier="episodic",
            importance=0.9
        )

        # 6. Record metrics
        monitor.metrics.record("tasks_completed", 1)

        # Verify lifecycle
        self.assertIn("task.created", lifecycle_events)
        self.assertIn("task.started", lifecycle_events)
        self.assertIn("task.completed", lifecycle_events)

        # Verify memory
        result = memory.recall(f"task:{task_id}:result")
        self.assertGreater(len(result.entries), 0)

        # Verify metrics
        latest = monitor.metrics.get_latest("tasks_completed")
        self.assertEqual(latest.value, 1)


def run_tests():
    """Run all tests and print summary."""
    loader = unittest.TestLoader()
    suite = unittest.TestSuite()

    # Add all test classes
    suite.addTests(loader.loadTestsFromTestCase(TestMemoryIntegration))
    suite.addTests(loader.loadTestsFromTestCase(TestEventBus))
    suite.addTests(loader.loadTestsFromTestCase(TestAgentProtocol))
    suite.addTests(loader.loadTestsFromTestCase(TestAutonomousQueue))
    suite.addTests(loader.loadTestsFromTestCase(TestHealthMonitor))
    suite.addTests(loader.loadTestsFromTestCase(TestVerifiedExecutor))
    suite.addTests(loader.loadTestsFromTestCase(TestCrossModuleIntegration))
    suite.addTests(loader.loadTestsFromTestCase(TestEndToEnd))

    # Run with verbosity
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)

    # Summary
    print("\n" + "=" * 70)
    print("PHASE 5 INTEGRATION TEST SUMMARY")
    print("=" * 70)
    print(f"Tests run: {result.testsRun}")
    print(f"Failures: {len(result.failures)}")
    print(f"Errors: {len(result.errors)}")
    print(f"Skipped: {len(result.skipped)}")
    print(f"Success: {result.wasSuccessful()}")

    return result.wasSuccessful()


if __name__ == "__main__":
    success = run_tests()
    sys.exit(0 if success else 1)
