[MCP Server] Support propagation of headers into mcp tool call (to pass a project with headers)

Pass ClientInfo into MCP call

(cherry picked from commit aa91dac054f062664c1eb436e2f7dc6de8dd8e3b)

GitOrigin-RevId: 2250cc7cf91b28d3e148f3b15597e94346ce7711
This commit is contained in:
Artem.Bukhonov
2025-07-17 21:47:57 +02:00
committed by intellij-monorepo-bot
parent 33193fcd38
commit b30fa93446
5 changed files with 191 additions and 48 deletions

View File

@@ -0,0 +1,13 @@
package com.intellij.mcpserver
import kotlin.coroutines.CoroutineContext
class ClientInfo(val name: String, val version: String)
internal class ClientInfoElement(val info: ClientInfo?) : CoroutineContext.Element {
companion object Key : CoroutineContext.Key<ClientInfoElement>
override val key: CoroutineContext.Key<*> = Key
}
val CoroutineContext.clientInfoOrNull: ClientInfo? get() = get(ClientInfoElement.Key)?.info
val CoroutineContext.clientInfo: ClientInfo get() = get(ClientInfoElement.Key)?.info ?: ClientInfo("Unknown client", "Unknown version")

View File

@@ -1,8 +1,7 @@
package com.intellij.mcpserver.impl
import com.intellij.mcpserver.*
import com.intellij.mcpserver.impl.util.network.findFirstFreePort
import com.intellij.mcpserver.impl.util.network.installHostValidation
import com.intellij.mcpserver.impl.util.network.*
import com.intellij.mcpserver.settings.McpServerSettings
import com.intellij.mcpserver.statistics.McpServerCounterUsagesCollector
import com.intellij.mcpserver.stdio.IJ_MCP_SERVER_PROJECT_PATH
@@ -39,14 +38,10 @@ import io.modelcontextprotocol.kotlin.sdk.*
import io.modelcontextprotocol.kotlin.sdk.server.RegisteredTool
import io.modelcontextprotocol.kotlin.sdk.server.Server
import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
import io.modelcontextprotocol.kotlin.sdk.server.mcp
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.collectLatest
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.serialization.json.JsonPrimitive
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CopyOnWriteArrayList
@@ -136,35 +131,39 @@ class McpServerService(val cs: CoroutineScope) {
}
})
val mcpServer = Server(
Implementation(
name = "${ApplicationNamesInfo.getInstance().fullProductName} MCP Server",
version = ApplicationInfo.getInstance().fullVersion
),
ServerOptions(
capabilities = ServerCapabilities(
//prompts = ServerCapabilities.Prompts(listChanged = true),
//resources = ServerCapabilities.Resources(subscribe = true, listChanged = true),
tools = ServerCapabilities.Tools(listChanged = true),
)
)
)
cs.launch {
var previousTools: List<RegisteredTool>? = null
mcpTools.collectLatest { updatedTools ->
previousTools?.forEach { previousTool ->
mcpServer.removeTool(previousTool.tool.name)
}
mcpServer.addTools(updatedTools)
previousTools = updatedTools
}
}
return cs.embeddedServer(CIO, host = "127.0.0.1", port = freePort) {
installHostValidation()
mcp {
return@mcp mcpServer
installHttpRequestPropagation()
mcpPatched {
// this is added because now Kotlin MCP client doesn't support header adjusting for each request, only for initial one, see McpStdioRunner
val projectPath = call.request.headers[IJ_MCP_SERVER_PROJECT_PATH]
val mcpServer = Server(
Implementation(
name = "${ApplicationNamesInfo.getInstance().fullProductName} MCP Server",
version = ApplicationInfo.getInstance().fullVersion
),
ServerOptions(
capabilities = ServerCapabilities(
//prompts = ServerCapabilities.Prompts(listChanged = true),
//resources = ServerCapabilities.Resources(subscribe = true, listChanged = true),
tools = ServerCapabilities.Tools(listChanged = true),
)
)
)
launch {
var previousTools: List<McpTool>? = null
mcpTools.collectLatest { updatedTools ->
previousTools?.forEach { previousTool ->
mcpServer.removeTool(previousTool.descriptor.name)
}
mcpServer.addTools(updatedTools.map { it.mcpToolToRegisteredTool(mcpServer, projectPath) })
previousTools = updatedTools
}
}
return@mcpPatched mcpServer
}
}.start(wait = false)
}
@@ -177,9 +176,9 @@ class McpServerService(val cs: CoroutineScope) {
logger.error("Cannot load tools for $it", e)
emptyList()
}
}.map { it.mcpToolToRegisteredTool() }
}
private fun McpTool.mcpToolToRegisteredTool(): RegisteredTool {
private fun McpTool.mcpToolToRegisteredTool(server: Server, projectPathFromInitialRequest: String?): RegisteredTool {
val tool = Tool(name = descriptor.name,
description = descriptor.description,
inputSchema = Tool.Input(
@@ -188,7 +187,8 @@ private fun McpTool.mcpToolToRegisteredTool(): RegisteredTool {
outputSchema = null,
annotations = null)
return RegisteredTool(tool) { request ->
val projectPath = (request._meta[IJ_MCP_SERVER_PROJECT_PATH] as? JsonPrimitive)?.content
val httpRequest = currentCoroutineContext().httpRequestOrNull
val projectPath = httpRequest?.headers?.get(IJ_MCP_SERVER_PROJECT_PATH) ?: (request._meta[IJ_MCP_SERVER_PROJECT_PATH] as? JsonPrimitive)?.content ?: projectPathFromInitialRequest
val project = if (!projectPath.isNullOrBlank()) {
ProjectManager.getInstance().openProjects.find { it.basePath == projectPath }
}
@@ -223,7 +223,13 @@ private fun McpTool.mcpToolToRegisteredTool(): RegisteredTool {
application.messageBus.syncPublisher(ToolCallListener.TOPIC).beforeMcpToolCall(this@mcpToolToRegisteredTool.descriptor)
logger.trace { "Start calling tool '${this@mcpToolToRegisteredTool.descriptor.name}'. Arguments: ${request.arguments}" }
val result = withContext(ProjectContextElement(project) + McpToolDescriptorElement(this@mcpToolToRegisteredTool.descriptor)) {
val clientVersion = server.clientVersion ?: Implementation("Unknown client", "Unknown version")
val result = withContext(
ProjectContextElement(project) +
McpToolDescriptorElement(descriptor) +
ClientInfoElement(ClientInfo(clientVersion.name, clientVersion.version))
) {
this@mcpToolToRegisteredTool.call(request.arguments)
}

View File

@@ -0,0 +1,117 @@
package com.intellij.mcpserver.impl.util.network
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.diagnostic.trace
import io.ktor.http.HttpStatusCode
import io.ktor.server.application.Application
import io.ktor.server.application.ApplicationCallPipeline
import io.ktor.server.application.install
import io.ktor.server.request.ApplicationRequest
import io.ktor.server.response.respond
import io.ktor.server.routing.RoutingContext
import io.ktor.server.routing.post
import io.ktor.server.routing.routing
import io.ktor.server.sse.SSE
import io.ktor.server.sse.ServerSSESession
import io.ktor.server.sse.sse
import io.ktor.util.collections.ConcurrentMap
import io.ktor.utils.io.KtorDsl
import io.modelcontextprotocol.kotlin.sdk.server.Server
import io.modelcontextprotocol.kotlin.sdk.server.SseServerTransport
import kotlinx.coroutines.withContext
import kotlin.coroutines.CoroutineContext
private val logger = logger<RoutingContext>()
/**
* Temporary copied code from MCP SDK to pass thought sse session into handler
*/
@KtorDsl
fun Application.mcpPatched(block: ServerSSESession.() -> Server) {
val transports = ConcurrentMap<String, SseServerTransport>()
install(SSE)
routing {
sse("/sse") {
mcpSseEndpoint("/message", transports, block)
}
post("/message") {
mcpPostEndpoint(transports)
}
}
}
private suspend fun ServerSSESession.mcpSseEndpoint(
postEndpoint: String,
transports: ConcurrentMap<String, SseServerTransport>,
block: ServerSSESession.() -> Server,
) {
val transport = mcpSseTransport(postEndpoint, transports)
val server = this.block()
server.onClose {
logger.trace { "Server connection closed for sessionId: ${transport.sessionId}" }
transports.remove(transport.sessionId)
}
server.connect(transport)
logger.trace { "Server connected to transport for sessionId: ${transport.sessionId}" }
}
internal fun ServerSSESession.mcpSseTransport(
postEndpoint: String,
transports: ConcurrentMap<String, SseServerTransport>,
): SseServerTransport {
val transport = SseServerTransport(postEndpoint, this)
transports[transport.sessionId] = transport
logger.trace { "New SSE connection established and stored with sessionId: ${transport.sessionId}" }
return transport
}
internal suspend fun RoutingContext.mcpPostEndpoint(
transports: ConcurrentMap<String, SseServerTransport>,
) {
val sessionId: String = call.request.queryParameters["sessionId"]
?: run {
call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided")
return
}
logger.trace { "Received message for sessionId: $sessionId" }
val transport = transports[sessionId]
if (transport == null) {
logger.warn("Session not found for sessionId: $sessionId")
call.respond(HttpStatusCode.NotFound, "Session not found")
return
}
transport.handlePostMessage(call)
logger.trace { "Message handled for sessionId: $sessionId" }
}
// your custom context element
class HttpRequestElement(val request: ApplicationRequest) : CoroutineContext.Element {
companion object Key : CoroutineContext.Key<HttpRequestElement>
override val key: CoroutineContext.Key<*> = Key
}
// install interceptor at the Call phase
fun Application.installHttpRequestPropagation() {
intercept(ApplicationCallPipeline.Call) {
// wrap the rest of the pipeline in your element
withContext(HttpRequestElement(this.context.request)) {
proceed() // this continues routing, handlers, etc.
}
}
}
val CoroutineContext.httpRequestOrNull: ApplicationRequest? get() = get(HttpRequestElement)?.request
val CoroutineContext.mcpSessionId: String? get() = httpRequestOrNull?.queryParameters?.get("sessionId")

View File

@@ -2,12 +2,9 @@
package com.intellij.mcpserver.toolsets.terminal
import com.intellij.mcpserver.McpServerBundle
import com.intellij.mcpserver.McpToolset
import com.intellij.mcpserver.*
import com.intellij.mcpserver.annotations.McpDescription
import com.intellij.mcpserver.annotations.McpTool
import com.intellij.mcpserver.project
import com.intellij.mcpserver.reportToolActivity
import com.intellij.mcpserver.toolsets.Constants
import com.intellij.mcpserver.util.TruncateMode
import com.intellij.mcpserver.util.checkUserConfirmationIfNeeded
@@ -58,8 +55,7 @@ class TerminalToolset : McpToolset {
val project = currentCoroutineContext().project
checkUserConfirmationIfNeeded(McpServerBundle.message("label.do.you.want.to.execute.command.in.terminal"), command, project)
// TODO pass from http request later (MCP Client name or something else)
val id = "mcp_session"
val id = currentCoroutineContext().clientInfoOrNull?.name ?: "mcp_session"
val window = ToolWindowManager.getInstance(project).getToolWindow(TerminalToolWindowFactory.TOOL_WINDOW_ID)
return executeShellCommand(window = window,
project = project,

View File

@@ -7,12 +7,14 @@ import com.intellij.execution.process.ProcessEvent
import com.intellij.execution.process.ProcessListener
import com.intellij.execution.process.ProcessOutputTypes
import com.intellij.mcpserver.McpServerBundle
import com.intellij.mcpserver.clientInfoOrNull
import com.intellij.mcpserver.mcpFail
import com.intellij.mcpserver.toolsets.terminal.TerminalToolset.CommandExecutionResult
import com.intellij.mcpserver.util.TruncateMode
import com.intellij.mcpserver.util.truncateText
import com.intellij.openapi.application.EDT
import com.intellij.openapi.project.Project
import com.intellij.openapi.util.Disposer
import com.intellij.openapi.util.Key
import com.intellij.openapi.util.NlsSafe
import com.intellij.openapi.util.io.toNioPathOrNull
@@ -21,10 +23,7 @@ import com.intellij.sh.run.ShConfigurationType
import com.intellij.terminal.TerminalExecutionConsole
import com.intellij.ui.content.ContentFactory
import com.intellij.util.execution.ParametersListUtil
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.coroutines.*
import kotlin.time.Duration
class CommandSession(val sessionId: String, val console: TerminalExecutionConsole)
@@ -87,8 +86,15 @@ suspend fun executeShellCommand(
}
else {
val executionConsole = TerminalExecutionConsole(project, processHandler).withConvertLfToCrlfForNonPtyProcess(true)
val content = ContentFactory.getInstance().createContent(executionConsole.component, McpServerBundle.message("mcp.general.terminal.tab.name"), false)
@Suppress("HardCodedStringLiteral")
val displayName = currentCoroutineContext().clientInfoOrNull?.name ?: McpServerBundle.message ("mcp.general.terminal.tab.name")
val content = ContentFactory.getInstance().createContent(executionConsole.component, displayName, false)
window.contentManager.addContent(content)
Disposer.register(content) {
@Suppress("HardCodedStringLiteral") // visible to LLM only
exitCode.completeExceptionally(ExecutionException("Terminal tab closed by user"))
processHandler.destroyProcess()
}
if (sessionId != null) {
content.putUserData(MCP_TERMINAL_KEY, CommandSession(sessionId, executionConsole))
}
@@ -104,7 +110,12 @@ suspend fun executeShellCommand(
processHandler.startNotify()
val exitCodeValue = withTimeoutOrNull(timeout) {
exitCode.await()
try {
exitCode.await()
}
catch (e: ExecutionException) {
mcpFail("Execution failed: ${e.message}")
}
}
val truncateText = truncateText(text = output.toString(), maxLinesCount = maxLinesCount, truncateMode = truncateMode)
if (exitCodeValue == null) {