import collections

import numpy as np

try:
    # pip install python-fcl
    import fcl
except BaseException:
    fcl = None


class ContactData:
    """
    Data structure for holding information about a collision contact.
    """

    def __init__(self, names, contact):
        """
        Initialize a ContactData.

        Parameters
        ----------
        names : list of str
          The names of the two objects in order.
        contact : fcl.Contact
          The contact in question.
        """
        self.names = set(names)
        self._inds = {names[0]: contact.b1, names[1]: contact.b2}
        self._normal = contact.normal
        self._point = contact.pos
        self._depth = contact.penetration_depth

    @property
    def normal(self):
        """
        The 3D intersection normal for this contact.

        Returns
        -------
        normal : (3,) float
          The intersection normal.
        """
        return self._normal

    @property
    def point(self):
        """
        The 3D point of intersection for this contact.

        Returns
        -------
        point : (3,) float
          The intersection point.
        """
        return self._point

    @property
    def depth(self):
        """
        The penetration depth of the 3D point of intersection for this contact.

        Returns
        -------
        depth : float
          The penetration depth.
        """
        return self._depth

    def index(self, name):
        """
        Returns the index of the face in contact for the mesh with
        the given name.

        Parameters
        ----------
        name : str
          The name of the target object.

        Returns
        -------
        index : int
          The index of the face in collision
        """
        return self._inds[name]


class DistanceData:
    """
    Data structure for holding information about a distance query.
    """

    def __init__(self, names, result):
        """
        Initialize a DistanceData.

        Parameters
        ----------
        names : list of str
          The names of the two objects in order.
        contact : fcl.DistanceResult
          The distance query result.
        """
        self.names = set(names)
        self._inds = {names[0]: result.b1, names[1]: result.b2}
        self._points = {
            names[0]: result.nearest_points[0],
            names[1]: result.nearest_points[1],
        }
        self._distance = result.min_distance

    @property
    def distance(self):
        """
        Returns the distance between the two objects.

        Returns
        -------
        distance : float
          The euclidean distance between the objects.
        """
        return self._distance

    def index(self, name):
        """
        Returns the index of the closest face for the mesh with
        the given name.

        Parameters
        ----------
        name : str
          The name of the target object.

        Returns
        -------
        index : int
          The index of the face in collisoin.
        """
        return self._inds[name]

    def point(self, name):
        """
        The 3D point of closest distance on the mesh with the given name.

        Parameters
        ----------
        name : str
          The name of the target object.

        Returns
        -------
        point : (3,) float
          The closest point.
        """
        return self._points[name]


