import asyncio
import json
import time
from types import SimpleNamespace
from unittest.mock import AsyncMock

import pytest

import app.main as main
from app.main import (
    ASSIGNED_DOCTORS_KEY,
    PRESCRIPTION_QUEUE,
    RX_PRESCRIPTION_MAPPING_KEY,
    RX_ASSIGNMENT_MAPPING_KEY,
    assign_doctor,
    block_doctor,
    handle_expired_prescription,
    process_prescription,
    vendor_keys,
)

VENDOR_ID = "vendor_1"
REDIS_KEY = "vendor_1_redis_list"


class FakeRedis:
    def __init__(self):
        self.lists = {}
        self.hashes = {}
        self.sets = {}

    async def lpop(self, key):
        items = self.lists.get(key, [])
        if not items:
            return None
        return items.pop(0)

    async def rpush(self, key, value):
        self.lists.setdefault(key, []).append(value)
        return len(self.lists[key])

    async def lrange(self, key, start, end):
        items = self.lists.get(key, [])
        if end == -1:
            return items[start:]
        return items[start : end + 1]

    async def lrem(self, key, count, value):
        items = self.lists.get(key, [])
        if count == 0:
            new_items = [item for item in items if item != value]
            removed = len(items) - len(new_items)
            self.lists[key] = new_items
            return removed
        removed = 0
        new_items = []
        for item in items:
            if item == value and removed < count:
                removed += 1
                continue
            new_items.append(item)
        self.lists[key] = new_items
        return removed

    async def lpos(self, key, value):
        items = self.lists.get(key, [])
        try:
            return items.index(value)
        except ValueError:
            return None

    async def eval(self, script, num_keys, *args):
        """Simulate the atomic LPOS+RPUSH Lua script used by safe_rpush_doctor."""
        key = args[0]
        value = args[1]
        items = self.lists.get(key, [])
        if value in items:
            return 0
        self.lists.setdefault(key, []).append(value)
        return 1

    async def set(self, key, value, ex=None, nx=False):
        if nx and key in self.hashes.get("__strings__", {}):
            return False
        self.hashes.setdefault("__strings__", {})[key] = value
        return True

    async def delete(self, key):
        self.lists.pop(key, None)
        if "__strings__" in self.hashes:
            self.hashes["__strings__"].pop(key, None)

    async def hget(self, key, field):
        return self.hashes.get(key, {}).get(field)

    async def hset(self, key, field, value):
        self.hashes.setdefault(key, {})[field] = value

    async def hgetall(self, key):
        return dict(self.hashes.get(key, {}))

    async def hdel(self, key, field):
        if key in self.hashes:
            self.hashes[key].pop(field, None)

    async def sadd(self, key, *values):
        self.sets.setdefault(key, set()).update(values)

    async def srem(self, key, *values):
        self.sets.setdefault(key, set()).difference_update(values)

    async def sismember(self, key, value):
        return value in self.sets.get(key, set())

    async def exists(self, key):
        return 1 if (key in self.sets or key in self.lists or key in self.hashes) else 0


@pytest.fixture(autouse=True)
def setup_vendor_keys():
    vendor_keys[VENDOR_ID] = REDIS_KEY
    yield
    vendor_keys.pop(VENDOR_ID, None)


def _make_stored_value(assigned_time, is_blocked=False):
    return json.dumps({
        "vendor_id": VENDOR_ID,
        "assignment_time": str(assigned_time),
        "is_blocked": is_blocked,
    })


def _expired_time():
    """Return a time far enough in the past to be expired."""
    return int(time.time()) - 9999


def _fresh_time():
    """Return a time recent enough to not be expired."""
    return int(time.time())


