#!/usr/bin/env python3
"""
GENESIS DEPENDENCY GRAPH
=========================
Analyzes and visualizes module dependencies within Genesis.

Features:
    - Import graph construction
    - Circular dependency detection
    - Module classification
    - Dependency metrics
    - Text-based visualization

Usage:
    graph = DependencyGraph()
    graph.build_from_directory("core")
    graph.detect_cycles()
    graph.visualize()
"""

import ast
import json
import re
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Set, Optional, Tuple, Any
from enum import Enum


class ModuleType(Enum):
    """Types of modules."""
    CORE = "core"          # Core system modules
    TOOL = "tool"          # Tool/utility modules
    SKILL = "skill"        # Skill modules
    TEST = "test"          # Test modules
    CONFIG = "config"      # Configuration modules
    EXTERNAL = "external"  # External dependencies


@dataclass
class ModuleInfo:
    """Information about a module."""
    name: str
    path: str
    module_type: ModuleType
    imports: List[str] = field(default_factory=list)
    imported_by: List[str] = field(default_factory=list)
    classes: List[str] = field(default_factory=list)
    functions: List[str] = field(default_factory=list)
    lines: int = 0
    complexity: int = 0

    def to_dict(self) -> Dict:
        return {
            "name": self.name,
            "path": self.path,
            "type": self.module_type.value,
            "imports": self.imports,
            "imported_by": self.imported_by,
            "classes": self.classes,
            "functions": self.functions,
            "lines": self.lines,
            "complexity": self.complexity
        }


@dataclass
class CycleDependency:
    """A circular dependency."""
    modules: List[str]
    severity: str  # "warning" or "critical"

    def __str__(self):
        return " -> ".join(self.modules + [self.modules[0]])


class ImportAnalyzer:
    """
    Analyzes Python imports from source files.
    """

    def __init__(self, project_root: Path):
        self.project_root = project_root
        self.genesis_modules: Set[str] = set()

    def analyze_file(self, file_path: Path) -> ModuleInfo:
        """Analyze a single Python file."""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()

            tree = ast.parse(content)

            imports = []
            classes = []
            functions = []

            for node in ast.walk(tree):
                if isinstance(node, ast.Import):
                    for alias in node.names:
                        imports.append(alias.name.split('.')[0])
                elif isinstance(node, ast.ImportFrom):
                    if node.module:
                        imports.append(node.module.split('.')[0])
                elif isinstance(node, ast.ClassDef):
                    classes.append(node.name)
                elif isinstance(node, ast.FunctionDef):
                    if not node.name.startswith('_') or node.name.startswith('__'):
                        functions.append(node.name)

            # Determine module type
            module_type = self._classify_module(file_path)

            # Count lines
            lines = len(content.split('\n'))

            # Estimate complexity (simple metric)
            complexity = len(classes) * 3 + len(functions)

            return ModuleInfo(
                name=file_path.stem,
                path=str(file_path.relative_to(self.project_root)),
                module_type=module_type,
                imports=list(set(imports)),
                classes=classes,
                functions=functions,
                lines=lines,
                complexity=complexity
            )

        except Exception as e:
            return ModuleInfo(
                name=file_path.stem,
                path=str(file_path.relative_to(self.project_root)),
                module_type=ModuleType.CORE,
                imports=[],
                complexity=0
            )

    def _classify_module(self, file_path: Path) -> ModuleType:
        """Classify module by its path."""
        path_str = str(file_path).lower()

        if 'test' in path_str:
            return ModuleType.TEST
        elif 'skill' in path_str:
            return ModuleType.SKILL
        elif 'tool' in path_str:
            return ModuleType.TOOL
        elif 'config' in path_str:
            return ModuleType.CONFIG
        elif 'core' in path_str:
            return ModuleType.CORE
        else:
            return ModuleType.CORE


