Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,13 @@ path = "tests/test_sampling.rs"
name = "test_close_connection"
required-features = ["server", "client"]
path = "tests/test_close_connection.rs"

[[test]]
name = "test_custom_headers"
required-features = [
"client",
"server",
"transport-streamable-http-client-reqwest",
"transport-streamable-http-server",
]
path = "tests/test_custom_headers.rs"
8 changes: 5 additions & 3 deletions crates/rmcp/src/transport/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use thiserror::Error;
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, error, warn};

use crate::transport::common::http_header::HEADER_MCP_PROTOCOL_VERSION;

const DEFAULT_EXCHANGE_URL: &str = "http://localhost";

/// Stored credentials for OAuth2 authorization
Expand Down Expand Up @@ -1051,7 +1053,7 @@ impl AuthorizationManager {
let response = match self
.http_client
.get(discovery_url.clone())
.header("MCP-Protocol-Version", "2024-11-05")
.header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05")
.send()
.await
{
Expand Down Expand Up @@ -1171,7 +1173,7 @@ impl AuthorizationManager {
let response = match self
.http_client
.get(url.clone())
.header("MCP-Protocol-Version", "2024-11-05")
.header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05")
.send()
.await
{
Expand Down Expand Up @@ -1224,7 +1226,7 @@ impl AuthorizationManager {
let response = match self
.http_client
.get(resource_metadata_url.clone())
.header("MCP-Protocol-Version", "2024-11-05")
.header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05")
.send()
.await
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use std::collections::HashMap;

use http::{HeaderName, HeaderValue};

use crate::transport::{
auth::AuthClient,
streamable_http_client::{StreamableHttpClient, StreamableHttpError},
Expand Down Expand Up @@ -47,6 +51,7 @@ where
message: crate::model::ClientJsonRpcMessage,
session_id: Option<std::sync::Arc<str>>,
mut auth_token: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<
crate::transport::streamable_http_client::StreamableHttpPostResponse,
StreamableHttpError<Self::Error>,
Expand All @@ -55,7 +60,7 @@ where
auth_token = Some(self.get_access_token().await?);
}
self.http_client
.post_message(uri, message, session_id, auth_token)
.post_message(uri, message, session_id, auth_token, custom_headers)
.await
}
}
1 change: 1 addition & 0 deletions crates/rmcp/src/transport/common/http_header.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
pub const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id";
pub const HEADER_MCP_PROTOCOL_VERSION: &str = "MCP-Protocol-Version";
pub const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
pub const JSON_MIME_TYPE: &str = "application/json";
28 changes: 25 additions & 3 deletions crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use std::{borrow::Cow, sync::Arc};
use std::{borrow::Cow, collections::HashMap, sync::Arc};

use futures::{StreamExt, stream::BoxStream};
use http::header::WWW_AUTHENTICATE;
use http::{HeaderName, HeaderValue, header::WWW_AUTHENTICATE};
use reqwest::header::ACCEPT;
use sse_stream::{Sse, SseStream};

use crate::{
model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
transport::{
common::http_header::{
EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE,
EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION,
HEADER_SESSION_ID, JSON_MIME_TYPE,
},
streamable_http_client::*,
},
Expand Down Expand Up @@ -94,13 +95,34 @@ impl StreamableHttpClient for reqwest::Client {
message: ClientJsonRpcMessage,
session_id: Option<Arc<str>>,
auth_token: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>> {
let mut request = self
.post(uri.as_ref())
.header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "));
if let Some(auth_header) = auth_token {
request = request.bearer_auth(auth_header);
}

// Apply custom headers
let reserved_headers = [
ACCEPT.as_str(),
HEADER_SESSION_ID,
HEADER_MCP_PROTOCOL_VERSION,
HEADER_LAST_EVENT_ID,
];
for (name, value) in custom_headers {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps disallow any of the headers that are controlled by the default client logic?

ACCEPT HEADER_SESSION_ID etc

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would check https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#streamable-http for a full list

MCP-Protocol-Version is another one that comes to mind

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I've updated the PR to add the check.

if reserved_headers
.iter()
.any(|&r| name.as_str().eq_ignore_ascii_case(r))
{
return Err(StreamableHttpError::ReservedHeaderConflict(
name.to_string(),
));
}

request = request.header(name, value);
}
if let Some(session_id) = session_id {
request = request.header(HEADER_SESSION_ID, session_id.as_ref());
}
Expand Down
45 changes: 44 additions & 1 deletion crates/rmcp/src/transport/streamable_http_client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{borrow::Cow, sync::Arc, time::Duration};
use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};

use futures::{Stream, StreamExt, future::BoxFuture, stream::BoxStream};
use http::{HeaderName, HeaderValue};
pub use sse_stream::Error as SseError;
use sse_stream::Sse;
use thiserror::Error;
Expand Down Expand Up @@ -76,6 +77,8 @@ pub enum StreamableHttpError<E: std::error::Error + Send + Sync + 'static> {
AuthRequired(AuthRequiredError),
#[error("Insufficient scope")]
InsufficientScope(InsufficientScopeError),
#[error("Header name '{0}' is reserved and conflicts with default headers")]
ReservedHeaderConflict(String),
}

#[derive(Debug, Clone, Error)]
Expand Down Expand Up @@ -173,6 +176,7 @@ pub trait StreamableHttpClient: Clone + Send + 'static {
message: ClientJsonRpcMessage,
session_id: Option<Arc<str>>,
auth_header: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> impl Future<Output = Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>>>
+ Send
+ '_;
Expand Down Expand Up @@ -324,6 +328,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
initialize_request,
None,
self.config.auth_header,
self.config.custom_headers,
)
.await
{
Expand Down Expand Up @@ -372,6 +377,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
initialized_notification.message,
session_id.clone(),
config.auth_header.clone(),
config.custom_headers.clone(),
)
.await
.map_err(WorkerQuitReason::fatal_context(
Expand Down Expand Up @@ -477,6 +483,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
message,
session_id.clone(),
config.auth_header.clone(),
config.custom_headers.clone(),
)
.await;
let send_result = match response {
Expand Down Expand Up @@ -609,8 +616,10 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
/// StreamableHttpClientTransportConfig
/// };
/// use std::sync::Arc;
/// use std::collections::HashMap;
/// use futures::stream::BoxStream;
/// use rmcp::model::ClientJsonRpcMessage;
/// use http::{HeaderName, HeaderValue};
/// use sse_stream::{Sse, Error as SseError};
///
/// #[derive(Clone)]
Expand All @@ -634,6 +643,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
/// _message: ClientJsonRpcMessage,
/// _session_id: Option<Arc<str>>,
/// _auth_header: Option<String>,
/// _custom_headers: HashMap<HeaderName, HeaderValue>,
/// ) -> Result<rmcp::transport::streamable_http_client::StreamableHttpPostResponse, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
/// todo!()
/// }
Expand Down Expand Up @@ -690,8 +700,10 @@ impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
/// StreamableHttpClientTransportConfig
/// };
/// use std::sync::Arc;
/// use std::collections::HashMap;
/// use futures::stream::BoxStream;
/// use rmcp::model::ClientJsonRpcMessage;
/// use http::{HeaderName, HeaderValue};
/// use sse_stream::{Sse, Error as SseError};
///
/// // Define your custom client
Expand All @@ -716,6 +728,7 @@ impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
/// _message: ClientJsonRpcMessage,
/// _session_id: Option<Arc<str>>,
/// _auth_header: Option<String>,
/// _custom_headers: HashMap<HeaderName, HeaderValue>,
/// ) -> Result<rmcp::transport::streamable_http_client::StreamableHttpPostResponse, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
/// todo!()
/// }
Expand Down Expand Up @@ -759,6 +772,8 @@ pub struct StreamableHttpClientTransportConfig {
pub allow_stateless: bool,
/// The value to send in the authorization header
pub auth_header: Option<String>,
/// Custom HTTP headers to include with every request
pub custom_headers: HashMap<HeaderName, HeaderValue>,
}

impl StreamableHttpClientTransportConfig {
Expand All @@ -779,6 +794,33 @@ impl StreamableHttpClientTransportConfig {
self.auth_header = Some(value.into());
self
}

/// Set custom HTTP headers to include with every request
///
/// # Arguments
///
/// * `custom_headers` - A HashMap of header names to header values
///
/// # Example
///
/// ```rust,no_run
/// use std::collections::HashMap;
/// use http::{HeaderName, HeaderValue};
/// use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
///
/// let mut headers = HashMap::new();
/// headers.insert(
/// HeaderName::from_static("x-custom-header"),
/// HeaderValue::from_static("custom-value")
/// );
///
/// let config = StreamableHttpClientTransportConfig::with_uri("http://localhost:8000")
/// .custom_headers(headers);
/// ```
pub fn custom_headers(mut self, custom_headers: HashMap<HeaderName, HeaderValue>) -> Self {
self.custom_headers = custom_headers;
self
}
}

impl Default for StreamableHttpClientTransportConfig {
Expand All @@ -789,6 +831,7 @@ impl Default for StreamableHttpClientTransportConfig {
channel_buffer_capacity: 16,
allow_stateless: true,
auth_header: None,
custom_headers: HashMap::new(),
}
}
}
Loading