def test_expired_non_blocked_returns_doctor_to_pool(monkeypatch):
    """When a non-blocked prescription expires, the assigned doctor should be rpush'd back to the pool."""
    mock_redis = AsyncMock()
    mock_redis_client = AsyncMock()
    monkeypatch.setattr(main, "redis", mock_redis)
    monkeypatch.setattr(main, "redis_client", mock_redis_client, raising=False)

    mock_redis.hget = AsyncMock(side_effect=[
        "doc_1",               # ASSIGNED_DOCTORS_KEY lookup
        json.dumps({"is_blocked": False}),  # RX_ASSIGNMENT_MAPPING_KEY lock section
    ])
    mock_redis.hgetall = AsyncMock(return_value={})
    mock_redis.eval = AsyncMock(return_value=1)  # doctor added successfully
    mock_redis_client.hdel = AsyncMock()

    monkeypatch.setattr(main, "acquire_redis_lock", AsyncMock(return_value="lock123"))
    monkeypatch.setattr(main, "release_redis_lock", AsyncMock())
    monkeypatch.setattr(main, "add_prescription_to_queue", AsyncMock())

    stored_value = _make_stored_value(_expired_time(), is_blocked=False)
    current_time = int(time.time())

    asyncio.run(handle_expired_prescription("rx_001", stored_value, current_time))

    mock_redis.eval.assert_called_once()
    call_args = mock_redis.eval.call_args
    assert call_args[0][2] == REDIS_KEY
    assert call_args[0][3] == "doc_1"


def test_expired_non_blocked_skips_doctor_already_in_pool(monkeypatch):
    """If the doctor is already in the pool, eval should return 0 (skipped)."""
    mock_redis = AsyncMock()
    mock_redis_client = AsyncMock()
    monkeypatch.setattr(main, "redis", mock_redis)
    monkeypatch.setattr(main, "redis_client", mock_redis_client, raising=False)

    mock_redis.hget = AsyncMock(side_effect=[
        "doc_1",
        json.dumps({"is_blocked": False}),
    ])
    mock_redis.hgetall = AsyncMock(return_value={})
    mock_redis.eval = AsyncMock(return_value=0)  # doctor already in pool
    mock_redis_client.hdel = AsyncMock()

    monkeypatch.setattr(main, "acquire_redis_lock", AsyncMock(return_value="lock123"))
    monkeypatch.setattr(main, "release_redis_lock", AsyncMock())
    monkeypatch.setattr(main, "add_prescription_to_queue", AsyncMock())

    stored_value = _make_stored_value(_expired_time(), is_blocked=False)
    current_time = int(time.time())

    asyncio.run(handle_expired_prescription("rx_001", stored_value, current_time))

    mock_redis.eval.assert_called_once()


def test_expired_non_blocked_returns_multiple_doctors(monkeypatch):
    """When multiple doctors are assigned (concurrent > 1), all should be returned to pool."""
    mock_redis = AsyncMock()
    mock_redis_client = AsyncMock()
    monkeypatch.setattr(main, "redis", mock_redis)
    monkeypatch.setattr(main, "redis_client", mock_redis_client, raising=False)

    # doc_2 already in pool (eval returns 0), doc_1 and doc_3 added (eval returns 1)
    mock_redis.hget = AsyncMock(side_effect=[
        "doc_1,doc_2,doc_3",   # three assigned doctors
        json.dumps({"is_blocked": False}),
    ])
    mock_redis.hgetall = AsyncMock(return_value={})
    mock_redis.eval = AsyncMock(side_effect=[1, 0, 1])  # doc_1=added, doc_2=skipped, doc_3=added
    mock_redis_client.hdel = AsyncMock()

    monkeypatch.setattr(main, "acquire_redis_lock", AsyncMock(return_value="lock123"))
    monkeypatch.setattr(main, "release_redis_lock", AsyncMock())
    monkeypatch.setattr(main, "add_prescription_to_queue", AsyncMock())

    stored_value = _make_stored_value(_expired_time(), is_blocked=False)
    current_time = int(time.time())

    asyncio.run(handle_expired_prescription("rx_001", stored_value, current_time))

    assert mock_redis.eval.call_count == 3
    calls = [c[0][3] for c in mock_redis.eval.call_args_list]  # 4th arg is doctor_id
    assert "doc_1" in calls
    assert "doc_2" in calls
    assert "doc_3" in calls


def test_non_expired_prescription_does_nothing(monkeypatch):
    """A prescription that hasn't expired yet should not trigger any requeue or rpush."""
    mock_redis = AsyncMock()
    mock_redis_client = AsyncMock()
    monkeypatch.setattr(main, "redis", mock_redis)
    monkeypatch.setattr(main, "redis_client", mock_redis_client, raising=False)

    mock_redis.hget = AsyncMock(side_effect=[
        "doc_1",
        json.dumps({"is_blocked": False}),
    ])
    mock_redis.hgetall = AsyncMock(return_value={})
    mock_redis.lrange = AsyncMock(return_value=[])
    mock_redis.rpush = AsyncMock()
    mock_redis_client.hdel = AsyncMock()

    monkeypatch.setattr(main, "acquire_redis_lock", AsyncMock(return_value="lock123"))
    monkeypatch.setattr(main, "release_redis_lock", AsyncMock())
    mock_add = AsyncMock()
    monkeypatch.setattr(main, "add_prescription_to_queue", mock_add)

    stored_value = _make_stored_value(_fresh_time(), is_blocked=False)
    current_time = int(time.time())

    asyncio.run(handle_expired_prescription("rx_001", stored_value, current_time))

    mock_redis.rpush.assert_not_called()
    mock_add.assert_not_called()
    mock_redis_client.hdel.assert_not_called()


