"""
Import Drugs_1.csv.gz + Drugs_2.xlsx (~337K medicines, 33 columns) into the medicines table.
Both files have identical columns and zero overlap in product names.

Usage:
    python scripts/migrate_drugs1.py
    python scripts/migrate_drugs1.py --truncate-all   # wipe everything first
"""

import pandas as pd
import sys
import os
import argparse
from pathlib import Path
from tqdm import tqdm
from sqlalchemy import text
import time

# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent))

from src.database.models import Base, Medicine
from src.database.connection import engine, SessionLocal

# Column mapping: Drugs_1.csv.gz headers → database columns
# Key differences from old migrate_data.py:
#   - "Composition" (not "salt_composition") → salt_composition
#   - "prescription_required" (not "prescription_req") → prescription_req
#   - "MRP" (not "mrp") → mrp
#   - Missing columns: benefits, use_of, expirtation, reference → will be NULL
COLUMN_MAPPING = {
    'Product ID': 'product_id',
    'Product Name': 'product_name',
    'Marketer': 'marketer',
    'Composition': 'salt_composition',
    'medicine_type': 'medicine_type',
    'introduction': 'introduction',
    'description': 'description',
    'how_to_use': 'how_to_use',
    'safety_advise': 'safety_advise',
    'if_miss': 'if_miss',
    'Packaging Detail': 'packaging_detail',
    'Package': 'package',
    'Qty': 'qty',
    'Product Form': 'product_form',
    'MRP': 'mrp',
    'prescription_required': 'prescription_req',
    'Fact_Box': 'fact_box',
    'primary_use': 'primary_use',
    'storage': 'storage',
    'common_side_effect': 'common_side_effect',
    'alcoholInteraction': 'alcohol_interaction',
    'pregnancyInteraction': 'pregnancy_interaction',
    'lactationInteraction': 'lactation_interaction',
    'drivingInteraction': 'driving_interaction',
    'kidneyInteraction': 'kidney_interaction',
    'liverInteraction': 'liver_interaction',
    'MANUFACTURER_ADDRESS': 'manufacturer_address',
    'country_of_origin': 'country_of_origin',
    'Q_A': 'q_a',
    'How it works': 'how_it_works',
    'Interaction': 'interaction',
    'Manufacturer details': 'manufacturer_details',
    'Marketer details': 'marketer_details',
    'symptoms': 'symptoms',
}

# Database columns that the CSV can map to
VALID_DB_COLUMNS = set(COLUMN_MAPPING.values())


def create_tables():
    """Create all database tables if they don't exist."""
    print("Creating database tables...")
    try:
        Base.metadata.create_all(bind=engine)
        print("Tables ready")
    except Exception as e:
        print(f"Error creating tables: {e}")
        sys.exit(1)


def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
    """Clean, deduplicate, and prepare dataframe for insertion."""
    df = df.rename(columns=COLUMN_MAPPING)

    # Convert NaN to None for proper NULL handling
    df = df.where(pd.notnull(df), None)

    # Strip whitespace from string columns
    str_columns = df.select_dtypes(include=['object']).columns
    for col in str_columns:
        df[col] = df[col].apply(lambda x: x.strip() if isinstance(x, str) else x)

    # Keep only columns that exist in the database model
    valid_columns = [col for col in df.columns if col in VALID_DB_COLUMNS]
    df = df[valid_columns]

    # Deduplicate by product_name — keep first occurrence
    before = len(df)
    df = df.drop_duplicates(subset='product_name', keep='first')
    dupes = before - len(df)
    if dupes > 0:
        print(f"  Removed {dupes:,} duplicate product names (kept first occurrence)")

    return df


def bulk_insert_batch(df_batch: pd.DataFrame, source_file: str):
    """Insert a batch of records using SQLAlchemy Core for speed."""
    df_batch = df_batch.copy()
    df_batch['source_file'] = source_file

    records = df_batch.to_dict('records')

    with engine.begin() as conn:
        conn.execute(Medicine.__table__.insert(), records)


def clear_old_data(keep_gemini: bool):
    """Remove old rows before import."""
    with engine.begin() as conn:
        if keep_gemini:
            result = conn.execute(
                text("DELETE FROM medicines WHERE source_file != 'gemini_ai'")
            )
            print(f"Deleted {result.rowcount:,} non-Gemini rows (gemini_ai rows preserved)")
        else:
            result = conn.execute(text("DELETE FROM medicines"))
            print(f"Deleted {result.rowcount:,} rows (full truncate)")


