import logging
import time
from typing import Dict, Optional

from pydantic import BaseModel

from mem0.vector_stores.base import VectorStoreBase

try:
    import pymochow
    from pymochow.auth.bce_credentials import BceCredentials
    from pymochow.configuration import Configuration
    from pymochow.exception import ServerError
    from pymochow.model.enum import (
        FieldType,
        IndexType,
        MetricType,
        ServerErrCode,
        TableState,
    )
    from pymochow.model.schema import (
        AutoBuildRowCountIncrement,
        Field,
        FilteringIndex,
        HNSWParams,
        Schema,
        VectorIndex,
    )
    from pymochow.model.table import (
        FloatVector,
        Partition,
        Row,
        VectorSearchConfig,
        VectorTopkSearchRequest,
    )
except ImportError:
    raise ImportError("The 'pymochow' library is required. Please install it using 'pip install pymochow'.")

logger = logging.getLogger(__name__)


class OutputData(BaseModel):
    id: Optional[str]  # memory id
    score: Optional[float]  # distance
    payload: Optional[Dict]  # metadata


class BaiduDB(VectorStoreBase):
    def __init__(
        self,
        endpoint: str,
        account: str,
        api_key: str,
        database_name: str,
        table_name: str,
        embedding_model_dims: int,
        metric_type: MetricType,
    ) -> None:
        """Initialize the BaiduDB database.

        Args:
            endpoint (str): Endpoint URL for Baidu VectorDB.
            account (str): Account for Baidu VectorDB.
            api_key (str): API Key for Baidu VectorDB.
            database_name (str): Name of the database.
            table_name (str): Name of the table.
            embedding_model_dims (int): Dimensions of the embedding model.
            metric_type (MetricType): Metric type for similarity search.
        """
        self.endpoint = endpoint
        self.account = account
        self.api_key = api_key
        self.database_name = database_name
        self.table_name = table_name
        self.embedding_model_dims = embedding_model_dims
        self.metric_type = metric_type

        # Initialize Mochow client
        config = Configuration(credentials=BceCredentials(account, api_key), endpoint=endpoint)
        self.client = pymochow.MochowClient(config)

        # Ensure database and table exist
        self._create_database_if_not_exists()
        self.create_col(
            name=self.table_name,
            vector_size=self.embedding_model_dims,
            distance=self.metric_type,
        )

    def _create_database_if_not_exists(self):
        """Create database if it doesn't exist."""
        try:
            # Check if database exists
            databases = self.client.list_databases()
            db_exists = any(db.database_name == self.database_name for db in databases)
            if not db_exists:
                self._database = self.client.create_database(self.database_name)
                logger.info(f"Created database: {self.database_name}")
            else:
                self._database = self.client.database(self.database_name)
                logger.info(f"Database {self.database_name} already exists")
        except Exception as e:
            logger.error(f"Error creating database: {e}")
            raise

    def create_col(self, name, vector_size, distance):
        """Create a new table.

        Args:
            name (str): Name of the table to create.
            vector_size (int): Dimension of the vector.
            distance (str): Metric type for similarity search.
        """
        # Check if table already exists
        try:
            tables = self._database.list_table()
            table_exists = any(table.table_name == name for table in tables)
            if table_exists:
                logger.info(f"Table {name} already exists. Skipping creation.")
                self._table = self._database.describe_table(name)
                return

            # Convert distance string to MetricType enum
            metric_type = None
            for k, v in MetricType.__members__.items():
                if k == distance:
                    metric_type = v
            if metric_type is None:
                raise ValueError(f"Unsupported metric_type: {distance}")

            # Define table schema
            fields = [
                Field(
                    "id", FieldType.STRING, primary_key=True, partition_key=True, auto_increment=False, not_null=True
                ),
                Field("vector", FieldType.FLOAT_VECTOR, dimension=vector_size),
                Field("metadata", FieldType.JSON),
            ]

            # Create vector index
            indexes = [
                VectorIndex(
                    index_name="vector_idx",
                    index_type=IndexType.HNSW,
                    field="vector",
                    metric_type=metric_type,
                    params=HNSWParams(m=16, efconstruction=200),
                    auto_build=True,
                    auto_build_index_policy=AutoBuildRowCountIncrement(row_count_increment=10000),
                ),
                FilteringIndex(index_name="metadata_filtering_idx", fields=["metadata"]),
            ]

            schema = Schema(fields=fields, indexes=indexes)

            # Create table
            self._table = self._database.create_table(
                table_name=name, replication=3, partition=Partition(partition_num=1), schema=schema
            )
            logger.info(f"Created table: {name}")

            # Wait for table to be ready
            while True:
                time.sleep(2)
                table = self._database.describe_table(name)
                if table.state == TableState.NORMAL:
                    logger.info(f"Table {name} is ready.")
                    break
                logger.info(f"Waiting for table {name} to be ready, current state: {table.state}")
            self._table = table
        except Exception as e:
            logger.error(f"Error creating table: {e}")
            raise

    def insert(self, vectors, payloads=None, ids=None):
        """Insert vectors into the table.

        Args:
            vectors (List[List[float]]): List of vectors to insert.
            payloads (List[Dict], optional): List of payloads corresponding to vectors.
            ids (List[str], optional): List of IDs corresponding to vectors.
        """
        # Prepare data for insertion
        for idx, vector, metadata in zip(ids, vectors, payloads):
            row = Row(id=idx, vector=vector, metadata=metadata)
            self._table.upsert(rows=[row])

    def search(self, query: str, vectors: list, limit: int = 5, filters: dict = None) -> list:
        """
        Search for similar vectors.

        Args:
            query (str): Query string.
            vectors (List[float]): Query vector.
            limit (int, optional): Number of results to return. Defaults to 5.
            filters (Dict, optional): Filters to apply to the search. Defaults to None.

        Returns:
            list: Search results.
        """
        # Add filters if provided
        search_filter = None
        if filters:
            search_filter = self._create_filter(filters)

        # Create AnnSearch for vector search
        request = VectorTopkSearchRequest(
            vector_field="vector",
            vector=FloatVector(vectors),
            limit=limit,
            filter=search_filter,
            config=VectorSearchConfig(ef=200),
        )

        # Perform search
        projections = ["id", "metadata"]
        res = self._table.vector_search(request=request, projections=projections)

        # Parse results
        output = []
        for row in res.rows:
            row_data = row.get("row", {})
            output_data = OutputData(
                id=row_data.get("id"), score=row.get("score", 0.0), payload=row_data.get("metadata", {})
            )
            output.append(output_data)

        return output

    def delete(self, vector_id):
        """
        Delete a vector by ID.

        Args:
            vector_id (str): ID of the vector to delete.
        """
        self._table.delete(primary_key={"id": vector_id})

    def update(self, vector_id=None, vector=None, payload=None):
        """
        Update a vector and its payload.

        Args:
            vector_id (str): ID of the vector to update.
            vector (List[float], optional): Updated vector.
            payload (Dict, optional): Updated payload.
        """
        row = Row(id=vector_id, vector=vector, metadata=payload)
        self._table.upsert(rows=[row])

    def get(self, vector_id):
        """
        Retrieve a vector by ID.

        Args:
            vector_id (str): ID of the vector to retrieve.

        Returns:
            OutputData: Retrieved vector.
        """
        projections = ["id", "metadata"]
        result = self._table.query(primary_key={"id": vector_id}, projections=projections)
        row = result.row
        return OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))

    def list_cols(self):
        """
        List all tables (collections).

        Returns:
            List[str]: List of table names.
        """
        tables = self._database.list_table()
        return [table.table_name for table in tables]

    def delete_col(self):
        """Delete the table."""
        try:
            tables = self._database.list_table()

            # skip drop table if table not exists
            table_exists = any(table.table_name == self.table_name for table in tables)
            if not table_exists:
                logger.info(f"Table {self.table_name} does not exist, skipping deletion")
                return

            # Delete the table
            self._database.drop_table(self.table_name)
            logger.info(f"Initiated deletion of table {self.table_name}")

            # Wait for table to be completely deleted
            while True:
                time.sleep(2)
                try:
                    self._database.describe_table(self.table_name)
                    logger.info(f"Waiting for table {self.table_name} to be deleted...")
                except ServerError as e:
                    if e.code == ServerErrCode.TABLE_NOT_EXIST:
                        logger.info(f"Table {self.table_name} has been completely deleted")
                        break
                    logger.error(f"Error checking table status: {e}")
                    raise
        except Exception as e:
            logger.error(f"Error deleting table: {e}")
            raise

    def col_info(self):
        """
        Get information about the table.

        Returns:
            Dict[str, Any]: Table information.
        """
        return self._table.stats()

    def list(self, filters: dict = None, limit: int = 100) -> list:
        """
        List all vectors in the table.

        Args:
            filters (Dict, optional): Filters to apply to the list.
            limit (int, optional): Number of vectors to return. Defaults to 100.

        Returns:
            List[OutputData]: List of vectors.
        """
        projections = ["id", "metadata"]
        list_filter = self._create_filter(filters) if filters else None
        result = self._table.select(filter=list_filter, projections=projections, limit=limit)

        memories = []
        for row in result.rows:
            obj = OutputData(id=row.get("id"), score=None, payload=row.get("metadata", {}))
            memories.append(obj)

        return [memories]

    def reset(self):
        """Reset the table by deleting and recreating it."""
        logger.warning(f"Resetting table {self.table_name}...")
        try:
            self.delete_col()
            self.create_col(
                name=self.table_name,
                vector_size=self.embedding_model_dims,
                distance=self.metric_type,
            )
        except Exception as e:
            logger.warning(f"Error resetting table: {e}")
            raise

    def _create_filter(self, filters: dict) -> str:
        """
        Create filter expression for queries.

        Args:
            filters (dict): Filter conditions.

        Returns:
            str: Filter expression.
        """
        conditions = []
        for key, value in filters.items():
            if isinstance(value, str):
                conditions.append(f'metadata["{key}"] = "{value}"')
            else:
                conditions.append(f'metadata["{key}"] = {value}')
        return " AND ".join(conditions)