def test_expired_blocked_returns_doctor_to_pool(monkeypatch):
    """Regression guard: blocked expiry path should also return doctors to pool (existing behavior)."""
    mock_redis = AsyncMock()
    mock_redis_client = AsyncMock()
    monkeypatch.setattr(main, "redis", mock_redis)
    monkeypatch.setattr(main, "redis_client", mock_redis_client, raising=False)

    mock_redis.hget = AsyncMock(side_effect=[
        "doc_1",
        json.dumps({"is_blocked": True}),
    ])
    mock_redis.hgetall = AsyncMock(return_value={})
    mock_redis.eval = AsyncMock(return_value=1)  # doctor added successfully
    mock_redis_client.hdel = AsyncMock()

    monkeypatch.setattr(main, "acquire_redis_lock", AsyncMock(return_value="lock123"))
    monkeypatch.setattr(main, "release_redis_lock", AsyncMock())

    # Mock the aiohttp cancel request to return 200
    mock_response = AsyncMock()
    mock_response.status = 200

    mock_post_ctx = AsyncMock()
    mock_post_ctx.__aenter__ = AsyncMock(return_value=mock_response)
    mock_post_ctx.__aexit__ = AsyncMock(return_value=False)

    mock_session = AsyncMock()
    mock_session.post = lambda *args, **kwargs: mock_post_ctx

    mock_session_ctx = AsyncMock()
    mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session)
    mock_session_ctx.__aexit__ = AsyncMock(return_value=False)

    monkeypatch.setattr(main.aiohttp, "ClientSession", lambda: mock_session_ctx)

    stored_value = _make_stored_value(_expired_time(), is_blocked=True)
    current_time = int(time.time())

    asyncio.run(handle_expired_prescription("rx_001", stored_value, current_time))

    mock_redis.eval.assert_called_once()
    call_args = mock_redis.eval.call_args
    assert call_args[0][2] == REDIS_KEY
    assert call_args[0][3] == "doc_1"


def test_expired_requeues_and_next_cycle_assigns_next_doctor(monkeypatch):
    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = ["doc_1", "doc_2"]

    monkeypatch.setattr(main, "redis", fake_redis)
    monkeypatch.setattr(main, "redis_client", fake_redis, raising=False)
    monkeypatch.setattr(
        main,
        "doctors_data",
        SimpleNamespace(
            data=[
                SimpleNamespace(
                    vendor=VENDOR_ID, concurrent_doctors=1, immediate_unblock=False
                )
            ]
        ),
    )

    monkeypatch.setattr(main, "add_prescription_to_vendor_map", AsyncMock())
    monkeypatch.setattr(main, "remove_prescription_from_vendor_map", AsyncMock())
    monkeypatch.setattr(main, "acquire_redis_lock", AsyncMock(return_value="lock123"))
    monkeypatch.setattr(main, "release_redis_lock", AsyncMock())

    asyncio.run(assign_doctor(
        SimpleNamespace(prescription_id="rx_001", vendor_id=VENDOR_ID)
    ))

    stored_value = _make_stored_value(_expired_time(), is_blocked=False)
    current_time = int(time.time())
    asyncio.run(handle_expired_prescription("rx_001", stored_value, current_time))

    pool_after_expiry = asyncio.run(fake_redis.lrange(REDIS_KEY, 0, -1))
    assert pool_after_expiry == ["doc_2", "doc_1"]
    queue_after_expiry = asyncio.run(fake_redis.lrange(PRESCRIPTION_QUEUE, 0, -1))
    assert queue_after_expiry == ["rx_001"]

    mock_response = AsyncMock()
    mock_response.status = 200

    mock_post_ctx = AsyncMock()
    mock_post_ctx.__aenter__ = AsyncMock(return_value=mock_response)
    mock_post_ctx.__aexit__ = AsyncMock(return_value=False)

    mock_session = AsyncMock()
    mock_session.post = lambda *args, **kwargs: mock_post_ctx

    mock_session_ctx = AsyncMock()
    mock_session_ctx.__aenter__ = AsyncMock(return_value=mock_session)
    mock_session_ctx.__aexit__ = AsyncMock(return_value=False)

    monkeypatch.setattr(main.aiohttp, "ClientSession", lambda: mock_session_ctx)

    asyncio.run(process_prescription("rx_001", VENDOR_ID))

    assigned_doctor = asyncio.run(fake_redis.hget(ASSIGNED_DOCTORS_KEY, "rx_001"))
    assert assigned_doctor == "doc_2"
    pool_after_reassign = asyncio.run(fake_redis.lrange(REDIS_KEY, 0, -1))
    assert pool_after_reassign == ["doc_1"]
    queue_after_reassign = asyncio.run(fake_redis.lrange(PRESCRIPTION_QUEUE, 0, -1))
    assert queue_after_reassign == []


