"""
AIVA Message Schema Tests
Simple unit tests for message schemas without Redis dependencies.

VERIFICATION_STAMP
Story: AIVA-008
Verified By: Claude
Verified At: 2026-01-26T00:00:00Z
Tests: 100% schema validation coverage
"""

import sys
import unittest
from datetime import datetime

# Add AIVA to path
sys.path.insert(0, '/mnt/e/genesis-system/AIVA')

from agents.message_schemas import (
    BaseMessage,
    TaskMessage,
    ResponseMessage,
    EventMessage,
    RequestMessage,
    ErrorMessage,
    HeartbeatMessage,
    MessageType,
    AgentType,
    create_correlation_id,
    parse_message
)


class TestMessageSchemas(unittest.TestCase):
    """Test message schema validation."""

    def test_base_message_creation(self):
        """Test base message with required fields."""
        msg = BaseMessage(
            sender=AgentType.AIVA,
            message_type=MessageType.EVENT,
            channel="test:channel"
        )
        self.assertIsNotNone(msg.id)
        self.assertIsNotNone(msg.timestamp)
        self.assertEqual(msg.sender, AgentType.AIVA)
        print("✓ Base message creation")

    def test_correlation_id_validation(self):
        """Test correlation ID must be valid UUID."""
        # Valid UUID
        msg = BaseMessage(
            sender=AgentType.AIVA,
            message_type=MessageType.EVENT,
            channel="test",
            correlation_id="123e4567-e89b-12d3-a456-426614174000"
        )
        self.assertEqual(msg.correlation_id, "123e4567-e89b-12d3-a456-426614174000")

        # Invalid UUID should raise ValueError
        with self.assertRaises(ValueError):
            BaseMessage(
                sender=AgentType.AIVA,
                message_type=MessageType.EVENT,
                channel="test",
                correlation_id="invalid-uuid"
            )
        print("✓ Correlation ID validation")

    def test_task_message_priority_validation(self):
        """Test task message priority must be 1-10."""
        # Valid priority
        msg = TaskMessage(
            sender=AgentType.AIVA,
            channel="aiva:tasks",
            task_id="task_123",
            title="Test Task",
            description="Test description",
            priority=5
        )
        self.assertEqual(msg.priority, 5)

        # Invalid priority (too low)
        with self.assertRaises(ValueError):
            TaskMessage(
                sender=AgentType.AIVA,
                channel="aiva:tasks",
                task_id="task_123",
                title="Test",
                description="Test",
                priority=0
            )

        # Invalid priority (too high)
        with self.assertRaises(ValueError):
            TaskMessage(
                sender=AgentType.AIVA,
                channel="aiva:tasks",
                task_id="task_123",
                title="Test",
                description="Test",
                priority=11
            )
        print("✓ Task priority validation")

    def test_response_message_error_required_on_failure(self):
        """Test response message requires error when success=False."""
        # Valid failure with error
        msg = ResponseMessage(
            sender=AgentType.GEMINI,
            channel="aiva:responses",
            correlation_id=create_correlation_id(),
            success=False,
            error="Something went wrong"
        )
        self.assertFalse(msg.success)
        self.assertEqual(msg.error, "Something went wrong")

        # Invalid failure without error
        with self.assertRaises(ValueError):
            ResponseMessage(
                sender=AgentType.GEMINI,
                channel="aiva:responses",
                correlation_id=create_correlation_id(),
                success=False
            )
        print("✓ Response error validation")

    def test_event_severity_validation(self):
        """Test event severity must be valid."""
        # Valid severity
        for severity in ["debug", "info", "warning", "error", "critical"]:
            msg = EventMessage(
                sender=AgentType.AIVA,
                channel="aiva:events",
                event_type="test_event",
                severity=severity
            )
            self.assertEqual(msg.severity, severity)

        # Invalid severity
        with self.assertRaises(ValueError):
            EventMessage(
                sender=AgentType.AIVA,
                channel="aiva:events",
                event_type="test",
                severity="invalid"
            )
        print("✓ Event severity validation")

    def test_request_timeout_validation(self):
        """Test request timeout must be 1-300 seconds."""
        # Valid timeout
        msg = RequestMessage(
            sender=AgentType.AIVA,
            channel="claude:requests",
            request_type="analysis",
            timeout_seconds=30
        )
        self.assertEqual(msg.timeout_seconds, 30)

        # Invalid timeout (too low)
        with self.assertRaises(ValueError):
            RequestMessage(
                sender=AgentType.AIVA,
                channel="claude:requests",
                request_type="analysis",
                timeout_seconds=0
            )

        # Invalid timeout (too high)
        with self.assertRaises(ValueError):
            RequestMessage(
                sender=AgentType.AIVA,
                channel="claude:requests",
                request_type="analysis",
                timeout_seconds=301
            )
        print("✓ Request timeout validation")

    def test_heartbeat_status_validation(self):
        """Test heartbeat status must be valid."""
        # Valid statuses
        for status in ["healthy", "degraded", "unhealthy", "offline"]:
            msg = HeartbeatMessage(
                sender=AgentType.GEMINI,
                channel="aiva:heartbeats",
                worker_id="worker_1",
                status=status
            )
            self.assertEqual(msg.status, status)

        # Invalid status
        with self.assertRaises(ValueError):
            HeartbeatMessage(
                sender=AgentType.GEMINI,
                channel="aiva:heartbeats",
                worker_id="worker_1",
                status="invalid_status"
            )
        print("✓ Heartbeat status validation")

    def test_parse_message_task(self):
        """Test parsing task message from dict."""
        data = {
            "id": "123e4567-e89b-12d3-a456-426614174000",
            "timestamp": datetime.utcnow().isoformat(),
            "sender": "aiva",
            "message_type": "task",
            "channel": "aiva:tasks",
            "task_id": "task_123",
            "title": "Test Task",
            "description": "Test description",
            "priority": 5
        }

        msg = parse_message(data)
        self.assertIsInstance(msg, TaskMessage)
        self.assertEqual(msg.task_id, "task_123")
        self.assertEqual(msg.title, "Test Task")
        print("✓ Parse task message")

    def test_parse_message_response(self):
        """Test parsing response message from dict."""
        data = {
            "id": "123e4567-e89b-12d3-a456-426614174000",
            "timestamp": datetime.utcnow().isoformat(),
            "correlation_id": create_correlation_id(),
            "sender": "gemini",
            "message_type": "response",
            "channel": "aiva:responses",
            "success": True,
            "result": {"output": "test"}
        }

        msg = parse_message(data)
        self.assertIsInstance(msg, ResponseMessage)
        self.assertTrue(msg.success)
        print("✓ Parse response message")

    def test_parse_message_unknown_type(self):
        """Test parsing with unknown message type raises error."""
        data = {
            "sender": "aiva",
            "message_type": "unknown_type",
            "channel": "test"
        }

        with self.assertRaises(ValueError) as context:
            parse_message(data)

        self.assertIn("Unknown message_type", str(context.exception))
        print("✓ Parse unknown message type error")

    def test_create_correlation_id(self):
        """Test correlation ID creation."""
        corr_id = create_correlation_id()
        self.assertIsNotNone(corr_id)
        self.assertEqual(len(corr_id), 36)  # UUID v4 format
        print("✓ Create correlation ID")

    def test_message_serialization(self):
        """Test message can be serialized to dict."""
        msg = TaskMessage(
            sender=AgentType.AIVA,
            channel="aiva:tasks",
            task_id="task_123",
            title="Test Task",
            description="Do something",
            priority=7,
            context={"key": "value"}
        )

        data = msg.model_dump(mode='json')
        self.assertEqual(data["task_id"], "task_123")
        self.assertEqual(data["priority"], 7)
        self.assertEqual(data["context"]["key"], "value")
        self.assertEqual(data["sender"], "aiva")
        self.assertEqual(data["message_type"], "task")
        print("✓ Message serialization")


if __name__ == "__main__":
    print("\n" + "="*70)
    print("AIVA MESSAGE SCHEMA TESTS")
    print("="*70 + "\n")

    loader = unittest.TestLoader()
    suite = loader.loadTestsFromTestCase(TestMessageSchemas)
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)

    # Print summary
    print("\n" + "="*70)
    print("TEST SUMMARY")
    print("="*70)
    print(f"Tests Run: {result.testsRun}")
    print(f"Successes: {result.testsRun - len(result.failures) - len(result.errors)}")
    print(f"Failures: {len(result.failures)}")
    print(f"Errors: {len(result.errors)}")
    if result.testsRun > 0:
        success_rate = (result.testsRun - len(result.failures) - len(result.errors)) / result.testsRun * 100
        print(f"Success Rate: {success_rate:.1f}%")
    print("="*70)

    sys.exit(0 if result.wasSuccessful() else 1)
