"""
Extract symptoms/conditions from the `description` column into a dedicated `symptoms` column.
Uses regex patterns only — no external API calls.

Usage:
    python scripts/extract_symptoms.py
    python scripts/extract_symptoms.py --dry-run     # preview without writing
    python scripts/extract_symptoms.py --sample 20   # spot-check N random medicines
"""

import re
import sys
import argparse
from pathlib import Path
from tqdm import tqdm
from sqlalchemy import text

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.database.connection import engine, SessionLocal

# Regex patterns to extract symptoms/conditions from description text.
# Each pattern captures the relevant symptom list in group "syms".
PATTERNS = [
    # "used in the treatment of cough, fever, and runny nose"
    re.compile(
        r'used\s+in\s+the\s+treatment\s+of\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "used to treat/relieve/manage/prevent/control cough, fever"
    re.compile(
        r'used\s+to\s+(?:treat|relieve|manage|prevent|control)\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "helps treat/relieve/reduce/manage pain and inflammation"
    re.compile(
        r'helps\s+(?:treat|relieve|reduce|manage)\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "effective against bacterial infections"
    re.compile(
        r'effective\s+against\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "indicated for pain, fever"
    re.compile(
        r'indicated\s+for\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "used for the treatment of ..."
    re.compile(
        r'used\s+for\s+(?:the\s+)?treatment\s+of\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "used for prevention of heart attack, stroke"
    re.compile(
        r'used\s+for\s+(?:the\s+)?prevention\s+of\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "prescribed to treat various types of bacterial infections"
    re.compile(
        r'prescribed\s+to\s+treat\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "commonly given to treat stomach pain, bloating"
    re.compile(
        r'(?:commonly\s+)?given\s+to\s+(?:treat|children\s+for\s+the\s+treatment\s+of)\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "helps in relieving moderate pain and reducing fever"
    re.compile(
        r'helps\s+in\s+(?:relieving|reducing|controlling|preventing)\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "helps to control/lower/reduce/ease/relieve high blood pressure"
    re.compile(
        r'helps\s+to\s+(?:control|lower|reduce|ease|relieve|bring\s+down)\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "used to lower cholesterol" / "used to reduce pressure"
    re.compile(
        r'used\s+to\s+(?:lower|reduce|promote)\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "used in symptomatic treatment of common cold"
    re.compile(
        r'used\s+in\s+(?:the\s+)?(?:symptomatic\s+)?treatment\s+of\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "It treats conditions like headache, toothache"
    re.compile(
        r'treats?\s+(?:conditions?\s+(?:like|such\s+as)\s+)?(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "provides/provides quick relief from common cold symptoms such as runny nose"
    re.compile(
        r'(?:provides|gives)\s+(?:\w+\s+)?relief\s+from\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "relieves allergy symptoms such as runny nose, stuffy nose"
    re.compile(
        r'relieves?\s+(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "control high blood sugar levels" / "control high blood pressure"
    re.compile(
        r'(?:control|controls)\s+(?P<syms>high\s+blood\s+(?:sugar|pressure)[^.]*?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "reduces the formation of ... blood clots" / "reduces pain and swelling"
    re.compile(
        r'reduces?\s+(?P<syms>(?:pain|swelling|inflammation|fever|cholesterol|blood\s+clot)[^.]*?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "used for/in treatment of osteoarthritis"
    re.compile(
        r'(?:prescribed|used)\s+for\s+(?:the\s+)?(?:treatment\s+of\s+)?(?P<syms>[^.]+?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
    # "fights against the infection" — generic antibiotic
    re.compile(
        r'fights?\s+against\s+(?:the\s+)?(?P<syms>(?:infection|microorganism)[^.]*?)(?:\.\s|\.\Z|\Z)',
        re.IGNORECASE,
    ),
]

# HTML-like tags and artifacts to strip
HTML_RE = re.compile(r'<[^>]+>')


def clean_text(raw: str) -> str:
    """Strip HTML tags and normalize whitespace."""
    text = HTML_RE.sub(' ', raw)
    text = text.replace('\u003c', '<').replace('\u003e', '>')
    text = HTML_RE.sub(' ', text)  # second pass after entity replacement
    text = re.sub(r'\s+', ' ', text).strip()
    return text


def extract_symptoms(description: str):
    """Extract comma-separated symptoms from a description string."""
    if not description:
        return None

    desc = clean_text(description)
    found = []

    for pat in PATTERNS:
        for m in pat.finditer(desc):
            raw = m.group('syms').strip()
            if raw:
                found.append(raw)

    if not found:
        return None

    # Split on commas / "and" / "&", deduplicate while preserving order
    symptoms = []
    seen = set()
    for chunk in found:
        # Split on comma, " and ", " & "
        parts = re.split(r',\s*|\s+and\s+|\s+&\s+', chunk)
        for part in parts:
            part = part.strip().rstrip('.')
            # Skip very short or very long fragments (noise)
            if len(part) < 2 or len(part) > 120:
                continue
            key = part.lower()
            if key not in seen:
                seen.add(key)
                symptoms.append(part)

    return ', '.join(symptoms) if symptoms else None


def add_column_if_missing():
    """Add the symptoms column to the medicines table if it doesn't exist."""
    with engine.connect() as conn:
        result = conn.execute(text(
            "SELECT COUNT(*) FROM information_schema.COLUMNS "
            "WHERE TABLE_SCHEMA = DATABASE() "
            "AND TABLE_NAME = 'medicines' "
            "AND COLUMN_NAME = 'symptoms'"
        ))
        exists = result.scalar() > 0

    if not exists:
        print("Adding 'symptoms' column to medicines table...")
        with engine.begin() as conn:
            conn.execute(text("ALTER TABLE medicines ADD COLUMN symptoms TEXT NULL AFTER description"))
            conn.execute(text(
                "CREATE FULLTEXT INDEX idx_symptoms_ft ON medicines(symptoms)"
            ))
        print("Column and FULLTEXT index created.")
    else:
        print("'symptoms' column already exists.")


def run_extraction(batch_size: int = 5000, dry_run: bool = False):
    """Extract symptoms from description (and introduction as fallback)."""
    add_column_if_missing()

    # Phase 1: Extract from description column
    with engine.connect() as conn:
        total = conn.execute(text(
            "SELECT COUNT(*) FROM medicines WHERE description IS NOT NULL AND description != ''"
        )).scalar()

    print(f"Phase 1: Processing {total:,} medicines with descriptions...")

    updated = 0
    no_match = 0
    offset = 0

    pbar = tqdm(total=total, desc="Extracting from description")

    while offset < total:
        with engine.connect() as conn:
            rows = conn.execute(text(
                "SELECT id, description FROM medicines "
                "WHERE description IS NOT NULL AND description != '' "
                "ORDER BY id LIMIT :limit OFFSET :offset"
            ), {"limit": batch_size, "offset": offset}).fetchall()

        if not rows:
            break

        updates = []
        for row in rows:
            med_id, desc = row[0], row[1]
            symptoms = extract_symptoms(desc)
            if symptoms:
                updates.append({"mid": med_id, "syms": symptoms})
            else:
                no_match += 1

        if updates and not dry_run:
            with engine.begin() as conn:
                conn.execute(
                    text("UPDATE medicines SET symptoms = :syms WHERE id = :mid"),
                    updates,
                )

        updated += len(updates)
        offset += len(rows)
        pbar.update(len(rows))

    pbar.close()
    print(f"\nPhase 1 done! Extracted: {updated:,} | No match: {no_match:,} | Total: {total:,}")

    # Phase 2: Fallback to introduction column for medicines still without symptoms
    with engine.connect() as conn:
        fallback_total = conn.execute(text(
            "SELECT COUNT(*) FROM medicines "
            "WHERE (symptoms IS NULL OR symptoms = '') "
            "AND introduction IS NOT NULL AND introduction != ''"
        )).scalar()

    if fallback_total == 0:
        print("\nPhase 2: No medicines need introduction fallback.")
    else:
        print(f"\nPhase 2: Processing {fallback_total:,} medicines using introduction fallback...")

        fb_updated = 0
        fb_no_match = 0
        offset = 0

        pbar = tqdm(total=fallback_total, desc="Extracting from introduction")

        while offset < fallback_total:
            with engine.connect() as conn:
                rows = conn.execute(text(
                    "SELECT id, introduction FROM medicines "
                    "WHERE (symptoms IS NULL OR symptoms = '') "
                    "AND introduction IS NOT NULL AND introduction != '' "
                    "ORDER BY id LIMIT :limit OFFSET :offset"
                ), {"limit": batch_size, "offset": offset}).fetchall()

            if not rows:
                break

            updates = []
            for row in rows:
                med_id, intro = row[0], row[1]
                symptoms = extract_symptoms(intro)
                if symptoms:
                    updates.append({"mid": med_id, "syms": symptoms})
                else:
                    fb_no_match += 1

            if updates and not dry_run:
                with engine.begin() as conn:
                    conn.execute(
                        text("UPDATE medicines SET symptoms = :syms WHERE id = :mid"),
                        updates,
                    )

            fb_updated += len(updates)
            offset += len(rows)
            pbar.update(len(rows))

        pbar.close()
        print(f"\nPhase 2 done! Extracted: {fb_updated:,} | No match: {fb_no_match:,}")
        updated += fb_updated
        no_match = no_match - fb_updated + fb_no_match

    print(f"\n{'='*60}")
    print(f"TOTAL: Symptoms extracted: {updated:,} | Still missing: {no_match:,}")
    if dry_run:
        print("(dry-run mode — no changes written)")


def sample_check(n: int = 20):
    """Print N random medicines with their extracted symptoms for verification."""
    with engine.connect() as conn:
        rows = conn.execute(text(
            "SELECT product_name, description, symptoms FROM medicines "
            "WHERE symptoms IS NOT NULL ORDER BY RAND() LIMIT :n"
        ), {"n": n}).fetchall()

    if not rows:
        print("No medicines with symptoms found. Run extraction first.")
        return

    for i, row in enumerate(rows, 1):
        name, desc, syms = row[0], row[1], row[2]
        print(f"\n{'='*60}")
        print(f"[{i}] {name}")
        print(f"Description: {desc[:200]}...")
        print(f"Symptoms: {syms}")


def main():
    parser = argparse.ArgumentParser(description="Extract symptoms from medicine descriptions")
    parser.add_argument('--dry-run', action='store_true', help="Preview without writing to DB")
    parser.add_argument('--sample', type=int, default=0,
                        help="Spot-check N random medicines (after extraction)")
    parser.add_argument('--batch-size', type=int, default=5000, help="Batch size (default: 5000)")
    args = parser.parse_args()

    if args.sample > 0:
        sample_check(args.sample)
    else:
        run_extraction(batch_size=args.batch_size, dry_run=args.dry_run)


if __name__ == "__main__":
    main()
