Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docker-compose-library.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ services:
- ./run.yaml:/app-root/run.yaml:Z
- ${GCP_KEYS_PATH:-./tmp/.gcp-keys-dummy}:/opt/app-root/.gcp-keys:ro
- ./tests/e2e/rag:/opt/app-root/src/.llama/storage/rag:Z
- ./tests/e2e/secrets/mcp-token:/tmp/mcp-token:ro
- ./tests/e2e/secrets/invalid-mcp-token:/tmp/invalid-mcp-token:ro
environment:
# LLM Provider API Keys
- BRAVE_SEARCH_API_KEY=${BRAVE_SEARCH_API_KEY:-}
Expand Down
4 changes: 2 additions & 2 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ services:
- "8080:8080"
volumes:
- ./lightspeed-stack.yaml:/app-root/lightspeed-stack.yaml:z
- ./tests/e2e/secrets/mcp-token:/tmp/mcp-secret-token:ro
- ./tests/e2e/secrets/invalid-mcp-token:/tmp/invalid-mcp-secret-token:ro
- ./tests/e2e/secrets/mcp-token:/tmp/mcp-token:ro
- ./tests/e2e/secrets/invalid-mcp-token:/tmp/invalid-mcp-token:ro
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
# Azure Entra ID credentials (AZURE_API_KEY is obtained dynamically)
Expand Down
13 changes: 8 additions & 5 deletions src/app/endpoints/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
UnauthorizedResponse,
)
from utils.endpoints import check_configuration_loaded
from utils.mcp_headers import McpHeaders, mcp_headers_dependency
from utils.mcp_headers import McpHeaders, build_mcp_headers, mcp_headers_dependency
from utils.mcp_oauth_probe import check_mcp_auth
from utils.tool_formatter import format_tools_list
from log import get_logger
Expand Down Expand Up @@ -115,15 +115,18 @@ async def tools_endpoint_handler( # pylint: disable=too-many-locals,too-many-st
ToolsResponse: An object containing the consolidated list of available tools
with metadata including tool name, description, parameters, and server source.
"""
# Used only by the middleware
_ = auth
_, _, _, token = auth

# Nothing interesting in the request
_ = request

check_configuration_loaded(configuration)

await check_mcp_auth(configuration, mcp_headers)
complete_mcp_headers = build_mcp_headers(
configuration, mcp_headers, request.headers, token
)

await check_mcp_auth(configuration, complete_mcp_headers)

toolgroups_response = []
try:
Expand All @@ -145,7 +148,7 @@ async def tools_endpoint_handler( # pylint: disable=too-many-locals,too-many-st
for toolgroup in toolgroups_response:
try:
# Get tools for each toolgroup
headers = mcp_headers.get(toolgroup.identifier, {})
headers = complete_mcp_headers.get(toolgroup.identifier, {})
authorization = headers.pop("Authorization", None)

tools_response = await client.tools.list(
Expand Down
73 changes: 73 additions & 0 deletions src/utils/mcp_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import json
from collections.abc import Mapping
from typing import Optional
from urllib.parse import urlparse

from fastapi import Request

import constants
from configuration import AppConfig
from log import get_logger
from models.config import ModelContextProtocolServer
Expand Down Expand Up @@ -121,3 +123,74 @@ def extract_propagated_headers(
if value is not None:
propagated[header_name] = value
return propagated


def build_mcp_headers(
config: AppConfig,
mcp_headers: McpHeaders,
request_headers: Optional[Mapping[str, str]],
token: Optional[str] = None,
) -> McpHeaders:
"""Build complete MCP headers by merging all header sources for each MCP server.

For each configured MCP server, combines four header sources (in priority order,
highest first):

1. Client-supplied headers from the ``MCP-HEADERS`` request header (keyed by server name).
2. Statically resolved authorization headers from configuration (e.g. file-based secrets).
3. Kubernetes Bearer token: when a header is configured with the ``kubernetes`` keyword,
the supplied ``token`` is formatted as ``Bearer <token>`` and used as its value.
``client`` and ``oauth`` keywords are not resolved here — those values are already
provided by the client in source 1.
4. Headers propagated from the incoming request via the server's configured allowlist.

Args:
config: Application configuration containing mcp_servers.
mcp_headers: Per-request headers from the client, keyed by MCP server name.
request_headers: Headers from the incoming HTTP request used for allowlist
propagation, or ``None`` when not available.
token: Optional Kubernetes service-account token used to resolve headers
configured with the ``kubernetes`` keyword.

Returns:
McpHeaders keyed by MCP server name with the complete merged set of headers.
Servers that end up with no headers are omitted from the result.
"""
if not config.mcp_servers:
return {}

complete: McpHeaders = {}

for mcp_server in config.mcp_servers:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to refactor these nested loops. At least move the internal loop into it's own function - then you will be basically "forced" to name params etc. And it will be better testable

server_headers: dict[str, str] = dict(mcp_headers.get(mcp_server.name, {}))
existing_lower = {k.lower() for k in server_headers}
Comment on lines +164 to +166
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Normalize and validate raw MCP-HEADERS before copying them.

This lookup assumes the incoming entry is already a dict[str, str] keyed by mcp_server.name. In this module, extract_mcp_headers() only validates the top-level JSON object and handle_mcp_headers_with_toolgroups() still documents URL/toolgroup keyed payloads. That means {"server": "token"} can blow up at dict(...), while {"https://server": {...}} is silently dropped. Normalizing the supported key shapes and filtering to string-to-string entries here will avoid 500s and keep downstream Authorization handling consistent.

Possible fix
     complete: McpHeaders = {}

     for mcp_server in config.mcp_servers:
-        server_headers: dict[str, str] = dict(mcp_headers.get(mcp_server.name, {}))
+        if mcp_server.name in mcp_headers:
+            raw_headers = mcp_headers[mcp_server.name]
+        elif mcp_server.url in mcp_headers:
+            raw_headers = mcp_headers[mcp_server.url]
+        else:
+            raw_headers = {}
+
+        if not isinstance(raw_headers, Mapping):
+            logger.warning(
+                "Ignoring invalid MCP headers for %s: expected object, got %s",
+                mcp_server.name,
+                type(raw_headers).__name__,
+            )
+            raw_headers = {}
+
+        server_headers = {
+            (
+                "Authorization"
+                if header_name.lower() == "authorization"
+                else header_name
+            ): header_value
+            for header_name, header_value in raw_headers.items()
+            if isinstance(header_name, str) and isinstance(header_value, str)
+        }
         existing_lower = {k.lower() for k in server_headers}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/utils/mcp_headers.py` around lines 164 - 166, The loop over
config.mcp_servers assumes mcp_headers[mcp_server.name] is already a
dict[str,str]; instead, before calling dict(...) normalize and validate raw
mcp_headers entries (mcp_headers variable used in the loop) so only recognized
key shapes (server name or URL forms produced by
extract_mcp_headers()/handle_mcp_headers_with_toolgroups) and string-to-string
mappings are accepted; specifically, if mcp_headers.get(mcp_server.name) (or an
entry keyed by a URL form) is not a mapping, treat it as absent, and when it is
a mapping filter out non-str keys/values and ignore nested non-string payloads,
then build server_headers from that sanitized mapping (adjusting any URL-keyed
entries to mcp_server.name). This change should be made around the loop that
defines server_headers and existing_lower to avoid TypeErrors and silently
dropping URL-keyed payloads.


for (
header_name,
resolved_value,
) in mcp_server.resolved_authorization_headers.items():
if header_name.lower() in existing_lower:
continue
match resolved_value:
case constants.MCP_AUTH_KUBERNETES:
if token:
server_headers[header_name] = f"Bearer {token}"
existing_lower.add(header_name.lower())
case constants.MCP_AUTH_CLIENT | constants.MCP_AUTH_OAUTH:
pass # client-provided; already included via the initial mcp_headers copy
case _:
server_headers[header_name] = resolved_value
existing_lower.add(header_name.lower())

# Propagate allowlisted headers from the incoming request.
if mcp_server.headers and request_headers is not None:
propagated = extract_propagated_headers(mcp_server, request_headers)
for h_name, h_value in propagated.items():
if h_name.lower() not in existing_lower:
server_headers[h_name] = h_value
existing_lower.add(h_name.lower())

if server_headers:
complete[mcp_server.name] = server_headers

return complete
11 changes: 10 additions & 1 deletion src/utils/mcp_oauth_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,16 @@ async def check_mcp_auth(configuration: AppConfig, mcp_headers: McpHeaders) -> N
probes = []
for mcp_server in configuration.mcp_servers:
headers = mcp_headers.get(mcp_server.name, {})
authorization = headers.get("Authorization", None)
auth_header = headers.get("Authorization")
if auth_header:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is not None - you don't want to pass with empty auth header

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition passes only if the auth_header is present and is non-empty, right? Isn't this what we actually want?

authorization = (
auth_header
if auth_header.startswith("Bearer ")
else f"Bearer {auth_header}"
)
else:
authorization = None

if (
authorization
or constants.MCP_AUTH_OAUTH
Expand Down
92 changes: 29 additions & 63 deletions src/utils/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
NotFoundResponse,
ServiceUnavailableResponse,
)
from utils.mcp_headers import McpHeaders, extract_propagated_headers
from utils.mcp_headers import McpHeaders, build_mcp_headers
from utils.prompts import get_system_prompt, get_topic_summary_system_prompt
from utils.query import (
extract_provider_and_model_from_model_id,
Expand Down Expand Up @@ -437,17 +437,21 @@ def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[InputToolFileSea
]


async def get_mcp_tools( # pylint: disable=too-many-return-statements,too-many-locals
async def get_mcp_tools(
token: Optional[str] = None,
mcp_headers: Optional[McpHeaders] = None,
request_headers: Optional[Mapping[str, str]] = None,
) -> list[InputToolMCP]:
"""Convert MCP servers to tools format for Responses API.

Fully delegates header assembly to ``build_mcp_headers``, which handles static
config tokens, the kubernetes Bearer token, client/oauth client-provided headers,
and propagated request headers.

Args:
token: Optional authentication token for MCP server authorization
mcp_headers: Optional per-request headers for MCP servers, keyed by server URL
request_headers: Optional incoming HTTP request headers for allowlist propagation
token: Optional Kubernetes service-account token for ``kubernetes`` auth headers.
mcp_headers: Optional per-request headers for MCP servers, keyed by server name.
request_headers: Optional incoming HTTP request headers for allowlist propagation.

Returns:
List of MCP tool definitions with server details and optional auth. When
Expand All @@ -458,68 +462,30 @@ async def get_mcp_tools( # pylint: disable=too-many-return-statements,too-many-
HTTPException: 401 with WWW-Authenticate header when an MCP server uses OAuth,
no headers are passed, and the server responds with 401 and WWW-Authenticate.
"""

def _get_token_value(original: str, header: str) -> Optional[str]:
"""Convert to header value."""
match original:
case constants.MCP_AUTH_KUBERNETES:
# use k8s token
if token is None or token == "":
return None
return f"Bearer {token}"
case constants.MCP_AUTH_CLIENT:
# use client provided token
if mcp_headers is None:
return None
c_headers = mcp_headers.get(mcp_server.name, None)
if c_headers is None:
return None
return c_headers.get(header, None)
case constants.MCP_AUTH_OAUTH:
# use oauth token
if mcp_headers is None:
return None
c_headers = mcp_headers.get(mcp_server.name, None)
if c_headers is None:
return None
return c_headers.get(header, None)
case _:
# use provided
return original
complete_headers = build_mcp_headers(
configuration, mcp_headers or {}, request_headers, token
)

tools: list[InputToolMCP] = []
for mcp_server in configuration.mcp_servers:
# Build headers
headers: dict[str, str] = {}
for name, value in mcp_server.resolved_authorization_headers.items():
# for each defined header
h_value = _get_token_value(value, name)
# only add the header if we got value
if h_value is not None:
headers[name] = h_value

# Skip server if auth headers were configured but not all could be resolved
if mcp_server.authorization_headers and len(headers) != len(
mcp_server.authorization_headers
):
logger.warning(
"Skipping MCP server %s: required %d auth headers but only resolved %d",
mcp_server.name,
len(mcp_server.authorization_headers),
len(headers),
)
continue

# Propagate allowlisted headers from the incoming request
if mcp_server.headers and request_headers is not None:
propagated = extract_propagated_headers(mcp_server, request_headers)
existing_lower = {name.lower() for name in headers}
for h_name, h_value in propagated.items():
if h_name.lower() not in existing_lower:
headers[h_name] = h_value
existing_lower.add(h_name.lower())
headers: dict[str, str] = dict(complete_headers.get(mcp_server.name, {}))

# Skip server if any configured auth header could not be resolved.
if mcp_server.authorization_headers:
unresolved = [
h
for h in mcp_server.authorization_headers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ideal four lines for new utility function - easy testable, understandable, usable in more places if needed

if not any(k.lower() == h.lower() for k in headers)
]
if unresolved:
logger.warning(
"Skipping MCP server %s: required %d auth headers but only resolved %d",
mcp_server.name,
len(mcp_server.authorization_headers),
len(mcp_server.authorization_headers) - len(unresolved),
)
continue

# Build Authorization header
authorization = headers.pop("Authorization", None)
tools.append(
InputToolMCP(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ mcp_servers:
- name: "mcp-file"
url: "http://mock-mcp:3001"
authorization_headers:
Authorization: "/tmp/invalid-mcp-secret-token"
Authorization: "/tmp/invalid-mcp-token"
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ mcp_servers:
- name: "mcp-file"
url: "http://mock-mcp:3001"
authorization_headers:
Authorization: "/tmp/mcp-secret-token"
Authorization: "/tmp/mcp-token"
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ user_data_collection:
transcripts_enabled: true
transcripts_storage: "/tmp/data/transcripts"
authentication:
module: "noop"
module: "noop-with-token"
mcp_servers:
- name: "mcp-kubernetes"
url: "http://mock-mcp:3001"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ mcp_servers:
- name: "mcp-file"
url: "http://mock-mcp:3001"
authorization_headers:
Authorization: "/tmp/mcp-secret-token"
Authorization: "/tmp/mcp-token"
- name: "mcp-client"
url: "http://mock-mcp:3001"
authorization_headers:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ mcp_servers:
- name: "mcp-file"
url: "http://mock-mcp:3001"
authorization_headers:
Authorization: "/tmp/invalid-mcp-secret-token"
Authorization: "/tmp/invalid-mcp-token"
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ mcp_servers:
- name: "mcp-file"
url: "http://mock-mcp:3001"
authorization_headers:
Authorization: "/tmp/mcp-secret-token"
Authorization: "/tmp/mcp-token"
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ user_data_collection:
transcripts_enabled: true
transcripts_storage: "/tmp/data/transcripts"
authentication:
module: "noop"
module: "noop-with-token"
mcp_servers:
- name: "mcp-kubernetes"
url: "http://mock-mcp:3001"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ mcp_servers:
- name: "mcp-file"
url: "http://mock-mcp:3001"
authorization_headers:
Authorization: "/tmp/mcp-secret-token"
Authorization: "/tmp/mcp-token"
- name: "mcp-client"
url: "http://mock-mcp:3001"
authorization_headers:
Expand Down
11 changes: 0 additions & 11 deletions tests/e2e/features/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,12 +435,6 @@ def before_feature(context: Context, feature: Feature) -> None:
switch_config(context.feature_config)
restart_container("lightspeed-stack")

if "MCPFileAuth" in feature.tags:
context.feature_config = _get_config_path("mcp-file-auth", mode_dir)
context.default_config_backup = create_config_backup("lightspeed-stack.yaml")
switch_config(context.feature_config)
restart_container("lightspeed-stack")


def after_feature(context: Context, feature: Feature) -> None:
"""Run after each feature file is exercised.
Expand Down Expand Up @@ -473,8 +467,3 @@ def after_feature(context: Context, feature: Feature) -> None:
switch_config(context.default_config_backup)
restart_container("lightspeed-stack")
remove_config_backup(context.default_config_backup)

if "MCPFileAuth" in feature.tags:
switch_config(context.default_config_backup)
restart_container("lightspeed-stack")
remove_config_backup(context.default_config_backup)
Loading
Loading