diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 7bafa831f..aa05798df 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -362,6 +362,29 @@ def on_event(event: AnyEvent) -> None: raise ClickException(f"task could not run: {ve}") from ve +@controller.command(name="ws") +@click.argument("name", type=str) +@click.argument("parameters", type=ParametersType(), default={}, required=False) +def run_blocking( + name: str, + parameters: TaskParameters, +): + instrument_session = "cm33-3" + + from websockets.sync.client import connect + + task_req = TaskRequest( + name=name, + params=parameters, + instrument_session=instrument_session, + ) + + with connect("ws://localhost:8007/run_plan") as ws: + ws.send(task_req.model_dump_json()) + for message in ws: + print(message) + + @controller.command(name="state") @click.pass_obj @check_connection diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 6acc29ab7..4ea671145 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -1,5 +1,7 @@ +import logging from collections.abc import Mapping from functools import cache +from multiprocessing.connection import Connection from typing import Any from bluesky.callbacks.tiled_writer import TiledWriter @@ -21,6 +23,7 @@ WorkerTask, ) from blueapi.utils.serialization import access_blob +from blueapi.worker import task_worker from blueapi.worker.event import TaskStatusEnum, WorkerEvent, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TaskWorker, TrackableTask @@ -28,7 +31,7 @@ """This module provides interface between web application and underlying Bluesky context and worker""" - +LOGGER = logging.getLogger(__name__) _CONFIG: ApplicationConfig = ApplicationConfig() @@ -270,3 +273,22 @@ def get_python_env( """Retrieve information about the Python environment""" scratch = config().scratch return get_python_environment(config=scratch, name=name, source=source) + + +def pipe_events(tx: Connection) -> int: + + def handler( + worker_event: WorkerEvent, + cor_id: str | None, + ) -> None: + LOGGER.info("Sending event") + tx.send(worker_event) + + task_worker = worker() + sub_id = task_worker.worker_events.subscribe(handler) + return sub_id + + +def unpipe_events(h: int) -> None: + task_worker = worker() + task_worker.worker_events.unsubscribe(h) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index c79dd3df3..554f75aa3 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -2,6 +2,7 @@ import urllib.parse from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager +from multiprocessing import Pipe from typing import Annotated, Any import jwt @@ -14,8 +15,10 @@ HTTPException, Request, Response, + WebSocket, status, ) +from fastapi.concurrency import run_in_threadpool from fastapi.datastructures import Address from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse, StreamingResponse @@ -37,7 +40,8 @@ from blueapi.config import ApplicationConfig, OIDCConfig, Tag from blueapi.service import interface from blueapi.worker import TrackableTask, WorkerState -from blueapi.worker.event import TaskStatusEnum +from blueapi.worker.event import TaskStatusEnum, WorkerEvent +from blueapi.worker.worker_errors import WorkerBusyError from .model import ( DeviceModel, @@ -540,6 +544,50 @@ def logout(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> Response: ) +@secure_router.websocket("/run_plan") +async def run_plan( + ws: WebSocket, + runner: Annotated[WorkerDispatcher, Depends(_runner)], +): + user = "alice" + + # ack ws + await ws.accept() + # accept task request through socket + rq = await ws.receive_json() + # submit task to runner + try: + task_request: TaskRequest = TaskRequest.model_validate(rq) + task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) + except ValidationError: + await ws.close(code=1003, reason="invalid args") + return + + # add listener to runner + tx, rx = Pipe() + h = runner.run(interface.pipe_events, tx=tx) + # start task + try: + task = WorkerTask(task_id=task_id) + runner.run( + interface.begin_task, + task=task, + ) + except WorkerBusyError: + await ws.close(code=1013, reason="Worker busy") + return + # pipe events to ws + try: + while True: + event: WorkerEvent = await run_in_threadpool(rx.recv) + await ws.send_json(event.model_dump(mode="json")) + if event.is_complete(): + break + finally: + await ws.close() + runner.run(interface.unpipe_events, h=h) + + @start_as_current_span(TRACER, "config") def start(config: ApplicationConfig): import uvicorn