# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors


from __future__ import annotations

from datetime import timedelta
from typing import TYPE_CHECKING, List, Optional

if TYPE_CHECKING:
    from .common import DATA
    from ._lancedb import (
        MergeInsertResult,
    )


class LanceMergeInsertBuilder(object):
    """Builder for a LanceDB merge insert operation

    See [`merge_insert`][lancedb.table.Table.merge_insert] for
    more context
    """

    def __init__(self, table: "Table", on: List[str]):  # noqa: F821
        # Do not put a docstring here.  This method should be hidden
        # from API docs.  Users should use merge_insert to create
        # this object.
        self._table = table
        self._on = on
        self._when_matched_update_all = False
        self._when_matched_update_all_condition = None
        self._when_not_matched_insert_all = False
        self._when_not_matched_by_source_delete = False
        self._when_not_matched_by_source_condition = None
        self._timeout = None
        self._use_index = True

    def when_matched_update_all(
        self, *, where: Optional[str] = None
    ) -> LanceMergeInsertBuilder:
        """
        Rows that exist in both the source table (new data) and
        the target table (old data) will be updated, replacing
        the old row with the corresponding matching row.

        If there are multiple matches then the behavior is undefined.
        Currently this causes multiple copies of the row to be created
        but that behavior is subject to change.
        """
        self._when_matched_update_all = True
        self._when_matched_update_all_condition = where
        return self

    def when_not_matched_insert_all(self) -> LanceMergeInsertBuilder:
        """
        Rows that exist only in the source table (new data) should
        be inserted into the target table.
        """
        self._when_not_matched_insert_all = True
        return self

    def when_not_matched_by_source_delete(
        self, condition: Optional[str] = None
    ) -> LanceMergeInsertBuilder:
        """
        Rows that exist only in the target table (old data) will be
        deleted.  An optional condition can be provided to limit what
        data is deleted.

        Parameters
        ----------
        condition: Optional[str], default None
            If None then all such rows will be deleted.  Otherwise the
            condition will be used as an SQL filter to limit what rows
            are deleted.
        """
        self._when_not_matched_by_source_delete = True
        if condition is not None:
            self._when_not_matched_by_source_condition = condition
        return self

    def use_index(self, use_index: bool) -> LanceMergeInsertBuilder:
        """
        Controls whether to use indexes for the merge operation.

        When set to `True` (the default), the operation will use an index if available
        on the join key for improved performance. When set to `False`, it forces a full
        table scan even if an index exists. This can be useful for benchmarking or when
        the query optimizer chooses a suboptimal path.

        Parameters
        ----------
        use_index: bool
            Whether to use indices for the merge operation. Defaults to `True`.
        """
        self._use_index = use_index
        return self

    def execute(
        self,
        new_data: DATA,
        on_bad_vectors: str = "error",
        fill_value: float = 0.0,
        timeout: Optional[timedelta] = None,
    ) -> MergeInsertResult:
        """
        Executes the merge insert operation

        Nothing is returned but the [`Table`][lancedb.table.Table] is updated

        Parameters
        ----------
        new_data: DATA
            New records which will be matched against the existing records
            to potentially insert or update into the table.  This parameter
            can be anything you use for [`add`][lancedb.table.Table.add]
        on_bad_vectors: str, default "error"
            What to do if any of the vectors are not the same size or contains NaNs.
            One of "error", "drop", "fill".
        fill_value: float, default 0.
            The value to use when filling vectors. Only used if on_bad_vectors="fill".
        timeout: Optional[timedelta], default None
            Maximum time to run the operation before cancelling it.

            By default, there is a 30-second timeout that is only enforced after the
            first attempt. This is to prevent spending too long retrying to resolve
            conflicts. For example, if a write attempt takes 20 seconds and fails,
            the second attempt will be cancelled after 10 seconds, hitting the
            30-second timeout. However, a write that takes one hour and succeeds on the
            first attempt will not be cancelled.

            When this is set, the timeout is enforced on all attempts, including
            the first.

        Returns
        -------
        MergeInsertResult
            version: the new version number of the table after doing merge insert.
        """
        if timeout is not None:
            self._timeout = timeout
        return self._table._do_merge(self, new_data, on_bad_vectors, fill_value)
