"""
Genesis Memory Hardening Integration Tests
==========================================
Test suite for the hardened Genesis memory system.

Tests cover:
1. Security: Secrets management
2. Resilience: Circuit breaker, retry logic
3. Intelligence: Enhanced surprise detection
4. Performance: Atomic I/O, adaptive TTL

Run with: python -m pytest tests/test_memory_hardening.py -v
"""

import json
import os
import sys
import tempfile
import time
import threading
from pathlib import Path
from datetime import datetime
from unittest.mock import MagicMock, patch

# Add core to path
sys.path.insert(0, str(Path(__file__).parent.parent / "core"))

import pytest


# =============================================================================
# TEST 1: Security - Secrets Management
# =============================================================================

class TestSecretsManagement:
    """Test secure credential loading."""

    def test_secrets_loader_import(self):
        """Secrets loader module can be imported."""
        from secrets_loader import get_redis_config, get_qdrant_config
        assert callable(get_redis_config)
        assert callable(get_qdrant_config)

    def test_redis_config_defaults(self):
        """Redis config has sensible defaults."""
        from secrets_loader import get_redis_config
        config = get_redis_config()
        assert config.host == "localhost"
        assert config.port == 6379
        assert isinstance(config.ssl, bool)

    def test_redis_config_from_env(self):
        """Redis config reads from environment."""
        from secrets_loader import get_redis_config

        with patch.dict(os.environ, {
            'GENESIS_REDIS_HOST': 'test-host',
            'GENESIS_REDIS_PORT': '6380'
        }):
            config = get_redis_config()
            assert config.host == 'test-host'
            assert config.port == 6380

    def test_qdrant_config_structure(self):
        """Qdrant config has required fields."""
        from secrets_loader import get_qdrant_config
        config = get_qdrant_config()
        assert hasattr(config, 'host')
        assert hasattr(config, 'port')
        assert hasattr(config, 'api_key')

    def test_no_hardcoded_passwords(self):
        """Verify no hardcoded passwords in core files."""
        core_dir = Path(__file__).parent.parent / "core"
        password_patterns = ['password=', 'api_key=', 'secret=']

        for py_file in core_dir.glob("*.py"):
            if py_file.name in ['secrets_loader.py']:
                continue  # Skip the secrets loader itself

            content = py_file.read_text()
            for pattern in password_patterns:
                # Check for hardcoded values (not environment lookups)
                if f'{pattern}"' in content or f"{pattern}'" in content:
                    # Make sure it's not just a variable assignment from env
                    lines = content.split('\n')
                    for line in lines:
                        if pattern in line and 'os.environ' not in line and '_get_env' not in line:
                            if '= "' in line or "= '" in line:
                                # Allow empty defaults and placeholders
                                if '""' not in line and "''" not in line:
                                    pytest.fail(f"Possible hardcoded secret in {py_file.name}: {line[:80]}")


# =============================================================================
# TEST 2: Resilience - Circuit Breaker
# =============================================================================

class TestCircuitBreaker:
    """Test circuit breaker pattern."""

    def test_circuit_breaker_import(self):
        """Circuit breaker can be imported."""
        from circuit_breaker import CircuitBreaker, CircuitState
        assert CircuitBreaker is not None
        assert CircuitState is not None

    def test_circuit_starts_closed(self):
        """Circuit breaker starts in closed state."""
        from circuit_breaker import CircuitBreaker, CircuitState
        cb = CircuitBreaker("test", failure_threshold=3)
        assert cb.state == CircuitState.CLOSED
        assert cb.is_available

    def test_circuit_opens_after_failures(self):
        """Circuit breaker opens after threshold failures."""
        from circuit_breaker import CircuitBreaker, CircuitState
        cb = CircuitBreaker("test", failure_threshold=3, recovery_timeout=60)

        for i in range(3):
            cb.record_failure(Exception(f"fail {i}"))

        assert cb.state == CircuitState.OPEN
        assert not cb.is_available

    def test_circuit_half_open_after_timeout(self):
        """Circuit breaker transitions to half-open after timeout."""
        from circuit_breaker import CircuitBreaker, CircuitState
        cb = CircuitBreaker("test", failure_threshold=2, recovery_timeout=0.1)

        # Trigger open state
        cb.record_failure(Exception("fail 1"))
        cb.record_failure(Exception("fail 2"))
        assert cb.state == CircuitState.OPEN

        # Wait for recovery timeout
        time.sleep(0.15)

        # Check availability (triggers half-open check)
        assert cb.is_available
        assert cb.state == CircuitState.HALF_OPEN

    def test_circuit_closes_after_success(self):
        """Circuit breaker closes after success in half-open state."""
        from circuit_breaker import CircuitBreaker, CircuitState
        cb = CircuitBreaker("test", failure_threshold=2, recovery_timeout=0.1, half_open_requests=1)

        # Open the circuit
        cb.record_failure(Exception("fail 1"))
        cb.record_failure(Exception("fail 2"))

        # Wait and transition to half-open
        time.sleep(0.15)
        _ = cb.is_available  # Trigger check

        # Record success
        cb.record_success()
        assert cb.state == CircuitState.CLOSED

    def test_circuit_breaker_registry(self):
        """Circuit breaker registry manages instances."""
        from circuit_breaker import get_circuit_breaker
        cb1 = get_circuit_breaker("service1")
        cb2 = get_circuit_breaker("service1")
        cb3 = get_circuit_breaker("service2")

        assert cb1 is cb2  # Same instance
        assert cb1 is not cb3  # Different instance


