# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import re
from typing import TYPE_CHECKING

from packaging import version

from .base import HfQuantizer
from .quantizers_utils import get_module_from_name, should_convert_module


if TYPE_CHECKING:
    from ..modeling_utils import PreTrainedModel

from safetensors import safe_open

from ..utils import is_torch_available, is_torchao_available, logging


if is_torch_available():
    from ..core_model_loading import WeightConverter


if is_torch_available():
    import torch

if is_torchao_available():
    if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.15.0"):
        from torchao.prototype.safetensors.safetensors_support import (
            flatten_tensor_state_dict,
        )


logger = logging.get_logger(__name__)


def fuzzy_match_size(config_name: str) -> str | None:
    """
    Extract the size digit from strings like "4weight", "8weight".
    Returns the digit as an integer if found, otherwise None.
    """
    config_name = config_name.lower()

    str_match = re.search(r"(\d)weight", config_name)

    if str_match:
        return str_match.group(1)

    return None


def _quantization_type(weight):
    from torchao.dtypes import AffineQuantizedTensor
    from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor

    if isinstance(weight, AffineQuantizedTensor):
        return f"{weight.__class__.__name__}({weight._quantization_type()})"

    if isinstance(weight, LinearActivationQuantizedTensor):
        return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"


def _linear_extra_repr(self):
    weight = _quantization_type(self.weight)
    if weight is None:
        return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
    else:
        return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"


if is_torchao_available():
    TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))


