-
Notifications
You must be signed in to change notification settings - Fork 78
LCORE-1461: Fix /tools not handling "kubernetes" and static token auth
#1349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,10 +2,12 @@ | |
|
|
||
| import json | ||
| from collections.abc import Mapping | ||
| from typing import Optional | ||
| from urllib.parse import urlparse | ||
|
|
||
| from fastapi import Request | ||
|
|
||
| import constants | ||
| from configuration import AppConfig | ||
| from log import get_logger | ||
| from models.config import ModelContextProtocolServer | ||
|
|
@@ -121,3 +123,74 @@ def extract_propagated_headers( | |
| if value is not None: | ||
| propagated[header_name] = value | ||
| return propagated | ||
|
|
||
|
|
||
| def build_mcp_headers( | ||
| config: AppConfig, | ||
| mcp_headers: McpHeaders, | ||
| request_headers: Optional[Mapping[str, str]], | ||
| token: Optional[str] = None, | ||
| ) -> McpHeaders: | ||
| """Build complete MCP headers by merging all header sources for each MCP server. | ||
|
|
||
| For each configured MCP server, combines four header sources (in priority order, | ||
| highest first): | ||
|
|
||
| 1. Client-supplied headers from the ``MCP-HEADERS`` request header (keyed by server name). | ||
| 2. Statically resolved authorization headers from configuration (e.g. file-based secrets). | ||
| 3. Kubernetes Bearer token: when a header is configured with the ``kubernetes`` keyword, | ||
| the supplied ``token`` is formatted as ``Bearer <token>`` and used as its value. | ||
| ``client`` and ``oauth`` keywords are not resolved here — those values are already | ||
| provided by the client in source 1. | ||
| 4. Headers propagated from the incoming request via the server's configured allowlist. | ||
|
|
||
| Args: | ||
| config: Application configuration containing mcp_servers. | ||
| mcp_headers: Per-request headers from the client, keyed by MCP server name. | ||
| request_headers: Headers from the incoming HTTP request used for allowlist | ||
| propagation, or ``None`` when not available. | ||
| token: Optional Kubernetes service-account token used to resolve headers | ||
| configured with the ``kubernetes`` keyword. | ||
|
|
||
| Returns: | ||
| McpHeaders keyed by MCP server name with the complete merged set of headers. | ||
| Servers that end up with no headers are omitted from the result. | ||
| """ | ||
| if not config.mcp_servers: | ||
| return {} | ||
|
|
||
| complete: McpHeaders = {} | ||
|
|
||
| for mcp_server in config.mcp_servers: | ||
| server_headers: dict[str, str] = dict(mcp_headers.get(mcp_server.name, {})) | ||
| existing_lower = {k.lower() for k in server_headers} | ||
|
Comment on lines
+164
to
+166
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Normalize and validate raw This lookup assumes the incoming entry is already a Possible fix complete: McpHeaders = {}
for mcp_server in config.mcp_servers:
- server_headers: dict[str, str] = dict(mcp_headers.get(mcp_server.name, {}))
+ if mcp_server.name in mcp_headers:
+ raw_headers = mcp_headers[mcp_server.name]
+ elif mcp_server.url in mcp_headers:
+ raw_headers = mcp_headers[mcp_server.url]
+ else:
+ raw_headers = {}
+
+ if not isinstance(raw_headers, Mapping):
+ logger.warning(
+ "Ignoring invalid MCP headers for %s: expected object, got %s",
+ mcp_server.name,
+ type(raw_headers).__name__,
+ )
+ raw_headers = {}
+
+ server_headers = {
+ (
+ "Authorization"
+ if header_name.lower() == "authorization"
+ else header_name
+ ): header_value
+ for header_name, header_value in raw_headers.items()
+ if isinstance(header_name, str) and isinstance(header_value, str)
+ }
existing_lower = {k.lower() for k in server_headers}🤖 Prompt for AI Agents |
||
|
|
||
| for ( | ||
| header_name, | ||
| resolved_value, | ||
| ) in mcp_server.resolved_authorization_headers.items(): | ||
| if header_name.lower() in existing_lower: | ||
| continue | ||
| match resolved_value: | ||
| case constants.MCP_AUTH_KUBERNETES: | ||
| if token: | ||
| server_headers[header_name] = f"Bearer {token}" | ||
| existing_lower.add(header_name.lower()) | ||
| case constants.MCP_AUTH_CLIENT | constants.MCP_AUTH_OAUTH: | ||
| pass # client-provided; already included via the initial mcp_headers copy | ||
| case _: | ||
| server_headers[header_name] = resolved_value | ||
| existing_lower.add(header_name.lower()) | ||
|
|
||
| # Propagate allowlisted headers from the incoming request. | ||
| if mcp_server.headers and request_headers is not None: | ||
| propagated = extract_propagated_headers(mcp_server, request_headers) | ||
| for h_name, h_value in propagated.items(): | ||
| if h_name.lower() not in existing_lower: | ||
| server_headers[h_name] = h_value | ||
| existing_lower.add(h_name.lower()) | ||
|
|
||
| if server_headers: | ||
| complete[mcp_server.name] = server_headers | ||
|
|
||
| return complete | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,7 +42,16 @@ async def check_mcp_auth(configuration: AppConfig, mcp_headers: McpHeaders) -> N | |
| probes = [] | ||
| for mcp_server in configuration.mcp_servers: | ||
| headers = mcp_headers.get(mcp_server.name, {}) | ||
| authorization = headers.get("Authorization", None) | ||
| auth_header = headers.get("Authorization") | ||
| if auth_header: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is not None - you don't want to pass with empty auth header
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The condition passes only if the |
||
| authorization = ( | ||
| auth_header | ||
| if auth_header.startswith("Bearer ") | ||
| else f"Bearer {auth_header}" | ||
| ) | ||
| else: | ||
| authorization = None | ||
|
|
||
| if ( | ||
| authorization | ||
| or constants.MCP_AUTH_OAUTH | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,7 +50,7 @@ | |
| NotFoundResponse, | ||
| ServiceUnavailableResponse, | ||
| ) | ||
| from utils.mcp_headers import McpHeaders, extract_propagated_headers | ||
| from utils.mcp_headers import McpHeaders, build_mcp_headers | ||
| from utils.prompts import get_system_prompt, get_topic_summary_system_prompt | ||
| from utils.query import ( | ||
| extract_provider_and_model_from_model_id, | ||
|
|
@@ -437,17 +437,21 @@ def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[InputToolFileSea | |
| ] | ||
|
|
||
|
|
||
| async def get_mcp_tools( # pylint: disable=too-many-return-statements,too-many-locals | ||
| async def get_mcp_tools( | ||
| token: Optional[str] = None, | ||
| mcp_headers: Optional[McpHeaders] = None, | ||
| request_headers: Optional[Mapping[str, str]] = None, | ||
| ) -> list[InputToolMCP]: | ||
| """Convert MCP servers to tools format for Responses API. | ||
|
|
||
| Fully delegates header assembly to ``build_mcp_headers``, which handles static | ||
| config tokens, the kubernetes Bearer token, client/oauth client-provided headers, | ||
| and propagated request headers. | ||
|
|
||
| Args: | ||
| token: Optional authentication token for MCP server authorization | ||
| mcp_headers: Optional per-request headers for MCP servers, keyed by server URL | ||
| request_headers: Optional incoming HTTP request headers for allowlist propagation | ||
| token: Optional Kubernetes service-account token for ``kubernetes`` auth headers. | ||
| mcp_headers: Optional per-request headers for MCP servers, keyed by server name. | ||
| request_headers: Optional incoming HTTP request headers for allowlist propagation. | ||
|
|
||
| Returns: | ||
| List of MCP tool definitions with server details and optional auth. When | ||
|
|
@@ -458,68 +462,30 @@ async def get_mcp_tools( # pylint: disable=too-many-return-statements,too-many- | |
| HTTPException: 401 with WWW-Authenticate header when an MCP server uses OAuth, | ||
| no headers are passed, and the server responds with 401 and WWW-Authenticate. | ||
| """ | ||
|
|
||
| def _get_token_value(original: str, header: str) -> Optional[str]: | ||
| """Convert to header value.""" | ||
| match original: | ||
| case constants.MCP_AUTH_KUBERNETES: | ||
| # use k8s token | ||
| if token is None or token == "": | ||
| return None | ||
| return f"Bearer {token}" | ||
| case constants.MCP_AUTH_CLIENT: | ||
| # use client provided token | ||
| if mcp_headers is None: | ||
| return None | ||
| c_headers = mcp_headers.get(mcp_server.name, None) | ||
| if c_headers is None: | ||
| return None | ||
| return c_headers.get(header, None) | ||
| case constants.MCP_AUTH_OAUTH: | ||
| # use oauth token | ||
| if mcp_headers is None: | ||
| return None | ||
| c_headers = mcp_headers.get(mcp_server.name, None) | ||
| if c_headers is None: | ||
| return None | ||
| return c_headers.get(header, None) | ||
| case _: | ||
| # use provided | ||
| return original | ||
| complete_headers = build_mcp_headers( | ||
| configuration, mcp_headers or {}, request_headers, token | ||
| ) | ||
|
|
||
| tools: list[InputToolMCP] = [] | ||
| for mcp_server in configuration.mcp_servers: | ||
| # Build headers | ||
| headers: dict[str, str] = {} | ||
| for name, value in mcp_server.resolved_authorization_headers.items(): | ||
| # for each defined header | ||
| h_value = _get_token_value(value, name) | ||
| # only add the header if we got value | ||
| if h_value is not None: | ||
| headers[name] = h_value | ||
|
|
||
| # Skip server if auth headers were configured but not all could be resolved | ||
| if mcp_server.authorization_headers and len(headers) != len( | ||
| mcp_server.authorization_headers | ||
| ): | ||
| logger.warning( | ||
| "Skipping MCP server %s: required %d auth headers but only resolved %d", | ||
| mcp_server.name, | ||
| len(mcp_server.authorization_headers), | ||
| len(headers), | ||
| ) | ||
| continue | ||
|
|
||
| # Propagate allowlisted headers from the incoming request | ||
| if mcp_server.headers and request_headers is not None: | ||
| propagated = extract_propagated_headers(mcp_server, request_headers) | ||
| existing_lower = {name.lower() for name in headers} | ||
| for h_name, h_value in propagated.items(): | ||
| if h_name.lower() not in existing_lower: | ||
| headers[h_name] = h_value | ||
| existing_lower.add(h_name.lower()) | ||
| headers: dict[str, str] = dict(complete_headers.get(mcp_server.name, {})) | ||
|
|
||
| # Skip server if any configured auth header could not be resolved. | ||
| if mcp_server.authorization_headers: | ||
| unresolved = [ | ||
| h | ||
| for h in mcp_server.authorization_headers | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ideal four lines for new utility function - easy testable, understandable, usable in more places if needed |
||
| if not any(k.lower() == h.lower() for k in headers) | ||
| ] | ||
| if unresolved: | ||
| logger.warning( | ||
| "Skipping MCP server %s: required %d auth headers but only resolved %d", | ||
| mcp_server.name, | ||
| len(mcp_server.authorization_headers), | ||
| len(mcp_server.authorization_headers) - len(unresolved), | ||
| ) | ||
| continue | ||
|
|
||
| # Build Authorization header | ||
| authorization = headers.pop("Authorization", None) | ||
| tools.append( | ||
| InputToolMCP( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be nice to refactor these nested loops. At least move the internal loop into it's own function - then you will be basically "forced" to name params etc. And it will be better testable