"""
VAST_05 - MapReduce Distributed Computation Engine for AIVA
============================================================

A production-grade MapReduce implementation for distributed data processing
within the Genesis-OS / AIVA framework.

Components:
-----------
1. MapReduceEngine   - Core orchestration and execution
2. Mapper            - Distributed mapping phase with parallel execution
3. Reducer           - Aggregation phase with combiners
4. Combiner          - Local combining optimization for map outputs
5. Partitioner       - Data partitioning strategies
6. JobScheduler      - Task scheduling with priority queues

Author: Genesis Lead Architect
Version: 1.0.0
"""

import asyncio
import hashlib
import json
import logging
import multiprocessing
import os
import pickle
import queue
import random
import threading
import time
import uuid
from abc import ABC, abstractmethod
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum, auto
from functools import partial, reduce
from heapq import heappush, heappop
from typing import (
    Any, Callable, Dict, Generator, Generic, Iterable, Iterator,
    List, Optional, Tuple, TypeVar, Union, Set
)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("VAST_MapReduce")

# Type variables for generic typing
K = TypeVar('K')  # Key type
V = TypeVar('V')  # Value type
K2 = TypeVar('K2')  # Output key type
V2 = TypeVar('V2')  # Output value type


class TaskStatus(Enum):
    """Status of a MapReduce task."""
    PENDING = auto()
    SCHEDULED = auto()
    RUNNING = auto()
    COMPLETED = auto()
    FAILED = auto()
    RETRYING = auto()
    CANCELLED = auto()


class PartitionStrategy(Enum):
    """Partitioning strategies for data distribution."""
    HASH = auto()
    RANGE = auto()
    ROUND_ROBIN = auto()
    CUSTOM = auto()
    CONSISTENT_HASH = auto()


@dataclass
class TaskMetrics:
    """Metrics for task execution tracking."""
    task_id: str
    start_time: Optional[float] = None
    end_time: Optional[float] = None
    records_processed: int = 0
    bytes_processed: int = 0
    errors: int = 0
    retries: int = 0

    @property
    def duration(self) -> float:
        """Calculate task duration in seconds."""
        if self.start_time and self.end_time:
            return self.end_time - self.start_time
        return 0.0

    @property
    def throughput(self) -> float:
        """Calculate records per second."""
        if self.duration > 0:
            return self.records_processed / self.duration
        return 0.0

    def to_dict(self) -> Dict[str, Any]:
        """Serialize metrics to dictionary."""
        return {
            "task_id": self.task_id,
            "start_time": self.start_time,
            "end_time": self.end_time,
            "duration": self.duration,
            "records_processed": self.records_processed,
            "bytes_processed": self.bytes_processed,
            "errors": self.errors,
            "retries": self.retries,
            "throughput": self.throughput
        }


@dataclass
class MapTask:
    """Represents a single map task."""
    task_id: str
    partition_id: int
    data: List[Tuple[K, V]]
    status: TaskStatus = TaskStatus.PENDING
    result: Optional[List[Tuple[K2, V2]]] = None
    error: Optional[str] = None
    metrics: TaskMetrics = field(default_factory=lambda: TaskMetrics(str(uuid.uuid4())))
    priority: int = 0
    assigned_worker: Optional[str] = None

    def __post_init__(self):
        self.metrics.task_id = self.task_id


@dataclass
class ReduceTask:
    """Represents a single reduce task."""
    task_id: str
    partition_id: int
    data: Dict[K2, List[V2]]
    status: TaskStatus = TaskStatus.PENDING
    result: Optional[Dict[K2, V2]] = None
    error: Optional[str] = None
    metrics: TaskMetrics = field(default_factory=lambda: TaskMetrics(str(uuid.uuid4())))
    priority: int = 0
    assigned_worker: Optional[str] = None

    def __post_init__(self):
        self.metrics.task_id = self.task_id


@dataclass
class JobConfig:
    """Configuration for a MapReduce job."""
    job_id: str = field(default_factory=lambda: str(uuid.uuid4()))
    job_name: str = "unnamed_job"
    num_mappers: int = 4
    num_reducers: int = 2
    partition_strategy: PartitionStrategy = PartitionStrategy.HASH
    use_combiner: bool = True
    max_retries: int = 3
    retry_delay: float = 1.0
    timeout_seconds: float = 300.0
    memory_limit_mb: int = 1024
    shuffle_buffer_size: int = 10000
    sort_buffer_size: int = 5000
    compression_enabled: bool = False
    checkpoint_enabled: bool = True
    checkpoint_interval: int = 1000


