import asyncio
import datetime
import functools
import hashlib
import inspect
import json
import logging
import os
import secrets
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import jwt
import prometheus_client
from fastapi import (
    Depends,
    FastAPI,
    HTTPException,
    Request,
    WebSocket,
    WebSocketDisconnect,
)
from fastapi.middleware import Middleware
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.routing import APIRoute
from prometheus_client import Counter, Histogram
from prometheus_client import multiprocess
from prometheus_client.metrics_process import ProcessCollector
from pydantic import BaseModel, Field
from starlette.endpoints import WebSocketEndpoint
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import WebSocketRoute
from starlette.status import HTTP_403_FORBIDDEN, HTTP_401_UNAUTHORIZED
from uvicorn import Config, Server

# --- Configuration ---
class Settings(BaseModel):
    """Settings for the API server."""

    app_name: str = "My Production API"
    admin_email: str = "admin@example.com"
    items_per_user: int = 50
    secret_key: str = secrets.token_urlsafe(32)  # Generate a strong secret key
    algorithm: str = "HS256"
    access_token_expire_minutes: int = 30
    refresh_token_expire_minutes: int = 60 * 24 * 30  # 30 days
    allowed_origins: List[str] = ["*"]  # Replace with your actual origins
    rate_limit_per_minute: int = 20
    skills_directory: str = "skills"
    log_level: str = "INFO"  # DEBUG, INFO, WARNING, ERROR, CRITICAL
    prometheus_port: int = 8000
    uvicorn_host: str = "0.0.0.0"
    uvicorn_port: int = 8080

settings = Settings()


# --- Logging ---
logging.basicConfig(level=settings.log_level)
logger = logging.getLogger(__name__)


# --- Prometheus Metrics ---
REQUEST_COUNT = Counter("request_count", "Total number of requests", ["method", "path", "status_code"])
REQUEST_DURATION = Histogram("request_duration_seconds", "Request duration in seconds", ["method", "path"])
AI_VALIDATION_SUCCESS = Counter("ai_validation_success", "Number of successful AI validations")
AI_VALIDATION_FAILURE = Counter("ai_validation_failure", "Number of failed AI validations")


# --- FastAPI App ---
app = FastAPI(
    title=settings.app_name,
    description="A production-ready API server.",
    version="0.1.0",
    openapi_url="/openapi.json",
    docs_url="/docs",
    redoc_url="/redoc",
)

# --- CORS Configuration ---
app.add_middleware(
    CORSMiddleware,
    allow_origins=settings.allowed_origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# --- Data Models ---
class Message(BaseModel):
    text: str

class User(BaseModel):
    user_id: str
    username: str
    email: str
    tier: str = "free"  # Example: free, premium, admin

class Token(BaseModel):
    access_token: str
    token_type: str = "bearer"

class TokenData(BaseModel):
    user_id: Optional[str] = None

class SkillRequest(BaseModel):
    skill_name: str
    input_data: Dict[str, Any] = Field(default_factory=dict)

# --- Authentication ---
def create_access_token(data: dict, expires_delta: Optional[datetime.timedelta] = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.datetime.utcnow() + expires_delta
    else:
        expire = datetime.datetime.utcnow() + datetime.timedelta(minutes=settings.access_token_expire_minutes)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
    return encoded_jwt

def create_refresh_token(data: dict, expires_delta: Optional[datetime.timedelta] = None):
    to_encode = data.copy()
    if expires_delta:
        expire = datetime.datetime.utcnow() + expires_delta
    else:
        expire = datetime.datetime.utcnow() + datetime.timedelta(minutes=settings.refresh_token_expire_minutes)
    to_encode.update({"exp": expire})
    encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
    return encoded_jwt

async def get_current_user(token: str) -> User:
    """Retrieves the current user from the JWT token."""
    try:
        payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm])
        user_id: str = payload.get("sub")  # "sub" claim holds the user ID
        if user_id is None:
            raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token: Missing user ID")

        # In a real application, you would fetch the user from a database here
        # For this example, we'll create a dummy user
        user = User(user_id=user_id, username=f"user_{user_id}", email=f"user_{user_id}@example.com")
        return user
    except jwt.ExpiredSignatureError:
        raise HTTPException(
            status_code=HTTP_401_UNAUTHORIZED, detail="Token has expired"
        )
    except jwt.JWTError as e:
        logger.error(f"JWT Error: {e}")
        raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token")
    except Exception as e:
        logger.exception(f"Unexpected error decoding token: {e}")  # Log the full traceback
        raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Could not validate credentials")

# --- Rate Limiting ---
user_rate_limits: Dict[str, Tuple[int, float]] = {}  # {user_id: (requests_count, last_reset_time)}

