#!/usr/bin/env python3
"""
GENESIS AGENT COMMUNICATION PROTOCOL
=====================================
Standardized message format for Claude, Gemini, and AIVA to communicate.

Message Types:
    - TaskRequest: Request agent to perform task
    - TaskResponse: Response from agent with results
    - StatusQuery: Check agent status
    - Handoff: Transfer context between agents
    - Coordination: Multi-agent coordination messages

Usage:
    protocol = AgentProtocol()
    msg = protocol.create_task_request(agent="gemini", task=task)
    response = protocol.parse_response(raw_response)
"""

import json
import hashlib
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
from typing import Dict, List, Any, Optional, Union
import base64


class MessageType(Enum):
    """Types of inter-agent messages."""
    TASK_REQUEST = "task_request"
    TASK_RESPONSE = "task_response"
    STATUS_QUERY = "status_query"
    STATUS_RESPONSE = "status_response"
    HANDOFF = "handoff"
    HANDOFF_ACK = "handoff_ack"
    COORDINATION = "coordination"
    HEARTBEAT = "heartbeat"
    ERROR = "error"


class AgentRole(Enum):
    """Roles agents can take."""
    EXECUTOR = "executor"       # Performs tasks
    VERIFIER = "verifier"       # Verifies outputs
    COORDINATOR = "coordinator" # Orchestrates work
    OBSERVER = "observer"       # Monitors without acting
    SPECIALIST = "specialist"   # Domain expert


@dataclass
class AgentIdentity:
    """Identity of an agent."""
    agent_id: str
    agent_type: str  # claude-opus, gemini-flash, aiva, etc.
    role: AgentRole
    capabilities: List[str] = field(default_factory=list)
    context_window: int = 100000
    cost_per_token: float = 0.0


@dataclass
class Message:
    """Base message structure."""
    message_id: str
    message_type: MessageType
    sender: AgentIdentity
    recipient: Optional[AgentIdentity]
    payload: Dict[str, Any]
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
    correlation_id: Optional[str] = None  # Links related messages
    priority: int = 5
    ttl: Optional[int] = None  # Time to live in seconds

    def to_dict(self) -> Dict:
        return {
            "message_id": self.message_id,
            "message_type": self.message_type.value,
            "sender": {
                "agent_id": self.sender.agent_id,
                "agent_type": self.sender.agent_type,
                "role": self.sender.role.value
            },
            "recipient": {
                "agent_id": self.recipient.agent_id,
                "agent_type": self.recipient.agent_type,
                "role": self.recipient.role.value
            } if self.recipient else None,
            "payload": self.payload,
            "timestamp": self.timestamp,
            "correlation_id": self.correlation_id,
            "priority": self.priority,
            "ttl": self.ttl
        }

    def to_json(self) -> str:
        return json.dumps(self.to_dict(), indent=2)


@dataclass
class TaskRequestPayload:
    """Payload for task request messages."""
    task_id: str
    title: str
    description: str
    task_type: str
    complexity: str
    context: Dict[str, Any] = field(default_factory=dict)
    constraints: List[str] = field(default_factory=list)
    expected_output: Optional[str] = None
    deadline: Optional[str] = None


@dataclass
class TaskResponsePayload:
    """Payload for task response messages."""
    task_id: str
    success: bool
    output: Any
    duration: float
    tokens_used: int
    cost: float
    error: Optional[str] = None
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class HandoffPayload:
    """Payload for context handoff between agents."""
    reason: str
    context_summary: str
    key_findings: List[str]
    pending_actions: List[str]
    memory_snapshot: Dict[str, Any] = field(default_factory=dict)
    recommended_approach: Optional[str] = None