def test_blocked_doctor_not_reassigned_after_other_rx_expiry(monkeypatch):
    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = ["doc_1"]

    monkeypatch.setattr(main, "redis", fake_redis)
    monkeypatch.setattr(main, "redis_client", fake_redis, raising=False)
    monkeypatch.setattr(
        main,
        "doctors_data",
        SimpleNamespace(
            data=[
                SimpleNamespace(
                    vendor=VENDOR_ID, concurrent_doctors=1, immediate_unblock=False
                )
            ]
        ),
    )

    monkeypatch.setattr(main, "add_prescription_to_queue", AsyncMock())
    monkeypatch.setattr(main, "remove_prescription_from_vendor_map", AsyncMock())
    monkeypatch.setattr(main, "acquire_redis_lock", AsyncMock(return_value="lock123"))
    monkeypatch.setattr(main, "release_redis_lock", AsyncMock())

    asyncio.run(
        assign_doctor(SimpleNamespace(prescription_id="rx_b", vendor_id=VENDOR_ID))
    )

    asyncio.run(
        fake_redis.hset(
            RX_PRESCRIPTION_MAPPING_KEY,
            "rx_a",
            json.dumps(
                {"prescription_create_time": int(time.time()), "vendor_id": VENDOR_ID}
            ),
        )
    )
    asyncio.run(fake_redis.rpush(PRESCRIPTION_QUEUE, "rx_a"))

    asyncio.run(
        block_doctor(
            SimpleNamespace(
                block_doctor_id="doc_1",
                vendor_id=VENDOR_ID,
                prescription_id="rx_a",
            )
        )
    )

    stored_value = _make_stored_value(_expired_time(), is_blocked=False)
    current_time = int(time.time())
    asyncio.run(handle_expired_prescription("rx_b", stored_value, current_time))

    pool_after_expiry = asyncio.run(fake_redis.lrange(REDIS_KEY, 0, -1))
    assert pool_after_expiry == []

    response = asyncio.run(
        assign_doctor(SimpleNamespace(prescription_id="rx_c", vendor_id=VENDOR_ID))
    )
    assert response.data.doctor_ids == []

    blocked_mapping = asyncio.run(
        fake_redis.hget(RX_ASSIGNMENT_MAPPING_KEY, "rx_a")
    )
    assert json.loads(blocked_mapping)["is_blocked"] is True


# ---- Tests for safe_rpush_doctor and duplicate prevention ----


def test_safe_rpush_doctor_adds_when_not_in_pool():
    """safe_rpush_doctor should add a doctor when they are not already in the pool."""
    from app.main import safe_rpush_doctor

    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = ["doc_2"]

    result = asyncio.run(safe_rpush_doctor(fake_redis, REDIS_KEY, "doc_1"))

    assert result is True
    assert fake_redis.lists[REDIS_KEY] == ["doc_2", "doc_1"]


def test_safe_rpush_doctor_skips_when_already_in_pool():
    """safe_rpush_doctor should not add a doctor when they are already in the pool."""
    from app.main import safe_rpush_doctor

    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = ["doc_1", "doc_2"]

    result = asyncio.run(safe_rpush_doctor(fake_redis, REDIS_KEY, "doc_1"))

    assert result is False
    assert fake_redis.lists[REDIS_KEY] == ["doc_1", "doc_2"]