# =============================================================================
# TEST 3: Resilience - Retry Logic
# =============================================================================

class TestRetryLogic:
    """Test retry with exponential backoff."""

    def test_retry_utils_import(self):
        """Retry utils can be imported."""
        from retry_utils import retry, retry_call, calculate_delay
        assert callable(retry)
        assert callable(retry_call)

    def test_calculate_delay_exponential(self):
        """Delay grows exponentially."""
        from retry_utils import calculate_delay

        # Disable jitter for predictable testing
        d0 = calculate_delay(0, base_delay=1.0, jitter=False)
        d1 = calculate_delay(1, base_delay=1.0, jitter=False)
        d2 = calculate_delay(2, base_delay=1.0, jitter=False)

        assert d0 == 1.0
        assert d1 == 2.0
        assert d2 == 4.0

    def test_calculate_delay_max_cap(self):
        """Delay is capped at max."""
        from retry_utils import calculate_delay
        d = calculate_delay(10, base_delay=1.0, max_delay=10.0, jitter=False)
        assert d == 10.0

    def test_retry_call_success(self):
        """Retry call succeeds on first try."""
        from retry_utils import retry_call

        call_count = 0
        def succeed():
            nonlocal call_count
            call_count += 1
            return "ok"

        result = retry_call(succeed, max_attempts=3)
        assert result == "ok"
        assert call_count == 1

    def test_retry_call_eventual_success(self):
        """Retry call succeeds after failures."""
        from retry_utils import retry_call, RetryConfig

        call_count = 0
        def fail_then_succeed():
            nonlocal call_count
            call_count += 1
            if call_count < 3:
                raise ValueError(f"Attempt {call_count}")
            return "ok"

        config = RetryConfig(max_attempts=5, base_delay=0.01)
        result = retry_call(fail_then_succeed, config=config)
        assert result == "ok"
        assert call_count == 3

    def test_retry_decorator(self):
        """Retry decorator works."""
        from retry_utils import retry, RetryConfig

        call_count = 0

        @retry(config=RetryConfig(max_attempts=3, base_delay=0.01))
        def flaky_function():
            nonlocal call_count
            call_count += 1
            if call_count < 2:
                raise RuntimeError("Flaky!")
            return "success"

        result = flaky_function()
        assert result == "success"
        assert call_count == 2


# =============================================================================
# TEST 4: Intelligence - Enhanced Surprise Detection
# =============================================================================

class TestEnhancedSurprise:
    """Test embedding-based novelty detection."""

    def test_surprise_detector_import(self):
        """Enhanced surprise can be imported."""
        from enhanced_surprise import EnhancedSurpriseDetector, SurpriseScore
        assert EnhancedSurpriseDetector is not None
        assert SurpriseScore is not None

    def test_surprise_detector_creation(self):
        """Surprise detector can be created."""
        from enhanced_surprise import EnhancedSurpriseDetector
        with tempfile.TemporaryDirectory() as tmpdir:
            detector = EnhancedSurpriseDetector(
                memory_path=f"{tmpdir}/test_vectors.json"
            )
            assert detector is not None
            assert detector.vector_size > 0

    def test_surprise_evaluation(self):
        """Surprise evaluation returns valid scores."""
        from enhanced_surprise import EnhancedSurpriseDetector
        with tempfile.TemporaryDirectory() as tmpdir:
            detector = EnhancedSurpriseDetector(
                memory_path=f"{tmpdir}/test_vectors.json"
            )

            result = detector.evaluate(
                "This is a test message about technology",
                "test",
                "tech"
            )

            assert "score" in result
            assert "tier" in result
            assert 0 <= result["score"]["novelty"] <= 1
            assert 0 <= result["score"]["total"] <= 1
            assert result["tier"] in ["working", "episodic", "semantic"]

    def test_surprise_novelty_detection(self):
        """Novel content gets higher scores than similar content."""
        from enhanced_surprise import EnhancedSurpriseDetector
        with tempfile.TemporaryDirectory() as tmpdir:
            detector = EnhancedSurpriseDetector(
                memory_path=f"{tmpdir}/test_vectors.json"
            )

            # Add baseline memory
            detector.add_to_memory(
                "Python is a programming language used for web development",
                "test",
                "tech"
            )

            # Similar content
            similar_result = detector.evaluate(
                "Python is used for building web applications",
                "test",
                "tech"
            )

            # Novel content
            novel_result = detector.evaluate(
                "Quantum entanglement enables faster-than-light communication",
                "test",
                "physics"
            )

            # Novel content should have higher novelty score
            assert novel_result["score"]["novelty"] > similar_result["score"]["novelty"]

    def test_memory_system_compatibility(self):
        """MemorySystem wrapper maintains backward compatibility."""
        from enhanced_surprise import MemorySystem
        with tempfile.TemporaryDirectory() as tmpdir:
            ms = MemorySystem(persistence_path=f"{tmpdir}/test.json")
            result = ms.evaluate("test content", "source", "domain")
            assert "score" in result


