"""
This module provides hardened parsing functions to mitigate against various attack vectors.

It focuses on preventing vulnerabilities such as:
- Arbitrary code execution through crafted input
- Denial-of-service attacks due to resource exhaustion
- Information leakage via error messages or unexpected behavior

The module provides replacement functions for potentially vulnerable built-in parsing functions
or commonly used parsing libraries.  It prioritizes security and robustness over raw performance.

Specific hardening techniques employed:
- Input sanitization and validation
- Resource limits (e.g., maximum string length, nesting depth)
- Error handling and safe fallback mechanisms
- Avoiding unsafe functions
"""

import json
import xml.etree.ElementTree as ET
import logging
import re

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class ParserHardeningError(Exception):
    """Base class for exceptions raised by this module."""
    pass

def safe_json_parse(json_string, max_length=4096, max_nesting=10):
    """
    Safely parses a JSON string with input validation and resource limits.

    Args:
        json_string (str): The JSON string to parse.
        max_length (int): The maximum allowed length of the JSON string.  Defaults to 4096.
        max_nesting (int): The maximum allowed nesting depth in the JSON structure. Defaults to 10.

    Returns:
        dict: The parsed JSON object.

    Raises:
        ParserHardeningError: If the JSON string is invalid, exceeds the length limit,
                              or exceeds the nesting depth limit.
    """
    if not isinstance(json_string, str):
        raise ParserHardeningError("Input must be a string.")

    if len(json_string) > max_length:
        raise ParserHardeningError(f"JSON string exceeds maximum length of {max_length} characters.")

    try:
        # Use a custom object hook to track nesting depth.  This is a simple but effective
        # defense against excessive nesting, which can lead to stack overflows or excessive
        # memory consumption.
        depth = 0
        def object_hook(dct):
            nonlocal depth
            depth += 1
            if depth > max_nesting:
                raise ParserHardeningError(f"JSON exceeds maximum nesting depth of {max_nesting}.")
            return dct

        parsed_data = json.loads(json_string, object_hook=object_hook)
        return parsed_data
    except json.JSONDecodeError as e:
        raise ParserHardeningError(f"Invalid JSON string: {e}")
    except ParserHardeningError as e:
        raise e # Re-raise our custom exception

def safe_xml_parse(xml_string, max_length=4096, allowed_tags=None, allowed_attributes=None):
    """
    Safely parses an XML string with input validation, resource limits and tag/attribute whitelisting.

    Args:
        xml_string (str): The XML string to parse.
        max_length (int): The maximum allowed length of the XML string. Defaults to 4096.
        allowed_tags (list): A list of allowed XML tag names. If None, all tags are allowed.
        allowed_attributes (list): A list of allowed XML attribute names. If None, all attributes are allowed.

    Returns:
        xml.etree.ElementTree.Element: The root element of the parsed XML tree.

    Raises:
        ParserHardeningError: If the XML string is invalid, exceeds the length limit,
                              or contains disallowed tags or attributes.
    """
    if not isinstance(xml_string, str):
        raise ParserHardeningError("Input must be a string.")

    if len(xml_string) > max_length:
        raise ParserHardeningError(f"XML string exceeds maximum length of {max_length} characters.")

    try:
        root = ET.fromstring(xml_string)

        # Validate tags and attributes against the whitelist
        for element in root.iter():
            if allowed_tags is not None and element.tag not in allowed_tags:
                raise ParserHardeningError(f"Disallowed tag: {element.tag}")

            for attribute in element.attrib:
                if allowed_attributes is not None and attribute not in allowed_attributes:
                    raise ParserHardeningError(f"Disallowed attribute: {attribute}")

        return root
    except ET.ParseError as e:
        raise ParserHardeningError(f"Invalid XML string: {e}")

def safe_regex_search(pattern, text, max_length=1024, max_groups=10):
  """
  Safely searches for a pattern in text, preventing catastrophic backtracking and limiting group captures.

  Args:
      pattern (str): The regular expression pattern.
      text (str): The text to search within.
      max_length (int): Maximum length of the text string.
      max_groups (int): Maximum number of capturing groups allowed.

  Returns:
      re.Match object or None: The match object if a match is found, otherwise None.

  Raises:
      ParserHardeningError: If the text is too long or the regex is too complex.
  """

  if not isinstance(text, str):
      raise ParserHardeningError("Text must be a string.")

  if len(text) > max_length:
    raise ParserHardeningError(f"Text exceeds maximum length of {max_length} characters.")

  # Prevent catastrophic backtracking by using a slightly modified pattern
  # that limits repetition when possible and avoids nested quantifiers when feasible.
  # This is a heuristic and not a perfect solution, but can greatly improve performance.
  # For example, convert (a+)* to (?:a+)* and be careful about using .*

  try:
    # Limit the number of capturing groups.
    #  Complex regexes with many capturing groups can lead to resource exhaustion.
    compiled_pattern = re.compile(pattern)  # Compile once for efficiency

    if compiled_pattern.groups > max_groups:
        raise ParserHardeningError(f"Regex exceeds maximum number of capturing groups: {max_groups}")


    match = compiled_pattern.search(text)
    return match

  except re.error as e:
    raise ParserHardeningError(f"Invalid regular expression: {e}")


if __name__ == '__main__':
    # Example usage
    try:
        json_data = '{"name": "John Doe", "age": 30}'
        parsed_json = safe_json_parse(json_data)
        print(f"Parsed JSON: {parsed_json}")

        xml_data = '<root><element attribute="value">Text</element></root>'
        parsed_xml = safe_xml_parse(xml_data, allowed_tags=['root', 'element'], allowed_attributes=['attribute'])
        print(f"Parsed XML root tag: {parsed_xml.tag}")

        regex_pattern = r"hello (\w+)"
        text = "hello world"
        match = safe_regex_search(regex_pattern, text)
        if match:
            print(f"Regex match: {match.group(1)}")

    except ParserHardeningError as e:
        print(f"Error: {e}")