diff --git a/apps/dev-playground/app.yaml b/apps/dev-playground/app.yaml index 2827b8d..b7c34fe 100644 --- a/apps/dev-playground/app.yaml +++ b/apps/dev-playground/app.yaml @@ -1,3 +1,8 @@ env: - name: DATABRICKS_WAREHOUSE_ID valueFrom: sql-warehouse + +permissions: + user_authorization: + scopes: + - sql diff --git a/apps/dev-playground/client/src/routes/index.tsx b/apps/dev-playground/client/src/routes/index.tsx index 3f78342..c226955 100644 --- a/apps/dev-playground/client/src/routes/index.tsx +++ b/apps/dev-playground/client/src/routes/index.tsx @@ -5,6 +5,7 @@ import { } from "@tanstack/react-router"; import { Button, Card } from "@databricks/appkit-ui/react"; import { ThemeSelector } from "@/components/theme-selector"; +import { useEffect } from "react"; export const Route = createFileRoute("/")({ component: IndexRoute, @@ -16,6 +17,22 @@ export const Route = createFileRoute("/")({ function IndexRoute() { const navigate = useNavigate(); + useEffect(() => { + fetch("/sp") + .then((res) => res.json()) + .then((data) => { + console.log(data); + }); + }, []); + + // useEffect(() => { + // fetch("/obo") + // .then((res) => res.json()) + // .then((data) => { + // console.log(data); + // }); + // }, []); + return (
diff --git a/apps/dev-playground/server/index.ts b/apps/dev-playground/server/index.ts index 9391c62..5b36c7d 100644 --- a/apps/dev-playground/server/index.ts +++ b/apps/dev-playground/server/index.ts @@ -3,5 +3,51 @@ import { reconnect } from "./reconnect-plugin"; import { telemetryExamples } from "./telemetry-example-plugin"; createApp({ - plugins: [server(), reconnect(), telemetryExamples(), analytics({})], + plugins: [ + server({ autoStart: false }), + reconnect(), + telemetryExamples(), + analytics({}), + ], +}).then((appkit) => { + appkit.server + .extend((app) => { + // Debug endpoint to inspect headers + app.get("/debug-headers", (req, res) => { + const token = req.headers["x-forwarded-access-token"] as string; + res.json({ + hasToken: !!token, + tokenLength: token?.length, + tokenPrefix: token?.substring(0, 30), + userId: req.headers["x-forwarded-user"], + allHeaders: Object.keys(req.headers), + }); + }); + + app.get("/sp", (_req, res) => { + appkit.analytics.query("SELECT 1").then((result) => { + console.log(result); + res.json(result); + }); + }); + + app.get("/obo", (req, res) => { + appkit.analytics + .asUser(req) + .query("SELECT 1") + .then((result) => { + console.log(result); + res.json(result); + }) + .catch((error) => { + console.error("OBO Error:", error); + res.status(500).json({ + error: error.message, + errorCode: error.errorCode, + statusCode: error.statusCode, + }); + }); + }); + }) + .start(); }); diff --git a/packages/appkit/src/analytics/analytics.ts b/packages/appkit/src/analytics/analytics.ts index 81d4b12..1a638a3 100644 --- a/packages/appkit/src/analytics/analytics.ts +++ b/packages/appkit/src/analytics/analytics.ts @@ -6,9 +6,13 @@ import type { StreamExecutionSettings, } from "shared"; import { SQLWarehouseConnector } from "../connectors"; +import { + getCurrentUserId, + getWarehouseId, + getWorkspaceClient, +} from "../context"; import { Plugin, toPlugin } from "../plugin"; import type { Request, Response } from "../utils"; -import { getRequestContext, getWorkspaceClient } from "../utils"; import { queryDefaults } from "./defaults"; import { QueryProcessor } from "./query"; import type { @@ -20,7 +24,6 @@ import type { export class AnalyticsPlugin extends Plugin { name = "analytics"; envVars = []; - requiresDatabricksClient = true; protected static description = "Analytics plugin for data analysis"; protected declare config: IAnalyticsConfig; @@ -41,6 +44,7 @@ export class AnalyticsPlugin extends Plugin { } injectRoutes(router: IAppRouter) { + // Service principal endpoints this.route(router, { name: "arrow", method: "get", @@ -50,12 +54,22 @@ export class AnalyticsPlugin extends Plugin { }, }); + this.route(router, { + name: "query", + method: "post", + path: "/query/:query_key", + handler: async (req: Request, res: Response) => { + await this._handleQueryRoute(req, res); + }, + }); + + // User context endpoints - use asUser(req) to execute with user's identity this.route(router, { name: "arrowAsUser", method: "get", path: "/users/me/arrow-result/:jobId", handler: async (req: Request, res: Response) => { - await this._handleArrowRoute(req, res, { asUser: true }); + await this.asUser(req)._handleArrowRoute(req, res); }, }); @@ -64,29 +78,19 @@ export class AnalyticsPlugin extends Plugin { method: "post", path: "/users/me/query/:query_key", handler: async (req: Request, res: Response) => { - await this._handleQueryRoute(req, res, { asUser: true }); - }, - }); - - this.route(router, { - name: "query", - method: "post", - path: "/query/:query_key", - handler: async (req: Request, res: Response) => { - await this._handleQueryRoute(req, res, { asUser: false }); + await this.asUser(req)._handleQueryRoute(req, res); }, }); } - private async _handleArrowRoute( - req: Request, - res: Response, - { asUser = false }: { asUser?: boolean } = {}, - ): Promise { + /** + * Handle Arrow data download requests. + * When called via asUser(req), uses the user's Databricks credentials. + */ + async _handleArrowRoute(req: Request, res: Response): Promise { try { const { jobId } = req.params; - - const workspaceClient = getWorkspaceClient(asUser); + const workspaceClient = getWorkspaceClient(); console.log( `Processing Arrow job request: ${jobId} for plugin: ${this.name}`, @@ -111,11 +115,11 @@ export class AnalyticsPlugin extends Plugin { } } - private async _handleQueryRoute( - req: Request, - res: Response, - { asUser = false }: { asUser?: boolean } = {}, - ): Promise { + /** + * Handle SQL query execution requests. + * When called via asUser(req), uses the user's Databricks credentials. + */ + async _handleQueryRoute(req: Request, res: Response): Promise { const { query_key } = req.params; const { parameters, format = "JSON" } = req.body as IAnalyticsQueryRequest; const queryParameters = @@ -131,10 +135,8 @@ export class AnalyticsPlugin extends Plugin { type: "result", }; - const requestContext = getRequestContext(); - const userKey = asUser - ? requestContext.userId - : requestContext.serviceUserId; + // Get user key from current context (automatically includes user ID when in user context) + const userKey = getCurrentUserId(); if (!query_key) { res.status(400).json({ error: "query_key is required" }); @@ -164,7 +166,7 @@ export class AnalyticsPlugin extends Plugin { JSON.stringify(parameters), JSON.stringify(format), hashedQuery, - userKey, + // userKey is automatically set based on context ], }, }; @@ -186,9 +188,6 @@ export class AnalyticsPlugin extends Plugin { processedParams, queryParameters.formatParameters, signal, - { - asUser, - }, ); return { type: queryParameters.type, ...result }; @@ -198,15 +197,29 @@ export class AnalyticsPlugin extends Plugin { ); } + /** + * Execute a SQL query using the current execution context. + * + * When called directly: uses service principal credentials. + * When called via asUser(req).query(...): uses user's credentials. + * + * @example + * ```typescript + * // Service principal execution + * const result = await analytics.query("SELECT * FROM table") + * + * // User context execution (in route handler) + * const result = await this.asUser(req).query("SELECT * FROM table") + * ``` + */ async query( query: string, parameters?: Record, formatParameters?: Record, signal?: AbortSignal, - { asUser = false }: { asUser?: boolean } = {}, ): Promise { - const requestContext = getRequestContext(); - const workspaceClient = getWorkspaceClient(asUser); + const workspaceClient = getWorkspaceClient(); + const warehouseId = await getWarehouseId(); const { statement, parameters: sqlParameters } = this.queryProcessor.convertToSQLParameters(query, parameters); @@ -215,7 +228,7 @@ export class AnalyticsPlugin extends Plugin { workspaceClient, { statement, - warehouse_id: await requestContext.warehouseId, + warehouse_id: warehouseId, parameters: sqlParameters, ...formatParameters, }, @@ -225,8 +238,9 @@ export class AnalyticsPlugin extends Plugin { return response.result; } - // If we need arrow stream in more plugins we can define this as a base method in the core plugin class - // and have a generic endpoint for each plugin that consumes this arrow data. + /** + * Get Arrow-formatted data for a completed query job. + */ protected async getArrowData( workspaceClient: WorkspaceClient, jobId: string, diff --git a/packages/appkit/src/analytics/query.ts b/packages/appkit/src/analytics/query.ts index 39c9a2f..c891687 100644 --- a/packages/appkit/src/analytics/query.ts +++ b/packages/appkit/src/analytics/query.ts @@ -1,7 +1,7 @@ import { createHash } from "node:crypto"; import type { sql } from "@databricks/sdk-experimental"; import { isSQLTypeMarker, type SQLTypeMarker, sql as sqlHelpers } from "shared"; -import { getRequestContext } from "../utils"; +import { getWorkspaceId } from "../context"; type SQLParameterValue = SQLTypeMarker | null | undefined; @@ -18,8 +18,7 @@ export class QueryProcessor { // auto-inject workspaceId if needed and not provided if (queryParams.has("workspaceId") && !processed.workspaceId) { - const requestContext = getRequestContext(); - const workspaceId = await requestContext.workspaceId; + const workspaceId = await getWorkspaceId(); if (workspaceId) { processed.workspaceId = sqlHelpers.string(workspaceId); } diff --git a/packages/appkit/src/connectors/lakebase/client.ts b/packages/appkit/src/connectors/lakebase/client.ts index 053f166..e05d282 100644 --- a/packages/appkit/src/connectors/lakebase/client.ts +++ b/packages/appkit/src/connectors/lakebase/client.ts @@ -268,22 +268,22 @@ export class LakebaseConnector { this.close(); } - /** Get Databricks workspace client - from config or request context */ + /** Get Databricks workspace client - from config or execution context */ private getWorkspaceClient(): WorkspaceClient { if (this.config.workspaceClient) { return this.config.workspaceClient; } try { - const { getRequestContext } = require("../../utils"); - const { serviceDatabricksClient } = getRequestContext(); + const { getWorkspaceClient: getClient } = require("../../context"); + const client = getClient(); // cache it for subsequent calls - this.config.workspaceClient = serviceDatabricksClient; - return serviceDatabricksClient; + this.config.workspaceClient = client; + return client; } catch (_error) { throw new Error( - "Databricks workspace client not available. Either pass it in config or use within AppKit request context.", + "Databricks workspace client not available. Either pass it in config or ensure ServiceContext is initialized.", ); } } diff --git a/packages/appkit/src/context/execution-context.ts b/packages/appkit/src/context/execution-context.ts new file mode 100644 index 0000000..a3326a6 --- /dev/null +++ b/packages/appkit/src/context/execution-context.ts @@ -0,0 +1,79 @@ +import { AsyncLocalStorage } from "node:async_hooks"; +import type { ExecutionContext, IUserContext } from "./service-context"; +import { ServiceContext, isUserContext } from "./service-context"; + +/** + * AsyncLocalStorage for execution context. + * Used to pass user context through the call stack without explicit parameters. + */ +const executionContextStorage = new AsyncLocalStorage(); + +/** + * Run a function in the context of a user. + * All calls within the function will have access to the user context. + * + * @param userContext - The user context to use + * @param fn - The function to run + * @returns The result of the function + */ +export function runInUserContext(userContext: IUserContext, fn: () => T): T { + return executionContextStorage.run(userContext, fn); +} + +/** + * Get the current execution context. + * + * - If running inside a user context (via asUser), returns the user context + * - Otherwise, returns the service context + * + * @throws Error if ServiceContext is not initialized + */ +export function getExecutionContext(): ExecutionContext { + const userContext = executionContextStorage.getStore(); + if (userContext) { + return userContext; + } + return ServiceContext.get(); +} + +/** + * Get the current user ID for cache keying and telemetry. + * + * Returns the user ID if in user context, otherwise the service user ID. + */ +export function getCurrentUserId(): string { + const ctx = getExecutionContext(); + if (isUserContext(ctx)) { + return ctx.userId; + } + return ctx.serviceUserId; +} + +/** + * Get the WorkspaceClient for the current execution context. + */ +export function getWorkspaceClient() { + return getExecutionContext().client; +} + +/** + * Get the warehouse ID promise. + */ +export function getWarehouseId(): Promise { + return getExecutionContext().warehouseId; +} + +/** + * Get the workspace ID promise. + */ +export function getWorkspaceId(): Promise { + return getExecutionContext().workspaceId; +} + +/** + * Check if currently running in a user context. + */ +export function isInUserContext(): boolean { + const ctx = executionContextStorage.getStore(); + return ctx !== undefined; +} diff --git a/packages/appkit/src/context/index.ts b/packages/appkit/src/context/index.ts new file mode 100644 index 0000000..0516872 --- /dev/null +++ b/packages/appkit/src/context/index.ts @@ -0,0 +1,17 @@ +export { + ServiceContext, + isUserContext, + type ExecutionContext, + type IServiceContext, + type IUserContext, +} from "./service-context"; + +export { + getExecutionContext, + getCurrentUserId, + getWorkspaceClient, + getWarehouseId, + getWorkspaceId, + isInUserContext, + runInUserContext, +} from "./execution-context"; diff --git a/packages/appkit/src/context/service-context.ts b/packages/appkit/src/context/service-context.ts new file mode 100644 index 0000000..1e41680 --- /dev/null +++ b/packages/appkit/src/context/service-context.ts @@ -0,0 +1,288 @@ +import { + type ClientOptions, + type sql, + WorkspaceClient, +} from "@databricks/sdk-experimental"; +import { + name as productName, + version as productVersion, +} from "../../package.json"; + +/** + * Service context holds the service principal client and shared resources. + * This is initialized once at app startup and shared across all requests. + */ +export interface IServiceContext { + /** WorkspaceClient authenticated as the service principal */ + client: WorkspaceClient; + /** The service principal's user ID */ + serviceUserId: string; + /** Promise that resolves to the warehouse ID */ + warehouseId: Promise; + /** Promise that resolves to the workspace ID */ + workspaceId: Promise; +} + +/** + * User execution context extends the service context with user-specific data. + * Created on-demand when asUser(req) is called. + */ +export interface IUserContext { + /** WorkspaceClient authenticated as the user */ + client: WorkspaceClient; + /** The user's ID (from request headers) */ + userId: string; + /** The user's name (from request headers) */ + userName?: string; + /** Promise that resolves to the warehouse ID (inherited from service context) */ + warehouseId: Promise; + /** Promise that resolves to the workspace ID (inherited from service context) */ + workspaceId: Promise; + /** Flag indicating this is a user context */ + isUserContext: true; +} + +/** + * Execution context can be either service or user context. + */ +export type ExecutionContext = IServiceContext | IUserContext; + +/** + * Check if an execution context is a user context. + */ +export function isUserContext(ctx: ExecutionContext): ctx is IUserContext { + return "isUserContext" in ctx && ctx.isUserContext === true; +} + +function getClientOptions(): ClientOptions { + const isDev = process.env.NODE_ENV === "development"; + const normalizedVersion = productVersion + .split(".") + .slice(0, 3) + .join(".") as ClientOptions["productVersion"]; + + return { + product: productName, + productVersion: normalizedVersion, + ...(isDev && { userAgentExtra: { mode: "dev" } }), + }; +} + +/** + * ServiceContext is a singleton that manages the service principal's + * WorkspaceClient and shared resources like warehouse/workspace IDs. + * + * It's initialized once at app startup and provides the foundation + * for both service principal and user context execution. + */ +export class ServiceContext { + private static instance: IServiceContext | null = null; + private static initPromise: Promise | null = null; + + /** + * Initialize the service context. Should be called once at app startup. + * Safe to call multiple times - will return the same instance. + */ + static async initialize(): Promise { + if (ServiceContext.instance) { + return ServiceContext.instance; + } + + if (ServiceContext.initPromise) { + return ServiceContext.initPromise; + } + + ServiceContext.initPromise = ServiceContext.createContext(); + ServiceContext.instance = await ServiceContext.initPromise; + return ServiceContext.instance; + } + + /** + * Get the initialized service context. + * @throws Error if not initialized + */ + static get(): IServiceContext { + if (!ServiceContext.instance) { + throw new Error( + "ServiceContext not initialized. Call ServiceContext.initialize() first.", + ); + } + return ServiceContext.instance; + } + + /** + * Check if the service context has been initialized. + */ + static isInitialized(): boolean { + return ServiceContext.instance !== null; + } + + /** + * Create a user context from request headers. + * + * @param token - The user's access token from x-forwarded-access-token header + * @param userId - The user's ID from x-forwarded-user header + * @param userName - Optional user name + * @throws Error if token is not provided + */ + static createUserContext( + token: string, + userId: string, + userName?: string, + ): IUserContext { + if (!token) { + throw new Error("User token is required to create user context"); + } + + const host = process.env.DATABRICKS_HOST; + if (!host) { + throw new Error( + "DATABRICKS_HOST environment variable is required for user context", + ); + } + + const serviceCtx = ServiceContext.get(); + + // Create user client with the OAuth token from Databricks Apps + // Note: We use authType: "pat" because the token is passed as a Bearer token + // just like a PAT, even though it's technically an OAuth token + const userClient = new WorkspaceClient( + { + token, + host, + authType: "pat", + }, + getClientOptions(), + ); + + // Log for debugging in production + console.log( + `[ServiceContext] Created user context: userId=${userId}, tokenLength=${token.length}`, + ); + + return { + client: userClient, + userId, + userName, + warehouseId: serviceCtx.warehouseId, + workspaceId: serviceCtx.workspaceId, + isUserContext: true, + }; + } + + /** + * Get the client options for WorkspaceClient. + * Exposed for testing purposes. + */ + static getClientOptions(): ClientOptions { + return getClientOptions(); + } + + private static async createContext(): Promise { + const client = new WorkspaceClient({}, getClientOptions()); + + const warehouseId = ServiceContext.getWarehouseId(client); + const workspaceId = ServiceContext.getWorkspaceId(client); + const currentUser = await client.currentUser.me(); + + if (!currentUser.id) { + throw new Error("Service user ID not found"); + } + + return { + client, + serviceUserId: currentUser.id, + warehouseId, + workspaceId, + }; + } + + private static async getWorkspaceId( + client: WorkspaceClient, + ): Promise { + if (process.env.DATABRICKS_WORKSPACE_ID) { + return process.env.DATABRICKS_WORKSPACE_ID; + } + + const response = (await client.apiClient.request({ + path: "/api/2.0/preview/scim/v2/Me", + method: "GET", + headers: new Headers(), + raw: false, + query: {}, + responseHeaders: ["x-databricks-org-id"], + })) as { "x-databricks-org-id": string }; + + if (!response["x-databricks-org-id"]) { + throw new Error("Workspace ID not found"); + } + + return response["x-databricks-org-id"]; + } + + private static async getWarehouseId( + client: WorkspaceClient, + ): Promise { + if (process.env.DATABRICKS_WAREHOUSE_ID) { + return process.env.DATABRICKS_WAREHOUSE_ID; + } + + if (process.env.NODE_ENV === "development") { + const response = (await client.apiClient.request({ + path: "/api/2.0/sql/warehouses", + method: "GET", + headers: new Headers(), + raw: false, + query: { + skip_cannot_use: "true", + }, + })) as { warehouses: sql.EndpointInfo[] }; + + const priorities: Record = { + RUNNING: 0, + STOPPED: 1, + STARTING: 2, + STOPPING: 3, + DELETED: 99, + DELETING: 99, + }; + + const warehouses = (response.warehouses || []).sort((a, b) => { + return ( + priorities[a.state as sql.State] - priorities[b.state as sql.State] + ); + }); + + if (response.warehouses.length === 0) { + throw new Error( + "Warehouse ID not found. Please configure the DATABRICKS_WAREHOUSE_ID environment variable.", + ); + } + + const firstWarehouse = warehouses[0]; + if ( + firstWarehouse.state === "DELETED" || + firstWarehouse.state === "DELETING" || + !firstWarehouse.id + ) { + throw new Error( + "Warehouse ID not found. Please configure the DATABRICKS_WAREHOUSE_ID environment variable.", + ); + } + + return firstWarehouse.id; + } + + throw new Error( + "Warehouse ID not found. Please configure the DATABRICKS_WAREHOUSE_ID environment variable.", + ); + } + + /** + * Reset the service context. Only for testing purposes. + */ + static reset(): void { + ServiceContext.instance = null; + ServiceContext.initPromise = null; + } +} diff --git a/packages/appkit/src/core/appkit.ts b/packages/appkit/src/core/appkit.ts index 7aecaf7..ac6cfe2 100644 --- a/packages/appkit/src/core/appkit.ts +++ b/packages/appkit/src/core/appkit.ts @@ -8,6 +8,7 @@ import type { PluginMap, } from "shared"; import { CacheManager } from "../cache"; +import { ServiceContext } from "../context"; import type { TelemetryConfig } from "../telemetry"; import { TelemetryManager } from "../telemetry"; @@ -91,9 +92,14 @@ export class AppKit { cache?: CacheConfig; } = {}, ): Promise> { + // Initialize core services TelemetryManager.initialize(config?.telemetry); await CacheManager.getInstance(config?.cache); + // Initialize ServiceContext for Databricks client management + // This provides the service principal client and shared resources + await ServiceContext.initialize(); + const rawPlugins = config.plugins as T; const preparedPlugins = AppKit.preparePlugins(rawPlugins); const mergedConfig = { diff --git a/packages/appkit/src/index.ts b/packages/appkit/src/index.ts index ebb4677..803f0f9 100644 --- a/packages/appkit/src/index.ts +++ b/packages/appkit/src/index.ts @@ -10,6 +10,18 @@ export { } from "shared"; export { analytics } from "./analytics"; export { CacheManager } from "./cache"; +export { + ServiceContext, + getExecutionContext, + getCurrentUserId, + getWorkspaceClient, + getWarehouseId, + getWorkspaceId, + isInUserContext, + type ExecutionContext, + type IServiceContext, + type IUserContext, +} from "./context"; export { createApp } from "./core"; export { Plugin, toPlugin } from "./plugin"; export { server } from "./server"; @@ -22,4 +34,9 @@ export { SpanStatusCode, } from "./telemetry"; export { appKitTypesPlugin } from "./type-generator/vite-plugin"; + +/** + * @deprecated Use getExecutionContext() from "./context" instead. + * This export is kept for backward compatibility. + */ export { getRequestContext } from "./utils"; diff --git a/packages/appkit/src/plugin/interceptors/cache.ts b/packages/appkit/src/plugin/interceptors/cache.ts index 2ad9f73..7f9f31e 100644 --- a/packages/appkit/src/plugin/interceptors/cache.ts +++ b/packages/appkit/src/plugin/interceptors/cache.ts @@ -1,6 +1,6 @@ import type { CacheManager } from "../../cache"; import type { CacheConfig } from "shared"; -import type { ExecutionContext, ExecutionInterceptor } from "./types"; +import type { InterceptorContext, ExecutionInterceptor } from "./types"; // interceptor to handle caching logic export class CacheInterceptor implements ExecutionInterceptor { @@ -11,7 +11,7 @@ export class CacheInterceptor implements ExecutionInterceptor { async intercept( fn: () => Promise, - context: ExecutionContext, + context: InterceptorContext, ): Promise { // if cache disabled, ignore if (!this.config.enabled || !this.config.cacheKey?.length) { diff --git a/packages/appkit/src/plugin/interceptors/retry.ts b/packages/appkit/src/plugin/interceptors/retry.ts index cc62bb0..274a0d3 100644 --- a/packages/appkit/src/plugin/interceptors/retry.ts +++ b/packages/appkit/src/plugin/interceptors/retry.ts @@ -1,5 +1,5 @@ import type { RetryConfig } from "shared"; -import type { ExecutionContext, ExecutionInterceptor } from "./types"; +import type { InterceptorContext, ExecutionInterceptor } from "./types"; // interceptor to handle retry logic export class RetryInterceptor implements ExecutionInterceptor { @@ -15,7 +15,7 @@ export class RetryInterceptor implements ExecutionInterceptor { async intercept( fn: () => Promise, - context: ExecutionContext, + context: InterceptorContext, ): Promise { let lastError: Error | unknown; diff --git a/packages/appkit/src/plugin/interceptors/telemetry.ts b/packages/appkit/src/plugin/interceptors/telemetry.ts index 8335180..3c0b659 100644 --- a/packages/appkit/src/plugin/interceptors/telemetry.ts +++ b/packages/appkit/src/plugin/interceptors/telemetry.ts @@ -1,7 +1,7 @@ import type { ITelemetry, Span } from "../../telemetry"; import { SpanStatusCode } from "../../telemetry"; import type { TelemetryConfig } from "shared"; -import type { ExecutionContext, ExecutionInterceptor } from "./types"; +import type { InterceptorContext, ExecutionInterceptor } from "./types"; /** * Interceptor to automatically instrument plugin executions with telemetry spans. @@ -15,7 +15,7 @@ export class TelemetryInterceptor implements ExecutionInterceptor { async intercept( fn: () => Promise, - _context: ExecutionContext, + _context: InterceptorContext, ): Promise { const spanName = this.config?.spanName || "plugin.execute"; return this.telemetry.startActiveSpan( diff --git a/packages/appkit/src/plugin/interceptors/timeout.ts b/packages/appkit/src/plugin/interceptors/timeout.ts index e0ee42b..1f5a26a 100644 --- a/packages/appkit/src/plugin/interceptors/timeout.ts +++ b/packages/appkit/src/plugin/interceptors/timeout.ts @@ -1,4 +1,4 @@ -import type { ExecutionContext, ExecutionInterceptor } from "./types"; +import type { InterceptorContext, ExecutionInterceptor } from "./types"; // interceptor to handle timeout logic export class TimeoutInterceptor implements ExecutionInterceptor { @@ -6,7 +6,7 @@ export class TimeoutInterceptor implements ExecutionInterceptor { async intercept( fn: () => Promise, - context: ExecutionContext, + context: InterceptorContext, ): Promise { // create timeout signal const timeoutController = new AbortController(); diff --git a/packages/appkit/src/plugin/interceptors/types.ts b/packages/appkit/src/plugin/interceptors/types.ts index 1e10af3..633e38d 100644 --- a/packages/appkit/src/plugin/interceptors/types.ts +++ b/packages/appkit/src/plugin/interceptors/types.ts @@ -1,10 +1,13 @@ -export interface ExecutionContext { +/** + * Context passed through the interceptor chain. + * Contains signal for cancellation, metadata, and user identification. + */ +export interface InterceptorContext { signal?: AbortSignal; metadata?: Map; userKey: string; - asUser?: boolean; } export interface ExecutionInterceptor { - intercept(fn: () => Promise, context: ExecutionContext): Promise; + intercept(fn: () => Promise, context: InterceptorContext): Promise; } diff --git a/packages/appkit/src/plugin/plugin.ts b/packages/appkit/src/plugin/plugin.ts index 975d894..21c6e8f 100644 --- a/packages/appkit/src/plugin/plugin.ts +++ b/packages/appkit/src/plugin/plugin.ts @@ -13,6 +13,12 @@ import type { } from "shared"; import { AppManager } from "../app"; import { CacheManager } from "../cache"; +import { + ServiceContext, + getCurrentUserId, + runInUserContext, + type IUserContext, +} from "../context"; import { StreamManager } from "../stream"; import { type ITelemetry, @@ -26,10 +32,29 @@ import { RetryInterceptor } from "./interceptors/retry"; import { TelemetryInterceptor } from "./interceptors/telemetry"; import { TimeoutInterceptor } from "./interceptors/timeout"; import type { - ExecutionContext, + InterceptorContext, ExecutionInterceptor, } from "./interceptors/types"; +/** + * Methods that should not be proxied by asUser(). + * These are lifecycle/internal methods that don't make sense + * to execute in a user context. + */ +const EXCLUDED_FROM_PROXY = new Set([ + // Lifecycle methods + "setup", + "shutdown", + "validateEnv", + "injectRoutes", + "getEndpoints", + "abortActiveOperations", + // asUser itself - prevent chaining like .asUser().asUser() + "asUser", + // Internal methods + "constructor", +]); + export abstract class Plugin< TConfig extends BasePluginConfig = BasePluginConfig, > implements BasePlugin @@ -80,12 +105,123 @@ export abstract class Plugin< this.streamManager.abortAll(); } + /** + * Execute operations using the user's identity from the request. + * + * Returns a scoped instance of this plugin where all method calls + * will execute with the user's Databricks credentials instead of + * the service principal. + * + * @param req - The Express request containing the user token in headers + * @returns A scoped plugin instance that executes as the user + * @throws Error if user token is not available in request headers + * + * @example + * ```typescript + * // In route handler - execute query as the requesting user + * router.post('/users/me/query/:key', async (req, res) => { + * const result = await this.asUser(req).query(req.params.key) + * res.json(result) + * }) + * + * // Mixed execution in same handler + * router.post('/dashboard', async (req, res) => { + * const [systemData, userData] = await Promise.all([ + * this.getSystemStats(), // Service principal + * this.asUser(req).getUserPreferences(), // User context + * ]) + * res.json({ systemData, userData }) + * }) + * ``` + */ + asUser(req: express.Request): this { + const token = req.headers["x-forwarded-access-token"] as string; + const userId = req.headers["x-forwarded-user"] as string; + const isDev = process.env.NODE_ENV === "development"; + + // In local development, fall back to service principal + // since there's no user token available + if (!token && isDev) { + console.warn( + "[AppKit] asUser() called without user token in development mode. " + + "Using service principal. Use 'databricks apps run-local' for proper token passthrough.", + ); + // Return self - methods will use service context + return this; + } + + if (!token) { + throw new Error( + "User token not available in request headers. " + + "Ensure the request has the x-forwarded-access-token header.", + ); + } + + if (!userId && !isDev) { + throw new Error( + "User ID not available in request headers. " + + "Ensure the request has the x-forwarded-user header.", + ); + } + + // In dev mode without userId, use a placeholder + const effectiveUserId = userId || "dev-user"; + + // Debug logging for token passthrough issues + if (process.env.APPKIT_DEBUG === "true") { + console.log("[AppKit Debug] asUser() called:"); + console.log(" - Token present:", !!token); + console.log(" - Token length:", token?.length); + console.log(" - Token prefix:", token?.substring(0, 20) + "..."); + console.log(" - UserId:", effectiveUserId); + console.log(" - Host:", process.env.DATABRICKS_HOST); + } + + // Create user context + const userContext = ServiceContext.createUserContext( + token, + effectiveUserId, + ); + + // Return a proxy that wraps method calls in user context + return this.createUserContextProxy(userContext); + } + + /** + * Creates a proxy that wraps method calls in a user context. + * This allows all plugin methods to automatically use the user's + * Databricks credentials. + */ + private createUserContextProxy(userContext: IUserContext): this { + return new Proxy(this, { + get: (target, prop, receiver) => { + // Get the original property + const value = Reflect.get(target, prop, receiver); + + // Don't wrap non-functions + if (typeof value !== "function") { + return value; + } + + // Don't wrap excluded methods + if (typeof prop === "string" && EXCLUDED_FROM_PROXY.has(prop)) { + return value; + } + + // Wrap method to run in user context + return (...args: unknown[]) => { + return runInUserContext(userContext, () => value.apply(target, args)); + }; + }, + }) as this; + } + // streaming execution with interceptors protected async executeStream( res: IAppResponse, fn: StreamExecuteHandler, options: StreamExecutionSettings, - userKey: string, + userKey?: string, ) { // destructure options const { @@ -100,15 +236,18 @@ export abstract class Plugin< user: userConfig, }); + // Get user key from context if not provided + const effectiveUserKey = userKey ?? getCurrentUserId(); + const self = this; // wrapper function to ensure it returns a generator const asyncWrapperFn = async function* (streamSignal?: AbortSignal) { // build execution context - const context: ExecutionContext = { + const context: InterceptorContext = { signal: streamSignal, metadata: new Map(), - userKey: userKey, + userKey: effectiveUserKey, }; // build interceptors @@ -143,15 +282,18 @@ export abstract class Plugin< protected async execute( fn: (signal?: AbortSignal) => Promise, options: PluginExecutionSettings, - userKey: string, + userKey?: string, ): Promise { const executeConfig = this._buildExecutionConfig(options); const interceptors = this._buildInterceptors(executeConfig); - const context: ExecutionContext = { + // Get user key from context if not provided + const effectiveUserKey = userKey ?? getCurrentUserId(); + + const context: InterceptorContext = { metadata: new Map(), - userKey: userKey, + userKey: effectiveUserKey, }; try { @@ -232,7 +374,7 @@ export abstract class Plugin< private async _executeWithInterceptors( fn: (signal?: AbortSignal) => Promise, interceptors: ExecutionInterceptor[], - context: ExecutionContext, + context: InterceptorContext, ): Promise { // no interceptors, execute directly if (interceptors.length === 0) { diff --git a/packages/appkit/src/plugin/tests/cache.test.ts b/packages/appkit/src/plugin/tests/cache.test.ts index 3acf17f..7e01e77 100644 --- a/packages/appkit/src/plugin/tests/cache.test.ts +++ b/packages/appkit/src/plugin/tests/cache.test.ts @@ -1,7 +1,7 @@ import type { CacheConfig } from "shared"; import { beforeEach, describe, expect, test, vi } from "vitest"; import { CacheInterceptor } from "../interceptors/cache"; -import type { ExecutionContext } from "../interceptors/types"; +import type { InterceptorContext } from "../interceptors/types"; vi.mock("../../telemetry", () => ({ TelemetryManager: { @@ -78,7 +78,7 @@ class MockCacheManager { describe("CacheInterceptor", () => { let cacheManager: MockCacheManager; - let context: ExecutionContext; + let context: InterceptorContext; beforeEach(() => { cacheManager = new MockCacheManager(); @@ -180,7 +180,7 @@ describe("CacheInterceptor", () => { enabled: true, cacheKey: ["query", "sales"], }; - const contextWithToken: ExecutionContext = { + const contextWithToken: InterceptorContext = { metadata: new Map(), userKey: "user1", }; @@ -213,7 +213,7 @@ describe("CacheInterceptor", () => { ); // Service account context - const context1: ExecutionContext = { + const context1: InterceptorContext = { metadata: new Map(), userKey: "service", }; @@ -221,7 +221,7 @@ describe("CacheInterceptor", () => { await interceptor.intercept(fn1, context1); // User context - const context2: ExecutionContext = { + const context2: InterceptorContext = { metadata: new Map(), userKey: "user1", }; diff --git a/packages/appkit/src/plugin/tests/plugin.test.ts b/packages/appkit/src/plugin/tests/plugin.test.ts index 1095992..d485289 100644 --- a/packages/appkit/src/plugin/tests/plugin.test.ts +++ b/packages/appkit/src/plugin/tests/plugin.test.ts @@ -12,7 +12,7 @@ import { createMockTelemetry } from "@tools/test-helpers"; import { validateEnv } from "../../utils"; import type express from "express"; import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; -import type { ExecutionContext } from "../interceptors/types"; +import type { InterceptorContext } from "../interceptors/types"; import { Plugin } from "../plugin"; // Mock all dependencies @@ -499,7 +499,7 @@ describe("Plugin", () => { test("should execute function directly when no interceptors", async () => { const plugin = new TestPlugin(config); const mockFn = vi.fn().mockResolvedValue("direct-result"); - const context: ExecutionContext = { + const context: InterceptorContext = { metadata: new Map(), userKey: "test", }; @@ -514,7 +514,7 @@ describe("Plugin", () => { test("should chain interceptors correctly", async () => { const plugin = new TestPlugin(config); const mockFn = vi.fn().mockResolvedValue("chained-result"); - const context: ExecutionContext = { + const context: InterceptorContext = { metadata: new Map(), userKey: "test", }; @@ -541,9 +541,8 @@ describe("Plugin", () => { test("should pass context to interceptors", async () => { const plugin = new TestPlugin(config); const mockFn = vi.fn().mockResolvedValue("context-result"); - const context: ExecutionContext = { + const context: InterceptorContext = { metadata: new Map(), - asUser: true, signal: new AbortController().signal, userKey: "test", }; diff --git a/packages/appkit/src/plugin/tests/retry.test.ts b/packages/appkit/src/plugin/tests/retry.test.ts index b3b8578..913a826 100644 --- a/packages/appkit/src/plugin/tests/retry.test.ts +++ b/packages/appkit/src/plugin/tests/retry.test.ts @@ -1,10 +1,10 @@ import type { RetryConfig } from "shared"; import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; import { RetryInterceptor } from "../interceptors/retry"; -import type { ExecutionContext } from "../interceptors/types"; +import type { InterceptorContext } from "../interceptors/types"; describe("RetryInterceptor", () => { - let context: ExecutionContext; + let context: InterceptorContext; beforeEach(() => { context = { @@ -137,7 +137,7 @@ describe("RetryInterceptor", () => { const interceptor = new RetryInterceptor(config); const abortController = new AbortController(); - const contextWithSignal: ExecutionContext = { + const contextWithSignal: InterceptorContext = { metadata: new Map(), signal: abortController.signal, userKey: "test", diff --git a/packages/appkit/src/plugin/tests/timeout.test.ts b/packages/appkit/src/plugin/tests/timeout.test.ts index d0e5f01..b065bcb 100644 --- a/packages/appkit/src/plugin/tests/timeout.test.ts +++ b/packages/appkit/src/plugin/tests/timeout.test.ts @@ -1,9 +1,9 @@ import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; import { TimeoutInterceptor } from "../interceptors/timeout"; -import type { ExecutionContext } from "../interceptors/types"; +import type { InterceptorContext } from "../interceptors/types"; describe("TimeoutInterceptor", () => { - let context: ExecutionContext; + let context: InterceptorContext; beforeEach(() => { context = { @@ -61,7 +61,7 @@ describe("TimeoutInterceptor", () => { test("should combine user signal with timeout signal", async () => { const userController = new AbortController(); - const contextWithSignal: ExecutionContext = { + const contextWithSignal: InterceptorContext = { metadata: new Map(), signal: userController.signal, userKey: "test", @@ -82,7 +82,7 @@ describe("TimeoutInterceptor", () => { test("should combine signals when user signal exists", async () => { const userController = new AbortController(); - const contextWithSignal: ExecutionContext = { + const contextWithSignal: InterceptorContext = { metadata: new Map(), signal: userController.signal, userKey: "test", @@ -105,7 +105,7 @@ describe("TimeoutInterceptor", () => { const userController = new AbortController(); userController.abort(new Error("Already aborted")); - const contextWithSignal: ExecutionContext = { + const contextWithSignal: InterceptorContext = { metadata: new Map(), signal: userController.signal, userKey: "test", diff --git a/packages/appkit/src/server/index.ts b/packages/appkit/src/server/index.ts index ea837b2..b886db7 100644 --- a/packages/appkit/src/server/index.ts +++ b/packages/appkit/src/server/index.ts @@ -6,7 +6,6 @@ import express from "express"; import type { PluginPhase } from "shared"; import { Plugin, toPlugin } from "../plugin"; import { instrumentations } from "../telemetry"; -import { databricksClientMiddleware } from "../utils"; import { RemoteTunnelController } from "./remote-tunnel/remote-tunnel-controller"; import { StaticServer } from "./static-server"; import type { ServerConfig } from "./types"; @@ -183,10 +182,6 @@ export class ServerPlugin extends Plugin { if (plugin?.injectRoutes && typeof plugin.injectRoutes === "function") { const router = express.Router(); - // add databricks client middleware to the router if the plugin needs the request context - if (plugin.requiresDatabricksClient) - router.use(await databricksClientMiddleware()); - plugin.injectRoutes(router); const basePath = `/api/${plugin.name}`; diff --git a/packages/appkit/src/telemetry/tests/telemetry-interceptor.test.ts b/packages/appkit/src/telemetry/tests/telemetry-interceptor.test.ts index be6815f..f78358e 100644 --- a/packages/appkit/src/telemetry/tests/telemetry-interceptor.test.ts +++ b/packages/appkit/src/telemetry/tests/telemetry-interceptor.test.ts @@ -2,13 +2,13 @@ import type { TelemetryConfig } from "shared"; import { SpanStatusCode, type Span } from "@opentelemetry/api"; import { beforeEach, describe, expect, test, vi } from "vitest"; import { TelemetryInterceptor } from "../../plugin/interceptors/telemetry"; -import type { ExecutionContext } from "../../plugin/interceptors/types"; +import type { InterceptorContext } from "../../plugin/interceptors/types"; import type { ITelemetry } from "../types"; describe("TelemetryInterceptor", () => { let mockTelemetry: ITelemetry; let mockSpan: Span; - let context: ExecutionContext; + let context: InterceptorContext; beforeEach(() => { mockSpan = {