"""
PM-003: Test Execution Framework
Run tests and capture results for validation in Genesis.

Acceptance Criteria:
- [x] GIVEN task with test_command WHEN run THEN tests executed
- [x] AND captures pass/fail/error counts
- [x] AND supports pytest, unittest, npm test
- [x] AND timeout: 5 minutes max

Dependencies: None
"""

import os
import json
import logging
import subprocess
import re
import time
from datetime import datetime
from typing import Optional, Dict, Any, List, Tuple
from dataclasses import dataclass, asdict, field
from pathlib import Path
from enum import Enum

logger = logging.getLogger(__name__)


class TestFramework(Enum):
    """Supported test frameworks."""
    PYTEST = "pytest"
    UNITTEST = "unittest"
    NPM = "npm"
    JEST = "jest"
    MOCHA = "mocha"
    GO = "go"
    CARGO = "cargo"
    UNKNOWN = "unknown"


@dataclass
class TestResult:
    """Result of a test execution."""
    framework: str
    command: str
    exit_code: int
    passed: int = 0
    failed: int = 0
    errors: int = 0
    skipped: int = 0
    total: int = 0
    duration_seconds: float = 0.0
    stdout: str = ""
    stderr: str = ""
    success: bool = False
    error_message: Optional[str] = None
    timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
    test_details: List[Dict[str, Any]] = field(default_factory=list)

    def to_dict(self) -> Dict[str, Any]:
        return asdict(self)

    @property
    def failure_rate(self) -> float:
        """Calculate failure rate as percentage."""
        if self.total == 0:
            return 0.0
        return ((self.failed + self.errors) / self.total) * 100


