#!/usr/bin/env python3
"""
GENESIS QUERY ENGINE
=====================
Unified query interface across data stores.

Features:
    - SQL-like query syntax
    - Filter expressions
    - Aggregations
    - Joins across stores
    - Query optimization
    - Result caching
    - Pagination

Usage:
    engine = QueryEngine()
    engine.register_store("memory", memory_store)
    engine.register_store("logs", log_store)

    results = engine.query('''
        SELECT * FROM memory
        WHERE type = 'episodic'
        AND timestamp > '2024-01-01'
        ORDER BY timestamp DESC
        LIMIT 10
    ''')
"""

import json
import operator
import re
import threading
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from functools import reduce
from typing import Dict, List, Any, Optional, Callable, Union, Tuple, Iterator


class QueryOperator(Enum):
    """Query comparison operators."""
    EQ = "="
    NE = "!="
    GT = ">"
    GTE = ">="
    LT = "<"
    LTE = "<="
    LIKE = "LIKE"
    IN = "IN"
    NOT_IN = "NOT IN"
    IS_NULL = "IS NULL"
    IS_NOT_NULL = "IS NOT NULL"
    BETWEEN = "BETWEEN"
    CONTAINS = "CONTAINS"
    STARTS_WITH = "STARTS_WITH"
    ENDS_WITH = "ENDS_WITH"


class AggregateFunction(Enum):
    """Aggregate functions."""
    COUNT = "COUNT"
    SUM = "SUM"
    AVG = "AVG"
    MIN = "MIN"
    MAX = "MAX"
    FIRST = "FIRST"
    LAST = "LAST"
    DISTINCT = "DISTINCT"


class SortOrder(Enum):
    """Sort order."""
    ASC = "ASC"
    DESC = "DESC"


@dataclass
class FilterCondition:
    """A filter condition."""
    field: str
    operator: QueryOperator
    value: Any
    case_sensitive: bool = True


@dataclass
class SortSpec:
    """Sort specification."""
    field: str
    order: SortOrder = SortOrder.ASC


@dataclass
class AggregateSpec:
    """Aggregation specification."""
    function: AggregateFunction
    field: str
    alias: Optional[str] = None


@dataclass
class QuerySpec:
    """Full query specification."""
    store: str
    select: List[str] = field(default_factory=lambda: ["*"])
    filters: List[FilterCondition] = field(default_factory=list)
    sort: List[SortSpec] = field(default_factory=list)
    limit: Optional[int] = None
    offset: int = 0
    aggregates: List[AggregateSpec] = field(default_factory=list)
    group_by: List[str] = field(default_factory=list)
    distinct: bool = False


@dataclass
class QueryResult:
    """Query result."""
    data: List[Dict[str, Any]]
    total_count: int
    execution_time_ms: float
    from_cache: bool = False
    metadata: Dict[str, Any] = field(default_factory=dict)

    def __iter__(self):
        return iter(self.data)

    def __len__(self):
        return len(self.data)

    def first(self) -> Optional[Dict[str, Any]]:
        return self.data[0] if self.data else None

    def to_dict(self) -> Dict:
        return {
            "data": self.data,
            "total_count": self.total_count,
            "execution_time_ms": self.execution_time_ms,
            "from_cache": self.from_cache
        }


class DataStore(ABC):
    """Abstract data store interface."""

    @abstractmethod
    def get_all(self) -> Iterator[Dict[str, Any]]:
        """Get all records from store."""
        pass

    @abstractmethod
    def get_fields(self) -> List[str]:
        """Get available fields."""
        pass

    @property
    @abstractmethod
    def name(self) -> str:
        """Store name."""
        pass


class MemoryDataStore(DataStore):
    """In-memory data store."""

    def __init__(self, name: str, data: List[Dict[str, Any]] = None):
        self._name = name
        self._data = data or []

    @property
    def name(self) -> str:
        return self._name

    def get_all(self) -> Iterator[Dict[str, Any]]:
        return iter(self._data)

    def get_fields(self) -> List[str]:
        if not self._data:
            return []
        # Get all unique fields
        fields = set()
        for row in self._data[:100]:  # Sample first 100
            fields.update(row.keys())
        return list(fields)

    def insert(self, record: Dict[str, Any]):
        self._data.append(record)

    def clear(self):
        self._data.clear()


