# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dataclasses
import inspect
import json
import types
import typing
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union

import proto

from google.cloud.aiplatform import base
from google.api import httpbody_pb2
from google.protobuf import struct_pb2
from google.protobuf import json_format

try:
    # For LangChain templates, they might not import langchain_core and get
    #   PydanticUserError: `query` is not fully defined; you should define
    #   `RunnableConfig`, then call `query.model_rebuild()`.
    import langchain_core.runnables.config

    RunnableConfig = langchain_core.runnables.config.RunnableConfig
except ImportError:
    RunnableConfig = Any

try:
    from llama_index.core.base.response import schema as llama_index_schema
    from llama_index.core.base.llms import types as llama_index_types

    LlamaIndexResponse = llama_index_schema.Response
    LlamaIndexBaseModel = llama_index_schema.BaseModel
    LlamaIndexChatResponse = llama_index_types.ChatResponse
except ImportError:
    LlamaIndexResponse = Any
    LlamaIndexBaseModel = Any
    LlamaIndexChatResponse = Any

JsonDict = Dict[str, Any]

_LOGGER = base.Logger(__name__)


def to_proto(
    obj: Union[JsonDict, proto.Message],
    message: Optional[proto.Message] = None,
) -> proto.Message:
    """Parses a JSON-like object into a message.

    If the object is already a message, this will return the object as-is. If
    the object is a JSON Dict, this will parse and merge the object into the
    message.

    Args:
        obj (Union[dict[str, Any], proto.Message]):
            Required. The object to convert to a proto message.
        message (proto.Message):
            Optional. A protocol buffer message to merge the obj into. It
            defaults to Struct() if unspecified.

    Returns:
        proto.Message: The same message passed as argument.
    """
    if message is None:
        message = struct_pb2.Struct()
    if isinstance(obj, (proto.Message, struct_pb2.Struct)):
        return obj
    try:
        json_format.ParseDict(obj, message._pb)
    except AttributeError:
        json_format.ParseDict(obj, message)
    return message


def to_dict(message: proto.Message) -> JsonDict:
    """Converts the contents of the protobuf message to JSON format.

    Args:
        message (proto.Message):
            Required. The proto message to be converted to a JSON dictionary.

    Returns:
        dict[str, Any]: A dictionary containing the contents of the proto.
    """
    try:
        # Best effort attempt to convert the message into a JSON dictionary.
        result: JsonDict = json.loads(json_format.MessageToJson(message._pb))
    except AttributeError:
        result: JsonDict = json.loads(json_format.MessageToJson(message))
    return result


def dataclass_to_dict(obj: dataclasses.dataclass) -> Any:
    """Converts a dataclass to a JSON dictionary.

    Args:
        obj (dataclasses.dataclass):
            Required. The dataclass to be converted to a JSON dictionary.

    Returns:
        dict[str, Any]: A dictionary containing the contents of the dataclass.
    """
    return json.loads(json.dumps(dataclasses.asdict(obj)))


def _llama_index_response_to_dict(obj: LlamaIndexResponse) -> Any:
    response = {}
    if hasattr(obj, "response"):
        response["response"] = obj.response
    if hasattr(obj, "source_nodes"):
        response["source_nodes"] = [node.model_dump_json() for node in obj.source_nodes]
    if hasattr(obj, "metadata"):
        response["metadata"] = obj.metadata

    return json.loads(json.dumps(response))


def _llama_index_chat_response_to_dict(obj: LlamaIndexChatResponse) -> Any:
    return json.loads(obj.message.model_dump_json())


def _llama_index_base_model_to_dict(obj: LlamaIndexBaseModel) -> Any:
    return json.loads(obj.model_dump_json())


def to_json_serializable_llama_index_object(
    obj: Union[
        LlamaIndexResponse,
        LlamaIndexBaseModel,
        LlamaIndexChatResponse,
        Sequence[LlamaIndexBaseModel],
    ],
) -> Union[str, Dict[str, Any], Sequence[Union[str, Dict[str, Any]]]]:
    """Converts a LlamaIndexResponse to a JSON serializable object."""
    if isinstance(obj, LlamaIndexResponse):
        return _llama_index_response_to_dict(obj)
    if isinstance(obj, LlamaIndexChatResponse):
        return _llama_index_chat_response_to_dict(obj)
    if isinstance(obj, Sequence):
        seq_result = []
        for item in obj:
            if isinstance(item, LlamaIndexBaseModel):
                seq_result.append(_llama_index_base_model_to_dict(item))
                continue
            seq_result.append(str(item))
        return seq_result
    if isinstance(obj, LlamaIndexBaseModel):
        return _llama_index_base_model_to_dict(obj)
    return str(obj)