class TestRunner:
    """
    Execute tests and capture results for validation.

    Features:
    - Supports pytest, unittest, npm test, jest, mocha, go test, cargo test
    - 5-minute timeout by default
    - Captures pass/fail/error counts
    - Extracts individual test results where possible
    """

    DEFAULT_TIMEOUT = 300  # 5 minutes
    MAX_OUTPUT_LENGTH = 50000  # Truncate large outputs

    def __init__(self, working_dir: Optional[str] = None):
        """
        Initialize TestRunner.

        Args:
            working_dir: Working directory for test execution. Defaults to cwd.
        """
        self.working_dir = working_dir or os.getcwd()

    def detect_framework(self, command: str) -> TestFramework:
        """
        Detect test framework from command.

        Args:
            command: Test command string

        Returns:
            Detected TestFramework
        """
        command_lower = command.lower()

        if "pytest" in command_lower:
            return TestFramework.PYTEST
        elif "unittest" in command_lower or "python -m unittest" in command_lower:
            return TestFramework.UNITTEST
        elif "npm test" in command_lower or "npm run test" in command_lower:
            return TestFramework.NPM
        elif "jest" in command_lower:
            return TestFramework.JEST
        elif "mocha" in command_lower:
            return TestFramework.MOCHA
        elif "go test" in command_lower:
            return TestFramework.GO
        elif "cargo test" in command_lower:
            return TestFramework.CARGO
        return TestFramework.UNKNOWN

    def run(self,
            command: str,
            timeout: Optional[int] = None,
            env: Optional[Dict[str, str]] = None,
            working_dir: Optional[str] = None) -> TestResult:
        """
        Run test command and capture results.

        Args:
            command: Test command to execute
            timeout: Timeout in seconds (default 5 minutes)
            env: Environment variables to add
            working_dir: Working directory (overrides instance default)

        Returns:
            TestResult with execution details
        """
        timeout = timeout or self.DEFAULT_TIMEOUT
        working_dir = working_dir or self.working_dir
        framework = self.detect_framework(command)

        logger.info(f"Running tests: {command} (framework: {framework.value}, timeout: {timeout}s)")

        # Prepare environment
        run_env = os.environ.copy()
        if env:
            run_env.update(env)

        # Add CI=true for consistent output in some frameworks
        run_env["CI"] = "true"

        start_time = time.time()

        try:
            process = subprocess.run(
                command,
                shell=True,
                cwd=working_dir,
                capture_output=True,
                text=True,
                timeout=timeout,
                env=run_env
            )

            duration = time.time() - start_time

            # Truncate large outputs
            stdout = process.stdout[:self.MAX_OUTPUT_LENGTH] if process.stdout else ""
            stderr = process.stderr[:self.MAX_OUTPUT_LENGTH] if process.stderr else ""

            # Parse results based on framework
            result = self._parse_results(
                framework=framework,
                command=command,
                exit_code=process.returncode,
                stdout=stdout,
                stderr=stderr,
                duration=duration
            )

            return result

        except subprocess.TimeoutExpired:
            duration = time.time() - start_time
            logger.error(f"Test execution timed out after {timeout}s")
            return TestResult(
                framework=framework.value,
                command=command,
                exit_code=-1,
                duration_seconds=duration,
                success=False,
                error_message=f"Test execution timed out after {timeout} seconds"
            )

        except Exception as e:
            duration = time.time() - start_time
            logger.error(f"Test execution failed: {e}")
            return TestResult(
                framework=framework.value,
                command=command,
                exit_code=-1,
                duration_seconds=duration,
                success=False,
                error_message=str(e)
            )

    def _parse_results(self,
                      framework: TestFramework,
                      command: str,
                      exit_code: int,
                      stdout: str,
                      stderr: str,
                      duration: float) -> TestResult:
        """Parse test results based on framework."""
        combined_output = stdout + stderr

        # Base result
        result = TestResult(
            framework=framework.value,
            command=command,
            exit_code=exit_code,
            stdout=stdout,
            stderr=stderr,
            duration_seconds=duration,
            success=exit_code == 0
        )

        # Framework-specific parsing
        if framework == TestFramework.PYTEST:
            self._parse_pytest(result, combined_output)
        elif framework == TestFramework.UNITTEST:
            self._parse_unittest(result, combined_output)
        elif framework in (TestFramework.NPM, TestFramework.JEST):
            self._parse_jest(result, combined_output)
        elif framework == TestFramework.MOCHA:
            self._parse_mocha(result, combined_output)
        elif framework == TestFramework.GO:
            self._parse_go(result, combined_output)
        elif framework == TestFramework.CARGO:
            self._parse_cargo(result, combined_output)
        else:
            # Generic parsing
            self._parse_generic(result, combined_output)

        # Calculate total if not set
        if result.total == 0:
            result.total = result.passed + result.failed + result.errors + result.skipped

        return result

    def _parse_pytest(self, result: TestResult, output: str) -> None:
        """Parse pytest output."""
        # Pattern: "5 passed, 2 failed, 1 error, 3 skipped"
        patterns = [
            (r"(\d+)\s+passed", "passed"),
            (r"(\d+)\s+failed", "failed"),
            (r"(\d+)\s+error", "errors"),
            (r"(\d+)\s+skipped", "skipped"),
        ]

        for pattern, attr in patterns:
            match = re.search(pattern, output)
            if match:
                setattr(result, attr, int(match.group(1)))

        # Also check for "1 passed in 0.5s" format
        if result.passed == 0:
            short_match = re.search(r"=+\s*(\d+)\s+passed", output)
            if short_match:
                result.passed = int(short_match.group(1))

    def _parse_unittest(self, result: TestResult, output: str) -> None:
        """Parse unittest output."""
        # Pattern: "Ran 10 tests in 0.5s" and "OK" or "FAILED (failures=2, errors=1)"
        ran_match = re.search(r"Ran\s+(\d+)\s+test", output)
        if ran_match:
            result.total = int(ran_match.group(1))

        failure_match = re.search(r"failures=(\d+)", output)
        if failure_match:
            result.failed = int(failure_match.group(1))

        error_match = re.search(r"errors=(\d+)", output)
        if error_match:
            result.errors = int(error_match.group(1))

        skip_match = re.search(r"skipped=(\d+)", output)
        if skip_match:
            result.skipped = int(skip_match.group(1))

        # Calculate passed
        if result.total > 0:
            result.passed = result.total - result.failed - result.errors - result.skipped

    def _parse_jest(self, result: TestResult, output: str) -> None:
        """Parse Jest/npm test output."""
        # Pattern: "Tests: 5 passed, 2 failed, 7 total"
        tests_match = re.search(
            r"Tests:\s*(?:(\d+)\s+passed)?[,\s]*(?:(\d+)\s+failed)?[,\s]*(?:(\d+)\s+skipped)?[,\s]*(\d+)\s+total",
            output
        )
        if tests_match:
            result.passed = int(tests_match.group(1) or 0)
            result.failed = int(tests_match.group(2) or 0)
            result.skipped = int(tests_match.group(3) or 0)
            result.total = int(tests_match.group(4) or 0)

    def _parse_mocha(self, result: TestResult, output: str) -> None:
        """Parse Mocha output."""
        # Pattern: "5 passing (500ms)" or "2 failing"
        passing_match = re.search(r"(\d+)\s+passing", output)
        if passing_match:
            result.passed = int(passing_match.group(1))

        failing_match = re.search(r"(\d+)\s+failing", output)
        if failing_match:
            result.failed = int(failing_match.group(1))

        pending_match = re.search(r"(\d+)\s+pending", output)
        if pending_match:
            result.skipped = int(pending_match.group(1))

    def _parse_go(self, result: TestResult, output: str) -> None:
        """Parse Go test output."""
        # Count PASS and FAIL
        result.passed = len(re.findall(r"---\s+PASS:", output))
        result.failed = len(re.findall(r"---\s+FAIL:", output))
        result.skipped = len(re.findall(r"---\s+SKIP:", output))

        # Also check for "ok" and "FAIL" package results
        result.passed += len(re.findall(r"^ok\s+", output, re.MULTILINE))

    def _parse_cargo(self, result: TestResult, output: str) -> None:
        """Parse Cargo test output."""
        # Pattern: "test result: ok. 5 passed; 0 failed; 1 ignored"
        match = re.search(
            r"test result:.*?(\d+)\s+passed;\s*(\d+)\s+failed;\s*(\d+)\s+ignored",
            output
        )
        if match:
            result.passed = int(match.group(1))
            result.failed = int(match.group(2))
            result.skipped = int(match.group(3))

    def _parse_generic(self, result: TestResult, output: str) -> None:
        """Generic parsing for unknown frameworks."""
        # Look for common patterns
        patterns = [
            (r"(\d+)\s+(?:tests?\s+)?pass(?:ed|ing)?", "passed"),
            (r"(\d+)\s+(?:tests?\s+)?fail(?:ed|ure|ing)?", "failed"),
            (r"(\d+)\s+(?:tests?\s+)?(?:error|err)", "errors"),
            (r"(\d+)\s+(?:tests?\s+)?(?:skip|pending|ignored)", "skipped"),
        ]

        for pattern, attr in patterns:
            match = re.search(pattern, output, re.IGNORECASE)
            if match:
                setattr(result, attr, int(match.group(1)))

    def run_pytest(self,
                  test_path: str = ".",
                  args: str = "",
                  **kwargs) -> TestResult:
        """Convenience method for running pytest."""
        command = f"python -m pytest {test_path} {args} -v"
        return self.run(command, **kwargs)

    def run_unittest(self,
                    test_module: str,
                    args: str = "",
                    **kwargs) -> TestResult:
        """Convenience method for running unittest."""
        command = f"python -m unittest {test_module} {args}"
        return self.run(command, **kwargs)

    def run_npm_test(self, **kwargs) -> TestResult:
        """Convenience method for running npm test."""
        return self.run("npm test", **kwargs)

    def extract_failure_messages(self, result: TestResult) -> List[str]:
        """
        Extract failure messages from test output.

        Args:
            result: TestResult to analyze

        Returns:
            List of failure message strings
        """
        failures = []
        combined = result.stdout + result.stderr

        # Common failure patterns
        patterns = [
            r"FAILED\s+(.+?)\n",
            r"AssertionError:\s*(.+?)(?:\n|$)",
            r"Error:\s*(.+?)(?:\n|$)",
            r"failure\s*:\s*(.+?)(?:\n|$)",
        ]

        for pattern in patterns:
            matches = re.findall(pattern, combined, re.IGNORECASE)
            failures.extend(matches)

        return failures[:20]  # Limit to 20 failures


# Singleton instance
_test_runner: Optional[TestRunner] = None


def get_test_runner() -> TestRunner:
    """Get or create global TestRunner instance."""
    global _test_runner
    if _test_runner is None:
        _test_runner = TestRunner()
    return _test_runner


if __name__ == "__main__":
    # Test the TestRunner
    logging.basicConfig(level=logging.INFO)

    runner = TestRunner()

    # Test with a simple Python test
    result = runner.run("python -c \"print('Tests: 5 passed, 2 failed, 7 total')\"")
    print(f"Framework: {result.framework}")
    print(f"Success: {result.success}")
    print(f"Passed: {result.passed}, Failed: {result.failed}, Total: {result.total}")
    print(f"Duration: {result.duration_seconds:.2f}s")
