# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
import shutil
import sys
from collections import OrderedDict, defaultdict
from dataclasses import dataclass
from typing import Any


_DIGIT_RX = re.compile(r"(?<=\.)(\d+)(?=\.|$)")  # numbers between dots or at the end


def _pattern_of(key: str) -> str:
    """Replace every dot-delimited integer with '*' to get the structure."""
    return _DIGIT_RX.sub("*", key)


def _fmt_indices(values: list[int], cutoff=10) -> str:
    """Format a list of ints as single number, {a, ..., b}, or first...last."""
    if len(values) == 1:
        return str(values[0])
    values = sorted(values)
    if len(values) > cutoff:
        return f"{values[0]}...{values[-1]}"
    return ", ".join(map(str, values))


def update_key_name(mapping: dict[str, Any]) -> dict[str, Any]:
    """
    Merge keys like 'layers.0.x', 'layers.1.x' into 'layers.{0, 1}.x'
    BUT only merge together keys that have the exact same value.
    Returns a new dict {merged_key: value}.
    """
    # (pattern, value) -> list[set[int]] (per-star index values)
    not_mapping = False
    if not isinstance(mapping, dict):
        mapping = {k: k for k in mapping}
        not_mapping = True

    bucket: dict[tuple[str, Any], list[set[int]]] = defaultdict(list)
    for key, val in mapping.items():
        digs = _DIGIT_RX.findall(key)
        patt = _pattern_of(key)
        for i, d in enumerate(digs):
            if len(bucket[patt]) <= i:
                bucket[patt].append(set())
            bucket[patt][i].add(int(d))
        bucket[patt].append(val)

    out_items = {}
    for patt, values in bucket.items():
        sets, val = values[:-1], values[-1]
        parts = patt.split("*")  # stars are between parts
        final = parts[0]
        for i in range(1, len(parts)):
            if i - 1 < len(sets) and sets[i - 1]:
                insert = _fmt_indices(sorted(sets[i - 1]))
                if len(sets[i - 1]) > 1:
                    final += "{" + insert + "}"
                else:
                    final += insert
            else:
                final += "*"
            final += parts[i]

        out_items[final] = val
    out = OrderedDict(out_items)
    if not_mapping:
        return out.keys()
    return out


_ansi_re = re.compile(r"\x1b\[[0-9;]*m")


def _strip_ansi(s: str) -> str:
    return _ansi_re.sub("", str(s))


def _pad(text, width):
    t = str(text)
    pad = max(0, width - len(_strip_ansi(t)))
    return t + " " * pad


def _make_table(rows, headers):
    # compute display widths while ignoring ANSI codes
    cols = list(zip(*([headers] + rows))) if rows else [headers]
    widths = [max(len(_strip_ansi(x)) for x in col) for col in cols]
    header_line = " | ".join(_pad(h, w) for h, w in zip(headers, widths))
    sep_line = "-+-".join("-" * w for w in widths)
    body = [" | ".join(_pad(c, w) for c, w in zip(r, widths)) for r in rows]
    return "\n".join([header_line, sep_line] + body)


PALETTE = {
    "reset": "[0m",
    "red": "[31m",
    "yellow": "[33m",
    "orange": "[38;5;208m",
    "purple": "[35m",
    "bold": "[1m",
    "italic": "[3m",
    "dim": "[2m",
}


def _color(s, color):
    """Return color-formatted input `s` if `sys.stdout` is interactive, e.g. connected to a terminal."""
    if sys.stdout.isatty():
        return f"{PALETTE[color]}{s}{PALETTE['reset']}"
    else:
        return s


def _get_terminal_width(default=80):
    try:
        return shutil.get_terminal_size().columns
    except Exception:
        return default