class Partitioner(Generic[K2]):
    """
    Data partitioner for distributing intermediate results to reducers.

    Supports multiple partitioning strategies:
    - HASH: Hash-based partitioning (default)
    - RANGE: Range-based partitioning for ordered keys
    - ROUND_ROBIN: Even distribution across partitions
    - CONSISTENT_HASH: Consistent hashing for distributed systems
    - CUSTOM: User-defined partitioning function
    """

    def __init__(
        self,
        num_partitions: int,
        strategy: PartitionStrategy = PartitionStrategy.HASH,
        custom_function: Optional[Callable[[K2, int], int]] = None,
        range_boundaries: Optional[List[K2]] = None
    ):
        """
        Initialize the partitioner.

        Args:
            num_partitions: Number of partitions (reducers)
            strategy: Partitioning strategy to use
            custom_function: Custom partitioning function (for CUSTOM strategy)
            range_boundaries: Sorted list of range boundaries (for RANGE strategy)
        """
        self.num_partitions = num_partitions
        self.strategy = strategy
        self.custom_function = custom_function
        self.range_boundaries = range_boundaries or []
        self._round_robin_counter = 0
        self._hash_ring: Dict[int, int] = {}

        if strategy == PartitionStrategy.CONSISTENT_HASH:
            self._build_hash_ring()

    def _build_hash_ring(self, virtual_nodes: int = 150):
        """Build consistent hash ring with virtual nodes."""
        for partition in range(self.num_partitions):
            for vnode in range(virtual_nodes):
                key = f"{partition}:{vnode}"
                hash_val = int(hashlib.md5(key.encode()).hexdigest(), 16)
                self._hash_ring[hash_val] = partition

    def _consistent_hash(self, key: K2) -> int:
        """Get partition using consistent hashing."""
        key_hash = int(hashlib.md5(str(key).encode()).hexdigest(), 16)
        sorted_hashes = sorted(self._hash_ring.keys())
        for ring_hash in sorted_hashes:
            if key_hash <= ring_hash:
                return self._hash_ring[ring_hash]
        return self._hash_ring[sorted_hashes[0]]

    def get_partition(self, key: K2) -> int:
        """
        Determine the partition for a given key.

        Args:
            key: The key to partition

        Returns:
            Partition index (0 to num_partitions - 1)
        """
        if self.strategy == PartitionStrategy.HASH:
            return hash(key) % self.num_partitions

        elif self.strategy == PartitionStrategy.RANGE:
            if not self.range_boundaries:
                return hash(key) % self.num_partitions
            for i, boundary in enumerate(self.range_boundaries):
                if key < boundary:
                    return i
            return len(self.range_boundaries)

        elif self.strategy == PartitionStrategy.ROUND_ROBIN:
            partition = self._round_robin_counter % self.num_partitions
            self._round_robin_counter += 1
            return partition

        elif self.strategy == PartitionStrategy.CONSISTENT_HASH:
            return self._consistent_hash(key)

        elif self.strategy == PartitionStrategy.CUSTOM:
            if self.custom_function:
                return self.custom_function(key, self.num_partitions)
            raise ValueError("Custom partitioning requires a custom_function")

        return hash(key) % self.num_partitions

    def partition_data(
        self,
        data: Iterable[Tuple[K2, V2]]
    ) -> Dict[int, List[Tuple[K2, V2]]]:
        """
        Partition a collection of key-value pairs.

        Args:
            data: Iterable of (key, value) tuples

        Returns:
            Dictionary mapping partition ID to list of (key, value) tuples
        """
        partitions: Dict[int, List[Tuple[K2, V2]]] = defaultdict(list)
        for key, value in data:
            partition_id = self.get_partition(key)
            partitions[partition_id].append((key, value))
        return dict(partitions)


class Combiner(Generic[K2, V2]):
    """
    Local combiner for reducing map output before shuffle.

    The combiner acts as a "mini-reducer" that runs locally on map outputs
    to reduce the amount of data transferred during the shuffle phase.
    """

    def __init__(
        self,
        combine_function: Callable[[K2, List[V2]], V2],
        buffer_size: int = 1000
    ):
        """
        Initialize the combiner.

        Args:
            combine_function: Function to combine values for a key
            buffer_size: Size of internal buffer before flushing
        """
        self.combine_function = combine_function
        self.buffer_size = buffer_size
        self._buffer: Dict[K2, List[V2]] = defaultdict(list)
        self._output: List[Tuple[K2, V2]] = []
        self._records_combined = 0

    def add(self, key: K2, value: V2) -> None:
        """
        Add a key-value pair to the combiner.

        Args:
            key: The key
            value: The value
        """
        self._buffer[key].append(value)
        if len(self._buffer) >= self.buffer_size:
            self.flush()

    def add_all(self, pairs: Iterable[Tuple[K2, V2]]) -> None:
        """Add multiple key-value pairs."""
        for key, value in pairs:
            self.add(key, value)

    def flush(self) -> None:
        """Flush buffer by combining values and adding to output."""
        for key, values in self._buffer.items():
            if len(values) > 1:
                combined = self.combine_function(key, values)
                self._output.append((key, combined))
                self._records_combined += len(values) - 1
            else:
                self._output.append((key, values[0]))
        self._buffer.clear()

    def get_output(self) -> List[Tuple[K2, V2]]:
        """
        Get combined output.

        Returns:
            List of combined (key, value) pairs
        """
        self.flush()  # Ensure buffer is flushed
        return self._output

    @property
    def records_combined(self) -> int:
        """Number of records reduced through combining."""
        return self._records_combined


