"""
Populate dosage_cache via Gemini AI for ALL medicines.

Gemini is the single source of truth for dosage data, including:
- dosage code (OD/BD/TDS/QDS), dose schedule (1 0 1), meal preference
- age-specific dosages (infant, child, adolescent, adult, elderly)

Auto-resumable: queries medicines with no dosage_cache entry (LEFT JOIN WHERE NULL).
Handles rate limits with exponential backoff.

Usage:
    python scripts/populate_dosage_cache.py
    python scripts/populate_dosage_cache.py --batch-size 20 --delay 0.5
    python scripts/populate_dosage_cache.py --max-batches 100 --dry-run
"""

import sys
import argparse
import time
import json
from pathlib import Path
from sqlalchemy import text

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.database.models import Base, DosageCache
from src.database.connection import engine, SessionLocal


def ensure_tables():
    """Ensure dosage_cache table exists."""
    DosageCache.__table__.create(bind=engine, checkfirst=True)


def get_uncached_medicines(db, limit: int):
    """Get medicines that don't have a dosage_cache entry yet, with context for Gemini."""
    query = text("""
        SELECT m.id, m.product_name, m.salt_composition, m.how_to_use, m.primary_use
        FROM medicines m
        LEFT JOIN dosage_cache dc ON m.id = dc.medicine_id
        WHERE dc.medicine_id IS NULL
        ORDER BY m.id
        LIMIT :limit
    """)
    result = db.execute(query, {"limit": limit})
    return result.fetchall()


def count_uncached(db) -> int:
    """Count medicines without dosage cache."""
    result = db.execute(text("""
        SELECT COUNT(*) FROM medicines m
        LEFT JOIN dosage_cache dc ON m.id = dc.medicine_id
        WHERE dc.medicine_id IS NULL
    """))
    return result.scalar()


