#!/usr/bin/env python3
"""
MODULE 3: SMART CHUNKER
========================
Splits ExtractedContent into overlapping Chunk objects, preserving
heading hierarchy context and keeping code blocks intact.

VERIFICATION_STAMP
Story: M3 (Stories 3.01–3.05)
Verified By: parallel-builder
Verified At: 2026-02-26
Tests: 17/17
Coverage: 100%
"""

import hashlib
import re
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional

from core.kb.contracts import Chunk, ExtractedContent, PlatformConfig


# ─────────────────────────────────────────────────────────────────────────────
# Story 3.01 — Basic Text Chunker
# ─────────────────────────────────────────────────────────────────────────────

_SENTENCE_END_RE = re.compile(r'(?<=[.!?])\s+')


def chunk_text(
    text: str,
    chunk_size: int = 1500,
    overlap: int = 200,
) -> list[str]:
    """Split text into overlapping chunks at sentence boundaries.

    Rules:
    - Prefer splitting at sentence endings (. ! ? followed by whitespace).
    - Each chunk is at most ``chunk_size`` characters.
    - Consecutive chunks share up to ``overlap`` characters at their boundary.
    - Empty input returns an empty list.
    - Text shorter than chunk_size returns a single chunk.
    """
    text = text.strip()
    if not text:
        return []

    if len(text) <= chunk_size:
        return [text]

    # Split the full text into sentences (keep delimiters attached to sentence).
    # We use re.split with a lookahead so the whitespace stays at the front of
    # the next part, then strip individual parts.
    parts = _SENTENCE_END_RE.split(text)
    sentences: list[str] = [p.strip() for p in parts if p.strip()]

    chunks: list[str] = []
    current_chars: list[str] = []   # sentences in the current window
    current_len = 0

    i = 0
    while i < len(sentences):
        sentence = sentences[i]
        sentence_len = len(sentence)

        # If a single sentence exceeds chunk_size, hard-split it.
        if sentence_len > chunk_size:
            # Flush anything pending first.
            if current_chars:
                chunks.append(" ".join(current_chars))
                current_chars = []
                current_len = 0

            # Hard-split the oversized sentence.
            start = 0
            while start < sentence_len:
                end = start + chunk_size
                piece = sentence[start:end]
                chunks.append(piece)
                start = end - overlap if end < sentence_len else sentence_len
            i += 1
            continue

        # Would adding this sentence exceed chunk_size?
        projected = current_len + (1 if current_chars else 0) + sentence_len
        if projected > chunk_size and current_chars:
            # Emit current chunk.
            chunks.append(" ".join(current_chars))

            # Build overlap: take sentences from the tail of the current window
            # until we accumulate ≥ overlap characters.
            overlap_sentences: list[str] = []
            overlap_len = 0
            for s in reversed(current_chars):
                overlap_sentences.insert(0, s)
                overlap_len += len(s) + 1
                if overlap_len >= overlap:
                    break

            current_chars = overlap_sentences
            current_len = sum(len(s) for s in current_chars) + max(0, len(current_chars) - 1)

        current_chars.append(sentence)
        current_len += (1 if len(current_chars) > 1 else 0) + sentence_len
        i += 1

    if current_chars:
        chunks.append(" ".join(current_chars))

    return chunks


# ─────────────────────────────────────────────────────────────────────────────
# Internal helpers
# ─────────────────────────────────────────────────────────────────────────────

_HEADING_RE = re.compile(r'^(#{1,6})\s+(.+)$', re.MULTILINE)
_CODE_FENCE_RE = re.compile(r'```[\s\S]*?```', re.MULTILINE)


def _compute_chunk_id(source_url: str, chunk_index: int, text: str) -> str:
    """Deterministic SHA-256 hash of (source_url + chunk_index + text)."""
    payload = f"{source_url}\x00{chunk_index}\x00{text}"
    return hashlib.sha256(payload.encode("utf-8")).hexdigest()


