import logging
import psycopg2
import psycopg2.extensions
import os
import subprocess
from typing import List, Tuple

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class MigrationTester:
    """
    Tests database schema migrations.
    """

    def __init__(self, dbname: str, user: str, password: str, host: str, port: int, migrations_dir: str):
        """
        Initializes the MigrationTester with database connection details and migration directory.

        Args:
            dbname: The name of the database.
            user: The database user.
            password: The database password.
            host: The database host.
            port: The database port.
            migrations_dir: The directory containing the migration scripts.
        """
        self.dbname = dbname
        self.user = user
        self.password = password
        self.host = host
        self.port = port
        self.migrations_dir = migrations_dir
        self.conn = None  # type: psycopg2.extensions.connection | None
        self.cursor = None  # type: psycopg2.extensions.cursor | None

    def connect(self) -> None:
        """
        Connects to the PostgreSQL database.
        """
        try:
            self.conn = psycopg2.connect(
                dbname=self.dbname,
                user=self.user,
                password=self.password,
                host=self.host,
                port=self.port
            )
            self.cursor = self.conn.cursor()
            self.conn.autocommit = False # Ensure transactions are used
            logging.info("Successfully connected to the database.")
        except psycopg2.Error as e:
            logging.error(f"Error connecting to the database: {e}")
            raise

    def disconnect(self) -> None:
        """
        Disconnects from the PostgreSQL database.
        """
        try:
            if self.cursor:
                self.cursor.close()
            if self.conn:
                self.conn.close()
            logging.info("Successfully disconnected from the database.")
        except psycopg2.Error as e:
            logging.error(f"Error disconnecting from the database: {e}")

    def execute_sql(self, sql: str) -> None:
        """
        Executes a SQL statement.

        Args:
            sql: The SQL statement to execute.
        """
        try:
            if not self.cursor:
                raise ValueError("Database cursor is not initialized. Call connect() first.")
            self.cursor.execute(sql)
            self.conn.commit()
        except psycopg2.Error as e:
            logging.error(f"Error executing SQL: {e}")
            self.conn.rollback()
            raise

    def get_migration_files(self) -> List[str]:
        """
        Gets a list of migration file paths sorted by filename.

        Returns:
            A list of migration file paths.
        """
        migration_files = [
            os.path.join(self.migrations_dir, f)
            for f in os.listdir(self.migrations_dir)
            if f.endswith('.sql')
        ]
        migration_files.sort()
        return migration_files

    def apply_migrations(self) -> bool:
        """
        Applies all migration scripts in the migrations directory.

        Returns:
            True if all migrations were applied successfully, False otherwise.
        """
        migration_files = self.get_migration_files()
        for migration_file in migration_files:
            logging.info(f"Applying migration: {migration_file}")
            try:
                with open(migration_file, 'r') as f:
                    sql = f.read()
                self.execute_sql(sql)
                logging.info(f"Migration applied successfully: {migration_file}")
            except Exception as e:
                logging.error(f"Error applying migration {migration_file}: {e}")
                return False
        return True

    def rollback_migrations(self) -> bool:
        """
        Rolls back all migrations by executing rollback scripts (if available).
        Assumes rollback scripts are named like `*.rollback.sql`.

        Returns:
            True if all rollbacks were successful, False otherwise.
        """
        migration_files = self.get_migration_files()
        migration_files.reverse() # Rollback in reverse order
        for migration_file in migration_files:
            rollback_file = migration_file.replace(".sql", ".rollback.sql")
            if os.path.exists(rollback_file):
                logging.info(f"Rolling back migration: {migration_file}")
                try:
                    with open(rollback_file, 'r') as f:
                        sql = f.read()
                    self.execute_sql(sql)
                    logging.info(f"Migration rollback successful: {migration_file}")
                except Exception as e:
                    logging.error(f"Error rolling back migration {migration_file}: {e}")
                    return False
            else:
                logging.warning(f"No rollback script found for {migration_file}")
        return True

    def validate_data_integrity(self, validation_queries: List[Tuple[str, int]]) -> bool:
        """
        Validates data integrity after migration by running validation queries.

        Args:
            validation_queries: A list of tuples, where each tuple contains a SQL query and the expected result count.

        Returns:
            True if all validation queries pass, False otherwise.
        """
        try:
            if not self.cursor:
                raise ValueError("Database cursor is not initialized. Call connect() first.")

            for query, expected_count in validation_queries:
                logging.info(f"Running validation query: {query}")
                self.cursor.execute(query)
                result = self.cursor.fetchone()
                if result is None or result[0] != expected_count:
                    logging.error(f"Validation query failed: {query}. Expected count: {expected_count}, Actual count: {result[0] if result else None}")
                    return False
                logging.info(f"Validation query passed: {query}")
            return True
        except psycopg2.Error as e:
            logging.error(f"Error during data integrity validation: {e}")
            return False
        except ValueError as e:
            logging.error(str(e))
            return False

    def run_migration_test(self, validation_queries: List[Tuple[str, int]]) -> bool:
        """
        Runs the complete migration test: forward migration, data integrity validation, and rollback.

        Args:
            validation_queries: A list of tuples, where each tuple contains a SQL query and the expected result count.

        Returns:
            True if the entire test passes, False otherwise.
        """
        try:
            self.connect()

            logging.info("Applying migrations...")
            if not self.apply_migrations():
                logging.error("Migration failed.")
                return False

            logging.info("Validating data integrity...")
            if not self.validate_data_integrity(validation_queries):
                logging.error("Data integrity validation failed.")
                return False

            logging.info("Rolling back migrations...")
            if not self.rollback_migrations():
                logging.error("Rollback failed.")
                return False

            logging.info("Migration test passed successfully.")
            return True

        except Exception as e:
            logging.error(f"Migration test failed: {e}")
            return False
        finally:
            self.disconnect()

if __name__ == '__main__':
    # Example Usage (replace with your actual configuration)
    dbname = "your_db_name"
    user = "your_db_user"
    password = "your_db_password"
    host = "postgresql-genesis-u50607.vm.elestio.app"
    port = 25432
    migrations_dir = "/mnt/e/genesis-system/migrations"  # Replace with your migrations directory

    # Example validation queries (replace with your actual queries)
    validation_queries = [
        ("SELECT COUNT(*) FROM your_table;", 0),  # Example: Check if table is empty after rollback
    ]

    tester = MigrationTester(dbname, user, password, host, port, migrations_dir)
    test_result = tester.run_migration_test(validation_queries)

    if test_result:
        print("Migration test PASSED")
    else:
        print("Migration test FAILED")
