from __future__ import annotations

import asyncio
import json
import time
from dataclasses import dataclass
from typing import Any, Literal

START = time.perf_counter()

HOST = "127.0.0.1"
API_PORT = 9000
CACHE_PORT = 9001
DB_PORT = 9002

CACHE_DATA: dict[str, str] = {
    "user:1": "Alice (cached)",
    "product:10": "Mechanical Keyboard (cached)",
}

DB_DATA: dict[str, str] = {
    "user:1": "Alice",
    "user:2": "Bob",
    "product:10": "Mechanical Keyboard",
    "product:20": "Ergonomic Mouse",
    "inventory:10": "42",
    "inventory:20": "7",
    "recommendations:1": "mousepad, wrist_rest",
    "recommendations:2": "usb_hub, webcam",
}


@dataclass(frozen=True)
class ClientRequest:
    request_id: int
    user_id: int
    product_id: int
    mode: Literal["sequential", "concurrent"]


@dataclass(frozen=True)
class ServerGroup:
    api: asyncio.AbstractServer
    cache: asyncio.AbstractServer
    db: asyncio.AbstractServer


def log(service: str, message: str) -> None:
    elapsed = time.perf_counter() - START
    print(f"[{elapsed:6.2f}s] {service:<8} | {message}")


async def read_line(reader: asyncio.StreamReader) -> str:
    data = await reader.readline()
    return data.decode().strip()


async def cache_server_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
    addr = writer.get_extra_info("peername")
    command = await read_line(reader)
    log("CACHE", f"from={addr} cmd={command!r}")

    await asyncio.sleep(0.06)
    response: str

    if command.startswith("GET "):
        key = command.removeprefix("GET ")
        response = CACHE_DATA.get(key, "MISS")
    else:
        response = "ERROR bad command"

    writer.write((response + "\n").encode())
    await writer.drain()
    writer.close()
    await writer.wait_closed()
    log("CACHE", f"replied={response!r}")


async def db_server_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
    addr = writer.get_extra_info("peername")
    command = await read_line(reader)
    log("DB", f"from={addr} cmd={command!r}")

    await asyncio.sleep(0.25)
    response: str

    if command.startswith("GET "):
        key = command.removeprefix("GET ")
        response = DB_DATA.get(key, "NOT_FOUND")
    else:
        response = "ERROR bad command"

    writer.write((response + "\n").encode())
    await writer.drain()
    writer.close()
    await writer.wait_closed()
    log("DB", f"replied={response!r}")


async def send_command(port: int, command: str) -> str:
    reader, writer = await asyncio.open_connection(HOST, port)
    writer.write((command + "\n").encode())
    await writer.drain()
    response = await read_line(reader)
    writer.close()
    await writer.wait_closed()
    return response


async def get_with_cache_fallback(key: str) -> str:
    cached = await send_command(CACHE_PORT, f"GET {key}")
    if cached != "MISS":
        log("API", f"cache hit for {key!r}")
        return cached

    log("API", f"cache miss for {key!r}; querying DB")
    return await send_command(DB_PORT, f"GET {key}")


async def get_db_only(key: str) -> str:
    return await send_command(DB_PORT, f"GET {key}")


async def build_response_sequential(req: ClientRequest) -> dict[str, Any]:
    # Every operation waits for the previous one to finish.
    user = await get_with_cache_fallback(f"user:{req.user_id}")
    product = await get_with_cache_fallback(f"product:{req.product_id}")
    inventory = await get_db_only(f"inventory:{req.product_id}")
    recommendations = await get_db_only(f"recommendations:{req.user_id}")

    return {
        "user": user,
        "product": product,
        "inventory": int(inventory),
        "recommendations": recommendations.split(", ") if ", " in recommendations else recommendations.split(","),
    }


