"""
Import the enriched CSV/xlsx (medicines + dosage cache) into a fresh environment.

This is the counterpart to export_enriched.py. It populates both the
`medicines` and `dosage_cache` tables from a single file — no Gemini
calls needed.

Usage:
    python scripts/import_enriched.py
    python scripts/import_enriched.py --input enriched_drugs.csv.gz
    python scripts/import_enriched.py --input enriched_drugs.xlsx --truncate-all
"""

import pandas as pd
import sys
import argparse
import time
from pathlib import Path
from tqdm import tqdm
from sqlalchemy import text

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.database.models import Base, Medicine, DosageCache
from src.database.connection import engine, SessionLocal

# Dosage cache columns in the enriched file (dc_ prefix)
DC_COLUMNS = {
    'dc_quantity': 'quantity',
    'dc_dosage_code': 'dosage_code',
    'dc_frequency': 'frequency',
    'dc_meal_preference': 'meal_preference',
    'dc_duration': 'duration',
    'dc_dose_morning': 'dose_morning',
    'dc_dose_afternoon': 'dose_afternoon',
    'dc_dose_night': 'dose_night',
    'dc_age_group_dosages': 'age_group_dosages',
    'dc_generation_method': 'generation_method',
    'dc_confidence_score': 'confidence_score',
}


def create_tables():
    """Create all tables if they don't exist."""
    print("Creating database tables...")
    Base.metadata.create_all(bind=engine)
    print("Tables ready")


def clear_tables(truncate_all: bool):
    """Clear existing data. Resets auto-increment so positional ID mapping works."""
    with engine.begin() as conn:
        # Always clear dosage_cache first (FK-like dependency on medicine IDs)
        conn.execute(text("DELETE FROM dosage_cache"))
        if truncate_all:
            conn.execute(text("DELETE FROM medicines"))
            # Reset auto-increment so IDs start from 1
            conn.execute(text("ALTER TABLE medicines AUTO_INCREMENT = 1"))
            print("Cleared all rows from medicines and dosage_cache (auto-increment reset)")
        else:
            conn.execute(text("DELETE FROM medicines"))
            conn.execute(text("ALTER TABLE medicines AUTO_INCREMENT = 1"))
            print("Cleared all medicines and dosage_cache (auto-increment reset)")


def read_file(file_path: Path) -> pd.DataFrame:
    """Read CSV or xlsx file."""
    suffix = ''.join(file_path.suffixes)  # handles .csv.gz
    if '.xlsx' in suffix:
        return pd.read_excel(str(file_path), engine='openpyxl')
    else:
        compression = 'gzip' if '.gz' in suffix else None
        return pd.read_csv(file_path, compression=compression, low_memory=False)


