import logging
import time
from typing import Dict, Optional

from prometheus_client import Counter, Histogram, Gauge, start_http_server
from prometheus_client.core import CollectorRegistry, GaugeMetricFamily

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define metrics with labels
COST_TOTAL = Counter('cost_total', 'Total cost of operations', ['provider', 'model', 'story_id'])
REQUESTS_TOTAL = Counter('requests_total', 'Total number of requests', ['provider', 'model', 'story_id'])
LATENCY_HISTOGRAM = Histogram('latency_histogram', 'Latency of requests', ['provider', 'model', 'story_id'])
ERRORS_TOTAL = Counter('errors_total', 'Total number of errors', ['provider', 'model', 'story_id'])

# Custom Collector Example (if needed, otherwise remove)
class CustomCollector:
    """
    A custom collector for Prometheus that fetches and exposes metrics.
    This is an example, adapt to your specific needs.
    """

    def __init__(self, data_source: str):
        """
        Initializes the custom collector.

        Args:
            data_source: The source of the data (e.g., database, API).
        """
        self.data_source = data_source
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)

    def collect(self):
        """
        Collects metrics and yields them to Prometheus.
        """
        try:
            # Example: Fetch some data from a source
            data = self._fetch_data()

            # Example: Create a gauge metric family
            gauge_metric = GaugeMetricFamily(
                'example_gauge',
                'Example gauge metric from custom collector',
                labels=['data_source']
            )
            gauge_metric.add_metric([self.data_source], data['value'])
            yield gauge_metric

        except Exception as e:
            self.logger.error(f"Error collecting metrics: {e}")


    def _fetch_data(self) -> Dict:
        """
        Fetches data from the data source.  This is a placeholder.
        Replace with actual data fetching logic.

        Returns:
            A dictionary containing the data.
        """
        # Simulate fetching data from a source
        # Replace this with your actual data fetching logic
        time.sleep(0.1)  # Simulate some delay
        return {'value': time.time() % 100} # Example: current time modulo 100



def start_prometheus_server(port: int, registry: Optional[CollectorRegistry] = None):
    """
    Starts a Prometheus HTTP server on the specified port.

    Args:
        port: The port to listen on.
        registry:  Optional CollectorRegistry.  If not provided, the default registry is used.
    """
    try:
        logging.info(f"Starting Prometheus exporter on port {port}")
        start_http_server(port, registry=registry)
        logging.info("Prometheus exporter started successfully.")

    except Exception as e:
        logging.error(f"Failed to start Prometheus exporter: {e}")


def record_cost(provider: str, model: str, story_id: str, cost: float):
    """
    Records the cost of an operation.

    Args:
        provider: The provider of the service.
        model: The model used.
        story_id: The ID of the story.
        cost: The cost of the operation.
    """
    COST_TOTAL.labels(provider=provider, model=model, story_id=story_id).inc(cost)


def record_request(provider: str, model: str, story_id: str):
    """
    Records a request.

    Args:
        provider: The provider of the service.
        model: The model used.
        story_id: The ID of the story.
    """
    REQUESTS_TOTAL.labels(provider=provider, model=model, story_id=story_id).inc()


def record_latency(provider: str, model: str, story_id: str, latency: float):
    """
    Records the latency of a request.

    Args:
        provider: The provider of the service.
        model: The model used.
        story_id: The ID of the story.
        latency: The latency of the request in seconds.
    """
    LATENCY_HISTOGRAM.labels(provider=provider, model=model, story_id=story_id).observe(latency)


def record_error(provider: str, model: str, story_id: str):
    """
    Records an error.

    Args:
        provider: The provider of the service.
        model: The model used.
        story_id: The ID of the story.
    """
    ERRORS_TOTAL.labels(provider=provider, model=model, story_id=story_id).inc()


if __name__ == '__main__':
    # Example Usage
    start_prometheus_server(8000)

    # Example Custom Collector
    registry = CollectorRegistry()
    custom_collector = CustomCollector(data_source="MyDataSource")
    registry.register(custom_collector)
    start_prometheus_server(8001, registry=registry)

    # Simulate some activity
    while True:
        record_request(provider="OpenAI", model="GPT-3", story_id="US-012")
        record_latency(provider="OpenAI", model="GPT-3", story_id="US-012", latency=0.5)
        record_cost(provider="OpenAI", model="GPT-3", story_id="US-012", cost=0.1)
        if time.time() % 10 > 8:
            record_error(provider="OpenAI", model="GPT-3", story_id="US-012")
        time.sleep(1)
