#!/usr/bin/env python3
"""
Tests for Supadata Quota Tracker

STORY-005: Enhance Supadata integration with quota tracking.

Tests both unit functionality (mocked) and integration (real PostgreSQL).
"""

import os
import sys
import json
import unittest
from datetime import datetime, timezone
from unittest.mock import patch, MagicMock

# Add paths
sys.path.insert(0, '/mnt/e/genesis-system/core/youtube')
sys.path.insert(0, '/mnt/e/genesis-system/data/genesis-memory')

from supadata_tracker import (
    SupadataQuotaTracker,
    QuotaExhaustedError,
    fetch_transcript_with_quota
)


class TestSupadataQuotaTrackerUnit(unittest.TestCase):
    """Unit tests with mocked PostgreSQL connection."""

    def setUp(self):
        """Set up test fixtures."""
        self.mock_conn = MagicMock()
        self.mock_cursor = MagicMock()
        self.mock_conn.cursor.return_value = self.mock_cursor

    @patch('supadata_tracker.psycopg2.connect')
    def test_check_quota_returns_correct_structure(self, mock_connect):
        """Test that check_quota returns all expected fields."""
        mock_connect.return_value = self.mock_conn
        self.mock_cursor.fetchone.return_value = (50,)  # 50 calls used

        tracker = SupadataQuotaTracker()
        quota = tracker.check_quota()

        self.assertIn('limit', quota)
        self.assertIn('used', quota)
        self.assertIn('remaining', quota)
        self.assertIn('month', quota)
        self.assertIn('reset_date', quota)
        self.assertIn('percentage_used', quota)

        self.assertEqual(quota['limit'], 100)
        self.assertEqual(quota['used'], 50)
        self.assertEqual(quota['remaining'], 50)
        self.assertEqual(quota['percentage_used'], 50.0)

    @patch('supadata_tracker.psycopg2.connect')
    def test_is_quota_available_true(self, mock_connect):
        """Test is_quota_available returns True when quota remains."""
        mock_connect.return_value = self.mock_conn
        self.mock_cursor.fetchone.return_value = (50,)

        tracker = SupadataQuotaTracker()
        self.assertTrue(tracker.is_quota_available())

    @patch('supadata_tracker.psycopg2.connect')
    def test_is_quota_available_false(self, mock_connect):
        """Test is_quota_available returns False when quota exhausted."""
        mock_connect.return_value = self.mock_conn
        self.mock_cursor.fetchone.return_value = (100,)

        tracker = SupadataQuotaTracker()
        self.assertFalse(tracker.is_quota_available())

    @patch('supadata_tracker.psycopg2.connect')
    def test_alert_if_low_returns_warning_at_threshold(self, mock_connect):
        """Test alert_if_low returns warning when quota is low."""
        mock_connect.return_value = self.mock_conn
        self.mock_cursor.fetchone.return_value = (95,)  # 5 remaining

        tracker = SupadataQuotaTracker()
        warning = tracker.alert_if_low(threshold=10)

        self.assertIsNotNone(warning)
        self.assertEqual(warning['level'], 'CRITICAL')
        self.assertEqual(warning['remaining'], 5)

    @patch('supadata_tracker.psycopg2.connect')
    def test_alert_if_low_returns_none_when_sufficient(self, mock_connect):
        """Test alert_if_low returns None when quota is sufficient."""
        mock_connect.return_value = self.mock_conn
        self.mock_cursor.fetchone.return_value = (50,)  # 50 remaining

        tracker = SupadataQuotaTracker()
        warning = tracker.alert_if_low(threshold=10)

        self.assertIsNone(warning)

    @patch('supadata_tracker.psycopg2.connect')
    def test_record_usage_increments_counter(self, mock_connect):
        """Test that record_usage properly records the call."""
        mock_connect.return_value = self.mock_conn
        self.mock_cursor.fetchone.side_effect = [
            (1, datetime.now(timezone.utc)),  # INSERT RETURNING
            (51,)  # COUNT for check_quota
        ]

        tracker = SupadataQuotaTracker()
        result = tracker.record_usage("test_video_123")

        self.assertEqual(result['video_id'], "test_video_123")
        self.assertEqual(result['record_id'], 1)
        self.assertIn('remaining', result)

    @patch('supadata_tracker.psycopg2.connect')
    def test_month_key_format(self, mock_connect):
        """Test that month key is in YYYY-MM format."""
        mock_connect.return_value = self.mock_conn
        self.mock_cursor.fetchone.return_value = (0,)

        tracker = SupadataQuotaTracker()
        quota = tracker.check_quota()

        # Should match YYYY-MM pattern
        import re
        self.assertRegex(quota['month'], r'^\d{4}-\d{2}$')