class DependencyGraph:
    """
    Builds and analyzes module dependency graph.
    """

    def __init__(self, project_root: Path = None):
        self.project_root = project_root or Path(__file__).parent.parent
        self.modules: Dict[str, ModuleInfo] = {}
        self.edges: Dict[str, Set[str]] = defaultdict(set)
        self.reverse_edges: Dict[str, Set[str]] = defaultdict(set)
        self.external_deps: Set[str] = set()
        self.analyzer = ImportAnalyzer(self.project_root)

    def build_from_directory(self, directory: str = "core"):
        """Build dependency graph from directory."""
        target_dir = self.project_root / directory

        if not target_dir.exists():
            return

        # First pass: collect all genesis module names
        genesis_modules = set()
        for py_file in target_dir.rglob("*.py"):
            if "__pycache__" in str(py_file):
                continue
            genesis_modules.add(py_file.stem)

        self.analyzer.genesis_modules = genesis_modules

        # Second pass: analyze each file
        for py_file in target_dir.rglob("*.py"):
            if "__pycache__" in str(py_file):
                continue

            info = self.analyzer.analyze_file(py_file)
            self.modules[info.name] = info

            # Build edges for internal dependencies
            for imp in info.imports:
                if imp in genesis_modules:
                    self.edges[info.name].add(imp)
                    self.reverse_edges[imp].add(info.name)
                else:
                    self.external_deps.add(imp)

        # Update imported_by lists
        for module_name, info in self.modules.items():
            info.imported_by = list(self.reverse_edges.get(module_name, set()))

    def build_full(self):
        """Build complete graph from all genesis directories."""
        for subdir in ["core", "tools", "skills", "tests"]:
            target = self.project_root / subdir
            if target.exists():
                self.build_from_directory(subdir)

    def detect_cycles(self) -> List[CycleDependency]:
        """Detect circular dependencies."""
        cycles = []
        visited = set()
        rec_stack = set()
        path = []

        def dfs(node: str) -> bool:
            visited.add(node)
            rec_stack.add(node)
            path.append(node)

            for neighbor in self.edges.get(node, set()):
                if neighbor not in visited:
                    if dfs(neighbor):
                        return True
                elif neighbor in rec_stack:
                    # Found cycle
                    cycle_start = path.index(neighbor)
                    cycle_modules = path[cycle_start:]

                    severity = "critical" if len(cycle_modules) <= 2 else "warning"
                    cycles.append(CycleDependency(
                        modules=cycle_modules,
                        severity=severity
                    ))
                    return True

            path.pop()
            rec_stack.remove(node)
            return False

        for module in self.modules:
            if module not in visited:
                dfs(module)

        return cycles

    def get_dependency_metrics(self) -> Dict:
        """Calculate dependency metrics."""
        if not self.modules:
            return {}

        # Calculate metrics
        total_modules = len(self.modules)
        total_edges = sum(len(deps) for deps in self.edges.values())

        # Fan-in (imported by many) and fan-out (imports many)
        fan_in = {name: len(self.reverse_edges.get(name, set())) for name in self.modules}
        fan_out = {name: len(self.edges.get(name, set())) for name in self.modules}

        # Most connected modules
        most_imported = sorted(fan_in.items(), key=lambda x: -x[1])[:5]
        most_importing = sorted(fan_out.items(), key=lambda x: -x[1])[:5]

        # Module layers (how deep in dependency chain)
        layers = self._calculate_layers()

        # Coupling score
        coupling = total_edges / max(total_modules, 1)

        return {
            "total_modules": total_modules,
            "total_dependencies": total_edges,
            "external_dependencies": len(self.external_deps),
            "coupling_score": round(coupling, 2),
            "most_imported": most_imported,
            "most_importing": most_importing,
            "layers": layers,
            "external_deps": sorted(self.external_deps)[:20]
        }

    def _calculate_layers(self) -> Dict[str, int]:
        """Calculate module layers (topological depth)."""
        layers = {}
        remaining = set(self.modules.keys())

        layer = 0
        while remaining:
            # Find modules with no dependencies in remaining set
            layer_modules = {
                m for m in remaining
                if not (self.edges.get(m, set()) & remaining)
            }

            if not layer_modules:
                # Cycle detected, assign remaining to max layer
                for m in remaining:
                    layers[m] = layer
                break

            for m in layer_modules:
                layers[m] = layer

            remaining -= layer_modules
            layer += 1

        return layers

    def get_module_tree(self, module_name: str, depth: int = 3) -> Dict:
        """Get dependency tree for a specific module."""
        def build_tree(name: str, current_depth: int, visited: Set[str]) -> Dict:
            if current_depth > depth or name in visited:
                return {"name": name, "children": [], "truncated": True}

            visited.add(name)
            children = []

            for dep in self.edges.get(name, set()):
                children.append(build_tree(dep, current_depth + 1, visited.copy()))

            return {
                "name": name,
                "type": self.modules.get(name, ModuleInfo(name, "", ModuleType.EXTERNAL)).module_type.value,
                "children": children
            }

        return build_tree(module_name, 0, set())

    def visualize_text(self) -> str:
        """Generate text-based visualization."""
        lines = [
            "=" * 60,
            "GENESIS DEPENDENCY GRAPH",
            "=" * 60,
            ""
        ]

        # By module type
        by_type = defaultdict(list)
        for name, info in self.modules.items():
            by_type[info.module_type.value].append(info)

        for mtype in ["core", "tool", "skill", "test", "config"]:
            if mtype not in by_type:
                continue

            lines.append(f"\n[{mtype.upper()} MODULES]")
            lines.append("-" * 40)

            for info in sorted(by_type[mtype], key=lambda x: x.name):
                deps = self.edges.get(info.name, set())
                imported_by = self.reverse_edges.get(info.name, set())

                lines.append(f"\n  {info.name}.py")
                lines.append(f"    Lines: {info.lines} | Complexity: {info.complexity}")

                if deps:
                    lines.append(f"    Imports: {', '.join(sorted(deps))}")
                if imported_by:
                    lines.append(f"    Used by: {', '.join(sorted(imported_by))}")

        # Cycles
        cycles = self.detect_cycles()
        if cycles:
            lines.append("\n" + "=" * 60)
            lines.append("CIRCULAR DEPENDENCIES DETECTED")
            lines.append("=" * 60)
            for cycle in cycles:
                lines.append(f"  [{cycle.severity.upper()}] {cycle}")

        # Metrics
        metrics = self.get_dependency_metrics()
        lines.extend([
            "",
            "=" * 60,
            "METRICS",
            "=" * 60,
            f"  Total Modules: {metrics.get('total_modules', 0)}",
            f"  Total Dependencies: {metrics.get('total_dependencies', 0)}",
            f"  External Dependencies: {metrics.get('external_dependencies', 0)}",
            f"  Coupling Score: {metrics.get('coupling_score', 0)}",
            "",
            "  Most Imported:",
        ])

        for name, count in metrics.get("most_imported", []):
            lines.append(f"    {name}: {count} importers")

        lines.append("\n  Most Dependencies:")
        for name, count in metrics.get("most_importing", []):
            lines.append(f"    {name}: {count} imports")

        return "\n".join(lines)

    def export_json(self) -> str:
        """Export graph as JSON."""
        data = {
            "modules": {name: info.to_dict() for name, info in self.modules.items()},
            "edges": {name: list(deps) for name, deps in self.edges.items()},
            "metrics": self.get_dependency_metrics(),
            "cycles": [{"modules": c.modules, "severity": c.severity} for c in self.detect_cycles()]
        }
        return json.dumps(data, indent=2)

    def export_dot(self) -> str:
        """Export graph in DOT format for Graphviz."""
        lines = ["digraph Genesis {", '  rankdir=LR;', '  node [shape=box];', '']

        # Color by type
        colors = {
            "core": "lightblue",
            "tool": "lightyellow",
            "skill": "lightgreen",
            "test": "lightgray",
            "config": "lightorange",
            "external": "white"
        }

        # Nodes
        for name, info in self.modules.items():
            color = colors.get(info.module_type.value, "white")
            lines.append(f'  "{name}" [fillcolor="{color}" style="filled"];')

        lines.append('')

        # Edges
        for source, targets in self.edges.items():
            for target in targets:
                lines.append(f'  "{source}" -> "{target}";')

        lines.append('}')
        return '\n'.join(lines)


