# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from typing import List, Optional, Sequence, Union

from google.cloud.aiplatform import base
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import utils as aip_utils
from google.cloud.aiplatform_v1beta1 import types
from vertexai.generative_models import _generative_models
from vertexai.reasoning_engines import _utils
from google.protobuf import struct_pb2

_LOGGER = base.Logger(__name__)

_AuthConfigOrJson = Union[_utils.JsonDict, types.AuthConfig]
_StructOrJson = Union[_utils.JsonDict, struct_pb2.Struct]
_RuntimeConfigOrJson = Union[_utils.JsonDict, types.RuntimeConfig]


_VERTEX_EXTENSION_HUB = {
    "code_interpreter": {
        "display_name": "Code Interpreter",
        "description": (
            "This extension generates and executes code in the specified language"
        ),
        "manifest": {
            "name": "code_interpreter_tool",
            "description": "Google Code Interpreter Extension",
            "api_spec": {
                "open_api_gcs_uri": (
                    "gs://vertex-extension-public/code_interpreter.yaml"
                ),
            },
            "auth_config": {
                "auth_type": "GOOGLE_SERVICE_ACCOUNT_AUTH",
                "google_service_account_config": {},
            },
        },
    },
    "vertex_ai_search": {
        "display_name": "Vertex AI Search",
        "description": "This extension generates and executes search queries",
        "manifest": {
            "name": "vertex_ai_search",
            "description": "Vertex AI Search Extension",
            "api_spec": {
                "open_api_gcs_uri": (
                    "gs://vertex-extension-public/vertex_ai_search.yaml"
                ),
            },
            "auth_config": {
                "auth_type": "GOOGLE_SERVICE_ACCOUNT_AUTH",
                "google_service_account_config": {},
            },
        },
    },
    "webpage_browser": {
        "display_name": "Webpage Browser",
        "description": "This extension fetches the content of a webpage",
        "manifest": {
            "name": "webpage_browser",
            "description": "Vertex Webpage Browser Extension",
            "api_spec": {
                "open_api_gcs_uri": (
                    "gs://vertex-extension-public/webpage_browser.yaml"
                ),
            },
            "auth_config": {
                "auth_type": "GOOGLE_SERVICE_ACCOUNT_AUTH",
                "google_service_account_config": {},
            },
        },
    },
}


