diff --git a/src/browser/App.tsx b/src/browser/App.tsx index 3f7a001158..9c602746eb 100644 --- a/src/browser/App.tsx +++ b/src/browser/App.tsx @@ -68,6 +68,7 @@ import { AboutDialogProvider } from "./contexts/AboutDialogContext"; import { SettingsModal } from "./components/Settings/SettingsModal"; import { AboutDialog } from "./components/About/AboutDialog"; import { MuxGatewaySessionExpiredDialog } from "./components/MuxGatewaySessionExpiredDialog"; +import { HostKeyVerificationDialog } from "./components/HostKeyVerificationDialog"; import { SplashScreenProvider } from "./components/splashScreens/SplashScreenProvider"; import { TutorialProvider } from "./contexts/TutorialContext"; import { PowerModeProvider } from "./contexts/PowerModeContext"; @@ -1071,6 +1072,7 @@ function AppInner() { + ); diff --git a/src/browser/components/HostKeyVerificationDialog.test.tsx b/src/browser/components/HostKeyVerificationDialog.test.tsx new file mode 100644 index 0000000000..3a325a229c --- /dev/null +++ b/src/browser/components/HostKeyVerificationDialog.test.tsx @@ -0,0 +1,298 @@ +import "../../../tests/ui/dom"; + +import { afterEach, beforeEach, describe, expect, it, mock } from "bun:test"; +import { act, cleanup, fireEvent, render, waitFor } from "@testing-library/react"; +import type { + HostKeyVerificationEvent, + HostKeyVerificationRequest, +} from "@/common/orpc/schemas/ssh"; +import type { ReactNode } from "react"; + +// Self-contained dialog stub — bun's mock.module is process-global and +// ShareTranscriptDialog.test.tsx registers an incomplete stub that omits +// DialogDescription/DialogFooter/Warning*. Our own complete mock prevents +// Radix context errors when tests run in the same bun process. +void mock.module("@/browser/components/ui/dialog", () => ({ + Dialog: (props: { open: boolean; children: ReactNode }) => + props.open ?
{props.children}
: null, + DialogContent: (props: { children: ReactNode }) =>
{props.children}
, + DialogHeader: (props: { children: ReactNode }) =>
{props.children}
, + DialogTitle: (props: { children: ReactNode }) =>

{props.children}

, + DialogDescription: (props: { children: ReactNode }) =>

{props.children}

