import json
import os
import hashlib
import datetime
from typing import Dict, Any, List, Optional

class ReasoningTraceAuditor:
    """
    The Genesis Prime Mother's Reasoning Trace Auditor, ensuring full decision provenance.
    This module meticulously logs every decision, its context, and its rationale,
    providing an immutable and replayable audit trail for Queen AIVA's evolution.

    Patent P4: Full Decision Provenance System.
    """

    def __init__(self, log_file_path: str = "genesis_reasoning_trace.jsonl"):
        """
        Initializes the Reasoning Trace Auditor.

        Args:
            log_file_path (str): The path to the file where decision traces will be stored.
                                 Uses JSON Lines format for append-only logging.
        """
        self.log_file_path = log_file_path
        self._ensure_log_directory_exists()
        self._last_entry_hash = self._get_last_hash_from_file()

    def _ensure_log_directory_exists(self):
        """Ensures the directory for the log file exists."""
        dir_name = os.path.dirname(self.log_file_path)
        if dir_name and not os.path.exists(dir_name):
            os.makedirs(dir_name)

    def _calculate_hash(self, data: Dict[str, Any]) -> str:
        """
        Calculates the SHA256 hash of a dictionary.
        The dictionary is first serialized to a canonical JSON string.
        """
        # Ensure consistent serialization for hashing
        serialized_data = json.dumps(data, sort_keys=True, separators=(',', ':'))
        return hashlib.sha256(serialized_data.encode('utf-8')).hexdigest()

    def _get_last_hash_from_file(self) -> Optional[str]:
        """
        Retrieves the hash of the last entry from the log file, if it exists.
        """
        if not os.path.exists(self.log_file_path):
            return None
        try:
            with open(self.log_file_path, 'r') as f:
                last_line = None
                for line in f:
                    last_line = line
                if last_line:
                    entry = json.loads(last_line)
                    return entry.get("entry_hash")
            return None
        except (json.JSONDecodeError, IOError) as e:
            # Log error, but don't prevent further operation if file is corrupt
            print(f"Warning: Could not read last hash from {self.log_file_path}: {e}")
            return None

    def record_decision(self,
                        decision_id: str,
                        agent_id: str,
                        inputs: Dict[str, Any],
                        outputs: Dict[str, Any],
                        reasoning: str,
                        context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        """
        Records a decision made by an agent, providing full provenance.

        Args:
            decision_id (str): A unique identifier for this specific decision.
            agent_id (str): The ID of the agent or module that made the decision.
            inputs (Dict[str, Any]): The input parameters or state that led to the decision.
            outputs (Dict[str, Any]): The resulting outputs or actions taken due to the decision.
            reasoning (str): A natural language or structured explanation of *why* the decision was made.
            context (Optional[Dict[str, Any]]): Additional contextual information relevant to the decision.

        Returns:
            Dict[str, Any]: The complete decision entry that was logged, including hashes.
        """
        timestamp = datetime.datetime.utcnow().isoformat() + "Z" # UTC timestamp

        entry_data = {
            "timestamp": timestamp,
            "decision_id": decision_id,
            "agent_id": agent_id,
            "inputs": inputs,
            "outputs": outputs,
            "reasoning": reasoning,
            "context": context if context is not None else {},
            "previous_entry_hash": self._last_entry_hash # Link to previous entry for immutability
        }

        # Calculate hash for the current entry (excluding its own hash field initially)
        entry_hash = self._calculate_hash(entry_data)
        entry_data["entry_hash"] = entry_hash

        # Append to the log file
        try:
            with open(self.log_file_path, 'a') as f:
                f.write(json.dumps(entry_data) + '\n')
            self._last_entry_hash = entry_hash # Update last hash for the next entry
            return entry_data
        except IOError as e:
            print(f"Error writing to log file {self.log_file_path}: {e}")
            # In a real system, this would trigger more robust error handling/fallback
            raise

    def get_trace(self, decision_id: str) -> Optional[Dict[str, Any]]:
        """
        Retrieves a specific decision trace by its ID.

        Args:
            decision_id (str): The unique identifier of the decision to retrieve.

        Returns:
            Optional[Dict[str, Any]]: The decision entry if found, otherwise None.
        """
        if not os.path.exists(self.log_file_path):
            return None
        try:
            with open(self.log_file_path, 'r') as f:
                for line in f:
                    entry = json.loads(line)
                    if entry.get("decision_id") == decision_id:
                        return entry
            return None
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON from log file {self.log_file_path}: {e}")
            return None

    def get_full_trace(self) -> List[Dict[str, Any]]:
        """
        Retrieves all recorded decision traces.

        Returns:
            List[Dict[str, Any]]: A list of all decision entries.
        """
        traces = []
        if not os.path.exists(self.log_file_path):
            return traces
        try:
            with open(self.log_file_path, 'r') as f:
                for line in f:
                    traces.append(json.loads(line))
            return traces
        except json.JSONDecodeError as e:
            print(f"Error decoding JSON from log file {self.log_file_path}: {e}")
            return []

    def verify_integrity(self) -> bool:
        """
        Verifies the integrity of the audit log by checking the hash chain.
        Ensures that no entries have been tampered with or reordered.

        Returns:
            bool: True if the log's integrity is intact, False otherwise.
        """
        if not os.path.exists(self.log_file_path):
            return True # An empty log is considered intact

        previous_hash = None
        line_number = 0
        try:
            with open(self.log_file_path, 'r') as f:
                for line in f:
                    line_number += 1
                    entry = json.loads(line)

                    # Check previous hash link
                    if entry.get("previous_entry_hash") != previous_hash:
                        print(f"Integrity Error: Hash chain broken at line {line_number}. "
                              f"Expected previous_entry_hash '{previous_hash}', got '{entry.get('previous_entry_hash')}'")
                        return False

                    # Re-calculate and verify current entry's hash
                    stored_entry_hash = entry.pop("entry_hash", None) # Remove it to calculate hash of original data
                    if stored_entry_hash is None:
                        print(f"Integrity Error: Entry at line {line_number} missing 'entry_hash'.")
                        return False

                    calculated_hash = self._calculate_hash(entry)

                    if calculated_hash != stored_entry_hash:
                        print(f"Integrity Error: Entry hash mismatch at line {line_number}. "
                              f"Calculated '{calculated_hash}', Stored '{stored_entry_hash}'")
                        return False

                    previous_hash = stored_entry_hash # Update for the next iteration
            return True
        except (json.JSONDecodeError, IOError) as e:
            print(f"Integrity Error: Problem reading or parsing log file {self.log_file_path}: {e}")
            return False

    def replay_decision(self, decision_id: str) -> Optional[Dict[str, Any]]:
        """
        Retrieves a decision trace for replay. This method itself does not
        'replay' the decision but provides all the necessary data for an
        external system to reconstruct or simulate the decision context.

        Args:
            decision_id (str): The unique identifier of the decision to replay.

        Returns:
            Optional[Dict[str, Any]]: The full decision entry, including inputs,
                                      outputs, reasoning, and context, if found.
        """
        return self.get_trace(decision_id)
