#!/usr/bin/env python3
"""
Test suite for PostToolUse formatter hooks.
"""
import json
import os
import subprocess
import sys
import tempfile
import unittest
from pathlib import Path

# Add parent directory to path
SCRIPTS_DIR = Path(__file__).parent
SKILL_DIR = SCRIPTS_DIR.parent


class TestAuditLogger(unittest.TestCase):
    """Tests for audit-logger.sh"""

    def setUp(self):
        """Create temporary log directory."""
        self.temp_dir = tempfile.mkdtemp()
        self.original_home = os.environ.get('HOME')
        os.environ['HOME'] = self.temp_dir

    def tearDown(self):
        """Restore HOME and cleanup."""
        if self.original_home:
            os.environ['HOME'] = self.original_home
        # Cleanup temp dir
        import shutil
        shutil.rmtree(self.temp_dir, ignore_errors=True)

    def test_audit_logger_creates_log(self):
        """Test that audit logger creates JSONL log file."""
        input_data = {
            "session_id": "test-session-123",
            "tool_name": "Bash",
            "tool_input": {"command": "ls -la"},
            "tool_result": {"exitCode": 0}
        }

        script_path = SCRIPTS_DIR / "audit-logger.sh"
        result = subprocess.run(
            ["bash", str(script_path)],
            input=json.dumps(input_data),
            capture_output=True,
            text=True
        )

        self.assertEqual(result.returncode, 0)

        # Check log file was created
        log_file = Path(self.temp_dir) / ".claude" / "audit" / "tool_usage.jsonl"
        self.assertTrue(log_file.exists(), f"Log file should exist at {log_file}")

        # Parse log entry
        with open(log_file) as f:
            log_entry = json.loads(f.read().strip())

        self.assertEqual(log_entry["session"], "test-session-123")
        self.assertEqual(log_entry["tool"], "Bash")
        self.assertEqual(log_entry["exit_code"], 0)

    def test_audit_logger_appends(self):
        """Test that audit logger appends to existing log."""
        script_path = SCRIPTS_DIR / "audit-logger.sh"

        # First entry
        input1 = {
            "session_id": "session-1",
            "tool_name": "Bash",
            "tool_input": {"command": "echo 1"},
            "tool_result": {"exitCode": 0}
        }
        subprocess.run(
            ["bash", str(script_path)],
            input=json.dumps(input1),
            capture_output=True,
            text=True
        )

        # Second entry
        input2 = {
            "session_id": "session-2",
            "tool_name": "Read",
            "tool_input": {"file_path": "/test"},
            "tool_result": {"exitCode": 0}
        }
        subprocess.run(
            ["bash", str(script_path)],
            input=json.dumps(input2),
            capture_output=True,
            text=True
        )

        log_file = Path(self.temp_dir) / ".claude" / "audit" / "tool_usage.jsonl"
        with open(log_file) as f:
            lines = f.readlines()

        self.assertEqual(len(lines), 2, "Should have 2 log entries")


class TestHighlightErrors(unittest.TestCase):
    """Tests for highlight-errors.sh"""

    def test_no_modification_on_success(self):
        """Test that successful commands pass through unchanged."""
        input_data = {
            "tool_name": "Bash",
            "tool_input": {"command": "echo hello"},
            "tool_result": {
                "stdout": "hello",
                "stderr": "",
                "exitCode": 0
            }
        }

        script_path = SCRIPTS_DIR / "highlight-errors.sh"
        result = subprocess.run(
            ["bash", str(script_path)],
            input=json.dumps(input_data),
            capture_output=True,
            text=True
        )

        self.assertEqual(result.returncode, 0)
        # Should have no output (pass through)
        self.assertEqual(result.stdout.strip(), "")

    def test_formats_on_error(self):
        """Test that errors get formatted prominently."""
        input_data = {
            "tool_name": "Bash",
            "tool_input": {"command": "false"},
            "tool_result": {
                "stdout": "",
                "stderr": "Command failed",
                "exitCode": 1
            }
        }

        script_path = SCRIPTS_DIR / "highlight-errors.sh"
        result = subprocess.run(
            ["bash", str(script_path)],
            input=json.dumps(input_data),
            capture_output=True,
            text=True
        )

        self.assertEqual(result.returncode, 0)

        # Should have JSON output with formatted error
        output = json.loads(result.stdout)
        self.assertIn("hookSpecificOutput", output)
        self.assertIn("⚠️", output["hookSpecificOutput"]["modifiedResult"])
        self.assertIn("Exit: 1", output["hookSpecificOutput"]["modifiedResult"])

    def test_formats_on_stderr_only(self):
        """Test that stderr presence triggers formatting even with exit 0."""
        input_data = {
            "tool_name": "Bash",
            "tool_input": {"command": "some-command"},
            "tool_result": {
                "stdout": "output",
                "stderr": "warning message",
                "exitCode": 0
            }
        }

        script_path = SCRIPTS_DIR / "highlight-errors.sh"
        result = subprocess.run(
            ["bash", str(script_path)],
            input=json.dumps(input_data),
            capture_output=True,
            text=True
        )

        self.assertEqual(result.returncode, 0)
        output = json.loads(result.stdout)
        self.assertIn("warning message", output["hookSpecificOutput"]["modifiedResult"])