class AgentProtocol:
    """
    Manages inter-agent communication protocol.
    """

    def __init__(self, agent_identity: AgentIdentity = None):
        self.identity = agent_identity or AgentIdentity(
            agent_id="genesis-coordinator",
            agent_type="coordinator",
            role=AgentRole.COORDINATOR,
            capabilities=["orchestration", "task_routing"]
        )
        self.message_counter = 0
        self.pending_messages: Dict[str, Message] = {}

    def _generate_message_id(self) -> str:
        """Generate unique message ID."""
        self.message_counter += 1
        content = f"{self.identity.agent_id}{datetime.now().isoformat()}{self.message_counter}"
        return f"msg_{hashlib.md5(content.encode()).hexdigest()[:12]}"

    def create_task_request(
        self,
        recipient: AgentIdentity,
        task_id: str,
        title: str,
        description: str,
        task_type: str = "general",
        complexity: str = "moderate",
        context: Dict = None,
        priority: int = 5
    ) -> Message:
        """Create a task request message."""
        payload = TaskRequestPayload(
            task_id=task_id,
            title=title,
            description=description,
            task_type=task_type,
            complexity=complexity,
            context=context or {}
        )

        return Message(
            message_id=self._generate_message_id(),
            message_type=MessageType.TASK_REQUEST,
            sender=self.identity,
            recipient=recipient,
            payload=asdict(payload),
            priority=priority
        )

    def create_task_response(
        self,
        request_message: Message,
        success: bool,
        output: Any,
        duration: float,
        tokens_used: int = 0,
        cost: float = 0.0,
        error: str = None
    ) -> Message:
        """Create a task response message."""
        payload = TaskResponsePayload(
            task_id=request_message.payload.get("task_id", "unknown"),
            success=success,
            output=output,
            duration=duration,
            tokens_used=tokens_used,
            cost=cost,
            error=error
        )

        return Message(
            message_id=self._generate_message_id(),
            message_type=MessageType.TASK_RESPONSE,
            sender=self.identity,
            recipient=request_message.sender,
            payload=asdict(payload),
            correlation_id=request_message.message_id
        )

    def create_handoff(
        self,
        recipient: AgentIdentity,
        reason: str,
        context_summary: str,
        key_findings: List[str],
        pending_actions: List[str],
        memory_snapshot: Dict = None
    ) -> Message:
        """Create a handoff message to transfer context."""
        payload = HandoffPayload(
            reason=reason,
            context_summary=context_summary,
            key_findings=key_findings,
            pending_actions=pending_actions,
            memory_snapshot=memory_snapshot or {}
        )

        return Message(
            message_id=self._generate_message_id(),
            message_type=MessageType.HANDOFF,
            sender=self.identity,
            recipient=recipient,
            payload=asdict(payload),
            priority=8  # Handoffs are high priority
        )

    def create_status_query(self, recipient: AgentIdentity) -> Message:
        """Create a status query message."""
        return Message(
            message_id=self._generate_message_id(),
            message_type=MessageType.STATUS_QUERY,
            sender=self.identity,
            recipient=recipient,
            payload={"query_time": datetime.now().isoformat()}
        )

    def create_heartbeat(self, status: str = "healthy", metrics: Dict = None) -> Message:
        """Create a heartbeat message."""
        return Message(
            message_id=self._generate_message_id(),
            message_type=MessageType.HEARTBEAT,
            sender=self.identity,
            recipient=None,  # Broadcast
            payload={
                "status": status,
                "metrics": metrics or {},
                "uptime": 0,  # Would track actual uptime
            }
        )

    def create_error(
        self,
        original_message: Message,
        error_code: str,
        error_message: str,
        details: Dict = None
    ) -> Message:
        """Create an error response message."""
        return Message(
            message_id=self._generate_message_id(),
            message_type=MessageType.ERROR,
            sender=self.identity,
            recipient=original_message.sender,
            payload={
                "error_code": error_code,
                "error_message": error_message,
                "details": details or {},
                "original_message_id": original_message.message_id
            },
            correlation_id=original_message.message_id,
            priority=9  # Errors are high priority
        )

    def parse_message(self, json_str: str) -> Message:
        """Parse a JSON message string."""
        data = json.loads(json_str)

        sender = AgentIdentity(
            agent_id=data["sender"]["agent_id"],
            agent_type=data["sender"]["agent_type"],
            role=AgentRole(data["sender"]["role"])
        )

        recipient = None
        if data.get("recipient"):
            recipient = AgentIdentity(
                agent_id=data["recipient"]["agent_id"],
                agent_type=data["recipient"]["agent_type"],
                role=AgentRole(data["recipient"]["role"])
            )

        return Message(
            message_id=data["message_id"],
            message_type=MessageType(data["message_type"]),
            sender=sender,
            recipient=recipient,
            payload=data["payload"],
            timestamp=data.get("timestamp", datetime.now().isoformat()),
            correlation_id=data.get("correlation_id"),
            priority=data.get("priority", 5),
            ttl=data.get("ttl")
        )

    def validate_message(self, message: Message) -> List[str]:
        """Validate a message. Returns list of validation errors."""
        errors = []

        if not message.message_id:
            errors.append("Missing message_id")

        if not message.sender:
            errors.append("Missing sender")

        if message.message_type in [MessageType.TASK_REQUEST, MessageType.TASK_RESPONSE]:
            if not message.recipient:
                errors.append("Task messages require recipient")

        if message.message_type == MessageType.TASK_REQUEST:
            if "task_id" not in message.payload:
                errors.append("Task request missing task_id")

        if message.ttl and message.ttl < 0:
            errors.append("Invalid TTL")

        return errors

    def format_for_agent(self, message: Message, agent_type: str) -> str:
        """Format message for specific agent type."""
        # Different agents may prefer different formats
        if "claude" in agent_type.lower():
            # Claude prefers structured markdown
            return self._format_as_markdown(message)
        elif "gemini" in agent_type.lower():
            # Gemini works well with JSON
            return message.to_json()
        else:
            # Default to JSON
            return message.to_json()

    def _format_as_markdown(self, message: Message) -> str:
        """Format message as markdown for Claude."""
        lines = [
            f"## Message: {message.message_type.value}",
            f"**ID:** {message.message_id}",
            f"**From:** {message.sender.agent_type} ({message.sender.role.value})",
            f"**Priority:** {message.priority}",
            "",
            "### Payload",
            "```json",
            json.dumps(message.payload, indent=2),
            "```"
        ]
        return "\n".join(lines)