# =============================================================================
# TEST 5: Performance - Atomic I/O
# =============================================================================

class TestAtomicIO:
    """Test atomic file operations."""

    def test_atomic_io_import(self):
        """Atomic I/O can be imported."""
        from atomic_io import atomic_write, atomic_json_write, safe_read
        assert callable(atomic_write)
        assert callable(atomic_json_write)
        assert callable(safe_read)

    def test_atomic_write_creates_file(self):
        """Atomic write creates file."""
        from atomic_io import atomic_write, safe_read
        with tempfile.TemporaryDirectory() as tmpdir:
            path = Path(tmpdir) / "test.txt"
            success = atomic_write(path, "Hello, World!")
            assert success
            assert path.exists()
            assert safe_read(path) == "Hello, World!"

    def test_atomic_write_creates_backup(self):
        """Atomic write creates backup of existing file."""
        from atomic_io import atomic_write
        with tempfile.TemporaryDirectory() as tmpdir:
            path = Path(tmpdir) / "test.txt"

            # Create initial file
            atomic_write(path, "Original", backup=True)

            # Overwrite
            atomic_write(path, "Updated", backup=True)

            # Check backup
            backup_path = path.with_suffix(".txt.bak")
            assert backup_path.exists()
            assert backup_path.read_text() == "Original"

    def test_atomic_json_write(self):
        """Atomic JSON write works."""
        from atomic_io import atomic_json_write, safe_json_read
        with tempfile.TemporaryDirectory() as tmpdir:
            path = Path(tmpdir) / "test.json"
            data = {"key": "value", "number": 42}

            success = atomic_json_write(path, data)
            assert success

            loaded = safe_json_read(path)
            assert loaded == data

    def test_atomic_update(self):
        """Atomic update modifies file atomically."""
        from atomic_io import atomic_json_write, atomic_update
        with tempfile.TemporaryDirectory() as tmpdir:
            path = Path(tmpdir) / "test.json"

            # Create initial
            atomic_json_write(path, {"count": 0})

            # Update
            def increment(data):
                data["count"] += 1
                return data

            success = atomic_update(path, increment)
            assert success

            # Verify
            import json
            data = json.loads(path.read_text())
            assert data["count"] == 1

    def test_atomic_file_context_manager(self):
        """AtomicFile context manager works."""
        from atomic_io import AtomicFile
        with tempfile.TemporaryDirectory() as tmpdir:
            path = Path(tmpdir) / "test.txt"

            with AtomicFile(path) as f:
                f.write("Context manager test")

            assert path.exists()
            assert path.read_text() == "Context manager test"

    def test_concurrent_atomic_writes(self):
        """Concurrent atomic writes don't corrupt file."""
        from atomic_io import atomic_write, safe_read
        with tempfile.TemporaryDirectory() as tmpdir:
            path = Path(tmpdir) / "concurrent.txt"

            errors = []
            def writer(n):
                try:
                    for i in range(10):
                        atomic_write(path, f"Writer {n} iteration {i}", backup=False)
                        time.sleep(0.001)
                except Exception as e:
                    errors.append(e)

            threads = [threading.Thread(target=writer, args=(i,)) for i in range(3)]
            for t in threads:
                t.start()
            for t in threads:
                t.join()

            assert not errors, f"Errors during concurrent writes: {errors}"
            # File should exist and be readable
            content = safe_read(path)
            assert content is not None
            assert "Writer" in content