class TestFormatPython(unittest.TestCase):
    """Tests for format.py"""

    def test_formats_bash_output(self):
        """Test Bash output formatting."""
        input_data = {
            "tool_name": "Bash",
            "tool_input": {"command": "ls"},
            "tool_result": {
                "stdout": "file1\nfile2",
                "exitCode": 0
            }
        }

        script_path = SCRIPTS_DIR / "format.py"
        result = subprocess.run(
            [sys.executable, str(script_path)],
            input=json.dumps(input_data),
            capture_output=True,
            text=True
        )

        self.assertEqual(result.returncode, 0)

        output = json.loads(result.stdout)
        modified = output["hookSpecificOutput"]["modifiedResult"]
        self.assertIn("Command Output", modified)
        self.assertIn("Exit Code:** 0", modified)
        self.assertIn("file1", modified)

    def test_passes_through_unknown_tools(self):
        """Test that unknown tools pass through unchanged."""
        input_data = {
            "tool_name": "UnknownTool",
            "tool_input": {},
            "tool_result": {"data": "test"}
        }

        script_path = SCRIPTS_DIR / "format.py"
        result = subprocess.run(
            [sys.executable, str(script_path)],
            input=json.dumps(input_data),
            capture_output=True,
            text=True
        )

        self.assertEqual(result.returncode, 0)
        # Should have no output (pass through)
        self.assertEqual(result.stdout.strip(), "")

    def test_handles_malformed_json(self):
        """Test graceful handling of malformed JSON."""
        script_path = SCRIPTS_DIR / "format.py"
        result = subprocess.run(
            [sys.executable, str(script_path)],
            input="not valid json",
            capture_output=True,
            text=True
        )

        # Should exit 0 (pass through on error)
        self.assertEqual(result.returncode, 0)

    def test_truncates_long_output(self):
        """Test that very long outputs get truncated."""
        # Create output with 200 lines
        long_output = "\n".join([f"line {i}" for i in range(200)])

        input_data = {
            "tool_name": "Bash",
            "tool_input": {"command": "cat bigfile"},
            "tool_result": {
                "stdout": long_output,
                "exitCode": 0
            }
        }

        script_path = SCRIPTS_DIR / "format.py"
        result = subprocess.run(
            [sys.executable, str(script_path)],
            input=json.dumps(input_data),
            capture_output=True,
            text=True
        )

        self.assertEqual(result.returncode, 0)
        output = json.loads(result.stdout)
        modified = output["hookSpecificOutput"]["modifiedResult"]
        self.assertIn("more lines", modified)


class TestSkillStructure(unittest.TestCase):
    """Tests for skill file structure."""

    def test_skill_md_exists(self):
        """Test SKILL.md exists and has required fields."""
        skill_file = SKILL_DIR / "SKILL.md"
        self.assertTrue(skill_file.exists())

        content = skill_file.read_text()
        self.assertIn("name: posttooluse-formatter", content)
        self.assertIn("description:", content)
        self.assertIn("allowed-tools:", content)

    def test_all_scripts_executable(self):
        """Test all scripts have execute permission."""
        for script in SCRIPTS_DIR.glob("*.sh"):
            self.assertTrue(
                os.access(script, os.X_OK),
                f"{script.name} should be executable"
            )

    def test_scripts_have_shebangs(self):
        """Test all scripts have proper shebangs."""
        for script in SCRIPTS_DIR.glob("*.sh"):
            with open(script) as f:
                first_line = f.readline()
            self.assertTrue(
                first_line.startswith("#!/bin/bash"),
                f"{script.name} should have bash shebang"
            )

        for script in SCRIPTS_DIR.glob("*.py"):
            if script.name.startswith("test_"):
                continue
            with open(script) as f:
                first_line = f.readline()
            self.assertTrue(
                first_line.startswith("#!/usr/bin/env python"),
                f"{script.name} should have python shebang"
            )


if __name__ == "__main__":
    unittest.main(verbosity=2)