@dataclass
class LoadStateDictInfo:
    """
    Mutable container for state-dict loading results and diagnostics. Each entry in this structure is mutable,
    and will usually be mutated in-place during the loading pipeline.

    Attributes:
        missing_keys (`set[str]`):
            Keys that are missing from the loaded checkpoints but expected in the model's architecture.
        unexpected_keys (`set[str]`):
            Keys that are found in the checkpoints, but not expected in the model's architecture.
        mismatched_keys (`set[tuple[str, tuple[int], tuple[int]]]`):
            Keys that are found in the checkpoints and are expected in the model's architecture, but with a different shape.
        error_msgs ( `list[str]`):
            Some potential error messages.
        conversion_errors (`dict[str, str]`):
            Errors happening during the on-the-fly weight conversion process.
    """

    missing_keys: set[str]
    unexpected_keys: set[str]
    mismatched_keys: set[tuple[str, tuple[int], tuple[int]]]
    error_msgs: list[str]
    conversion_errors: dict[str, str]

    def missing_and_mismatched(self):
        """Return all effective missing keys, including `missing` and `mismatched` keys."""
        return self.missing_keys | {k[0] for k in self.mismatched_keys}

    def to_dict(self):
        # Does not include the `conversion_errors` to be coherent with legacy reporting in the tests
        return {
            "missing_keys": self.missing_keys,
            "unexpected_keys": self.unexpected_keys,
            "mismatched_keys": self.mismatched_keys,
            "error_msgs": self.error_msgs,
        }

    def create_loading_report(self) -> str | None:
        """Generate the minimal table of a loading report."""
        term_w = _get_terminal_width()

        rows = []
        tips = ""
        if self.unexpected_keys:
            tips += (
                f"\n- {_color('UNEXPECTED', 'orange') + PALETTE['italic']}\t:can be ignored when loading from different "
                "task/architecture; not ok if you expect identical arch."
            )
            for k in update_key_name(self.unexpected_keys):
                status = _color("UNEXPECTED", "orange")
                rows.append([k, status, "", ""])

        if self.missing_keys:
            tips += (
                f"\n- {_color('MISSING', 'red') + PALETTE['italic']}\t:those params were newly initialized because missing "
                "from the checkpoint. Consider training on your downstream task."
            )
            for k in update_key_name(self.missing_keys):
                status = _color("MISSING", "red")
                rows.append([k, status, ""])

        if self.mismatched_keys:
            tips += (
                f"\n- {_color('MISMATCH', 'yellow') + PALETTE['italic']}\t:ckpt weights were loaded, but they did not match "
                "the original empty weight shapes."
            )
            iterator = {a: (b, c) for a, b, c in self.mismatched_keys}
            for key, (shape_ckpt, shape_model) in update_key_name(iterator).items():
                status = _color("MISMATCH", "yellow")
                data = [
                    key,
                    status,
                    f"Reinit due to size mismatch - ckpt: {str(shape_ckpt)} vs model:{str(shape_model)}",
                ]
                rows.append(data)

        if self.conversion_errors:
            tips += f"\n- {_color('CONVERSION', 'purple') + PALETTE['italic']}\t:originate from the conversion scheme"
            for k, v in update_key_name(self.conversion_errors).items():
                status = _color("CONVERSION", "purple")
                _details = f"\n\n{v}\n\n"
                rows.append([k, status, _details])

        # If nothing is wrong, return None
        if len(rows) == 0:
            return None

        headers = ["Key", "Status"]
        if term_w > 200:
            headers += ["Details"]
        else:
            headers += ["", ""]
        table = _make_table(rows, headers=headers)
        tips = f"\n\n{PALETTE['italic']}Notes:{tips}{PALETTE['reset']}"
        report = table + tips

        return report


def log_state_dict_report(
    model,
    pretrained_model_name_or_path: str,
    ignore_mismatched_sizes: bool,
    loading_info: LoadStateDictInfo,
    logger: logging.Logger | None = None,
):
    """
    Log a readable report about state_dict loading issues.

    This version is terminal-size aware: for very small terminals it falls back to a compact
    Key | Status view so output doesn't wrap badly.
    """
    if logger is None:
        logger = logging.getLogger(__name__)

    # Re-raise errors early if needed
    if loading_info.error_msgs:
        error_msg = "\n\t".join(loading_info.error_msgs)
        if "size mismatch" in error_msg:
            error_msg += (
                "\n\tYou may consider adding `ignore_mismatched_sizes=True` to `from_pretrained(...)` if appropriate."
            )
        raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")

    # Create the report table
    report = loading_info.create_loading_report()
    if report is None:
        return

    prelude = f"{PALETTE['bold']}{model.__class__.__name__} LOAD REPORT{PALETTE['reset']} from: {pretrained_model_name_or_path}\n"

    # Log the report as warning
    logger.warning(prelude + report)

    # Re-raise in those case, after the report
    if loading_info.conversion_errors:
        raise RuntimeError(
            "We encountered some issues during automatic conversion of the weights. For details look at the `CONVERSION` entries of "
            "the above report!"
        )
    if not ignore_mismatched_sizes and loading_info.mismatched_keys:
        raise RuntimeError(
            "You set `ignore_mismatched_sizes` to `False`, thus raising an error. For details look at the above report!"
        )
