import multiprocessing
import os
import time
from typing import List, Dict, Any, Callable

# Assuming rwl_runner.py exists and contains the run_prd function
from core.rwl_runner import run_prd
from core.report_merger import merge_reports


class ParallelRWLExecutor:
    """
    Executor that runs multiple PRDs (Process Requirement Documents) in parallel using multiprocessing.

    Attributes:
        max_parallel (int): The maximum number of PRDs to run in parallel.  Defaults to 4.
    """

    def __init__(self, max_parallel: int = 4):
        """
        Initializes the ParallelRWLExecutor with a maximum number of parallel processes.

        Args:
            max_parallel (int): The maximum number of PRDs to run concurrently.  Defaults to 4.
        """
        if not isinstance(max_parallel, int) or max_parallel <= 0:
            raise ValueError("max_parallel must be a positive integer.")
        self.max_parallel = max_parallel

    def execute(self, prd_paths: List[str], progress_callback: Callable[[str], None] = None) -> Dict[str, Any]:
        """
        Executes a list of PRDs in parallel, respecting the maximum parallel limit.

        Args:
            prd_paths (List[str]): A list of paths to the PRD files.
            progress_callback (Callable[[str], None], optional): A callback function to report progress.
                                                                It receives a PRD path as input. Defaults to None.

        Returns:
            Dict[str, Any]: A merged report containing the aggregated results from all PRD executions.
        """
        if not isinstance(prd_paths, list):
            raise TypeError("prd_paths must be a list.")
        if not all(isinstance(path, str) for path in prd_paths):
            raise TypeError("All elements in prd_paths must be strings.")
        if not all(os.path.exists(path) for path in prd_paths):
            raise ValueError("All paths in prd_paths must exist.")

        with multiprocessing.Pool(processes=self.max_parallel) as pool:
            results = []

            def collect_result(result: Dict[str, Any]) -> None:
                """
                Callback function to collect results from the worker processes.
                """
                results.append(result)

            def error_callback(e: Exception) -> None:
                """
                Callback function to handle exceptions raised by worker processes.
                """
                print(f"Error in worker process: {e}")

            for prd_path in prd_paths:
                pool.apply_async(
                    run_prd, args=(prd_path,), callback=collect_result, error_callback=error_callback
                )
                if progress_callback:
                    progress_callback(prd_path)

            pool.close()  # No more tasks will be submitted to the pool
            pool.join()   # Wait for all tasks to complete

        if not results:
            print("No results received from PRD executions.")
            return {}  # Return an empty dictionary if no results are available.

        # Merge the reports from each PRD execution
        merged_report = merge_reports(results)

        return merged_report


if __name__ == "__main__":
    # Example Usage (assuming you have sample PRD files)

    def progress_report(prd_path: str) -> None:
        print(f"PRD processing started: {prd_path}")

    # Create dummy PRD files for testing
    prd_file1 = "/tmp/prd1.txt"
    prd_file2 = "/tmp/prd2.txt"

    with open(prd_file1, "w") as f:
        f.write("Some content for PRD 1")

    with open(prd_file2, "w") as f:
        f.write("Some content for PRD 2")

    try:
        prd_paths = [prd_file1, prd_file2]  # Replace with your actual PRD file paths
        executor = ParallelRWLExecutor(max_parallel=2)
        start_time = time.time()
        final_report = executor.execute(prd_paths, progress_callback=progress_report)
        end_time = time.time()
        execution_time = end_time - start_time

        print("Final Merged Report:")
        print(final_report)
        print(f"Total execution time: {execution_time:.2f} seconds")

    except Exception as e:
        print(f"An error occurred: {e}")

    finally:
        # Clean up the dummy files
        try:
            os.remove(prd_file1)
            os.remove(prd_file2)
        except OSError:
            pass # Ignore errors if files don't exist