"""
AIVA Real-Time Stream Processing System
========================================

Production-grade stream processing for handling vast data flows with:
- Apache Kafka-style stream processing semantics
- Time-windowed aggregations (tumbling, sliding, session)
- Event sourcing with full state reconstruction
- Backpressure handling for overwhelming data rates
- Exactly-once delivery guarantees
- State checkpointing for fault tolerance

Genesis Protocol - AIVA Queen Output
"""

from __future__ import annotations

import asyncio
import hashlib
import json
import logging
import os
import pickle
import threading
import time
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum, auto
from pathlib import Path
from typing import (
    Any,
    AsyncGenerator,
    Awaitable,
    Callable,
    Deque,
    Dict,
    Generic,
    List,
    Optional,
    Set,
    Tuple,
    TypeVar,
    Union,
)

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("aiva.stream_processor")

T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
R = TypeVar("R")


# =============================================================================
# Core Data Structures
# =============================================================================


class EventType(Enum):
    """Types of events in the stream."""

    DATA = auto()
    WATERMARK = auto()
    CHECKPOINT = auto()
    BARRIER = auto()
    COMMIT = auto()
    ROLLBACK = auto()


class WindowType(Enum):
    """Types of time windows."""

    TUMBLING = auto()
    SLIDING = auto()
    SESSION = auto()
    GLOBAL = auto()


class DeliveryGuarantee(Enum):
    """Message delivery guarantees."""

    AT_MOST_ONCE = auto()
    AT_LEAST_ONCE = auto()
    EXACTLY_ONCE = auto()