class CollisionManager:
    """
    A mesh-mesh collision manager.
    """

    def __init__(self):
        """
        Initialize a mesh-mesh collision manager.
        """
        if fcl is None:
            raise ValueError("No FCL Available! Please install the python-fcl library")
        # {name: {geom:, obj}}
        self._objs = {}
        # {id(bvh) : str, name}
        # unpopulated values will return None
        self._names = collections.defaultdict(lambda: None)

        self._manager = fcl.DynamicAABBTreeCollisionManager()
        self._manager.setup()

    def add_object(self, name, mesh, transform=None):
        """
        Add an object to the collision manager.

        If an object with the given name is already in the manager,
        replace it.

        Parameters
        ----------
        name : str
          An identifier for the object
        mesh : Trimesh object
          The geometry of the collision object
        transform : (4,4) float
          Homogeneous transform matrix for the object
        """

        # if no transform passed, assume identity transform
        if transform is None:
            transform = np.eye(4)
        transform = np.asanyarray(transform, dtype=np.float32)
        if transform.shape != (4, 4):
            raise ValueError("transform must be (4,4)!")

        # create BVH/Convex
        geom = self._get_fcl_obj(mesh)

        # create the FCL transform from (4,4) matrix
        t = fcl.Transform(transform[:3, :3], transform[:3, 3])
        o = fcl.CollisionObject(geom, t)

        # Add collision object to set
        if name in self._objs:
            self._manager.unregisterObject(self._objs[name])
        self._objs[name] = {"obj": o, "geom": geom}
        # store the name of the geometry
        self._names[id(geom)] = name

        self._manager.registerObject(o)
        self._manager.update()
        return o

    def remove_object(self, name):
        """
        Delete an object from the collision manager.

        Parameters
        ----------
        name : str
          The identifier for the object
        """
        if name in self._objs:
            self._manager.unregisterObject(self._objs[name]["obj"])
            self._manager.update(self._objs[name]["obj"])
            # remove objects from _objs
            geom_id = id(self._objs.pop(name)["geom"])
            # remove names
            self._names.pop(geom_id)
        else:
            raise ValueError(f"{name} not in collision manager!")

    def set_transform(self, name, transform):
        """
        Set the transform for one of the manager's objects.
        This replaces the prior transform.

        Parameters
        ----------
        name : str
          An identifier for the object already in the manager
        transform : (4,4) float
          A new homogeneous transform matrix for the object
        """
        if name in self._objs:
            o = self._objs[name]["obj"]
            o.setRotation(transform[:3, :3])
            o.setTranslation(transform[:3, 3])
            self._manager.update(o)
        else:
            raise ValueError(f"{name} not in collision manager!")

    def in_collision_single(
        self, mesh, transform=None, return_names=False, return_data=False
    ):
        """
        Check a single object for collisions against all objects in the
        manager.

        Parameters
        ----------
        mesh : Trimesh object
          The geometry of the collision object
        transform : (4,4) float
          Homogeneous transform matrix
        return_names : bool
          If true, a set is returned containing the names
          of all objects in collision with the object
        return_data :  bool
          If true, a list of ContactData is returned as well

        Returns
        ------------
        is_collision : bool
          True if a collision occurs and False otherwise
        names : set of str
          [OPTIONAL] The set of names of objects that collided with the
          provided one
        contacts : list of ContactData
          [OPTIONAL] All contacts detected
        """
        if transform is None:
            transform = np.eye(4)

        # create BVH/Convex
        geom = self._get_fcl_obj(mesh)

        # create the FCL transform from (4,4) matrix
        t = fcl.Transform(transform[:3, :3], transform[:3, 3])
        o = fcl.CollisionObject(geom, t)

        # Collide with manager's objects
        cdata = fcl.CollisionData()
        if return_names or return_data:
            cdata = fcl.CollisionData(
                request=fcl.CollisionRequest(num_max_contacts=100000, enable_contact=True)
            )

        self._manager.collide(o, cdata, fcl.defaultCollisionCallback)
        result = cdata.result.is_collision

        # If we want to return the objects that were collision, collect them.
        objs_in_collision = set()
        contact_data = []
        if return_names or return_data:
            for contact in cdata.result.contacts:
                cg = contact.o1
                if cg == geom:
                    cg = contact.o2
                name = self._extract_name(cg)

                names = (name, "__external")
                if cg == contact.o2:
                    names = tuple(reversed(names))

                if return_names:
                    objs_in_collision.add(name)
                if return_data:
                    contact_data.append(ContactData(names, contact))

        if return_names and return_data:
            return result, objs_in_collision, contact_data
        elif return_names:
            return result, objs_in_collision
        elif return_data:
            return result, contact_data
        else:
            return result

    def in_collision_internal(self, return_names=False, return_data=False):
        """
        Check if any pair of objects in the manager collide with one another.

        Parameters
        ----------
        return_names : bool
          If true, a set is returned containing the names
          of all pairs of objects in collision.
        return_data :  bool
          If true, a list of ContactData is returned as well

        Returns
        -------
        is_collision : bool
          True if a collision occurred between any pair of objects
          and False otherwise
        names : set of 2-tup
          The set of pairwise collisions. Each tuple
          contains two names in alphabetical order indicating
          that the two corresponding objects are in collision.
        contacts : list of ContactData
          All contacts detected
        """
        cdata = fcl.CollisionData()
        if return_names or return_data:
            cdata = fcl.CollisionData(
                request=fcl.CollisionRequest(num_max_contacts=100000, enable_contact=True)
            )

        self._manager.collide(cdata, fcl.defaultCollisionCallback)

        result = cdata.result.is_collision

        objs_in_collision = set()
        contact_data = []
        if return_names or return_data:
            for contact in cdata.result.contacts:
                names = (self._extract_name(contact.o1), self._extract_name(contact.o2))

                if return_names:
                    objs_in_collision.add(tuple(sorted(names)))
                if return_data:
                    contact_data.append(ContactData(names, contact))

        if return_names and return_data:
            return result, objs_in_collision, contact_data
        elif return_names:
            return result, objs_in_collision
        elif return_data:
            return result, contact_data
        else:
            return result

    def in_collision_other(self, other_manager, return_names=False, return_data=False):
        """
        Check if any object from this manager collides with any object
        from another manager.

        Parameters
        -------------------
        other_manager : CollisionManager
          Another collision manager object
        return_names : bool
          If true, a set is returned containing the names
          of all pairs of objects in collision.
        return_data : bool
          If true, a list of ContactData is returned as well

        Returns
        -------------
        is_collision : bool
          True if a collision occurred between any pair of objects
          and False otherwise
        names : set of 2-tup
          The set of pairwise collisions. Each tuple
          contains two names (first from this manager,
          second from the other_manager) indicating
          that the two corresponding objects are in collision.
        contacts : list of ContactData
          All contacts detected
        """
        cdata = fcl.CollisionData()
        if return_names or return_data:
            cdata = fcl.CollisionData(
                request=fcl.CollisionRequest(num_max_contacts=100000, enable_contact=True)
            )
        self._manager.collide(other_manager._manager, cdata, fcl.defaultCollisionCallback)
        result = cdata.result.is_collision

        objs_in_collision = set()
        contact_data = []
        if return_names or return_data:
            for contact in cdata.result.contacts:
                reverse = False
                names = (
                    self._extract_name(contact.o1),
                    other_manager._extract_name(contact.o2),
                )
                if names[0] is None:
                    names = (
                        self._extract_name(contact.o2),
                        other_manager._extract_name(contact.o1),
                    )
                    reverse = True

                if return_names:
                    objs_in_collision.add(names)
                if return_data:
                    if reverse:
                        names = tuple(reversed(names))
                    contact_data.append(ContactData(names, contact))

        if return_names and return_data:
            return result, objs_in_collision, contact_data
        elif return_names:
            return result, objs_in_collision
        elif return_data:
            return result, contact_data
        else:
            return result

    def min_distance_single(
        self, mesh, transform=None, return_name=False, return_data=False
    ):
        """
        Get the minimum distance between a single object and any
        object in the manager.

        Parameters
        ---------------
        mesh : Trimesh object
          The geometry of the collision object
        transform : (4,4) float
          Homogeneous transform matrix for the object
        return_names : bool
          If true, return name of the closest object
        return_data : bool
          If true, a DistanceData object is returned as well

        Returns
        -------------
        distance : float
          Min distance between mesh and any object in the manager
        name : str
          The name of the object in the manager that was closest
        data : DistanceData
          Extra data about the distance query
        """
        if transform is None:
            transform = np.eye(4)

        # create BVH/Convex
        geom = self._get_fcl_obj(mesh)

        # create the FCL transform from (4,4) matrix
        t = fcl.Transform(transform[:3, :3], transform[:3, 3])
        o = fcl.CollisionObject(geom, t)

        # Collide with manager's objects
        ddata = fcl.DistanceData(fcl.DistanceRequest(enable_signed_distance=True))
        if return_data:
            ddata = fcl.DistanceData(
                fcl.DistanceRequest(
                    enable_nearest_points=True, enable_signed_distance=True
                ),
                fcl.DistanceResult(),
            )

        self._manager.distance(o, ddata, fcl.defaultDistanceCallback)

        distance = ddata.result.min_distance

        # If we want to return the objects that were collision, collect them.
        name, data = None, None
        if return_name or return_data:
            cg = ddata.result.o1
            if cg == geom:
                cg = ddata.result.o2

            name = self._extract_name(cg)

            names = (name, "__external")
            if cg == ddata.result.o2:
                names = tuple(reversed(names))
            data = DistanceData(names, ddata.result)

        if return_name and return_data:
            return distance, name, data
        elif return_name:
            return distance, name
        elif return_data:
            return distance, data
        else:
            return distance

    def min_distance_internal(self, name=None, return_names=False, return_data=False):
        """
        Get the minimum distance between objects in the manager.

        If name is provided, computes the minimum distance between the
        specified object and any other object in the manager.
        If name is None, computes the minimum distance between any pair
        of objects in the manager.

        Parameters
        -------------
        name : str or None
          If provided, the identifier for the object already in the manager
          to compute distances from. If None, computes distances between
          all pairs of objects.
        return_names : bool
          If true, a 2-tuple is returned containing the names
          of the closest objects.
        return_data : bool
          If true, a DistanceData object is returned as well

        Returns
        -----------
        distance : float
          Min distance between objects
        names : (2,) str
          The names of the closest objects
        data : DistanceData
          Extra data about the distance query
        """
        ddata = fcl.DistanceData(fcl.DistanceRequest(enable_signed_distance=True))
        if return_data:
            ddata = fcl.DistanceData(
                fcl.DistanceRequest(
                    enable_nearest_points=True,
                    enable_signed_distance=True,
                ),
                fcl.DistanceResult(),
            )

        # If name is provided, compute distance from that object to others
        if name is not None:
            if name not in self._objs:
                raise ValueError(f"{name} not in collision manager!")
            obj = self._objs[name]["obj"]
            # remove object from manager temporarily
            self._manager.unregisterObject(obj)
            self._manager.update(obj)

            # compute distance
            self._manager.distance(obj, ddata, fcl.defaultDistanceCallback)

            # add it back to the manager
            self._manager.registerObject(obj)
            self._manager.update()

        else:
            # Compute distance between any pair of objects
            self._manager.distance(ddata, fcl.defaultDistanceCallback)

        distance = ddata.result.min_distance

        names, data = None, None
        if return_names or return_data:
            names = (
                self._extract_name(ddata.result.o1),
                self._extract_name(ddata.result.o2),
            )
            data = DistanceData(names, ddata.result)
            names = tuple(sorted(names))

        if return_names and return_data:
            return distance, names, data
        elif return_names:
            return distance, names
        elif return_data:
            return distance, data
        else:
            return distance

    def min_distance_other(self, other_manager, return_names=False, return_data=False):
        """
        Get the minimum distance between any pair of objects,
        one in each manager.

        Parameters
        ----------
        other_manager : CollisionManager
          Another collision manager object
        return_names : bool
          If true, a 2-tuple is returned containing
          the names of the closest objects.
        return_data : bool
          If true, a DistanceData object is returned as well

        Returns
        -----------
        distance : float
          The min distance between a pair of objects,
          one from each manager.
        names : 2-tup of str
          A 2-tuple containing two names (first from this manager,
          second from the other_manager) indicating
          the two closest objects.
        data : DistanceData
          Extra data about the distance query
        """
        ddata = fcl.DistanceData(fcl.DistanceRequest(enable_signed_distance=True))
        if return_data:
            ddata = fcl.DistanceData(
                fcl.DistanceRequest(
                    enable_nearest_points=True,
                    enable_signed_distance=True,
                ),
                fcl.DistanceResult(),
            )

        self._manager.distance(other_manager._manager, ddata, fcl.defaultDistanceCallback)

        distance = ddata.result.min_distance

        names, data = None, None
        if return_names or return_data:
            reverse = False
            names = (
                self._extract_name(ddata.result.o1),
                other_manager._extract_name(ddata.result.o2),
            )
            if names[0] is None:
                reverse = True
                names = (
                    self._extract_name(ddata.result.o2),
                    other_manager._extract_name(ddata.result.o1),
                )

            dnames = tuple(names)
            if reverse:
                dnames = tuple(reversed(dnames))
            data = DistanceData(dnames, ddata.result)

        if return_names and return_data:
            return distance, names, data
        elif return_names:
            return distance, names
        elif return_data:
            return distance, data
        else:
            return distance

    def _get_fcl_obj(self, mesh):
        """
        Get a BVH or Convex for a mesh.

        Parameters
        -------------
        mesh : Trimesh
          Mesh to create BVH/Convex for

        Returns
        --------------
        obj : fcl.BVHModel or fcl.Convex
          BVH/Convex object of source mesh
        """

        if mesh.is_convex:
            obj = mesh_to_convex(mesh)
        else:
            obj = mesh_to_BVH(mesh)
        return obj

    def _extract_name(self, geom):
        """
        Retrieve the name of an object from the manager by its
        CollisionObject, or return None if not found.

        Parameters
        -----------
        geom : CollisionObject or BVHModel
          Input model

        Returns
        ------------
        names : hashable
          Name of input geometry
        """
        return self._names[id(geom)]