def test_safe_rpush_doctor_handles_empty_pool():
    """safe_rpush_doctor should add a doctor to an empty pool."""
    from app.main import safe_rpush_doctor

    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = []

    result = asyncio.run(safe_rpush_doctor(fake_redis, REDIS_KEY, "doc_1"))

    assert result is True
    assert fake_redis.lists[REDIS_KEY] == ["doc_1"]


def test_concurrent_expiry_no_duplicate_doctors(monkeypatch):
    """
    Reproduce the race condition: two prescriptions expire concurrently,
    both assigned to the same doctor. The doctor should appear only once
    in the pool after both expiry handlers complete.
    """
    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = []  # pool is empty, doctor is assigned

    # Set up two prescriptions assigned to the same doctor
    fake_redis.hashes[ASSIGNED_DOCTORS_KEY] = {
        "rx_001": "doc_1",
        "rx_002": "doc_1",
    }

    monkeypatch.setattr(main, "redis", fake_redis)
    monkeypatch.setattr(main, "redis_client", fake_redis, raising=False)
    monkeypatch.setattr(main, "acquire_redis_lock", AsyncMock(return_value="lock123"))
    monkeypatch.setattr(main, "release_redis_lock", AsyncMock())
    monkeypatch.setattr(main, "add_prescription_to_queue", AsyncMock())

    expired_time = _expired_time()
    current_time = int(time.time())
    stored_value_1 = _make_stored_value(expired_time, is_blocked=False)
    stored_value_2 = _make_stored_value(expired_time, is_blocked=False)

    # Run both expiry handlers concurrently
    async def run_concurrent():
        await asyncio.gather(
            handle_expired_prescription("rx_001", stored_value_1, current_time),
            handle_expired_prescription("rx_002", stored_value_2, current_time),
        )

    asyncio.run(run_concurrent())

    pool = fake_redis.lists[REDIS_KEY]
    assert pool.count("doc_1") == 1, f"Doctor doc_1 appears {pool.count('doc_1')} times in pool: {pool}"


def test_concurrent_expiry_multiple_doctors_no_duplicates(monkeypatch):
    """
    Two prescriptions expire concurrently, each assigned to two doctors
    with one shared doctor. No doctor should be duplicated in the pool.
    """
    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = []

    # rx_001 has doc_1,doc_2 and rx_002 has doc_2,doc_3 — doc_2 is shared
    fake_redis.hashes[ASSIGNED_DOCTORS_KEY] = {
        "rx_001": "doc_1,doc_2",
        "rx_002": "doc_2,doc_3",
    }

    monkeypatch.setattr(main, "redis", fake_redis)
    monkeypatch.setattr(main, "redis_client", fake_redis, raising=False)
    monkeypatch.setattr(main, "acquire_redis_lock", AsyncMock(return_value="lock123"))
    monkeypatch.setattr(main, "release_redis_lock", AsyncMock())
    monkeypatch.setattr(main, "add_prescription_to_queue", AsyncMock())

    expired_time = _expired_time()
    current_time = int(time.time())
    stored_value_1 = _make_stored_value(expired_time, is_blocked=False)
    stored_value_2 = _make_stored_value(expired_time, is_blocked=False)

    async def run_concurrent():
        await asyncio.gather(
            handle_expired_prescription("rx_001", stored_value_1, current_time),
            handle_expired_prescription("rx_002", stored_value_2, current_time),
        )

    asyncio.run(run_concurrent())

    pool = fake_redis.lists[REDIS_KEY]
    assert pool.count("doc_1") == 1, f"doc_1 duplicated: {pool}"
    assert pool.count("doc_2") == 1, f"doc_2 duplicated: {pool}"
    assert pool.count("doc_3") == 1, f"doc_3 duplicated: {pool}"
    assert len(pool) == 3


def test_release_doctor_no_duplicate(monkeypatch):
    """
    Releasing a doctor who is already in the pool should not create a duplicate.
    """
    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = ["doc_1"]  # doc_1 already in pool
    fake_redis.hashes[ASSIGNED_DOCTORS_KEY] = {"rx_001": "doc_1"}
    fake_redis.hashes[RX_ASSIGNMENT_MAPPING_KEY] = {
        "rx_001": json.dumps({
            "assignment_time": str(int(time.time())),
            "vendor_id": VENDOR_ID,
            "is_blocked": False,
        })
    }

    monkeypatch.setattr(main, "redis", fake_redis)
    monkeypatch.setattr(main, "redis_client", fake_redis, raising=False)
    monkeypatch.setattr(main, "pop_prescription_from_vendor_map", AsyncMock(return_value=None))

    from app.main import release_doctor
    asyncio.run(release_doctor(
        SimpleNamespace(doctor_id="doc_1", vendor_id=VENDOR_ID, prescription_id="rx_001")
    ))

    pool = fake_redis.lists[REDIS_KEY]
    assert pool.count("doc_1") == 1, f"doc_1 duplicated after release: {pool}"


