import argparse
import asyncio
import logging
import sys
import time
from typing import Any, Callable, Dict, List, Optional

import aiohttp
import statistics

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    stream=sys.stdout,
)

class LoadTestResults:
    """
    Represents the results of a load test.
    """

    def __init__(self) -> None:
        self.latencies: List[float] = []
        self.errors: int = 0
        self.total_requests: int = 0

    def add_latency(self, latency: float) -> None:
        """Adds a latency measurement."""
        self.latencies.append(latency)

    def increment_errors(self) -> None:
        """Increments the error count."""
        self.errors += 1

    def increment_total_requests(self) -> None:
        """Increments the total request count."""
        self.total_requests += 1

    def calculate_percentile(self, percentile: float) -> float:
        """Calculates a given percentile of the latencies."""
        if not self.latencies:
            return 0.0
        return statistics.quantiles(self.latencies, n=100)[int(percentile)]

    def calculate_throughput(self, duration: float) -> float:
        """Calculates the throughput (requests per second)."""
        if duration == 0:
            return 0.0
        return self.total_requests / duration

    def generate_report(self, duration: float) -> Dict[str, Any]:
        """Generates a performance report."""
        if not self.latencies:
            return {
                "error": "No requests were made."
            }

        return {
            "total_requests": self.total_requests,
            "errors": self.errors,
            "error_rate": self.errors / self.total_requests if self.total_requests else 0.0,
            "mean_latency": statistics.mean(self.latencies),
            "median_latency": statistics.median(self.latencies),
            "90th_percentile_latency": self.calculate_percentile(90),
            "95th_percentile_latency": self.calculate_percentile(95),
            "99th_percentile_latency": self.calculate_percentile(99),
            "throughput": self.calculate_throughput(duration),
            "duration": duration,
        }


async def make_request(
    session: aiohttp.ClientSession,
    url: str,
    results: LoadTestResults,
    request_data: Optional[Dict[str, Any]] = None,
    request_method: str = "GET",
) -> None:
    """
    Makes a single HTTP request and records the results.
    """
    start_time = time.monotonic()
    try:
        if request_method == "GET":
            async with session.get(url) as response:
                await response.read()  # Ensure the response body is read
                response.raise_for_status()
        elif request_method == "POST":
            async with session.post(url, json=request_data) as response:
                await response.read()  # Ensure the response body is read
                response.raise_for_status()
        else:
            logging.error(f"Unsupported request method: {request_method}")
            results.increment_errors()
            return

        end_time = time.monotonic()
        latency = end_time - start_time
        results.add_latency(latency)
        results.increment_total_requests()

    except aiohttp.ClientError as e:
        logging.error(f"Request failed: {e}")
        results.increment_errors()
        results.increment_total_requests()
    except Exception as e:
        logging.exception(f"Unexpected error during request: {e}")
        results.increment_errors()
        results.increment_total_requests()


async def run_load_test(
    url: str,
    num_users: int,
    request_rate: float,
    request_data: Optional[Dict[str, Any]] = None,
    request_method: str = "GET",
) -> LoadTestResults:
    """
    Runs the load test with the specified parameters.
    """
    results = LoadTestResults()
    semaphore = asyncio.Semaphore(num_users)  # Limit concurrent requests

    async def worker():
        async with semaphore:
            await make_request(session, url, results, request_data, request_method)

    async with aiohttp.ClientSession() as session:
        start_time = time.monotonic()
        tasks = []
        while True:
            task = asyncio.create_task(worker())
            tasks.append(task)
            await asyncio.sleep(1 / request_rate)  # Control request rate

            if len(tasks) >= 1000:  # Limit the number of outstanding tasks to prevent memory issues
                await asyncio.gather(*tasks)
                tasks = []

            if time.monotonic() - start_time > 60: #Run for 60 seconds
                break

        if tasks:
            await asyncio.gather(*tasks) #wait for remaining tasks

        end_time = time.monotonic()
        duration = end_time - start_time

    report = results.generate_report(duration)
    logging.info(f"Load test report: {report}")
    return results


async def main(
    url: str,
    num_users: int,
    request_rate: float,
    request_data: Optional[Dict[str, Any]] = None,
    request_method: str = "GET",
) -> None:
    """
    Main function to parse arguments and run the load test.
    """
    logging.info(
        f"Starting load test with {num_users} users, request rate of {request_rate} requests/second, against {url} using {request_method} method."
    )
    if request_data:
        logging.info(f"Request data: {request_data}")

    results = await run_load_test(url, num_users, request_rate, request_data, request_method)
    report = results.generate_report(results.generate_report(0)['duration']) #re-generate with the actual duration

    print("\nPerformance Report:")
    for key, value in report.items():
        print(f"{key}: {value}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Load test a service.")
    parser.add_argument("--url", required=True, help="The URL to load test.")
    parser.add_argument(
        "--num_users", type=int, default=10, help="The number of concurrent users."
    )
    parser.add_argument(
        "--request_rate",
        type=float,
        default=10.0,
        help="The number of requests per second.",
    )
    parser.add_argument(
        "--request_data",
        type=str,
        default=None,
        help="JSON data to send with POST requests (e.g., '{\"key\": \"value\"}').",
    )
    parser.add_argument(
        "--request_method",
        type=str,
        default="GET",
        choices=["GET", "POST"],
        help="HTTP method to use (GET or POST).",
    )

    args = parser.parse_args()

    request_data = None
    if args.request_data:
        import json
        try:
            request_data = json.loads(args.request_data)
        except json.JSONDecodeError as e:
            print(f"Error decoding request_data JSON: {e}")
            sys.exit(1)

    asyncio.run(main(args.url, args.num_users, args.request_rate, request_data, args.request_method))
