"""
Parallel worker for dosage cache population.
Each worker handles a specific medicine ID range.

Usage:
    python scripts/populate_dosage_worker.py --start 197 --end 100000 --worker-id 1
"""

import sys
import argparse
import time
import json
from pathlib import Path
from sqlalchemy import text

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.database.models import DosageCache
from src.database.connection import engine, SessionLocal


def get_uncached_in_range(db, start_id: int, end_id: int, limit: int):
    query = text("""
        SELECT m.id, m.product_name, m.salt_composition, m.how_to_use, m.primary_use
        FROM medicines m
        LEFT JOIN dosage_cache dc ON m.id = dc.medicine_id
        WHERE dc.medicine_id IS NULL AND m.id >= :start AND m.id < :end
        ORDER BY m.id
        LIMIT :limit
    """)
    return db.execute(query, {"start": start_id, "end": end_id, "limit": limit}).fetchall()


def count_uncached_in_range(db, start_id: int, end_id: int) -> int:
    result = db.execute(text("""
        SELECT COUNT(*) FROM medicines m
        LEFT JOIN dosage_cache dc ON m.id = dc.medicine_id
        WHERE dc.medicine_id IS NULL AND m.id >= :start AND m.id < :end
    """), {"start": start_id, "end": end_id})
    return result.scalar()


def _trunc(val, maxlen):
    s = str(val or '')
    return s[:maxlen] if len(s) > maxlen else s


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--start', type=int, required=True)
    parser.add_argument('--end', type=int, required=True)
    parser.add_argument('--worker-id', type=int, default=1)
    parser.add_argument('--batch-size', type=int, default=20)
    parser.add_argument('--delay', type=float, default=0.2)
    args = parser.parse_args()

    tag = f"[W{args.worker_id}]"
    print(f"{tag} Worker {args.worker_id}: IDs {args.start} to {args.end}")

    from src.nlp.gemini_extractor import get_gemini_extractor
    gemini = get_gemini_extractor()
    if not gemini.is_available():
        print(f"{tag} ERROR: Gemini not available")
        sys.exit(1)

    db = SessionLocal()
    try:
        total_uncached = count_uncached_in_range(db, args.start, args.end)
        print(f"{tag} Uncached in range: {total_uncached:,}")

        if total_uncached == 0:
            print(f"{tag} Nothing to do")
            return

        cached_count = 0
        failed_count = 0
        batch_num = 0
        start_time = time.time()
        backoff = args.delay

        while True:
            rows = get_uncached_in_range(db, args.start, args.end, args.batch_size)
            if not rows:
                print(f"{tag} All done!")
                break

            batch_num += 1
            name_to_id = {row.product_name: row.id for row in rows}
            med_contexts = [
                {
                    "name": row.product_name,
                    "composition": row.salt_composition,
                    "how_to_use": row.how_to_use,
                    "primary_use": row.primary_use,
                }
                for row in rows
            ]
            names = list(name_to_id.keys())

            try:
                results = gemini.batch_lookup_medicine_dosages(med_contexts)

                # If ALL results are None, Gemini is rate-limited — wait, don't cache as failed
                all_none = all(v is None for v in results.values())
                if all_none:
                    backoff = min(backoff * 2, 120)
                    print(f"{tag} Rate limited — sleeping {backoff:.0f}s (no retry, no billing)")
                    time.sleep(backoff)
                    continue
                backoff = args.delay

                for name, dosage_data in results.items():
                    med_id = name_to_id.get(name)
                    if not med_id:
                        continue

                    try:
                        if not dosage_data or dosage_data.get('confidence', 0) < 0.5:
                            failed_count += 1
                            db.add(DosageCache(
                                medicine_id=med_id, quantity='', dosage_code='',
                                frequency='', meal_preference='', duration='',
                                generation_method='gemini_failed', confidence_score=0.0,
                            ))
                        else:
                            age_dosages = dosage_data.get('age_group_dosages')
                            age_dosages_str = json.dumps(age_dosages) if age_dosages else None

                            db.add(DosageCache(
                                medicine_id=med_id,
                                quantity=_trunc(dosage_data.get('quantity', ''), 50),
                                dosage_code=_trunc(dosage_data.get('dosage', ''), 10),
                                frequency=_trunc(dosage_data.get('frequency', ''), 100),
                                meal_preference=_trunc(dosage_data.get('meal_preference', ''), 100),
                                duration=_trunc(dosage_data.get('duration', ''), 200),
                                dose_morning=dosage_data.get('dose_morning', 0),
                                dose_afternoon=dosage_data.get('dose_afternoon', 0),
                                dose_night=dosage_data.get('dose_night', 0),
                                age_group_dosages=age_dosages_str,
                                generation_method='gemini',
                                confidence_score=dosage_data.get('confidence', 0.7),
                            ))
                            cached_count += 1
                        db.commit()
                    except Exception as insert_err:
                        db.rollback()
                        if 'Duplicate' in str(insert_err):
                            pass  # Already cached, skip
                        else:
                            print(f"{tag} Insert error for {name}: {insert_err}")

                # Cache medicines with no result
                for name in names:
                    med_id = name_to_id[name]
                    if name not in results or not results[name]:
                        try:
                            db.add(DosageCache(
                                medicine_id=med_id, quantity='', dosage_code='',
                                frequency='', meal_preference='', duration='',
                                generation_method='gemini_failed', confidence_score=0.0,
                            ))
                            db.commit()
                        except Exception:
                            db.rollback()

            except Exception as e:
                db.rollback()
                error_msg = str(e)
                if '429' in error_msg or 'RESOURCE_EXHAUSTED' in error_msg:
                    backoff = min(backoff * 2, 60)
                    print(f"{tag} Rate limited — backing off {backoff:.0f}s")
                    time.sleep(backoff)
                    continue
                else:
                    print(f"{tag} ERROR: {e}")

            elapsed = time.time() - start_time
            rate = cached_count / max(elapsed, 1) * 60
            remaining = count_uncached_in_range(db, args.start, args.end)
            print(f"{tag} Batch {batch_num}: {cached_count:,} cached | {remaining:,} remaining | {rate:.0f}/min")

            if args.delay > 0:
                time.sleep(args.delay)

        elapsed = time.time() - start_time
        print(f"\n{tag} DONE: {cached_count:,} cached, {failed_count:,} failed in {elapsed:.0f}s")

    finally:
        db.close()


if __name__ == "__main__":
    main()
