"""
Enhanced Consumption Extractor
Wraps SmartConsumptionExtractor with additional regex patterns for:
- Time-interval dosing (every X hours)
- PRN / SOS / as-needed medications
- Morning/evening patterns
- Alternate-day dosing
- Extended meal-timing patterns
- Better duration detection (lifelong, complete course)
- Multi-ingredient quantity extraction
- Improved confidence scoring
"""

import re
from typing import Dict

from src.nlp.smart_consumption_extractor import SmartConsumptionExtractor, get_smart_extractor
from src.utils.logger import get_logger

logger = get_logger(__name__)


class EnhancedExtractor:
    """
    Enhanced extraction layer that calls the original SmartConsumptionExtractor
    first, then applies additional patterns to improve low-confidence results.
    """

    # Maps time-interval hours to dosage codes
    INTERVAL_TO_CODE = {
        (3, 4): 'QDS',     # every 3-4 hours  -> ~4-6x/day, closest standard = QDS
        (4, 5): 'QDS',     # every 4 hours
        (5, 6): 'QDS',     # every 5-6 hours
        (6, 7): 'QDS',     # every 6 hours     -> 4x/day
        (7, 8): 'TDS',     # every 8 hours     -> 3x/day
        (8, 9): 'TDS',
        (10, 12): 'BD',    # every 12 hours    -> 2x/day
        (12, 13): 'BD',
        (24, 25): 'OD',    # every 24 hours    -> 1x/day
    }

    # Time-interval regex: "every 4 hours", "every 6-8 hours", "every 4-6 hrs"
    RE_INTERVAL = re.compile(
        r'every\s+(\d+)(?:\s*(?:to|-)\s*(\d+))?\s*(?:hours?|hrs?)',
        re.IGNORECASE,
    )

    # PRN / as-needed patterns
    RE_PRN = re.compile(
        r'\b(?:as\s+needed|when\s+required|if\s+needed|prn|sos|on\s+demand|as\s+required)\b',
        re.IGNORECASE,
    )

    # Morning/evening combination -> BD
    RE_MORNING_EVENING = re.compile(
        r'morning\s+(?:and|&)\s+(?:evening|night)',
        re.IGNORECASE,
    )

    # Alternate-day patterns
    RE_ALTERNATE = re.compile(
        r'\b(?:alternate\s+days?|every\s+other\s+day|every\s+2(?:nd)?\s+day)\b',
        re.IGNORECASE,
    )

    # Extra meal-timing patterns (ordered by specificity)
    EXTRA_MEAL_PATTERNS = [
        ('Empty Stomach', re.compile(
            r'\b(?:on\s+(?:an\s+)?empty\s+stomach|with\s+water\s+on\s+empty\s+stomach'
            r'|on\s+waking|first\s+thing\s+in\s+the\s+morning)\b', re.IGNORECASE)),
        ('Before Meal', re.compile(
            r'\b(?:before\s+breakfast|before\s+lunch|before\s+dinner'
            r'|30\s+min(?:utes?)?\s+before|with\s+first\s+bite)\b', re.IGNORECASE)),
        ('After Meal', re.compile(
            r'\b(?:after\s+breakfast|after\s+lunch|after\s+dinner'
            r'|immediately\s+after\s+(?:food|meal|eating))\b', re.IGNORECASE)),
        ('Between Meals', re.compile(
            r'\b(?:between\s+meals|in\s+between\s+meals|not\s+with\s+(?:food|meals))\b',
            re.IGNORECASE)),
        ('At Bedtime', re.compile(
            r'\b(?:at\s+(?:bed\s*time|night)|before\s+(?:sleep|sleeping|bed)'
            r'|h\.?s\.?|hora\s+somni)\b', re.IGNORECASE)),
    ]

    # Duration extras
    RE_LIFELONG = re.compile(
        r'\b(?:lifelong|life[\s-]?long|indefinitely|continuously|long[\s-]?term)\b',
        re.IGNORECASE,
    )
    RE_COMPLETE_COURSE = re.compile(
        r'\b(?:complete\s+(?:the\s+)?course|until\s+finished|finish\s+(?:the\s+)?course'
        r'|do\s+not\s+stop\s+early)\b',
        re.IGNORECASE,
    )

    # Multi-ingredient quantity: "Amoxicillin 500mg + Clavulanic Acid 125mg"
    RE_MULTI_QTY = re.compile(
        r'(\d+\.?\d*)\s*(mg|mcg|g|ml|iu|%)(?:\s*(?:\+|/)\s*\w[\w\s]*?(\d+\.?\d*)\s*(mg|mcg|g|ml|iu|%))+',
        re.IGNORECASE,
    )
    # Simpler: find all "number unit" pairs
    RE_ALL_QTY = re.compile(r'(\d+\.?\d*)\s*(mg|mcg|g|ml|iu|%)', re.IGNORECASE)

    # Dose range in instructions: "1-2 tablets", "5-10 ml"
    RE_DOSE_RANGE = re.compile(
        r'(\d+)(?:\s*(?:to|-)\s*(\d+))?\s*(tablet|capsule|pill|ml|drop|puff|sachet|spoon)s?',
        re.IGNORECASE,
    )

    def __init__(self):
        self._base = get_smart_extractor()

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def extract_enhanced(self, medicine_record, enable_validation=True, use_ai_fallback=True) -> Dict:
        """
        Run the original smart extractor, then enhance the result with
        additional patterns.  Returns the same dict shape as the original.
        """
        result = self._base.extract_from_database_record(
            medicine_record,
            enable_validation=enable_validation,
            use_ai_fallback=use_ai_fallback,
        )

        how_to_use = str(medicine_record.how_to_use or "").lower()
        composition = str(medicine_record.salt_composition or "")

        changed = False

        # --- 1. Enhance dosage code ---
        new_dosage = self._enhance_dosage(how_to_use, result['dosage'])
        if new_dosage != result['dosage']:
            result['dosage'] = new_dosage
            result['frequency'] = self._base._map_dosage_to_frequency(new_dosage)
            result['dosage_explicit'] = True  # Pattern-matched by enhanced logic
            changed = True

        # --- 2. SOS / PRN detection ---
        if self.RE_PRN.search(how_to_use):
            result['dosage'] = 'SOS'
            result['frequency'] = 'As needed (SOS)'
            result['dosage_explicit'] = True  # Explicitly detected
            changed = True

        # --- 3. Enhance meal preference ---
        if result['meal_preference'] == 'As advised':
            new_meal = self._enhance_meal(how_to_use)
            if new_meal:
                result['meal_preference'] = new_meal
                changed = True

        # --- 4. Enhance duration ---
        if result['duration'] == 'As prescribed':
            new_dur = self._enhance_duration(how_to_use)
            if new_dur:
                result['duration'] = new_dur
                changed = True

        # --- 5. Enhance quantity (multi-ingredient) ---
        new_qty = self._enhance_quantity(composition)
        if new_qty and new_qty != result['quantity']:
            result['quantity'] = new_qty
            changed = True

        # --- 6. Recalculate confidence ---
        # Skip recalculation if Gemini already provided a high-confidence result
        if result.get('extraction_method') != 'gemini_ai':
            result['confidence_score'] = self._recalculate_confidence(
                how_to_use, composition, result
            )
        result['needs_verification'] = result['confidence_score'] < 0.80

        if changed:
            if result['extraction_method'] == 'regex':
                result['extraction_method'] = 'enhanced_regex'
            logger.debug(
                f"Enhanced extraction for {medicine_record.product_name}: "
                f"dosage={result['dosage']}, meal={result['meal_preference']}"
            )

        return result

    # ------------------------------------------------------------------
    # Private helpers
    # ------------------------------------------------------------------

    def _enhance_dosage(self, how_to_use: str, current_code: str) -> str:
        """Try time-interval and morning/evening patterns."""

        # Morning and evening -> BD
        if self.RE_MORNING_EVENING.search(how_to_use):
            return 'BD'

        # Alternate days -> OD (alternate)
        if self.RE_ALTERNATE.search(how_to_use):
            return 'OD'

        # Time-interval: "every X hours"
        m = self.RE_INTERVAL.search(how_to_use)
        if m:
            lo = int(m.group(1))
            hi = int(m.group(2)) if m.group(2) else lo
            mid = (lo + hi) / 2.0
            # Find best matching bracket
            best_code = current_code
            for (bracket_lo, bracket_hi), code in self.INTERVAL_TO_CODE.items():
                if bracket_lo <= mid < bracket_hi:
                    best_code = code
                    break
            # Fallback heuristic for anything not in the table
            if best_code == current_code:
                if mid <= 6:
                    best_code = 'QDS'
                elif mid <= 8:
                    best_code = 'TDS'
                elif mid <= 12:
                    best_code = 'BD'
                else:
                    best_code = 'OD'
            return best_code

        return current_code

    def _enhance_meal(self, how_to_use: str) -> str:
        """Check additional meal-timing patterns."""
        for preference, pattern in self.EXTRA_MEAL_PATTERNS:
            if pattern.search(how_to_use):
                return preference
        return ""

    def _enhance_duration(self, how_to_use: str) -> str:
        """Check lifelong / complete-course patterns."""
        if self.RE_LIFELONG.search(how_to_use):
            return 'Long-term'
        if self.RE_COMPLETE_COURSE.search(how_to_use):
            return 'Complete the course'
        return ""

    def _enhance_quantity(self, composition: str) -> str:
        """Extract multi-ingredient quantities like '500 mg + 125 mg'."""
        if not composition or composition.lower() == 'nan':
            return ""
        matches = self.RE_ALL_QTY.findall(composition)
        if len(matches) >= 2:
            parts = [f"{val} {unit}" for val, unit in matches]
            return " + ".join(parts)
        return ""

    def _recalculate_confidence(self, how_to_use: str, composition: str, result: Dict) -> float:
        """
        Improved confidence scoring with harder penalties for vague instructions
        and bonuses for specific pattern matches.
        """
        score = 0.0

        # --- Data availability ---
        if how_to_use and len(how_to_use) > 20:
            score += 0.25
        elif how_to_use and len(how_to_use) > 5:
            score += 0.10

        if composition and composition.lower() != 'nan':
            score += 0.15

        # --- Extraction specificity ---
        dosage = result.get('dosage', '')
        dosage_explicit = result.get('dosage_explicit', False)
        if dosage in ('OD', 'BD', 'TDS', 'QDS', 'SOS'):
            if dosage_explicit:
                score += 0.25
            else:
                score += 0.05  # Small credit for type-based guess

        meal = result.get('meal_preference', '')
        if meal not in ('As advised', ''):
            score += 0.15

        duration = result.get('duration', '')
        if duration not in ('As prescribed', ''):
            score += 0.10

        # --- Bonuses ---
        if result.get('extraction_method') in ('gemini_ai', 'enhanced_regex'):
            score += 0.05

        # --- Penalties ---
        text = how_to_use.lower()
        if 'as directed' in text or 'as advised' in text:
            score -= 0.15
        if 'consult' in text and 'physician' in text:
            score -= 0.10
        if not how_to_use or len(how_to_use.strip()) < 5:
            score -= 0.20

        return round(min(max(score, 0.0), 1.0), 2)


# ------------------------------------------------------------------
# Singleton
# ------------------------------------------------------------------

_enhanced_instance = None


def get_enhanced_extractor() -> EnhancedExtractor:
    """Get singleton enhanced extractor."""
    global _enhanced_instance
    if _enhanced_instance is None:
        _enhanced_instance = EnhancedExtractor()
        logger.info("Enhanced extractor initialized (wraps SmartConsumptionExtractor)")
    return _enhanced_instance
