import os
import json
from typing import Dict, List, Any

class PRDArchive:
    def __init__(self, archive_dir: str = 'archive'):
        self.archive_dir = archive_dir
        self.index_file = os.path.join(archive_dir, 'index.json')
        self._ensure_archive_dir()

    def _ensure_archive_dir(self):
        os.makedirs(self.archive_dir, exist_ok=True)
        if not os.path.exists(self.index_file):
            self._save_index({})

    def archive(self, prd_id: str, prd_content: Dict[str, Any]) -> bool:
        try:
            file_path = os.path.join(self.archive_dir, f'{prd_id}.json')
            with open(file_path, 'w', encoding='utf-8') as f:
                json.dump(prd_content, f, indent=2)
            
            index = self._load_index()
            index[prd_id] = file_path
            self._save_index(index)
            
            return True
        except Exception as e:
            raise RuntimeError(f'Failed to archive PRD: {str(e)}') from e

    def extract_learnings(self) -> Dict[str, List[Any]]:
        index = self._load_index()
        learnings = {
            'key_achievements': [],
            'common_requirements': {}
        }
        
        for prd_id, path in index.items():
            with open(path, 'r', encoding='utf-8') as f:
                prd = json.load(f)
                
                if 'key_achievements' in prd:
                    learnings['key_achievements'].extend(prd['key_achievements'])
                
                if 'requirements' in prd:
                    for req in prd['requirements']:
                        learnings['common_requirements'][req] = learnings['common_requirements'].get(req, 0) + 1
        
        return learnings

    def identify_patterns(self) -> Dict[str, List[str]]:
        index = self._load_index()
        key_occurrence = {}
        all_keys = set()
        
        for prd_id, path in index.items():
            with open(path, 'r', encoding='utf-8') as f:
                prd = json.load(f)
                keys = prd.keys()
                all_keys.update(keys)
                
                for key in keys:
                    key_occurrence[key] = key_occurrence.get(key, 0) + 1
        
        # Pattern: keys appearing in at least 50% of PRDs
        threshold = max(1, len(index) // 2)
        patterns = [key for key, count in key_occurrence.items() if count >= threshold]
        
        return {
            'common_keys': list(all_keys),
            'pattern_keys': patterns,
            'key_occurrence': key_occurrence
        }

    def _load_index(self) -> Dict[str, str]:
        if os.path.exists(self.index_file):
            with open(self.index_file, 'r', encoding='utf-8') as f:
                return json.load(f)
        return {}

    def _save_index(self, index: Dict[str, str]):
        with open(self.index_file, 'w', encoding='utf-8') as f:
            json.dump(index, f, indent=2)