async def build_response_concurrent(req: ClientRequest) -> dict[str, Any]:
    # Launch independent operations immediately, then await all results together.
    user_task = asyncio.create_task(get_with_cache_fallback(f"user:{req.user_id}"))
    product_task = asyncio.create_task(get_with_cache_fallback(f"product:{req.product_id}"))
    inventory_task = asyncio.create_task(get_db_only(f"inventory:{req.product_id}"))
    recommendations_task = asyncio.create_task(get_db_only(f"recommendations:{req.user_id}"))

    user, product, inventory, recommendations = await asyncio.gather(
        user_task,
        product_task,
        inventory_task,
        recommendations_task,
    )

    return {
        "user": user,
        "product": product,
        "inventory": int(inventory),
        "recommendations": recommendations.split(", ") if ", " in recommendations else recommendations.split(","),
    }


def parse_client_request(raw: str) -> ClientRequest:
    payload = json.loads(raw)
    mode_raw = payload.get("mode", "concurrent")

    if mode_raw not in {"sequential", "concurrent"}:
        raise ValueError("mode must be 'sequential' or 'concurrent'")

    return ClientRequest(
        request_id=int(payload["request_id"]),
        user_id=int(payload["user_id"]),
        product_id=int(payload["product_id"]),
        mode=mode_raw,
    )


async def api_server_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
    started = time.perf_counter()
    addr = writer.get_extra_info("peername")

    try:
        raw = await read_line(reader)
        req = parse_client_request(raw)
    except (ValueError, KeyError, TypeError, json.JSONDecodeError) as exc:
        error_body = json.dumps({"error": str(exc)}) + "\n"
        writer.write(error_body.encode())
        await writer.drain()
        writer.close()
        await writer.wait_closed()
        log("API", f"bad request from={addr}: {exc}")
        return

    log("API", f"request_id={req.request_id} mode={req.mode} from={addr}")

    if req.mode == "sequential":
        data = await build_response_sequential(req)
    else:
        data = await build_response_concurrent(req)

    duration_ms = round((time.perf_counter() - started) * 1000)
    response = {
        "request_id": req.request_id,
        "mode": req.mode,
        "duration_ms": duration_ms,
        "data": data,
    }

    writer.write((json.dumps(response) + "\n").encode())
    await writer.drain()
    writer.close()
    await writer.wait_closed()
    log("API", f"request_id={req.request_id} complete in {duration_ms}ms")


async def run_client(request: ClientRequest) -> dict[str, Any]:
    reader, writer = await asyncio.open_connection(HOST, API_PORT)
    writer.write((json.dumps(request.__dict__) + "\n").encode())
    await writer.drain()
    result = json.loads(await read_line(reader))
    writer.close()
    await writer.wait_closed()
    return result


async def start_servers() -> ServerGroup:
    cache = await asyncio.start_server(cache_server_handler, HOST, CACHE_PORT)
    db = await asyncio.start_server(db_server_handler, HOST, DB_PORT)
    api = await asyncio.start_server(api_server_handler, HOST, API_PORT)

    log("MAIN", f"cache listening on {HOST}:{CACHE_PORT}")
    log("MAIN", f"db listening on {HOST}:{DB_PORT}")
    log("MAIN", f"api listening on {HOST}:{API_PORT}")

    return ServerGroup(api=api, cache=cache, db=db)


async def stop_servers(servers: ServerGroup) -> None:
    for server in (servers.api, servers.cache, servers.db):
        server.close()
        await server.wait_closed()


async def run_demo() -> None:
    servers = await start_servers()

    try:
        await asyncio.sleep(0.05)

        sequential_request = ClientRequest(request_id=1, user_id=2, product_id=20, mode="sequential")
        concurrent_request = ClientRequest(request_id=2, user_id=2, product_id=20, mode="concurrent")

        log("MAIN", "running one sequential request")
        sequential_result = await run_client(sequential_request)
        print(json.dumps(sequential_result, indent=2))

        log("MAIN", "running one concurrent request")
        concurrent_result = await run_client(concurrent_request)
        print(json.dumps(concurrent_result, indent=2))

        speedup = round(
            sequential_result["duration_ms"] / concurrent_result["duration_ms"],
            2,
        )
        log("MAIN", f"concurrent speedup ~= {speedup}x")

    finally:
        await stop_servers(servers)


if __name__ == "__main__":
    asyncio.run(run_demo())