class Extension(base.VertexAiResourceNounWithFutureManager):
    """Represents a Vertex AI Extension resource."""

    client_class = aip_utils.ExtensionRegistryClientWithOverride
    _resource_noun = "extension"
    _getter_method = "get_extension"
    _list_method = "list_extensions"
    _delete_method = "delete_extension"
    _parse_resource_name_method = "parse_extension_path"
    _format_resource_name_method = "extension_path"

    def __init__(self, extension_name: str):
        """Retrieves an extension resource.

        Args:
            extension_name (str):
                Required. A fully-qualified resource name or ID such as
                "projects/123/locations/us-central1/extensions/456" or
                "456" when project and location are initialized or passed.
        """
        super().__init__(resource_name=extension_name)
        self.execution_api_client = initializer.global_config.create_client(
            client_class=aip_utils.ExtensionExecutionClientWithOverride,
        )
        self._gca_resource = self._get_gca_resource(resource_name=extension_name)
        self._api_spec = None
        self._operation_schemas = None

    @classmethod
    def create(
        cls,
        manifest: Union[_utils.JsonDict, types.ExtensionManifest],
        *,
        extension_name: Optional[str] = None,
        display_name: Optional[str] = None,
        description: Optional[str] = None,
        runtime_config: Optional[_RuntimeConfigOrJson] = None,
    ):
        """Creates a new Extension.

        Args:
            manifest (Union[dict[str, Any], ExtensionManifest]):
                Required. The manifest for the Extension to be created.
            extension_name (str):
                Optional. A fully-qualified extension resource name or extension
                ID such as "projects/123/locations/us-central1/extensions/456" or
                "456" when project and location are initialized or passed. If
                specifying the extension ID, it should be 4-63 characters, valid
                characters are lowercase letters, numbers and hyphens ("-"),
                and it should start with a number or a lower-case letter. If not
                provided, Vertex AI will generate a value for this ID.
            display_name (str):
                Optional. The user-defined name of the Extension.
                The name can be up to 128 characters long and can comprise any
                UTF-8 character.
            description (str):
                Optional. The description of the Extension.
            runtime_config (Union[dict[str, Any], RuntimeConfig]):
                Optional. Runtime config controlling the runtime behavior of
                this Extension. Defaults to None.

        Returns:
            Extension: The extension that was created.
        """
        sdk_resource = cls.__new__(cls)
        base.VertexAiResourceNounWithFutureManager.__init__(
            sdk_resource,
            resource_name=extension_name,
        )
        extension = types.Extension(
            name=extension_name,
            display_name=display_name or cls._generate_display_name(),
            description=description,
            manifest=_utils.to_proto(manifest, types.ExtensionManifest()),
        )
        if runtime_config:
            extension.runtime_config = _utils.to_proto(
                runtime_config,
                types.RuntimeConfig(),
            )
        operation_future = sdk_resource.api_client.import_extension(
            parent=initializer.global_config.common_location_path(),
            extension=extension,
        )
        _LOGGER.log_create_with_lro(cls, operation_future)
        created_extension = operation_future.result()
        _LOGGER.log_create_complete(
            cls,
            created_extension,
            cls._resource_noun,
            module_name="vertexai.preview.extensions",
        )
        # We use `._get_gca_resource(...)` instead of `created_extension` to
        # fully instantiate the attributes of the extension.
        sdk_resource._gca_resource = sdk_resource._get_gca_resource(
            resource_name=created_extension.name
        )
        sdk_resource.execution_api_client = initializer.global_config.create_client(
            client_class=aip_utils.ExtensionExecutionClientWithOverride,
        )
        sdk_resource._api_spec = None
        sdk_resource._operation_schemas = None
        return sdk_resource

    @property
    def resource_name(self) -> str:
        """Full qualified resource name for the extension."""
        return self._gca_resource.name

    def api_spec(self) -> _utils.JsonDict:
        """Returns the (Open)API Spec of the extension."""
        if self._api_spec is None:
            self._api_spec = _load_api_spec(self._gca_resource.manifest.api_spec)
        return self._api_spec

    def operation_schemas(self) -> Sequence[_utils.JsonDict]:
        """Returns the (Open)API schemas for each operation of the extension."""
        if self._operation_schemas is None:
            self._operation_schemas = [
                _utils.to_dict(op.function_declaration)
                for op in self._gca_resource.extension_operations
            ]
        return self._operation_schemas

    def execute(
        self,
        operation_id: str,
        operation_params: Optional[_StructOrJson] = None,
        runtime_auth_config: Optional[_AuthConfigOrJson] = None,
    ) -> Union[_utils.JsonDict, str]:
        """Executes an operation of the extension with the specified params.

        Args:
          operation_id (str):
              Required. The ID of the operation to be executed.
          operation_params (Union[dict[str, Any], Struct]):
              Optional. Parameters used for executing the operation. It should
              be in a form of map with param name as the key and actual param
              value as the value. E.g. if this operation requires a param
              "name" to be set to "abc", you can set this to {"name": "abc"}.
              Defaults to an empty dictionary.
          runtime_auth_config (Union[dict[str, Any], AuthConfig]):
              Optional. The Auth configuration to execute the operation.

        Returns:
            The result of executing the extension operation.
        """
        request = types.ExecuteExtensionRequest(
            name=self.resource_name,
            operation_id=operation_id,
            operation_params=operation_params,
        )
        if runtime_auth_config:
            request.runtime_auth_config = _utils.to_proto(
                runtime_auth_config,
                types.AuthConfig(),
            )
        response = self.execution_api_client.execute_extension(request)
        return _try_parse_execution_response(response)

    def query(
        self,
        contents: _generative_models.ContentsType,
    ) -> "QueryExtensionResponse":
        """Queries an extension with the specified contents.

        Args:
          contents (ContentsType):
              Required. The content of the current
              conversation with the model.
              For single-turn queries, this is a single
              instance. For multi-turn queries, this is a
              repeated field that contains conversation
              history + latest request.

        Returns:
            The result of querying the extension.

        Raises:
            RuntimeError: If the response contains an error.
        """
        request = types.QueryExtensionRequest(
            name=self.resource_name,
            contents=_generative_models._content_types_to_gapic_contents(contents),
        )
        response = self.execution_api_client.query_extension(request)
        if response.failure_message:
            raise RuntimeError(response.failure_message)
        return QueryExtensionResponse._from_gapic(response)

    @classmethod
    def from_hub(
        cls,
        name: str,
        *,
        runtime_config: Optional[_RuntimeConfigOrJson] = None,
    ):
        """Creates a new Extension from the set of first party extensions.

        Args:
            name (str):
                Required. The name of the extension in the hub to be created.
                Supported values are "code_interpreter", "vertex_ai_search" and
                "webpage_browser".
            runtime_config (Union[dict[str, Any], RuntimeConfig]):
                Optional. Runtime config controlling the runtime behavior of
                the Extension. Defaults to None.

        Returns:
            Extension: The extension that was created.

        Raises:
            ValueError: If the `name` is not supported in the hub.
            ValueError: If the `runtime_config` is specified but inconsistent
            with the name (e.g. the name was "code_interpreter" but the
            runtime_config was based on "vertex_ai_search_runtime_config").
        """
        if runtime_config:
            runtime_config = _utils.to_proto(
                runtime_config,
                types.RuntimeConfig(),
            )
        if name == "code_interpreter":
            if runtime_config and not getattr(
                runtime_config,
                "code_interpreter_runtime_config",
                None,
            ):
                raise ValueError(
                    "code_interpreter_runtime_config is required for "
                    "code_interpreter extension"
                )
        elif name == "vertex_ai_search":
            if not runtime_config:
                raise ValueError(
                    "runtime_config is required for vertex_ai_search extension"
                )
            if runtime_config and not getattr(
                runtime_config,
                "vertex_ai_search_runtime_config",
                None,
            ):
                raise ValueError(
                    "vertex_ai_search_runtime_config is required for "
                    "vertex_ai_search extension"
                )
        elif name == "webpage_browser":
            pass
        else:
            raise ValueError(f"Unsupported 1P extension name: {name}")
        extension_info = _VERTEX_EXTENSION_HUB[name]
        return cls.create(
            display_name=extension_info["display_name"],
            description=extension_info["description"],
            manifest=extension_info["manifest"],
            runtime_config=runtime_config,
        )


