# aiva/api/middleware.py
import time
from opentelemetry import trace
from opentelemetry.context import attach, detach, get_current, set_value
from fastapi import Request, Response

tracer = trace.get_tracer(__name__)

async def tracing_middleware(request: Request, call_next):
    """Middleware to add tracing information to requests."""
    request_id = request.headers.get('X-Request-ID', 'unknown')
    context = set_value('request_id', request_id)
    token = attach(context)

    with tracer.start_as_current_span(request.url.path) as span:
        span.set_attribute("http.method", request.method)
        span.set_attribute("http.url", str(request.url))

        # Add request headers as span attributes
        for header, value in request.headers.items():
            span.set_attribute(f"http.request.header.{header}", value)

        start_time = time.time()
        response: Response = await call_next(request)
        process_time = (time.time() - start_time) * 1000

        span.set_attribute("http.status_code", response.status_code)
        span.set_attribute("process_time_ms", int(process_time))

        detach(token)

    return response