def main():
    parser = argparse.ArgumentParser(description="Populate dosage_cache via Gemini AI")
    parser.add_argument('--batch-size', type=int, default=20,
                        help="Medicines per Gemini batch call (default: 20)")
    parser.add_argument('--delay', type=float, default=0.5,
                        help="Seconds between batches (default: 0.5)")
    parser.add_argument('--max-batches', type=int, default=0,
                        help="Max batches to process, 0=unlimited (default: 0)")
    parser.add_argument('--dry-run', action='store_true',
                        help="Show what would be done without calling Gemini")
    args = parser.parse_args()

    print("=" * 60)
    print("DOSAGE CACHE — GEMINI POPULATION")
    print("=" * 60)

    ensure_tables()

    # Import Gemini extractor
    from src.nlp.gemini_extractor import get_gemini_extractor
    gemini = get_gemini_extractor()

    if not gemini.is_available():
        print("ERROR: Gemini extractor not available. Check GEMINI_API_KEY.")
        sys.exit(1)

    db = SessionLocal()

    try:
        total_uncached = count_uncached(db)
        print(f"Medicines without dosage cache: {total_uncached:,}")

        if total_uncached == 0:
            print("Nothing to do — all medicines already cached.")
            return

        if args.dry_run:
            rows = get_uncached_medicines(db, min(args.batch_size, 20))
            print(f"\n[DRY RUN] First {len(rows)} uncached medicines:")
            for row in rows:
                print(f"  id={row.id}: {row.product_name}")
            print(f"\nWould process {total_uncached:,} medicines in "
                  f"~{(total_uncached + args.batch_size - 1) // args.batch_size} batches")
            return

        cached_count = 0
        failed_count = 0
        batch_num = 0
        start_time = time.time()
        backoff = args.delay

        while True:
            rows = get_uncached_medicines(db, args.batch_size)
            if not rows:
                print("All medicines cached!")
                break

            batch_num += 1
            if args.max_batches > 0 and batch_num > args.max_batches:
                print(f"Reached max batches ({args.max_batches}), stopping.")
                break

            # Build name→id mapping and context list for Gemini
            name_to_id = {row.product_name: row.id for row in rows}
            med_contexts = [
                {
                    "name": row.product_name,
                    "composition": row.salt_composition,
                    "how_to_use": row.how_to_use,
                    "primary_use": row.primary_use,
                }
                for row in rows
            ]
            names = list(name_to_id.keys())

            print(f"\nBatch {batch_num}: {len(names)} medicines "
                  f"(total cached so far: {cached_count})")

            try:
                results = gemini.batch_lookup_medicine_dosages(med_contexts)

                # If ALL results are None, Gemini is rate-limited — wait, don't waste money
                all_none = all(v is None for v in results.values())
                if all_none:
                    backoff = min(backoff * 2, 120)
                    print(f"  Rate limited — sleeping {backoff:.0f}s (no retry, no billing)")
                    time.sleep(backoff)
                    continue
                backoff = args.delay  # Reset backoff on success

                for name, dosage_data in results.items():
                    med_id = name_to_id.get(name)
                    if not med_id:
                        continue

                    if not dosage_data or dosage_data.get('confidence', 0) < 0.5:
                        failed_count += 1
                        continue

                    # Serialize age_group_dosages to JSON
                    age_dosages = dosage_data.get('age_group_dosages')
                    age_dosages_str = json.dumps(age_dosages) if age_dosages else None

                    # Truncate string fields to fit DB column limits
                    def _trunc(val, maxlen):
                        s = str(val or '')
                        return s[:maxlen] if len(s) > maxlen else s

                    cache_entry = DosageCache(
                        medicine_id=med_id,
                        quantity=_trunc(dosage_data.get('quantity', ''), 50),
                        dosage_code=_trunc(dosage_data.get('dosage', ''), 10),
                        frequency=_trunc(dosage_data.get('frequency', ''), 100),
                        meal_preference=_trunc(dosage_data.get('meal_preference', ''), 100),
                        duration=_trunc(dosage_data.get('duration', ''), 200),
                        dose_morning=dosage_data.get('dose_morning', 0),
                        dose_afternoon=dosage_data.get('dose_afternoon', 0),
                        dose_night=dosage_data.get('dose_night', 0),
                        age_group_dosages=age_dosages_str,
                        generation_method='gemini',
                        confidence_score=dosage_data.get('confidence', 0.7),
                    )
                    db.add(cache_entry)
                    cached_count += 1

                # Also cache medicines that returned no result (so we don't retry them)
                for name in names:
                    med_id = name_to_id[name]
                    if name not in results or not results[name]:
                        # Cache with low confidence so we skip next time
                        cache_entry = DosageCache(
                            medicine_id=med_id,
                            quantity='',
                            dosage_code='',
                            frequency='',
                            meal_preference='',
                            duration='',
                            generation_method='gemini_failed',
                            confidence_score=0.0,
                        )
                        db.add(cache_entry)

                db.commit()

            except Exception as e:
                db.rollback()
                error_msg = str(e)
                if '429' in error_msg or 'RESOURCE_EXHAUSTED' in error_msg:
                    backoff = min(backoff * 2, 60)
                    print(f"  Rate limited — backing off {backoff:.0f}s")
                    time.sleep(backoff)
                    continue
                else:
                    print(f"  ERROR: {e}")
                    failed_count += len(names)

            # Progress
            remaining = count_uncached(db)
            elapsed = time.time() - start_time
            rate = cached_count / max(elapsed, 1) * 60
            print(f"  Progress: {cached_count:,} cached | "
                  f"{remaining:,} remaining | {rate:.0f}/min")

            if args.delay > 0:
                time.sleep(args.delay)

        elapsed = time.time() - start_time

        print("\n" + "=" * 60)
        print("GEMINI CACHE SUMMARY")
        print("=" * 60)
        print(f"  Batches processed: {batch_num}")
        print(f"  Cached: {cached_count:,}")
        print(f"  Failed: {failed_count:,}")
        print(f"  Time: {elapsed:.1f}s")
        print("=" * 60)

    finally:
        db.close()


if __name__ == "__main__":
    main()