class QueryExtensionResponse:
    """A class representing the response from querying an extension."""

    def __init__(self, steps: List[_generative_models.Content]):
        """Initializes the QueryExtensionResponse with the given steps."""
        self.steps = steps

    @classmethod
    def _from_gapic(
        cls, response: types.QueryExtensionResponse
    ) -> "QueryExtensionResponse":
        """Creates a QueryExtensionResponse from a gapic response."""
        return cls(
            steps=[
                _generative_models.Content(
                    parts=[_generative_models.Part._from_gapic(p) for p in c.parts],
                    role=c.role,
                )
                for c in response.steps
            ]
        )


def _try_parse_execution_response(
    response: types.ExecuteExtensionResponse,
) -> Union[_utils.JsonDict, str]:
    content: str = response.content
    try:
        content = json.loads(content)
    except json.JSONDecodeError:
        pass
    return content


def _load_api_spec(api_spec) -> _utils.JsonDict:
    """Loads the (Open)API Spec of the extension and converts it to JSON."""
    if api_spec.open_api_yaml:
        yaml = aip_utils.yaml_utils._maybe_import_yaml()
        return yaml.safe_load(api_spec.open_api_yaml)
    elif api_spec.open_api_gcs_uri:
        return aip_utils.yaml_utils.load_yaml(api_spec.open_api_gcs_uri)
    return {}
