"""
Regression test for the block_doctor pool leak / prescription starvation.

Scenario (the "4 doctors active, concurrent_doctors=4" report):
  - 4 doctors are in the pool.
  - A prescription P1 is offered to all 4; the pool drains.
  - One doctor accepts P1 (block_doctor). The other 3 did NOT take it, so they
    are free and must be returned to the pool -- this is the same
    "rotate back to the end of the queue" invariant that release_doctor,
    reassign_doctor and the expiry-requeue all maintain.
  - If block_doctor fails to return them, the next prescription P2 finds an
    empty pool and STARVES, even though 3 doctors are idle.

This test asserts the correct (no-starvation) behaviour, so it FAILS against the
current code (demonstrating the leak) and PASSES once block_doctor returns the
non-accepting doctors to the pool.
"""
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock

import pytest

import app.main as main
from app.main import assign_doctor, block_doctor, vendor_keys

VENDOR_ID = "vendor_leak"
REDIS_KEY = "vendor_leak_pool"


class FakeRedis:
    """In-memory Redis fake with real list/hash semantics (subset used here)."""

    def __init__(self):
        self.lists = {}
        self.hashes = {}
        self.sets = {}

    async def lrange(self, key, start, end):
        items = self.lists.get(key, [])
        return items[start:] if end == -1 else items[start : end + 1]

    async def lrem(self, key, count, value):
        items = self.lists.get(key, [])
        if count == 0:
            new = [i for i in items if i != value]
            removed = len(items) - len(new)
            self.lists[key] = new
            return removed
        removed, new = 0, []
        for i in items:
            if i == value and removed < count:
                removed += 1
                continue
            new.append(i)
        self.lists[key] = new
        return removed

    async def rpush(self, key, value):
        self.lists.setdefault(key, []).append(value)

    async def eval(self, script, num_keys, *args):
        """Mimic safe_rpush_doctor's atomic LPOS+RPUSH: append only if absent."""
        key, value = args[0], args[1]
        if value in self.lists.get(key, []):
            return 0
        self.lists.setdefault(key, []).append(value)
        return 1

    async def hget(self, key, field):
        return self.hashes.get(key, {}).get(field)

    async def hgetall(self, key):
        return dict(self.hashes.get(key, {}))

    async def hset(self, key, field, value):
        self.hashes.setdefault(key, {})[field] = value

    async def hdel(self, key, field):
        if key in self.hashes:
            self.hashes[key].pop(field, None)

    async def sadd(self, key, *values):
        self.sets.setdefault(key, set()).update(values)

    async def srem(self, key, *values):
        self.sets.setdefault(key, set()).difference_update(values)

    async def sismember(self, key, value):
        return value in self.sets.get(key, set())

    async def exists(self, key):
        return 1 if (key in self.sets or key in self.lists or key in self.hashes) else 0


@pytest.fixture(autouse=True)
def setup_vendor_keys():
    vendor_keys[VENDOR_ID] = REDIS_KEY
    yield
    vendor_keys.pop(VENDOR_ID, None)


def _wire(monkeypatch, fake_redis):
    monkeypatch.setattr(main, "redis", fake_redis)
    monkeypatch.setattr(main, "redis_client", fake_redis, raising=False)
    monkeypatch.setattr(
        main,
        "doctors_data",
        SimpleNamespace(
            data=[
                SimpleNamespace(
                    vendor=VENDOR_ID, concurrent_doctors=4, immediate_unblock=False
                )
            ]
        ),
    )
    # Isolate the lock + queue side-helpers so we test pure pool behaviour.
    monkeypatch.setattr(main, "acquire_redis_lock", AsyncMock(return_value="lock"))
    monkeypatch.setattr(main, "release_redis_lock", AsyncMock())
    monkeypatch.setattr(main, "add_prescription_to_queue", AsyncMock())


def test_block_doctor_returns_unaccepted_doctors_to_pool(monkeypatch):
    fake_redis = FakeRedis()
    fake_redis.lists[REDIS_KEY] = ["D1", "D2", "D3", "D4"]  # 4 active doctors
    _wire(monkeypatch, fake_redis)

    # P1 is offered to all 4 doctors; the pool drains.
    resp1 = asyncio.run(
        assign_doctor(SimpleNamespace(prescription_id="P1", vendor_id=VENDOR_ID))
    )
    assert sorted(resp1.data.doctor_ids) == ["D1", "D2", "D3", "D4"]
    assert asyncio.run(fake_redis.lrange(REDIS_KEY, 0, -1)) == []

    # D2 accepts P1. D1, D3, D4 did not take it and are now free.
    asyncio.run(
        block_doctor(
            SimpleNamespace(
                block_doctor_id="D2", vendor_id=VENDOR_ID, prescription_id="P1"
            )
        )
    )

    # Invariant: the three non-accepting doctors must be back in the pool.
    pool_after_block = asyncio.run(fake_redis.lrange(REDIS_KEY, 0, -1))
    assert sorted(pool_after_block) == ["D1", "D3", "D4"], (
        f"block_doctor leaked the non-accepting doctors out of the pool; "
        f"pool={pool_after_block}"
    )

    # End-to-end: the next prescription must still reach those 3 (no starvation).
    resp2 = asyncio.run(
        assign_doctor(SimpleNamespace(prescription_id="P2", vendor_id=VENDOR_ID))
    )
    assert sorted(resp2.data.doctor_ids) == ["D1", "D3", "D4"], (
        f"P2 starved because the non-accepting doctors were leaked; "
        f"offered to {resp2.data.doctor_ids}"
    )