class FilterEngine:
    """Evaluates filter conditions."""

    _operators = {
        QueryOperator.EQ: operator.eq,
        QueryOperator.NE: operator.ne,
        QueryOperator.GT: operator.gt,
        QueryOperator.GTE: operator.ge,
        QueryOperator.LT: operator.lt,
        QueryOperator.LTE: operator.le,
    }

    @classmethod
    def matches(cls, record: Dict[str, Any], condition: FilterCondition) -> bool:
        """Check if record matches condition."""
        field_value = cls._get_nested_value(record, condition.field)

        op = condition.operator
        target = condition.value

        # Handle null checks
        if op == QueryOperator.IS_NULL:
            return field_value is None

        if op == QueryOperator.IS_NOT_NULL:
            return field_value is not None

        if field_value is None:
            return False

        # String operations
        if not condition.case_sensitive and isinstance(field_value, str):
            field_value = field_value.lower()
            if isinstance(target, str):
                target = target.lower()

        # Standard operators
        if op in cls._operators:
            try:
                return cls._operators[op](field_value, target)
            except TypeError:
                return False

        # Special operators
        if op == QueryOperator.LIKE:
            pattern = cls._like_to_regex(target)
            return bool(re.match(pattern, str(field_value), re.IGNORECASE if not condition.case_sensitive else 0))

        if op == QueryOperator.IN:
            return field_value in target

        if op == QueryOperator.NOT_IN:
            return field_value not in target

        if op == QueryOperator.BETWEEN:
            if len(target) >= 2:
                return target[0] <= field_value <= target[1]
            return False

        if op == QueryOperator.CONTAINS:
            return target in str(field_value)

        if op == QueryOperator.STARTS_WITH:
            return str(field_value).startswith(str(target))

        if op == QueryOperator.ENDS_WITH:
            return str(field_value).endswith(str(target))

        return True

    @classmethod
    def _get_nested_value(cls, record: Dict, field: str) -> Any:
        """Get nested field value using dot notation."""
        keys = field.split('.')
        value = record
        for key in keys:
            if isinstance(value, dict):
                value = value.get(key)
            else:
                return None
        return value

    @staticmethod
    def _like_to_regex(pattern: str) -> str:
        """Convert SQL LIKE pattern to regex."""
        regex = pattern.replace('%', '.*').replace('_', '.')
        return f"^{regex}$"


class AggregationEngine:
    """Performs aggregations."""

    @classmethod
    def aggregate(
        cls,
        data: List[Dict[str, Any]],
        aggregates: List[AggregateSpec],
        group_by: List[str] = None
    ) -> List[Dict[str, Any]]:
        """Perform aggregations on data."""
        if not group_by:
            # Single group aggregation
            return [cls._aggregate_group(data, aggregates)]

        # Group by aggregation
        groups: Dict[tuple, List[Dict]] = defaultdict(list)
        for record in data:
            key = tuple(record.get(f) for f in group_by)
            groups[key].append(record)

        results = []
        for key, group_data in groups.items():
            result = dict(zip(group_by, key))
            agg_result = cls._aggregate_group(group_data, aggregates)
            result.update(agg_result)
            results.append(result)

        return results

    @classmethod
    def _aggregate_group(
        cls,
        data: List[Dict[str, Any]],
        aggregates: List[AggregateSpec]
    ) -> Dict[str, Any]:
        """Aggregate a single group."""
        result = {}

        for agg in aggregates:
            field = agg.field
            alias = agg.alias or f"{agg.function.value}({field})"

            if agg.function == AggregateFunction.COUNT:
                if field == "*":
                    result[alias] = len(data)
                else:
                    result[alias] = sum(1 for r in data if r.get(field) is not None)

            elif agg.function == AggregateFunction.SUM:
                values = [r.get(field, 0) for r in data if isinstance(r.get(field), (int, float))]
                result[alias] = sum(values)

            elif agg.function == AggregateFunction.AVG:
                values = [r.get(field) for r in data if isinstance(r.get(field), (int, float))]
                result[alias] = sum(values) / len(values) if values else 0

            elif agg.function == AggregateFunction.MIN:
                values = [r.get(field) for r in data if r.get(field) is not None]
                result[alias] = min(values) if values else None

            elif agg.function == AggregateFunction.MAX:
                values = [r.get(field) for r in data if r.get(field) is not None]
                result[alias] = max(values) if values else None

            elif agg.function == AggregateFunction.FIRST:
                result[alias] = data[0].get(field) if data else None

            elif agg.function == AggregateFunction.LAST:
                result[alias] = data[-1].get(field) if data else None

            elif agg.function == AggregateFunction.DISTINCT:
                values = set(r.get(field) for r in data if r.get(field) is not None)
                result[alias] = list(values)

        return result