def main():
    parser = argparse.ArgumentParser(description="Import enriched medicines + dosage cache")
    parser.add_argument('--input', type=str, default=None,
                        help="Input file (default: scripts/enriched_drugs.csv.gz)")
    parser.add_argument('--truncate-all', action='store_true',
                        help="Delete ALL existing rows before import")
    parser.add_argument('--batch-size', type=int, default=5000,
                        help="Records per batch (default: 5000)")
    args = parser.parse_args()

    # Find input file
    if args.input:
        file_path = Path(args.input)
        if not file_path.is_absolute():
            file_path = Path(__file__).parent / file_path
    else:
        file_path = Path(__file__).parent / 'enriched_drugs.csv.gz'

    if not file_path.exists():
        print(f"File not found: {file_path}")
        sys.exit(1)

    print("=" * 60)
    print("IMPORT ENRICHED MEDICINES + DOSAGE CACHE")
    print("=" * 60)
    print(f"Source: {file_path}")

    # Step 1: Create tables
    create_tables()

    # Step 2: Clear old data
    clear_tables(args.truncate_all)

    # Step 3: Read file
    print(f"Reading {file_path.name}...")
    start = time.time()
    df = read_file(file_path)
    print(f"Read {len(df):,} rows in {time.time() - start:.1f}s")

    # Convert NaN to None
    df = df.where(pd.notnull(df), None)

    # Strip whitespace from string columns
    str_columns = df.select_dtypes(include=['object']).columns
    for col in str_columns:
        df[col] = df[col].apply(lambda x: x.strip() if isinstance(x, str) else x)

    # Separate medicine columns from dosage cache columns
    dc_cols_present = [c for c in DC_COLUMNS.keys() if c in df.columns]
    has_dosage_data = len(dc_cols_present) > 0

    med_cols = [c for c in df.columns if not c.startswith('dc_')]

    print(f"Medicine columns: {len(med_cols)}")
    print(f"Dosage cache columns: {len(dc_cols_present)}")

    # Step 4: Insert medicines in batches
    print(f"\nInserting medicines ({len(df):,} rows)...")
    med_df = df[med_cols].copy()
    # Ensure source_file has a default
    if 'source_file' not in med_df.columns:
        med_df['source_file'] = 'enriched_import'
    med_df.loc[med_df['source_file'].isna(), 'source_file'] = 'enriched_import'

    batch_size = args.batch_size
    inserted = 0
    start = time.time()

    for i in tqdm(range(0, len(med_df), batch_size), desc="Importing medicines"):
        batch = med_df.iloc[i:i + batch_size]
        records = batch.to_dict('records')
        try:
            with engine.begin() as conn:
                conn.execute(Medicine.__table__.insert(), records)
            inserted += len(batch)
        except Exception as e:
            print(f"\nError at row {i}: {e}")

    med_time = time.time() - start
    print(f"Inserted {inserted:,} medicines in {med_time:.1f}s")

    # Step 5: Insert dosage cache entries
    # Product names are unique (deduplicated during migrate_drugs1.py),
    # so we can safely map by product_name to get the auto-generated IDs.
    dc_inserted = 0
    if has_dosage_data:
        print("\nInserting dosage cache...")
        start = time.time()
        db = SessionLocal()
        try:
            # Build product_name → id mapping from the just-inserted medicines
            rows = db.execute(text("SELECT id, product_name FROM medicines")).fetchall()
            name_to_id = {}
            for row in rows:
                key = (row.product_name or '').strip().lower()
                name_to_id[key] = row.id

            # Build dosage cache records
            dc_records = []
            skipped = 0
            for _, row in tqdm(df.iterrows(), total=len(df), desc="Mapping dosage cache"):
                dc_code = row.get('dc_dosage_code')
                if dc_code is None:
                    continue

                name = (row.get('product_name') or '').strip().lower()
                med_id = name_to_id.get(name)
                if not med_id:
                    skipped += 1
                    continue

                dc_records.append({
                    'medicine_id': med_id,
                    'quantity': row.get('dc_quantity') or '',
                    'dosage_code': row.get('dc_dosage_code') or '',
                    'frequency': row.get('dc_frequency') or '',
                    'meal_preference': row.get('dc_meal_preference') or '',
                    'duration': row.get('dc_duration') or '',
                    'dose_morning': float(row.get('dc_dose_morning') or 0),
                    'dose_afternoon': float(row.get('dc_dose_afternoon') or 0),
                    'dose_night': float(row.get('dc_dose_night') or 0),
                    'age_group_dosages': row.get('dc_age_group_dosages') or None,
                    'generation_method': row.get('dc_generation_method') or 'imported',
                    'confidence_score': row.get('dc_confidence_score') or 0.0,
                })

            if skipped > 0:
                print(f"  Skipped {skipped} dosage entries (no matching medicine)")

            print(f"Dosage cache entries to insert: {len(dc_records):,}")

            # Batch insert dosage cache
            for i in tqdm(range(0, len(dc_records), batch_size), desc="Importing dosage cache"):
                batch = dc_records[i:i + batch_size]
                try:
                    with engine.begin() as conn:
                        conn.execute(DosageCache.__table__.insert(), batch)
                    dc_inserted += len(batch)
                except Exception as e:
                    print(f"\nDosage cache batch error at {i}: {e}")

        finally:
            db.close()

        dc_time = time.time() - start
        print(f"Inserted {dc_inserted:,} dosage cache entries in {dc_time:.1f}s")

    # Step 6: Create indexes + analyze
    if inserted > 0:
        print("\nCreating indexes...")
        try:
            with engine.connect() as conn:
                conn.execute(text(
                    "CREATE INDEX IF NOT EXISTS idx_medicine_type ON medicines(medicine_type);"
                ))
                conn.execute(text(
                    "CREATE INDEX IF NOT EXISTS idx_source_file ON medicines(source_file);"
                ))
                conn.execute(text(
                    "CREATE INDEX IF NOT EXISTS idx_type_primary_use "
                    "ON medicines(medicine_type(100), primary_use(100));"
                ))
                conn.execute(text("ANALYZE TABLE medicines;"))
                conn.execute(text("ANALYZE TABLE dosage_cache;"))
                conn.commit()
            print("Indexes and analysis complete")
        except Exception as e:
            print(f"Index creation error: {e}")

    # Summary
    print("\n" + "=" * 60)
    print("IMPORT COMPLETE")
    print("=" * 60)
    print(f"  Medicines imported:    {inserted:,}")
    print(f"  Dosage cache imported: {dc_inserted:,}")
    print(f"  Source: {file_path.name}")
    print("=" * 60)
    print("\nReady to start the API — no Gemini calls needed!")


if __name__ == "__main__":
    main()
