import logging
import traceback
import uuid
from typing import Any, Dict, List, Optional, Tuple

import google.api_core.exceptions
from google.cloud import aiplatform, aiplatform_v1
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import Namespace
from google.oauth2 import service_account
from pydantic import BaseModel

try:
    from langchain_core.documents import Document
except ImportError:  # pragma: no cover - fallback for older LangChain versions
    from langchain.schema import Document  # type: ignore[no-redef]

from mem0.configs.vector_stores.vertex_ai_vector_search import (
    GoogleMatchingEngineConfig,
)
from mem0.vector_stores.base import VectorStoreBase

# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


class OutputData(BaseModel):
    id: Optional[str]  # memory id
    score: Optional[float]  # distance
    payload: Optional[Dict]  # metadata


class GoogleMatchingEngine(VectorStoreBase):
    def __init__(self, **kwargs):
        """Initialize Google Matching Engine client."""
        logger.debug("Initializing Google Matching Engine with kwargs: %s", kwargs)

        # If collection_name is passed, use it as deployment_index_id if deployment_index_id is not provided
        if "collection_name" in kwargs and "deployment_index_id" not in kwargs:
            kwargs["deployment_index_id"] = kwargs["collection_name"]
            logger.debug("Using collection_name as deployment_index_id: %s", kwargs["deployment_index_id"])
        elif "deployment_index_id" in kwargs and "collection_name" not in kwargs:
            kwargs["collection_name"] = kwargs["deployment_index_id"]
            logger.debug("Using deployment_index_id as collection_name: %s", kwargs["collection_name"])

        try:
            config = GoogleMatchingEngineConfig(**kwargs)
            logger.debug("Config created: %s", config.model_dump())
            logger.debug("Config collection_name: %s", getattr(config, "collection_name", None))
        except Exception as e:
            logger.error("Failed to validate config: %s", str(e))
            raise

        self.project_id = config.project_id
        self.project_number = config.project_number
        self.region = config.region
        self.endpoint_id = config.endpoint_id
        self.index_id = config.index_id  # The actual index ID
        self.deployment_index_id = config.deployment_index_id  # The deployment-specific ID
        self.collection_name = config.collection_name
        self.vector_search_api_endpoint = config.vector_search_api_endpoint

        logger.debug("Using project=%s, location=%s", self.project_id, self.region)

        # Initialize Vertex AI with credentials if provided
        init_args = {
            "project": self.project_id,
            "location": self.region,
        }
        
        # Support both credentials_path and service_account_json
        if hasattr(config, "credentials_path") and config.credentials_path:
            logger.debug("Using credentials from file: %s", config.credentials_path)
            credentials = service_account.Credentials.from_service_account_file(config.credentials_path)
            init_args["credentials"] = credentials
        elif hasattr(config, "service_account_json") and config.service_account_json:
            logger.debug("Using credentials from provided JSON dict")
            credentials = service_account.Credentials.from_service_account_info(config.service_account_json)
            init_args["credentials"] = credentials

        try:
            aiplatform.init(**init_args)
            logger.debug("Vertex AI initialized successfully")
        except Exception as e:
            logger.error("Failed to initialize Vertex AI: %s", str(e))
            raise

        try:
            # Format the index path properly using the configured index_id
            index_path = f"projects/{self.project_number}/locations/{self.region}/indexes/{self.index_id}"
            logger.debug("Initializing index with path: %s", index_path)
            self.index = aiplatform.MatchingEngineIndex(index_name=index_path)
            logger.debug("Index initialized successfully")

            # Format the endpoint name properly
            endpoint_name = self.endpoint_id
            logger.debug("Initializing endpoint with name: %s", endpoint_name)
            self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(index_endpoint_name=endpoint_name)
            logger.debug("Endpoint initialized successfully")
        except Exception as e:
            logger.error("Failed to initialize Matching Engine components: %s", str(e))
            raise ValueError(f"Invalid configuration: {str(e)}")

    def _parse_output(self, data: Dict) -> List[OutputData]:
        """
        Parse the output data.
        Args:
            data (Dict): Output data.
        Returns:
            List[OutputData]: Parsed output data.
        """
        results = data.get("nearestNeighbors", {}).get("neighbors", [])
        output_data = []
        for result in results:
            output_data.append(
                OutputData(
                    id=result.get("datapoint").get("datapointId"),
                    score=result.get("distance"),
                    payload=result.get("datapoint").get("metadata"),
                )
            )
        return output_data

    def _create_restriction(self, key: str, value: Any) -> aiplatform_v1.types.index.IndexDatapoint.Restriction:
        """Create a restriction object for the Matching Engine index.

        Args:
            key: The namespace/key for the restriction
            value: The value to restrict on

        Returns:
            Restriction object for the index
        """
        str_value = str(value) if value is not None else ""
        return aiplatform_v1.types.index.IndexDatapoint.Restriction(namespace=key, allow_list=[str_value])

    def _create_datapoint(
        self, vector_id: str, vector: List[float], payload: Optional[Dict] = None
    ) -> aiplatform_v1.types.index.IndexDatapoint:
        """Create a datapoint object for the Matching Engine index.

        Args:
            vector_id: The ID for the datapoint
            vector: The vector to store
            payload: Optional metadata to store with the vector

        Returns:
            IndexDatapoint object
        """
        restrictions = []
        if payload:
            restrictions = [self._create_restriction(key, value) for key, value in payload.items()]

        return aiplatform_v1.types.index.IndexDatapoint(
            datapoint_id=vector_id, feature_vector=vector, restricts=restrictions
        )

    def insert(
        self,
        vectors: List[list],
        payloads: Optional[List[Dict]] = None,
        ids: Optional[List[str]] = None,
    ) -> None:
        """Insert vectors into the Matching Engine index.

        Args:
            vectors: List of vectors to insert
            payloads: Optional list of metadata dictionaries
            ids: Optional list of IDs for the vectors

        Raises:
            ValueError: If vectors is empty or lengths don't match
            GoogleAPIError: If the API call fails
        """
        if not vectors:
            raise ValueError("No vectors provided for insertion")

        if payloads and len(payloads) != len(vectors):
            raise ValueError(f"Number of payloads ({len(payloads)}) does not match number of vectors ({len(vectors)})")

        if ids and len(ids) != len(vectors):
            raise ValueError(f"Number of ids ({len(ids)}) does not match number of vectors ({len(vectors)})")

        logger.debug("Starting insert of %d vectors", len(vectors))

        try:
            datapoints = [
                self._create_datapoint(
                    vector_id=ids[i] if ids else str(uuid.uuid4()),
                    vector=vector,
                    payload=payloads[i] if payloads and i < len(payloads) else None,
                )
                for i, vector in enumerate(vectors)
            ]

            logger.debug("Created %d datapoints", len(datapoints))
            self.index.upsert_datapoints(datapoints=datapoints)
            logger.debug("Successfully inserted datapoints")

        except google.api_core.exceptions.GoogleAPIError as e:
            logger.error("Failed to insert vectors: %s", str(e))
            raise
        except Exception as e:
            logger.error("Unexpected error during insert: %s", str(e))
            logger.error("Stack trace: %s", traceback.format_exc())
            raise

    def search(
        self, query: str, vectors: List[float], limit: int = 5, filters: Optional[Dict] = None
    ) -> List[OutputData]:
        """
        Search for similar vectors.
        Args:
            query (str): Query.
            vectors (List[float]): Query vector.
            limit (int, optional): Number of results to return. Defaults to 5.
            filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
        Returns:
            List[OutputData]: Search results (unwrapped)
        """
        logger.debug("Starting search")
        logger.debug("Limit: %d, Filters: %s", limit, filters)

        try:
            filter_namespaces = []
            if filters:
                logger.debug("Processing filters")
                for key, value in filters.items():
                    logger.debug("Processing filter %s=%s (type=%s)", key, value, type(value))
                    if isinstance(value, (str, int, float)):
                        logger.debug("Adding simple filter for %s", key)
                        filter_namespaces.append(Namespace(key, [str(value)], []))
                    elif isinstance(value, dict):
                        logger.debug("Adding complex filter for %s", key)
                        includes = value.get("include", [])
                        excludes = value.get("exclude", [])
                        filter_namespaces.append(Namespace(key, includes, excludes))

            logger.debug("Final filter_namespaces: %s", filter_namespaces)

            response = self.index_endpoint.find_neighbors(
                deployed_index_id=self.deployment_index_id,
                queries=[vectors],
                num_neighbors=limit,
                filter=filter_namespaces if filter_namespaces else None,
                return_full_datapoint=True,
            )

            if not response or len(response) == 0 or len(response[0]) == 0:
                logger.debug("No results found")
                return []

            results = []
            for neighbor in response[0]:
                logger.debug("Processing neighbor - id: %s, distance: %s", neighbor.id, neighbor.distance)

                payload = {}
                if hasattr(neighbor, "restricts"):
                    logger.debug("Processing restricts")
                    for restrict in neighbor.restricts:
                        if hasattr(restrict, "name") and hasattr(restrict, "allow_tokens") and restrict.allow_tokens:
                            logger.debug("Adding %s: %s", restrict.name, restrict.allow_tokens[0])
                            payload[restrict.name] = restrict.allow_tokens[0]

                output_data = OutputData(id=neighbor.id, score=neighbor.distance, payload=payload)
                results.append(output_data)

            logger.debug("Returning %d results", len(results))
            return results

        except Exception as e:
            logger.error("Error occurred: %s", str(e))
            logger.error("Error type: %s", type(e))
            logger.error("Stack trace: %s", traceback.format_exc())
            raise

    def delete(self, vector_id: Optional[str] = None, ids: Optional[List[str]] = None) -> bool:
        """
        Delete vectors from the Matching Engine index.
        Args:
            vector_id (Optional[str]): Single ID to delete (for backward compatibility)
            ids (Optional[List[str]]): List of IDs of vectors to delete
        Returns:
            bool: True if vectors were deleted successfully or already deleted, False if error
        """
        logger.debug("Starting delete, vector_id: %s, ids: %s", vector_id, ids)
        try:
            # Handle both single vector_id and list of ids
            if vector_id:
                datapoint_ids = [vector_id]
            elif ids:
                datapoint_ids = ids
            else:
                raise ValueError("Either vector_id or ids must be provided")

            logger.debug("Deleting ids: %s", datapoint_ids)
            try:
                self.index.remove_datapoints(datapoint_ids=datapoint_ids)
                logger.debug("Delete completed successfully")
                return True
            except google.api_core.exceptions.NotFound:
                # If the datapoint is already deleted, consider it a success
                logger.debug("Datapoint already deleted")
                return True
            except google.api_core.exceptions.PermissionDenied as e:
                logger.error("Permission denied: %s", str(e))
                return False
            except google.api_core.exceptions.InvalidArgument as e:
                logger.error("Invalid argument: %s", str(e))
                return False

        except Exception as e:
            logger.error("Error occurred: %s", str(e))
            logger.error("Error type: %s", type(e))
            logger.error("Stack trace: %s", traceback.format_exc())
            return False

    def update(
        self,
        vector_id: str,
        vector: Optional[List[float]] = None,
        payload: Optional[Dict] = None,
    ) -> bool:
        """Update a vector and its payload.

        Args:
            vector_id: ID of the vector to update
            vector: Optional new vector values
            payload: Optional new metadata payload

        Returns:
            bool: True if update was successful

        Raises:
            ValueError: If neither vector nor payload is provided
            GoogleAPIError: If the API call fails
        """
        logger.debug("Starting update for vector_id: %s", vector_id)

        if vector is None and payload is None:
            raise ValueError("Either vector or payload must be provided for update")

        # First check if the vector exists
        try:
            existing = self.get(vector_id)
            if existing is None:
                logger.error("Vector ID not found: %s", vector_id)
                return False

            datapoint = self._create_datapoint(
                vector_id=vector_id, vector=vector if vector is not None else [], payload=payload
            )

            logger.debug("Upserting datapoint: %s", datapoint)
            self.index.upsert_datapoints(datapoints=[datapoint])
            logger.debug("Update completed successfully")
            return True

        except google.api_core.exceptions.GoogleAPIError as e:
            logger.error("API error during update: %s", str(e))
            return False
        except Exception as e:
            logger.error("Unexpected error during update: %s", str(e))
            logger.error("Stack trace: %s", traceback.format_exc())
            raise

    def get(self, vector_id: str) -> Optional[OutputData]:
        """
        Retrieve a vector by ID.
        Args:
            vector_id (str): ID of the vector to retrieve.
        Returns:
            Optional[OutputData]: Retrieved vector or None if not found.
        """
        logger.debug("Starting get for vector_id: %s", vector_id)

        try:
            if not self.vector_search_api_endpoint:
                raise ValueError("vector_search_api_endpoint is required for get operation")

            vector_search_client = aiplatform_v1.MatchServiceClient(
                client_options={"api_endpoint": self.vector_search_api_endpoint},
            )
            datapoint = aiplatform_v1.IndexDatapoint(datapoint_id=vector_id)

            query = aiplatform_v1.FindNeighborsRequest.Query(datapoint=datapoint, neighbor_count=1)
            request = aiplatform_v1.FindNeighborsRequest(
                index_endpoint=f"projects/{self.project_number}/locations/{self.region}/indexEndpoints/{self.endpoint_id}",
                deployed_index_id=self.deployment_index_id,
                queries=[query],
                return_full_datapoint=True,
            )

            try:
                response = vector_search_client.find_neighbors(request)
                logger.debug("Got response")

                if response and response.nearest_neighbors:
                    nearest = response.nearest_neighbors[0]
                    if nearest.neighbors:
                        neighbor = nearest.neighbors[0]

                        payload = {}
                        if hasattr(neighbor.datapoint, "restricts"):
                            for restrict in neighbor.datapoint.restricts:
                                if restrict.allow_list:
                                    payload[restrict.namespace] = restrict.allow_list[0]

                        return OutputData(id=neighbor.datapoint.datapoint_id, score=neighbor.distance, payload=payload)

                logger.debug("No results found")
                return None

            except google.api_core.exceptions.NotFound:
                logger.debug("Datapoint not found")
                return None
            except google.api_core.exceptions.PermissionDenied as e:
                logger.error("Permission denied: %s", str(e))
                return None

        except Exception as e:
            logger.error("Error occurred: %s", str(e))
            logger.error("Error type: %s", type(e))
            logger.error("Stack trace: %s", traceback.format_exc())
            raise

    def list_cols(self) -> List[str]:
        """
        List all collections (indexes).
        Returns:
            List[str]: List of collection names.
        """
        return [self.deployment_index_id]

    def delete_col(self):
        """
        Delete a collection (index).
        Note: This operation is not supported through the API.
        """
        logger.warning("Delete collection operation is not supported for Google Matching Engine")
        pass

    def col_info(self) -> Dict:
        """
        Get information about a collection (index).
        Returns:
            Dict: Collection information.
        """
        return {
            "index_id": self.index_id,
            "endpoint_id": self.endpoint_id,
            "project_id": self.project_id,
            "region": self.region,
        }

    def list(self, filters: Optional[Dict] = None, limit: Optional[int] = None) -> List[List[OutputData]]:
        """List vectors matching the given filters.

        Args:
            filters: Optional filters to apply
            limit: Optional maximum number of results to return

        Returns:
            List[List[OutputData]]: List of matching vectors wrapped in an extra array
            to match the interface
        """
        logger.debug("Starting list operation")
        logger.debug("Filters: %s", filters)
        logger.debug("Limit: %s", limit)

        try:
            # Use a zero vector for the search
            dimension = 768  # This should be configurable based on the model
            zero_vector = [0.0] * dimension

            # Use a large limit if none specified
            search_limit = limit if limit is not None else 10000

            results = self.search(query=zero_vector, limit=search_limit, filters=filters)

            logger.debug("Found %d results", len(results))
            return [results]  # Wrap in extra array to match interface

        except Exception as e:
            logger.error("Error in list operation: %s", str(e))
            logger.error("Stack trace: %s", traceback.format_exc())
            raise

    def create_col(self, name=None, vector_size=None, distance=None):
        """
        Create a new collection. For Google Matching Engine, collections (indexes)
        are created through the Google Cloud Console or API separately.
        This method is a no-op since indexes are pre-created.

        Args:
            name: Ignored for Google Matching Engine
            vector_size: Ignored for Google Matching Engine
            distance: Ignored for Google Matching Engine
        """
        # Google Matching Engine indexes are created through Google Cloud Console
        # This method is included only to satisfy the abstract base class
        pass

    def add(self, text: str, metadata: Optional[Dict] = None, user_id: Optional[str] = None) -> str:
        logger.debug("Starting add operation")
        logger.debug("Text: %s", text)
        logger.debug("Metadata: %s", metadata)
        logger.debug("User ID: %s", user_id)

        try:
            # Generate a unique ID for this entry
            vector_id = str(uuid.uuid4())

            # Create the payload with all necessary fields
            payload = {
                "data": text,  # Store the text in the data field
                "user_id": user_id,
                **(metadata or {}),
            }

            # Get the embedding
            vector = self.embedder.embed_query(text)

            # Insert using the insert method
            self.insert(vectors=[vector], payloads=[payload], ids=[vector_id])

            return vector_id

        except Exception as e:
            logger.error("Error occurred: %s", str(e))
            raise

    def add_texts(
        self,
        texts: List[str],
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
    ) -> List[str]:
        """Add texts to the vector store.

        Args:
            texts: List of texts to add
            metadatas: Optional list of metadata dicts
            ids: Optional list of IDs to use

        Returns:
            List[str]: List of IDs of the added texts

        Raises:
            ValueError: If texts is empty or lengths don't match
        """
        if not texts:
            raise ValueError("No texts provided")

        if metadatas and len(metadatas) != len(texts):
            raise ValueError(
                f"Number of metadata items ({len(metadatas)}) does not match number of texts ({len(texts)})"
            )

        if ids and len(ids) != len(texts):
            raise ValueError(f"Number of ids ({len(ids)}) does not match number of texts ({len(texts)})")

        logger.debug("Starting add_texts operation")
        logger.debug("Number of texts: %d", len(texts))
        logger.debug("Has metadatas: %s", metadatas is not None)
        logger.debug("Has ids: %s", ids is not None)

        if ids is None:
            ids = [str(uuid.uuid4()) for _ in texts]

        try:
            # Get embeddings
            embeddings = self.embedder.embed_documents(texts)

            # Add to store
            self.insert(vectors=embeddings, payloads=metadatas if metadatas else [{}] * len(texts), ids=ids)
            return ids

        except Exception as e:
            logger.error("Error in add_texts: %s", str(e))
            logger.error("Stack trace: %s", traceback.format_exc())
            raise

    @classmethod
    def from_texts(
        cls,
        texts: List[str],
        embedding: Any,
        metadatas: Optional[List[dict]] = None,
        ids: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> "GoogleMatchingEngine":
        """Create an instance from texts."""
        logger.debug("Creating instance from texts")
        store = cls(**kwargs)
        store.add_texts(texts=texts, metadatas=metadatas, ids=ids)
        return store

    def similarity_search_with_score(
        self,
        query: str,
        k: int = 5,
        filter: Optional[Dict] = None,
    ) -> List[Tuple[Document, float]]:
        """Return documents most similar to query with scores."""
        logger.debug("Starting similarity search with score")
        logger.debug("Query: %s", query)
        logger.debug("k: %d", k)
        logger.debug("Filter: %s", filter)

        embedding = self.embedder.embed_query(query)
        results = self.search(query=embedding, limit=k, filters=filter)

        docs_and_scores = [
            (Document(page_content=result.payload.get("text", ""), metadata=result.payload), result.score)
            for result in results
        ]
        logger.debug("Found %d results", len(docs_and_scores))
        return docs_and_scores

    def similarity_search(
        self,
        query: str,
        k: int = 5,
        filter: Optional[Dict] = None,
    ) -> List[Document]:
        """Return documents most similar to query."""
        logger.debug("Starting similarity search")
        docs_and_scores = self.similarity_search_with_score(query, k, filter)
        return [doc for doc, _ in docs_and_scores]

    def reset(self):
        """
        Reset the Google Matching Engine index.
        """
        logger.warning("Reset operation is not supported for Google Matching Engine")
        pass