def create_indexes():
    """Create optimized indexes for search performance."""
    print("Creating performance indexes...")
    try:
        with engine.connect() as conn:
            conn.execute(text(
                "CREATE INDEX IF NOT EXISTS idx_medicine_type ON medicines(medicine_type);"
            ))
            conn.execute(text(
                "CREATE INDEX IF NOT EXISTS idx_source_file ON medicines(source_file);"
            ))
            conn.execute(text(
                "CREATE INDEX IF NOT EXISTS idx_type_primary_use "
                "ON medicines(medicine_type(100), primary_use(100));"
            ))
            conn.commit()
        print("Indexes created")
    except Exception as e:
        print(f"Error creating indexes: {e}")


def analyze_table():
    """Run ANALYZE to update query planner statistics."""
    print("Analyzing table statistics...")
    try:
        with engine.connect() as conn:
            conn.execute(text("ANALYZE TABLE medicines;"))
            conn.commit()
        print("Analysis complete")
    except Exception as e:
        print(f"Error analyzing table: {e}")


def read_source_file(file_path: Path) -> pd.DataFrame:
    """Read a CSV (.csv.gz) or Excel (.xlsx) source file."""
    suffix = ''.join(file_path.suffixes)
    print(f"Reading {file_path.name}...")
    start = time.time()

    if '.csv' in suffix:
        compression = 'gzip' if '.gz' in suffix else None
        df = pd.read_csv(file_path, compression=compression, low_memory=False)
    elif '.xlsx' in suffix:
        df = pd.read_excel(str(file_path), engine='openpyxl')
    else:
        raise ValueError(f"Unsupported file type: {suffix}")

    print(f"  Read {len(df):,} rows in {time.time() - start:.1f}s")
    return df


def main():
    parser = argparse.ArgumentParser(description="Import Drugs_1 + Drugs_2 into medicines table")
    parser.add_argument('--truncate-all', action='store_true',
                        help="Delete ALL rows before import (including gemini_ai)")
    parser.add_argument('--batch-size', type=int, default=5000,
                        help="Records per batch (default: 5000)")
    args = parser.parse_args()

    print("=" * 60)
    print("MYRX DRUGS IMPORT (Drugs_1 + Drugs_2)")
    print("=" * 60)

    # Locate source files
    scripts_dir = Path(__file__).parent
    sources = [
        (scripts_dir / 'Drugs_1.csv.gz', 'drugs_1'),
        (scripts_dir / 'Drugs_2.xlsx', 'drugs_2'),
    ]

    found = [(p, tag) for p, tag in sources if p.exists()]
    if not found:
        print("No source files found. Place these in scripts/:")
        for p, _ in sources:
            print(f"  {p.name}")
        sys.exit(1)

    print(f"Found {len(found)} source file(s):")
    for p, tag in found:
        print(f"  {p.name} -> source_file='{tag}'")

    # Step 1: Create tables
    create_tables()

    # Step 2: Clear old data
    keep_gemini = not args.truncate_all
    clear_old_data(keep_gemini)

    # Step 3: Read and combine all source files
    frames = []
    for file_path, source_tag in found:
        df = read_source_file(file_path)
        df = clean_dataframe(df)
        df['source_file'] = source_tag
        frames.append(df)

    combined = pd.concat(frames, ignore_index=True)

    # Deduplicate across files by product_name
    before = len(combined)
    combined = combined.drop_duplicates(subset='product_name', keep='first')
    dupes = before - len(combined)
    if dupes > 0:
        print(f"Removed {dupes:,} cross-file duplicates")
    print(f"Total unique medicines: {len(combined):,}")

    # Step 4: Batch insert
    batch_size = args.batch_size
    total_batches = (len(combined) + batch_size - 1) // batch_size
    print(f"Inserting in {total_batches} batches of {batch_size}...")

    inserted = 0
    start_time = time.time()

    for i in tqdm(range(0, len(combined), batch_size), desc="Importing"):
        batch = combined.iloc[i:i + batch_size]
        records = batch.to_dict('records')
        try:
            with engine.begin() as conn:
                conn.execute(Medicine.__table__.insert(), records)
            inserted += len(batch)
        except Exception as e:
            print(f"\nError inserting batch at row {i}: {e}")
            print("Continuing with next batch...")

    insert_time = time.time() - start_time
    print(f"Inserted {inserted:,} records in {insert_time:.1f}s")

    # Step 5: Create indexes + analyze
    if inserted > 0:
        create_indexes()
        analyze_table()

    # Summary
    print("\n" + "=" * 60)
    print("IMPORT SUMMARY")
    print("=" * 60)
    for p, tag in found:
        print(f"  {p.name} ({tag})")
    print(f"  Total unique: {inserted:,} inserted")
    print(f"  Time: {insert_time:.1f}s")
    print("=" * 60)

    if inserted > 0:
        print("\nNext step:")
        print("  python scripts/populate_dosage_cache.py   # Gemini dosage + age data (~$41, ~9.5 hours)")


if __name__ == "__main__":
    main()