async def rate_limit(user: User = Depends(get_current_user)):
    """Rate limits requests based on user tier."""
    now = time.time()
    user_id = user.user_id
    tier = user.tier

    # Define rate limits per tier
    rate_limits = {
        "free": settings.rate_limit_per_minute,
        "premium": settings.rate_limit_per_minute * 5,
        "admin": float('inf'),  # No rate limit for admins
    }

    limit = rate_limits.get(tier, settings.rate_limit_per_minute)

    if user_id not in user_rate_limits:
        user_rate_limits[user_id] = (0, now)

    requests_count, last_reset_time = user_rate_limits[user_id]

    if now - last_reset_time > 60:  # Reset every 60 seconds
        user_rate_limits[user_id] = (0, now)
        requests_count = 0

    if requests_count >= limit:
        raise HTTPException(status_code=429, detail="Too Many Requests")

    user_rate_limits[user_id] = (requests_count + 1, now)


# --- Skills Loading ---
skills: Dict[str, Callable] = {}

def load_skills(skills_directory: str):
    """Loads skills from Python files in the specified directory."""
    for filename in os.listdir(skills_directory):
        if filename.endswith(".py"):
            module_name = filename[:-3]  # Remove ".py" extension
            try:
                module_path = os.path.join(skills_directory, filename)
                spec = importlib.util.spec_from_file_location(module_name, module_path)
                if spec is None or spec.loader is None:
                    logger.warning(f"Could not load skill from {filename}: Invalid module specification.")
                    continue

                module = importlib.util.module_from_spec(spec)
                spec.loader.exec_module(module)

                # Find functions in the module and add them to the skills dictionary
                for name, obj in inspect.getmembers(module):
                    if inspect.isfunction(obj) and not name.startswith("_"):  # Exclude private functions
                        skills[name] = obj
                        logger.info(f"Loaded skill: {name} from {filename}")

            except Exception as e:
                logger.error(f"Error loading skill from {filename}: {e}")

import importlib.util

# Create the skills directory if it doesn't exist.
if not os.path.exists(settings.skills_directory):
    os.makedirs(settings.skills_directory)

# Create a dummy skill if the directory is empty.
if not os.listdir(settings.skills_directory):
    with open(os.path.join(settings.skills_directory, "dummy_skill.py"), "w") as f:
        f.write("""
def greet(name: str = "World"):
    return f"Hello, {name}!"
""")

load_skills(settings.skills_directory)


# --- Prometheus Middleware ---
@app.middleware("http")
async def track_requests(request: Request, call_next):
    """Tracks request counts and duration for Prometheus metrics."""
    start_time = time.time()
    response = await call_next(request)
    process_time = time.time() - start_time
    REQUEST_COUNT.labels(method=request.method, path=request.url.path, status_code=response.status_code).inc()
    REQUEST_DURATION.labels(method=request.method, path=request.url.path).observe(process_time)
    return response

# --- Endpoints ---
@app.get("/health")
async def health_check():
    """Returns the health status of the API."""
    return {"status": "ok"}

@app.post("/token", response_model=Token)
async def login(user: User):  # Replace with actual form data
    """Endpoint for user login and token creation."""
    access_token_expires = datetime.timedelta(minutes=settings.access_token_expire_minutes)
    refresh_token_expires = datetime.timedelta(minutes=settings.refresh_token_expire_minutes)
    access_token = create_access_token(
        data={"sub": user.user_id}, expires_delta=access_token_expires
    )
    refresh_token = create_refresh_token(
        data={"sub": user.user_id}, expires_delta=refresh_token_expires
    )
    return {"access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer"}

@app.post("/refresh_token", response_model=Token)
async def refresh_access_token(refresh_token: str):
    """Endpoint for refreshing access token using refresh token."""
    try:
        payload = jwt.decode(refresh_token, settings.secret_key, algorithms=[settings.algorithm])
        user_id: str = payload.get("sub")
        if user_id is None:
            raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid refresh token: Missing user ID")

        access_token_expires = datetime.timedelta(minutes=settings.access_token_expire_minutes)
        access_token = create_access_token(data={"sub": user_id}, expires_delta=access_token_expires)
        return {"access_token": access_token, "token_type": "bearer"}
    except jwt.ExpiredSignatureError:
        raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Refresh token has expired")
    except jwt.JWTError as e:
        logger.error(f"JWT Error: {e}")
        raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid refresh token")
    except Exception as e:
        logger.exception(f"Unexpected error decoding refresh token: {e}")
        raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Could not validate credentials")