# =============================================================================
# TEST 6: Observability - Logging
# =============================================================================

class TestLogging:
    """Test structured logging."""

    def test_logging_import(self):
        """Logging config can be imported."""
        from logging_config import get_logger, with_context, JSONFormatter
        assert callable(get_logger)
        assert callable(with_context)

    def test_json_formatter(self):
        """JSON formatter produces valid JSON."""
        import logging
        from logging_config import JSONFormatter

        formatter = JSONFormatter()
        record = logging.LogRecord(
            name="test",
            level=logging.INFO,
            pathname="test.py",
            lineno=1,
            msg="Test message",
            args=(),
            exc_info=None
        )

        output = formatter.format(record)
        parsed = json.loads(output)

        assert parsed["level"] == "INFO"
        assert parsed["logger"] == "test"
        assert parsed["message"] == "Test message"
        assert "timestamp" in parsed

    def test_context_manager(self):
        """Logging context manager works."""
        from logging_config import with_context

        with with_context(correlation_id="test-123", component="TestComponent"):
            # Context should be set
            from logging_config import _correlation_id, _component
            assert _correlation_id.get() == "test-123"
            assert _component.get() == "TestComponent"

        # Context should be cleared
        from logging_config import _correlation_id, _component
        assert _correlation_id.get() is None
        assert _component.get() is None


# =============================================================================
# TEST 7: Observability - Metrics
# =============================================================================

class TestMetrics:
    """Test metrics collection."""

    def test_metrics_import(self):
        """Metrics can be imported."""
        from metrics import Counter, Gauge, Histogram, GenesisMetrics
        assert Counter is not None
        assert Gauge is not None
        assert GenesisMetrics is not None

    def test_counter(self):
        """Counter increments correctly."""
        from metrics import Counter
        c = Counter("test_counter", "Test counter")
        c.inc()
        c.inc(5)
        assert c.get() == 6

    def test_counter_with_labels(self):
        """Counter works with labels."""
        from metrics import Counter
        c = Counter("test_counter", "Test")
        c.inc(labels={"method": "GET"})
        c.inc(labels={"method": "POST"})
        c.inc(labels={"method": "GET"})

        assert c.get(labels={"method": "GET"}) == 2
        assert c.get(labels={"method": "POST"}) == 1

    def test_gauge(self):
        """Gauge sets and changes value."""
        from metrics import Gauge
        g = Gauge("test_gauge", "Test gauge")
        g.set(10)
        assert g.get() == 10
        g.inc(5)
        assert g.get() == 15
        g.dec(3)
        assert g.get() == 12

    def test_histogram(self):
        """Histogram records observations."""
        from metrics import Histogram
        h = Histogram("test_histogram", "Test", buckets=[0.1, 0.5, 1.0])
        h.observe(0.05)
        h.observe(0.3)
        h.observe(0.8)

        stats = h.get_stats()
        assert stats["count"] == 3
        assert stats["buckets"][0.1] == 1  # 0.05 <= 0.1
        assert stats["buckets"][0.5] == 2  # 0.05, 0.3 <= 0.5
        assert stats["buckets"][1.0] == 3  # all <= 1.0

    def test_genesis_metrics_snapshot(self):
        """GenesisMetrics snapshot works."""
        from metrics import GenesisMetrics
        snapshot = GenesisMetrics.snapshot()
        assert "timestamp" in snapshot
        assert "metrics" in snapshot
        assert "genesis_memory_operations_total" in snapshot["metrics"]


# =============================================================================
# RUN TESTS
# =============================================================================

if __name__ == "__main__":
    # Run with pytest if available
    try:
        import pytest
        sys.exit(pytest.main([__file__, "-v", "--tb=short"]))
    except ImportError:
        print("pytest not available, running basic tests...")

        # Run basic sanity checks
        test_classes = [
            TestSecretsManagement,
            TestCircuitBreaker,
            TestRetryLogic,
            TestEnhancedSurprise,
            TestAtomicIO,
            TestLogging,
            TestMetrics,
        ]

        passed = 0
        failed = 0

        for test_class in test_classes:
            print(f"\n=== {test_class.__name__} ===")
            instance = test_class()
            for name in dir(instance):
                if name.startswith("test_"):
                    try:
                        getattr(instance, name)()
                        print(f"  [OK] {name}")
                        passed += 1
                    except Exception as e:
                        print(f"  [FAIL] {name}: {e}")
                        failed += 1

        print(f"\n{'='*40}")
        print(f"Passed: {passed}, Failed: {failed}")
        sys.exit(0 if failed == 0 else 1)
