"""
Regression test for the assign_doctor double-assignment race.

assign_doctor reads the whole pool once (lrange), then lrem's each chosen
doctor out and appends them to the result. In the real system every redis call
is an awaited network round-trip, so two assign_doctor calls for *different*
prescriptions sharing one pool can interleave:

  - both read the same pool snapshot [D1, D2, D3, D4]
  - call A lrem's D1..D4 (removes them)
  - call B lrem's D1..D4 -> removes 0 (already gone) BUT still appends them

=> the same doctor is offered to TWO prescriptions at once.

The FakeRedis below models the round-trip by yielding inside lrange (after
capturing the snapshot) and returns the real LREM count. The fix is to append a
doctor only when lrem actually removed them (removed >= 1), which makes the
"claim" atomic, so no doctor can land on two prescriptions.

This test asserts that invariant, so it FAILS against the current code and
PASSES once assign_doctor checks the lrem return value.
"""
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock

import pytest

import app.main as main
from app.main import assign_doctor, vendor_keys

VENDOR_ID = "vendor_race"
REDIS_KEY = "vendor_race_pool"


class RaceFakeRedis:
    """FakeRedis whose lrange yields control (like a real async round-trip),
    and whose lrem returns the number of elements removed (real LREM semantics)."""

    def __init__(self):
        self.lists = {}
        self.hashes = {}

    async def lrange(self, key, start, end):
        items = self.lists.get(key, [])
        snapshot = items[start:] if end == -1 else items[start : end + 1]
        # Model the network round-trip: yield AFTER capturing the snapshot, so a
        # concurrent assign_doctor can read the same (now stale) pool.
        await asyncio.sleep(0)
        return snapshot

    async def lrem(self, key, count, value):
        items = self.lists.get(key, [])
        if count == 0:
            new = [i for i in items if i != value]
            removed = len(items) - len(new)
            self.lists[key] = new
            return removed
        removed, new = 0, []
        for i in items:
            if i == value and removed < count:
                removed += 1
                continue
            new.append(i)
        self.lists[key] = new
        return removed

    async def rpush(self, key, value):
        self.lists.setdefault(key, []).append(value)

    async def hget(self, key, field):
        return self.hashes.get(key, {}).get(field)

    async def hset(self, key, field, value):
        self.hashes.setdefault(key, {})[field] = value

    async def hdel(self, key, field):
        if key in self.hashes:
            self.hashes[key].pop(field, None)


@pytest.fixture(autouse=True)
def setup_vendor_keys():
    vendor_keys[VENDOR_ID] = REDIS_KEY
    yield
    vendor_keys.pop(VENDOR_ID, None)


def test_concurrent_assign_does_not_offer_one_doctor_to_two_prescriptions(monkeypatch):
    fake_redis = RaceFakeRedis()
    fake_redis.lists[REDIS_KEY] = ["D1", "D2", "D3", "D4"]

    monkeypatch.setattr(main, "redis", fake_redis)
    monkeypatch.setattr(main, "redis_client", fake_redis, raising=False)
    monkeypatch.setattr(
        main,
        "doctors_data",
        SimpleNamespace(
            data=[
                SimpleNamespace(
                    vendor=VENDOR_ID, concurrent_doctors=4, immediate_unblock=False
                )
            ]
        ),
    )
    monkeypatch.setattr(main, "add_prescription_to_queue", AsyncMock())

    async def run():
        return await asyncio.gather(
            assign_doctor(SimpleNamespace(prescription_id="P1", vendor_id=VENDOR_ID)),
            assign_doctor(SimpleNamespace(prescription_id="P2", vendor_id=VENDOR_ID)),
        )

    resp1, resp2 = asyncio.run(run())
    docs1 = list(resp1.data.doctor_ids)
    docs2 = list(resp2.data.doctor_ids)

    overlap = set(docs1) & set(docs2)
    assert not overlap, (
        f"assign_doctor offered the same doctor(s) {overlap} to two prescriptions "
        f"(P1={docs1}, P2={docs2})"
    )

    # And no doctor should be handed out more times than it exists in the pool.
    all_assigned = docs1 + docs2
    assert len(all_assigned) == len(set(all_assigned)), (
        f"a doctor was assigned more than once across prescriptions: {all_assigned}"
    )
