"""
Request logging middleware.
Sets request context (request_id, client_id) so every log line within a
request is automatically tagged — even across async calls.
"""

import time
import uuid
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response

from src.utils.logger import get_logger, set_request_context, clear_request_context

logger = get_logger(__name__)


class RequestLoggingMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next) -> Response:
        # Extract or generate request_id
        request_id = request.headers.get("x-myrx-request_id") or str(uuid.uuid4())
        client_id = getattr(request.state, "client_id", None) or "-"

        # Set context for all downstream loggers
        set_request_context(request_id=request_id, client_id=client_id)

        start = time.perf_counter()
        method = request.method
        path = request.url.path

        logger.info(f"{method} {path} started")

        try:
            response = await call_next(request)
        except Exception:
            elapsed_ms = (time.perf_counter() - start) * 1000
            logger.error(f"{method} {path} failed after {elapsed_ms:.0f}ms")
            clear_request_context()
            raise

        elapsed_ms = (time.perf_counter() - start) * 1000
        status = response.status_code

        # Update client_id if auth middleware set it after us
        resolved_client = getattr(request.state, "client_id", client_id)
        if resolved_client != client_id:
            set_request_context(request_id=request_id, client_id=resolved_client)

        response.headers["x-myrx-request_id"] = request_id
        response.headers["x-response-time"] = f"{elapsed_ms:.0f}ms"

        log_fn = logger.info if status < 400 else logger.warning if status < 500 else logger.error
        log_fn(f"{method} {path} -> {status} ({elapsed_ms:.0f}ms)")

        clear_request_context()
        return response