def _split_preserving_code_blocks(
    text: str,
    chunk_size: int,
    overlap: int,
) -> list[str]:
    """
    Split *text* into chunks while keeping fenced code blocks (```…```) intact.

    Strategy:
    1. Tokenise the text into alternating prose / code-fence segments.
    2. Pack segments greedily into chunks of at most chunk_size chars.
    3. When a code block fits in the current chunk, add it whole.
    4. When a code block is too large for even an empty chunk, emit it as its
       own (oversized) chunk rather than splitting it mid-block.
    5. Apply overlap at sentence boundaries between consecutive prose chunks.
    """
    # Tokenise: list of (is_code_block: bool, segment_text: str)
    segments: list[tuple[bool, str]] = []
    last_end = 0
    for m in _CODE_FENCE_RE.finditer(text):
        prose = text[last_end:m.start()]
        if prose:
            segments.append((False, prose))
        segments.append((True, m.group()))
        last_end = m.end()
    tail = text[last_end:]
    if tail:
        segments.append((False, tail))

    chunks: list[str] = []
    current_parts: list[str] = []   # raw segment strings in current chunk
    current_len = 0

    def flush() -> None:
        nonlocal current_parts, current_len
        if current_parts:
            chunks.append("".join(current_parts))
        current_parts = []
        current_len = 0

    for is_code, seg in segments:
        seg_len = len(seg)

        if is_code:
            # Code block: keep intact.
            if seg_len > chunk_size:
                # Oversized code block — flush pending, then emit as own chunk.
                flush()
                chunks.append(seg)
            elif current_len + seg_len > chunk_size:
                flush()
                current_parts = [seg]
                current_len = seg_len
            else:
                current_parts.append(seg)
                current_len += seg_len
        else:
            # Prose: use sentence-level chunking.
            prose_chunks = chunk_text(seg, chunk_size=chunk_size, overlap=overlap)
            for pc in prose_chunks:
                pc_len = len(pc)
                if current_len + pc_len > chunk_size and current_parts:
                    flush()
                current_parts.append(pc)
                current_len += pc_len

    flush()
    return [c for c in chunks if c.strip()]


def _extract_heading_context(lines_before: list[str]) -> str:
    """Return the active heading hierarchy as a string like '# H1 > ## H2'."""
    heading_stack: list[str] = []  # (level, heading_text)
    for line in lines_before:
        m = re.match(r'^(#{1,6})\s+(.+)$', line)
        if m:
            level = len(m.group(1))
            title = m.group(2).strip()
            # Pop headings of equal or deeper level.
            heading_stack = [h for h in heading_stack if h[0] < level]
            heading_stack.append((level, title))  # type: ignore[arg-type]
    if not heading_stack:
        return ""
    return " > ".join(f"{'#' * lvl} {txt}" for lvl, txt in heading_stack)


# ─────────────────────────────────────────────────────────────────────────────
# Story 3.02 — Heading-Aware Chunker
# ─────────────────────────────────────────────────────────────────────────────

