"""
Tier 1: Populate dosage_cache using rule-based extraction (no API calls).

Runs SmartConsumptionExtractor + EnhancedExtractor on all medicines that
don't yet have a dosage_cache entry. Saves results with confidence >= 0.85.

Usage:
    python scripts/populate_dosage_cache_regex.py
    python scripts/populate_dosage_cache_regex.py --min-confidence 0.80
    python scripts/populate_dosage_cache_regex.py --batch-size 1000
"""

import sys
import argparse
import time
from pathlib import Path
from sqlalchemy import text
from tqdm import tqdm

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.database.models import Base, Medicine, DosageCache
from src.database.connection import engine, SessionLocal
from src.nlp.enhanced_extractor import get_enhanced_extractor


def ensure_tables():
    """Ensure dosage_cache table exists."""
    DosageCache.__table__.create(bind=engine, checkfirst=True)


def get_uncached_medicines(db, batch_size: int, offset: int):
    """Get medicines that don't have a dosage_cache entry yet."""
    query = text("""
        SELECT m.id, m.product_name, m.salt_composition, m.medicine_type,
               m.how_to_use, m.product_form, m.description
        FROM medicines m
        LEFT JOIN dosage_cache dc ON m.id = dc.medicine_id
        WHERE dc.medicine_id IS NULL
        ORDER BY m.id
        LIMIT :limit OFFSET :offset
    """)
    result = db.execute(query, {"limit": batch_size, "offset": offset})
    return result.fetchall()


def count_uncached(db) -> int:
    """Count medicines without dosage cache."""
    result = db.execute(text("""
        SELECT COUNT(*) FROM medicines m
        LEFT JOIN dosage_cache dc ON m.id = dc.medicine_id
        WHERE dc.medicine_id IS NULL
    """))
    return result.scalar()


class MedicineProxy:
    """Lightweight object that mimics a Medicine ORM instance for the extractor."""
    def __init__(self, row):
        self.id = row.id
        self.product_name = row.product_name
        self.salt_composition = row.salt_composition
        self.medicine_type = row.medicine_type
        self.how_to_use = row.how_to_use
        self.product_form = row.product_form
        self.description = row.description


def main():
    parser = argparse.ArgumentParser(description="Populate dosage_cache with rule-based extraction")
    parser.add_argument('--min-confidence', type=float, default=0.85,
                        help="Minimum confidence to cache (default: 0.85)")
    parser.add_argument('--batch-size', type=int, default=1000,
                        help="DB query batch size (default: 1000)")
    args = parser.parse_args()

    print("=" * 60)
    print("DOSAGE CACHE — RULE-BASED POPULATION (Tier 1)")
    print("=" * 60)

    ensure_tables()
    extractor = get_enhanced_extractor()
    db = SessionLocal()

    try:
        total_uncached = count_uncached(db)
        print(f"Medicines without dosage cache: {total_uncached:,}")

        if total_uncached == 0:
            print("Nothing to do — all medicines already cached.")
            return

        cached_count = 0
        skipped_count = 0
        error_count = 0
        offset = 0
        start_time = time.time()

        pbar = tqdm(total=total_uncached, desc="Extracting")

        while True:
            # Always offset=0 because we're inserting cache rows,
            # so the LEFT JOIN WHERE NULL shrinks each time
            rows = get_uncached_medicines(db, args.batch_size, 0)
            if not rows:
                break

            for row in rows:
                try:
                    med = MedicineProxy(row)
                    result = extractor.extract_enhanced(
                        med,
                        enable_validation=False,
                        use_ai_fallback=False,
                    )

                    if result['confidence_score'] >= args.min_confidence:
                        # Derive meal-slot schedule from dosage code
                        schedule_map = {
                            'OD': (1, 0, 0), 'BD': (1, 0, 1),
                            'TDS': (1, 1, 1), 'QDS': (1, 1, 1),
                        }
                        dosage_code = result.get('dosage', 'OD')
                        dm, da, dn = schedule_map.get(dosage_code, (1, 0, 0))
                        # Bedtime OD → 0 0 1
                        if dosage_code == 'OD' and 'bedtime' in (result.get('meal_preference') or '').lower():
                            dm, da, dn = 0, 0, 1

                        cache_entry = DosageCache(
                            medicine_id=med.id,
                            quantity=result.get('quantity', ''),
                            dosage_code=dosage_code,
                            frequency=result.get('frequency', ''),
                            meal_preference=result.get('meal_preference', ''),
                            duration=result.get('duration', ''),
                            dose_morning=dm,
                            dose_afternoon=da,
                            dose_night=dn,
                            generation_method='rule_based',
                            confidence_score=result['confidence_score'],
                        )
                        db.add(cache_entry)
                        cached_count += 1
                    else:
                        skipped_count += 1

                except Exception as e:
                    error_count += 1
                    if error_count <= 5:
                        print(f"\nError processing medicine {row.id} ({row.product_name}): {e}")

                pbar.update(1)

            # Commit each batch
            try:
                db.commit()
            except Exception as e:
                print(f"\nBatch commit error: {e}")
                db.rollback()

        pbar.close()
        elapsed = time.time() - start_time

        print("\n" + "=" * 60)
        print("RULE-BASED CACHE SUMMARY")
        print("=" * 60)
        print(f"  Total processed: {total_uncached:,}")
        print(f"  Cached (>= {args.min_confidence}): {cached_count:,}")
        print(f"  Skipped (low confidence): {skipped_count:,}")
        print(f"  Errors: {error_count:,}")
        print(f"  Time: {elapsed:.1f}s")
        print(f"  Coverage: {cached_count / max(total_uncached, 1) * 100:.1f}%")
        print("=" * 60)

        if skipped_count > 0:
            print(f"\nNext: run 'python scripts/populate_dosage_cache.py' to fill "
                  f"remaining {skipped_count:,} via Gemini")

    finally:
        db.close()


if __name__ == "__main__":
    main()