def mesh_to_BVH(mesh):
    """
    Create a BVHModel object from a Trimesh object

    Parameters
    -----------
    mesh : Trimesh
      Input geometry

    Returns
    ------------
    bvh : fcl.BVHModel
      BVH of input geometry
    """
    bvh = fcl.BVHModel()
    bvh.beginModel(num_tris_=len(mesh.faces), num_vertices_=len(mesh.vertices))
    bvh.addSubModel(verts=mesh.vertices, triangles=mesh.faces)
    bvh.endModel()
    return bvh


def mesh_to_convex(mesh):
    """
    Create a Convex object from a Trimesh object

    Parameters
    -----------
    mesh : Trimesh
      Input geometry

    Returns
    ------------
    convex : fcl.Convex
      Convex of input geometry
    """
    fs = np.concatenate(
        (3 * np.ones((len(mesh.faces), 1), dtype=np.int64), mesh.faces), axis=1
    )
    return fcl.Convex(mesh.vertices, len(fs), fs.flatten())


def scene_to_collision(scene):
    """
    Create collision objects from a trimesh.Scene object.

    Parameters
    ------------
    scene : trimesh.Scene
      Scene to create collision objects for

    Returns
    ------------
    manager : CollisionManager
      CollisionManager for objects in scene
    objects: {node name: CollisionObject}
      Collision objects for nodes in scene
    """
    manager = CollisionManager()
    objects = {}
    for node in scene.graph.nodes_geometry:
        T, geometry = scene.graph[node]
        objects[node] = manager.add_object(
            name=node, mesh=scene.geometry[geometry], transform=T
        )
    return manager, objects