class TorchAoHfQuantizer(HfQuantizer):
    """
    Quantizer for torchao: https://github.com/pytorch/ao/
    """

    requires_calibration = False

    def __init__(self, quantization_config, **kwargs):
        super().__init__(quantization_config, **kwargs)

        self.quantized_param_size = None
        quant_type = self.quantization_config.quant_type
        if isinstance(quant_type, str):
            map_to_param_size = {
                "int4_weight_only": 0.5,
                "int8_weight_only": 1,
                "int8_dynamic_activation_int8_weight": 1,
            }
            if quant_type in map_to_param_size:
                self.quantized_param_size = map_to_param_size[quant_type]
        else:
            size_digit = fuzzy_match_size(quant_type.__class__.__name__)
            self.quantized_param_size = 0.5 if size_digit == "4" else 1

    def validate_environment(self, *args, **kwargs):
        if not is_torchao_available():
            raise ImportError("Loading an torchao quantized model requires torchao library (`pip install torchao`)")

        self.offload = False
        device_map = kwargs.get("device_map")
        if isinstance(device_map, dict):
            if ("disk" in device_map.values() or "cpu" in device_map.values()) and len(device_map) > 1:
                self.offload = True
                if self.pre_quantized and "disk" in device_map.values():
                    raise ValueError(
                        "You are attempting to perform disk offload with a pre-quantized torchao model "
                        "This is not supported yet . Please remove the disk device from the device_map."
                    )
        if self.pre_quantized:
            weights_only = kwargs.get("weights_only")
            if weights_only:
                torch_version = version.parse(importlib.metadata.version("torch"))
                if torch_version < version.parse("2.5.0"):
                    raise RuntimeError(
                        f"In order to use torchao pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}."
                        f" You can also set with `weights_only=False` in `from_pretrained` if you don't want to update torch"
                    )

    def update_dtype(self, dtype):
        if self.quantization_config.quant_type == "int4_weight_only":
            if dtype != torch.bfloat16:
                logger.warning_once(
                    f"Setting dtype to {dtype} for int4_weight_only quantization, but only bfloat16 is supported right now. Overwriting torch_dtype to bfloat16."
                )
                dtype = torch.bfloat16
        return dtype

    def get_state_dict_and_metadata(self, model):
        """
        We flatten the state dict of tensor subclasses so that it is compatible with the safetensors format.
        """
        if version.parse("0.15.0") <= TORCHAO_VERSION:
            return flatten_tensor_state_dict(model.state_dict())
        else:
            raise RuntimeError(
                f"In order to use safetensors with torchao, please use torchao version >= 0.15.0. Current version: {TORCHAO_VERSION}"
            )

    def param_element_size(self, model: "PreTrainedModel", param_name: str, param: "torch.Tensor") -> float:
        "Return the element size (in bytes) for `param_name`."
        if self.param_needs_quantization(model, param_name) and self.quantized_param_size is not None:
            return self.quantized_param_size

        return super().param_element_size(model, param_name, param)

    def adjust_max_memory(self, max_memory: dict[str, int | str]) -> dict[str, int | str]:
        # need more space for the quantization parameters (e.g. scale). Tested with int4 wo and group size = 128
        max_memory = {key: val * 0.9 for key, val in max_memory.items()}
        return max_memory

    def _process_model_before_weight_loading(self, model: "PreTrainedModel", checkpoint_files=None, **kwargs):
        self.modules_to_not_convert = self.get_modules_to_not_convert(
            model, self.quantization_config.modules_to_not_convert, model._keep_in_fp32_modules
        )
        if self.quantization_config.include_input_output_embeddings:
            input_emb = model.get_input_embeddings()
            input_emb_names = [name for name, module in model.named_modules() if id(module) == id(input_emb)]
            output_emb = model.get_output_embeddings()
            output_emb_names = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
            self.modules_to_not_convert = [
                x for x in self.modules_to_not_convert if x not in input_emb_names + output_emb_names
            ]
        if checkpoint_files is not None:
            # Torchao needs access to all metadata later
            self.set_metadata(checkpoint_files)

    def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
        # check if the param_name is not in self.modules_to_not_convert
        if not should_convert_module(param_name, self.modules_to_not_convert):
            return False

        # we only quantize the weight of nn.Linear and nn.Embedding
        module, tensor_name = get_module_from_name(model, param_name)
        _QUANTIZABLE = [torch.nn.Linear]
        if self.quantization_config.include_input_output_embeddings:
            _QUANTIZABLE.append(torch.nn.Embedding)

        # Handle FqnToConfig, introduced in torchao 0.15.0+
        if self.quantization_config._get_ao_version() >= version.parse("0.15.0"):
            from torchao.quantization import FqnToConfig, fqn_matches_fqn_config

            if isinstance(self.quantization_config.quant_type, FqnToConfig):
                module_fqn, param_name_fqn = param_name.rsplit(".", 1)
                if (
                    fqn_matches_fqn_config(module_fqn, self.quantization_config.quant_type)
                    or fqn_matches_fqn_config(param_name, self.quantization_config.quant_type)
                    or (
                        "_default" in self.quantization_config.quant_type.fqn_to_config
                        and isinstance(module, tuple(_QUANTIZABLE))
                    )
                ):
                    return True

        return isinstance(module, tuple(_QUANTIZABLE)) and tensor_name == "weight"

    def _process_model_after_weight_loading(self, model, **kwargs):
        return

    def is_serializable(self) -> bool:
        _is_torchao_serializable = version.parse("0.15.0") <= TORCHAO_VERSION
        if not version.parse("0.15.0") <= TORCHAO_VERSION:
            logger.warning(
                "torchao quantized model only supports serialization for torchao version >= 0.15.0, please upgrade "
                "your version to save the quantized model"
            )
        return _is_torchao_serializable

    @property
    def is_trainable(self) -> bool:
        supported_quant_types_for_training = [
            "int8_weight_only",
            "int8_dynamic_activation_int8_weight",
        ]
        return self.quantization_config.quant_type in supported_quant_types_for_training

    @property
    def is_compileable(self) -> bool:
        return True

    def set_metadata(self, checkpoint_files: list[str]):
        if checkpoint_files[0].endswith(".safetensors"):
            metadata = {}
            for checkpoint in checkpoint_files:
                with safe_open(checkpoint, framework="pt") as f:
                    metadata_ = f.metadata() or {}
                    metadata.update(metadata_)
            # Save it
            self.metadata = metadata

    def get_quantize_ops(self):
        from ..integrations.torchao import TorchAoQuantize

        return TorchAoQuantize(self)

    def get_weight_conversions(self):
        from ..integrations.torchao import TorchAoDeserialize

        if self.pre_quantized:
            return [
                WeightConverter(
                    # TODO: incr flexibility by generalizing the source patterns to match the format of "_weight_"
                    # note that the matching logic is greedy, so for ex, if _weight_scale is before _weight_scale_and_zero in this list, it will match _weight_scale always (this is incorrect)
                    # thus, the order of source_patterns is intentional
                    source_patterns=[
                        "_weight_qdata",
                        "_weight_scale_and_zero",
                        "_weight_scale",
                        "_weight_zero_point",
                        "_weight_act_pre_scale",
                    ],
                    target_patterns="weight",
                    operations=[TorchAoDeserialize(self)],
                ),
            ]
        return []