def main():
    """CLI for dependency graph."""
    import argparse
    parser = argparse.ArgumentParser(description="Genesis Dependency Graph")
    parser.add_argument("command", choices=["analyze", "cycles", "metrics", "tree", "dot", "json"])
    parser.add_argument("--module", help="Module name for tree command")
    parser.add_argument("--dir", default="core", help="Directory to analyze")
    parser.add_argument("--full", action="store_true", help="Analyze all directories")
    args = parser.parse_args()

    graph = DependencyGraph()

    if args.full:
        graph.build_full()
    else:
        graph.build_from_directory(args.dir)

    if args.command == "analyze":
        print(graph.visualize_text())

    elif args.command == "cycles":
        cycles = graph.detect_cycles()
        if cycles:
            print("Circular dependencies found:")
            for cycle in cycles:
                print(f"  [{cycle.severity}] {cycle}")
        else:
            print("No circular dependencies found.")

    elif args.command == "metrics":
        metrics = graph.get_dependency_metrics()
        print(json.dumps(metrics, indent=2, default=str))

    elif args.command == "tree":
        if not args.module:
            print("--module required")
            return
        tree = graph.get_module_tree(args.module)
        print(json.dumps(tree, indent=2))

    elif args.command == "dot":
        print(graph.export_dot())

    elif args.command == "json":
        print(graph.export_json())


if __name__ == "__main__":
    main()