def test_repeated_expiry_requeue_no_duplicate(monkeypatch):
    """
    Simulate the scenario from production logs: a single doctor with concurrent_doctors=1
    gets assigned, expires, requeued, assigned again, expires again. The pool should
    never contain duplicates.
    """
    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = ["doc_1"]

    monkeypatch.setattr(main, "redis", fake_redis)
    monkeypatch.setattr(main, "redis_client", fake_redis, raising=False)
    monkeypatch.setattr(
        main,
        "doctors_data",
        SimpleNamespace(
            data=[
                SimpleNamespace(
                    vendor=VENDOR_ID, concurrent_doctors=1, immediate_unblock=False
                )
            ]
        ),
    )
    monkeypatch.setattr(main, "add_prescription_to_vendor_map", AsyncMock())
    monkeypatch.setattr(main, "remove_prescription_from_vendor_map", AsyncMock())
    monkeypatch.setattr(main, "acquire_redis_lock", AsyncMock(return_value="lock123"))
    monkeypatch.setattr(main, "release_redis_lock", AsyncMock())

    for cycle in range(5):
        # Assign doctor to prescription
        asyncio.run(assign_doctor(
            SimpleNamespace(prescription_id="rx_001", vendor_id=VENDOR_ID)
        ))
        assert fake_redis.lists[REDIS_KEY] == [], f"Cycle {cycle}: pool should be empty after assign"

        # Expire the prescription
        stored_value = _make_stored_value(_expired_time(), is_blocked=False)
        current_time = int(time.time())
        asyncio.run(handle_expired_prescription("rx_001", stored_value, current_time))

        pool = fake_redis.lists[REDIS_KEY]
        assert pool.count("doc_1") == 1, f"Cycle {cycle}: doc_1 appears {pool.count('doc_1')} times: {pool}"


# ---- Race condition test: proves atomic eval prevents duplicates ----


def test_old_approach_race_condition_causes_duplicates():
    """
    Prove the race condition: with non-atomic lpos+rpush (old code),
    two concurrent calls can both see lpos=None and both rpush,
    creating a duplicate. Uses an asyncio.Event barrier to force interleaving.
    """
    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = []

    check_count = [0]

    async def run_race():
        barrier_event = asyncio.Event()

        async def old_safe_rpush(redis_conn, redis_key, doctor_id):
            """Simulates the OLD non-atomic safe_rpush_doctor with forced interleave."""
            result = await redis_conn.lpos(redis_key, doctor_id)
            if result is not None:
                return False
            # Force both tasks to complete their lpos check before either rpush
            check_count[0] += 1
            if check_count[0] < 2:
                await barrier_event.wait()
            else:
                barrier_event.set()
            await redis_conn.rpush(redis_key, doctor_id)
            return True

        return await asyncio.gather(
            old_safe_rpush(fake_redis, REDIS_KEY, "doc_1"),
            old_safe_rpush(fake_redis, REDIS_KEY, "doc_1"),
        )

    results = asyncio.run(run_race())

    pool = fake_redis.lists[REDIS_KEY]
    # Both tasks saw lpos=None and both did rpush → DUPLICATE
    assert results == [True, True], f"Expected both to succeed (race), got {results}"
    assert pool.count("doc_1") == 2, f"Expected duplicate, got {pool}"


def test_atomic_eval_prevents_race_condition():
    """
    Prove the fix: with atomic eval (Lua script), even if two tasks run
    concurrently, only one succeeds in adding the doctor.
    """
    from app.main import safe_rpush_doctor

    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = []

    async def run_concurrent():
        results = await asyncio.gather(
            safe_rpush_doctor(fake_redis, REDIS_KEY, "doc_1"),
            safe_rpush_doctor(fake_redis, REDIS_KEY, "doc_1"),
        )
        return results

    results = asyncio.run(run_concurrent())

    pool = fake_redis.lists[REDIS_KEY]
    # eval is atomic: first call adds, second call sees it and skips
    assert results.count(True) == 1, f"Expected exactly one success, got {results}"
    assert results.count(False) == 1, f"Expected exactly one skip, got {results}"
    assert pool.count("doc_1") == 1, f"Expected no duplicate, got {pool}"