# Pre-defined agent identities
CLAUDE_OPUS = AgentIdentity(
    agent_id="claude-opus",
    agent_type="claude-opus-4",
    role=AgentRole.SPECIALIST,
    capabilities=["complex_reasoning", "architecture", "code_review"],
    context_window=200000,
    cost_per_token=0.015
)

CLAUDE_SONNET = AgentIdentity(
    agent_id="claude-sonnet",
    agent_type="claude-sonnet-4",
    role=AgentRole.EXECUTOR,
    capabilities=["implementation", "balanced", "general"],
    context_window=200000,
    cost_per_token=0.003
)

GEMINI_FLASH = AgentIdentity(
    agent_id="gemini-flash",
    agent_type="gemini-2.0-flash",
    role=AgentRole.EXECUTOR,
    capabilities=["fast_iteration", "simple_tasks", "code_generation"],
    context_window=1000000,
    cost_per_token=0.0001
)

GEMINI_PRO = AgentIdentity(
    agent_id="gemini-pro",
    agent_type="gemini-2.5-pro",
    role=AgentRole.SPECIALIST,
    capabilities=["large_context", "research", "analysis"],
    context_window=2000000,
    cost_per_token=0.0005
)

AIVA_QWEN = AgentIdentity(
    agent_id="aiva",
    agent_type="aiva-qwen",
    role=AgentRole.OBSERVER,
    capabilities=["monitoring", "local_execution", "zero_cost"],
    context_window=32000,
    cost_per_token=0.0
)


def main():
    """Demo the agent protocol."""
    import argparse
    parser = argparse.ArgumentParser(description="Genesis Agent Protocol")
    parser.add_argument("command", choices=["demo", "validate"])
    args = parser.parse_args()

    if args.command == "demo":
        print("Agent Protocol Demo")
        print("=" * 40)

        protocol = AgentProtocol(CLAUDE_OPUS)

        # Create task request
        request = protocol.create_task_request(
            recipient=GEMINI_FLASH,
            task_id="demo-001",
            title="Generate hello world",
            description="Create a hello world function in Python",
            task_type="code_generation",
            complexity="simple"
        )
        print("\nTask Request:")
        print(request.to_json())

        # Create response
        response = protocol.create_task_response(
            request_message=request,
            success=True,
            output="def hello(): print('Hello, World!')",
            duration=1.5,
            tokens_used=100,
            cost=0.001
        )
        print("\nTask Response:")
        print(response.to_json())

        # Create handoff
        handoff = protocol.create_handoff(
            recipient=CLAUDE_SONNET,
            reason="Need deeper code review",
            context_summary="Generated basic hello world function",
            key_findings=["Simple function created", "No edge cases handled"],
            pending_actions=["Add error handling", "Add type hints"]
        )
        print("\nHandoff:")
        print(handoff.to_json())

    elif args.command == "validate":
        protocol = AgentProtocol()
        # Create an invalid message for testing
        msg = Message(
            message_id="",  # Invalid
            message_type=MessageType.TASK_REQUEST,
            sender=CLAUDE_OPUS,
            recipient=None,  # Invalid for task request
            payload={}  # Missing task_id
        )
        errors = protocol.validate_message(msg)
        print(f"Validation errors: {errors}")


if __name__ == "__main__":
    main()