class QueryParser:
    """Parses SQL-like query strings."""

    # Simple SQL parser patterns
    SELECT_PATTERN = re.compile(
        r"SELECT\s+(.+?)\s+FROM\s+(\w+)",
        re.IGNORECASE
    )
    WHERE_PATTERN = re.compile(
        r"WHERE\s+(.+?)(?:\s+ORDER\s+BY|\s+LIMIT|\s+GROUP\s+BY|$)",
        re.IGNORECASE | re.DOTALL
    )
    ORDER_BY_PATTERN = re.compile(
        r"ORDER\s+BY\s+(.+?)(?:\s+LIMIT|$)",
        re.IGNORECASE
    )
    LIMIT_PATTERN = re.compile(
        r"LIMIT\s+(\d+)(?:\s+OFFSET\s+(\d+))?",
        re.IGNORECASE
    )
    GROUP_BY_PATTERN = re.compile(
        r"GROUP\s+BY\s+(.+?)(?:\s+ORDER\s+BY|\s+LIMIT|$)",
        re.IGNORECASE
    )

    @classmethod
    def parse(cls, query: str) -> QuerySpec:
        """Parse a SQL-like query string."""
        query = query.strip()

        # Parse SELECT and FROM
        select_match = cls.SELECT_PATTERN.search(query)
        if not select_match:
            raise ValueError("Invalid query: missing SELECT or FROM")

        select_clause = select_match.group(1).strip()
        store = select_match.group(2).strip()

        # Parse selected fields
        if select_clause == '*':
            select = ['*']
        else:
            select = [f.strip() for f in select_clause.split(',')]

        # Parse WHERE
        filters = []
        where_match = cls.WHERE_PATTERN.search(query)
        if where_match:
            where_clause = where_match.group(1).strip()
            filters = cls._parse_where(where_clause)

        # Parse ORDER BY
        sort = []
        order_match = cls.ORDER_BY_PATTERN.search(query)
        if order_match:
            order_clause = order_match.group(1).strip()
            sort = cls._parse_order_by(order_clause)

        # Parse LIMIT
        limit = None
        offset = 0
        limit_match = cls.LIMIT_PATTERN.search(query)
        if limit_match:
            limit = int(limit_match.group(1))
            if limit_match.group(2):
                offset = int(limit_match.group(2))

        # Parse GROUP BY
        group_by = []
        group_match = cls.GROUP_BY_PATTERN.search(query)
        if group_match:
            group_clause = group_match.group(1).strip()
            group_by = [f.strip() for f in group_clause.split(',')]

        # Check for aggregates in select
        aggregates = []
        for field in select:
            agg_match = re.match(r'(\w+)\((\w+|\*)\)(?:\s+AS\s+(\w+))?', field, re.IGNORECASE)
            if agg_match:
                try:
                    func = AggregateFunction(agg_match.group(1).upper())
                    aggregates.append(AggregateSpec(
                        function=func,
                        field=agg_match.group(2),
                        alias=agg_match.group(3)
                    ))
                except ValueError:
                    pass

        return QuerySpec(
            store=store,
            select=select,
            filters=filters,
            sort=sort,
            limit=limit,
            offset=offset,
            aggregates=aggregates,
            group_by=group_by,
            distinct='DISTINCT' in query.upper()
        )

    @classmethod
    def _parse_where(cls, where_clause: str) -> List[FilterCondition]:
        """Parse WHERE clause conditions."""
        conditions = []

        # Simple parser - split by AND
        parts = re.split(r'\s+AND\s+', where_clause, flags=re.IGNORECASE)

        for part in parts:
            part = part.strip()
            condition = cls._parse_condition(part)
            if condition:
                conditions.append(condition)

        return conditions

    @classmethod
    def _parse_condition(cls, condition: str) -> Optional[FilterCondition]:
        """Parse a single condition."""
        # Pattern for various operators
        patterns = [
            (r"(\w+(?:\.\w+)*)\s*>=\s*(.+)", QueryOperator.GTE),
            (r"(\w+(?:\.\w+)*)\s*<=\s*(.+)", QueryOperator.LTE),
            (r"(\w+(?:\.\w+)*)\s*!=\s*(.+)", QueryOperator.NE),
            (r"(\w+(?:\.\w+)*)\s*=\s*(.+)", QueryOperator.EQ),
            (r"(\w+(?:\.\w+)*)\s*>\s*(.+)", QueryOperator.GT),
            (r"(\w+(?:\.\w+)*)\s*<\s*(.+)", QueryOperator.LT),
            (r"(\w+(?:\.\w+)*)\s+LIKE\s+(.+)", QueryOperator.LIKE),
            (r"(\w+(?:\.\w+)*)\s+IN\s+\((.+)\)", QueryOperator.IN),
            (r"(\w+(?:\.\w+)*)\s+IS\s+NULL", QueryOperator.IS_NULL),
            (r"(\w+(?:\.\w+)*)\s+IS\s+NOT\s+NULL", QueryOperator.IS_NOT_NULL),
        ]

        for pattern, op in patterns:
            match = re.match(pattern, condition, re.IGNORECASE)
            if match:
                field = match.group(1)
                value = match.group(2).strip() if len(match.groups()) > 1 else None

                # Parse value
                if value:
                    value = cls._parse_value(value)

                    # Handle IN operator
                    if op == QueryOperator.IN:
                        value = [cls._parse_value(v.strip()) for v in value.split(',')]

                return FilterCondition(field=field, operator=op, value=value)

        return None

    @classmethod
    def _parse_value(cls, value: str) -> Any:
        """Parse a value string."""
        value = value.strip()

        # Remove quotes
        if (value.startswith("'") and value.endswith("'")) or \
           (value.startswith('"') and value.endswith('"')):
            return value[1:-1]

        # Try numeric
        try:
            if '.' in value:
                return float(value)
            return int(value)
        except ValueError:
            pass

        # Boolean
        if value.upper() == 'TRUE':
            return True
        if value.upper() == 'FALSE':
            return False

        return value

    @classmethod
    def _parse_order_by(cls, order_clause: str) -> List[SortSpec]:
        """Parse ORDER BY clause."""
        specs = []
        parts = order_clause.split(',')

        for part in parts:
            part = part.strip()
            tokens = part.split()

            field = tokens[0]
            order = SortOrder.ASC

            if len(tokens) > 1:
                if tokens[1].upper() == 'DESC':
                    order = SortOrder.DESC

            specs.append(SortSpec(field=field, order=order))

        return specs