def test_add_doctor_rejects_assigned_doctor_with_concurrent(monkeypatch):
    """
    Reproduce the add_doctor bug: with concurrent_doctors=2, assigned_doctors
    stores 'doc_aaa,doc_bbb'. The old code compared the whole string == doctor_id
    which always failed. The fix uses split(',') + 'in' check.
    """
    fake_redis = FakeRedis()
    fake_redis.lists[f"{main.DOCTORS_POOL}_vendor_1"] = ["doc_ccc"]
    fake_redis.hashes[ASSIGNED_DOCTORS_KEY] = {
        "rx_010": "doc_aaa,doc_bbb",  # Two doctors assigned (concurrent_doctors=2)
    }

    monkeypatch.setattr(main, "redis", fake_redis)
    monkeypatch.setattr(main, "redis_client", fake_redis, raising=False)

    from app.main import add_doctor
    from app.models import DoctorRequest

    response = asyncio.run(add_doctor(
        doctor_request=DoctorRequest(id="doc_aaa", vendor_id="vendor_1"),
    ))

    # Should be rejected because doc_aaa is assigned to rx_010
    assert "currently assigned" in response.data.message
    pool = fake_redis.lists[f"{main.DOCTORS_POOL}_vendor_1"]
    assert "doc_aaa" not in pool, f"doc_aaa should NOT be in pool: {pool}"


# ---- Round-robin tests ----


def test_round_robin_skips_tried_doctors(monkeypatch):
    """
    With concurrent_doctors=1 and 2 doctors / 1 prescription, after expiry
    the prescription should get a DIFFERENT doctor (not the one it had before).
    This is the simplest case to prove round-robin works.
    """
    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = ["doc_1", "doc_2"]

    monkeypatch.setattr(main, "redis", fake_redis)
    monkeypatch.setattr(main, "redis_client", fake_redis, raising=False)
    monkeypatch.setattr(
        main,
        "doctors_data",
        SimpleNamespace(
            data=[
                SimpleNamespace(
                    vendor=VENDOR_ID, concurrent_doctors=1, immediate_unblock=False
                )
            ]
        ),
    )
    monkeypatch.setattr(main, "add_prescription_to_vendor_map", AsyncMock())
    monkeypatch.setattr(main, "remove_prescription_from_vendor_map", AsyncMock())
    monkeypatch.setattr(main, "acquire_redis_lock", AsyncMock(return_value="lock123"))
    monkeypatch.setattr(main, "release_redis_lock", AsyncMock())

    # Round 1: doc_1 assigned to rx_001
    r1 = asyncio.run(assign_doctor(
        SimpleNamespace(prescription_id="rx_001", vendor_id=VENDOR_ID)
    ))
    assert r1.data.doctor_ids == ["doc_1"]
    assert fake_redis.lists[REDIS_KEY] == ["doc_2"]  # doc_1 removed

    # Expire rx_001 → doc_1 returned to pool, rx_001 requeued with tried_doctors=[doc_1]
    stored_value = _make_stored_value(_expired_time(), is_blocked=False)
    current_time = int(time.time())
    asyncio.run(handle_expired_prescription("rx_001", stored_value, current_time))

    assert fake_redis.lists[REDIS_KEY] == ["doc_2", "doc_1"]  # doc_1 at end

    # Round 2: Should skip doc_1 (tried) and assign doc_2
    r2 = asyncio.run(assign_doctor(
        SimpleNamespace(prescription_id="rx_001", vendor_id=VENDOR_ID)
    ))
    assert r2.data.doctor_ids == ["doc_2"], \
        f"Expected doc_2 (round-robin), got {r2.data.doctor_ids}"


