"""
Token authentication for Amazon ingress API.

Validates Bearer JWT tokens:
1. Check token expiry (exp > current time)
2. Extract vendorUid from token
3. Decrypt vendorUid using AES-256-CBC (PBKDF2 key derivation)
4. Compare decrypted value with the expected Amazon API key
"""

import base64
import hashlib
import json
import logging
import os
import time
from typing import Optional

from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes

logger = logging.getLogger(__name__)

SECRET_KEY = os.getenv("TOKEN_SECRET_KEY", "aw345-der856-45ty675eu-&8%$#-87967-#@!^7!")
SALT = os.getenv("TOKEN_SALT", "kte45%623@e#{}&&#$%67")


def decrypt(encrypted_text: str) -> Optional[str]:
    """Decrypt AES-256-CBC encrypted text using PBKDF2WithHmacSHA256 key derivation.

    Equivalent to the Java decrypt function using:
    - PBKDF2WithHmacSHA256 with 65536 iterations, 256-bit key
    - AES/CBC/PKCS5Padding
    - 16-byte zero IV
    """
    try:
        iv = bytes(16)

        key = hashlib.pbkdf2_hmac(
            "sha256",
            SECRET_KEY.encode("utf-8"),
            SALT.encode("utf-8"),
            65536,
            dklen=32,
        )

        cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
        decryptor = cipher.decryptor()

        decrypted_padded = (
            decryptor.update(base64.b64decode(encrypted_text)) + decryptor.finalize()
        )

        unpadder = padding.PKCS7(128).unpadder()
        decrypted = unpadder.update(decrypted_padded) + unpadder.finalize()

        return decrypted.decode("utf-8")
    except Exception as e:
        logger.error("Error while decrypting: %s", e)
        return None


def decode_jwt_payload(token: str) -> dict:
    """Decode JWT payload without signature verification.

    We validate the token via vendorUid decryption instead of signature check.
    """
    parts = token.split(".")
    if len(parts) != 3:
        raise ValueError("Invalid JWT format")

    payload = parts[1]
    # Add base64 padding
    payload += "=" * (4 - len(payload) % 4)
    decoded = base64.urlsafe_b64decode(payload)
    return json.loads(decoded)


def validate_bearer_token(authorization: str, expected_api_key: str) -> dict:
    """Validate Bearer token from Authorization header.

    Args:
        authorization: Full Authorization header value (e.g. "Bearer <token>")
        expected_api_key: The Amazon x-api-key to compare against decrypted vendorUid

    Returns:
        dict with 'valid' bool and 'error' or 'payload' keys

    Validation steps:
        1. Extract Bearer token from header
        2. Decode JWT payload
        3. Check exp > current time
        4. Extract and decrypt vendorUid
        5. Compare decrypted vendorUid with expected API key
    """
    if not authorization or not authorization.startswith("Bearer "):
        return {"valid": False, "error": "Missing or invalid Authorization header"}

    token = authorization[7:].strip().strip("'\"")

    try:
        payload = decode_jwt_payload(token)
    except Exception as e:
        logger.error("Failed to decode JWT: %s", e)
        return {"valid": False, "error": "Invalid JWT token"}

    # Validate expiry
    exp = payload.get("exp")
    if not exp:
        return {"valid": False, "error": "Token missing expiry (exp) claim"}

    current_time = int(time.time())
    if exp <= current_time:
        return {"valid": False, "error": "Token has expired"}

    # Extract and decrypt vendorUid
    vendor_uid = payload.get("vendorUid")
    if not vendor_uid:
        return {"valid": False, "error": "Token missing vendorUid claim"}

    decrypted_vendor_uid = decrypt(vendor_uid)
    if decrypted_vendor_uid is None:
        return {"valid": False, "error": "Failed to decrypt vendorUid"}

    # Compare with expected API key
    if decrypted_vendor_uid != expected_api_key:
        logger.warning(
            "vendorUid mismatch: decrypted=%s, expected=%s",
            decrypted_vendor_uid,
            expected_api_key,
        )
        return {"valid": False, "error": "Invalid vendorUid"}

    return {"valid": True, "payload": payload}