def yield_parsed_json(body: httpbody_pb2.HttpBody) -> Iterable[Any]:
    """Converts the contents of the httpbody message to JSON format.

    Args:
        body (httpbody_pb2.HttpBody):
            Required. The httpbody body to be converted to a JSON.

    Yields:
        Any: A JSON object or the original body if it is not JSON or None.
    """
    content_type = getattr(body, "content_type", None)
    data = getattr(body, "data", None)

    if content_type is None or data is None or "application/json" not in content_type:
        yield body
        return

    try:
        utf8_data = data.decode("utf-8")
    except Exception as e:
        _LOGGER.warning(f"Failed to decode data: {data}. Exception: {e}")
        yield body
        return

    if not utf8_data:
        yield None
        return

    # Handle the case of multiple dictionaries delimited by newlines.
    for line in utf8_data.split("\n"):
        if line:
            try:
                line = json.loads(line)
            except Exception as e:
                _LOGGER.warning(f"failed to parse json: {line}. Exception: {e}")
            yield line


def generate_schema(
    f: Callable[..., Any],
    *,
    schema_name: Optional[str] = None,
    descriptions: Mapping[str, str] = {},
    required: Sequence[str] = [],
) -> JsonDict:
    """Generates the OpenAPI Schema for a callable object.

    Only positional and keyword arguments of the function `f` will be supported
    in the OpenAPI Schema that is generated. I.e. `*args` and `**kwargs` will
    not be present in the OpenAPI schema returned from this function. For those
    cases, you can either include it in the docstring for `f`, or modify the
    OpenAPI schema returned from this function to include additional arguments.

    Args:
        f (Callable):
            Required. The function to generate an OpenAPI Schema for.
        schema_name (str):
            Optional. The name for the OpenAPI schema. If unspecified, the name
            of the Callable will be used.
        descriptions (Mapping[str, str]):
            Optional. A `{name: description}` mapping for annotating input
            arguments of the function with user-provided descriptions. It
            defaults to an empty dictionary (i.e. there will not be any
            description for any of the inputs).
        required (Sequence[str]):
            Optional. For the user to specify the set of required arguments in
            function calls to `f`. If specified, it will be automatically
            inferred from `f`.

    Returns:
        dict[str, Any]: The OpenAPI Schema for the function `f` in JSON format.
    """
    pydantic = _import_pydantic_or_raise()
    defaults = dict(inspect.signature(f).parameters)
    fields_dict = {
        name: (
            # 1. We infer the argument type here: use Any rather than None so
            # it will not try to auto-infer the type based on the default value.
            (param.annotation if param.annotation != inspect.Parameter.empty else Any),
            pydantic.Field(
                # 2. We do not support default values for now.
                # default=(
                #     param.default if param.default != inspect.Parameter.empty
                #     else None
                # ),
                # 3. We support user-provided descriptions.
                description=descriptions.get(name, None),
            ),
        )
        for name, param in defaults.items()
        # We do not support *args or **kwargs
        if param.kind
        in (
            inspect.Parameter.POSITIONAL_OR_KEYWORD,
            inspect.Parameter.KEYWORD_ONLY,
            inspect.Parameter.POSITIONAL_ONLY,
        )
    }
    parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
    # Postprocessing
    # 4. Suppress unnecessary title generation:
    #    * https://github.com/pydantic/pydantic/issues/1051
    #    * http://cl/586221780
    parameters.pop("title", "")
    for name, function_arg in parameters.get("properties", {}).items():
        function_arg.pop("title", "")
        annotation = defaults[name].annotation
        # 5. Nullable fields:
        #     * https://github.com/pydantic/pydantic/issues/1270
        #     * https://stackoverflow.com/a/58841311
        #     * https://github.com/pydantic/pydantic/discussions/4872
        if typing.get_origin(annotation) is Union and type(None) in typing.get_args(
            annotation
        ):
            # for "typing.Optional" arguments, function_arg might be a
            # dictionary like
            #
            #   {'anyOf': [{'type': 'integer'}, {'type': 'null'}]
            for schema in function_arg.pop("anyOf", []):
                schema_type = schema.get("type")
                if schema_type and schema_type != "null":
                    function_arg["type"] = schema_type
                    break
            function_arg["nullable"] = True
    # 6. Annotate required fields.
    if required:
        # We use the user-provided "required" fields if specified.
        parameters["required"] = required
    else:
        # Otherwise we infer it from the function signature.
        parameters["required"] = [
            k
            for k in defaults
            if (
                defaults[k].default == inspect.Parameter.empty
                and defaults[k].kind
                in (
                    inspect.Parameter.POSITIONAL_OR_KEYWORD,
                    inspect.Parameter.KEYWORD_ONLY,
                    inspect.Parameter.POSITIONAL_ONLY,
                )
            )
        ]
    schema = dict(name=f.__name__, description=f.__doc__, parameters=parameters)
    if schema_name:
        schema["name"] = schema_name
    return schema


def is_noop_or_proxy_tracer_provider(tracer_provider) -> bool:
    """Returns True if the tracer_provider is Proxy or NoOp."""
    opentelemetry = _import_opentelemetry_or_warn()
    ProxyTracerProvider = opentelemetry.trace.ProxyTracerProvider
    NoOpTracerProvider = opentelemetry.trace.NoOpTracerProvider
    return isinstance(tracer_provider, (NoOpTracerProvider, ProxyTracerProvider))