def test_round_robin_resets_when_all_tried(monkeypatch):
    """
    When all doctors have been tried, the tried list resets and doctors
    can be assigned again.
    """
    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = ["doc_1"]

    monkeypatch.setattr(main, "redis", fake_redis)
    monkeypatch.setattr(main, "redis_client", fake_redis, raising=False)
    monkeypatch.setattr(
        main,
        "doctors_data",
        SimpleNamespace(
            data=[
                SimpleNamespace(
                    vendor=VENDOR_ID, concurrent_doctors=1, immediate_unblock=False
                )
            ]
        ),
    )
    monkeypatch.setattr(main, "add_prescription_to_vendor_map", AsyncMock())
    monkeypatch.setattr(main, "remove_prescription_from_vendor_map", AsyncMock())
    monkeypatch.setattr(main, "acquire_redis_lock", AsyncMock(return_value="lock123"))
    monkeypatch.setattr(main, "release_redis_lock", AsyncMock())

    # Round 1: Only doc_1 available → assigned
    r1 = asyncio.run(assign_doctor(
        SimpleNamespace(prescription_id="rx_001", vendor_id=VENDOR_ID)
    ))
    assert r1.data.doctor_ids == ["doc_1"]

    # Expire → doc_1 returned, tried_doctors=[doc_1]
    stored_value = _make_stored_value(_expired_time(), is_blocked=False)
    current_time = int(time.time())
    asyncio.run(handle_expired_prescription("rx_001", stored_value, current_time))

    # Round 2: Only doc_1 in pool, but all tried → resets → assigns doc_1 again
    r2 = asyncio.run(assign_doctor(
        SimpleNamespace(prescription_id="rx_001", vendor_id=VENDOR_ID)
    ))
    assert r2.data.doctor_ids == ["doc_1"], \
        f"Expected doc_1 (reset after all tried), got {r2.data.doctor_ids}"


def test_round_robin_three_doctors_three_prescriptions(monkeypatch):
    """
    With 3 doctors and 3 prescriptions (concurrent_doctors=1), after expiry:
    - rx_001 tried doc_1 → gets doc_2 ✓
    - rx_002 tried doc_2 → gets doc_1 ✓
    - rx_003 tried doc_3 → only doc_3 left, resets → gets doc_3 (acceptable)

    At least the first 2 prescriptions get different doctors.
    """
    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = ["doc_1", "doc_2", "doc_3"]

    monkeypatch.setattr(main, "redis", fake_redis)
    monkeypatch.setattr(main, "redis_client", fake_redis, raising=False)
    monkeypatch.setattr(
        main,
        "doctors_data",
        SimpleNamespace(
            data=[
                SimpleNamespace(
                    vendor=VENDOR_ID, concurrent_doctors=1, immediate_unblock=False
                )
            ]
        ),
    )
    monkeypatch.setattr(main, "add_prescription_to_vendor_map", AsyncMock())
    monkeypatch.setattr(main, "remove_prescription_from_vendor_map", AsyncMock())
    monkeypatch.setattr(main, "acquire_redis_lock", AsyncMock(return_value="lock123"))
    monkeypatch.setattr(main, "release_redis_lock", AsyncMock())

    # Round 1: Assign
    r1 = {}
    for rx_id in ["rx_001", "rx_002", "rx_003"]:
        resp = asyncio.run(assign_doctor(
            SimpleNamespace(prescription_id=rx_id, vendor_id=VENDOR_ID)
        ))
        r1[rx_id] = resp.data.doctor_ids[0]

    assert r1 == {"rx_001": "doc_1", "rx_002": "doc_2", "rx_003": "doc_3"}

    # Expire all
    for rx_id in ["rx_001", "rx_002", "rx_003"]:
        stored_value = _make_stored_value(_expired_time(), is_blocked=False)
        current_time = int(time.time())
        asyncio.run(handle_expired_prescription(rx_id, stored_value, current_time))

    # Round 2: Assign again
    r2 = {}
    for rx_id in ["rx_001", "rx_002", "rx_003"]:
        resp = asyncio.run(assign_doctor(
            SimpleNamespace(prescription_id=rx_id, vendor_id=VENDOR_ID)
        ))
        r2[rx_id] = resp.data.doctor_ids[0]

    # rx_001 should NOT get doc_1 again, rx_002 should NOT get doc_2 again
    assert r2["rx_001"] != "doc_1", f"rx_001 got doc_1 again: {r2}"
    assert r2["rx_002"] != "doc_2", f"rx_002 got doc_2 again: {r2}"
    # rx_003 may get doc_3 again (only option left after others take untried docs)
    # All 3 doctors should be assigned (no leaked doctors)
    assert set(r2.values()) == {"doc_1", "doc_2", "doc_3"}, \
        f"Not all doctors used: {r2}"
