"""
Drug-Drug Interaction (DDI) Checker
Checks for interactions between medicines using normalized salt pairs.
DB-first approach: uses salt_composition and interaction fields,
then checks DDICache (normalized salt pairs), then falls back to Gemini AI.
"""

import re
from itertools import combinations, product
from typing import Dict, List, Optional, Tuple

from src.utils.logger import get_logger

logger = get_logger(__name__)


class DDIChecker:
    """
    Check drug-drug interactions between medicines.
    Uses normalized individual salt pairs for efficient caching.
    e.g. Augmentin (amoxicillin + clavulanate) vs Crocin (paracetamol)
    checks: (amoxicillin, paracetamol) and (clavulanate, paracetamol)
    """

    DDI_SALT_PROMPT = """You are a clinical pharmacology expert. Check for drug-drug interactions between these pharmaceutical salts/ingredients.

Salt A: {salt_a}
Salt B: {salt_b}

Return ONLY a JSON object:
{{
    "has_interaction": true or false,
    "severity": "none" or "mild" or "moderate" or "severe" or "contraindicated",
    "description": "brief clinical description of the interaction between these two ingredients",
    "recommendation": "what the patient/doctor should do",
    "confidence": 0.0 to 1.0
}}

Rules:
- severity "none" means no clinically significant interaction
- severity "mild" means minor interaction, usually safe with monitoring
- severity "moderate" means may require dose adjustment or monitoring
- severity "severe" means should generally be avoided
- severity "contraindicated" means must never be used together
- If unsure, err on the side of caution and recommend consulting a doctor
- Return ONLY the JSON object, no markdown or explanation"""

    BATCH_DDI_SALT_PROMPT = """You are a clinical pharmacology expert. Check for drug-drug interactions between these pairs of pharmaceutical salts/ingredients.

{pairs_text}

Return ONLY a JSON array where each element corresponds to one pair:
[
  {{
    "pair_index": 0,
    "has_interaction": true or false,
    "severity": "none" or "mild" or "moderate" or "severe" or "contraindicated",
    "description": "brief clinical description",
    "recommendation": "what patient/doctor should do",
    "confidence": 0.0 to 1.0
  }}
]

Rules:
- Return one object per pair in the same order
- severity levels: none, mild, moderate, severe, contraindicated
- If unsure, err on the side of caution
- Return ONLY the JSON array, no markdown or explanation"""

    SEVERITY_RANK = {
        "none": 0,
        "mild": 1,
        "moderate": 2,
        "severe": 3,
        "contraindicated": 4,
    }

    def parse_salt_composition(self, composition: str) -> List[str]:
        """
        Parse salt composition string into normalized ingredient names.
        e.g. "Amoxicillin (500mg) + Potassium Clavulanate (125mg)" -> ["amoxicillin", "potassium clavulanate"]
        """
        if not composition:
            return []

        parts = composition.split('+')
        salts = []
        for part in parts:
            # Remove parenthetical dosages like (500mg), (125mg/5ml)
            cleaned = re.sub(r'\([^)]*\)', '', part)
            # Remove standalone dosage patterns like 500mg, 10ml
            cleaned = re.sub(r'\b\d+(\.\d+)?\s*(mg|g|mcg|ml|iu|%|units?)\b', '', cleaned, flags=re.IGNORECASE)
            cleaned = cleaned.strip().lower()
            if cleaned:
                salts.append(cleaned)

        return salts

    def _normalize_salt_pair(self, salt_a: str, salt_b: str) -> Tuple[str, str]:
        """Return alphabetically sorted salt pair for consistent cache keys."""
        a = salt_a.strip().lower()
        b = salt_b.strip().lower()
        return (a, b) if a <= b else (b, a)

    def check_interactions(
        self,
        medicines: List[Dict],
        use_gemini_fallback: bool = False,
        db_session=None,
    ) -> List[Dict]:
        """
        Check DDI for all pairs of medicines.
        Breaks each medicine into individual salts and checks all cross-salt pairs.
        """
        if len(medicines) < 2:
            return []

        results = []

        for i, j in combinations(range(len(medicines)), 2):
            med_a = medicines[i]
            med_b = medicines[j]

            result = self._check_medicine_pair(med_a, med_b, use_gemini_fallback, db_session)
            result["medicine_a"] = med_a["product_name"]
            result["medicine_b"] = med_b["product_name"]
            results.append(result)

        return results

    def _check_medicine_pair(
        self, med_a: Dict, med_b: Dict, use_gemini_fallback: bool, db_session
    ) -> Dict:
        """Check a single medicine pair by examining all cross-salt pairs."""
        comp_a = med_a.get("salt_composition", "") or ""
        comp_b = med_b.get("salt_composition", "") or ""
        interaction_a = med_a.get("interaction", "") or ""
        interaction_b = med_b.get("interaction", "") or ""

        salts_a = self.parse_salt_composition(comp_a)
        salts_b = self.parse_salt_composition(comp_b)

        # Check 1: Duplicate therapy (overlapping salts)
        if salts_a and salts_b:
            overlap = set(salts_a) & set(salts_b)
            if overlap:
                return {
                    "salt_a": comp_a,
                    "salt_b": comp_b,
                    "severity": "moderate",
                    "description": f"Duplicate therapy detected. Both medicines contain: {', '.join(overlap)}. "
                                   f"Taking both may lead to overdose of the shared ingredient(s).",
                    "recommendation": "Consult your doctor. Do not take both medicines without medical advice as this may cause an overdose.",
                    "source": "database",
                    "confidence": 0.95,
                }

        # Check 2: Search interaction text fields for cross-references
        found_in_a = self._search_interaction_text(interaction_a, salts_b)
        found_in_b = self._search_interaction_text(interaction_b, salts_a)

        if found_in_a or found_in_b:
            description = found_in_a or found_in_b
            severity = self._infer_severity_from_text(description)
            return {
                "salt_a": comp_a,
                "salt_b": comp_b,
                "severity": severity,
                "description": description,
                "recommendation": "Consult your doctor before taking these medicines together.",
                "source": "database",
                "confidence": 0.8,
            }

        # Check 3: If we had interaction text but no match, it's likely safe
        if salts_a and salts_b and (interaction_a or interaction_b):
            return {
                "salt_a": comp_a,
                "salt_b": comp_b,
                "severity": "none",
                "description": "No known interaction found between these medicines based on available data.",
                "recommendation": "Generally safe to take together. Follow your doctor's instructions.",
                "source": "database",
                "confidence": 0.7,
            }

        # Check 4: Check DDICache for individual salt pairs
        if salts_a and salts_b and db_session:
            cached_result = self._check_salt_pairs_in_cache(
                salts_a, salts_b, comp_a, comp_b, db_session
            )
            if cached_result:
                logger.info(
                    f"DDI [{med_a['product_name']} + {med_b['product_name']}]: "
                    f"source=ddi_cache_table | severity={cached_result.get('severity')}"
                )
                return cached_result

        # Check 5: Gemini fallback for individual salt pairs
        if salts_a and salts_b and use_gemini_fallback:
            gemini_result = self._check_salt_pairs_via_gemini(
                salts_a, salts_b, comp_a, comp_b, db_session
            )
            if gemini_result:
                logger.info(
                    f"DDI [{med_a['product_name']} + {med_b['product_name']}]: "
                    f"source=gemini_ai | severity={gemini_result.get('severity')}"
                )
                return gemini_result

        # No data available
        if not salts_a or not salts_b:
            return {
                "salt_a": comp_a,
                "salt_b": comp_b,
                "severity": "unknown",
                "description": "Insufficient data to determine interaction. Consult a healthcare provider.",
                "recommendation": "Consult your doctor or pharmacist about taking these medicines together.",
                "source": "insufficient_data",
                "confidence": 0.0,
            }

        return {
            "salt_a": comp_a,
            "salt_b": comp_b,
            "severity": "unknown",
            "description": "Unable to determine interaction. Consult a healthcare provider.",
            "recommendation": "Consult your doctor or pharmacist about taking these medicines together.",
            "source": "insufficient_data",
            "confidence": 0.0,
        }

    def _check_salt_pairs_in_cache(
        self, salts_a: List[str], salts_b: List[str],
        comp_a: str, comp_b: str, db_session
    ) -> Optional[Dict]:
        """
        Check DDICache for all cross-salt pairs between two medicines.
        Returns the highest-severity interaction found, or None if no cache entries exist.
        """
        try:
            from src.database.models import DDICache

            worst_severity = "none"
            worst_result = None
            all_cached = True  # Track if ALL pairs are cached

            for sa, sb in product(salts_a, salts_b):
                norm_a, norm_b = self._normalize_salt_pair(sa, sb)

                cached = db_session.query(DDICache).filter(
                    DDICache.salt_a == norm_a,
                    DDICache.salt_b == norm_b,
                ).first()

                if not cached:
                    all_cached = False
                    break  # At least one pair not cached — can't give complete answer
                else:
                    rank = self.SEVERITY_RANK.get(cached.severity, 0)
                    if rank > self.SEVERITY_RANK.get(worst_severity, 0):
                        worst_severity = cached.severity
                        worst_result = {
                            "salt_a": comp_a,
                            "salt_b": comp_b,
                            "severity": cached.severity or "unknown",
                            "description": cached.description or "",
                            "recommendation": cached.recommendation or "",
                            "source": "ddi_cache",
                            "confidence": cached.confidence or 0.0,
                        }

            if all_cached:
                # All pairs were in cache
                if worst_result:
                    return worst_result
                # All pairs cached as "none"
                return {
                    "salt_a": comp_a,
                    "salt_b": comp_b,
                    "severity": "none",
                    "description": "No known interaction found between these ingredients.",
                    "recommendation": "Generally safe to take together. Follow your doctor's instructions.",
                    "source": "ddi_cache",
                    "confidence": 0.8,
                }

        except Exception as e:
            logger.debug(f"DDICache salt pair lookup failed: {e}")

        return None

    def _check_salt_pairs_via_gemini(
        self, salts_a: List[str], salts_b: List[str],
        comp_a: str, comp_b: str, db_session
    ) -> Optional[Dict]:
        """
        Check all cross-salt pairs via Gemini.
        Skips pairs already in DDICache. Saves new results to DDICache.
        Returns the highest-severity interaction found.
        """
        try:
            from src.nlp.gemini_extractor import get_gemini_extractor
            gemini = get_gemini_extractor()
            if not gemini.is_available():
                return None
        except Exception:
            return None

        # Collect pairs that need Gemini lookup
        uncached_pairs = []
        cached_results = []

        for sa, sb in product(salts_a, salts_b):
            norm_a, norm_b = self._normalize_salt_pair(sa, sb)

            # Check cache first
            if db_session:
                try:
                    from src.database.models import DDICache
                    cached = db_session.query(DDICache).filter(
                        DDICache.salt_a == norm_a,
                        DDICache.salt_b == norm_b,
                    ).first()
                    if cached:
                        cached_results.append({
                            "severity": cached.severity or "none",
                            "description": cached.description or "",
                            "recommendation": cached.recommendation or "",
                            "confidence": cached.confidence or 0.0,
                        })
                        continue
                except Exception:
                    pass

            uncached_pairs.append((norm_a, norm_b))

        # Call Gemini for uncached pairs
        gemini_results = []
        if uncached_pairs:
            if len(uncached_pairs) == 1:
                result = self._gemini_single_salt_pair(gemini, uncached_pairs[0])
                if result:
                    gemini_results.append((uncached_pairs[0], result))
            else:
                batch_results = self._gemini_batch_salt_pairs(gemini, uncached_pairs)
                gemini_results.extend(batch_results)

            # Save Gemini results to DDICache
            if db_session:
                for (norm_a, norm_b), result in gemini_results:
                    self._save_salt_pair_cache(db_session, norm_a, norm_b, result)

        # Combine all results and find worst severity
        all_results = cached_results + [r for _, r in gemini_results]

        if not all_results:
            return None

        worst_severity = "none"
        worst_result = None

        for result in all_results:
            rank = self.SEVERITY_RANK.get(result.get("severity", "none"), 0)
            if rank > self.SEVERITY_RANK.get(worst_severity, 0):
                worst_severity = result["severity"]
                worst_result = result

        if worst_result and worst_severity != "none":
            return {
                "salt_a": comp_a,
                "salt_b": comp_b,
                "severity": worst_result["severity"],
                "description": worst_result.get("description", ""),
                "recommendation": worst_result.get("recommendation", "Consult your doctor."),
                "source": "gemini_ai",
                "confidence": worst_result.get("confidence", 0.7),
            }

        return {
            "salt_a": comp_a,
            "salt_b": comp_b,
            "severity": "none",
            "description": "No known interaction found between these ingredients.",
            "recommendation": "Generally safe to take together. Follow your doctor's instructions.",
            "source": "gemini_ai",
            "confidence": 0.8,
        }

    def _gemini_single_salt_pair(self, gemini, pair: Tuple[str, str]) -> Optional[Dict]:
        """Check a single salt pair via Gemini."""
        prompt = self.DDI_SALT_PROMPT.format(salt_a=pair[0], salt_b=pair[1])

        try:
            text = gemini._generate(prompt, max_tokens=2048)
            if text:
                parsed = gemini._parse_json_response(text)
                if parsed:
                    return self._normalize_gemini_result(parsed, pair[0], pair[1])
        except Exception as e:
            logger.error(f"Gemini DDI salt pair check failed for {pair[0]} + {pair[1]}: {e}")

        return None

    def _gemini_batch_salt_pairs(
        self, gemini, pairs: List[Tuple[str, str]]
    ) -> List[Tuple[Tuple[str, str], Dict]]:
        """Check multiple salt pairs via Gemini in a single batch call."""
        pairs_text = ""
        for idx, (sa, sb) in enumerate(pairs):
            pairs_text += f"Pair {idx}: {sa} + {sb}\n"

        prompt = self.BATCH_DDI_SALT_PROMPT.format(pairs_text=pairs_text)
        results = []

        try:
            text = gemini._generate(prompt, max_tokens=2048)
            if text:
                parsed = gemini._parse_json_array_response(text)
                if parsed and isinstance(parsed, list):
                    for item in parsed:
                        if not isinstance(item, dict):
                            continue
                        pair_idx = item.get("pair_index", -1)
                        if 0 <= pair_idx < len(pairs):
                            result = self._normalize_gemini_result(
                                item, pairs[pair_idx][0], pairs[pair_idx][1]
                            )
                            if result:
                                results.append((pairs[pair_idx], result))
        except Exception as e:
            logger.error(f"Gemini batch DDI salt pair check failed: {e}")

        # Individual fallback for any missing pairs
        resolved_pairs = {p for p, _ in results}
        for pair in pairs:
            if pair not in resolved_pairs:
                result = self._gemini_single_salt_pair(gemini, pair)
                if result:
                    results.append((pair, result))

        return results

    def _normalize_gemini_result(self, parsed: Dict, salt_a: str, salt_b: str) -> Optional[Dict]:
        """Normalize a Gemini DDI result."""
        severity = parsed.get("severity", "unknown")
        valid_severities = {"none", "mild", "moderate", "severe", "contraindicated"}
        if severity not in valid_severities:
            severity = "moderate" if parsed.get("has_interaction") else "none"

        description = parsed.get("description", "") or ""
        recommendation = parsed.get("recommendation", "") or ""

        if not description and severity != "none":
            description = (
                f"Potential interaction between {salt_a} and {salt_b}. "
                f"Severity: {severity}. Consult a healthcare provider for details."
            )
        if not recommendation:
            recommendation = "Consult your doctor before taking these medicines together."

        return {
            "severity": severity,
            "description": description,
            "recommendation": recommendation,
            "confidence": float(parsed.get("confidence", 0.7)),
        }

    def _save_salt_pair_cache(self, db_session, norm_a: str, norm_b: str, result: Dict):
        """Save a single salt pair interaction result to DDICache."""
        try:
            from src.database.models import DDICache

            existing = db_session.query(DDICache).filter(
                DDICache.salt_a == norm_a,
                DDICache.salt_b == norm_b,
            ).first()

            if not existing:
                cache_entry = DDICache(
                    salt_a=norm_a,
                    salt_b=norm_b,
                    severity=result.get("severity"),
                    description=result.get("description"),
                    recommendation=result.get("recommendation"),
                    source="gemini_ai",
                    confidence=result.get("confidence"),
                )
                db_session.add(cache_entry)
                db_session.commit()
                logger.info(f"DDI cached: {norm_a} + {norm_b} -> {result.get('severity')}")
        except Exception as e:
            logger.debug(f"DDICache save failed for {norm_a} + {norm_b}: {e}")
            try:
                db_session.rollback()
            except Exception:
                pass

    def _search_interaction_text(self, interaction_text: str, salt_names: List[str]) -> Optional[str]:
        """Search interaction text for mentions of specific salt names."""
        if not interaction_text or not salt_names:
            return None

        text_lower = interaction_text.lower()
        for salt in salt_names:
            if salt in text_lower:
                idx = text_lower.find(salt)
                start = max(0, idx - 100)
                end = min(len(interaction_text), idx + len(salt) + 200)
                snippet = interaction_text[start:end].strip()
                if start > 0:
                    snippet = "..." + snippet
                if end < len(interaction_text):
                    snippet = snippet + "..."
                return snippet

        return None

    def _infer_severity_from_text(self, text: str) -> str:
        """Infer interaction severity from descriptive text."""
        text_lower = text.lower()

        contraindicated_words = ["contraindicated", "never", "must not", "do not use together", "fatal", "life-threatening"]
        severe_words = ["severe", "serious", "dangerous", "avoid", "high risk", "toxic"]
        moderate_words = ["moderate", "caution", "monitor", "adjust dose", "dose adjustment", "may increase", "may decrease"]
        mild_words = ["mild", "minor", "slight", "unlikely", "minimal"]

        for word in contraindicated_words:
            if word in text_lower:
                return "contraindicated"
        for word in severe_words:
            if word in text_lower:
                return "severe"
        for word in moderate_words:
            if word in text_lower:
                return "moderate"
        for word in mild_words:
            if word in text_lower:
                return "mild"

        return "moderate"


# Singleton
_ddi_checker_instance = None


def get_ddi_checker() -> DDIChecker:
    """Get singleton DDI checker instance."""
    global _ddi_checker_instance
    if _ddi_checker_instance is None:
        _ddi_checker_instance = DDIChecker()
    return _ddi_checker_instance