, + DialogFooter: (props: { children: ReactNode }) =>
{props.children}
, + WarningBox: (props: { children: ReactNode }) =>
{props.children}
, + WarningTitle: (props: { children: ReactNode }) =>
{props.children}
, + WarningText: (props: { children: ReactNode }) =>
{props.children}
, +})); + +import { HostKeyVerificationDialog } from "./HostKeyVerificationDialog"; + +interface ControlledSubscription { + iterable: AsyncIterable; + push: (value: T) => void; + close: () => void; + returnSpy: ReturnType; +} + +function createMockIterableSubscription(): ControlledSubscription { + const buffered: T[] = []; + const pending: Array<(result: IteratorResult) => void> = []; + let closed = false; + + const doneResult = (): IteratorResult => ({ + value: undefined as unknown as T, + done: true, + }); + + const flushDone = () => { + while (pending.length > 0) { + const resolve = pending.shift(); + resolve?.(doneResult()); + } + }; + + const returnSpy = mock((_value?: unknown) => { + closed = true; + flushDone(); + return Promise.resolve(doneResult()); + }); + + const iterator: AsyncIterator = { + next() { + if (closed) { + return Promise.resolve(doneResult()); + } + + if (buffered.length > 0) { + return Promise.resolve({ value: buffered.shift()!, done: false }); + } + + return new Promise((resolve) => { + pending.push(resolve); + }); + }, + return: returnSpy, + }; + + return { + iterable: { + [Symbol.asyncIterator]: () => iterator, + }, + returnSpy, + push(value: T) { + if (closed) { + return; + } + + const resolve = pending.shift(); + if (resolve) { + resolve({ value, done: false }); + return; + } + + buffered.push(value); + }, + close() { + if (closed) { + return; + } + + closed = true; + flushDone(); + }, + }; +} + +interface HostKeyVerificationApi { + ssh: { + hostKeyVerification: { + subscribe: ( + _input?: undefined, + _options?: { signal?: AbortSignal } + ) => Promise>; + respond: (input: { requestId: string; accept: boolean }) => Promise; + }; + }; +} + +let api: HostKeyVerificationApi | null = null; +let respondMock: ReturnType; +let subscribeMock: ReturnType; +let mockSubscription: ControlledSubscription; + +// mock.module is hoisted by bun — the mock is active before static imports resolve. +void mock.module("@/browser/contexts/API", () => ({ + useAPI: () => ({ api }), +})); + +const MOCK_REQUEST: HostKeyVerificationRequest = { + requestId: "req-1", + host: "example.com", + keyType: "ssh-ed25519", + fingerprint: "SHA256:abcdef", + prompt: "Trust host key?", +}; + +async function flushReactWork(): Promise { + await Promise.resolve(); + await Promise.resolve(); +} + +async function enqueueRequest(request: HostKeyVerificationRequest): Promise { + await act(async () => { + mockSubscription.push({ type: "request", ...request }); + await flushReactWork(); + }); +} + +describe("HostKeyVerificationDialog", () => { + beforeEach(() => { + cleanup(); + + mockSubscription = createMockIterableSubscription(); + respondMock = mock(() => Promise.resolve()); + subscribeMock = mock(() => Promise.resolve(mockSubscription.iterable)); + + api = { + ssh: { + hostKeyVerification: { + subscribe: subscribeMock, + respond: respondMock, + }, + }, + }; + }); + + afterEach(() => { + mockSubscription.close(); + cleanup(); + api = null; + }); + + it("dequeues request on successful respond", async () => { + const { getByRole, queryByRole } = render(); + + await waitFor(() => expect(subscribeMock).toHaveBeenCalledTimes(1)); + await enqueueRequest(MOCK_REQUEST); + + await act(async () => { + fireEvent.click(getByRole("button", { name: "Reject" })); + await flushReactWork(); + }); + + await waitFor(() => { + expect(respondMock).toHaveBeenCalledWith({ requestId: "req-1", accept: false }); + }); + expect(respondMock).toHaveBeenCalledTimes(1); + + // Successful respond dequeues → dialog closes → no Reject button + expect(queryByRole("button", { name: "Reject" })).toBeNull(); + }); + + it("keeps request visible when respond fails", async () => { + respondMock = mock(() => Promise.reject(new Error("RPC transport error"))); + subscribeMock = mock(() => Promise.resolve(mockSubscription.iterable)); + api = { + ssh: { + hostKeyVerification: { + subscribe: subscribeMock, + respond: respondMock, + }, + }, + }; + + const { getByRole, queryByRole } = render(); + + await waitFor(() => expect(subscribeMock).toHaveBeenCalledTimes(1)); + await enqueueRequest(MOCK_REQUEST); + + // Regression guard: failed responses must leave the same prompt active so retry works. + await act(async () => { + fireEvent.click(getByRole("button", { name: "Reject" })); + await flushReactWork(); + }); + await waitFor(() => expect(respondMock).toHaveBeenCalledTimes(1)); + + // Button still visible — user can retry + expect(queryByRole("button", { name: "Reject" })).not.toBeNull(); + + await act(async () => { + fireEvent.click(getByRole("button", { name: "Reject" })); + await flushReactWork(); + }); + await waitFor(() => expect(respondMock).toHaveBeenCalledTimes(2)); + + expect(respondMock).toHaveBeenNthCalledWith(1, { requestId: "req-1", accept: false }); + expect(respondMock).toHaveBeenNthCalledWith(2, { requestId: "req-1", accept: false }); + }); + + it("closes late iterator when cleanup runs before subscribe resolves", async () => { + let resolveSubscribe: ((iterable: AsyncIterable) => void) | null = + null; + subscribeMock = mock( + () => + new Promise>((resolve) => { + resolveSubscribe = resolve; + }) + ); + api = { + ssh: { + hostKeyVerification: { + subscribe: subscribeMock, + respond: respondMock, + }, + }, + }; + + const { unmount } = render(); + await waitFor(() => expect(subscribeMock).toHaveBeenCalledTimes(1)); + + // Cleanup fires while subscribe() is still pending — iteratorRef is undefined. + unmount(); + + // Now resolve the subscribe promise. The abort guard should close the iterator. + await act(async () => { + resolveSubscribe?.(mockSubscription.iterable); + await flushReactWork(); + }); + + // The abort guard should have called return() on the late iterator. + await waitFor(() => expect(mockSubscription.returnSpy).toHaveBeenCalledTimes(1)); + }); + + it("does not double-close iterator on normal cleanup", async () => { + const { unmount } = render(); + await waitFor(() => expect(subscribeMock).toHaveBeenCalledTimes(1)); + await enqueueRequest(MOCK_REQUEST); + + unmount(); + + // Give async tasks time to settle. + await act(async () => { + await flushReactWork(); + }); + + // Normal cleanup path: return() called exactly once. + expect(mockSubscription.returnSpy).toHaveBeenCalledTimes(1); + }); + + it("clears pending queue when api becomes null", async () => { + const { queryByRole, rerender } = render(); + + await waitFor(() => expect(subscribeMock).toHaveBeenCalledTimes(1)); + await enqueueRequest(MOCK_REQUEST); + + // Dialog should be visible + expect(queryByRole("button", { name: "Reject" })).not.toBeNull(); + + // Simulate disconnect — api becomes null + api = null; + await act(async () => { + rerender(); + await flushReactWork(); + }); + + // Queue cleared → dialog dismissed + expect(queryByRole("button", { name: "Reject" })).toBeNull(); + }); +}); diff --git a/src/browser/components/HostKeyVerificationDialog.tsx b/src/browser/components/HostKeyVerificationDialog.tsx new file mode 100644 index 0000000000..2845dffc47 --- /dev/null +++ b/src/browser/components/HostKeyVerificationDialog.tsx @@ -0,0 +1,164 @@ +import { useEffect, useState } from "react"; +import { useAPI } from "@/browser/contexts/API"; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogDescription, + DialogFooter, + WarningBox, + WarningTitle, + WarningText, +} from "@/browser/components/ui/dialog"; +import { Button } from "@/browser/components/ui/button"; +import type { + HostKeyVerificationEvent, + HostKeyVerificationRequest, +} from "@/common/orpc/schemas/ssh"; + +export function HostKeyVerificationDialog() { + const { api } = useAPI(); + const [pendingQueue, setPendingQueue] = useState([]); + const pending = pendingQueue[0] ?? null; + const [responding, setResponding] = useState(false); + + useEffect(() => { + if (!api) { + setPendingQueue([]); + return; + } + + const controller = new AbortController(); + const { signal } = controller; + + // Track the async iterator so we can explicitly close it on cleanup. + // Some oRPC iterators don't reliably terminate on abort alone; + // calling return() ensures the backend subscription finally block runs, + // which releases the responder lease and listener state. + let iteratorRef: AsyncIterator | undefined; + + // Global subscription: backend can request host-key verification at any time. + // Queue pending requests so concurrent prompts are handled FIFO without drops. + (async () => { + try { + const iterable = await api.ssh.hostKeyVerification.subscribe(undefined, { signal }); + // Consume and cleanup the same iterator instance — + // for-await-of would call [Symbol.asyncIterator]() again, + // creating a second iterator that cleanup can't reach. + const iterator = iterable[Symbol.asyncIterator](); + + // If cleanup ran while subscribe() was in-flight, close the late + // iterator now — nobody else will call return() on it. + if (signal.aborted) { + void iterator.return?.(undefined); + return; + } + + iteratorRef = iterator; + + while (!signal.aborted) { + const { value: event, done } = await iterator.next(); + if (done) { + break; + } + + if (event.type === "removed") { + // Backend finalized this request (timeout or another subscriber responded). + setPendingQueue((prev) => prev.filter((item) => item.requestId !== event.requestId)); + } else { + const { type: _, ...request } = event; + setPendingQueue((prev) => + prev.some((item) => item.requestId === request.requestId) ? prev : [...prev, request] + ); + } + } + } catch { + // Subscription closed (cleanup/reconnect): no-op + } + })(); + + return () => { + controller.abort(); + void iteratorRef?.return?.(undefined); + setPendingQueue([]); // Drop stale prompts; reconnect delivers fresh snapshot + }; + }, [api]); + + const respond = async (accept: boolean) => { + if (!api || !pending || responding) { + return; + } + + const requestId = pending.requestId; + setResponding(true); + + try { + await api.ssh.hostKeyVerification.respond({ requestId, accept }); + // Dequeue only on success — RPC failure keeps prompt visible for retry. + setPendingQueue((prev) => prev.filter((item) => item.requestId !== requestId)); + } catch { + // Transport/RPC failure: keep current request in queue so user can retry. + } finally { + setResponding(false); + } + }; + + return ( + { + // Treat dismiss/escape as explicit rejection so backend unblocks promptly. + if (!open && !responding) { + void respond(false); + } + }} + > + + + Unknown SSH Host + + {pending?.prompt ?? ( + <> + The authenticity of host{" "} + {pending?.host} cannot be + established. + + )} + + + +
+
{pending?.keyType} key fingerprint:
+
{pending?.fingerprint}
+
+ + + Host Key Verification + Accepting will add the host to your known_hosts file. + + + + + + +
+
+ ); +} diff --git a/src/cli/cli.test.ts b/src/cli/cli.test.ts index cf06d17140..1840a322f2 100644 --- a/src/cli/cli.test.ts +++ b/src/cli/cli.test.ts @@ -87,6 +87,7 @@ async function createTestServer(authToken?: string): Promise { sessionUsageService: services.sessionUsageService, signingService: services.signingService, coderService: services.coderService, + hostKeyVerificationService: services.hostKeyVerificationService, }; // Use the actual createOrpcServer function diff --git a/src/cli/server.test.ts b/src/cli/server.test.ts index 4c2fa39845..9d2bdb3a1b 100644 --- a/src/cli/server.test.ts +++ b/src/cli/server.test.ts @@ -90,6 +90,7 @@ async function createTestServer(): Promise { sessionUsageService: services.sessionUsageService, signingService: services.signingService, coderService: services.coderService, + hostKeyVerificationService: services.hostKeyVerificationService, }; // Use the actual createOrpcServer function diff --git a/src/common/constants/ssh.ts b/src/common/constants/ssh.ts new file mode 100644 index 0000000000..8fc9d2d901 --- /dev/null +++ b/src/common/constants/ssh.ts @@ -0,0 +1,8 @@ +/** + * Maximum time (ms) to wait for the user to accept/reject a host-key + * verification prompt in the UI dialog. Shared across: + * - HostKeyVerificationService (auto-reject timeout) + * - OpenSSH connection pool (probe deadline extension) + * - SSH2 connection pool (readyTimeout extension) + */ +export const HOST_KEY_APPROVAL_TIMEOUT_MS = 60_000; diff --git a/src/common/orpc/schemas.ts b/src/common/orpc/schemas.ts index 223f4cd32d..c76205fa67 100644 --- a/src/common/orpc/schemas.ts +++ b/src/common/orpc/schemas.ts @@ -203,6 +203,7 @@ export { signing, type SigningCapabilities, type SignatureEnvelope, + ssh, terminal, tokenizer, update, diff --git a/src/common/orpc/schemas/api.ts b/src/common/orpc/schemas/api.ts index 5132928871..dc3e788eac 100644 --- a/src/common/orpc/schemas/api.ts +++ b/src/common/orpc/schemas/api.ts @@ -6,6 +6,7 @@ import { SendMessageErrorSchema } from "./errors"; import { BranchListResultSchema, FilePartSchema, MuxMessageSchema } from "./message"; import { ProjectConfigSchema, SectionConfigSchema } from "./project"; import { ResultSchema } from "./result"; +import { HostKeyVerificationEventSchema } from "./ssh"; import { RuntimeConfigSchema, RuntimeAvailabilitySchema } from "./runtime"; import { SecretSchema } from "./secrets"; import { @@ -1695,3 +1696,21 @@ export const debug = { output: z.boolean(), // true if error was triggered on an active stream }, }; + +export const ssh = { + hostKeyVerification: { + subscribe: { + input: z.void(), + output: eventIterator(HostKeyVerificationEventSchema), + }, + respond: { + input: z + .object({ + requestId: z.string(), + accept: z.boolean(), + }) + .strict(), + output: ResultSchema(z.void(), z.string()), + }, + }, +}; diff --git a/src/common/orpc/schemas/ssh.ts b/src/common/orpc/schemas/ssh.ts new file mode 100644 index 0000000000..bd22dc1f51 --- /dev/null +++ b/src/common/orpc/schemas/ssh.ts @@ -0,0 +1,18 @@ +import { z } from "zod"; + +export const HostKeyVerificationRequestSchema = z.object({ + requestId: z.string(), + host: z.string(), + keyType: z.string(), + fingerprint: z.string(), + prompt: z.string(), +}); + +export type HostKeyVerificationRequest = z.infer; + +export const HostKeyVerificationEventSchema = z.discriminatedUnion("type", [ + HostKeyVerificationRequestSchema.extend({ type: z.literal("request") }), + z.object({ type: z.literal("removed"), requestId: z.string() }), +]); + +export type HostKeyVerificationEvent = z.infer; diff --git a/src/common/utils/ssh/formatSshEndpoint.ts b/src/common/utils/ssh/formatSshEndpoint.ts new file mode 100644 index 0000000000..c47ac0f18d --- /dev/null +++ b/src/common/utils/ssh/formatSshEndpoint.ts @@ -0,0 +1,9 @@ +/** + * Canonical SSH endpoint identity for dedupe purposes. + * IPv6-safe: wraps bare IPv6 addresses in brackets. + */ +export function formatSshEndpoint(host: string, port: number): string { + const needsBrackets = host.includes(":") && !host.startsWith("[") && !host.endsWith("]"); + const normalizedHost = needsBrackets ? `[${host}]` : host; + return `${normalizedHost}:${port}`; +} diff --git a/src/node/orpc/context.ts b/src/node/orpc/context.ts index 61eee9391d..b91d9a406a 100644 --- a/src/node/orpc/context.ts +++ b/src/node/orpc/context.ts @@ -29,6 +29,7 @@ import type { SessionUsageService } from "@/node/services/sessionUsageService"; import type { TaskService } from "@/node/services/taskService"; import type { PolicyService } from "@/node/services/policyService"; import type { CoderService } from "@/node/services/coderService"; +import type { HostKeyVerificationService } from "@/node/services/hostKeyVerificationService"; export interface ORPCContext { config: Config; @@ -61,5 +62,6 @@ export interface ORPCContext { policyService: PolicyService; signingService: SigningService; coderService: CoderService; + hostKeyVerificationService: HostKeyVerificationService; headers?: IncomingHttpHeaders; } diff --git a/src/node/orpc/router.ts b/src/node/orpc/router.ts index 928bae44ba..aee285c044 100644 --- a/src/node/orpc/router.ts +++ b/src/node/orpc/router.ts @@ -16,6 +16,10 @@ import type { FrontendWorkspaceMetadataSchemaType, } from "@/common/orpc/types"; import type { WorkspaceMetadata } from "@/common/types/workspace"; +import type { + HostKeyVerificationEvent, + HostKeyVerificationRequest, +} from "@/common/orpc/schemas/ssh"; import { createAuthMiddleware } from "./authMiddleware"; import { createAsyncMessageQueue } from "@/common/utils/asyncMessageQueue"; import { clearLogFiles, getLogFilePath } from "@/node/services/log"; @@ -3839,6 +3843,49 @@ export const router = (authToken?: string) => { return { success: true }; }), }, + ssh: { + hostKeyVerification: { + subscribe: t + .input(schemas.ssh.hostKeyVerification.subscribe.input) + .output(schemas.ssh.hostKeyVerification.subscribe.output) + .handler(async function* ({ context, signal }) { + if (signal?.aborted) return; + + const service = context.hostKeyVerificationService; + const releaseResponder = service.registerInteractiveResponder(); + const queue = createAsyncEventQueue(); + + const onRequest = (req: HostKeyVerificationRequest) => + queue.push({ type: "request" as const, ...req }); + const onRemoved = (requestId: string) => + queue.push({ type: "removed" as const, requestId }); + + // Atomic handshake: register listener + snapshot in one step. + // No requests can be lost between snapshot and subscription. + const { snapshot, unsubscribe } = service.subscribeRequests(onRequest, onRemoved); + for (const req of snapshot) queue.push({ type: "request" as const, ...req }); + + const onAbort = () => queue.end(); + signal?.addEventListener("abort", onAbort, { once: true }); + + try { + yield* queue.iterate(); + } finally { + signal?.removeEventListener("abort", onAbort); + releaseResponder(); + queue.end(); + unsubscribe(); + } + }), + respond: t + .input(schemas.ssh.hostKeyVerification.respond.input) + .output(schemas.ssh.hostKeyVerification.respond.output) + .handler(({ context, input }) => { + context.hostKeyVerificationService.respond(input.requestId, input.accept); + return Ok(undefined); + }), + }, + }, }); }; diff --git a/src/node/runtime/SSH2ConnectionPool.ts b/src/node/runtime/SSH2ConnectionPool.ts index e63f8dd40a..e0c4571930 100644 --- a/src/node/runtime/SSH2ConnectionPool.ts +++ b/src/node/runtime/SSH2ConnectionPool.ts @@ -13,11 +13,20 @@ import * as path from "path"; import { spawn, type ChildProcess } from "child_process"; import { Duplex } from "stream"; import type { Client } from "ssh2"; +import { HOST_KEY_APPROVAL_TIMEOUT_MS } from "@/common/constants/ssh"; import { getErrorMessage } from "@/common/utils/errors"; +import { formatSshEndpoint } from "@/common/utils/ssh/formatSshEndpoint"; import { log } from "@/node/services/log"; import { attachStreamErrorHandler } from "@/node/utils/streamErrors"; import type { SSHConnectionConfig } from "./sshConnectionPool"; import { resolveSSHConfig, type ResolvedSSHConfig } from "./sshConfigParser"; +import type { HostKeyVerificationService } from "@/node/services/hostKeyVerificationService"; + +let hostKeyService: HostKeyVerificationService | undefined; + +export function setHostKeyVerificationService(svc: HostKeyVerificationService): void { + hostKeyService = svc; +} /** * Connection health status @@ -494,6 +503,8 @@ export class SSH2ConnectionPool { const readableKeys = await resolvePrivateKeys(resolvedConfigWithIdentities.identityFiles); const keysToTry: Array = readableKeys.length > 0 ? readableKeys : [undefined]; + const verificationService = hostKeyService; + const canPromptInteractively = verificationService?.hasInteractiveResponder() === true; const connectWithKey = async ( privateKey: Buffer | undefined, @@ -607,10 +618,40 @@ export class SSH2ConnectionPool { username, agent: agentOverride, sock: proxy?.sock, - readyTimeout: timeoutMs, + // hostVerifier can wait for user approval in the UI dialog, + // so keep the handshake alive long enough for that interaction. + readyTimeout: canPromptInteractively + ? Math.max(timeoutMs, HOST_KEY_APPROVAL_TIMEOUT_MS) + : timeoutMs, keepaliveInterval: 5000, keepaliveCountMax: 2, ...(privateKey ? { privateKey } : {}), + // Host key verification + ...(canPromptInteractively && verificationService + ? { + hostHash: "sha256" as const, + hostVerifier: ( + fingerprint: string, + verify: (accept: boolean) => void + ): boolean => { + void verificationService + .requestVerification({ + host: resolvedConfig.hostName, + dedupeKey: formatSshEndpoint( + resolvedConfig.hostName, + resolvedConfig.port + ), + keyType: "unknown", // ssh2 doesn't expose key type in this callback + fingerprint: `SHA256:${fingerprint}`, + prompt: `The authenticity of host '${resolvedConfig.hostName}' can't be established.\nFingerprint: SHA256:${fingerprint}`, + }) + .then(verify); + return true; + }, + } + : { + hostVerifier: () => true, + }), }; client.connect(connectOptions); diff --git a/src/node/runtime/SSHRuntime.test.ts b/src/node/runtime/SSHRuntime.test.ts index de2f54de06..77ea21cbd7 100644 --- a/src/node/runtime/SSHRuntime.test.ts +++ b/src/node/runtime/SSHRuntime.test.ts @@ -101,6 +101,47 @@ describe("SSHRuntime.ensureReady repository checks", () => { } }); }); + +describe("SSHRuntime.resolvePath", () => { + let runtime: SSHRuntime; + let transport: ReturnType; + let acquireConnectionSpy: ReturnType> | null = + null; + let execBufferedSpy: ReturnType> | null = + null; + + beforeEach(() => { + const config = { host: "example.com", srcBaseDir: "/home/user/src" }; + transport = createSSHTransport(config, false); + runtime = new SSHRuntime(config, transport, { + projectPath: "/project", + workspaceName: "ws", + }); + }); + + afterEach(() => { + acquireConnectionSpy?.mockRestore(); + acquireConnectionSpy = null; + execBufferedSpy?.mockRestore(); + execBufferedSpy = null; + }); + + it("passes a 10s timeout and max wait to preflight acquireConnection", async () => { + acquireConnectionSpy = spyOn(transport, "acquireConnection").mockResolvedValue(undefined); + execBufferedSpy = spyOn(runtimeHelpers, "execBuffered").mockResolvedValue({ + stdout: "/home/user/foo\n", + stderr: "", + exitCode: 0, + duration: 0, + }); + + expect(await runtime.resolvePath("~/foo")).toBe("/home/user/foo"); + expect(acquireConnectionSpy).toHaveBeenCalledWith({ + timeoutMs: 10_000, + maxWaitMs: 10_000, + }); + }); +}); describe("computeBaseRepoPath", () => { it("computes the correct bare repo path", () => { // computeBaseRepoPath uses getProjectName (basename) to compute: diff --git a/src/node/runtime/SSHRuntime.ts b/src/node/runtime/SSHRuntime.ts index f1ba95f335..e2ab32a47e 100644 --- a/src/node/runtime/SSHRuntime.ts +++ b/src/node/runtime/SSHRuntime.ts @@ -298,7 +298,17 @@ export class SSHRuntime extends RemoteRuntime { const command = `bash -lc ${shescape.quote(script)}`; - const abortController = createAbortController(10_000); + // Wait for connection establishment (including host-key confirmation) before + // starting the 10s command timeout. Otherwise users who take >10s to accept + // the host key prompt will hit a false timeout immediately after acceptance. + const resolvePathTimeoutMs = 10_000; + + await this.transport.acquireConnection({ + timeoutMs: resolvePathTimeoutMs, + maxWaitMs: resolvePathTimeoutMs, + }); + + const abortController = createAbortController(resolvePathTimeoutMs); try { const result = await execBuffered(this, command, { cwd: "/tmp", diff --git a/src/node/runtime/sshAskpass.test.ts b/src/node/runtime/sshAskpass.test.ts new file mode 100644 index 0000000000..94b45efa7e --- /dev/null +++ b/src/node/runtime/sshAskpass.test.ts @@ -0,0 +1,218 @@ +import { describe, expect, spyOn, test } from "bun:test"; +import * as fs from "fs"; +import * as os from "os"; +import * as path from "path"; +import { createAskpassSession, parseHostKeyPrompt } from "./sshAskpass"; + +describe("sshAskpass", () => { + describe("createAskpassSession", () => { + async function simulateAskpassInvocation( + askpassDir: string, + promptText: string, + requestId: string, + cleanupResponseFiles = true + ): Promise { + const promptFile = path.join(askpassDir, `prompt.${requestId}.txt`); + const responseFile = path.join(askpassDir, `response.${requestId}.txt`); + + // Simulate askpass writing prompt content for this invocation. + await fs.promises.writeFile(promptFile, promptText, "utf-8"); + + // Poll for the response file written by createAskpassSession(). + for (let i = 0; i < 100; i += 1) { + try { + const response = await fs.promises.readFile(responseFile, "utf-8"); + + if (cleanupResponseFiles) { + // Simulate askpass script cleanup. + await fs.promises.unlink(promptFile).catch(() => undefined); + await fs.promises.unlink(responseFile).catch(() => undefined); + } + + return response.trim(); + } catch { + await new Promise((resolve) => setTimeout(resolve, 50)); + } + } + + throw new Error(`Timeout waiting for response for request '${requestId}'`); + } + + async function listAskpassTempDirs(): Promise { + return (await fs.promises.readdir(os.tmpdir())) + .filter((entry) => entry.startsWith("mux-askpass-")) + .sort(); + } + + test("does not leak temp dir when script bootstrap fails", async () => { + const tmpBefore = await listAskpassTempDirs(); + const accessSpy = spyOn(fs.promises, "access").mockRejectedValueOnce( + new Error("ENOENT: script missing") + ); + const writeFileSpy = spyOn(fs.promises, "writeFile").mockRejectedValueOnce( + new Error("EACCES: permission denied") + ); + + let leakedDirs: string[] = []; + + try { + let error: unknown; + try { + await createAskpassSession(() => Promise.resolve("ok")); + } catch (thrown) { + error = thrown; + } + + expect(error).toBeInstanceOf(Error); + if (!(error instanceof Error)) { + throw new Error("Expected createAskpassSession to throw an Error"); + } + expect(error.message).toContain("EACCES"); + + const tmpAfter = await listAskpassTempDirs(); + const beforeSet = new Set(tmpBefore); + leakedDirs = tmpAfter.filter((entry) => !beforeSet.has(entry)); + expect(leakedDirs).toHaveLength(0); + } finally { + accessSpy.mockRestore(); + writeFileSpy.mockRestore(); + for (const dir of leakedDirs) { + await fs.promises + .rm(path.join(os.tmpdir(), dir), { recursive: true, force: true }) + .catch(() => undefined); + } + } + }); + + test("handles a single prompt and returns response", async () => { + const prompts: string[] = []; + const session = await createAskpassSession((prompt) => { + prompts.push(prompt); + return Promise.resolve("yes"); + }); + + try { + const result = await simulateAskpassInvocation( + session.env.MUX_ASKPASS_DIR, + "Are you sure you want to continue connecting (yes/no)?", + "req1" + ); + + expect(result).toBe("yes"); + expect(prompts).toEqual(["Are you sure you want to continue connecting (yes/no)?"]); + } finally { + session.cleanup(); + } + }); + + test("handles two sequential prompts without ignoring the second", async () => { + const prompts: string[] = []; + const session = await createAskpassSession((prompt) => { + prompts.push(prompt); + return Promise.resolve(prompt.includes("continue connecting") ? "yes" : "denied"); + }); + + try { + const askpassDir = session.env.MUX_ASKPASS_DIR; + + const first = await simulateAskpassInvocation( + askpassDir, + "Are you sure you want to continue connecting (yes/no)?", + "1001.1234", + false + ); + expect(first).toBe("yes"); + + const second = await simulateAskpassInvocation( + askpassDir, + "Enter passphrase for key '/home/user/.ssh/id_ed25519':", + "1001.5678" + ); + expect(second).toBe("denied"); + + expect(prompts).toHaveLength(2); + } finally { + session.cleanup(); + } + }); + + test("cleanup is idempotent", async () => { + const session = await createAskpassSession(() => Promise.resolve("ok")); + + session.cleanup(); + expect(() => session.cleanup()).not.toThrow(); + }); + + test("ignores duplicate request IDs", async () => { + let callCount = 0; + const session = await createAskpassSession(() => { + callCount += 1; + return Promise.resolve("yes"); + }); + + try { + const askpassDir = session.env.MUX_ASKPASS_DIR; + const requestId = "dup-test"; + const promptFile = path.join(askpassDir, `prompt.${requestId}.txt`); + const responseFile = path.join(askpassDir, `response.${requestId}.txt`); + + // Simulate duplicate writes for the same askpass request id. + await fs.promises.writeFile(promptFile, "test prompt", "utf-8"); + await fs.promises.writeFile(promptFile, "test prompt", "utf-8"); + + for (let i = 0; i < 100; i += 1) { + try { + await fs.promises.access(responseFile); + break; + } catch { + await new Promise((resolve) => setTimeout(resolve, 50)); + } + } + + await new Promise((resolve) => setTimeout(resolve, 200)); + expect(callCount).toBe(1); + } finally { + session.cleanup(); + } + }); + + test("session env includes required SSH variables", async () => { + const session = await createAskpassSession(() => Promise.resolve("ok")); + + try { + expect(session.env.SSH_ASKPASS).toBeDefined(); + expect(session.env.SSH_ASKPASS_REQUIRE).toBe("force"); + expect(session.env.DISPLAY).toBeDefined(); + expect(session.env.MUX_ASKPASS_DIR).toBeDefined(); + + const stat = await fs.promises.stat(session.env.MUX_ASKPASS_DIR); + expect(stat.isDirectory()).toBe(true); + } finally { + session.cleanup(); + } + }); + }); + + describe("parseHostKeyPrompt", () => { + test("parses standard host-key prompt", () => { + const text = + "The authenticity of host 'example.com (1.2.3.4)' can't be established.\n" + + "ED25519 key fingerprint is SHA256:abcdef123456\n" + + "Are you sure you want to continue connecting (yes/no/[fingerprint])?"; + + const result = parseHostKeyPrompt(text); + + expect(result.host).toBe("example.com (1.2.3.4)"); + expect(result.keyType).toBe("ED25519"); + expect(result.fingerprint).toBe("SHA256:abcdef123456"); + }); + + test("returns unknown for non-host-key text", () => { + const result = parseHostKeyPrompt("Enter passphrase for key '/home/user/.ssh/id_ed25519':"); + + expect(result.host).toBe("unknown"); + expect(result.keyType).toBe("unknown"); + expect(result.fingerprint).toBe("unknown"); + }); + }); +}); diff --git a/src/node/runtime/sshAskpass.ts b/src/node/runtime/sshAskpass.ts new file mode 100644 index 0000000000..7f7db1e84d --- /dev/null +++ b/src/node/runtime/sshAskpass.ts @@ -0,0 +1,175 @@ +import * as fs from "fs"; +import * as path from "path"; +import * as os from "os"; +import { log } from "@/node/services/log"; + +const ASKPASS_SCRIPT = `#!/bin/sh +# mux-askpass — SSH_ASKPASS helper for Mux +# Each invocation is an independent request/response transaction identified +# by a unique ID, so multiple prompts per SSH handshake are handled correctly. +# Uses only regular files (no mkfifo) for cross-platform portability. +req_id="$$.$(date +%s%N)" +prompt_file="$MUX_ASKPASS_DIR/prompt.$req_id.txt" +response_file="$MUX_ASKPASS_DIR/response.$req_id.txt" +printf '%s' "$1" > "$prompt_file" +# Poll for response file (60s timeout = 1200 × 50ms) +i=0 +while [ "$i" -lt 1200 ]; do + if [ -f "$response_file" ]; then + cat "$response_file" + rm -f "$prompt_file" "$response_file" + exit 0 + fi + sleep 0.05 + i=$((i + 1)) +done +exit 1 +`; + +let askpassPath: string | undefined; + +function extractRequestId(filename: string): string | undefined { + const match = /^prompt\.(.+)\.txt$/.exec(filename); + return match?.[1]; +} + +async function ensureAskpassScript(): Promise { + if (askpassPath) { + try { + await fs.promises.access(askpassPath, fs.constants.X_OK); + return askpassPath; + } catch { + // Recreate the helper script if it was deleted. + } + } + + const dir = path.join(os.homedir(), ".mux", "bin"); + await fs.promises.mkdir(dir, { recursive: true }); + askpassPath = path.join(dir, "mux-askpass"); + await fs.promises.writeFile(askpassPath, ASKPASS_SCRIPT, { mode: 0o755 }); + return askpassPath; +} + +/** Parse host/keyType/fingerprint from OpenSSH output. */ +export function parseHostKeyPrompt(text: string): { + host: string; + keyType: string; + fingerprint: string; + prompt: string; +} { + const hostMatch = /authenticity of host '([^']+)'/.exec(text); + const keyMatch = /(\w+) key fingerprint is (SHA256:\S+)/.exec(text); + return { + host: hostMatch?.[1] ?? "unknown", + keyType: keyMatch?.[1] ?? "unknown", + fingerprint: keyMatch?.[2] ?? "unknown", + prompt: text.trim(), + }; +} + +export interface AskpassSession { + /** Merge into the spawn env: { ...process.env, ...env } */ + env: Record; + /** Must be called when the SSH process exits. */ + cleanup(): void; +} + +/** + * Creates a per-probe askpass session. + * + * @param onPrompt Called when askpass fires. Receives the prompt text, + * must return the response string (e.g. "yes" or "no"). + */ +export async function createAskpassSession( + onPrompt: (prompt: string) => Promise +): Promise { + // Resolve script path before allocating temp resources. + const scriptPath = await ensureAskpassScript(); + + const dir = fs.mkdtempSync(path.join(os.tmpdir(), "mux-askpass-")); + const processed = new Set(); + let closed = false; + + async function handlePrompt(requestId: string): Promise { + const promptFile = path.join(dir, `prompt.${requestId}.txt`); + const responseFile = path.join(dir, `response.${requestId}.txt`); + + try { + await fs.promises.access(promptFile); + } catch { + processed.delete(requestId); + return; + } + + if (closed) return; + + try { + const promptText = await fs.promises.readFile(promptFile, "utf-8"); + const response = await onPrompt(promptText); + await fs.promises.writeFile(responseFile, response + "\n"); + } catch (err) { + log.debug("Askpass prompt handling failed:", err); + // Write rejection to unblock askpass (best-effort) + try { + await fs.promises.writeFile(responseFile, "no\n"); + } catch { + /* askpass may already be gone */ + } + } + } + + // Watch for askpass to write prompt files. + // fs.watch is set up BEFORE SSH is spawned, so we cannot miss events. + let watcher: fs.FSWatcher; + try { + watcher = fs.watch(dir, (_, filename) => { + if (closed) return; + + void (async () => { + let candidateFilenames: string[]; + if (typeof filename === "string") { + candidateFilenames = [filename]; + } else { + try { + candidateFilenames = await fs.promises.readdir(dir); + } catch { + return; + } + } + + for (const candidate of candidateFilenames) { + const requestId = extractRequestId(candidate); + if (!requestId || processed.has(requestId)) { + continue; + } + + processed.add(requestId); + void handlePrompt(requestId); + } + })(); + }); + } catch (error) { + fs.rmSync(dir, { recursive: true, force: true }); + throw error; + } + + return { + env: { + SSH_ASKPASS: scriptPath, + // Force askpass usage even with a controlling terminal (OpenSSH 8.4+) + SSH_ASKPASS_REQUIRE: "force", + // Enable askpass on pre-8.4 OpenSSH (DISPLAY must be non-empty) + DISPLAY: process.env.DISPLAY ?? "mux", + MUX_ASKPASS_DIR: dir, + }, + cleanup() { + closed = true; + watcher.close(); + try { + fs.rmSync(dir, { recursive: true, force: true }); + } catch { + /* best-effort */ + } + }, + }; +} diff --git a/src/node/runtime/sshConnectionPool.test.ts b/src/node/runtime/sshConnectionPool.test.ts index e8efb9f667..1eb2e5043a 100644 --- a/src/node/runtime/sshConnectionPool.test.ts +++ b/src/node/runtime/sshConnectionPool.test.ts @@ -1,6 +1,12 @@ import * as os from "os"; import * as path from "path"; -import { getControlPath, SSHConnectionPool, type SSHRuntimeConfig } from "./sshConnectionPool"; +import { describe, expect, test, spyOn } from "bun:test"; +import { + appendOpenSSHHostKeyPolicyArgs, + getControlPath, + SSHConnectionPool, + type SSHRuntimeConfig, +} from "./sshConnectionPool"; describe("sshConnectionPool", () => { describe("getControlPath", () => { @@ -119,6 +125,30 @@ describe("sshConnectionPool", () => { }); }); +describe("appendOpenSSHHostKeyPolicyArgs", () => { + test("appends fallback args when interactive mode is unavailable", () => { + const args: string[] = ["-T"]; + + appendOpenSSHHostKeyPolicyArgs(args, false); + + expect(args).toEqual([ + "-T", + "-o", + "StrictHostKeyChecking=no", + "-o", + "UserKnownHostsFile=/dev/null", + ]); + }); + + test("does not append fallback args when interactive mode is available", () => { + const args: string[] = ["-T"]; + + appendOpenSSHHostKeyPolicyArgs(args, true); + + expect(args).toEqual(["-T"]); + }); +}); + describe("username isolation", () => { test("controlPath includes local username to prevent cross-user collisions", () => { // This test verifies that os.userInfo().username is included in the hash @@ -216,6 +246,7 @@ describe("SSHConnectionPool", () => { }; // Trigger a failure via acquireConnection (will fail to connect) + // eslint-disable-next-line @typescript-eslint/await-thenable await expect( pool.acquireConnection(config, { timeoutMs: 1000, maxWaitMs: 0 }) ).rejects.toThrow(); @@ -296,11 +327,13 @@ describe("SSHConnectionPool", () => { }; // Trigger a failure to put connection in backoff + // eslint-disable-next-line @typescript-eslint/await-thenable await expect( pool.acquireConnection(config, { timeoutMs: 1000, maxWaitMs: 0 }) ).rejects.toThrow(); // Second call should throw immediately with backoff message + // eslint-disable-next-line @typescript-eslint/await-thenable await expect(pool.acquireConnection(config, { maxWaitMs: 0 })).rejects.toThrow(/in backoff/); }); @@ -317,6 +350,68 @@ describe("SSHConnectionPool", () => { expect(path1).toBe(path2); expect(path1).toBe(getControlPath(config)); }); + + test("records backoff when probe rejects without recording it (safety net)", async () => { + const pool = new SSHConnectionPool(); + const config: SSHRuntimeConfig = { + host: "askpass-fail.example.com", + srcBaseDir: "/work", + }; + + // Simulate a probe that fails before reaching markFailedByKey + // (e.g., createAskpassSession throwing on fs.watch ENOSPC). + const spy = spyOn( + pool as unknown as { probeConnection: () => Promise }, + "probeConnection" + ).mockRejectedValueOnce(new Error("ENOSPC: no space left on device")); + + // eslint-disable-next-line @typescript-eslint/await-thenable + await expect( + pool.acquireConnection(config, { timeoutMs: 1000, maxWaitMs: 0 }) + ).rejects.toThrow(/ENOSPC/); + + // Safety net should have recorded backoff despite probeConnection not doing so. + const health = pool.getConnectionHealth(config); + expect(health?.status).toBe("unhealthy"); + expect(health?.backoffUntil).toBeDefined(); + expect(health?.lastError).toContain("ENOSPC"); + + spy.mockRestore(); + }); + }); + + describe("askpass prompt classification", () => { + // classifyAskpassPrompt() routes prompts containing "continue connecting" + // through host-key verification and treats other prompts as credentials. + const HOST_KEY_PATTERN = /continue connecting/i; + + test("detects standard host-key confirmation prompt", () => { + expect( + HOST_KEY_PATTERN.test( + "Are you sure you want to continue connecting (yes/no/[fingerprint])? " + ) + ).toBe(true); + }); + + test("detects host-key prompt case-insensitively", () => { + expect(HOST_KEY_PATTERN.test("Are you sure you want to Continue Connecting (yes/no)?")).toBe( + true + ); + }); + + test("rejects passphrase prompt", () => { + expect(HOST_KEY_PATTERN.test("Enter passphrase for key '/home/user/.ssh/id_ed25519':")).toBe( + false + ); + }); + + test("rejects password prompt", () => { + expect(HOST_KEY_PATTERN.test("user@host's password:")).toBe(false); + }); + + test("rejects empty prompt", () => { + expect(HOST_KEY_PATTERN.test("")).toBe(false); + }); }); describe("singleflighting", () => { diff --git a/src/node/runtime/sshConnectionPool.ts b/src/node/runtime/sshConnectionPool.ts index cadd1ed17b..a6507c4dbe 100644 --- a/src/node/runtime/sshConnectionPool.ts +++ b/src/node/runtime/sshConnectionPool.ts @@ -18,7 +18,45 @@ import * as crypto from "crypto"; import * as path from "path"; import * as os from "os"; import { spawn } from "child_process"; +import { HOST_KEY_APPROVAL_TIMEOUT_MS } from "@/common/constants/ssh"; +import { formatSshEndpoint } from "@/common/utils/ssh/formatSshEndpoint"; import { log } from "@/node/services/log"; +import type { HostKeyVerificationService } from "@/node/services/hostKeyVerificationService"; +import { createAskpassSession, parseHostKeyPrompt } from "./sshAskpass"; + +/** + * Classify an SSH_ASKPASS prompt to route it through the correct handler. + * + * OpenSSH askpass receives the bare prompt question (e.g., "Are you sure you + * want to continue connecting (yes/no/[fingerprint])?" for host-key, or + * "Enter passphrase for key '...':" for encrypted keys). We classify based + * on the prompt text so host-key prompts go through verification UI and + * everything else fails fast. + */ +function classifyAskpassPrompt(promptText: string): "host-key" | "credential" { + // OpenSSH host-key confirmation prompt always contains "continue connecting" + if (/continue connecting/i.test(promptText)) return "host-key"; + return "credential"; +} + +let hostKeyService: HostKeyVerificationService | undefined; + +export function setHostKeyVerificationService(svc: HostKeyVerificationService): void { + hostKeyService = svc; +} + +export function isInteractiveHostKeyApprovalAvailable(): boolean { + return hostKeyService?.hasInteractiveResponder() === true; +} + +export function appendOpenSSHHostKeyPolicyArgs(args: string[], interactive: boolean): void { + if (interactive) { + return; + } + + args.push("-o", "StrictHostKeyChecking=no"); + args.push("-o", "UserKnownHostsFile=/dev/null"); +} /** * SSH connection configuration (host/port/identity only). @@ -253,10 +291,16 @@ export class SSHConnectionPool { await probe; return; } catch (error) { + // Ensure backoff is recorded even if probeConnection rejected before + // reaching markFailedByKey (e.g., askpass setup failure). Without this, + // the while-loop retries immediately with no backoff — a hot loop. + const h = this.health.get(key); + if (!h?.backoffUntil || h.backoffUntil <= new Date()) { + this.markFailedByKey(key, error instanceof Error ? error.message : String(error)); + } if (!shouldWait) { throw error; } - // In wait mode: probeConnection() recorded backoff; loop and wait. continue; } finally { this.inflight.delete(key); @@ -363,6 +407,8 @@ export class SSHConnectionPool { key: string ): Promise { const controlPath = getControlPath(config); + const verificationService = hostKeyService; + const canPromptInteractively = isInteractiveHostKeyApprovalAvailable(); const args: string[] = ["-T"]; // No PTY needed for probe @@ -372,8 +418,6 @@ export class SSHConnectionPool { if (config.identityFile) { args.push("-i", config.identityFile); - args.push("-o", "StrictHostKeyChecking=no"); - args.push("-o", "UserKnownHostsFile=/dev/null"); args.push("-o", "LogLevel=ERROR"); } @@ -382,35 +426,94 @@ export class SSHConnectionPool { args.push("-o", `ControlPath=${controlPath}`); args.push("-o", "ControlPersist=60"); - // Aggressive timeouts for probe - const connectTimeout = Math.min(Math.ceil(timeoutMs / 1000), 15); + // ConnectTimeout covers the entire SSH handshake including SSH_ASKPASS waits. + // When host-key prompts are possible, use the longer prompt timeout so SSH + // doesn't self-terminate while the user is responding to the dialog. + // The Node.js timer still provides fast-fail for unreachable hosts. + const connectTimeout = canPromptInteractively + ? Math.ceil(HOST_KEY_APPROVAL_TIMEOUT_MS / 1000) + : Math.min(Math.ceil(timeoutMs / 1000), 15); args.push("-o", `ConnectTimeout=${connectTimeout}`); args.push("-o", "ServerAliveInterval=5"); args.push("-o", "ServerAliveCountMax=2"); + // When no host-key verification UI is available (headless/test/CLI), + // fall back to auto-accepting unknown hosts for probe connectivity checks. + appendOpenSSHHostKeyPolicyArgs(args, canPromptInteractively); + args.push(config.host, "echo ok"); log.debug(`SSH probe: ssh ${args.join(" ")}`); + let stderr = ""; + // Wired to the probe timer inside the Promise; the askpass callback + // calls this to transition from connection phase (10s) to interaction + // phase (60s) when a host-key prompt is detected. + let extendDeadline: ((ms: number) => void) | undefined; + + // Set up SSH_ASKPASS for interactive host-key verification. + // The askpass helper exchanges prompt/response text through temp files. + // Non-host-key prompts (passphrase, password) return empty to fail fast — + // passphrase-protected keys must be agent-unlocked before Mux can use them. + const askpass = + canPromptInteractively && verificationService + ? await createAskpassSession(async (promptText) => { + if (classifyAskpassPrompt(promptText) !== "host-key") { + // Credential prompts (passphrase/password) are not supported during + // probe — keys must be unlocked via ssh-agent. Return empty string + // so SSH treats this as auth failure and moves on. + log.warn("SSH askpass: unsupported credential prompt during probe, failing fast"); + return ""; + } + + extendDeadline?.(HOST_KEY_APPROVAL_TIMEOUT_MS); + + const fullContext = stderr + "\n" + promptText; + const parsed = parseHostKeyPrompt(fullContext); + const accepted = await verificationService.requestVerification({ + ...parsed, + dedupeKey: formatSshEndpoint(config.host, config.port ?? 22), + }); + return accepted ? "yes" : "no"; + }) + : undefined; + return new Promise((resolve, reject) => { - const proc = spawn("ssh", args, { stdio: ["ignore", "pipe", "pipe"] }); + const proc = spawn("ssh", args, { + stdio: ["ignore", "pipe", "pipe"], + ...(askpass ? { env: { ...process.env, ...askpass.env } } : {}), + }); + + let timedOut = false; + let timer: ReturnType | undefined; + + const scheduleKill = (ms: number) => { + if (timer) { + clearTimeout(timer); + } + timer = setTimeout(() => { + timedOut = true; + proc.kill("SIGKILL"); + askpass?.cleanup(); + const error = "SSH probe timed out"; + this.markFailedByKey(key, error); + reject(new Error(error)); + }, ms); + }; + + // Wire askpass deadline extension, then start initial fast timeout. + extendDeadline = scheduleKill; + scheduleKill(timeoutMs); - let stderr = ""; proc.stderr.on("data", (data: Buffer) => { stderr += data.toString(); }); - let timedOut = false; - const timeout = setTimeout(() => { - timedOut = true; - proc.kill("SIGKILL"); - const error = "SSH probe timed out"; - this.markFailedByKey(key, error); - reject(new Error(error)); - }, timeoutMs); - proc.on("close", (code) => { - clearTimeout(timeout); + if (timer) { + clearTimeout(timer); + } + askpass?.cleanup(); if (timedOut) return; // Already handled by timeout if (code === 0) { @@ -425,7 +528,10 @@ export class SSHConnectionPool { }); proc.on("error", (err) => { - clearTimeout(timeout); + if (timer) { + clearTimeout(timer); + } + askpass?.cleanup(); const error = `SSH probe spawn error: ${err.message}`; this.markFailedByKey(key, error); reject(new Error(error)); diff --git a/src/node/runtime/transports/OpenSSHTransport.test.ts b/src/node/runtime/transports/OpenSSHTransport.test.ts new file mode 100644 index 0000000000..5b43ed2e99 --- /dev/null +++ b/src/node/runtime/transports/OpenSSHTransport.test.ts @@ -0,0 +1,79 @@ +import { afterEach, beforeEach, describe, expect, mock, spyOn, test } from "bun:test"; +import * as childProcess from "child_process"; + +import { HostKeyVerificationService } from "@/node/services/hostKeyVerificationService"; +import { setHostKeyVerificationService, sshConnectionPool } from "../sshConnectionPool"; +import { OpenSSHTransport } from "./OpenSSHTransport"; + +function createMockChildProcess(): ReturnType { + return { + on: mock(() => undefined), + pid: 12345, + } as unknown as ReturnType; +} + +describe("OpenSSHTransport.spawnRemoteProcess", () => { + let spawnSpy: ReturnType>; + let acquireConnectionSpy: ReturnType>; + let releaseInteractiveResponder: (() => void) | undefined; + + beforeEach(() => { + spawnSpy = spyOn(childProcess, "spawn").mockImplementation((() => + createMockChildProcess()) as unknown as typeof childProcess.spawn); + acquireConnectionSpy = spyOn(sshConnectionPool, "acquireConnection").mockResolvedValue( + undefined + ); + }); + + afterEach(() => { + releaseInteractiveResponder?.(); + releaseInteractiveResponder = undefined; + // Reset to a default headless service so interactive state does not leak across tests. + setHostKeyVerificationService(new HostKeyVerificationService()); + + spawnSpy.mockRestore(); + acquireConnectionSpy.mockRestore(); + }); + + function setInteractiveHostKeyApproval(interactive: boolean): void { + const service = new HostKeyVerificationService(); + if (interactive) { + releaseInteractiveResponder = service.registerInteractiveResponder(); + } + setHostKeyVerificationService(service); + } + + async function runSpawnRemoteProcess(): Promise { + const transport = new OpenSSHTransport({ host: "remote.example.com" }); + await transport.spawnRemoteProcess("echo ok", {}); + + expect(spawnSpy).toHaveBeenCalledTimes(1); + const [command, args] = spawnSpy.mock.calls[0] as [string, string[], childProcess.SpawnOptions]; + expect(command).toBe("ssh"); + return args; + } + + test("headless mode includes host-key fallback options and BatchMode=yes", async () => { + setInteractiveHostKeyApproval(false); + + const args = await runSpawnRemoteProcess(); + + expect(args).toContain("BatchMode=yes"); + expect(args).toContain("StrictHostKeyChecking=no"); + expect(args).toContain("UserKnownHostsFile=/dev/null"); + expect(args.indexOf("StrictHostKeyChecking=no")).toBeGreaterThan(args.indexOf("BatchMode=yes")); + expect(args.indexOf("UserKnownHostsFile=/dev/null")).toBeGreaterThan( + args.indexOf("StrictHostKeyChecking=no") + ); + }); + + test("interactive mode keeps BatchMode=yes but excludes host-key fallback options", async () => { + setInteractiveHostKeyApproval(true); + + const args = await runSpawnRemoteProcess(); + + expect(args).toContain("BatchMode=yes"); + expect(args).not.toContain("StrictHostKeyChecking=no"); + expect(args).not.toContain("UserKnownHostsFile=/dev/null"); + }); +}); diff --git a/src/node/runtime/transports/OpenSSHTransport.ts b/src/node/runtime/transports/OpenSSHTransport.ts index 7493fdddff..74c3edd510 100644 --- a/src/node/runtime/transports/OpenSSHTransport.ts +++ b/src/node/runtime/transports/OpenSSHTransport.ts @@ -3,7 +3,13 @@ import { log } from "@/node/services/log"; import { spawnPtyProcess } from "../ptySpawn"; import { expandTildeForSSH } from "../tildeExpansion"; -import { getControlPath, sshConnectionPool, type SSHConnectionConfig } from "../sshConnectionPool"; +import { + appendOpenSSHHostKeyPolicyArgs, + getControlPath, + isInteractiveHostKeyApprovalAvailable, + sshConnectionPool, + type SSHConnectionConfig, +} from "../sshConnectionPool"; import type { SpawnResult } from "../RemoteRuntime"; import type { SSHTransport, @@ -41,11 +47,13 @@ export class OpenSSHTransport implements SSHTransport { async acquireConnection(options?: { abortSignal?: AbortSignal; timeoutMs?: number; + maxWaitMs?: number; onWait?: (waitMs: number) => void; }): Promise { await sshConnectionPool.acquireConnection(this.config, { abortSignal: options?.abortSignal, timeoutMs: options?.timeoutMs, + maxWaitMs: options?.maxWaitMs, onWait: options?.onWait, }); } @@ -63,6 +71,11 @@ export class OpenSSHTransport implements SSHTransport { sshArgs.push("-o", `ConnectTimeout=${connectTimeout}`); sshArgs.push("-o", "ServerAliveInterval=5"); sshArgs.push("-o", "ServerAliveCountMax=2"); + // Non-interactive execs must never hang on host-key or password prompts. + // The probe path handles host-key verification via the Mux dialog; + // by the time we reach here, the host key should already be accepted. + sshArgs.push("-o", "BatchMode=yes"); + appendOpenSSHHostKeyPolicyArgs(sshArgs, isInteractiveHostKeyApprovalAvailable()); sshArgs.push(this.config.host, fullCommand); @@ -110,8 +123,6 @@ export class OpenSSHTransport implements SSHTransport { if (this.config.identityFile) { args.push("-i", this.config.identityFile); - args.push("-o", "StrictHostKeyChecking=no"); - args.push("-o", "UserKnownHostsFile=/dev/null"); } args.push("-o", "LogLevel=FATAL"); diff --git a/src/node/runtime/transports/SSH2Transport.ts b/src/node/runtime/transports/SSH2Transport.ts index ec51e302c4..04262ef2f3 100644 --- a/src/node/runtime/transports/SSH2Transport.ts +++ b/src/node/runtime/transports/SSH2Transport.ts @@ -239,11 +239,13 @@ export class SSH2Transport implements SSHTransport { async acquireConnection(options?: { abortSignal?: AbortSignal; timeoutMs?: number; + maxWaitMs?: number; onWait?: (waitMs: number) => void; }): Promise { await ssh2ConnectionPool.acquireConnection(this.config, { abortSignal: options?.abortSignal, timeoutMs: options?.timeoutMs, + maxWaitMs: options?.maxWaitMs, onWait: options?.onWait, }); } diff --git a/src/node/runtime/transports/SSHTransport.ts b/src/node/runtime/transports/SSHTransport.ts index 4e37ec344b..43f7aedef4 100644 --- a/src/node/runtime/transports/SSHTransport.ts +++ b/src/node/runtime/transports/SSHTransport.ts @@ -40,6 +40,7 @@ export interface SSHTransport { acquireConnection(options?: { abortSignal?: AbortSignal; timeoutMs?: number; + maxWaitMs?: number; onWait?: (waitMs: number) => void; }): Promise; diff --git a/src/node/services/hostKeyVerificationService.test.ts b/src/node/services/hostKeyVerificationService.test.ts new file mode 100644 index 0000000000..8343d51fd4 --- /dev/null +++ b/src/node/services/hostKeyVerificationService.test.ts @@ -0,0 +1,279 @@ +import { describe, it, expect, beforeEach } from "bun:test"; + +import { HostKeyVerificationService } from "./hostKeyVerificationService"; +import type { HostKeyVerificationRequest } from "@/common/orpc/schemas/ssh"; + +/** Short timeout for tests — avoids waiting the real 60s. */ +const TEST_TIMEOUT_MS = 20; + +const REQUEST_PARAMS: Omit = { + host: "example.com", + keyType: "ssh-ed25519", + fingerprint: "SHA256:abcdef", + prompt: "Trust host key?", +}; + +function waitForTimeout(): Promise { + return new Promise((resolve) => { + setTimeout(resolve, TEST_TIMEOUT_MS * 3); + }); +} + +describe("HostKeyVerificationService", () => { + let service: HostKeyVerificationService; + let requests: HostKeyVerificationRequest[]; + let releaseResponder: () => void; + + beforeEach(() => { + service = new HostKeyVerificationService(TEST_TIMEOUT_MS); + requests = []; + service.on("request", (req: HostKeyVerificationRequest) => { + requests.push(req); + }); + releaseResponder = service.registerInteractiveResponder(); + }); + + it("resolves on explicit respond", async () => { + const verification = service.requestVerification(REQUEST_PARAMS); + + expect(requests).toHaveLength(1); + service.respond(requests[0].requestId, true); + + const result = await verification; + expect(result).toBe(true); + }); + + it("resolves false on timeout", async () => { + const verification = service.requestVerification(REQUEST_PARAMS); + + await waitForTimeout(); + + const result = await verification; + expect(result).toBe(false); + }); + + it("deduped waiters all resolve on respond", async () => { + const verification1 = service.requestVerification(REQUEST_PARAMS); + const verification2 = service.requestVerification(REQUEST_PARAMS); + const verification3 = service.requestVerification(REQUEST_PARAMS); + + expect(requests).toHaveLength(1); + service.respond(requests[0].requestId, true); + + const results = await Promise.all([verification1, verification2, verification3]); + expect(results).toEqual([true, true, true]); + }); + + it("deduped waiters all resolve false on timeout", async () => { + const verification1 = service.requestVerification(REQUEST_PARAMS); + const verification2 = service.requestVerification(REQUEST_PARAMS); + const verification3 = service.requestVerification(REQUEST_PARAMS); + + await waitForTimeout(); + + const results = await Promise.all([verification1, verification2, verification3]); + expect(results).toEqual([false, false, false]); + }); + + it("late respond after timeout is a no-op", async () => { + const verification = service.requestVerification(REQUEST_PARAMS); + const requestId = requests[0].requestId; + + await waitForTimeout(); + const result = await verification; + expect(result).toBe(false); + + expect(() => { + service.respond(requestId, true); + }).not.toThrow(); + }); + + it("host can be re-requested after timeout cleanup", async () => { + const firstVerification = service.requestVerification(REQUEST_PARAMS); + + await waitForTimeout(); + const firstResult = await firstVerification; + expect(firstResult).toBe(false); + + const secondVerification = service.requestVerification(REQUEST_PARAMS); + + expect(requests).toHaveLength(2); + expect(requests[0].requestId).not.toBe(requests[1].requestId); + + service.respond(requests[1].requestId, true); + + const secondResult = await secondVerification; + expect(secondResult).toBe(true); + }); + + it("emits request event only for first caller", async () => { + const verification1 = service.requestVerification(REQUEST_PARAMS); + const verification2 = service.requestVerification(REQUEST_PARAMS); + const verification3 = service.requestVerification(REQUEST_PARAMS); + + expect(requests).toHaveLength(1); + + service.respond(requests[0].requestId, true); + + const results = await Promise.all([verification1, verification2, verification3]); + expect(results).toEqual([true, true, true]); + }); + + it("rejects immediately with no responders", async () => { + releaseResponder(); + + const result = await service.requestVerification(REQUEST_PARAMS); + + expect(result).toBe(false); + expect(requests).toHaveLength(0); + }); + + it("emits request when responder is registered", async () => { + releaseResponder(); + const release = service.registerInteractiveResponder(); + + const verification = service.requestVerification(REQUEST_PARAMS); + + expect(requests).toHaveLength(1); + service.respond(requests[0].requestId, true); + + const result = await verification; + expect(result).toBe(true); + + release(); + }); + + it("rejects immediately after responder released", async () => { + releaseResponder(); + + const result = await service.requestVerification(REQUEST_PARAMS); + + expect(result).toBe(false); + expect(requests).toHaveLength(0); + }); + + it("keeps pending verification alive when last responder disconnects", async () => { + const verification = service.requestVerification(REQUEST_PARAMS); + expect(requests).toHaveLength(1); + + // Simulate renderer disconnect — last responder released while prompt pending. + releaseResponder(); + + // Re-register before responding (simulates reconnect). + const release2 = service.registerInteractiveResponder(); + service.respond(requests[0].requestId, true); + + // Pending request survived the responder gap and was accepted. + const result = await verification; + expect(result).toBe(true); + + release2(); + }); + + it("times out pending verification even after all responders disconnect", async () => { + const verification = service.requestVerification(REQUEST_PARAMS); + expect(requests).toHaveLength(1); + + releaseResponder(); + + // No responder to approve — timeout should still fire and reject. + await waitForTimeout(); + const result = await verification; + expect(result).toBe(false); + }); + it("double-release is safe", () => { + releaseResponder(); + + const release = service.registerInteractiveResponder(); + + expect(() => { + release(); + release(); + }).not.toThrow(); + }); + + it("does not coalesce when dedupeKey differs", async () => { + const v1 = service.requestVerification({ ...REQUEST_PARAMS, dedupeKey: "example.com:22" }); + const v2 = service.requestVerification({ ...REQUEST_PARAMS, dedupeKey: "example.com:2222" }); + + expect(requests).toHaveLength(2); + + service.respond(requests[0].requestId, true); + service.respond(requests[1].requestId, false); + + expect(await v1).toBe(true); + expect(await v2).toBe(false); + }); + + it("coalesces when dedupeKey matches", async () => { + const v1 = service.requestVerification({ ...REQUEST_PARAMS, dedupeKey: "example.com:22" }); + const v2 = service.requestVerification({ ...REQUEST_PARAMS, dedupeKey: "example.com:22" }); + + expect(requests).toHaveLength(1); + + service.respond(requests[0].requestId, true); + + const results = await Promise.all([v1, v2]); + expect(results).toEqual([true, true]); + }); + + it("replays pending requests to late subscribers", async () => { + // Request emitted BEFORE subscriber connects + const verification = service.requestVerification(REQUEST_PARAMS); + expect(requests).toHaveLength(1); + const requestId = requests[0].requestId; + + // Late subscriber should see the pending request via snapshot + const lateRequests: HostKeyVerificationRequest[] = []; + const { snapshot, unsubscribe } = service.subscribeRequests((req) => { + lateRequests.push(req); + }); + + expect(snapshot).toHaveLength(1); + expect(snapshot[0].requestId).toBe(requestId); + expect(snapshot[0].host).toBe("example.com"); + + unsubscribe(); + service.respond(requestId, true); + await verification; + }); + + it("does not replay resolved requests", async () => { + const verification = service.requestVerification(REQUEST_PARAMS); + const requestId = requests[0].requestId; + + service.respond(requestId, true); + await verification; + + // eslint-disable-next-line @typescript-eslint/no-empty-function -- no-op listener; we only care about the snapshot + const { snapshot, unsubscribe } = service.subscribeRequests(() => {}); + expect(snapshot).toHaveLength(0); + unsubscribe(); + }); + + it("emits removed event on timeout", async () => { + const removedIds: string[] = []; + service.on("removed", (id: string) => removedIds.push(id)); + + const verification = service.requestVerification(REQUEST_PARAMS); + const requestId = requests[0].requestId; + + await waitForTimeout(); + await verification; + + expect(removedIds).toEqual([requestId]); + }); + + it("emits removed event on explicit respond", async () => { + const removedIds: string[] = []; + service.on("removed", (id: string) => removedIds.push(id)); + + const verification = service.requestVerification(REQUEST_PARAMS); + const requestId = requests[0].requestId; + + service.respond(requestId, true); + await verification; + + expect(removedIds).toEqual([requestId]); + }); +}); diff --git a/src/node/services/hostKeyVerificationService.ts b/src/node/services/hostKeyVerificationService.ts new file mode 100644 index 0000000000..cc0778bff7 --- /dev/null +++ b/src/node/services/hostKeyVerificationService.ts @@ -0,0 +1,138 @@ +import { EventEmitter } from "events"; +import * as crypto from "crypto"; +import { HOST_KEY_APPROVAL_TIMEOUT_MS } from "@/common/constants/ssh"; +import type { HostKeyVerificationRequest } from "@/common/orpc/schemas/ssh"; + +interface PendingEntry { + request: HostKeyVerificationRequest; + dedupeKey: string; + timer: ReturnType; + waiters: Array<(accept: boolean) => void>; +} + +export class HostKeyVerificationService extends EventEmitter { + private pending = new Map(); + /** + * Dedup: endpoint identity -> inflight requestId. + * Callers can provide host+port identity to avoid cross-port prompt coalescing. + */ + private inflightByDedupeKey = new Map(); + private activeResponders = 0; + private readonly timeoutMs: number; + + constructor(timeoutMs = HOST_KEY_APPROVAL_TIMEOUT_MS) { + super(); + this.timeoutMs = timeoutMs; + } + + registerInteractiveResponder(): () => void { + this.activeResponders += 1; + + let released = false; + return () => { + if (released) { + return; + } + + released = true; + this.activeResponders = Math.max(0, this.activeResponders - 1); + + // Keep responder count as an admission gate only. Pending requests are + // not rejected on disconnect and instead resolve via explicit respond() + // or timeout, which prevents reconnect churn from killing in-flight + // prompts. + }; + } + + hasInteractiveResponder(): boolean { + return this.activeResponders > 0; + } + + /** + * Atomic subscribe+snapshot: register listener FIRST, then return current + * pending requests. Any request emitted between registration and snapshot + * appears in both the listener and snapshot — callers must deduplicate + * (the frontend already does via requestId check in setPendingQueue). + */ + subscribeRequests( + onRequest: (req: HostKeyVerificationRequest) => void, + onRemoved?: (requestId: string) => void + ): { + snapshot: HostKeyVerificationRequest[]; + unsubscribe: () => void; + } { + this.on("request", onRequest); + if (onRemoved) this.on("removed", onRemoved); + return { + snapshot: Array.from(this.pending.values()).map((entry) => entry.request), + unsubscribe: () => { + this.off("request", onRequest); + if (onRemoved) this.off("removed", onRemoved); + }, + }; + } + + private finalizeRequest(requestId: string, accept: boolean): void { + const entry = this.pending.get(requestId); + if (!entry) { + return; + } + + clearTimeout(entry.timer); + this.pending.delete(requestId); + this.inflightByDedupeKey.delete(entry.dedupeKey); + this.emit("removed", requestId); + + for (const resolve of entry.waiters) { + resolve(accept); + } + } + + /** + * Called from SSH pool when a host-key prompt is detected. + * Blocks until the user responds or timeout fires. + */ + async requestVerification( + params: Omit & { dedupeKey?: string } + ): Promise { + if (!this.hasInteractiveResponder()) { + return false; + } + + const { dedupeKey: dedupeKeyOverride, ...requestParams } = params; + const dedupeKey = dedupeKeyOverride ?? requestParams.host; + + // Dedup: if a prompt for this endpoint is already pending, append another waiter. + const existingId = this.inflightByDedupeKey.get(dedupeKey); + if (existingId) { + const entry = this.pending.get(existingId); + if (entry) { + return new Promise((resolve) => { + entry.waiters.push(resolve); + }); + } + } + + const requestId = crypto.randomUUID(); + this.inflightByDedupeKey.set(dedupeKey, requestId); + + return new Promise((resolve) => { + const request: HostKeyVerificationRequest = { requestId, ...requestParams }; + const entry: PendingEntry = { + request, + dedupeKey, + timer: setTimeout(() => { + this.finalizeRequest(requestId, false); + }, this.timeoutMs), + waiters: [resolve], + }; + + this.pending.set(requestId, entry); + this.emit("request", request); + }); + } + + respond(requestId: string, accept: boolean): void { + this.finalizeRequest(requestId, accept); + } +} diff --git a/src/node/services/serviceContainer.ts b/src/node/services/serviceContainer.ts index bb41baeab7..286c12aea3 100644 --- a/src/node/services/serviceContainer.ts +++ b/src/node/services/serviceContainer.ts @@ -45,12 +45,15 @@ import { McpOauthService } from "@/node/services/mcpOauthService"; import { IdleCompactionService } from "@/node/services/idleCompactionService"; import { getSigningService, type SigningService } from "@/node/services/signingService"; import { coderService, type CoderService } from "@/node/services/coderService"; +import { HostKeyVerificationService } from "@/node/services/hostKeyVerificationService"; import { WorkspaceLifecycleHooks } from "@/node/services/workspaceLifecycleHooks"; import { createStartCoderOnUnarchiveHook, createStopCoderOnArchiveHook, } from "@/node/runtime/coderLifecycleHooks"; import { setGlobalCoderService } from "@/node/runtime/runtimeFactory"; +import { setHostKeyVerificationService } from "@/node/runtime/sshConnectionPool"; +import { setHostKeyVerificationService as setSSH2HostKeyVerificationService } from "@/node/runtime/SSH2ConnectionPool"; import { PolicyService } from "@/node/services/policyService"; import type { ORPCContext } from "@/node/orpc/context"; @@ -112,6 +115,7 @@ export class ServiceContainer { public readonly signingService: SigningService; public readonly policyService: PolicyService; public readonly coderService: CoderService; + public readonly hostKeyVerificationService = new HostKeyVerificationService(); private readonly ptyService: PTYService; public readonly idleCompactionService: IdleCompactionService; @@ -223,6 +227,8 @@ export class ServiceContainer { // Register globally so all createRuntime calls can create CoderSSHRuntime setGlobalCoderService(this.coderService); + setHostKeyVerificationService(this.hostKeyVerificationService); + setSSH2HostKeyVerificationService(this.hostKeyVerificationService); // Backend timing stats (behind feature flag). this.aiService.on("stream-start", (data: StreamStartEvent) => @@ -435,6 +441,7 @@ export class ServiceContainer { policyService: this.policyService, signingService: this.signingService, coderService: this.coderService, + hostKeyVerificationService: this.hostKeyVerificationService, }; } diff --git a/tests/ipc/runtime/runtimeExecuteBash.test.ts b/tests/ipc/runtime/runtimeExecuteBash.test.ts index b1f5f32a9a..11821fc6e4 100644 --- a/tests/ipc/runtime/runtimeExecuteBash.test.ts +++ b/tests/ipc/runtime/runtimeExecuteBash.test.ts @@ -22,6 +22,8 @@ import { sendMessageAndWait, extractTextFromEvents, HAIKU_MODEL, + STREAM_TIMEOUT_LOCAL_MS, + STREAM_TIMEOUT_SSH_MS, TEST_TIMEOUT_LOCAL_MS, TEST_TIMEOUT_SSH_MS, getTestRunner, @@ -148,6 +150,8 @@ describeIntegration("Runtime Bash Execution", () => { return undefined; // undefined = defaults to local }; + const streamTimeoutMs = type === "ssh" ? STREAM_TIMEOUT_SSH_MS : STREAM_TIMEOUT_LOCAL_MS; + // SSH tests run serially to avoid Docker container overload const runTest = getTestRunner(type); @@ -184,7 +188,8 @@ describeIntegration("Runtime Bash Execution", () => { workspaceId, 'Use the bash tool with args: { script: "echo Hello World", timeout_secs: 30, run_in_background: false, display_name: "echo-hello" }. Do not spawn a sub-agent.', HAIKU_MODEL, - BASH_ONLY + BASH_ONLY, + streamTimeoutMs ); // Extract response text @@ -253,7 +258,8 @@ describeIntegration("Runtime Bash Execution", () => { workspaceId, 'Use the bash tool with args: { script: "export TEST_VAR=test123 && echo Value:$TEST_VAR", timeout_secs: 30, run_in_background: false, display_name: "env-var" }. Do not spawn a sub-agent.', HAIKU_MODEL, - BASH_ONLY + BASH_ONLY, + streamTimeoutMs ); // Extract response text @@ -331,7 +337,7 @@ describeIntegration("Runtime Bash Execution", () => { 'Use the bash tool with args: { script: "echo testdata > /tmp/test.txt && cat /tmp/test.txt | grep test", timeout_secs: 30, run_in_background: false, display_name: "stdin-grep" }. Do not spawn a sub-agent.', HAIKU_MODEL, BASH_ONLY, - 30000 // Relaxed timeout for CI stability (was 10s) + streamTimeoutMs ); // Calculate actual tool execution duration @@ -410,7 +416,7 @@ describeIntegration("Runtime Bash Execution", () => { 'Use the bash tool with args: { script: "for i in {1..1000}; do echo \"terminal bench line $i\" >> testfile.txt; done && grep -n \"terminal bench\" testfile.txt | head -n 200", timeout_secs: 60, run_in_background: false, display_name: "grep-head" }. Do not spawn a sub-agent.', HAIKU_MODEL, BASH_ONLY, - 30000 // Relaxed timeout for CI stability (was 15s) + streamTimeoutMs ); // Calculate actual tool execution duration diff --git a/tests/ipc/setup.ts b/tests/ipc/setup.ts index 31c5ae7dd7..3bc8891a13 100644 --- a/tests/ipc/setup.ts +++ b/tests/ipc/setup.ts @@ -14,6 +14,7 @@ import { } from "./helpers"; import type { OrpcSource } from "./helpers"; import type { ORPCContext } from "../../src/node/orpc/context"; +import type { HostKeyVerificationService } from "../../src/node/services/hostKeyVerificationService"; import type { RuntimeConfig } from "../../src/common/types/runtime"; import { createOrpcTestClient, type OrpcTestClient } from "./orpcTestClient"; import { shouldRunIntegrationTests, validateApiKeys, getApiKey } from "../testUtils"; @@ -112,6 +113,7 @@ export async function createTestEnvironment(): Promise { signingService: services.signingService, coderService: services.coderService, policyService: services.policyService, + hostKeyVerificationService: {} as HostKeyVerificationService, }; const orpc = createOrpcTestClient(orpcContext);