@app.get("/items/")
async def read_items(user: User = Depends(get_current_user), skip: int = 0, limit: int = 10,  rate_limit_dependency: None = Depends(rate_limit)):
    """Example endpoint that requires authentication and rate limiting."""
    logger.info(f"User {user.user_id} accessed /items/ endpoint.")
    return {"items": [f"item_{i}" for i in range(skip, skip + limit)]}

@app.post("/skills/invoke")
async def invoke_skill(request: SkillRequest, user: User = Depends(get_current_user),  rate_limit_dependency: None = Depends(rate_limit)):
    """Invokes a skill based on the request data."""
    skill_name = request.skill_name
    input_data = request.input_data

    logger.info(f"User {user.user_id} invoking skill: {skill_name} with data: {input_data}")

    if skill_name not in skills:
        raise HTTPException(status_code=404, detail="Skill not found")

    try:
        skill_function = skills[skill_name]
        # Check if the skill function is a coroutine function
        if inspect.iscoroutinefunction(skill_function):
            result = await skill_function(**input_data)
        else:
            result = skill_function(**input_data)  # Invoke the skill function
        return {"result": result}
    except Exception as e:
        logger.exception(f"Error invoking skill {skill_name}: {e}")
        raise HTTPException(status_code=500, detail=f"Error invoking skill: {e}")

@app.get("/admin/evolution")
async def trigger_evolution(user: User = Depends(get_current_user),  rate_limit_dependency: None = Depends(rate_limit)):
    """Placeholder for triggering evolution (admin only)."""
    if user.tier != "admin":
        raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Admin access required")
    logger.info(f"Admin {user.user_id} triggered evolution.")
    return {"message": "Evolution triggered (placeholder)"}

async def generate_data():
    """Simulates streaming data."""
    for i in range(10):
        yield f"data: This is line {i}\n\n"
        await asyncio.sleep(1)

@app.get("/stream")
async def stream_data(user: User = Depends(get_current_user),  rate_limit_dependency: None = Depends(rate_limit)):
    """Streams data to the client."""
    logger.info(f"User {user.user_id} accessed /stream endpoint.")
    return StreamingResponse(generate_data(), media_type="text/event-stream")

# --- WebSocket ---
class WebSocketEndpoint(WebSocketEndpoint):
    encoding: str = "json"

    async def on_connect(self, websocket: WebSocket) -> None:
        await websocket.accept()
        self.user_id = None  # Initialize user_id

        try:
            # Authenticate the user via token (you might need a custom header)
            token = websocket.headers.get("Authorization")
            if token:
                token = token.replace("Bearer ", "")
                user = await get_current_user(token)
                self.user_id = user.user_id
                logger.info(f"WebSocket connection established for user: {self.user_id}")
            else:
                logger.warning("WebSocket connection attempted without authentication.")
                await websocket.close(code=1008)  # Policy Violation
                return
        except HTTPException as e:
            logger.warning(f"WebSocket authentication failed: {e.detail}")
            await websocket.close(code=1008)  # Policy Violation
            return
        except Exception as e:
            logger.exception(f"Unexpected error during WebSocket authentication: {e}")
            await websocket.close(code=1011)  # Internal Error
            return

    async def on_receive(self, websocket: WebSocket, data: Any) -> None:
        """Handles incoming WebSocket messages."""
        if self.user_id:
            logger.info(f"Received message from user {self.user_id}: {data}")
            try:
                message = Message(**data)  # Validate incoming data
                response_text = f"You said: {message.text}"
                await websocket.send_json({"message": response_text})
            except Exception as e:
                logger.error(f"Error processing WebSocket message: {e}")
                await websocket.send_json({"error": "Invalid message format"})
        else:
            logger.warning("Received message on unauthenticated WebSocket.")
            await websocket.close(code=1008)

    async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None:
        """Handles WebSocket disconnections."""
        if self.user_id:
            logger.info(f"WebSocket disconnected for user: {self.user_id} with code: {close_code}")
        else:
            logger.info(f"Unauthenticated WebSocket disconnected with code: {close_code}")


app.add_api_websocket_route("/ws", WebSocketEndpoint)

# --- Prometheus Integration ---
from prometheus_client import make_asgi_app

metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)

# --- Startup and Shutdown Events ---
@app.on_event("startup")
async def startup_event():
    """Initializes resources when the application starts."""
    logger.info("Starting up the API...")
    # Perform any necessary initialization tasks here (e.g., database connections)

@app.on_event("shutdown")
async def shutdown_event():
    """Releases resources when the application shuts down."""
    logger.info("Shutting down the API...")
    # Perform any necessary cleanup tasks here (e.g., closing database connections)

# --- Main ---
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host=settings.uvicorn_host, port=settings.uvicorn_port)
