[MCP Server] Support propagation of headers into mcp tool call (to pass a project with headers)

Pass ClientInfo into MCP call

(cherry picked from commit aa91dac054f062664c1eb436e2f7dc6de8dd8e3b)

GitOrigin-RevId: 2250cc7cf91b28d3e148f3b15597e94346ce7711
This commit is contained in:
Artem.Bukhonov
2025-07-17 21:47:57 +02:00
committed by intellij-monorepo-bot
parent 33193fcd38
commit b30fa93446
5 changed files with 191 additions and 48 deletions

View File

@@ -0,0 +1,13 @@
package com.intellij.mcpserver
import kotlin.coroutines.CoroutineContext
class ClientInfo(val name: String, val version: String)
internal class ClientInfoElement(val info: ClientInfo?) : CoroutineContext.Element {
companion object Key : CoroutineContext.Key<ClientInfoElement>
override val key: CoroutineContext.Key<*> = Key
}
val CoroutineContext.clientInfoOrNull: ClientInfo? get() = get(ClientInfoElement.Key)?.info
val CoroutineContext.clientInfo: ClientInfo get() = get(ClientInfoElement.Key)?.info ?: ClientInfo("Unknown client", "Unknown version")

View File

@@ -1,8 +1,7 @@
package com.intellij.mcpserver.impl package com.intellij.mcpserver.impl
import com.intellij.mcpserver.* import com.intellij.mcpserver.*
import com.intellij.mcpserver.impl.util.network.findFirstFreePort import com.intellij.mcpserver.impl.util.network.*
import com.intellij.mcpserver.impl.util.network.installHostValidation
import com.intellij.mcpserver.settings.McpServerSettings import com.intellij.mcpserver.settings.McpServerSettings
import com.intellij.mcpserver.statistics.McpServerCounterUsagesCollector import com.intellij.mcpserver.statistics.McpServerCounterUsagesCollector
import com.intellij.mcpserver.stdio.IJ_MCP_SERVER_PROJECT_PATH import com.intellij.mcpserver.stdio.IJ_MCP_SERVER_PROJECT_PATH
@@ -39,14 +38,10 @@ import io.modelcontextprotocol.kotlin.sdk.*
import io.modelcontextprotocol.kotlin.sdk.server.RegisteredTool import io.modelcontextprotocol.kotlin.sdk.server.RegisteredTool
import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.Server
import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
import io.modelcontextprotocol.kotlin.sdk.server.mcp import kotlinx.coroutines.*
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.collectLatest import kotlinx.coroutines.flow.collectLatest
import kotlinx.coroutines.flow.update import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.JsonPrimitive
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.CopyOnWriteArrayList
@@ -136,35 +131,39 @@ class McpServerService(val cs: CoroutineScope) {
} }
}) })
val mcpServer = Server(
Implementation(
name = "${ApplicationNamesInfo.getInstance().fullProductName} MCP Server",
version = ApplicationInfo.getInstance().fullVersion
),
ServerOptions(
capabilities = ServerCapabilities(
//prompts = ServerCapabilities.Prompts(listChanged = true),
//resources = ServerCapabilities.Resources(subscribe = true, listChanged = true),
tools = ServerCapabilities.Tools(listChanged = true),
)
)
)
cs.launch {
var previousTools: List<RegisteredTool>? = null
mcpTools.collectLatest { updatedTools ->
previousTools?.forEach { previousTool ->
mcpServer.removeTool(previousTool.tool.name)
}
mcpServer.addTools(updatedTools)
previousTools = updatedTools
}
}
return cs.embeddedServer(CIO, host = "127.0.0.1", port = freePort) { return cs.embeddedServer(CIO, host = "127.0.0.1", port = freePort) {
installHostValidation() installHostValidation()
mcp { installHttpRequestPropagation()
return@mcp mcpServer mcpPatched {
// this is added because now Kotlin MCP client doesn't support header adjusting for each request, only for initial one, see McpStdioRunner
val projectPath = call.request.headers[IJ_MCP_SERVER_PROJECT_PATH]
val mcpServer = Server(
Implementation(
name = "${ApplicationNamesInfo.getInstance().fullProductName} MCP Server",
version = ApplicationInfo.getInstance().fullVersion
),
ServerOptions(
capabilities = ServerCapabilities(
//prompts = ServerCapabilities.Prompts(listChanged = true),
//resources = ServerCapabilities.Resources(subscribe = true, listChanged = true),
tools = ServerCapabilities.Tools(listChanged = true),
)
)
)
launch {
var previousTools: List<McpTool>? = null
mcpTools.collectLatest { updatedTools ->
previousTools?.forEach { previousTool ->
mcpServer.removeTool(previousTool.descriptor.name)
}
mcpServer.addTools(updatedTools.map { it.mcpToolToRegisteredTool(mcpServer, projectPath) })
previousTools = updatedTools
}
}
return@mcpPatched mcpServer
} }
}.start(wait = false) }.start(wait = false)
} }
@@ -177,9 +176,9 @@ class McpServerService(val cs: CoroutineScope) {
logger.error("Cannot load tools for $it", e) logger.error("Cannot load tools for $it", e)
emptyList() emptyList()
} }
}.map { it.mcpToolToRegisteredTool() } }
private fun McpTool.mcpToolToRegisteredTool(): RegisteredTool { private fun McpTool.mcpToolToRegisteredTool(server: Server, projectPathFromInitialRequest: String?): RegisteredTool {
val tool = Tool(name = descriptor.name, val tool = Tool(name = descriptor.name,
description = descriptor.description, description = descriptor.description,
inputSchema = Tool.Input( inputSchema = Tool.Input(
@@ -188,7 +187,8 @@ private fun McpTool.mcpToolToRegisteredTool(): RegisteredTool {
outputSchema = null, outputSchema = null,
annotations = null) annotations = null)
return RegisteredTool(tool) { request -> return RegisteredTool(tool) { request ->
val projectPath = (request._meta[IJ_MCP_SERVER_PROJECT_PATH] as? JsonPrimitive)?.content val httpRequest = currentCoroutineContext().httpRequestOrNull
val projectPath = httpRequest?.headers?.get(IJ_MCP_SERVER_PROJECT_PATH) ?: (request._meta[IJ_MCP_SERVER_PROJECT_PATH] as? JsonPrimitive)?.content ?: projectPathFromInitialRequest
val project = if (!projectPath.isNullOrBlank()) { val project = if (!projectPath.isNullOrBlank()) {
ProjectManager.getInstance().openProjects.find { it.basePath == projectPath } ProjectManager.getInstance().openProjects.find { it.basePath == projectPath }
} }
@@ -223,7 +223,13 @@ private fun McpTool.mcpToolToRegisteredTool(): RegisteredTool {
application.messageBus.syncPublisher(ToolCallListener.TOPIC).beforeMcpToolCall(this@mcpToolToRegisteredTool.descriptor) application.messageBus.syncPublisher(ToolCallListener.TOPIC).beforeMcpToolCall(this@mcpToolToRegisteredTool.descriptor)
logger.trace { "Start calling tool '${this@mcpToolToRegisteredTool.descriptor.name}'. Arguments: ${request.arguments}" } logger.trace { "Start calling tool '${this@mcpToolToRegisteredTool.descriptor.name}'. Arguments: ${request.arguments}" }
val result = withContext(ProjectContextElement(project) + McpToolDescriptorElement(this@mcpToolToRegisteredTool.descriptor)) { val clientVersion = server.clientVersion ?: Implementation("Unknown client", "Unknown version")
val result = withContext(
ProjectContextElement(project) +
McpToolDescriptorElement(descriptor) +
ClientInfoElement(ClientInfo(clientVersion.name, clientVersion.version))
) {
this@mcpToolToRegisteredTool.call(request.arguments) this@mcpToolToRegisteredTool.call(request.arguments)
} }

View File

@@ -0,0 +1,117 @@
package com.intellij.mcpserver.impl.util.network
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.diagnostic.trace
import io.ktor.http.HttpStatusCode
import io.ktor.server.application.Application
import io.ktor.server.application.ApplicationCallPipeline
import io.ktor.server.application.install
import io.ktor.server.request.ApplicationRequest
import io.ktor.server.response.respond
import io.ktor.server.routing.RoutingContext
import io.ktor.server.routing.post
import io.ktor.server.routing.routing
import io.ktor.server.sse.SSE
import io.ktor.server.sse.ServerSSESession
import io.ktor.server.sse.sse
import io.ktor.util.collections.ConcurrentMap
import io.ktor.utils.io.KtorDsl
import io.modelcontextprotocol.kotlin.sdk.server.Server
import io.modelcontextprotocol.kotlin.sdk.server.SseServerTransport
import kotlinx.coroutines.withContext
import kotlin.coroutines.CoroutineContext
private val logger = logger<RoutingContext>()
/**
* Temporary copied code from MCP SDK to pass thought sse session into handler
*/
@KtorDsl
fun Application.mcpPatched(block: ServerSSESession.() -> Server) {
val transports = ConcurrentMap<String, SseServerTransport>()
install(SSE)
routing {
sse("/sse") {
mcpSseEndpoint("/message", transports, block)
}
post("/message") {
mcpPostEndpoint(transports)
}
}
}
private suspend fun ServerSSESession.mcpSseEndpoint(
postEndpoint: String,
transports: ConcurrentMap<String, SseServerTransport>,
block: ServerSSESession.() -> Server,
) {
val transport = mcpSseTransport(postEndpoint, transports)
val server = this.block()
server.onClose {
logger.trace { "Server connection closed for sessionId: ${transport.sessionId}" }
transports.remove(transport.sessionId)
}
server.connect(transport)
logger.trace { "Server connected to transport for sessionId: ${transport.sessionId}" }
}
internal fun ServerSSESession.mcpSseTransport(
postEndpoint: String,
transports: ConcurrentMap<String, SseServerTransport>,
): SseServerTransport {
val transport = SseServerTransport(postEndpoint, this)
transports[transport.sessionId] = transport
logger.trace { "New SSE connection established and stored with sessionId: ${transport.sessionId}" }
return transport
}
internal suspend fun RoutingContext.mcpPostEndpoint(
transports: ConcurrentMap<String, SseServerTransport>,
) {
val sessionId: String = call.request.queryParameters["sessionId"]
?: run {
call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided")
return
}
logger.trace { "Received message for sessionId: $sessionId" }
val transport = transports[sessionId]
if (transport == null) {
logger.warn("Session not found for sessionId: $sessionId")
call.respond(HttpStatusCode.NotFound, "Session not found")
return
}
transport.handlePostMessage(call)
logger.trace { "Message handled for sessionId: $sessionId" }
}
// your custom context element
class HttpRequestElement(val request: ApplicationRequest) : CoroutineContext.Element {
companion object Key : CoroutineContext.Key<HttpRequestElement>
override val key: CoroutineContext.Key<*> = Key
}
// install interceptor at the Call phase
fun Application.installHttpRequestPropagation() {
intercept(ApplicationCallPipeline.Call) {
// wrap the rest of the pipeline in your element
withContext(HttpRequestElement(this.context.request)) {
proceed() // this continues routing, handlers, etc.
}
}
}
val CoroutineContext.httpRequestOrNull: ApplicationRequest? get() = get(HttpRequestElement)?.request
val CoroutineContext.mcpSessionId: String? get() = httpRequestOrNull?.queryParameters?.get("sessionId")

View File

@@ -2,12 +2,9 @@
package com.intellij.mcpserver.toolsets.terminal package com.intellij.mcpserver.toolsets.terminal
import com.intellij.mcpserver.McpServerBundle import com.intellij.mcpserver.*
import com.intellij.mcpserver.McpToolset
import com.intellij.mcpserver.annotations.McpDescription import com.intellij.mcpserver.annotations.McpDescription
import com.intellij.mcpserver.annotations.McpTool import com.intellij.mcpserver.annotations.McpTool
import com.intellij.mcpserver.project
import com.intellij.mcpserver.reportToolActivity
import com.intellij.mcpserver.toolsets.Constants import com.intellij.mcpserver.toolsets.Constants
import com.intellij.mcpserver.util.TruncateMode import com.intellij.mcpserver.util.TruncateMode
import com.intellij.mcpserver.util.checkUserConfirmationIfNeeded import com.intellij.mcpserver.util.checkUserConfirmationIfNeeded
@@ -58,8 +55,7 @@ class TerminalToolset : McpToolset {
val project = currentCoroutineContext().project val project = currentCoroutineContext().project
checkUserConfirmationIfNeeded(McpServerBundle.message("label.do.you.want.to.execute.command.in.terminal"), command, project) checkUserConfirmationIfNeeded(McpServerBundle.message("label.do.you.want.to.execute.command.in.terminal"), command, project)
// TODO pass from http request later (MCP Client name or something else) val id = currentCoroutineContext().clientInfoOrNull?.name ?: "mcp_session"
val id = "mcp_session"
val window = ToolWindowManager.getInstance(project).getToolWindow(TerminalToolWindowFactory.TOOL_WINDOW_ID) val window = ToolWindowManager.getInstance(project).getToolWindow(TerminalToolWindowFactory.TOOL_WINDOW_ID)
return executeShellCommand(window = window, return executeShellCommand(window = window,
project = project, project = project,

View File

@@ -7,12 +7,14 @@ import com.intellij.execution.process.ProcessEvent
import com.intellij.execution.process.ProcessListener import com.intellij.execution.process.ProcessListener
import com.intellij.execution.process.ProcessOutputTypes import com.intellij.execution.process.ProcessOutputTypes
import com.intellij.mcpserver.McpServerBundle import com.intellij.mcpserver.McpServerBundle
import com.intellij.mcpserver.clientInfoOrNull
import com.intellij.mcpserver.mcpFail import com.intellij.mcpserver.mcpFail
import com.intellij.mcpserver.toolsets.terminal.TerminalToolset.CommandExecutionResult import com.intellij.mcpserver.toolsets.terminal.TerminalToolset.CommandExecutionResult
import com.intellij.mcpserver.util.TruncateMode import com.intellij.mcpserver.util.TruncateMode
import com.intellij.mcpserver.util.truncateText import com.intellij.mcpserver.util.truncateText
import com.intellij.openapi.application.EDT import com.intellij.openapi.application.EDT
import com.intellij.openapi.project.Project import com.intellij.openapi.project.Project
import com.intellij.openapi.util.Disposer
import com.intellij.openapi.util.Key import com.intellij.openapi.util.Key
import com.intellij.openapi.util.NlsSafe import com.intellij.openapi.util.NlsSafe
import com.intellij.openapi.util.io.toNioPathOrNull import com.intellij.openapi.util.io.toNioPathOrNull
@@ -21,10 +23,7 @@ import com.intellij.sh.run.ShConfigurationType
import com.intellij.terminal.TerminalExecutionConsole import com.intellij.terminal.TerminalExecutionConsole
import com.intellij.ui.content.ContentFactory import com.intellij.ui.content.ContentFactory
import com.intellij.util.execution.ParametersListUtil import com.intellij.util.execution.ParametersListUtil
import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.*
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import kotlinx.coroutines.withTimeoutOrNull
import kotlin.time.Duration import kotlin.time.Duration
class CommandSession(val sessionId: String, val console: TerminalExecutionConsole) class CommandSession(val sessionId: String, val console: TerminalExecutionConsole)
@@ -87,8 +86,15 @@ suspend fun executeShellCommand(
} }
else { else {
val executionConsole = TerminalExecutionConsole(project, processHandler).withConvertLfToCrlfForNonPtyProcess(true) val executionConsole = TerminalExecutionConsole(project, processHandler).withConvertLfToCrlfForNonPtyProcess(true)
val content = ContentFactory.getInstance().createContent(executionConsole.component, McpServerBundle.message("mcp.general.terminal.tab.name"), false) @Suppress("HardCodedStringLiteral")
val displayName = currentCoroutineContext().clientInfoOrNull?.name ?: McpServerBundle.message ("mcp.general.terminal.tab.name")
val content = ContentFactory.getInstance().createContent(executionConsole.component, displayName, false)
window.contentManager.addContent(content) window.contentManager.addContent(content)
Disposer.register(content) {
@Suppress("HardCodedStringLiteral") // visible to LLM only
exitCode.completeExceptionally(ExecutionException("Terminal tab closed by user"))
processHandler.destroyProcess()
}
if (sessionId != null) { if (sessionId != null) {
content.putUserData(MCP_TERMINAL_KEY, CommandSession(sessionId, executionConsole)) content.putUserData(MCP_TERMINAL_KEY, CommandSession(sessionId, executionConsole))
} }
@@ -104,7 +110,12 @@ suspend fun executeShellCommand(
processHandler.startNotify() processHandler.startNotify()
val exitCodeValue = withTimeoutOrNull(timeout) { val exitCodeValue = withTimeoutOrNull(timeout) {
exitCode.await() try {
exitCode.await()
}
catch (e: ExecutionException) {
mcpFail("Execution failed: ${e.message}")
}
} }
val truncateText = truncateText(text = output.toString(), maxLinesCount = maxLinesCount, truncateMode = truncateMode) val truncateText = truncateText(text = output.toString(), maxLinesCount = maxLinesCount, truncateMode = truncateMode)
if (exitCodeValue == null) { if (exitCodeValue == null) {