@dataclass
class StreamEvent(Generic[T]):
    """Represents a single event in the stream."""

    event_id: str
    payload: T
    timestamp: float
    event_type: EventType = EventType.DATA
    partition_key: Optional[str] = None
    headers: Dict[str, str] = field(default_factory=dict)
    sequence_number: int = 0
    source_topic: str = ""
    retry_count: int = 0
    idempotency_key: Optional[str] = None

    def __post_init__(self):
        if not self.event_id:
            self.event_id = str(uuid.uuid4())
        if not self.idempotency_key:
            self.idempotency_key = self._compute_idempotency_key()

    def _compute_idempotency_key(self) -> str:
        """Compute a unique idempotency key for exactly-once processing."""
        content = f"{self.source_topic}:{self.partition_key}:{self.sequence_number}"
        return hashlib.sha256(content.encode()).hexdigest()[:16]

    def to_dict(self) -> Dict[str, Any]:
        """Serialize event to dictionary."""
        return {
            "event_id": self.event_id,
            "payload": self.payload,
            "timestamp": self.timestamp,
            "event_type": self.event_type.name,
            "partition_key": self.partition_key,
            "headers": self.headers,
            "sequence_number": self.sequence_number,
            "source_topic": self.source_topic,
            "retry_count": self.retry_count,
            "idempotency_key": self.idempotency_key,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "StreamEvent":
        """Deserialize event from dictionary."""
        data["event_type"] = EventType[data["event_type"]]
        return cls(**data)


@dataclass
class WindowedValue(Generic[K, V]):
    """Represents a value within a time window."""

    key: K
    value: V
    window_start: float
    window_end: float
    event_time: float
    processing_time: float = field(default_factory=time.time)

    @property
    def window_duration(self) -> float:
        return self.window_end - self.window_start


@dataclass
class Checkpoint:
    """Represents a stream processing checkpoint."""

    checkpoint_id: str
    timestamp: float
    state: Dict[str, Any]
    offsets: Dict[str, int]
    watermark: float
    metadata: Dict[str, Any] = field(default_factory=dict)

    def to_bytes(self) -> bytes:
        """Serialize checkpoint to bytes."""
        return pickle.dumps(self)

    @classmethod
    def from_bytes(cls, data: bytes) -> "Checkpoint":
        """Deserialize checkpoint from bytes."""
        return pickle.loads(data)


# =============================================================================
# Backpressure Handler
# =============================================================================


class BackpressureStrategy(Enum):
    """Strategies for handling backpressure."""

    DROP_OLDEST = auto()
    DROP_NEWEST = auto()
    BLOCK = auto()
    SAMPLE = auto()
    ADAPTIVE_RATE_LIMIT = auto()


@dataclass
class BackpressureMetrics:
    """Metrics for backpressure monitoring."""

    current_queue_size: int = 0
    max_queue_size: int = 0
    dropped_events: int = 0
    blocked_time_ms: float = 0
    current_rate: float = 0
    target_rate: float = 0
    backpressure_active: bool = False


class BackpressureHandler:
    """
    Handles overwhelming data rates using various strategies.

    Implements adaptive rate limiting and queue management to prevent
    system overload while maintaining throughput.
    """

    def __init__(
        self,
        max_queue_size: int = 10000,
        strategy: BackpressureStrategy = BackpressureStrategy.ADAPTIVE_RATE_LIMIT,
        high_watermark: float = 0.8,
        low_watermark: float = 0.5,
        sample_rate: float = 0.1,
        rate_limit_window_ms: int = 1000,
    ):
        self.max_queue_size = max_queue_size
        self.strategy = strategy
        self.high_watermark = high_watermark
        self.low_watermark = low_watermark
        self.sample_rate = sample_rate
        self.rate_limit_window_ms = rate_limit_window_ms

        self._queue: Deque[StreamEvent] = deque(maxlen=max_queue_size)
        self._lock = asyncio.Lock()
        self._not_full = asyncio.Condition()
        self._not_empty = asyncio.Condition()

        self._metrics = BackpressureMetrics()
        self._current_rate = 0.0
        self._target_rate = float("inf")
        self._last_rate_check = time.time()
        self._event_count_in_window = 0
        self._sample_counter = 0

    @property
    def metrics(self) -> BackpressureMetrics:
        """Get current backpressure metrics."""
        self._metrics.current_queue_size = len(self._queue)
        self._metrics.max_queue_size = self.max_queue_size
        self._metrics.current_rate = self._current_rate
        self._metrics.target_rate = self._target_rate
        self._metrics.backpressure_active = self._is_backpressure_active()
        return self._metrics

    def _is_backpressure_active(self) -> bool:
        """Check if backpressure is currently active."""
        fill_ratio = len(self._queue) / self.max_queue_size
        return fill_ratio >= self.high_watermark

    def _should_release_backpressure(self) -> bool:
        """Check if backpressure should be released."""
        fill_ratio = len(self._queue) / self.max_queue_size
        return fill_ratio <= self.low_watermark

    async def offer(
        self, event: StreamEvent, timeout_ms: Optional[int] = None
    ) -> bool:
        """
        Attempt to add an event to the queue.

        Returns True if the event was accepted, False otherwise.
        """
        async with self._lock:
            self._update_rate_metrics()

            if self._is_backpressure_active():
                accepted = await self._handle_backpressure(event, timeout_ms)
                return accepted

            self._queue.append(event)
            return True

    async def _handle_backpressure(
        self, event: StreamEvent, timeout_ms: Optional[int]
    ) -> bool:
        """Handle backpressure based on configured strategy."""
        if self.strategy == BackpressureStrategy.DROP_OLDEST:
            if len(self._queue) >= self.max_queue_size:
                self._queue.popleft()
                self._metrics.dropped_events += 1
            self._queue.append(event)
            return True

        elif self.strategy == BackpressureStrategy.DROP_NEWEST:
            self._metrics.dropped_events += 1
            return False

        elif self.strategy == BackpressureStrategy.BLOCK:
            start_time = time.time()
            timeout = (timeout_ms / 1000) if timeout_ms else None

            async with self._not_full:
                while len(self._queue) >= self.max_queue_size:
                    try:
                        await asyncio.wait_for(
                            self._not_full.wait(), timeout=timeout
                        )
                    except asyncio.TimeoutError:
                        self._metrics.blocked_time_ms += (
                            time.time() - start_time
                        ) * 1000
                        return False

            self._queue.append(event)
            self._metrics.blocked_time_ms += (time.time() - start_time) * 1000
            return True

        elif self.strategy == BackpressureStrategy.SAMPLE:
            self._sample_counter += 1
            if self._sample_counter % int(1 / self.sample_rate) == 0:
                self._queue.append(event)
                return True
            self._metrics.dropped_events += 1
            return False

        elif self.strategy == BackpressureStrategy.ADAPTIVE_RATE_LIMIT:
            if self._current_rate > self._target_rate:
                self._target_rate = max(
                    self._target_rate * 0.9, self._current_rate * 0.5
                )
                self._metrics.dropped_events += 1
                return False
            self._queue.append(event)
            return True

        return False

    async def poll(self, timeout_ms: Optional[int] = None) -> Optional[StreamEvent]:
        """
        Poll for the next event from the queue.

        Returns None if timeout expires or queue is empty.
        """
        timeout = (timeout_ms / 1000) if timeout_ms else None

        async with self._not_empty:
            while len(self._queue) == 0:
                try:
                    await asyncio.wait_for(self._not_empty.wait(), timeout=timeout)
                except asyncio.TimeoutError:
                    return None

            event = self._queue.popleft()

            if self._should_release_backpressure():
                self._target_rate = float("inf")

            async with self._not_full:
                self._not_full.notify_all()

            return event

    def _update_rate_metrics(self):
        """Update rate calculation metrics."""
        current_time = time.time()
        elapsed = (current_time - self._last_rate_check) * 1000

        if elapsed >= self.rate_limit_window_ms:
            self._current_rate = (
                self._event_count_in_window / elapsed * 1000
            )
            self._event_count_in_window = 0
            self._last_rate_check = current_time

        self._event_count_in_window += 1


# =============================================================================
# Exactly-Once Delivery
# =============================================================================


class TransactionState(Enum):
    """States for exactly-once transaction processing."""

    PENDING = auto()
    COMMITTED = auto()
    ABORTED = auto()


@dataclass
class TransactionRecord:
    """Record of a transaction for exactly-once processing."""

    transaction_id: str
    state: TransactionState
    events: List[str]  # Event IDs
    created_at: float
    committed_at: Optional[float] = None
    aborted_at: Optional[float] = None


class ExactlyOnceDelivery:
    """
    Guarantees exactly-once message processing semantics.

    Uses idempotency keys, transaction logs, and two-phase commit
    to ensure each message is processed exactly once.
    """

    def __init__(
        self,
        state_dir: Optional[Path] = None,
        transaction_timeout_ms: int = 30000,
        max_processed_cache_size: int = 100000,
    ):
        self.state_dir = state_dir or Path("/tmp/aiva_eo_state")
        self.state_dir.mkdir(parents=True, exist_ok=True)
        self.transaction_timeout_ms = transaction_timeout_ms
        self.max_processed_cache_size = max_processed_cache_size

        self._processed_events: Dict[str, float] = {}
        self._active_transactions: Dict[str, TransactionRecord] = {}
        self._transaction_log: List[TransactionRecord] = []
        self._lock = asyncio.Lock()
        self._sequence_numbers: Dict[str, int] = defaultdict(int)

    async def begin_transaction(self) -> str:
        """Begin a new exactly-once transaction."""
        async with self._lock:
            transaction_id = str(uuid.uuid4())
            self._active_transactions[transaction_id] = TransactionRecord(
                transaction_id=transaction_id,
                state=TransactionState.PENDING,
                events=[],
                created_at=time.time(),
            )
            logger.debug(f"Started transaction: {transaction_id}")
            return transaction_id

    async def add_to_transaction(
        self, transaction_id: str, event: StreamEvent
    ) -> bool:
        """
        Add an event to a transaction.

        Returns True if the event should be processed (not a duplicate).
        """
        async with self._lock:
            if transaction_id not in self._active_transactions:
                raise ValueError(f"Transaction {transaction_id} not found")

            transaction = self._active_transactions[transaction_id]

            # Check for duplicate using idempotency key
            if event.idempotency_key in self._processed_events:
                logger.debug(
                    f"Duplicate event detected: {event.idempotency_key}"
                )
                return False

            transaction.events.append(event.event_id)
            return True

    async def commit_transaction(self, transaction_id: str) -> bool:
        """
        Commit a transaction, marking all events as processed.

        Implements two-phase commit: prepare, then commit.
        """
        async with self._lock:
            if transaction_id not in self._active_transactions:
                raise ValueError(f"Transaction {transaction_id} not found")

            transaction = self._active_transactions[transaction_id]

            # Phase 1: Prepare - write to transaction log
            transaction.state = TransactionState.COMMITTED
            transaction.committed_at = time.time()
            self._transaction_log.append(transaction)

            # Phase 2: Commit - mark events as processed
            for event_id in transaction.events:
                self._processed_events[event_id] = time.time()

            # Cleanup
            del self._active_transactions[transaction_id]

            # Persist to disk
            await self._persist_state()

            # Evict old entries if cache is too large
            await self._evict_old_entries()

            logger.debug(f"Committed transaction: {transaction_id}")
            return True

    async def abort_transaction(self, transaction_id: str):
        """Abort a transaction, discarding all events."""
        async with self._lock:
            if transaction_id not in self._active_transactions:
                return

            transaction = self._active_transactions[transaction_id]
            transaction.state = TransactionState.ABORTED
            transaction.aborted_at = time.time()
            self._transaction_log.append(transaction)

            del self._active_transactions[transaction_id]

            await self._persist_state()
            logger.debug(f"Aborted transaction: {transaction_id}")

    async def is_duplicate(self, event: StreamEvent) -> bool:
        """Check if an event is a duplicate."""
        async with self._lock:
            return event.idempotency_key in self._processed_events

    async def get_next_sequence(self, partition: str) -> int:
        """Get the next sequence number for a partition."""
        async with self._lock:
            self._sequence_numbers[partition] += 1
            return self._sequence_numbers[partition]

    async def _persist_state(self):
        """Persist state to disk for recovery."""
        state_file = self.state_dir / "eo_state.pkl"
        state = {
            "processed_events": self._processed_events,
            "sequence_numbers": dict(self._sequence_numbers),
            "transaction_log": self._transaction_log[-1000:],  # Keep last 1000
        }
        with open(state_file, "wb") as f:
            pickle.dump(state, f)

    async def recover_state(self):
        """Recover state from disk after restart."""
        state_file = self.state_dir / "eo_state.pkl"
        if state_file.exists():
            with open(state_file, "rb") as f:
                state = pickle.load(f)
                self._processed_events = state.get("processed_events", {})
                self._sequence_numbers = defaultdict(
                    int, state.get("sequence_numbers", {})
                )
                self._transaction_log = state.get("transaction_log", [])
            logger.info("Recovered exactly-once state from disk")

    async def _evict_old_entries(self):
        """Evict old entries from the processed events cache."""
        if len(self._processed_events) > self.max_processed_cache_size:
            # Sort by timestamp and remove oldest 20%
            sorted_events = sorted(
                self._processed_events.items(), key=lambda x: x[1]
            )
            evict_count = int(self.max_processed_cache_size * 0.2)
            for key, _ in sorted_events[:evict_count]:
                del self._processed_events[key]


# =============================================================================
# Checkpoint Manager
# =============================================================================


class CheckpointManager:
    """
    Manages state checkpointing for fault tolerance.

    Supports incremental checkpointing, async checkpoint writes,
    and automatic checkpoint cleanup.
    """

    def __init__(
        self,
        checkpoint_dir: Optional[Path] = None,
        checkpoint_interval_ms: int = 60000,
        max_checkpoints: int = 10,
        async_checkpoints: bool = True,
    ):
        self.checkpoint_dir = checkpoint_dir or Path("/tmp/aiva_checkpoints")
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.checkpoint_interval_ms = checkpoint_interval_ms
        self.max_checkpoints = max_checkpoints
        self.async_checkpoints = async_checkpoints

        self._last_checkpoint_time: float = 0
        self._checkpoint_in_progress = False
        self._checkpoint_lock = asyncio.Lock()
        self._checkpoint_history: List[str] = []
        self._pending_writes: asyncio.Queue = asyncio.Queue()
        self._writer_task: Optional[asyncio.Task] = None

    async def start(self):
        """Start the checkpoint manager."""
        if self.async_checkpoints:
            self._writer_task = asyncio.create_task(self._checkpoint_writer())
        logger.info("Checkpoint manager started")

    async def stop(self):
        """Stop the checkpoint manager."""
        if self._writer_task:
            self._writer_task.cancel()
            try:
                await self._writer_task
            except asyncio.CancelledError:
                pass
        logger.info("Checkpoint manager stopped")

    async def should_checkpoint(self) -> bool:
        """Check if it's time for a new checkpoint."""
        elapsed = (time.time() - self._last_checkpoint_time) * 1000
        return elapsed >= self.checkpoint_interval_ms

    async def create_checkpoint(
        self,
        state: Dict[str, Any],
        offsets: Dict[str, int],
        watermark: float,
        metadata: Optional[Dict[str, Any]] = None,
    ) -> Checkpoint:
        """Create a new checkpoint."""
        async with self._checkpoint_lock:
            checkpoint_id = f"ckpt_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}"

            checkpoint = Checkpoint(
                checkpoint_id=checkpoint_id,
                timestamp=time.time(),
                state=state,
                offsets=offsets,
                watermark=watermark,
                metadata=metadata or {},
            )

            if self.async_checkpoints:
                await self._pending_writes.put(checkpoint)
            else:
                await self._write_checkpoint(checkpoint)

            self._last_checkpoint_time = time.time()
            self._checkpoint_history.append(checkpoint_id)

            # Cleanup old checkpoints
            await self._cleanup_old_checkpoints()

            logger.info(f"Created checkpoint: {checkpoint_id}")
            return checkpoint

    async def _write_checkpoint(self, checkpoint: Checkpoint):
        """Write checkpoint to disk."""
        checkpoint_file = self.checkpoint_dir / f"{checkpoint.checkpoint_id}.ckpt"
        with open(checkpoint_file, "wb") as f:
            f.write(checkpoint.to_bytes())

    async def _checkpoint_writer(self):
        """Background task for async checkpoint writes."""
        while True:
            try:
                checkpoint = await self._pending_writes.get()
                await self._write_checkpoint(checkpoint)
                self._pending_writes.task_done()
            except asyncio.CancelledError:
                break
            except Exception as e:
                logger.error(f"Error writing checkpoint: {e}")

    async def restore_from_checkpoint(
        self, checkpoint_id: Optional[str] = None
    ) -> Optional[Checkpoint]:
        """
        Restore state from a checkpoint.

        If no checkpoint_id is provided, restores from the latest checkpoint.
        """
        if checkpoint_id is None:
            # Find latest checkpoint
            checkpoint_files = sorted(
                self.checkpoint_dir.glob("*.ckpt"),
                key=lambda p: p.stat().st_mtime,
                reverse=True,
            )
            if not checkpoint_files:
                logger.warning("No checkpoints found")
                return None
            checkpoint_file = checkpoint_files[0]
        else:
            checkpoint_file = self.checkpoint_dir / f"{checkpoint_id}.ckpt"

        if not checkpoint_file.exists():
            logger.warning(f"Checkpoint not found: {checkpoint_id}")
            return None

        with open(checkpoint_file, "rb") as f:
            checkpoint = Checkpoint.from_bytes(f.read())

        logger.info(f"Restored from checkpoint: {checkpoint.checkpoint_id}")
        return checkpoint

    async def _cleanup_old_checkpoints(self):
        """Remove old checkpoints exceeding max_checkpoints."""
        checkpoint_files = sorted(
            self.checkpoint_dir.glob("*.ckpt"),
            key=lambda p: p.stat().st_mtime,
            reverse=True,
        )

        for old_file in checkpoint_files[self.max_checkpoints :]:
            old_file.unlink()
            logger.debug(f"Removed old checkpoint: {old_file.name}")


# =============================================================================
# Windowed Aggregation
# =============================================================================


@dataclass
class Window:
    """Represents a time window."""

    start: float
    end: float
    window_type: WindowType

    def contains(self, timestamp: float) -> bool:
        """Check if a timestamp falls within this window."""
        return self.start <= timestamp < self.end

    def overlaps(self, other: "Window") -> bool:
        """Check if this window overlaps with another."""
        return self.start < other.end and other.start < self.end


class WindowAssigner(ABC):
    """Base class for window assigners."""

    @abstractmethod
    def assign_windows(self, timestamp: float) -> List[Window]:
        """Assign windows for a given timestamp."""
        pass


class TumblingWindowAssigner(WindowAssigner):
    """Assigns tumbling (non-overlapping) windows."""

    def __init__(self, size_ms: int):
        self.size_ms = size_ms
        self.size_seconds = size_ms / 1000

    def assign_windows(self, timestamp: float) -> List[Window]:
        window_start = (timestamp // self.size_seconds) * self.size_seconds
        return [
            Window(
                start=window_start,
                end=window_start + self.size_seconds,
                window_type=WindowType.TUMBLING,
            )
        ]


class SlidingWindowAssigner(WindowAssigner):
    """Assigns sliding (overlapping) windows."""

    def __init__(self, size_ms: int, slide_ms: int):
        self.size_ms = size_ms
        self.slide_ms = slide_ms
        self.size_seconds = size_ms / 1000
        self.slide_seconds = slide_ms / 1000

    def assign_windows(self, timestamp: float) -> List[Window]:
        windows = []
        window_start = (
            (timestamp // self.slide_seconds) * self.slide_seconds
            - self.size_seconds
            + self.slide_seconds
        )

        while window_start <= timestamp:
            window_end = window_start + self.size_seconds
            if window_start <= timestamp < window_end:
                windows.append(
                    Window(
                        start=window_start,
                        end=window_end,
                        window_type=WindowType.SLIDING,
                    )
                )
            window_start += self.slide_seconds

        return windows


class SessionWindowAssigner(WindowAssigner):
    """Assigns session windows with configurable gap."""

    def __init__(self, gap_ms: int):
        self.gap_ms = gap_ms
        self.gap_seconds = gap_ms / 1000
        self._sessions: Dict[str, Tuple[float, float]] = {}

    def assign_windows(
        self, timestamp: float, session_key: Optional[str] = None
    ) -> List[Window]:
        key = session_key or "default"

        if key in self._sessions:
            start, end = self._sessions[key]
            if timestamp <= end + self.gap_seconds:
                # Extend existing session
                new_end = max(end, timestamp + self.gap_seconds)
                self._sessions[key] = (start, new_end)
                return [
                    Window(start=start, end=new_end, window_type=WindowType.SESSION)
                ]
            else:
                # Start new session
                start = timestamp
                end = timestamp + self.gap_seconds
                self._sessions[key] = (start, end)
                return [
                    Window(start=start, end=end, window_type=WindowType.SESSION)
                ]
        else:
            # First event in session
            start = timestamp
            end = timestamp + self.gap_seconds
            self._sessions[key] = (start, end)
            return [Window(start=start, end=end, window_type=WindowType.SESSION)]


class WindowedAggregation(Generic[K, V, R]):
    """
    Performs time-windowed computations over streaming data.

    Supports tumbling, sliding, and session windows with configurable
    aggregation functions and late data handling.
    """

    def __init__(
        self,
        window_assigner: WindowAssigner,
        aggregator: Callable[[List[V]], R],
        key_extractor: Callable[[Any], K],
        allowed_lateness_ms: int = 0,
    ):
        self.window_assigner = window_assigner
        self.aggregator = aggregator
        self.key_extractor = key_extractor
        self.allowed_lateness_ms = allowed_lateness_ms
        self.allowed_lateness_seconds = allowed_lateness_ms / 1000

        self._window_state: Dict[K, Dict[Tuple[float, float], List[V]]] = (
            defaultdict(lambda: defaultdict(list))
        )
        self._watermark: float = 0
        self._late_data_handler: Optional[Callable[[StreamEvent], None]] = None
        self._lock = asyncio.Lock()

    def set_late_data_handler(self, handler: Callable[[StreamEvent], None]):
        """Set handler for late data events."""
        self._late_data_handler = handler

    async def update_watermark(self, watermark: float):
        """Update the watermark and trigger window computations."""
        async with self._lock:
            old_watermark = self._watermark
            self._watermark = watermark

            # Find windows that should be triggered
            triggered_results = []

            for key, windows in list(self._window_state.items()):
                for (start, end), values in list(windows.items()):
                    trigger_time = end + self.allowed_lateness_seconds
                    if old_watermark < trigger_time <= watermark:
                        result = self.aggregator(values)
                        triggered_results.append(
                            WindowedValue(
                                key=key,
                                value=result,
                                window_start=start,
                                window_end=end,
                                event_time=end,
                            )
                        )
                        # Cleanup window state
                        del windows[(start, end)]

            return triggered_results

    async def process(self, event: StreamEvent[V]) -> List[WindowedValue[K, R]]:
        """
        Process an event and return any triggered window results.
        """
        async with self._lock:
            # Check for late data
            if event.timestamp < self._watermark - self.allowed_lateness_seconds:
                if self._late_data_handler:
                    self._late_data_handler(event)
                return []

            key = self.key_extractor(event.payload)
            windows = self.window_assigner.assign_windows(event.timestamp)

            for window in windows:
                window_key = (window.start, window.end)
                self._window_state[key][window_key].append(event.payload)

            return []

    async def get_window_state(
        self, key: K
    ) -> Dict[Tuple[float, float], List[V]]:
        """Get current window state for a key."""
        async with self._lock:
            return dict(self._window_state.get(key, {}))


# =============================================================================
# Event Sourcing
# =============================================================================


@dataclass
class DomainEvent:
    """Base class for domain events in event sourcing."""

    event_id: str
    aggregate_id: str
    event_type: str
    timestamp: float
    version: int
    payload: Dict[str, Any]
    metadata: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "event_id": self.event_id,
            "aggregate_id": self.aggregate_id,
            "event_type": self.event_type,
            "timestamp": self.timestamp,
            "version": self.version,
            "payload": self.payload,
            "metadata": self.metadata,
        }


class EventStore:
    """
    Stores and retrieves domain events for event sourcing.
    """

    def __init__(self, storage_dir: Optional[Path] = None):
        self.storage_dir = storage_dir or Path("/tmp/aiva_event_store")
        self.storage_dir.mkdir(parents=True, exist_ok=True)

        self._events: Dict[str, List[DomainEvent]] = defaultdict(list)
        self._global_sequence: int = 0
        self._lock = asyncio.Lock()

    async def append(self, event: DomainEvent) -> int:
        """Append an event to the store, returning the global sequence number."""
        async with self._lock:
            self._global_sequence += 1
            event.metadata["global_sequence"] = self._global_sequence
            self._events[event.aggregate_id].append(event)

            # Persist to disk
            await self._persist_event(event)

            return self._global_sequence

    async def get_events(
        self,
        aggregate_id: str,
        from_version: int = 0,
        to_version: Optional[int] = None,
    ) -> List[DomainEvent]:
        """Get events for an aggregate within a version range."""
        async with self._lock:
            events = self._events.get(aggregate_id, [])
            filtered = [e for e in events if e.version >= from_version]
            if to_version is not None:
                filtered = [e for e in filtered if e.version <= to_version]
            return sorted(filtered, key=lambda e: e.version)

    async def get_all_events(
        self, from_sequence: int = 0
    ) -> AsyncGenerator[DomainEvent, None]:
        """Stream all events from a given sequence number."""
        async with self._lock:
            all_events = []
            for events in self._events.values():
                all_events.extend(events)

            sorted_events = sorted(
                all_events,
                key=lambda e: e.metadata.get("global_sequence", 0),
            )

            for event in sorted_events:
                if event.metadata.get("global_sequence", 0) >= from_sequence:
                    yield event

    async def _persist_event(self, event: DomainEvent):
        """Persist an event to disk."""
        event_file = (
            self.storage_dir
            / f"{event.aggregate_id}_{event.version}.json"
        )
        with open(event_file, "w") as f:
            json.dump(event.to_dict(), f)


class Aggregate(ABC):
    """Base class for event-sourced aggregates."""

    def __init__(self, aggregate_id: str):
        self.aggregate_id = aggregate_id
        self.version = 0
        self._pending_events: List[DomainEvent] = []

    @abstractmethod
    def apply(self, event: DomainEvent):
        """Apply an event to update aggregate state."""
        pass

    def load_from_history(self, events: List[DomainEvent]):
        """Reconstruct aggregate state from event history."""
        for event in sorted(events, key=lambda e: e.version):
            self.apply(event)
            self.version = event.version

    def raise_event(self, event_type: str, payload: Dict[str, Any]) -> DomainEvent:
        """Raise a new domain event."""
        self.version += 1
        event = DomainEvent(
            event_id=str(uuid.uuid4()),
            aggregate_id=self.aggregate_id,
            event_type=event_type,
            timestamp=time.time(),
            version=self.version,
            payload=payload,
        )
        self.apply(event)
        self._pending_events.append(event)
        return event

    def get_pending_events(self) -> List[DomainEvent]:
        """Get and clear pending events."""
        events = self._pending_events
        self._pending_events = []
        return events


class EventSourcing:
    """
    Event-driven state management using event sourcing pattern.

    Provides full audit trail, temporal queries, and state reconstruction
    from the event stream.
    """

    def __init__(self, event_store: Optional[EventStore] = None):
        self.event_store = event_store or EventStore()
        self._projections: Dict[str, Callable[[DomainEvent], None]] = {}
        self._snapshots: Dict[str, Tuple[int, Dict[str, Any]]] = {}
        self._snapshot_interval: int = 100
        self._lock = asyncio.Lock()

    def register_projection(
        self, name: str, handler: Callable[[DomainEvent], None]
    ):
        """Register a projection that reacts to events."""
        self._projections[name] = handler

    async def save(self, aggregate: Aggregate):
        """Save an aggregate by persisting its pending events."""
        async with self._lock:
            pending = aggregate.get_pending_events()

            for event in pending:
                await self.event_store.append(event)

                # Run projections
                for projection in self._projections.values():
                    try:
                        projection(event)
                    except Exception as e:
                        logger.error(f"Projection error: {e}")

            # Create snapshot if needed
            if aggregate.version % self._snapshot_interval == 0:
                await self._create_snapshot(aggregate)

    async def load(
        self, aggregate_type: type, aggregate_id: str
    ) -> Optional[Aggregate]:
        """Load an aggregate by replaying its events."""
        async with self._lock:
            # Check for snapshot
            snapshot = await self._get_snapshot(aggregate_id)
            from_version = 0

            aggregate = aggregate_type(aggregate_id)

            if snapshot:
                from_version, state = snapshot
                aggregate.__dict__.update(state)
                aggregate.version = from_version

            # Replay events from snapshot
            events = await self.event_store.get_events(
                aggregate_id, from_version=from_version + 1
            )

            if not events and not snapshot:
                return None

            aggregate.load_from_history(events)
            return aggregate

    async def rebuild_projection(
        self, name: str, handler: Callable[[DomainEvent], None]
    ):
        """Rebuild a projection by replaying all events."""
        self._projections[name] = handler

        async for event in self.event_store.get_all_events():
            try:
                handler(event)
            except Exception as e:
                logger.error(f"Error rebuilding projection {name}: {e}")

    async def _create_snapshot(self, aggregate: Aggregate):
        """Create a snapshot of the aggregate state."""
        state = {
            k: v
            for k, v in aggregate.__dict__.items()
            if not k.startswith("_")
        }
        self._snapshots[aggregate.aggregate_id] = (aggregate.version, state)

    async def _get_snapshot(
        self, aggregate_id: str
    ) -> Optional[Tuple[int, Dict[str, Any]]]:
        """Get the latest snapshot for an aggregate."""
        return self._snapshots.get(aggregate_id)


# =============================================================================
# Stream Processor
# =============================================================================


class StreamProcessor:
    """
    Apache Kafka-style stream processing for AIVA.

    Provides:
    - Topic-based message routing
    - Partitioned processing for parallelism
    - Consumer groups for load balancing
    - At-least-once and exactly-once delivery
    - Windowed aggregations
    - Event sourcing integration
    """

    def __init__(
        self,
        processor_id: str,
        state_dir: Optional[Path] = None,
        delivery_guarantee: DeliveryGuarantee = DeliveryGuarantee.EXACTLY_ONCE,
        num_partitions: int = 8,
        checkpoint_interval_ms: int = 60000,
        max_batch_size: int = 100,
    ):
        self.processor_id = processor_id
        self.state_dir = state_dir or Path(f"/tmp/aiva_stream_{processor_id}")
        self.state_dir.mkdir(parents=True, exist_ok=True)
        self.delivery_guarantee = delivery_guarantee
        self.num_partitions = num_partitions
        self.checkpoint_interval_ms = checkpoint_interval_ms
        self.max_batch_size = max_batch_size

        # Core components
        self.backpressure_handler = BackpressureHandler()
        self.exactly_once = ExactlyOnceDelivery(self.state_dir / "eo")
        self.checkpoint_manager = CheckpointManager(
            self.state_dir / "checkpoints",
            checkpoint_interval_ms=checkpoint_interval_ms,
        )

        # Topic management
        self._topics: Dict[str, List[Deque[StreamEvent]]] = {}
        self._consumer_groups: Dict[str, Dict[str, Set[int]]] = defaultdict(
            lambda: defaultdict(set)
        )

        # Processing state
        self._handlers: Dict[
            str, List[Callable[[StreamEvent], Awaitable[Optional[StreamEvent]]]]
        ] = defaultdict(list)
        self._running = False
        self._processor_tasks: List[asyncio.Task] = []
        self._offsets: Dict[str, Dict[int, int]] = defaultdict(
            lambda: defaultdict(int)
        )
        self._watermarks: Dict[str, float] = defaultdict(float)

        # Metrics
        self._events_processed = 0
        self._events_failed = 0
        self._processing_latency_sum = 0.0

    async def start(self):
        """Start the stream processor."""
        await self.checkpoint_manager.start()
        await self.exactly_once.recover_state()

        # Restore from checkpoint
        checkpoint = await self.checkpoint_manager.restore_from_checkpoint()
        if checkpoint:
            self._offsets = defaultdict(
                lambda: defaultdict(int), checkpoint.offsets
            )
            self._watermarks = defaultdict(float)
            for topic, wm in checkpoint.metadata.get("watermarks", {}).items():
                self._watermarks[topic] = wm

        self._running = True
        logger.info(f"Stream processor {self.processor_id} started")

    async def stop(self):
        """Stop the stream processor gracefully."""
        self._running = False

        # Cancel all processor tasks
        for task in self._processor_tasks:
            task.cancel()

        await asyncio.gather(*self._processor_tasks, return_exceptions=True)
        await self.checkpoint_manager.stop()

        logger.info(f"Stream processor {self.processor_id} stopped")

    def create_topic(self, topic: str, num_partitions: Optional[int] = None):
        """Create a new topic with partitions."""
        partitions = num_partitions or self.num_partitions
        self._topics[topic] = [deque() for _ in range(partitions)]
        logger.info(f"Created topic {topic} with {partitions} partitions")

    def subscribe(
        self,
        topic: str,
        handler: Callable[[StreamEvent], Awaitable[Optional[StreamEvent]]],
        consumer_group: Optional[str] = None,
    ):
        """Subscribe a handler to a topic."""
        if topic not in self._topics:
            self.create_topic(topic)

        self._handlers[topic].append(handler)

        if consumer_group:
            # Assign partitions to consumer group
            group = self._consumer_groups[topic][consumer_group]
            unassigned = set(range(len(self._topics[topic]))) - group
            if unassigned:
                group.add(unassigned.pop())

    async def produce(
        self,
        topic: str,
        payload: Any,
        partition_key: Optional[str] = None,
        headers: Optional[Dict[str, str]] = None,
    ) -> StreamEvent:
        """Produce a message to a topic."""
        if topic not in self._topics:
            self.create_topic(topic)

        # Determine partition
        if partition_key:
            partition = hash(partition_key) % len(self._topics[topic])
        else:
            partition = hash(str(uuid.uuid4())) % len(self._topics[topic])

        # Get sequence number
        sequence = await self.exactly_once.get_next_sequence(
            f"{topic}:{partition}"
        )

        event = StreamEvent(
            event_id=str(uuid.uuid4()),
            payload=payload,
            timestamp=time.time(),
            partition_key=partition_key,
            headers=headers or {},
            sequence_number=sequence,
            source_topic=topic,
        )

        # Apply backpressure
        accepted = await self.backpressure_handler.offer(event)
        if not accepted:
            logger.warning(f"Event dropped due to backpressure: {event.event_id}")
            return event

        self._topics[topic][partition].append(event)
        return event

    async def process_topic(self, topic: str):
        """Process all events from a topic."""
        if topic not in self._topics:
            return

        partitions = self._topics[topic]

        for partition_idx, partition in enumerate(partitions):
            while partition and self._running:
                event = partition.popleft()

                try:
                    await self._process_event(topic, partition_idx, event)
                except Exception as e:
                    logger.error(f"Error processing event: {e}")
                    self._events_failed += 1
                    # Re-queue for retry if at-least-once
                    if self.delivery_guarantee != DeliveryGuarantee.AT_MOST_ONCE:
                        event.retry_count += 1
                        if event.retry_count < 3:
                            partition.appendleft(event)

    async def _process_event(
        self, topic: str, partition: int, event: StreamEvent
    ):
        """Process a single event with delivery guarantees."""
        start_time = time.time()

        if self.delivery_guarantee == DeliveryGuarantee.EXACTLY_ONCE:
            # Check for duplicate
            if await self.exactly_once.is_duplicate(event):
                logger.debug(f"Skipping duplicate: {event.event_id}")
                return

            # Begin transaction
            txn_id = await self.exactly_once.begin_transaction()

            try:
                should_process = await self.exactly_once.add_to_transaction(
                    txn_id, event
                )
                if not should_process:
                    await self.exactly_once.abort_transaction(txn_id)
                    return

                # Run handlers
                for handler in self._handlers.get(topic, []):
                    await handler(event)

                # Commit transaction
                await self.exactly_once.commit_transaction(txn_id)

            except Exception as e:
                await self.exactly_once.abort_transaction(txn_id)
                raise

        else:
            # At-most-once or at-least-once: just process
            for handler in self._handlers.get(topic, []):
                await handler(event)

        # Update metrics
        self._events_processed += 1
        self._processing_latency_sum += time.time() - start_time

        # Update offset
        self._offsets[topic][partition] = event.sequence_number

        # Update watermark
        self._watermarks[topic] = max(
            self._watermarks[topic], event.timestamp
        )

        # Checkpoint if needed
        if await self.checkpoint_manager.should_checkpoint():
            await self._create_checkpoint()

    async def _create_checkpoint(self):
        """Create a checkpoint of current state."""
        state = {
            "events_processed": self._events_processed,
            "events_failed": self._events_failed,
        }

        await self.checkpoint_manager.create_checkpoint(
            state=state,
            offsets={
                topic: dict(partitions)
                for topic, partitions in self._offsets.items()
            },
            watermark=max(self._watermarks.values()) if self._watermarks else 0,
            metadata={"watermarks": dict(self._watermarks)},
        )

    async def run(self):
        """Run the stream processor continuously."""
        await self.start()

        try:
            while self._running:
                for topic in self._topics:
                    await self.process_topic(topic)

                await asyncio.sleep(0.01)  # Small yield

        finally:
            await self.stop()

    def get_metrics(self) -> Dict[str, Any]:
        """Get processor metrics."""
        avg_latency = (
            self._processing_latency_sum / self._events_processed
            if self._events_processed > 0
            else 0
        )

        return {
            "processor_id": self.processor_id,
            "events_processed": self._events_processed,
            "events_failed": self._events_failed,
            "average_latency_ms": avg_latency * 1000,
            "backpressure": self.backpressure_handler.metrics.__dict__,
            "watermarks": dict(self._watermarks),
            "offsets": {
                topic: dict(partitions)
                for topic, partitions in self._offsets.items()
            },
        }


# =============================================================================
# Convenience Functions
# =============================================================================


def create_stream_processor(
    processor_id: str = "aiva_default",
    exactly_once: bool = True,
    **kwargs,
) -> StreamProcessor:
    """Create a configured stream processor."""
    return StreamProcessor(
        processor_id=processor_id,
        delivery_guarantee=(
            DeliveryGuarantee.EXACTLY_ONCE
            if exactly_once
            else DeliveryGuarantee.AT_LEAST_ONCE
        ),
        **kwargs,
    )


def create_tumbling_window(
    size_ms: int,
    aggregator: Callable[[List[V]], R],
    key_extractor: Callable[[Any], K],
) -> WindowedAggregation[K, V, R]:
    """Create a tumbling window aggregation."""
    return WindowedAggregation(
        window_assigner=TumblingWindowAssigner(size_ms),
        aggregator=aggregator,
        key_extractor=key_extractor,
    )


def create_sliding_window(
    size_ms: int,
    slide_ms: int,
    aggregator: Callable[[List[V]], R],
    key_extractor: Callable[[Any], K],
) -> WindowedAggregation[K, V, R]:
    """Create a sliding window aggregation."""
    return WindowedAggregation(
        window_assigner=SlidingWindowAssigner(size_ms, slide_ms),
        aggregator=aggregator,
        key_extractor=key_extractor,
    )


# =============================================================================
# Example Usage and Testing
# =============================================================================


async def example_usage():
    """Demonstrate stream processor usage."""
    # Create processor
    processor = create_stream_processor("aiva_demo")

    # Create topics
    processor.create_topic("sensor_data", num_partitions=4)
    processor.create_topic("aggregated_data", num_partitions=2)

    # Create windowed aggregation
    window_agg = create_tumbling_window(
        size_ms=5000,
        aggregator=lambda values: sum(v.get("value", 0) for v in values),
        key_extractor=lambda payload: payload.get("sensor_id", "unknown"),
    )

    # Event sourcing setup
    event_store = EventStore()
    event_sourcing = EventSourcing(event_store)

    # Define handler
    async def process_sensor_data(event: StreamEvent) -> Optional[StreamEvent]:
        # Process through window aggregation
        results = await window_agg.process(event)

        # Store via event sourcing
        domain_event = DomainEvent(
            event_id=str(uuid.uuid4()),
            aggregate_id=event.payload.get("sensor_id", "unknown"),
            event_type="SENSOR_READING",
            timestamp=event.timestamp,
            version=1,
            payload=event.payload,
        )
        await event_store.append(domain_event)

        logger.info(f"Processed sensor data: {event.payload}")
        return None

    # Subscribe handler
    processor.subscribe("sensor_data", process_sensor_data)

    # Start processor
    await processor.start()

    # Produce some test events
    for i in range(10):
        await processor.produce(
            "sensor_data",
            {"sensor_id": f"sensor_{i % 3}", "value": i * 10, "unit": "celsius"},
            partition_key=f"sensor_{i % 3}",
        )

    # Process events
    await processor.process_topic("sensor_data")

    # Update watermark to trigger window
    triggered = await window_agg.update_watermark(time.time() + 10)
    for result in triggered:
        logger.info(
            f"Window result: key={result.key}, value={result.value}, "
            f"window=[{result.window_start}, {result.window_end}]"
        )

    # Get metrics
    metrics = processor.get_metrics()
    logger.info(f"Processor metrics: {json.dumps(metrics, indent=2)}")

    await processor.stop()


if __name__ == "__main__":
    asyncio.run(example_usage())
