"""
Production data migration script
Migrates 343,794 medicines from Excel to MySQL
Optimized for bulk insert performance
"""

import pandas as pd
import sys
import os
from pathlib import Path
from tqdm import tqdm
from sqlalchemy import text
import time

# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent))

from src.database.models import Base, Medicine, DDICache
from src.database.connection import engine, SessionLocal

# Column mapping: Excel → Database
COLUMN_MAPPING = {
    'Product ID': 'product_id',
    'Product Name': 'product_name',
    'Marketer': 'marketer',
    'salt_composition': 'salt_composition',
    'medicine_type': 'medicine_type',
    'Introduction': 'introduction',
    'benefits': 'benefits',
    'description': 'description',
    'how_to_use': 'how_to_use',
    'safety_advise': 'safety_advise',
    'If_miss': 'if_miss',
    'Packaging Detail': 'packaging_detail',
    'Package': 'package',
    'Qty': 'qty',
    'Product Form': 'product_form',
    'mrp': 'mrp',
    'prescription_req': 'prescription_req',
    'Fact_Box': 'fact_box',
    'primary_use': 'primary_use',
    'storage': 'storage',
    'use_of': 'use_of',
    'common_side_effect': 'common_side_effect',
    'alcoholInteraction': 'alcohol_interaction',
    'pregnancyInteraction': 'pregnancy_interaction',
    'lactationInteraction': 'lactation_interaction',
    'drivingInteraction': 'driving_interaction',
    'kidneyInteraction': 'kidney_interaction',
    'liverInteraction': 'liver_interaction',
    'MANUFACTURER_ADDRESS': 'manufacturer_address',
    'country_of_origin': 'country_of_origin',
    'Q_A': 'q_a',
    'How it works': 'how_it_works',
    'Interaction': 'interaction',
    'Manufacturer details': 'manufacturer_details',
    'Marketer details': 'marketer_details',
    'Expirtation': 'expirtation',
    'Reference': 'reference'
}


def create_tables():
    """Create all database tables"""
    print("Creating database tables...")
    Base.metadata.create_all(bind=engine)
    print("Tables created successfully")


def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
    """Clean and prepare dataframe for insertion"""
    # Rename columns according to mapping
    df = df.rename(columns=COLUMN_MAPPING)
    
    # Convert NaN to None for proper NULL handling
    df = df.where(pd.notnull(df), None)
    
    # Clean string columns (strip whitespace)
    str_columns = df.select_dtypes(include=['object']).columns
    for col in str_columns:
        if col in df.columns:
            df[col] = df[col].apply(lambda x: x.strip() if isinstance(x, str) else x)
    
    return df


def bulk_insert_batch(df_batch: pd.DataFrame, source_file: str, session):
    """
    Insert a batch of records using SQLAlchemy Core for speed
    Uses multi-row VALUES clause [web:90][web:93]
    """
    df_batch['source_file'] = source_file
    
    # Convert to dict records
    records = df_batch.to_dict('records')
    
    # Bulk insert using SQLAlchemy Core (faster than ORM)
    with engine.begin() as conn:
        conn.execute(Medicine.__table__.insert(), records)


def migrate_excel_file(file_path: str, source_name: str, batch_size: int = 5000):
    """
    Migrate a single Excel file to database
    
    Args:
        file_path: Path to Excel file
        source_name: 'file1' or 'file2'
        batch_size: Records per batch (5000 optimal) [web:90]
    """
    print(f"\nProcessing: {file_path}")
    print(f"   Source: {source_name}")
    
    # Read Excel file
    print("   Reading Excel file...")
    start_time = time.time()
    
    try:
        df = pd.read_excel(file_path, engine='openpyxl')
    except Exception as e:
        print(f"Error reading file: {e}")
        return 0
    
    read_time = time.time() - start_time
    print(f"   Read {len(df):,} rows in {read_time:.2f}s")
    
    # Clean data
    print("   Cleaning data...")
    df = clean_dataframe(df)
    
    # Batch insert
    total_batches = (len(df) + batch_size - 1) // batch_size
    print(f"   Inserting in {total_batches} batches of {batch_size}...")
    
    session = SessionLocal()
    inserted = 0
    
    try:
        for i in tqdm(range(0, len(df), batch_size), desc=f"   Migrating {source_name}"):
            batch = df.iloc[i:i+batch_size]
            bulk_insert_batch(batch, source_name, session)
            inserted += len(batch)
        
        session.commit()
        print(f"   Successfully inserted {inserted:,} records")
        return inserted
        
    except Exception as e:
        session.rollback()
        print(f"   Error during migration: {e}")
        return 0
    finally:
        session.close()


def create_indexes():
    """
    Create optimized indexes for search performance
    """
    print("\n🔧 Creating performance indexes...")

    with engine.connect() as conn:
        # B-tree indexes for exact lookups
        conn.execute(text("""
            CREATE INDEX IF NOT EXISTS idx_medicine_type
            ON medicines(medicine_type);
        """))

        conn.execute(text("""
            CREATE INDEX IF NOT EXISTS idx_source_file
            ON medicines(source_file);
        """))

        # Composite index for common queries
        conn.execute(text("""
            CREATE INDEX IF NOT EXISTS idx_type_primary_use
            ON medicines(medicine_type(100), primary_use(100));
        """))

        conn.commit()

    print("Indexes created successfully")


def analyze_table():
    """Run ANALYZE to update query planner statistics"""
    print("\nAnalyzing table statistics...")
    with engine.connect() as conn:
        conn.execute(text("ANALYZE TABLE medicines;"))
        conn.commit()
    print("Analysis complete")


def main():
    """Main migration workflow"""
    print("=" * 60)
    print("MEDTECH PRODUCTION DATABASE MIGRATION")
    print("=" * 60)
    
    # File paths
    data_dir = Path(__file__).parent.parent / 'data'
    file1_path = data_dir / 'medicines_file1.xlsx'
    file2_path = data_dir / 'medicines_file2.xlsx'
    
    # Verify files exist
    if not file1_path.exists():
        print(f"File not found: {file1_path}")
        return
    if not file2_path.exists():
        print(f"File not found: {file2_path}")
        return
    
    # Step 1: Create tables
    create_tables()
    
    # Step 2: Migrate File 1
    total_file1 = migrate_excel_file(str(file1_path), 'file1')
    
    # Step 3: Migrate File 2
    total_file2 = migrate_excel_file(str(file2_path), 'file2')
    
    # Step 4: Create indexes
    if total_file1 > 0 or total_file2 > 0:
        create_indexes()
        analyze_table()
    
    # Summary
    print("\n" + "=" * 60)
    print("MIGRATION SUMMARY")
    print("=" * 60)
    print(f"   File 1: {total_file1:,} medicines")
    print(f"   File 2: {total_file2:,} medicines")
    print(f"   TOTAL:  {total_file1 + total_file2:,} medicines")
    print("=" * 60)
    print("Migration complete!")


if __name__ == "__main__":
    main()
