# 05.01.26

import os
import json
import shutil
import logging
from typing import Dict


# External libraries
from rich.console import Console


# Internal utilities
from StreamingCommunity.utils.http_client import get_headers
from StreamingCommunity.core.processors import join_video, join_audios, join_subtitles
from StreamingCommunity.core.processors.helper.nfo import create_nfo
from StreamingCommunity.source.utils.tracker import download_tracker, context_tracker
from StreamingCommunity.source.utils.media_players import MediaPlayers
from StreamingCommunity.utils import config_manager, os_manager, internet_manager


# DRM Utilities
from StreamingCommunity.source.N_m3u8 import MediaDownloader
from StreamingCommunity.setup import get_wvd_path, get_prd_path
from ..parser import MPDParser, DRMSystem
from ..drm import DRMManager


# Config
console = Console()
CLEANUP_TMP = config_manager.config.get_bool('M3U8_DOWNLOAD', 'cleanup_tmp_folder')
EXTENSION_OUTPUT = config_manager.config.get("M3U8_CONVERSION", "extension")
SKIP_DOWNLOAD = config_manager.config.get_bool('M3U8_DOWNLOAD', 'skip_download')
CREATE_NFO_FILES = config_manager.config.get_bool('M3U8_CONVERSION', 'generate_nfo', default=False)