def chunk_with_headings(
    content: ExtractedContent,
    platform: str,
    customer_id: Optional[str] = None,
    chunk_size: int = 1500,
    overlap: int = 200,
) -> list[Chunk]:
    """Chunk content preserving heading hierarchy context.

    - Each Chunk's ``heading_context`` contains the nearest parent headings.
    - Prefers to split at heading boundaries.
    - Keeps fenced code blocks intact.
    - chunk_id is a deterministic SHA-256 of (source_url + chunk_index + text).
    """
    text = content.text.strip()
    if not text:
        return []

    lines = text.splitlines()

    # Split text into sections at top-level heading boundaries first.
    # A "section" is a list of lines that belong to the same heading block.
    sections: list[tuple[str, list[str]]] = []  # (heading_context, lines)
    current_section_lines: list[str] = []
    seen_lines: list[str] = []   # for computing heading_context at each point

    for line in lines:
        is_heading = re.match(r'^#{1,6}\s+', line)
        if is_heading and current_section_lines:
            ctx = _extract_heading_context(seen_lines)
            sections.append((ctx, current_section_lines))
            current_section_lines = [line]
        else:
            current_section_lines.append(line)
        seen_lines.append(line)

    if current_section_lines:
        ctx = _extract_heading_context(seen_lines)
        sections.append((ctx, current_section_lines))

    # Now chunk each section independently.
    raw_chunks: list[tuple[str, str]] = []  # (heading_context, chunk_text)
    for heading_ctx, sec_lines in sections:
        section_text = "\n".join(sec_lines).strip()
        if not section_text:
            continue
        sub_chunks = _split_preserving_code_blocks(section_text, chunk_size, overlap)
        for sc in sub_chunks:
            if sc.strip():
                raw_chunks.append((heading_ctx, sc.strip()))

    total = len(raw_chunks)
    result: list[Chunk] = []
    for idx, (heading_ctx, chunk_text_val) in enumerate(raw_chunks):
        chunk_id = _compute_chunk_id(content.url, idx, chunk_text_val)
        chunk = Chunk(
            chunk_id=chunk_id,
            source_url=content.url,
            platform=platform,
            customer_id=customer_id,
            title=content.title,
            text=chunk_text_val,
            heading_context=heading_ctx,
            chunk_index=idx,
            total_chunks=total,
            metadata=dict(content.metadata),
        )
        result.append(chunk)

    return result


# ─────────────────────────────────────────────────────────────────────────────
# Story 3.03 — Multi-Tenant Tagging
# ─────────────────────────────────────────────────────────────────────────────

def tag_chunks(
    chunks: list[Chunk],
    platform: str,
    customer_id: Optional[str] = None,
    extra_metadata: Optional[Dict[str, Any]] = None,
) -> list[Chunk]:
    """Apply platform and customer tags to chunks for multi-tenant isolation.

    - Returns NEW Chunk objects (originals are not mutated).
    - Adds ``ingested_at`` (ISO 8601) and ``source_url`` to metadata.
    - Merges ``extra_metadata`` into each chunk's metadata.
    """
    now_iso = datetime.now(tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
    result: list[Chunk] = []

    for chunk in chunks:
        new_meta: Dict[str, Any] = dict(chunk.metadata)
        new_meta["ingested_at"] = now_iso
        new_meta["source_url"] = chunk.source_url
        if extra_metadata:
            new_meta.update(extra_metadata)

        tagged = Chunk(
            chunk_id=chunk.chunk_id,
            source_url=chunk.source_url,
            platform=platform,
            customer_id=customer_id,
            title=chunk.title,
            text=chunk.text,
            heading_context=chunk.heading_context,
            chunk_index=chunk.chunk_index,
            total_chunks=chunk.total_chunks,
            metadata=new_meta,
        )
        result.append(tagged)

    return result


# ─────────────────────────────────────────────────────────────────────────────
# Story 3.04 — Batch Chunker
# ─────────────────────────────────────────────────────────────────────────────

def chunk_batch(
    contents: list[ExtractedContent],
    config: PlatformConfig,
    customer_id: Optional[str] = None,
) -> list[Chunk]:
    """Chunk multiple extracted contents using platform config.

    - Uses config.chunk_size and config.chunk_overlap for all documents.
    - Skips None or empty-text contents gracefully.
    - Returns a flat list of all resulting Chunks.
    """
    all_chunks: list[Chunk] = []

    for content in contents:
        if content is None:
            continue
        if not content.text or not content.text.strip():
            continue

        page_chunks = chunk_with_headings(
            content=content,
            platform=config.name,
            customer_id=customer_id,
            chunk_size=config.chunk_size,
            overlap=config.chunk_overlap,
        )
        all_chunks.extend(page_chunks)

    return all_chunks
