# -*- coding: utf-8 -*-

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import grpc
import logging
import ray

from typing import Dict
from typing import Optional
from google.cloud import aiplatform
from google.cloud.aiplatform import initializer
from ray import client_builder
from .render import VertexRayTemplate
from .util import _validation_utils
from .util import _gapic_utils


VERTEX_SDK_VERSION = aiplatform.__version__


class _VertexRayClientContext(client_builder.ClientContext):
    """Custom ClientContext."""

    def __init__(
        self,
        persistent_resource_id: str,
        ray_head_uris: Dict[str, str],
        ray_client_context: client_builder.ClientContext,
    ) -> None:
        dashboard_uri = ray_head_uris.get("RAY_DASHBOARD_URI")
        if dashboard_uri is None:
            raise ValueError(
                "Ray Cluster ",
                persistent_resource_id,
                " failed to start Head node properly.",
            )
        if ray.__version__ in ("2.47.1", "2.42.0", "2.33.0"):
            super().__init__(
                dashboard_url=dashboard_uri,
                python_version=ray_client_context.python_version,
                ray_version=ray_client_context.ray_version,
                ray_commit=ray_client_context.ray_commit,
                _num_clients=ray_client_context._num_clients,
                _context_to_restore=ray_client_context._context_to_restore,
            )
        elif ray.__version__ == "2.9.3":
            super().__init__(
                dashboard_url=dashboard_uri,
                python_version=ray_client_context.python_version,
                ray_version=ray_client_context.ray_version,
                ray_commit=ray_client_context.ray_commit,
                protocol_version=ray_client_context.protocol_version,
                _num_clients=ray_client_context._num_clients,
                _context_to_restore=ray_client_context._context_to_restore,
            )
        else:
            raise ImportError(
                f"[Ray on Vertex AI]: Unsupported version {ray.__version__}."
                + "Only 2.47.1, 2.42.0, 2.33.0, and 2.9.3 are supported."
            )
        self.persistent_resource_id = persistent_resource_id
        self.vertex_sdk_version = str(VERTEX_SDK_VERSION)
        self.shell_uri = ray_head_uris.get("RAY_HEAD_NODE_INTERACTIVE_SHELL_URI")

    def _context_table_template(self):
        shell_uri_row = None
        if self.shell_uri is not None:
            shell_uri_row = VertexRayTemplate("context_shellurirow.html.j2").render(
                shell_uri=self.shell_uri
            )

        return VertexRayTemplate("context_table.html.j2").render(
            python_version=self.python_version,
            ray_version=self.ray_version,
            vertex_sdk_version=self.vertex_sdk_version,
            dashboard_url=self.dashboard_url,
            persistent_resource_id=self.persistent_resource_id,
            shell_uri_row=shell_uri_row,
        )


class VertexRayClientBuilder(client_builder.ClientBuilder):
    """Class to initialize a Ray client with vertex on ray capabilities."""

    def __init__(self, address: Optional[str]) -> None:
        address = _validation_utils.maybe_reconstruct_resource_name(address)
        _validation_utils.valid_resource_name(address)
        self._credentials = None
        self._metadata = None
        self.vertex_address = address
        logging.info(
            "[Ray on Vertex AI]: Using cluster resource name to access head address with GAPIC API"
        )

        self.resource_name = address

        self.response = _gapic_utils.get_persistent_resource(self.resource_name)
        private_address = self.response.resource_runtime.access_uris.get(
            "RAY_HEAD_NODE_INTERNAL_IP"
        )
        public_address = self.response.resource_runtime.access_uris.get(
            "RAY_CLIENT_ENDPOINT"
        )
        service_account = (
            self.response.resource_runtime_spec.service_account_spec.service_account
        )

        if public_address is None:
            address = private_address
            if service_account:
                raise ValueError(
                    "[Ray on Vertex AI]: Ray Cluster ",
                    address,
                    " failed to start Head node properly because custom service"
                    " account isn't supported in peered VPC network. Use public"
                    " endpoint instead (createa a cluster without specifying"
                    " VPC network).",
                )
        else:
            address = public_address

        if address is None:
            persistent_resource_id = self.resource_name.split("/")[5]
            raise ValueError(
                "[Ray on Vertex AI]: Ray Cluster ",
                persistent_resource_id,
                " Head node is not reachable. Please ensure that a valid VPC network has been specified.",
            )

        logging.debug("[Ray on Vertex AI]: Resolved head node ip: %s", address)
        cluster = _gapic_utils.persistent_resource_to_cluster(
            persistent_resource=self.response
        )
        if cluster is None:
            raise ValueError(
                "[Ray on Vertex AI]: Please delete and recreate the cluster (The cluster is not a Ray cluster or the cluster image is outdated)."
            )
        local_ray_verion = _validation_utils.get_local_ray_version()
        if cluster.ray_version != local_ray_verion:
            if cluster.head_node_type.custom_image is None:
                install_ray_version = _validation_utils.SUPPORTED_RAY_VERSIONS.get(
                    cluster.ray_version
                )
                logging.info(
                    "[Ray on Vertex]: Local runtime has Ray version %s"
                    ", but the requested cluster runtime has %s. Please "
                    "ensure that the Ray versions match for client connectivity. You may "
                    '"pip install --user --force-reinstall ray[default]==%s"'
                    " and restart runtime before cluster connection."
                    % (local_ray_verion, cluster.ray_version, install_ray_version)
                )
            else:
                logging.info(
                    "[Ray on Vertex]: Local runtime has Ray version %s."
                    "Please ensure that the Ray versions match for client connectivity."
                    % local_ray_verion
                )
        super().__init__(address)

    def connect(self) -> _VertexRayClientContext:
        # Can send any other params to ray cluster here
        logging.info("[Ray on Vertex AI]: Connecting...")

        public_address = self.response.resource_runtime.access_uris.get(
            "RAY_CLIENT_ENDPOINT"
        )
        private_address = self.response.resource_runtime.access_uris.get(
            "RAY_HEAD_NODE_INTERNAL_IP"
        )
        if public_address and not private_address:
            self._credentials = grpc.ssl_channel_credentials()
            bearer_token = _validation_utils.get_bearer_token()
            self._metadata = [
                ("authorization", "Bearer {}".format(bearer_token)),
                ("x-goog-user-project", "{}".format(initializer.global_config.project)),
            ]
        ray_client_context = super().connect()
        ray_head_uris = self.response.resource_runtime.access_uris

        # Valid resource name (reference public doc for public release):
        # "projects/<project_num>/locations/<region>/persistentResources/<pr_id>"
        persistent_resource_id = self.resource_name.split("/")[5]

        return _VertexRayClientContext(
            persistent_resource_id=persistent_resource_id,
            ray_head_uris=ray_head_uris,
            ray_client_context=ray_client_context,
        )