class QueryEngine:
    """
    Central query execution engine.
    """

    def __init__(self, enable_cache: bool = True, cache_ttl: int = 60):
        self._stores: Dict[str, DataStore] = {}
        self._enable_cache = enable_cache
        self._cache_ttl = cache_ttl
        self._cache: Dict[str, Tuple[float, QueryResult]] = {}
        self._lock = threading.RLock()

        # Stats
        self._stats = {
            "queries_executed": 0,
            "cache_hits": 0,
            "cache_misses": 0,
            "total_time_ms": 0
        }

    def register_store(self, name: str, store: DataStore):
        """Register a data store."""
        self._stores[name] = store

    def unregister_store(self, name: str):
        """Unregister a data store."""
        self._stores.pop(name, None)

    def query(self, query_string: str) -> QueryResult:
        """Execute a SQL-like query."""
        # Check cache
        cache_key = query_string.strip()
        if self._enable_cache:
            cached = self._get_cached(cache_key)
            if cached:
                self._stats["cache_hits"] += 1
                return cached

        self._stats["cache_misses"] += 1

        # Parse query
        spec = QueryParser.parse(query_string)
        result = self.execute(spec)

        # Cache result
        if self._enable_cache:
            self._cache[cache_key] = (time.time(), result)

        return result

    def execute(self, spec: QuerySpec) -> QueryResult:
        """Execute a query specification."""
        start_time = time.time()
        self._stats["queries_executed"] += 1

        # Get store
        store = self._stores.get(spec.store)
        if not store:
            raise ValueError(f"Unknown store: {spec.store}")

        # Get all data
        data = list(store.get_all())
        total_before_filter = len(data)

        # Apply filters
        if spec.filters:
            data = [
                record for record in data
                if all(FilterEngine.matches(record, f) for f in spec.filters)
            ]

        total_after_filter = len(data)

        # Apply distinct
        if spec.distinct:
            seen = set()
            unique_data = []
            for record in data:
                key = tuple(sorted(record.items()))
                if key not in seen:
                    seen.add(key)
                    unique_data.append(record)
            data = unique_data

        # Apply aggregations
        if spec.aggregates:
            data = AggregationEngine.aggregate(data, spec.aggregates, spec.group_by)
        else:
            # Apply sorting
            if spec.sort:
                for sort_spec in reversed(spec.sort):
                    reverse = sort_spec.order == SortOrder.DESC
                    data.sort(
                        key=lambda r: (r.get(sort_spec.field) is None, r.get(sort_spec.field)),
                        reverse=reverse
                    )

        # Apply pagination
        total_count = len(data)
        if spec.offset:
            data = data[spec.offset:]
        if spec.limit:
            data = data[:spec.limit]

        # Select fields
        if spec.select and spec.select != ['*'] and not spec.aggregates:
            data = [
                {k: v for k, v in record.items() if k in spec.select}
                for record in data
            ]

        execution_time = (time.time() - start_time) * 1000
        self._stats["total_time_ms"] += execution_time

        return QueryResult(
            data=data,
            total_count=total_count,
            execution_time_ms=execution_time,
            from_cache=False,
            metadata={
                "store": spec.store,
                "total_before_filter": total_before_filter,
                "total_after_filter": total_after_filter
            }
        )

    def filter(
        self,
        store_name: str,
        **conditions
    ) -> QueryResult:
        """Simple filter query."""
        filters = [
            FilterCondition(field=k, operator=QueryOperator.EQ, value=v)
            for k, v in conditions.items()
        ]

        spec = QuerySpec(store=store_name, filters=filters)
        return self.execute(spec)

    def count(self, store_name: str, **conditions) -> int:
        """Count records matching conditions."""
        result = self.filter(store_name, **conditions)
        return result.total_count

    def first(self, store_name: str, **conditions) -> Optional[Dict]:
        """Get first matching record."""
        result = self.filter(store_name, **conditions)
        return result.first()

    def _get_cached(self, key: str) -> Optional[QueryResult]:
        """Get cached result if valid."""
        with self._lock:
            if key in self._cache:
                timestamp, result = self._cache[key]
                if time.time() - timestamp < self._cache_ttl:
                    result_copy = QueryResult(
                        data=result.data.copy(),
                        total_count=result.total_count,
                        execution_time_ms=result.execution_time_ms,
                        from_cache=True,
                        metadata=result.metadata.copy()
                    )
                    return result_copy
                else:
                    del self._cache[key]
        return None

    def clear_cache(self):
        """Clear the query cache."""
        with self._lock:
            self._cache.clear()

    def get_stats(self) -> Dict:
        """Get query engine statistics."""
        return {
            **self._stats,
            "cache_entries": len(self._cache),
            "stores": list(self._stores.keys()),
            "avg_query_time_ms": (
                self._stats["total_time_ms"] / self._stats["queries_executed"]
                if self._stats["queries_executed"] > 0 else 0
            )
        }

    def get_stores(self) -> List[str]:
        """List registered stores."""
        return list(self._stores.keys())