class TestFetchTranscriptWithQuota(unittest.TestCase):
    """Tests for the wrapper function."""

    @patch('supadata_tracker.SupadataQuotaTracker')
    def test_raises_quota_exhausted_error(self, MockTracker):
        """Test that QuotaExhaustedError is raised when quota is exhausted."""
        mock_tracker = MagicMock()
        mock_tracker.is_quota_available.return_value = False
        mock_tracker.check_quota.return_value = {
            'used': 100,
            'limit': 100,
            'reset_date': '2026-03-01'
        }
        mock_tracker.check_video_already_fetched.return_value = None
        MockTracker.return_value = mock_tracker

        mock_client = MagicMock()
        mock_client.extract_video_id.return_value = "abc123"

        with self.assertRaises(QuotaExhaustedError):
            fetch_transcript_with_quota(mock_client, "https://youtube.com/watch?v=abc123")

    @patch('supadata_tracker.SupadataQuotaTracker')
    def test_returns_cached_if_already_fetched(self, MockTracker):
        """Test that cached response is returned for already-fetched videos."""
        mock_tracker = MagicMock()
        mock_tracker.check_video_already_fetched.return_value = {
            'video_id': 'abc123',
            'created_at': '2026-01-15T10:00:00',
            'status': 'success'
        }
        MockTracker.return_value = mock_tracker

        mock_client = MagicMock()
        mock_client.extract_video_id.return_value = "abc123"

        result = fetch_transcript_with_quota(mock_client, "abc123")

        self.assertEqual(result['status'], 'cached')
        self.assertEqual(result['video_id'], 'abc123')
        mock_client.get_transcript.assert_not_called()

    @patch('supadata_tracker.SupadataQuotaTracker')
    def test_records_successful_fetch(self, MockTracker):
        """Test that successful fetches are recorded."""
        mock_tracker = MagicMock()
        mock_tracker.is_quota_available.return_value = True
        mock_tracker.check_video_already_fetched.return_value = None
        mock_tracker.alert_if_low.return_value = None
        MockTracker.return_value = mock_tracker

        mock_client = MagicMock()
        mock_client.extract_video_id.return_value = "abc123"
        mock_client.get_transcript.return_value = {
            'content': [{'text': 'Hello'}],
            'lang': 'en'
        }

        result = fetch_transcript_with_quota(mock_client, "abc123")

        mock_tracker.record_usage.assert_called_once()
        call_args = mock_tracker.record_usage.call_args
        self.assertEqual(call_args[1]['video_id'], 'abc123')
        self.assertEqual(call_args[1]['status'], 'success')


