from typing import Optional, Union

from google.cloud import aiplatform
from google.cloud.aiplatform import initializer as aiplatform_initializer
from vertexai.language_models import _language_models
from vertexai.language_models import _language_models as tuning


_DISTILLATION_PIPELINE_URI = (
    "https://us-kfp.pkg.dev/ml-pipeline/distillation/distillation/v1.0.0"
)


class DistillationMixin:
    def distill_from(
        self,
        *,
        dataset: str,
        teacher_model: Union[str, _language_models._TextGenerationModel],
        train_steps: Optional[int] = None,
        learning_rate_multiplier: Optional[float] = None,
        evaluation_spec: Optional[tuning.TuningEvaluationSpec] = None,
        accelerator_type: Optional[tuning._ACCELERATOR_TYPE_TYPE] = None,
        model_display_name: Optional[str] = None,
        max_context_length: Optional[str] = None,
    ):
        """Tunes a smaller model with help from another bigger model.

        Args:
            dataset: A URI pointing to data in JSON lines format.
            teacher_model: The teacher model to use for distillation.
            train_steps: Number of training batches to use (batch size is 8 samples).
            learning_rate_multiplier: Learning rate multiplier to use in tuning.
            evaluation_spec: Specification for the model evaluation during tuning.
            accelerator_type: Type of accelerator to use. Can be "TPU" or "GPU".
            model_display_name: Custom display name for the tuned model.
            max_context_length: The max context length used for tuning.
                Can be either '8k' or '32k'

        Returns:
            A tuning job for distillation.

        Raises:
            RuntimeError: If the model does not support distillation.
        """
        if "/models/" not in self._endpoint_name:
            raise RuntimeError(
                f"Model does not support distillation: {self._endpoint_name}"
            )
        student_short_model_id = self._endpoint_name.split("/")[-1]

        if isinstance(teacher_model, str):
            teacher_short_model_id = teacher_model
        elif isinstance(teacher_model, _language_models._LanguageModel):
            if "/models/" not in teacher_model._endpoint_name:
                raise RuntimeError(
                    f"Teacher model does not support distillation: {teacher_model._endpoint_name}"
                )
            teacher_short_model_id = teacher_model._endpoint_name.split("/")[-1]
        else:
            raise RuntimeError(f"Unsupported teacher model type: {teacher_model}")

        pipeline_job = submit_distillation_pipeline_job(
            teacher_model=teacher_short_model_id,
            student_model=student_short_model_id,
            dataset=dataset,
            train_steps=train_steps,
            learning_rate_multiplier=learning_rate_multiplier,
            evaluation_spec=evaluation_spec,
            accelerator_type=accelerator_type,
            model_display_name=model_display_name,
            max_context_length=max_context_length,
        )
        tuning_job = tuning._LanguageModelTuningJob(
            base_model=self,
            job=pipeline_job,
        )
        return tuning_job


def submit_distillation_pipeline_job(
    *,
    teacher_model: str,
    student_model: str,
    dataset: str,
    train_steps: Optional[int] = None,
    learning_rate_multiplier: Optional[float] = None,
    evaluation_spec: Optional[tuning.TuningEvaluationSpec] = None,
    accelerator_type: Optional[tuning._ACCELERATOR_TYPE_TYPE] = None,
    model_display_name: Optional[str] = None,
    max_context_length: Optional[str] = None,
):
    teacher_short_model_id = teacher_model.split("/")[-1]
    student_short_model_id = student_model.split("/")[-1]
    pipeline_arguments = {
        "teacher_model_reference": teacher_model,
        "student_model_reference": student_model,
        "dataset_uri": dataset,
        "project": aiplatform_initializer.global_config.project,
        "location": aiplatform_initializer.global_config.location,
    }
    if train_steps is not None:
        pipeline_arguments["train_steps"] = train_steps
    if learning_rate_multiplier is not None:
        pipeline_arguments["learning_rate_multiplier"] = learning_rate_multiplier
    if evaluation_spec is not None:
        pipeline_arguments["evaluation_data_uri"] = evaluation_spec.evaluation_data
        pipeline_arguments["evaluation_interval"] = evaluation_spec.evaluation_interval
        pipeline_arguments["enable_early_stopping"] = (
            evaluation_spec.enable_early_stopping
        )
        pipeline_arguments["enable_checkpoint_selection"] = (
            evaluation_spec.enable_checkpoint_selection
        )
        pipeline_arguments["tensorboard_resource_id"] = evaluation_spec.tensorboard
        # pipeline_parameter_values["evaluation_output_root_dir"] = ...
    if accelerator_type is not None:
        pipeline_arguments["accelerator_type"] = accelerator_type
    if aiplatform_initializer.global_config.encryption_spec_key_name is not None:
        pipeline_arguments["encryption_spec_key_name"] = (
            aiplatform_initializer.global_config.encryption_spec_key_name
        )
    if max_context_length is not None:
        pipeline_arguments["max_context_length"] = max_context_length
    if model_display_name is None:
        model_display_name = (
            f"{student_short_model_id} distilled from {teacher_short_model_id}"
        )
    pipeline_arguments["model_display_name"] = model_display_name
    # # Not exposing these parameters:
    # temperature: Optional[float] = None,
    # tpu_training_skip_cmek: Optional[bool] = None,
    # api_endpoint: Optional[str] = None,
    # version: Optional[str] = None,
    pipeline_job = aiplatform.PipelineJob(
        template_path=_DISTILLATION_PIPELINE_URI,
        display_name=None,
        parameter_values=pipeline_arguments,
    )
    pipeline_job.submit()
    return pipeline_job