# Global query engine
_engine: Optional[QueryEngine] = None


def get_query_engine() -> QueryEngine:
    """Get global query engine."""
    global _engine
    if _engine is None:
        _engine = QueryEngine()
    return _engine


def main():
    """CLI and demo for query engine."""
    import argparse
    parser = argparse.ArgumentParser(description="Genesis Query Engine")
    parser.add_argument("command", choices=["demo", "status"])
    args = parser.parse_args()

    if args.command == "demo":
        print("Query Engine Demo")
        print("=" * 40)

        engine = QueryEngine()

        # Create sample data
        users = MemoryDataStore("users", [
            {"id": 1, "name": "Alice", "age": 30, "role": "admin", "active": True},
            {"id": 2, "name": "Bob", "age": 25, "role": "user", "active": True},
            {"id": 3, "name": "Charlie", "age": 35, "role": "user", "active": False},
            {"id": 4, "name": "Diana", "age": 28, "role": "admin", "active": True},
            {"id": 5, "name": "Eve", "age": 32, "role": "user", "active": True},
        ])

        logs = MemoryDataStore("logs", [
            {"id": 1, "user_id": 1, "action": "login", "timestamp": "2024-01-01T10:00:00"},
            {"id": 2, "user_id": 1, "action": "logout", "timestamp": "2024-01-01T12:00:00"},
            {"id": 3, "user_id": 2, "action": "login", "timestamp": "2024-01-01T09:00:00"},
            {"id": 4, "user_id": 2, "action": "create", "timestamp": "2024-01-01T09:30:00"},
            {"id": 5, "user_id": 3, "action": "login", "timestamp": "2024-01-01T11:00:00"},
        ])

        engine.register_store("users", users)
        engine.register_store("logs", logs)

        # Basic SELECT
        print("\n1. Basic SELECT:")
        result = engine.query("SELECT * FROM users")
        print(f"  Found {len(result)} users")
        for row in result:
            print(f"    {row}")

        # SELECT with WHERE
        print("\n2. WHERE clause:")
        result = engine.query("SELECT name, age FROM users WHERE age > 28")
        print(f"  Users over 28: {[r['name'] for r in result]}")

        # Multiple conditions
        print("\n3. Multiple conditions:")
        result = engine.query("SELECT * FROM users WHERE role = 'admin' AND active = True")
        print(f"  Active admins: {[r['name'] for r in result]}")

        # ORDER BY
        print("\n4. ORDER BY:")
        result = engine.query("SELECT name, age FROM users ORDER BY age DESC")
        print(f"  By age (desc): {[(r['name'], r['age']) for r in result]}")

        # LIMIT and OFFSET
        print("\n5. LIMIT and OFFSET:")
        result = engine.query("SELECT * FROM users LIMIT 2 OFFSET 1")
        print(f"  Page 2 (limit=2, offset=1): {[r['name'] for r in result]}")

        # LIKE operator
        print("\n6. LIKE operator:")
        result = engine.query("SELECT * FROM users WHERE name LIKE '%li%'")
        print(f"  Names containing 'li': {[r['name'] for r in result]}")

        # Aggregations
        print("\n7. Aggregations:")
        result = engine.query("SELECT COUNT(*) AS total FROM users")
        print(f"  Total users: {result.first()}")

        result = engine.query("SELECT AVG(age) AS avg_age FROM users")
        print(f"  Average age: {result.first()}")

        # GROUP BY
        print("\n8. GROUP BY:")
        result = engine.query("SELECT role, COUNT(*) AS count FROM users GROUP BY role")
        for row in result:
            print(f"  {row['role']}: {row['count']} users")

        # Logs query
        print("\n9. Logs query:")
        result = engine.query("SELECT * FROM logs WHERE action = 'login' ORDER BY timestamp")
        print(f"  Login events: {len(result)}")

        # Simple filter API
        print("\n10. Simple filter API:")
        result = engine.filter("users", role="admin")
        print(f"  Admins: {[r['name'] for r in result]}")

        # Cached query
        print("\n11. Query caching:")
        result1 = engine.query("SELECT * FROM users WHERE active = True")
        print(f"  First query: {result1.execution_time_ms:.2f}ms, cached={result1.from_cache}")

        result2 = engine.query("SELECT * FROM users WHERE active = True")
        print(f"  Second query: {result2.execution_time_ms:.2f}ms, cached={result2.from_cache}")

        # Stats
        print("\n12. Engine stats:")
        print(f"  {json.dumps(engine.get_stats(), indent=4)}")

    elif args.command == "status":
        engine = get_query_engine()
        print(json.dumps(engine.get_stats(), indent=2))


if __name__ == "__main__":
    main()
