import tiktoken
import time
import heapq
from typing import List, Tuple, Dict, Any
from abc import ABC, abstractmethod


class AbstractSummarizer(ABC):
    """
    Abstract class for summarization models.
    """

    @abstractmethod
    def summarize(self, text: str, **kwargs) -> str:
        """
        Summarizes the given text.

        Args:
            text (str): The text to summarize.
            **kwargs: Additional keyword arguments for the summarization model.

        Returns:
            str: The summarized text.
        """
        pass


class ContextWindowManager:
    """
    Manages the context window for a language model, including token tracking,
    context assembly, and compression strategies.
    """

    def __init__(
        self,
        model_name: str,
        system_prompt: str,
        max_tokens: int,
        summarizer: AbstractSummarizer = None,
        compression_ratio: float = 0.8,
        history_fraction: float = 0.5,
        knowledge_fraction: float = 0.3,
        query_fraction: float = 0.2,
    ):
        """
        Initializes the ContextWindowManager.

        Args:
            model_name (str): The name of the language model.
            system_prompt (str): The system prompt for the language model.
            max_tokens (int): The maximum number of tokens allowed in the context window.
            summarizer (AbstractSummarizer, optional): An abstract summarization model. Defaults to None.
            compression_ratio (float, optional): The target compression ratio for the context window. Defaults to 0.8.
            history_fraction (float, optional): The fraction of the context budget allocated to the conversation history. Defaults to 0.5.
            knowledge_fraction (float, optional): The fraction of the context budget allocated to the retrieved knowledge. Defaults to 0.3.
            query_fraction (float, optional): The fraction of the context budget allocated to the user query. Defaults to 0.2.
        """
        self.model_name = model_name
        self.tokenizer = tiktoken.encoding_for_model(model_name)
        self.system_prompt = system_prompt
        self.max_tokens = max_tokens
        self.available_tokens = max_tokens
        self.history: List[Dict[str, Any]] = []
        self.knowledge: List[Dict[str, Any]] = []
        self.summarizer = summarizer
        self.compression_ratio = compression_ratio
        self.history_fraction = history_fraction
        self.knowledge_fraction = knowledge_fraction
        self.query_fraction = query_fraction
        self.system_prompt_tokens = self._count_tokens(system_prompt)
        self.token_usage: Dict[str, int] = {
            "system_prompt": self.system_prompt_tokens,
            "history": 0,
            "knowledge": 0,
            "query": 0,
        }
        self.context_cache: Dict[str, str] = {}  # Cache context based on conversation state.

        if self.system_prompt_tokens > max_tokens:
            raise ValueError(
                "System prompt exceeds the maximum token limit.  Please shorten it."
            )

    def reset_context(self):
        """
        Resets the context, clearing the history and knowledge.
        Preserves the system prompt.
        """
        self.history = []
        self.knowledge = []
        self.token_usage["history"] = 0
        self.token_usage["knowledge"] = 0
        self.available_tokens = self.max_tokens - self.system_prompt_tokens
        self.context_cache = {}

    def _count_tokens(self, text: str) -> int:
        """
        Counts the number of tokens in a given text using the tokenizer.

        Args:
            text (str): The text to tokenize.

        Returns:
            int: The number of tokens in the text.
        """
        return len(self.tokenizer.encode(text))

    def _update_token_count(self, section: str, new_count: int):
        """
        Updates the token count for a specific section of the context.

        Args:
            section (str): The section to update (e.g., "history", "knowledge").
            new_count (int): The new token count for the section.
        """
        token_difference = new_count - self.token_usage[section]
        self.available_tokens -= token_difference
        self.token_usage[section] = new_count

    def get_token_usage(self) -> Dict[str, int]:
        """
        Returns a dictionary containing the token usage for each section of the context.

        Returns:
            Dict[str, int]: A dictionary containing the token usage for each section.
        """
        return {
            "system_prompt": self.token_usage["system_prompt"],
            "history": self.token_usage["history"],
            "knowledge": self.token_usage["knowledge"],
            "query": self.token_usage["query"],
            "available": self.available_tokens,
            "total": self.max_tokens - self.available_tokens,
        }

    def add_message(self, role: str, content: str):
        """
        Adds a message to the conversation history.

        Args:
            role (str): The role of the message sender (e.g., "user", "assistant").
            content (str): The content of the message.
        """
        message = {
            "role": role,
            "content": content,
            "timestamp": time.time(),
        }
        self.history.append(message)
        self._update_token_count(
            "history", self.token_usage["history"] + self._count_tokens(content)
        )
        self.context_cache = {}  # Invalidate cache

    def add_knowledge(self, content: str, relevance_score: float):
        """
        Adds a knowledge snippet to the context.

        Args:
            content (str): The content of the knowledge snippet.
            relevance_score (float): The relevance score of the knowledge snippet.
        """
        knowledge = {
            "content": content,
            "relevance_score": relevance_score,
            "timestamp": time.time(),
        }
        self.knowledge.append(knowledge)
        self._update_token_count(
            "knowledge", self.token_usage["knowledge"] + self._count_tokens(content)
        )
        self.context_cache = {}  # Invalidate cache

    def _summarize_old_messages(self, target_tokens: int):
        """
        Summarizes the oldest messages in the history to reduce the token count.

        This function identifies the oldest messages in the history until summarizing them
        would bring the history token count below the target_tokens. It then calls an
        external summarization model (self.summarizer) to generate a summary of these messages.
        The original messages are removed from the history, and the summary is added as a
        new message with the role "system".

        Args:
            target_tokens (int): The target number of tokens for the history.
        """
        if not self.summarizer:
            print(
                "Warning: No summarizer provided. Skipping summarization.  Please provide an AbstractSummarizer implementation."
            )
            return

        if self.token_usage["history"] <= target_tokens:
            return  # No need to summarize

        # Sort messages by timestamp (oldest first)
        self.history.sort(key=lambda x: x["timestamp"])

        messages_to_summarize: List[Dict[str, Any]] = []
        tokens_to_remove = 0
        summarization_needed = False

        for message in self.history:
            if self.token_usage["history"] - tokens_to_remove > target_tokens:
                messages_to_summarize.append(message)
                tokens_to_remove += self._count_tokens(message["content"])
                summarization_needed = True
            else:
                break

        if not summarization_needed:
            return

        # Remove the messages to summarize from the history
        for message in messages_to_summarize:
            self.history.remove(message)

        # Summarize the selected messages
        text_to_summarize = "\n".join(
            [f"{msg['role']}: {msg['content']}" for msg in messages_to_summarize]
        )

        try:
            summary = self.summarizer.summarize(text_to_summarize)  # Call the summarizer
        except Exception as e:
            print(f"Error during summarization: {e}")
            # Re-add the messages if summarization fails.
            self.history = messages_to_summarize + self.history
            self.history.sort(key=lambda x: x["timestamp"])
            return

        # Add the summary to the history
        summary_message = {"role": "system", "content": summary, "timestamp": time.time()}
        self.history.insert(0, summary_message)  # Add to the beginning

        # Update token counts
        self._update_token_count(
            "history",
            self.token_usage["history"]
            - tokens_to_remove
            + self._count_tokens(summary),
        )

    def _truncate_retrievals(self, target_tokens: int):
        """
        Truncates the least relevant knowledge snippets to reduce the token count.

        Args:
            target_tokens (int): The target number of tokens for the knowledge.
        """
        if self.token_usage["knowledge"] <= target_tokens:
            return  # No need to truncate

        # Sort knowledge by relevance score (lowest first) and then by timestamp (oldest first)
        self.knowledge.sort(key=lambda x: (x["relevance_score"], x["timestamp"]))

        tokens_to_remove = 0
        knowledge_to_remove: List[Dict[str, Any]] = []
        for knowledge in self.knowledge:
            if self.token_usage["knowledge"] - tokens_to_remove > target_tokens:
                knowledge_to_remove.append(knowledge)
                tokens_to_remove += self._count_tokens(knowledge["content"])
            else:
                break

        # Remove the least relevant knowledge snippets
        for knowledge in knowledge_to_remove:
            self.knowledge.remove(knowledge)

        # Update token counts
        self._update_token_count(
            "knowledge", self.token_usage["knowledge"] - tokens_to_remove
        )

    def _priority_based_pruning(self, query: str):
        """
        Combines summarization and truncation to reduce the token count while preserving
        the most important information.  Prioritizes summarization of old messages
        before truncating knowledge.

        Args:
            query (str): The user query, used to estimate the query token size.
        """

        # Calculate initial target tokens based on fractions
        history_target_tokens = int(self.max_tokens * self.history_fraction)
        knowledge_target_tokens = int(self.max_tokens * self.knowledge_fraction)
        query_target_tokens = int(self.max_tokens * self.query_fraction)

        # Ensure query tokens fit within budget.  If not, reduce history and knowledge proportionally.
        query_token_count = self._count_tokens(query)
        if query_token_count > query_target_tokens:
            reduction_factor = query_token_count / query_target_tokens
            history_target_tokens = int(history_target_tokens / reduction_factor)
            knowledge_target_tokens = int(knowledge_target_tokens / reduction_factor)

        # Summarize history if needed
        self._summarize_old_messages(history_target_tokens)

        # Truncate knowledge if needed
        self._truncate_retrievals(knowledge_target_tokens)

        # Re-calculate the available tokens after compression.
        self.available_tokens = (
            self.max_tokens
            - self.token_usage["system_prompt"]
            - self.token_usage["history"]
            - self.token_usage["knowledge"]
            - query_token_count
        )

    def build_context(self, query: str) -> str:
        """
        Builds the context string for the language model.

        Args:
            query (str): The user query.

        Returns:
            str: The context string.
        """

        # Check if the context is already cached
        cache_key = self._generate_cache_key(query)
        if cache_key in self.context_cache:
            return self.context_cache[cache_key]

        # 1. Prune the context to fit within the token limit
        self._priority_based_pruning(query)

        # 2. Sort knowledge snippets by relevance and recency
        self.knowledge.sort(
            key=lambda x: (x["relevance_score"], x["timestamp"]), reverse=True
        )

        # 3. Build the context string
        context_parts = [self.system_prompt]
        context_parts.extend(
            [f"Knowledge: {k['content']}" for k in self.knowledge]
        )
        context_parts.extend(
            [f"{msg['role']}: {msg['content']}" for msg in self.history]
        )
        context_parts.append(f"User: {query}")  # User query at the end

        context = "\n".join(context_parts)

        # 4. Update query token count
        self._update_token_count("query", self._count_tokens(query))

        # Cache the context
        self.context_cache[cache_key] = context

        return context

    def _generate_cache_key(self, query: str) -> str:
        """
        Generates a cache key based on the current conversation state.

        Args:
            query (str): The user query.

        Returns:
            str: The cache key.
        """
        history_state = tuple(
            (msg["role"], msg["content"]) for msg in self.history
        )  # Convert to tuple for hashability
        knowledge_state = tuple(
            (knowledge["content"], knowledge["relevance_score"])
            for knowledge in self.knowledge
        )  # Convert to tuple for hashability
        return hash(
            (self.system_prompt, history_state, knowledge_state, query)
        )  # All relevant data