class TestSupadataQuotaTrackerIntegration(unittest.TestCase):
    """Integration tests with real PostgreSQL connection.

    These tests require a live PostgreSQL connection to Elestio.
    They are skipped if the connection fails.
    """

    @classmethod
    def setUpClass(cls):
        """Try to establish PostgreSQL connection."""
        try:
            cls.tracker = SupadataQuotaTracker()
            # Test connection
            cls.tracker.check_quota()
            cls.skip_integration = False
        except Exception as e:
            cls.skip_integration = True
            cls.skip_reason = str(e)

    def setUp(self):
        """Skip if no connection."""
        if self.skip_integration:
            self.skipTest(f"PostgreSQL unavailable: {self.skip_reason}")

    def test_integration_check_quota(self):
        """Integration test: check_quota connects and returns data."""
        quota = self.tracker.check_quota()

        self.assertIsInstance(quota['limit'], int)
        self.assertIsInstance(quota['used'], int)
        self.assertIsInstance(quota['remaining'], int)
        self.assertGreaterEqual(quota['remaining'], 0)
        self.assertLessEqual(quota['used'], quota['limit'] + 1000)  # Allow for test runs

    def test_integration_record_and_check(self):
        """Integration test: record usage and verify it's counted."""
        # Generate unique test video ID
        test_video_id = f"TEST_{datetime.now(timezone.utc).strftime('%Y%m%d%H%M%S')}"

        # Get initial quota
        initial = self.tracker.check_quota()

        # Record usage
        result = self.tracker.record_usage(
            video_id=test_video_id,
            status="test",
            metadata={"test": True}
        )

        self.assertEqual(result['video_id'], test_video_id)

        # Check quota increased
        after = self.tracker.check_quota()
        self.assertEqual(after['used'], initial['used'] + 1)

    def test_integration_get_usage_stats(self):
        """Integration test: get_usage_stats returns valid structure."""
        stats = self.tracker.get_usage_stats()

        self.assertIn('current_month', stats)
        self.assertIn('monthly_breakdown', stats)
        self.assertIn('all_time', stats)
        self.assertIn('recent_videos', stats)
        self.assertIn('generated_at', stats)

    def test_integration_check_video_already_fetched(self):
        """Integration test: check if video was already fetched."""
        # Use a unique test ID that shouldn't exist
        result = self.tracker.check_video_already_fetched("NONEXISTENT_VIDEO_XYZ")
        self.assertIsNone(result)

    def test_integration_alert_threshold(self):
        """Integration test: alert_if_low with various thresholds."""
        quota = self.tracker.check_quota()

        # Test with threshold higher than remaining (should warn)
        high_threshold = quota['remaining'] + 10
        warning = self.tracker.alert_if_low(threshold=high_threshold)
        self.assertIsNotNone(warning)

        # Test with threshold of 0 (should only warn if exactly 0 remaining)
        if quota['remaining'] > 0:
            warning = self.tracker.alert_if_low(threshold=0)
            self.assertIsNone(warning)


class TestEdgeCases(unittest.TestCase):
    """Test edge cases and error handling."""

    @patch('supadata_tracker.psycopg2.connect')
    def test_handles_connection_error_gracefully(self, mock_connect):
        """Test that connection errors are properly raised."""
        import psycopg2
        mock_connect.side_effect = psycopg2.OperationalError("Connection refused")

        with self.assertRaises(ConnectionError):
            SupadataQuotaTracker()

    @patch('supadata_tracker.psycopg2.connect')
    def test_quota_at_exactly_limit(self, mock_connect):
        """Test behavior when quota is exactly at limit."""
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_connect.return_value = mock_conn
        mock_cursor.fetchone.return_value = (100,)  # Exactly at limit

        tracker = SupadataQuotaTracker()
        quota = tracker.check_quota()

        self.assertEqual(quota['remaining'], 0)
        self.assertEqual(quota['percentage_used'], 100.0)
        self.assertFalse(tracker.is_quota_available())

    @patch('supadata_tracker.psycopg2.connect')
    def test_quota_over_limit_edge_case(self, mock_connect):
        """Test behavior when quota somehow exceeds limit."""
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_connect.return_value = mock_conn
        mock_cursor.fetchone.return_value = (105,)  # Over limit

        tracker = SupadataQuotaTracker()
        quota = tracker.check_quota()

        # Remaining should be clamped to 0
        self.assertEqual(quota['remaining'], 0)
        self.assertFalse(tracker.is_quota_available())


if __name__ == '__main__':
    # Run with verbosity
    unittest.main(verbosity=2)
