"""
Vector Store - Structured Multi-View Indexing Implementation (Section 3.2)

Paper Reference: Section 3.2 - Structured Indexing
Implements the three structured indexing dimensions:
- Semantic Layer: Dense vectors v_k in R^d (embedding-based similarity)
- Lexical Layer: Full-text search with Tantivy FTS
- Symbolic Layer: Metadata R_k = {(key, val)} (structured filtering via SQL)
"""
from typing import List, Optional, Dict, Any
import os

import lancedb
import pyarrow as pa

from simplemem.models.memory_entry import MemoryEntry
from simplemem.utils.embedding import EmbeddingModel
from simplemem.config import get_config


class VectorStore:
    """
    Structured Multi-View Indexing - Storage and retrieval for Atomic Entries

    Paper Reference: Section 3.2 - Structured Indexing
    Implements M(m_k) with three structured layers:
    1. Semantic Layer: Dense embedding vectors for conceptual similarity
    2. Lexical Layer: Full-text search via Tantivy FTS index
    3. Symbolic Layer: SQL-based metadata filtering with DataFusion
    """

    def __init__(
        self,
        db_path: Optional[str] = None,
        embedding_model: Optional[EmbeddingModel] = None,
        table_name: Optional[str] = None,
        storage_options: Optional[Dict[str, Any]] = None,
    ):
        config = get_config()
        self.db_path = db_path or config.lancedb_path
        self.embedding_model = embedding_model or EmbeddingModel()
        self.table_name = table_name or config.memory_table_name
        self.table: Any = None
        self._fts_initialized = False

        # Detect if using cloud storage (GCS, S3, Azure)
        self._is_cloud_storage = self.db_path.startswith(("gs://", "s3://", "az://"))

        # Connect to database
        if self._is_cloud_storage:
            self.db = lancedb.connect(self.db_path, storage_options=storage_options)
        else:
            os.makedirs(self.db_path, exist_ok=True)
            self.db = lancedb.connect(self.db_path)

        self._init_table()

    def _init_table(self) -> None:
        """Initialize table schema and FTS index."""
        schema = pa.schema(
            [
                pa.field("entry_id", pa.string()),
                pa.field("lossless_restatement", pa.string()),
                pa.field("keywords", pa.list_(pa.string())),
                pa.field("timestamp", pa.string()),
                pa.field("location", pa.string()),
                pa.field("persons", pa.list_(pa.string())),
                pa.field("entities", pa.list_(pa.string())),
                pa.field("topic", pa.string()),
                pa.field(
                    "vector", pa.list_(pa.float32(), self.embedding_model.dimension)
                ),
            ]
        )

        if self.table_name not in self.db.table_names():
            self.table = self.db.create_table(self.table_name, schema=schema)
            print(f"Created new table: {self.table_name}")
        else:
            self.table = self.db.open_table(self.table_name)
            print(f"Opened existing table: {self.table_name}")

    def _init_fts_index(self) -> None:
        """Initialize Full-Text Search index on lossless_restatement column."""
        if self._fts_initialized:
            return

        try:
            if self._is_cloud_storage:
                # Use native FTS for cloud storage (Tantivy only works with local filesystem)
                self.table.create_fts_index(
                    "lossless_restatement",
                    use_tantivy=False,
                    replace=True,
                )
                print("FTS index created (native mode for cloud storage)")
            else:
                # Use Tantivy FTS for local storage (better performance)
                self.table.create_fts_index(
                    "lossless_restatement",
                    use_tantivy=True,
                    tokenizer_name="en_stem",
                    replace=True,
                )
                print("FTS index created (Tantivy mode)")
            self._fts_initialized = True
        except Exception as e:
            print(f"FTS index creation skipped: {e}")

    def _results_to_entries(self, results: List[Dict[str, Any]]) -> List[MemoryEntry]:
        """Convert LanceDB results to MemoryEntry objects."""
        entries = []
        for r in results:
            try:
                entries.append(
                    MemoryEntry(
                        entry_id=r["entry_id"],
                        lossless_restatement=r["lossless_restatement"],
                        keywords=list(r.get("keywords") or []),
                        timestamp=r.get("timestamp") or None,
                        location=r.get("location") or None,
                        persons=list(r.get("persons") or []),
                        entities=list(r.get("entities") or []),
                        topic=r.get("topic") or None,
                    )
                )
            except Exception as e:
                print(f"Warning: Failed to parse result: {e}")
                continue
        return entries

    def add_entries(self, entries: List[MemoryEntry]) -> None:
        """Batch add memory entries."""
        if not entries:
            return

        restatements = [entry.lossless_restatement for entry in entries]
        vectors = self.embedding_model.encode_documents(restatements)

        data = []
        for entry, vector in zip(entries, vectors):
            data.append(
                {
                    "entry_id": entry.entry_id,
                    "lossless_restatement": entry.lossless_restatement,
                    "keywords": entry.keywords,
                    "timestamp": entry.timestamp or "",
                    "location": entry.location or "",
                    "persons": entry.persons,
                    "entities": entry.entities,
                    "topic": entry.topic or "",
                    "vector": vector.tolist(),
                }
            )

        self.table.add(data)
        print(f"Added {len(entries)} memory entries")

        # Initialize FTS index after first data insertion
        if not self._fts_initialized:
            self._init_fts_index()

    def semantic_search(self, query: str, top_k: int = 5) -> List[MemoryEntry]:
        """
        Semantic Layer Search - Dense vector similarity.

        Paper Reference: Section 3.1
        Retrieves based on v_k = E_dense(S_k) where S_k is the lossless restatement.
        """
        try:
            if self.table.count_rows() == 0:
                return []

            query_vector = self.embedding_model.encode_single(query, is_query=True)
            results = self.table.search(query_vector.tolist()).limit(top_k).to_list()
            return self._results_to_entries(results)

        except Exception as e:
            print(f"Error during semantic search: {e}")
            return []

    def keyword_search(
        self, keywords: List[str], top_k: int = 3
    ) -> List[MemoryEntry]:
        """
        Lexical Layer Search - Full-text search via Tantivy FTS.

        Paper Reference: Section 3.1
        Retrieves based on BM25 text matching using LanceDB native FTS.
        """
        try:
            if not keywords or self.table.count_rows() == 0:
                return []

            # LanceDB auto-detects string input as FTS query when FTS index exists
            query = " ".join(keywords)
            results = self.table.search(query).limit(top_k).to_list()
            return self._results_to_entries(results)

        except Exception as e:
            print(f"Error during keyword search: {e}")
            return []

    def structured_search(
        self,
        persons: Optional[List[str]] = None,
        timestamp_range: Optional[tuple] = None,
        location: Optional[str] = None,
        entities: Optional[List[str]] = None,
        top_k: Optional[int] = None,
    ) -> List[MemoryEntry]:
        """
        Symbolic Layer Search - SQL-based metadata filtering.

        Paper Reference: Section 3.1
        Retrieves based on R_k = {(key, val)} for structured constraints.
        Uses DataFusion SQL expressions with array_has_any for list columns.
        """
        try:
            if self.table.count_rows() == 0:
                return []

            if not any([persons, timestamp_range, location, entities]):
                return []

            conditions = []

            if persons:
                values = ", ".join([f"'{p}'" for p in persons])
                conditions.append(f"array_has_any(persons, make_array({values}))")

            if location:
                safe_location = location.replace("'", "''")
                conditions.append(f"location LIKE '%{safe_location}%'")

            if entities:
                values = ", ".join([f"'{e}'" for e in entities])
                conditions.append(f"array_has_any(entities, make_array({values}))")

            if timestamp_range:
                start_time, end_time = timestamp_range
                conditions.append(
                    f"timestamp >= '{start_time}' AND timestamp <= '{end_time}'"
                )

            where_clause = " AND ".join(conditions)
            query = self.table.search().where(where_clause, prefilter=True)

            if top_k:
                query = query.limit(top_k)

            results = query.to_list()
            return self._results_to_entries(results)

        except Exception as e:
            print(f"Error during structured search: {e}")
            return []

    def get_all_entries(self) -> List[MemoryEntry]:
        """Get all memory entries."""
        results = self.table.to_arrow().to_pylist()
        return self._results_to_entries(results)

    def optimize(self) -> None:
        """Optimize table after bulk insertions for better query performance."""
        self.table.optimize()
        print("Table optimized")

    def clear(self) -> None:
        """Clear all data and reinitialize table."""
        self.db.drop_table(self.table_name)
        self._fts_initialized = False
        self._init_table()
        print("Database cleared")
