From b30fa93446afbfdf4e0543cc45a6c2654834f92e Mon Sep 17 00:00:00 2001 From: "Artem.Bukhonov" Date: Thu, 17 Jul 2025 21:47:57 +0200 Subject: [PATCH] [MCP Server] Support propagation of headers into mcp tool call (to pass a project with headers) Pass ClientInfo into MCP call (cherry picked from commit aa91dac054f062664c1eb436e2f7dc6de8dd8e3b) GitOrigin-RevId: 2250cc7cf91b28d3e148f3b15597e94346ce7711 --- .../src/com/intellij/mcpserver/ClientInfo.kt | 13 ++ .../mcpserver/impl/McpServerService.kt | 78 ++++++------ .../impl/util/network/mcp.sdk.util.kt | 117 ++++++++++++++++++ .../toolsets/terminal/TerminalToolset.kt | 8 +- .../toolsets/terminal/terminalToolsetUtil.kt | 23 +++- 5 files changed, 191 insertions(+), 48 deletions(-) create mode 100644 plugins/mcp-server/src/com/intellij/mcpserver/ClientInfo.kt create mode 100644 plugins/mcp-server/src/com/intellij/mcpserver/impl/util/network/mcp.sdk.util.kt diff --git a/plugins/mcp-server/src/com/intellij/mcpserver/ClientInfo.kt b/plugins/mcp-server/src/com/intellij/mcpserver/ClientInfo.kt new file mode 100644 index 000000000000..027d27e4bd8b --- /dev/null +++ b/plugins/mcp-server/src/com/intellij/mcpserver/ClientInfo.kt @@ -0,0 +1,13 @@ +package com.intellij.mcpserver + +import kotlin.coroutines.CoroutineContext + +class ClientInfo(val name: String, val version: String) + +internal class ClientInfoElement(val info: ClientInfo?) : CoroutineContext.Element { + companion object Key : CoroutineContext.Key + override val key: CoroutineContext.Key<*> = Key +} + +val CoroutineContext.clientInfoOrNull: ClientInfo? get() = get(ClientInfoElement.Key)?.info +val CoroutineContext.clientInfo: ClientInfo get() = get(ClientInfoElement.Key)?.info ?: ClientInfo("Unknown client", "Unknown version") \ No newline at end of file diff --git a/plugins/mcp-server/src/com/intellij/mcpserver/impl/McpServerService.kt b/plugins/mcp-server/src/com/intellij/mcpserver/impl/McpServerService.kt index 9de5eaa2313b..7b3c7ef00ccf 100644 --- a/plugins/mcp-server/src/com/intellij/mcpserver/impl/McpServerService.kt +++ b/plugins/mcp-server/src/com/intellij/mcpserver/impl/McpServerService.kt @@ -1,8 +1,7 @@ package com.intellij.mcpserver.impl import com.intellij.mcpserver.* -import com.intellij.mcpserver.impl.util.network.findFirstFreePort -import com.intellij.mcpserver.impl.util.network.installHostValidation +import com.intellij.mcpserver.impl.util.network.* import com.intellij.mcpserver.settings.McpServerSettings import com.intellij.mcpserver.statistics.McpServerCounterUsagesCollector import com.intellij.mcpserver.stdio.IJ_MCP_SERVER_PROJECT_PATH @@ -39,14 +38,10 @@ import io.modelcontextprotocol.kotlin.sdk.* import io.modelcontextprotocol.kotlin.sdk.server.RegisteredTool import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions -import io.modelcontextprotocol.kotlin.sdk.server.mcp -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.* import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.collectLatest import kotlinx.coroutines.flow.update -import kotlinx.coroutines.launch -import kotlinx.coroutines.withContext import kotlinx.serialization.json.JsonPrimitive import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.CopyOnWriteArrayList @@ -136,35 +131,39 @@ class McpServerService(val cs: CoroutineScope) { } }) - val mcpServer = Server( - Implementation( - name = "${ApplicationNamesInfo.getInstance().fullProductName} MCP Server", - version = ApplicationInfo.getInstance().fullVersion - ), - ServerOptions( - capabilities = ServerCapabilities( - //prompts = ServerCapabilities.Prompts(listChanged = true), - //resources = ServerCapabilities.Resources(subscribe = true, listChanged = true), - tools = ServerCapabilities.Tools(listChanged = true), - ) - ) - ) - cs.launch { - var previousTools: List? = null - mcpTools.collectLatest { updatedTools -> - previousTools?.forEach { previousTool -> - mcpServer.removeTool(previousTool.tool.name) - } - mcpServer.addTools(updatedTools) - previousTools = updatedTools - } - } + return cs.embeddedServer(CIO, host = "127.0.0.1", port = freePort) { installHostValidation() - mcp { - return@mcp mcpServer + installHttpRequestPropagation() + mcpPatched { + // this is added because now Kotlin MCP client doesn't support header adjusting for each request, only for initial one, see McpStdioRunner + val projectPath = call.request.headers[IJ_MCP_SERVER_PROJECT_PATH] + val mcpServer = Server( + Implementation( + name = "${ApplicationNamesInfo.getInstance().fullProductName} MCP Server", + version = ApplicationInfo.getInstance().fullVersion + ), + ServerOptions( + capabilities = ServerCapabilities( + //prompts = ServerCapabilities.Prompts(listChanged = true), + //resources = ServerCapabilities.Resources(subscribe = true, listChanged = true), + tools = ServerCapabilities.Tools(listChanged = true), + ) + ) + ) + launch { + var previousTools: List? = null + mcpTools.collectLatest { updatedTools -> + previousTools?.forEach { previousTool -> + mcpServer.removeTool(previousTool.descriptor.name) + } + mcpServer.addTools(updatedTools.map { it.mcpToolToRegisteredTool(mcpServer, projectPath) }) + previousTools = updatedTools + } + } + return@mcpPatched mcpServer } }.start(wait = false) } @@ -177,9 +176,9 @@ class McpServerService(val cs: CoroutineScope) { logger.error("Cannot load tools for $it", e) emptyList() } - }.map { it.mcpToolToRegisteredTool() } + } -private fun McpTool.mcpToolToRegisteredTool(): RegisteredTool { +private fun McpTool.mcpToolToRegisteredTool(server: Server, projectPathFromInitialRequest: String?): RegisteredTool { val tool = Tool(name = descriptor.name, description = descriptor.description, inputSchema = Tool.Input( @@ -188,7 +187,8 @@ private fun McpTool.mcpToolToRegisteredTool(): RegisteredTool { outputSchema = null, annotations = null) return RegisteredTool(tool) { request -> - val projectPath = (request._meta[IJ_MCP_SERVER_PROJECT_PATH] as? JsonPrimitive)?.content + val httpRequest = currentCoroutineContext().httpRequestOrNull + val projectPath = httpRequest?.headers?.get(IJ_MCP_SERVER_PROJECT_PATH) ?: (request._meta[IJ_MCP_SERVER_PROJECT_PATH] as? JsonPrimitive)?.content ?: projectPathFromInitialRequest val project = if (!projectPath.isNullOrBlank()) { ProjectManager.getInstance().openProjects.find { it.basePath == projectPath } } @@ -223,7 +223,13 @@ private fun McpTool.mcpToolToRegisteredTool(): RegisteredTool { application.messageBus.syncPublisher(ToolCallListener.TOPIC).beforeMcpToolCall(this@mcpToolToRegisteredTool.descriptor) logger.trace { "Start calling tool '${this@mcpToolToRegisteredTool.descriptor.name}'. Arguments: ${request.arguments}" } - val result = withContext(ProjectContextElement(project) + McpToolDescriptorElement(this@mcpToolToRegisteredTool.descriptor)) { + val clientVersion = server.clientVersion ?: Implementation("Unknown client", "Unknown version") + + val result = withContext( + ProjectContextElement(project) + + McpToolDescriptorElement(descriptor) + + ClientInfoElement(ClientInfo(clientVersion.name, clientVersion.version)) + ) { this@mcpToolToRegisteredTool.call(request.arguments) } diff --git a/plugins/mcp-server/src/com/intellij/mcpserver/impl/util/network/mcp.sdk.util.kt b/plugins/mcp-server/src/com/intellij/mcpserver/impl/util/network/mcp.sdk.util.kt new file mode 100644 index 000000000000..09ac2c8fcbd5 --- /dev/null +++ b/plugins/mcp-server/src/com/intellij/mcpserver/impl/util/network/mcp.sdk.util.kt @@ -0,0 +1,117 @@ +package com.intellij.mcpserver.impl.util.network + +import com.intellij.openapi.diagnostic.logger +import com.intellij.openapi.diagnostic.trace +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.Application +import io.ktor.server.application.ApplicationCallPipeline +import io.ktor.server.application.install +import io.ktor.server.request.ApplicationRequest +import io.ktor.server.response.respond +import io.ktor.server.routing.RoutingContext +import io.ktor.server.routing.post +import io.ktor.server.routing.routing +import io.ktor.server.sse.SSE +import io.ktor.server.sse.ServerSSESession +import io.ktor.server.sse.sse +import io.ktor.util.collections.ConcurrentMap +import io.ktor.utils.io.KtorDsl +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.SseServerTransport +import kotlinx.coroutines.withContext +import kotlin.coroutines.CoroutineContext + + +private val logger = logger() + +/** + * Temporary copied code from MCP SDK to pass thought sse session into handler + */ +@KtorDsl +fun Application.mcpPatched(block: ServerSSESession.() -> Server) { + val transports = ConcurrentMap() + + install(SSE) + + routing { + sse("/sse") { + mcpSseEndpoint("/message", transports, block) + } + + post("/message") { + mcpPostEndpoint(transports) + } + } +} + +private suspend fun ServerSSESession.mcpSseEndpoint( + postEndpoint: String, + transports: ConcurrentMap, + block: ServerSSESession.() -> Server, +) { + val transport = mcpSseTransport(postEndpoint, transports) + + val server = this.block() + + server.onClose { + logger.trace { "Server connection closed for sessionId: ${transport.sessionId}" } + transports.remove(transport.sessionId) + } + + server.connect(transport) + logger.trace { "Server connected to transport for sessionId: ${transport.sessionId}" } +} + +internal fun ServerSSESession.mcpSseTransport( + postEndpoint: String, + transports: ConcurrentMap, +): SseServerTransport { + val transport = SseServerTransport(postEndpoint, this) + transports[transport.sessionId] = transport + + logger.trace { "New SSE connection established and stored with sessionId: ${transport.sessionId}" } + + return transport +} + + +internal suspend fun RoutingContext.mcpPostEndpoint( + transports: ConcurrentMap, +) { + val sessionId: String = call.request.queryParameters["sessionId"] + ?: run { + call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided") + return + } + + logger.trace { "Received message for sessionId: $sessionId" } + + val transport = transports[sessionId] + if (transport == null) { + logger.warn("Session not found for sessionId: $sessionId") + call.respond(HttpStatusCode.NotFound, "Session not found") + return + } + + transport.handlePostMessage(call) + logger.trace { "Message handled for sessionId: $sessionId" } +} + +//–– your custom context element +class HttpRequestElement(val request: ApplicationRequest) : CoroutineContext.Element { + companion object Key : CoroutineContext.Key + override val key: CoroutineContext.Key<*> = Key +} + +//–– install interceptor at the Call phase +fun Application.installHttpRequestPropagation() { + intercept(ApplicationCallPipeline.Call) { + // wrap the rest of the pipeline in your element + withContext(HttpRequestElement(this.context.request)) { + proceed() // this continues routing, handlers, etc. + } + } +} + +val CoroutineContext.httpRequestOrNull: ApplicationRequest? get() = get(HttpRequestElement)?.request +val CoroutineContext.mcpSessionId: String? get() = httpRequestOrNull?.queryParameters?.get("sessionId") diff --git a/plugins/mcp-server/src/com/intellij/mcpserver/toolsets/terminal/TerminalToolset.kt b/plugins/mcp-server/src/com/intellij/mcpserver/toolsets/terminal/TerminalToolset.kt index 7c47872069df..e790f40f2fde 100644 --- a/plugins/mcp-server/src/com/intellij/mcpserver/toolsets/terminal/TerminalToolset.kt +++ b/plugins/mcp-server/src/com/intellij/mcpserver/toolsets/terminal/TerminalToolset.kt @@ -2,12 +2,9 @@ package com.intellij.mcpserver.toolsets.terminal -import com.intellij.mcpserver.McpServerBundle -import com.intellij.mcpserver.McpToolset +import com.intellij.mcpserver.* import com.intellij.mcpserver.annotations.McpDescription import com.intellij.mcpserver.annotations.McpTool -import com.intellij.mcpserver.project -import com.intellij.mcpserver.reportToolActivity import com.intellij.mcpserver.toolsets.Constants import com.intellij.mcpserver.util.TruncateMode import com.intellij.mcpserver.util.checkUserConfirmationIfNeeded @@ -58,8 +55,7 @@ class TerminalToolset : McpToolset { val project = currentCoroutineContext().project checkUserConfirmationIfNeeded(McpServerBundle.message("label.do.you.want.to.execute.command.in.terminal"), command, project) - // TODO pass from http request later (MCP Client name or something else) - val id = "mcp_session" + val id = currentCoroutineContext().clientInfoOrNull?.name ?: "mcp_session" val window = ToolWindowManager.getInstance(project).getToolWindow(TerminalToolWindowFactory.TOOL_WINDOW_ID) return executeShellCommand(window = window, project = project, diff --git a/plugins/mcp-server/src/com/intellij/mcpserver/toolsets/terminal/terminalToolsetUtil.kt b/plugins/mcp-server/src/com/intellij/mcpserver/toolsets/terminal/terminalToolsetUtil.kt index dc8ba197594f..51cb87268852 100644 --- a/plugins/mcp-server/src/com/intellij/mcpserver/toolsets/terminal/terminalToolsetUtil.kt +++ b/plugins/mcp-server/src/com/intellij/mcpserver/toolsets/terminal/terminalToolsetUtil.kt @@ -7,12 +7,14 @@ import com.intellij.execution.process.ProcessEvent import com.intellij.execution.process.ProcessListener import com.intellij.execution.process.ProcessOutputTypes import com.intellij.mcpserver.McpServerBundle +import com.intellij.mcpserver.clientInfoOrNull import com.intellij.mcpserver.mcpFail import com.intellij.mcpserver.toolsets.terminal.TerminalToolset.CommandExecutionResult import com.intellij.mcpserver.util.TruncateMode import com.intellij.mcpserver.util.truncateText import com.intellij.openapi.application.EDT import com.intellij.openapi.project.Project +import com.intellij.openapi.util.Disposer import com.intellij.openapi.util.Key import com.intellij.openapi.util.NlsSafe import com.intellij.openapi.util.io.toNioPathOrNull @@ -21,10 +23,7 @@ import com.intellij.sh.run.ShConfigurationType import com.intellij.terminal.TerminalExecutionConsole import com.intellij.ui.content.ContentFactory import com.intellij.util.execution.ParametersListUtil -import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext -import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.coroutines.* import kotlin.time.Duration class CommandSession(val sessionId: String, val console: TerminalExecutionConsole) @@ -87,8 +86,15 @@ suspend fun executeShellCommand( } else { val executionConsole = TerminalExecutionConsole(project, processHandler).withConvertLfToCrlfForNonPtyProcess(true) - val content = ContentFactory.getInstance().createContent(executionConsole.component, McpServerBundle.message("mcp.general.terminal.tab.name"), false) + @Suppress("HardCodedStringLiteral") + val displayName = currentCoroutineContext().clientInfoOrNull?.name ?: McpServerBundle.message ("mcp.general.terminal.tab.name") + val content = ContentFactory.getInstance().createContent(executionConsole.component, displayName, false) window.contentManager.addContent(content) + Disposer.register(content) { + @Suppress("HardCodedStringLiteral") // visible to LLM only + exitCode.completeExceptionally(ExecutionException("Terminal tab closed by user")) + processHandler.destroyProcess() + } if (sessionId != null) { content.putUserData(MCP_TERMINAL_KEY, CommandSession(sessionId, executionConsole)) } @@ -104,7 +110,12 @@ suspend fun executeShellCommand( processHandler.startNotify() val exitCodeValue = withTimeoutOrNull(timeout) { - exitCode.await() + try { + exitCode.await() + } + catch (e: ExecutionException) { + mcpFail("Execution failed: ${e.message}") + } } val truncateText = truncateText(text = output.toString(), maxLinesCount = maxLinesCount, truncateMode = truncateMode) if (exitCodeValue == null) {