def _import_cloud_storage_or_raise() -> types.ModuleType:
    """Tries to import the Cloud Storage module."""
    try:
        from google.cloud import storage
    except ImportError as e:
        raise ImportError(
            "Cloud Storage is not installed. Please call "
            "'pip install google-cloud-aiplatform[agent_engines]'."
        ) from e
    return storage


def _import_cloudpickle_or_raise() -> types.ModuleType:
    """Tries to import the cloudpickle module."""
    try:
        import cloudpickle  # noqa:F401
    except ImportError as e:
        raise ImportError(
            "cloudpickle is not installed. Please call "
            "'pip install google-cloud-aiplatform[agent_engines]'."
        ) from e
    return cloudpickle


def _import_pydantic_or_raise() -> types.ModuleType:
    """Tries to import the pydantic module."""
    try:
        import pydantic

        _ = pydantic.Field
    except AttributeError:
        from pydantic import v1 as pydantic
    except ImportError as e:
        raise ImportError(
            "pydantic is not installed. Please call "
            "'pip install google-cloud-aiplatform[agent_engines]'."
        ) from e
    return pydantic


def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]:
    """Tries to import the opentelemetry module."""
    try:
        import opentelemetry  # noqa:F401

        return opentelemetry
    except ImportError:
        _LOGGER.warning(
            "opentelemetry-sdk is not installed. Please call "
            "'pip install google-cloud-aiplatform[agent_engines]'."
        )
    return None


def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]:
    """Tries to import the opentelemetry.sdk.trace module."""
    try:
        import opentelemetry.sdk.trace  # noqa:F401

        return opentelemetry.sdk.trace
    except ImportError:
        _LOGGER.warning(
            "opentelemetry-sdk is not installed. Please call "
            "'pip install google-cloud-aiplatform[agent_engines]'."
        )
    return None


def _import_cloud_trace_v2_or_warn() -> Optional[types.ModuleType]:
    """Tries to import the google.cloud.trace_v2 module."""
    try:
        import google.cloud.trace_v2

        return google.cloud.trace_v2
    except ImportError:
        _LOGGER.warning(
            "google-cloud-trace is not installed. Please call "
            "'pip install google-cloud-aiplatform[agent_engines]'."
        )
    return None


def _import_cloud_trace_exporter_or_warn() -> Optional[types.ModuleType]:
    """Tries to import the opentelemetry.exporter.cloud_trace module."""
    try:
        import opentelemetry.exporter.cloud_trace  # noqa:F401

        return opentelemetry.exporter.cloud_trace
    except ImportError:
        _LOGGER.warning(
            "opentelemetry-exporter-gcp-trace is not installed. Please "
            "call 'pip install google-cloud-aiplatform[langchain]'."
        )
    return None


def _import_openinference_langchain_or_warn() -> Optional[types.ModuleType]:
    """Tries to import the openinference.instrumentation.langchain module."""
    try:
        import openinference.instrumentation.langchain  # noqa:F401

        return openinference.instrumentation.langchain
    except ImportError:
        _LOGGER.warning(
            "openinference-instrumentation-langchain is not installed. Please "
            "call 'pip install google-cloud-aiplatform[langchain]'."
        )
    return None


def _import_openinference_autogen_or_warn() -> Optional[types.ModuleType]:
    """Tries to import the openinference.instrumentation.autogen module."""
    try:
        import openinference.instrumentation.autogen  # noqa:F401

        return openinference.instrumentation.autogen
    except ImportError:
        _LOGGER.warning(
            "openinference-instrumentation-autogen is not installed. Please "
            "call 'pip install openinference-instrumentation-autogen'."
        )
    return None


def _import_openinference_llama_index_or_warn() -> Optional[types.ModuleType]:
    """Tries to import the openinference.instrumentation.llama_index module."""
    try:
        import openinference.instrumentation.llama_index  # noqa:F401

        return openinference.instrumentation.llama_index
    except ImportError:
        _LOGGER.warning(
            "openinference-instrumentation-llama_index is not installed. Please "
            "call 'pip install google-cloud-aiplatform[llama_index]'."
        )
    return None


def _import_autogen_tools_or_warn() -> Optional[types.ModuleType]:
    """Tries to import the autogen.tools module."""
    try:
        from autogen import tools

        return tools
    except ImportError:
        _LOGGER.warning(
            "autogen.tools is not installed. Please call: `pip install ag2[tools]`"
        )
    return None


def _import_nest_asyncio_or_warn() -> Optional[types.ModuleType]:
    """Tries to import the nest_asyncio module."""
    try:
        import nest_asyncio

        return nest_asyncio
    except ImportError:
        _LOGGER.warning(
            "nest_asyncio is not installed. Please call: `pip install nest-asyncio`"
        )
    return None