class Mapper(Generic[K, V, K2, V2]):
    """
    Distributed mapper for the map phase.

    Handles:
    - Parallel map execution
    - Input splitting
    - Local combining (optional)
    - Fault tolerance with retries
    """

    def __init__(
        self,
        map_function: Callable[[K, V], Iterable[Tuple[K2, V2]]],
        config: JobConfig,
        combiner: Optional[Combiner[K2, V2]] = None
    ):
        """
        Initialize the mapper.

        Args:
            map_function: User-defined map function
            config: Job configuration
            combiner: Optional combiner for local aggregation
        """
        self.map_function = map_function
        self.config = config
        self.combiner = combiner
        self._executor: Optional[ProcessPoolExecutor] = None
        self._metrics: Dict[str, TaskMetrics] = {}

    def _map_partition(
        self,
        task: MapTask
    ) -> Tuple[str, List[Tuple[K2, V2]], Optional[str]]:
        """
        Map a single partition of data.

        Args:
            task: Map task to execute

        Returns:
            Tuple of (task_id, results, error_message)
        """
        try:
            results: List[Tuple[K2, V2]] = []
            local_combiner = None

            if self.combiner:
                local_combiner = Combiner(
                    self.combiner.combine_function,
                    self.combiner.buffer_size
                )

            for key, value in task.data:
                try:
                    map_output = self.map_function(key, value)
                    if local_combiner:
                        local_combiner.add_all(map_output)
                    else:
                        results.extend(map_output)
                    task.metrics.records_processed += 1
                except Exception as e:
                    task.metrics.errors += 1
                    logger.warning(f"Map error for key {key}: {e}")

            if local_combiner:
                results = local_combiner.get_output()

            return task.task_id, results, None

        except Exception as e:
            error_msg = f"Map task {task.task_id} failed: {str(e)}"
            logger.error(error_msg)
            return task.task_id, [], error_msg

    def split_input(
        self,
        data: List[Tuple[K, V]],
        num_splits: int
    ) -> List[List[Tuple[K, V]]]:
        """
        Split input data into partitions for parallel processing.

        Args:
            data: Input data as list of (key, value) pairs
            num_splits: Number of splits to create

        Returns:
            List of data splits
        """
        if not data:
            return []

        split_size = max(1, len(data) // num_splits)
        splits = []

        for i in range(0, len(data), split_size):
            splits.append(data[i:i + split_size])

        # Ensure we don't have more splits than requested
        while len(splits) > num_splits:
            splits[-2].extend(splits[-1])
            splits.pop()

        return splits

    def execute(
        self,
        data: List[Tuple[K, V]],
        num_workers: Optional[int] = None
    ) -> List[Tuple[K2, V2]]:
        """
        Execute the map phase.

        Args:
            data: Input data as list of (key, value) pairs
            num_workers: Number of parallel workers (defaults to config)

        Returns:
            List of mapped (key, value) pairs
        """
        num_workers = num_workers or self.config.num_mappers
        splits = self.split_input(data, num_workers)

        # Create map tasks
        tasks = [
            MapTask(
                task_id=f"map_{i}_{uuid.uuid4().hex[:8]}",
                partition_id=i,
                data=split,
                priority=0
            )
            for i, split in enumerate(splits)
        ]

        all_results: List[Tuple[K2, V2]] = []

        # Execute in parallel using process pool
        with ProcessPoolExecutor(max_workers=num_workers) as executor:
            future_to_task = {
                executor.submit(self._map_partition, task): task
                for task in tasks
            }

            for future in as_completed(future_to_task):
                task = future_to_task[future]
                try:
                    task_id, results, error = future.result(
                        timeout=self.config.timeout_seconds
                    )

                    if error:
                        task.status = TaskStatus.FAILED
                        task.error = error
                        logger.error(f"Task {task_id} failed: {error}")
                    else:
                        task.status = TaskStatus.COMPLETED
                        task.result = results
                        all_results.extend(results)
                        logger.info(f"Task {task_id} completed with {len(results)} results")

                    self._metrics[task_id] = task.metrics

                except Exception as e:
                    task.status = TaskStatus.FAILED
                    task.error = str(e)
                    logger.error(f"Task execution error: {e}")

        return all_results

    async def execute_async(
        self,
        data: List[Tuple[K, V]],
        num_workers: Optional[int] = None
    ) -> List[Tuple[K2, V2]]:
        """
        Asynchronous execution of the map phase.

        Args:
            data: Input data as list of (key, value) pairs
            num_workers: Number of parallel workers

        Returns:
            List of mapped (key, value) pairs
        """
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(
            None,
            partial(self.execute, data, num_workers)
        )


class Reducer(Generic[K2, V2]):
    """
    Aggregation reducer for the reduce phase.

    Handles:
    - Shuffling and sorting
    - Grouping by key
    - Parallel reduce execution
    - Fault tolerance
    """

    def __init__(
        self,
        reduce_function: Callable[[K2, Iterable[V2]], V2],
        config: JobConfig
    ):
        """
        Initialize the reducer.

        Args:
            reduce_function: User-defined reduce function
            config: Job configuration
        """
        self.reduce_function = reduce_function
        self.config = config
        self._metrics: Dict[str, TaskMetrics] = {}

    def shuffle_and_sort(
        self,
        map_output: List[Tuple[K2, V2]],
        partitioner: Partitioner[K2]
    ) -> Dict[int, Dict[K2, List[V2]]]:
        """
        Shuffle and sort map output for reduce phase.

        Args:
            map_output: Output from map phase
            partitioner: Partitioner for distributing data

        Returns:
            Dictionary mapping partition ID to grouped key-values
        """
        # Partition the data
        partitions = partitioner.partition_data(map_output)

        # Group by key within each partition and sort
        grouped: Dict[int, Dict[K2, List[V2]]] = {}

        for partition_id, pairs in partitions.items():
            # Sort by key within partition
            sorted_pairs = sorted(pairs, key=lambda x: x[0])

            # Group by key
            key_groups: Dict[K2, List[V2]] = defaultdict(list)
            for key, value in sorted_pairs:
                key_groups[key].append(value)

            grouped[partition_id] = dict(key_groups)

        return grouped

    def _reduce_partition(
        self,
        task: ReduceTask
    ) -> Tuple[str, Dict[K2, V2], Optional[str]]:
        """
        Reduce a single partition of grouped data.

        Args:
            task: Reduce task to execute

        Returns:
            Tuple of (task_id, results, error_message)
        """
        try:
            task.metrics.start_time = time.time()
            results: Dict[K2, V2] = {}

            for key, values in task.data.items():
                try:
                    reduced = self.reduce_function(key, iter(values))
                    results[key] = reduced
                    task.metrics.records_processed += 1
                except Exception as e:
                    task.metrics.errors += 1
                    logger.warning(f"Reduce error for key {key}: {e}")

            task.metrics.end_time = time.time()
            return task.task_id, results, None

        except Exception as e:
            error_msg = f"Reduce task {task.task_id} failed: {str(e)}"
            logger.error(error_msg)
            return task.task_id, {}, error_msg

    def execute(
        self,
        grouped_data: Dict[int, Dict[K2, List[V2]]],
        num_workers: Optional[int] = None
    ) -> Dict[K2, V2]:
        """
        Execute the reduce phase.

        Args:
            grouped_data: Shuffled and grouped data from map phase
            num_workers: Number of parallel workers

        Returns:
            Dictionary of final reduced results
        """
        num_workers = num_workers or self.config.num_reducers

        # Create reduce tasks
        tasks = [
            ReduceTask(
                task_id=f"reduce_{partition_id}_{uuid.uuid4().hex[:8]}",
                partition_id=partition_id,
                data=data,
                priority=0
            )
            for partition_id, data in grouped_data.items()
        ]

        all_results: Dict[K2, V2] = {}

        # Execute in parallel
        with ThreadPoolExecutor(max_workers=num_workers) as executor:
            future_to_task = {
                executor.submit(self._reduce_partition, task): task
                for task in tasks
            }

            for future in as_completed(future_to_task):
                task = future_to_task[future]
                try:
                    task_id, results, error = future.result(
                        timeout=self.config.timeout_seconds
                    )

                    if error:
                        task.status = TaskStatus.FAILED
                        task.error = error
                    else:
                        task.status = TaskStatus.COMPLETED
                        task.result = results
                        all_results.update(results)
                        logger.info(f"Reduce {task_id} completed with {len(results)} results")

                    self._metrics[task_id] = task.metrics

                except Exception as e:
                    task.status = TaskStatus.FAILED
                    task.error = str(e)
                    logger.error(f"Reduce execution error: {e}")

        return all_results


class JobScheduler:
    """
    Task scheduler for MapReduce jobs.

    Features:
    - Priority-based scheduling
    - Resource management
    - Load balancing
    - Speculative execution
    """

    def __init__(
        self,
        max_concurrent_jobs: int = 4,
        max_workers_per_job: int = 8
    ):
        """
        Initialize the scheduler.

        Args:
            max_concurrent_jobs: Maximum number of concurrent jobs
            max_workers_per_job: Maximum workers per job
        """
        self.max_concurrent_jobs = max_concurrent_jobs
        self.max_workers_per_job = max_workers_per_job
        self._job_queue: List[Tuple[int, str, 'MapReduceJob']] = []
        self._running_jobs: Dict[str, 'MapReduceJob'] = {}
        self._completed_jobs: Dict[str, 'MapReduceJob'] = {}
        self._lock = threading.Lock()
        self._worker_pool: Dict[str, Set[str]] = defaultdict(set)
        self._scheduler_running = False
        self._scheduler_thread: Optional[threading.Thread] = None

    def submit(
        self,
        job: 'MapReduceJob',
        priority: int = 0
    ) -> str:
        """
        Submit a job for execution.

        Args:
            job: MapReduce job to submit
            priority: Job priority (higher = more urgent)

        Returns:
            Job ID
        """
        with self._lock:
            # Use negative priority for min-heap (highest priority first)
            heappush(self._job_queue, (-priority, job.config.job_id, job))
            logger.info(f"Job {job.config.job_id} submitted with priority {priority}")
        return job.config.job_id

    def start(self) -> None:
        """Start the scheduler background thread."""
        if not self._scheduler_running:
            self._scheduler_running = True
            self._scheduler_thread = threading.Thread(
                target=self._scheduler_loop,
                daemon=True
            )
            self._scheduler_thread.start()
            logger.info("Scheduler started")

    def stop(self) -> None:
        """Stop the scheduler."""
        self._scheduler_running = False
        if self._scheduler_thread:
            self._scheduler_thread.join(timeout=5.0)
        logger.info("Scheduler stopped")

    def _scheduler_loop(self) -> None:
        """Main scheduler loop."""
        while self._scheduler_running:
            self._schedule_next()
            time.sleep(0.1)

    def _schedule_next(self) -> None:
        """Schedule the next available job."""
        with self._lock:
            if len(self._running_jobs) >= self.max_concurrent_jobs:
                return

            if not self._job_queue:
                return

            _, job_id, job = heappop(self._job_queue)
            self._running_jobs[job_id] = job

        # Execute job in background thread
        thread = threading.Thread(
            target=self._execute_job,
            args=(job,)
        )
        thread.start()

    def _execute_job(self, job: 'MapReduceJob') -> None:
        """Execute a job and handle completion."""
        try:
            job.run()
        except Exception as e:
            logger.error(f"Job {job.config.job_id} failed: {e}")
        finally:
            with self._lock:
                if job.config.job_id in self._running_jobs:
                    del self._running_jobs[job.config.job_id]
                self._completed_jobs[job.config.job_id] = job

    def get_job_status(self, job_id: str) -> Optional[TaskStatus]:
        """Get the status of a job."""
        with self._lock:
            if job_id in self._running_jobs:
                return TaskStatus.RUNNING
            if job_id in self._completed_jobs:
                return TaskStatus.COMPLETED
            for _, jid, _ in self._job_queue:
                if jid == job_id:
                    return TaskStatus.PENDING
        return None

    def get_metrics(self) -> Dict[str, Any]:
        """Get scheduler metrics."""
        with self._lock:
            return {
                "queued_jobs": len(self._job_queue),
                "running_jobs": len(self._running_jobs),
                "completed_jobs": len(self._completed_jobs),
                "max_concurrent": self.max_concurrent_jobs
            }


class MapReduceJob(Generic[K, V, K2, V2]):
    """
    Complete MapReduce job orchestration.

    Coordinates:
    - Input splitting
    - Map phase execution
    - Shuffle and sort
    - Reduce phase execution
    - Output collection
    """

    def __init__(
        self,
        config: JobConfig,
        map_function: Callable[[K, V], Iterable[Tuple[K2, V2]]],
        reduce_function: Callable[[K2, Iterable[V2]], V2],
        combine_function: Optional[Callable[[K2, List[V2]], V2]] = None,
        partition_function: Optional[Callable[[K2, int], int]] = None
    ):
        """
        Initialize a MapReduce job.

        Args:
            config: Job configuration
            map_function: User-defined map function
            reduce_function: User-defined reduce function
            combine_function: Optional combiner function
            partition_function: Optional custom partition function
        """
        self.config = config
        self.map_function = map_function
        self.reduce_function = reduce_function
        self.combine_function = combine_function
        self.partition_function = partition_function

        # Initialize components
        self._combiner: Optional[Combiner[K2, V2]] = None
        if combine_function and config.use_combiner:
            self._combiner = Combiner(combine_function, config.shuffle_buffer_size)

        self._mapper = Mapper(
            map_function=map_function,
            config=config,
            combiner=self._combiner
        )

        self._reducer = Reducer(
            reduce_function=reduce_function,
            config=config
        )

        self._partitioner = Partitioner(
            num_partitions=config.num_reducers,
            strategy=config.partition_strategy,
            custom_function=partition_function
        )

        self._status = TaskStatus.PENDING
        self._result: Optional[Dict[K2, V2]] = None
        self._metrics = JobMetrics(config.job_id)
        self._checkpoints: List[Dict[str, Any]] = []

    def run(self, data: List[Tuple[K, V]]) -> Dict[K2, V2]:
        """
        Execute the complete MapReduce job.

        Args:
            data: Input data as list of (key, value) pairs

        Returns:
            Dictionary of final results
        """
        logger.info(f"Starting job {self.config.job_id}: {self.config.job_name}")
        self._status = TaskStatus.RUNNING
        self._metrics.start_time = time.time()
        self._metrics.input_records = len(data)

        try:
            # MAP PHASE
            logger.info(f"[{self.config.job_id}] Starting map phase with {len(data)} records")
            map_start = time.time()

            map_output = self._mapper.execute(data, self.config.num_mappers)

            self._metrics.map_duration = time.time() - map_start
            self._metrics.map_output_records = len(map_output)
            logger.info(f"[{self.config.job_id}] Map phase complete: {len(map_output)} intermediate records")

            # Checkpoint after map
            if self.config.checkpoint_enabled:
                self._checkpoint("after_map", {"map_output_size": len(map_output)})

            # SHUFFLE AND SORT
            logger.info(f"[{self.config.job_id}] Starting shuffle phase")
            shuffle_start = time.time()

            grouped_data = self._reducer.shuffle_and_sort(map_output, self._partitioner)

            self._metrics.shuffle_duration = time.time() - shuffle_start
            total_groups = sum(len(g) for g in grouped_data.values())
            logger.info(f"[{self.config.job_id}] Shuffle complete: {total_groups} unique keys in {len(grouped_data)} partitions")

            # REDUCE PHASE
            logger.info(f"[{self.config.job_id}] Starting reduce phase")
            reduce_start = time.time()

            self._result = self._reducer.execute(grouped_data, self.config.num_reducers)

            self._metrics.reduce_duration = time.time() - reduce_start
            self._metrics.output_records = len(self._result)
            logger.info(f"[{self.config.job_id}] Reduce phase complete: {len(self._result)} output records")

            self._status = TaskStatus.COMPLETED
            self._metrics.end_time = time.time()

            logger.info(f"Job {self.config.job_id} completed successfully")
            logger.info(f"  Total duration: {self._metrics.total_duration:.2f}s")
            logger.info(f"  Map: {self._metrics.map_duration:.2f}s, Shuffle: {self._metrics.shuffle_duration:.2f}s, Reduce: {self._metrics.reduce_duration:.2f}s")

            return self._result

        except Exception as e:
            self._status = TaskStatus.FAILED
            self._metrics.end_time = time.time()
            self._metrics.error = str(e)
            logger.error(f"Job {self.config.job_id} failed: {e}")
            raise

    def _checkpoint(self, stage: str, data: Dict[str, Any]) -> None:
        """Create a checkpoint for fault recovery."""
        checkpoint = {
            "stage": stage,
            "timestamp": datetime.now().isoformat(),
            "job_id": self.config.job_id,
            "data": data
        }
        self._checkpoints.append(checkpoint)
        logger.debug(f"Checkpoint created: {stage}")

    @property
    def status(self) -> TaskStatus:
        """Get job status."""
        return self._status

    @property
    def result(self) -> Optional[Dict[K2, V2]]:
        """Get job results."""
        return self._result

    @property
    def metrics(self) -> 'JobMetrics':
        """Get job metrics."""
        return self._metrics


@dataclass
class JobMetrics:
    """Comprehensive job metrics."""
    job_id: str
    start_time: Optional[float] = None
    end_time: Optional[float] = None
    input_records: int = 0
    map_output_records: int = 0
    output_records: int = 0
    map_duration: float = 0.0
    shuffle_duration: float = 0.0
    reduce_duration: float = 0.0
    error: Optional[str] = None

    @property
    def total_duration(self) -> float:
        """Total job duration."""
        if self.start_time and self.end_time:
            return self.end_time - self.start_time
        return 0.0

    def to_dict(self) -> Dict[str, Any]:
        """Serialize metrics."""
        return {
            "job_id": self.job_id,
            "start_time": self.start_time,
            "end_time": self.end_time,
            "total_duration": self.total_duration,
            "input_records": self.input_records,
            "map_output_records": self.map_output_records,
            "output_records": self.output_records,
            "map_duration": self.map_duration,
            "shuffle_duration": self.shuffle_duration,
            "reduce_duration": self.reduce_duration,
            "error": self.error
        }


class MapReduceEngine:
    """
    Core MapReduce engine orchestrating all components.

    Provides:
    - High-level API for MapReduce jobs
    - Resource management
    - Job lifecycle management
    - Metrics aggregation
    """

    def __init__(
        self,
        max_concurrent_jobs: int = 4,
        max_workers: int = 8,
        enable_scheduler: bool = True
    ):
        """
        Initialize the MapReduce engine.

        Args:
            max_concurrent_jobs: Maximum concurrent jobs
            max_workers: Maximum worker threads/processes
            enable_scheduler: Whether to enable background scheduler
        """
        self.max_concurrent_jobs = max_concurrent_jobs
        self.max_workers = max_workers
        self._scheduler = JobScheduler(
            max_concurrent_jobs=max_concurrent_jobs,
            max_workers_per_job=max_workers
        )
        self._jobs: Dict[str, MapReduceJob] = {}
        self._enable_scheduler = enable_scheduler

        if enable_scheduler:
            self._scheduler.start()

    def create_job(
        self,
        map_function: Callable[[K, V], Iterable[Tuple[K2, V2]]],
        reduce_function: Callable[[K2, Iterable[V2]], V2],
        combine_function: Optional[Callable[[K2, List[V2]], V2]] = None,
        job_name: str = "unnamed_job",
        num_mappers: int = 4,
        num_reducers: int = 2,
        partition_strategy: PartitionStrategy = PartitionStrategy.HASH,
        use_combiner: bool = True,
        **kwargs
    ) -> MapReduceJob:
        """
        Create a new MapReduce job.

        Args:
            map_function: Map function
            reduce_function: Reduce function
            combine_function: Optional combiner function
            job_name: Human-readable job name
            num_mappers: Number of map tasks
            num_reducers: Number of reduce tasks
            partition_strategy: Data partitioning strategy
            use_combiner: Whether to use combiner
            **kwargs: Additional JobConfig parameters

        Returns:
            Configured MapReduceJob
        """
        config = JobConfig(
            job_name=job_name,
            num_mappers=min(num_mappers, self.max_workers),
            num_reducers=min(num_reducers, self.max_workers),
            partition_strategy=partition_strategy,
            use_combiner=use_combiner and (combine_function is not None),
            **kwargs
        )

        job = MapReduceJob(
            config=config,
            map_function=map_function,
            reduce_function=reduce_function,
            combine_function=combine_function
        )

        self._jobs[config.job_id] = job
        return job

    def submit(
        self,
        job: MapReduceJob,
        data: List[Tuple[K, V]],
        priority: int = 0,
        async_execution: bool = False
    ) -> Union[Dict[K2, V2], str]:
        """
        Submit a job for execution.

        Args:
            job: MapReduce job to execute
            data: Input data
            priority: Job priority
            async_execution: If True, returns job_id for async tracking

        Returns:
            Results dictionary or job_id (if async)
        """
        if async_execution:
            # Store data for deferred execution
            job._input_data = data
            job_id = self._scheduler.submit(job, priority)
            return job_id
        else:
            return job.run(data)

    def run(
        self,
        data: List[Tuple[K, V]],
        map_function: Callable[[K, V], Iterable[Tuple[K2, V2]]],
        reduce_function: Callable[[K2, Iterable[V2]], V2],
        combine_function: Optional[Callable[[K2, List[V2]], V2]] = None,
        **kwargs
    ) -> Dict[K2, V2]:
        """
        Convenience method to create and run a job in one call.

        Args:
            data: Input data
            map_function: Map function
            reduce_function: Reduce function
            combine_function: Optional combiner
            **kwargs: JobConfig parameters

        Returns:
            Results dictionary
        """
        job = self.create_job(
            map_function=map_function,
            reduce_function=reduce_function,
            combine_function=combine_function,
            **kwargs
        )
        return job.run(data)

    def get_job(self, job_id: str) -> Optional[MapReduceJob]:
        """Get a job by ID."""
        return self._jobs.get(job_id)

    def get_job_status(self, job_id: str) -> Optional[TaskStatus]:
        """Get job status."""
        return self._scheduler.get_job_status(job_id)

    def get_metrics(self) -> Dict[str, Any]:
        """Get engine metrics."""
        return {
            "engine": {
                "max_concurrent_jobs": self.max_concurrent_jobs,
                "max_workers": self.max_workers,
                "total_jobs": len(self._jobs)
            },
            "scheduler": self._scheduler.get_metrics(),
            "jobs": {
                job_id: job.metrics.to_dict()
                for job_id, job in self._jobs.items()
            }
        }

    def shutdown(self, wait: bool = True) -> None:
        """Shutdown the engine."""
        if self._enable_scheduler:
            self._scheduler.stop()
        logger.info("MapReduce engine shutdown complete")


# ============================================================================
# EXAMPLE IMPLEMENTATIONS AND UTILITIES
# ============================================================================

def word_count_mapper(key: int, value: str) -> Iterable[Tuple[str, int]]:
    """Classic word count mapper."""
    words = value.lower().split()
    for word in words:
        # Clean word of punctuation
        clean_word = ''.join(c for c in word if c.isalnum())
        if clean_word:
            yield (clean_word, 1)


def word_count_reducer(key: str, values: Iterable[int]) -> int:
    """Classic word count reducer."""
    return sum(values)


def word_count_combiner(key: str, values: List[int]) -> int:
    """Word count combiner (same as reducer for associative operations)."""
    return sum(values)


def inverted_index_mapper(doc_id: str, content: str) -> Iterable[Tuple[str, str]]:
    """Create inverted index mapping."""
    words = set(content.lower().split())
    for word in words:
        clean_word = ''.join(c for c in word if c.isalnum())
        if clean_word:
            yield (clean_word, doc_id)


def inverted_index_reducer(word: str, doc_ids: Iterable[str]) -> List[str]:
    """Aggregate document IDs for inverted index."""
    return sorted(set(doc_ids))


def average_mapper(key: str, value: float) -> Iterable[Tuple[str, Tuple[float, int]]]:
    """Mapper for computing averages."""
    yield (key, (value, 1))


def average_reducer(key: str, values: Iterable[Tuple[float, int]]) -> float:
    """Reducer for computing averages."""
    total_sum = 0.0
    total_count = 0
    for value_sum, count in values:
        total_sum += value_sum
        total_count += count
    return total_sum / total_count if total_count > 0 else 0.0


def average_combiner(key: str, values: List[Tuple[float, int]]) -> Tuple[float, int]:
    """Combiner for partial average computation."""
    total_sum = sum(v[0] for v in values)
    total_count = sum(v[1] for v in values)
    return (total_sum, total_count)


# ============================================================================
# MAIN EXECUTION AND TESTING
# ============================================================================

if __name__ == "__main__":
    print("=" * 70)
    print("VAST_05 MapReduce Distributed Computation Engine")
    print("=" * 70)

    # Create engine
    engine = MapReduceEngine(
        max_concurrent_jobs=2,
        max_workers=4,
        enable_scheduler=False  # Disable for synchronous testing
    )

    # Test 1: Word Count
    print("\n--- Test 1: Word Count ---")
    documents = [
        (1, "Hello world hello"),
        (2, "World of MapReduce"),
        (3, "Hello MapReduce world"),
        (4, "Distributed computing with MapReduce"),
        (5, "Hello distributed world")
    ]

    word_counts = engine.run(
        data=documents,
        map_function=word_count_mapper,
        reduce_function=word_count_reducer,
        combine_function=word_count_combiner,
        job_name="word_count",
        num_mappers=2,
        num_reducers=2
    )

    print("Word Counts:")
    for word, count in sorted(word_counts.items(), key=lambda x: -x[1])[:10]:
        print(f"  {word}: {count}")

    # Test 2: Inverted Index
    print("\n--- Test 2: Inverted Index ---")
    docs = [
        ("doc1", "python programming language"),
        ("doc2", "java programming language"),
        ("doc3", "python machine learning"),
        ("doc4", "java enterprise programming")
    ]

    inverted_index = engine.run(
        data=docs,
        map_function=inverted_index_mapper,
        reduce_function=inverted_index_reducer,
        job_name="inverted_index",
        num_mappers=2,
        num_reducers=2,
        use_combiner=False
    )

    print("Inverted Index:")
    for word, doc_ids in sorted(inverted_index.items()):
        print(f"  {word}: {doc_ids}")

    # Test 3: Computing Averages
    print("\n--- Test 3: Computing Averages ---")
    grades = [
        ("math", 85.0),
        ("math", 90.0),
        ("math", 78.0),
        ("science", 92.0),
        ("science", 88.0),
        ("english", 75.0),
        ("english", 82.0),
        ("english", 79.0)
    ]

    averages = engine.run(
        data=grades,
        map_function=average_mapper,
        reduce_function=average_reducer,
        combine_function=average_combiner,
        job_name="grade_averages",
        num_mappers=2,
        num_reducers=2
    )

    print("Subject Averages:")
    for subject, avg in sorted(averages.items()):
        print(f"  {subject}: {avg:.2f}")

    # Test 4: Large Scale Test
    print("\n--- Test 4: Large Scale Test (10,000 records) ---")
    large_data = [
        (i, f"word{random.randint(1, 100)} word{random.randint(1, 100)} word{random.randint(1, 100)}")
        for i in range(10000)
    ]

    start_time = time.time()
    large_counts = engine.run(
        data=large_data,
        map_function=word_count_mapper,
        reduce_function=word_count_reducer,
        combine_function=word_count_combiner,
        job_name="large_word_count",
        num_mappers=4,
        num_reducers=2
    )
    duration = time.time() - start_time

    print(f"Processed 10,000 records in {duration:.2f}s")
    print(f"Unique words: {len(large_counts)}")
    print("Top 5 words:")
    for word, count in sorted(large_counts.items(), key=lambda x: -x[1])[:5]:
        print(f"  {word}: {count}")

    # Test 5: Partitioner Strategies
    print("\n--- Test 5: Partitioner Strategies ---")

    # Hash partitioner
    hash_partitioner = Partitioner(num_partitions=4, strategy=PartitionStrategy.HASH)
    test_keys = ["apple", "banana", "cherry", "date", "elderberry"]
    print("Hash Partitioning:")
    for key in test_keys:
        print(f"  {key} -> partition {hash_partitioner.get_partition(key)}")

    # Consistent hash partitioner
    consistent_partitioner = Partitioner(num_partitions=4, strategy=PartitionStrategy.CONSISTENT_HASH)
    print("\nConsistent Hash Partitioning:")
    for key in test_keys:
        print(f"  {key} -> partition {consistent_partitioner.get_partition(key)}")

    # Test 6: Custom Partitioner
    print("\n--- Test 6: Custom Partitioner ---")

    def first_letter_partitioner(key: str, num_partitions: int) -> int:
        """Partition by first letter."""
        return ord(key[0].lower()) % num_partitions

    custom_partitioner = Partitioner(
        num_partitions=4,
        strategy=PartitionStrategy.CUSTOM,
        custom_function=first_letter_partitioner
    )

    print("Custom (First Letter) Partitioning:")
    for key in test_keys:
        print(f"  {key} -> partition {custom_partitioner.get_partition(key)}")

    # Display engine metrics
    print("\n--- Engine Metrics ---")
    metrics = engine.get_metrics()
    print(json.dumps(metrics["engine"], indent=2))

    print("\n--- Job Metrics Summary ---")
    for job_id, job_metrics in metrics["jobs"].items():
        print(f"\nJob: {job_id[:20]}...")
        print(f"  Duration: {job_metrics['total_duration']:.2f}s")
        print(f"  Input: {job_metrics['input_records']} -> Output: {job_metrics['output_records']}")

    # Cleanup
    engine.shutdown()

    print("\n" + "=" * 70)
    print("MapReduce Engine Tests Complete")
    print("=" * 70)
