"""
Extract symptoms for medicines that regex couldn't handle, using Gemini AI.
Targets only medicines with NULL/empty symptoms column.

Auto-resumable: only processes medicines still missing symptoms.
Handles rate limits with exponential backoff.

Usage:
    python scripts/extract_symptoms_gemini.py
    python scripts/extract_symptoms_gemini.py --batch-size 20 --delay 0.5
    python scripts/extract_symptoms_gemini.py --max-batches 100 --dry-run
"""

import sys
import argparse
import time
import json
from pathlib import Path
from sqlalchemy import text

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.database.connection import engine, SessionLocal


SYMPTOMS_PROMPT = """You are a pharmaceutical expert. Given a list of medicines with their available data, extract the symptoms/conditions each medicine is used to treat.

Medicines:
{medicine_lines}

Return ONLY a JSON array where each element has:
[
  {{
    "id": <medicine id>,
    "symptoms": "comma-separated list of symptoms/conditions this medicine treats (e.g., 'fever, headache, body pain')"
  }}
]

Rules:
- Extract specific symptoms and conditions the medicine treats
- Use the description, introduction, primary_use, and composition to determine symptoms
- If it's an antibiotic, list the types of infections it treats (e.g., 'bacterial infections, respiratory tract infections, urinary tract infections')
- Keep each symptom/condition concise (2-5 words each)
- If you cannot determine symptoms, use the primary therapeutic category (e.g., 'vitamin deficiency', 'nutritional supplement')
- Return ONLY the JSON array, no markdown or explanation"""


def get_medicines_without_symptoms(db, limit: int):
    """Get medicines that still have no symptoms."""
    query = text("""
        SELECT m.id, m.product_name, m.salt_composition, m.description,
               m.introduction, m.primary_use, m.medicine_type
        FROM medicines m
        WHERE (m.symptoms IS NULL OR m.symptoms = '')
        ORDER BY m.id
        LIMIT :limit
    """)
    return db.execute(query, {"limit": limit}).fetchall()


def count_missing(db) -> int:
    """Count medicines still missing symptoms."""
    result = db.execute(text(
        "SELECT COUNT(*) FROM medicines WHERE symptoms IS NULL OR symptoms = ''"
    ))
    return result.scalar()


def build_medicine_lines(rows):
    """Build text lines for the Gemini prompt."""
    lines = []
    for row in rows:
        mid, name, comp, desc, intro, primary_use, med_type = row
        parts = [f"[id={mid}] {name}"]
        if comp:
            parts.append(f"Composition: {str(comp)[:200]}")
        if desc:
            parts.append(f"Description: {str(desc)[:300]}")
        elif intro:
            parts.append(f"Introduction: {str(intro)[:300]}")
        if primary_use:
            parts.append(f"Primary use: {str(primary_use)[:200]}")
        if med_type:
            parts.append(f"Type: {med_type}")
        lines.append(" | ".join(parts))
    return "\n".join(lines)


def parse_gemini_response(response_text):
    """Parse JSON array from Gemini response."""
    text = response_text.strip()
    if '```json' in text:
        text = text.split('```json', 1)[1].split('```', 1)[0].strip()
    elif '```' in text:
        text = text.split('```', 1)[1].split('```', 1)[0].strip()

    if not text.startswith('['):
        idx = text.find('[')
        if idx != -1:
            text = text[idx:]

    try:
        return json.loads(text)
    except json.JSONDecodeError:
        # Try to repair truncated JSON
        open_braces = text.count('{') - text.count('}')
        open_brackets = text.count('[') - text.count(']')
        fixed = text.rstrip()
        if fixed.endswith(','):
            fixed = fixed[:-1]
        fixed += '}' * max(open_braces, 0)
        fixed += ']' * max(open_brackets, 0)
        try:
            return json.loads(fixed)
        except json.JSONDecodeError:
            return None


def main():
    parser = argparse.ArgumentParser(description="Extract symptoms via Gemini for remaining medicines")
    parser.add_argument('--batch-size', type=int, default=15,
                        help="Medicines per Gemini call (default: 15)")
    parser.add_argument('--delay', type=float, default=1.0,
                        help="Seconds between batches (default: 1.0)")
    parser.add_argument('--max-batches', type=int, default=0,
                        help="Max batches to process (0 = unlimited)")
    parser.add_argument('--dry-run', action='store_true',
                        help="Preview without writing to DB")
    args = parser.parse_args()

    # Initialize Gemini
    from src.nlp.gemini_extractor import GeminiDosageExtractor
    gemini = GeminiDosageExtractor()
    if not gemini.is_available():
        print("ERROR: Gemini is not available. Check GOOGLE_API_KEY.")
        sys.exit(1)
    print(f"Gemini initialized ({gemini._backend}, model={gemini._model_name})")

    db = SessionLocal()
    try:
        remaining = count_missing(db)
        print(f"Medicines without symptoms: {remaining:,}")
        if remaining == 0:
            print("Nothing to do!")
            return

        total_updated = 0
        total_failed = 0
        batch_num = 0
        backoff = args.delay
        max_backoff = 60

        while True:
            rows = get_medicines_without_symptoms(db, args.batch_size)
            if not rows:
                break

            batch_num += 1
            if args.max_batches > 0 and batch_num > args.max_batches:
                print(f"Reached max batches ({args.max_batches}). Stopping.")
                break

            medicine_lines = build_medicine_lines(rows)
            prompt = SYMPTOMS_PROMPT.format(medicine_lines=medicine_lines)

            try:
                response_text = gemini._generate(prompt, max_tokens=1024 * 2)
            except Exception as e:
                print(f"  Gemini error: {e}")
                backoff = min(backoff * 2, max_backoff)
                print(f"  Backing off {backoff:.0f}s...")
                time.sleep(backoff)
                continue

            if not response_text:
                print(f"  Batch {batch_num}: empty response (rate limited?)")
                backoff = min(backoff * 2, max_backoff)
                print(f"  Backing off {backoff:.0f}s...")
                time.sleep(backoff)
                continue

            # Reset backoff on success
            backoff = args.delay

            parsed = parse_gemini_response(response_text)
            if not parsed or not isinstance(parsed, list):
                print(f"  Batch {batch_num}: failed to parse response")
                total_failed += len(rows)
                time.sleep(args.delay)
                continue

            # Build id->symptoms map
            updates = []
            for item in parsed:
                if not isinstance(item, dict):
                    continue
                mid = item.get('id')
                symptoms = (item.get('symptoms') or '').strip()
                if mid and symptoms:
                    updates.append({"mid": mid, "syms": symptoms})

            if updates and not args.dry_run:
                with engine.begin() as conn:
                    conn.execute(
                        text("UPDATE medicines SET symptoms = :syms WHERE id = :mid"),
                        updates,
                    )

            # Refresh session so next query sees the updated rows
            db.commit()

            total_updated += len(updates)
            total_failed += len(rows) - len(updates)
            remaining_now = remaining - total_updated - total_failed

            print(
                f"  Batch {batch_num}: {len(updates)}/{len(rows)} updated | "
                f"Total: {total_updated:,} updated, {total_failed:,} failed | "
                f"~{max(remaining_now, 0):,} remaining"
            )

            if args.dry_run and updates:
                for u in updates[:3]:
                    print(f"    [id={u['mid']}] {u['syms'][:100]}")

            time.sleep(args.delay)

    finally:
        db.close()

    print(f"\n{'='*60}")
    print(f"DONE: {total_updated:,} symptoms extracted | {total_failed:,} failed")
    if args.dry_run:
        print("(dry-run mode - no changes written)")


if __name__ == "__main__":
    main()
