import json
import logging
import re
from typing import List, Optional

from pydantic import BaseModel

from mem0.memory.utils import extract_json
from mem0.vector_stores.base import VectorStoreBase

try:
    from azure.core.credentials import AzureKeyCredential
    from azure.core.exceptions import ResourceNotFoundError
    from azure.identity import DefaultAzureCredential
    from azure.search.documents import SearchClient
    from azure.search.documents.indexes import SearchIndexClient
    from azure.search.documents.indexes.models import (
        BinaryQuantizationCompression,
        HnswAlgorithmConfiguration,
        ScalarQuantizationCompression,
        SearchField,
        SearchFieldDataType,
        SearchIndex,
        SimpleField,
        VectorSearch,
        VectorSearchProfile,
    )
    from azure.search.documents.models import VectorizedQuery
except ImportError:
    raise ImportError(
        "The 'azure-search-documents' library is required. Please install it using 'pip install azure-search-documents==11.5.2'."
    )

logger = logging.getLogger(__name__)


class OutputData(BaseModel):
    id: Optional[str]
    score: Optional[float]
    payload: Optional[dict]


class AzureAISearch(VectorStoreBase):
    def __init__(
        self,
        service_name,
        collection_name,
        api_key,
        embedding_model_dims,
        compression_type: Optional[str] = None,
        use_float16: bool = False,
        hybrid_search: bool = False,
        vector_filter_mode: Optional[str] = None,
    ):
        """
        Initialize the Azure AI Search vector store.

        Args:
            service_name (str): Azure AI Search service name.
            collection_name (str): Index name.
            api_key (str): API key for the Azure AI Search service.
            embedding_model_dims (int): Dimension of the embedding vector.
            compression_type (Optional[str]): Specifies the type of quantization to use.
                Allowed values are None (no quantization), "scalar", or "binary".
            use_float16 (bool): Whether to store vectors in half precision (Edm.Half) or full precision (Edm.Single).
                (Note: This flag is preserved from the initial implementation per feedback.)
            hybrid_search (bool): Whether to use hybrid search. Default is False.
            vector_filter_mode (Optional[str]): Mode for vector filtering. Default is "preFilter".
        """
        self.service_name = service_name
        self.api_key = api_key
        self.index_name = collection_name
        self.collection_name = collection_name
        self.embedding_model_dims = embedding_model_dims
        # If compression_type is None, treat it as "none".
        self.compression_type = (compression_type or "none").lower()
        self.use_float16 = use_float16
        self.hybrid_search = hybrid_search
        self.vector_filter_mode = vector_filter_mode

        # If the API key is not provided or is a placeholder, use DefaultAzureCredential.
        if self.api_key is None or self.api_key == "" or self.api_key == "your-api-key":
            credential = DefaultAzureCredential()
            self.api_key = None
        else:
            credential = AzureKeyCredential(self.api_key)

        self.search_client = SearchClient(
            endpoint=f"https://{service_name}.search.windows.net",
            index_name=self.index_name,
            credential=credential,
        )
        self.index_client = SearchIndexClient(
            endpoint=f"https://{service_name}.search.windows.net",
            credential=credential,
        )

        self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
        self.index_client._client._config.user_agent_policy.add_user_agent("mem0")

        collections = self.list_cols()
        if collection_name not in collections:
            self.create_col()

    def create_col(self):
        """Create a new index in Azure AI Search."""
        # Determine vector type based on use_float16 setting.
        if self.use_float16:
            vector_type = "Collection(Edm.Half)"
        else:
            vector_type = "Collection(Edm.Single)"

        # Configure compression settings based on the specified compression_type.
        compression_configurations = []
        compression_name = None
        if self.compression_type == "scalar":
            compression_name = "myCompression"
            # For SQ, rescoring defaults to True and oversampling defaults to 4.
            compression_configurations = [
                ScalarQuantizationCompression(
                    compression_name=compression_name
                    # rescoring defaults to True and oversampling defaults to 4
                )
            ]
        elif self.compression_type == "binary":
            compression_name = "myCompression"
            # For BQ, rescoring defaults to True and oversampling defaults to 10.
            compression_configurations = [
                BinaryQuantizationCompression(
                    compression_name=compression_name
                    # rescoring defaults to True and oversampling defaults to 10
                )
            ]
        # If no compression is desired, compression_configurations remains empty.
        fields = [
            SimpleField(name="id", type=SearchFieldDataType.String, key=True),
            SimpleField(name="user_id", type=SearchFieldDataType.String, filterable=True),
            SimpleField(name="run_id", type=SearchFieldDataType.String, filterable=True),
            SimpleField(name="agent_id", type=SearchFieldDataType.String, filterable=True),
            SearchField(
                name="vector",
                type=vector_type,
                searchable=True,
                vector_search_dimensions=self.embedding_model_dims,
                vector_search_profile_name="my-vector-config",
            ),
            SearchField(name="payload", type=SearchFieldDataType.String, searchable=True),
        ]

        vector_search = VectorSearch(
            profiles=[
                VectorSearchProfile(
                    name="my-vector-config",
                    algorithm_configuration_name="my-algorithms-config",
                    compression_name=compression_name if self.compression_type != "none" else None,
                )
            ],
            algorithms=[HnswAlgorithmConfiguration(name="my-algorithms-config")],
            compressions=compression_configurations,
        )
        index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
        self.index_client.create_or_update_index(index)

    def _generate_document(self, vector, payload, id):
        document = {"id": id, "vector": vector, "payload": json.dumps(payload)}
        # Extract additional fields if they exist.
        for field in ["user_id", "run_id", "agent_id"]:
            if field in payload:
                document[field] = payload[field]
        return document

    # Note: Explicit "insert" calls may later be decoupled from memory management decisions.
    def insert(self, vectors, payloads=None, ids=None):
        """
        Insert vectors into the index.

        Args:
            vectors (List[List[float]]): List of vectors to insert.
            payloads (List[Dict], optional): List of payloads corresponding to vectors.
            ids (List[str], optional): List of IDs corresponding to vectors.
        """
        logger.info(f"Inserting {len(vectors)} vectors into index {self.index_name}")
        documents = [
            self._generate_document(vector, payload, id) for id, vector, payload in zip(ids, vectors, payloads)
        ]
        response = self.search_client.upload_documents(documents)
        for doc in response:
            if not hasattr(doc, "status_code") and doc.get("status_code") != 201:
                raise Exception(f"Insert failed for document {doc.get('id')}: {doc}")
        return response

    def _sanitize_key(self, key: str) -> str:
        return re.sub(r"[^\w]", "", key)

    def _build_filter_expression(self, filters):
        filter_conditions = []
        for key, value in filters.items():
            safe_key = self._sanitize_key(key)
            if isinstance(value, str):
                safe_value = value.replace("'", "''")
                condition = f"{safe_key} eq '{safe_value}'"
            else:
                condition = f"{safe_key} eq {value}"
            filter_conditions.append(condition)
        filter_expression = " and ".join(filter_conditions)
        return filter_expression

    def search(self, query, vectors, limit=5, filters=None):
        """
        Search for similar vectors.

        Args:
            query (str): Query.
            vectors (List[float]): Query vector.
            limit (int, optional): Number of results to return. Defaults to 5.
            filters (Dict, optional): Filters to apply to the search. Defaults to None.

        Returns:
            List[OutputData]: Search results.
        """
        filter_expression = None
        if filters:
            filter_expression = self._build_filter_expression(filters)

        vector_query = VectorizedQuery(vector=vectors, k_nearest_neighbors=limit, fields="vector")
        if self.hybrid_search:
            search_results = self.search_client.search(
                search_text=query,
                vector_queries=[vector_query],
                filter=filter_expression,
                top=limit,
                vector_filter_mode=self.vector_filter_mode,
                search_fields=["payload"],
            )
        else:
            search_results = self.search_client.search(
                vector_queries=[vector_query],
                filter=filter_expression,
                top=limit,
                vector_filter_mode=self.vector_filter_mode,
            )

        results = []
        for result in search_results:
            payload = json.loads(extract_json(result["payload"]))
            results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
        return results

    def delete(self, vector_id):
        """
        Delete a vector by ID.

        Args:
            vector_id (str): ID of the vector to delete.
        """
        response = self.search_client.delete_documents(documents=[{"id": vector_id}])
        for doc in response:
            if not hasattr(doc, "status_code") and doc.get("status_code") != 200:
                raise Exception(f"Delete failed for document {vector_id}: {doc}")
        logger.info(f"Deleted document with ID '{vector_id}' from index '{self.index_name}'.")
        return response

    def update(self, vector_id, vector=None, payload=None):
        """
        Update a vector and its payload.

        Args:
            vector_id (str): ID of the vector to update.
            vector (List[float], optional): Updated vector.
            payload (Dict, optional): Updated payload.
        """
        document = {"id": vector_id}
        if vector:
            document["vector"] = vector
        if payload:
            json_payload = json.dumps(payload)
            document["payload"] = json_payload
            for field in ["user_id", "run_id", "agent_id"]:
                document[field] = payload.get(field)
        response = self.search_client.merge_or_upload_documents(documents=[document])
        for doc in response:
            if not hasattr(doc, "status_code") and doc.get("status_code") != 200:
                raise Exception(f"Update failed for document {vector_id}: {doc}")
        return response

    def get(self, vector_id) -> OutputData:
        """
        Retrieve a vector by ID.

        Args:
            vector_id (str): ID of the vector to retrieve.

        Returns:
            OutputData: Retrieved vector.
        """
        try:
            result = self.search_client.get_document(key=vector_id)
        except ResourceNotFoundError:
            return None
        payload = json.loads(extract_json(result["payload"]))
        return OutputData(id=result["id"], score=None, payload=payload)

    def list_cols(self) -> List[str]:
        """
        List all collections (indexes).

        Returns:
            List[str]: List of index names.
        """
        try:
            names = self.index_client.list_index_names()
        except AttributeError:
            names = [index.name for index in self.index_client.list_indexes()]
        return names

    def delete_col(self):
        """Delete the index."""
        self.index_client.delete_index(self.index_name)

    def col_info(self):
        """
        Get information about the index.

        Returns:
            dict: Index information.
        """
        index = self.index_client.get_index(self.index_name)
        return {"name": index.name, "fields": index.fields}

    def list(self, filters=None, limit=100):
        """
        List all vectors in the index.

        Args:
            filters (dict, optional): Filters to apply to the list.
            limit (int, optional): Number of vectors to return. Defaults to 100.

        Returns:
            List[OutputData]: List of vectors.
        """
        filter_expression = None
        if filters:
            filter_expression = self._build_filter_expression(filters)

        search_results = self.search_client.search(search_text="*", filter=filter_expression, top=limit)
        results = []
        for result in search_results:
            payload = json.loads(extract_json(result["payload"]))
            results.append(OutputData(id=result["id"], score=result["@search.score"], payload=payload))
        return [results]

    def __del__(self):
        """Close the search client when the object is deleted."""
        self.search_client.close()
        self.index_client.close()

    def reset(self):
        """Reset the index by deleting and recreating it."""
        logger.warning(f"Resetting index {self.index_name}...")

        try:
            # Close the existing clients
            self.search_client.close()
            self.index_client.close()

            # Delete the collection
            self.delete_col()

            # If the API key is not provided or is a placeholder, use DefaultAzureCredential.
            if self.api_key is None or self.api_key == "" or self.api_key == "your-api-key":
                credential = DefaultAzureCredential()
                self.api_key = None
            else:
                credential = AzureKeyCredential(self.api_key)

            # Reinitialize the clients
            service_endpoint = f"https://{self.service_name}.search.windows.net"
            self.search_client = SearchClient(
                endpoint=service_endpoint,
                index_name=self.index_name,
                credential=credential,
            )
            self.index_client = SearchIndexClient(
                endpoint=service_endpoint,
                credential=credential,
            )

            # Add user agent
            self.search_client._client._config.user_agent_policy.add_user_agent("mem0")
            self.index_client._client._config.user_agent_policy.add_user_agent("mem0")

            # Create the collection
            self.create_col()
        except Exception as e:
            logger.error(f"Error resetting index {self.index_name}: {e}")
            raise