class DASH_Downloader:
    def __init__(self, license_url: str, license_headers: Dict[str, str] = None, mpd_url: str = None, mpd_headers: Dict[str, str] = None, mpd_sub_list: list = None, output_path: str = None, drm_preference: str = 'widevine', decrypt_preference: str = "bento4", key: str = None, cookies: Dict[str, str] = None):
        """
        Initialize DASH Downloader.
        
        Parameters:
            license_url: URL to obtain DRM license
            mpd_url: URL of the MPD manifest
            mpd_sub_list: List of subtitle dicts (unused with MediaDownloader)
            output_path: Full path including filename and extension (e.g., /path/to/video.mp4)
            drm_preference: Preferred DRM system ('widevine', 'playready', 'auto')
        """
        self.mpd_url = str(mpd_url).strip() if mpd_url else None
        self.license_url = str(license_url).strip() if license_url else None
        self.mpd_headers = mpd_headers or get_headers()
        self.license_headers = license_headers
        self.mpd_sub_list = mpd_sub_list or []
        self.drm_preference = drm_preference.lower()
        self.key = key
        self.cookies = cookies or {}
        self.decrypt_preference = decrypt_preference.lower()
        self.drm_manager = DRMManager(get_wvd_path(), get_prd_path(), config_manager.remote_cdm.get('remote_cdm', 'widevine'), config_manager.remote_cdm.get('remote_cdm', 'playready'))
        
        # Tracking IDs - check context if not provided
        self.download_id = context_tracker.download_id
        self.site_name = context_tracker.site_name
        self.raw_mpd_path = None
        
        # Setup output path
        self.output_path = os_manager.get_sanitize_path(output_path)
        if not self.output_path.endswith(f'.{EXTENSION_OUTPUT}'):
            self.output_path += f'.{EXTENSION_OUTPUT}'
        
        self.filename_base = os.path.splitext(os.path.basename(self.output_path))[0]
        self.output_dir = os.path.join(os.path.dirname(self.output_path), self.filename_base + "_dash_temp")
        self.file_already_exists = os.path.exists(self.output_path)
        
        # DRM and state
        self.drm_info = None
        self.decryption_keys = []
        self.kid_to_label = {}
        self.media_downloader = None
        self.meta_json = self.meta_selected = self.raw_mpd = None
        self.error = None
        self.last_merge_result = None
        self.media_players = None
    
    def _setup_drm_info(self, selected_ids, selected_kids, selected_langs, selected_periods):
        """Fetch and setup DRM information."""
        try:
            parser = MPDParser(self.mpd_url, headers=self.mpd_headers)
            parser.parse_from_file(self.raw_mpd)
            
            # Map KIDs to labels
            self._map_kids_to_labels(parser, selected_ids, selected_kids, selected_langs, selected_periods)
            
            # Get DRM info
            self.drm_info = parser.get_drm_info(
                self.drm_preference, selected_ids, selected_kids, 
                selected_langs, selected_periods
            )
            return True
        
        except Exception as e:
            console.print(f"[yellow]Warning parsing MPD: {e}")
            return False
    
    def _map_kids_to_labels(self, parser, selected_ids, selected_kids, selected_langs, selected_periods):
        """Map KIDs to descriptive labels."""
        self.kid_to_label = {}
        sets = parser.get_adaptation_sets_info(selected_ids, selected_kids, selected_langs, selected_periods)
        
        # Group by content type
        groups = {}
        for s in sets:
            if s['content_type'] in ('image', 'text'):
                continue
            groups.setdefault(s['content_type'], []).append(s)
        
        has_filter = any([selected_ids, selected_kids, selected_langs])
        
        for c_type, items in groups.items():
            is_uni = len({i['default_kid'] for i in items}) == 1 and not has_filter
            
            for item in items:
                if not item['default_kid']:
                    continue
                
                norm_kid = item['default_kid'].lower().replace('-', '')
                label = self._generate_label(item, c_type, is_uni)
                self.kid_to_label[norm_kid] = label
    
    def _generate_label(self, item, content_type, is_uniform):
        """Generate label for a stream."""
        if is_uniform:
            return f"all {content_type}"
        
        parts = [content_type]
        if item.get('height'):
            parts.append(f"{item['height']}p")
        if item.get('language') and item['language'] != 'N/A':
            parts.append(f"({item['language']})")
        
        return " ".join(parts)
    
    def _fetch_decryption_keys(self):
        """Fetch decryption keys based on DRM type."""
        if len(self.drm_info.get('available_drm_types', [])) > 0 and (self.license_url is None or self.license_url == "") or len(self.drm_info.get('available_drm_types', [])) > 0 and (self.key is None or self.key == ""):
            if (len(self.drm_info.get('available_drm_types', [])) > 0 and (not self.license_url or self.license_url == "") and (not self.key or self.key == "")):
                console.print("[yellow]DRM detected but missing both license_url and key. Cannot proceed.")
                self.error = "Missing license_url and key for DRM-protected content"
                return False
            
        drm_type = self.drm_info['selected_drm_type']
        try:
            if drm_type == DRMSystem.WIDEVINE:
                keys = self.drm_manager.get_wv_keys(self.drm_info.get('widevine_pssh', []), self.license_url, self.license_headers, self.key, self.kid_to_label)
            elif drm_type == DRMSystem.PLAYREADY:
                keys = self.drm_manager.get_pr_keys(self.drm_info.get('playready_pssh', []), self.license_url, self.license_headers, self.key, self.kid_to_label)
            else:
                console.print(f"[red]Unsupported DRM type: {drm_type}")
                self.error = f"Unsupported DRM type: {drm_type}"
                return False
        
            if keys:
                self.decryption_keys = keys
                return True
        
            else:
                self.error = "Failed to fetch decryption keys"
                return False
            
        except Exception as e:
            console.print(f"[red]Error fetching keys: {e}")
            self.error = f"Key fetch error: {e}"
            return False
    
    def _extract_selected_track_info(self):
        """Extract selected track information from metadata files."""
        selected_ids, selected_kids, selected_langs, selected_periods = [], [], [], []
        has_video_in_selected = False
        
        # 1. Process meta_selected first if it exists
        if os.path.exists(self.meta_selected):
            try:
                with open(self.meta_selected, "r", encoding="utf-8-sig") as f:
                    data = json.load(f)
                
                for item in data:
                    is_video = bool(item.get("Resolution") or item.get("MediaType") == "VIDEO")
                    if is_video:
                        has_video_in_selected = True
                    
                    # Extract IDs
                    self._extract_ids(item, selected_ids)
                    
                    # Extract language
                    if lang := item.get("Language"):
                        selected_langs.append(lang.lower())
                    
                    # Extract period ID
                    if pid := item.get("PeriodId"):
                        selected_periods.append(str(pid))
                    
                    # Extract KIDs from EncryptInfo
                    self._extract_kids_from_encryptinfo(item, selected_kids)
            except Exception as e:
                console.print(f"[yellow]Warning reading {self.meta_selected}: {e}")

        # 2. Process meta_json for best video ONLY if no video was found in meta_selected
        force_best = getattr(self.media_downloader, "force_best_video", False)
        
        if not has_video_in_selected and force_best and os.path.exists(self.meta_json):
            try:
                with open(self.meta_json, "r", encoding="utf-8-sig") as f:
                    data = json.load(f)
                
                best_video = self._find_best_video(data)
                for item in best_video:
                    self._extract_ids(item, selected_ids)
                    
                    # Extract language
                    if lang := item.get("Language"):
                        selected_langs.append(lang.lower())
                    
                    # Extract period ID
                    if pid := item.get("PeriodId"):
                        selected_periods.append(str(pid))
                    
                    # Extract KIDs from EncryptInfo
                    self._extract_kids_from_encryptinfo(item, selected_kids)
            except Exception as e:
                console.print(f"[yellow]Warning reading {self.meta_json}: {e}")
        
        # Remove duplicates
        return (list(dict.fromkeys(selected_ids)), list(dict.fromkeys(selected_kids)), list(dict.fromkeys(selected_langs)), list(dict.fromkeys(selected_periods)))
    
    def _find_best_video(self, data):
        """Find best video track based on bandwidth."""
        videos = [
            item for item in data 
            if (item.get("Resolution") or item.get("MediaType") == "VIDEO") 
            and item.get("Bandwidth")
        ]
        return [max(videos, key=lambda x: x.get("Bandwidth", 0))] if videos else []
    
    def _extract_ids(self, item, selected_ids):
        """Extract IDs from item, prioritizing specific ID over GroupId."""
        extracted = []
        tid = item.get("Id", "")
        if tid:
            tid_s = str(tid)
            selected_ids.append(tid_s)
            extracted.append(tid_s)
            if ":" in tid_s:
                part = tid_s.split(":")[-1]
                selected_ids.append(part)
                extracted.append(part)
            if "-" in tid_s:
                part = tid_s.split("-")[-1]
                selected_ids.append(part)
                extracted.append(part)

        elif gid := item.get("GroupId"):
            gid_s = str(gid)
            selected_ids.append(gid_s)
            extracted.append(gid_s)
    
    def _extract_kids_from_encryptinfo(self, item, selected_kids):
        """Extract KIDs from EncryptInfo in MediaInit or MediaSegments."""
        playlist = item.get("Playlist", {})
        for part in playlist.get("MediaParts", []):

            # Check MediaInit for KID (common in PlutoTV and others)
            if init := part.get("MediaInit"):
                if kid_val := init.get("EncryptInfo", {}).get("KID"):
                    selected_kids.append(kid_val.lower().replace("-", ""))
            
            # Check MediaSegments for KID
            for seg in part.get("MediaSegments", []):
                if kid_val := seg.get("EncryptInfo", {}).get("KID"):
                    selected_kids.append(kid_val.lower().replace("-", ""))
    
    def start(self):
        """Main execution flow for downloading DASH content."""
        if self.file_already_exists:
            console.print("[yellow]File already exists.")
            return self.output_path, False
        
        # Create output directory
        os_manager.create_path(self.output_dir)
        
        # Create media player ignore files
        try:
            self.media_players = MediaPlayers(self.output_dir)
            self.media_players.create()
        except Exception:
            pass
        
        # Initialize MediaDownloader
        self.media_downloader = MediaDownloader(
            url=self.mpd_url,
            output_dir=self.output_dir,
            filename=self.filename_base,
            headers=self.mpd_headers,
            cookies=self.cookies,
            decrypt_preference=self.decrypt_preference,
            download_id=self.download_id,
            site_name=self.site_name
        )
        
        if self.mpd_sub_list:
            self.media_downloader.external_subtitles = self.mpd_sub_list
        
        if self.download_id:
            download_tracker.update_status(self.download_id, "Parsing...")
            
        self.media_downloader.parser_stream()
        
        # Get metadata
        self.meta_json, self.meta_selected, _, self.raw_mpd = self.media_downloader.get_metadata()
        
        # Extract selected track info
        selected_info = self._extract_selected_track_info()
        
        # Fetch DRM info
        if not self._setup_drm_info(*selected_info):
            logging.error("Failed to fetch DRM info")
            if self.download_id:
                download_tracker.complete_download(self.download_id, success=False, error="DRM parsing failed")
            return None, True
        
        # Fetch decryption keys if DRM protected
        if self.drm_info and self.drm_info['available_drm_types']:
            if not self._fetch_decryption_keys():
                logging.error(f"Failed to fetch decryption keys: {self.error}")
                if self.download_id:
                    download_tracker.complete_download(self.download_id, success=False, error=self.error)
                return None, True
            
        if SKIP_DOWNLOAD:
            console.print("[yellow]Skipping download as per configuration.")
            return self.output_path, False
        
        # Set keys and start download
        if self.download_id:
            download_tracker.update_status(self.download_id, "downloading")
            
        self.media_downloader.set_key(self.decryption_keys if self.decryption_keys else None)
        status = self.media_downloader.start_download()
        
        # Check for cancellation
        if status.get('error') == 'cancelled':
            if self.download_id:
                download_tracker.complete_download(self.download_id, success=False, error="cancelled")
            return None, True

        # Check if any media was downloaded
        if self._no_media_downloaded(status):
            logging.error("No media downloaded")
            if self.download_id:
                download_tracker.complete_download(self.download_id, success=False, error="No media downloaded")
            return None, True
        
        # Merge files
        if self.download_id:
            download_tracker.update_status(self.download_id, "Muxing...")
            
        final_file = self._merge_files(status)
        if not final_file or not os.path.exists(final_file):
            if self.download_id and download_tracker.is_stopped(self.download_id):
                download_tracker.complete_download(self.download_id, success=False, error="cancelled")
                return None, True
                
            logging.error("Merge operation failed")
            if self.download_id:
                download_tracker.complete_download(self.download_id, success=False, error="Merge failed")
            return None, True
        
        # Move to final location if needed
        self._move_to_final_location(final_file)
        
        # Print summary and cleanup
        self._print_summary()
        
        if CREATE_NFO_FILES:
            create_nfo(self.output_path)
        if self.download_id:
            download_tracker.complete_download(self.download_id, success=True, path=os.path.abspath(self.output_path))
            
        if CLEANUP_TMP:
            shutil.rmtree(self.output_dir, ignore_errors=True)
        return self.output_path, False
    
    def _no_media_downloaded(self, status):
        """Check if no media was downloaded."""
        return (status.get('video') is None and status.get('audios') == [] and status.get('subtitles') == [] and status.get('external_subtitles') == [])
    
    def _move_to_final_location(self, final_file):
        """Move file to final output path."""
        if os.path.abspath(final_file) != os.path.abspath(self.output_path):
            try:
                if os.path.exists(self.output_path):
                    os.remove(self.output_path)
                os.rename(final_file, self.output_path)
            except Exception as e:
                console.print(f"[yellow]Warning: Could not move file: {e}")
                self.output_path = final_file
    
    def _merge_files(self, status):
        """Merge downloaded files using FFmpeg."""
        if not status or not status.get('video') or not status['video'].get('path'):
            console.print("[red]Error: Video track information missing")
            self.error = "Video track missing"
            return None

        video_path = status['video']['path']
        
        if not os.path.exists(video_path):
            console.print(f"[red]Video file not found: {video_path}")
            self.error = "Video file missing"
            return None
        
        # If no additional tracks, just mux video
        if not status['audios'] and not status['subtitles']:
            console.print("[cyan]\nNo additional tracks to merge, muxing video...")
            merged_file, result_json = join_video(
                video_path=video_path,
                out_path=self.output_path
            )
            self.last_merge_result = result_json
            return merged_file if os.path.exists(merged_file) else None
        
        current_file = video_path
        
        # Merge audio tracks
        if status['audios']:
            current_file = self._merge_audio_tracks(current_file, status['audios'])
        
        # Merge subtitle tracks
        if status['subtitles']:
            current_file = self._merge_subtitle_tracks(current_file, status['subtitles'])
        
        return current_file
    
    def _merge_audio_tracks(self, current_file, audio_tracks):
        """Merge audio tracks with current video."""
        console.print(f"[cyan]\nMerging [red]{len(audio_tracks)} [cyan]audio track(s)...")
        audio_output = os.path.join(self.output_dir, f"{self.filename_base}_with_audio.{EXTENSION_OUTPUT}")
        
        merged_file, _, result_json = join_audios(
            video_path=current_file,
            audio_tracks=audio_tracks,
            out_path=audio_output
        )
        self.last_merge_result = result_json
        
        if os.path.exists(merged_file):
            return merged_file
        else:
            console.print("[yellow]Audio merge failed, continuing with video only")
            return current_file
    
    def _merge_subtitle_tracks(self, current_file, subtitle_tracks):
        """Merge subtitle tracks with current video."""
        console.print(f"[cyan]\nMerging [red]{len(subtitle_tracks)} [cyan]subtitle track(s)...")
        sub_output = os.path.join(self.output_dir, f"{self.filename_base}_final.{EXTENSION_OUTPUT}")
        
        merged_file, result_json = join_subtitles(
            video_path=current_file,
            subtitles_list=subtitle_tracks,
            out_path=sub_output
        )
        self.last_merge_result = result_json
        
        if os.path.exists(merged_file):
            return merged_file
        else:
            console.print("[yellow]Subtitle merge failed, continuing without subtitles")
            return current_file
    
    def _print_summary(self):
        """Print download summary."""
        if not os.path.exists(self.output_path):
            return
        
        file_size = internet_manager.format_file_size(os.path.getsize(self.output_path))
        duration = 'N/A'
        
        if self.last_merge_result and isinstance(self.last_merge_result, dict):
            duration = self.last_merge_result.get('time', 'N/A')
        
        console.print("\n[green]Output:")
        console.print(f"  [cyan]Path: [red]{os.path.abspath(self.output_path)}")
        console.print(f"  [cyan]Size: [red]{file_size}")
        console.print(f"  [cyan]Duration: [red]{duration}")