mirror of
https://gitflic.ru/project/openide/openide.git
synced 2025-12-15 02:59:33 +07:00
[MCP Server] Support propagation of headers into mcp tool call (to pass a project with headers)
Pass ClientInfo into MCP call (cherry picked from commit aa91dac054f062664c1eb436e2f7dc6de8dd8e3b) GitOrigin-RevId: 2250cc7cf91b28d3e148f3b15597e94346ce7711
This commit is contained in:
committed by
intellij-monorepo-bot
parent
33193fcd38
commit
b30fa93446
13
plugins/mcp-server/src/com/intellij/mcpserver/ClientInfo.kt
Normal file
13
plugins/mcp-server/src/com/intellij/mcpserver/ClientInfo.kt
Normal file
@@ -0,0 +1,13 @@
|
||||
package com.intellij.mcpserver
|
||||
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
|
||||
class ClientInfo(val name: String, val version: String)
|
||||
|
||||
internal class ClientInfoElement(val info: ClientInfo?) : CoroutineContext.Element {
|
||||
companion object Key : CoroutineContext.Key<ClientInfoElement>
|
||||
override val key: CoroutineContext.Key<*> = Key
|
||||
}
|
||||
|
||||
val CoroutineContext.clientInfoOrNull: ClientInfo? get() = get(ClientInfoElement.Key)?.info
|
||||
val CoroutineContext.clientInfo: ClientInfo get() = get(ClientInfoElement.Key)?.info ?: ClientInfo("Unknown client", "Unknown version")
|
||||
@@ -1,8 +1,7 @@
|
||||
package com.intellij.mcpserver.impl
|
||||
|
||||
import com.intellij.mcpserver.*
|
||||
import com.intellij.mcpserver.impl.util.network.findFirstFreePort
|
||||
import com.intellij.mcpserver.impl.util.network.installHostValidation
|
||||
import com.intellij.mcpserver.impl.util.network.*
|
||||
import com.intellij.mcpserver.settings.McpServerSettings
|
||||
import com.intellij.mcpserver.statistics.McpServerCounterUsagesCollector
|
||||
import com.intellij.mcpserver.stdio.IJ_MCP_SERVER_PROJECT_PATH
|
||||
@@ -39,14 +38,10 @@ import io.modelcontextprotocol.kotlin.sdk.*
|
||||
import io.modelcontextprotocol.kotlin.sdk.server.RegisteredTool
|
||||
import io.modelcontextprotocol.kotlin.sdk.server.Server
|
||||
import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
|
||||
import io.modelcontextprotocol.kotlin.sdk.server.mcp
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.*
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.collectLatest
|
||||
import kotlinx.coroutines.flow.update
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.serialization.json.JsonPrimitive
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.CopyOnWriteArrayList
|
||||
@@ -136,35 +131,39 @@ class McpServerService(val cs: CoroutineScope) {
|
||||
}
|
||||
})
|
||||
|
||||
val mcpServer = Server(
|
||||
Implementation(
|
||||
name = "${ApplicationNamesInfo.getInstance().fullProductName} MCP Server",
|
||||
version = ApplicationInfo.getInstance().fullVersion
|
||||
),
|
||||
ServerOptions(
|
||||
capabilities = ServerCapabilities(
|
||||
//prompts = ServerCapabilities.Prompts(listChanged = true),
|
||||
//resources = ServerCapabilities.Resources(subscribe = true, listChanged = true),
|
||||
tools = ServerCapabilities.Tools(listChanged = true),
|
||||
)
|
||||
)
|
||||
)
|
||||
cs.launch {
|
||||
var previousTools: List<RegisteredTool>? = null
|
||||
mcpTools.collectLatest { updatedTools ->
|
||||
previousTools?.forEach { previousTool ->
|
||||
mcpServer.removeTool(previousTool.tool.name)
|
||||
}
|
||||
mcpServer.addTools(updatedTools)
|
||||
previousTools = updatedTools
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
return cs.embeddedServer(CIO, host = "127.0.0.1", port = freePort) {
|
||||
installHostValidation()
|
||||
mcp {
|
||||
return@mcp mcpServer
|
||||
installHttpRequestPropagation()
|
||||
mcpPatched {
|
||||
// this is added because now Kotlin MCP client doesn't support header adjusting for each request, only for initial one, see McpStdioRunner
|
||||
val projectPath = call.request.headers[IJ_MCP_SERVER_PROJECT_PATH]
|
||||
val mcpServer = Server(
|
||||
Implementation(
|
||||
name = "${ApplicationNamesInfo.getInstance().fullProductName} MCP Server",
|
||||
version = ApplicationInfo.getInstance().fullVersion
|
||||
),
|
||||
ServerOptions(
|
||||
capabilities = ServerCapabilities(
|
||||
//prompts = ServerCapabilities.Prompts(listChanged = true),
|
||||
//resources = ServerCapabilities.Resources(subscribe = true, listChanged = true),
|
||||
tools = ServerCapabilities.Tools(listChanged = true),
|
||||
)
|
||||
)
|
||||
)
|
||||
launch {
|
||||
var previousTools: List<McpTool>? = null
|
||||
mcpTools.collectLatest { updatedTools ->
|
||||
previousTools?.forEach { previousTool ->
|
||||
mcpServer.removeTool(previousTool.descriptor.name)
|
||||
}
|
||||
mcpServer.addTools(updatedTools.map { it.mcpToolToRegisteredTool(mcpServer, projectPath) })
|
||||
previousTools = updatedTools
|
||||
}
|
||||
}
|
||||
return@mcpPatched mcpServer
|
||||
}
|
||||
}.start(wait = false)
|
||||
}
|
||||
@@ -177,9 +176,9 @@ class McpServerService(val cs: CoroutineScope) {
|
||||
logger.error("Cannot load tools for $it", e)
|
||||
emptyList()
|
||||
}
|
||||
}.map { it.mcpToolToRegisteredTool() }
|
||||
}
|
||||
|
||||
private fun McpTool.mcpToolToRegisteredTool(): RegisteredTool {
|
||||
private fun McpTool.mcpToolToRegisteredTool(server: Server, projectPathFromInitialRequest: String?): RegisteredTool {
|
||||
val tool = Tool(name = descriptor.name,
|
||||
description = descriptor.description,
|
||||
inputSchema = Tool.Input(
|
||||
@@ -188,7 +187,8 @@ private fun McpTool.mcpToolToRegisteredTool(): RegisteredTool {
|
||||
outputSchema = null,
|
||||
annotations = null)
|
||||
return RegisteredTool(tool) { request ->
|
||||
val projectPath = (request._meta[IJ_MCP_SERVER_PROJECT_PATH] as? JsonPrimitive)?.content
|
||||
val httpRequest = currentCoroutineContext().httpRequestOrNull
|
||||
val projectPath = httpRequest?.headers?.get(IJ_MCP_SERVER_PROJECT_PATH) ?: (request._meta[IJ_MCP_SERVER_PROJECT_PATH] as? JsonPrimitive)?.content ?: projectPathFromInitialRequest
|
||||
val project = if (!projectPath.isNullOrBlank()) {
|
||||
ProjectManager.getInstance().openProjects.find { it.basePath == projectPath }
|
||||
}
|
||||
@@ -223,7 +223,13 @@ private fun McpTool.mcpToolToRegisteredTool(): RegisteredTool {
|
||||
application.messageBus.syncPublisher(ToolCallListener.TOPIC).beforeMcpToolCall(this@mcpToolToRegisteredTool.descriptor)
|
||||
|
||||
logger.trace { "Start calling tool '${this@mcpToolToRegisteredTool.descriptor.name}'. Arguments: ${request.arguments}" }
|
||||
val result = withContext(ProjectContextElement(project) + McpToolDescriptorElement(this@mcpToolToRegisteredTool.descriptor)) {
|
||||
val clientVersion = server.clientVersion ?: Implementation("Unknown client", "Unknown version")
|
||||
|
||||
val result = withContext(
|
||||
ProjectContextElement(project) +
|
||||
McpToolDescriptorElement(descriptor) +
|
||||
ClientInfoElement(ClientInfo(clientVersion.name, clientVersion.version))
|
||||
) {
|
||||
this@mcpToolToRegisteredTool.call(request.arguments)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
package com.intellij.mcpserver.impl.util.network
|
||||
|
||||
import com.intellij.openapi.diagnostic.logger
|
||||
import com.intellij.openapi.diagnostic.trace
|
||||
import io.ktor.http.HttpStatusCode
|
||||
import io.ktor.server.application.Application
|
||||
import io.ktor.server.application.ApplicationCallPipeline
|
||||
import io.ktor.server.application.install
|
||||
import io.ktor.server.request.ApplicationRequest
|
||||
import io.ktor.server.response.respond
|
||||
import io.ktor.server.routing.RoutingContext
|
||||
import io.ktor.server.routing.post
|
||||
import io.ktor.server.routing.routing
|
||||
import io.ktor.server.sse.SSE
|
||||
import io.ktor.server.sse.ServerSSESession
|
||||
import io.ktor.server.sse.sse
|
||||
import io.ktor.util.collections.ConcurrentMap
|
||||
import io.ktor.utils.io.KtorDsl
|
||||
import io.modelcontextprotocol.kotlin.sdk.server.Server
|
||||
import io.modelcontextprotocol.kotlin.sdk.server.SseServerTransport
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlin.coroutines.CoroutineContext
|
||||
|
||||
|
||||
private val logger = logger<RoutingContext>()
|
||||
|
||||
/**
|
||||
* Temporary copied code from MCP SDK to pass thought sse session into handler
|
||||
*/
|
||||
@KtorDsl
|
||||
fun Application.mcpPatched(block: ServerSSESession.() -> Server) {
|
||||
val transports = ConcurrentMap<String, SseServerTransport>()
|
||||
|
||||
install(SSE)
|
||||
|
||||
routing {
|
||||
sse("/sse") {
|
||||
mcpSseEndpoint("/message", transports, block)
|
||||
}
|
||||
|
||||
post("/message") {
|
||||
mcpPostEndpoint(transports)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private suspend fun ServerSSESession.mcpSseEndpoint(
|
||||
postEndpoint: String,
|
||||
transports: ConcurrentMap<String, SseServerTransport>,
|
||||
block: ServerSSESession.() -> Server,
|
||||
) {
|
||||
val transport = mcpSseTransport(postEndpoint, transports)
|
||||
|
||||
val server = this.block()
|
||||
|
||||
server.onClose {
|
||||
logger.trace { "Server connection closed for sessionId: ${transport.sessionId}" }
|
||||
transports.remove(transport.sessionId)
|
||||
}
|
||||
|
||||
server.connect(transport)
|
||||
logger.trace { "Server connected to transport for sessionId: ${transport.sessionId}" }
|
||||
}
|
||||
|
||||
internal fun ServerSSESession.mcpSseTransport(
|
||||
postEndpoint: String,
|
||||
transports: ConcurrentMap<String, SseServerTransport>,
|
||||
): SseServerTransport {
|
||||
val transport = SseServerTransport(postEndpoint, this)
|
||||
transports[transport.sessionId] = transport
|
||||
|
||||
logger.trace { "New SSE connection established and stored with sessionId: ${transport.sessionId}" }
|
||||
|
||||
return transport
|
||||
}
|
||||
|
||||
|
||||
internal suspend fun RoutingContext.mcpPostEndpoint(
|
||||
transports: ConcurrentMap<String, SseServerTransport>,
|
||||
) {
|
||||
val sessionId: String = call.request.queryParameters["sessionId"]
|
||||
?: run {
|
||||
call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided")
|
||||
return
|
||||
}
|
||||
|
||||
logger.trace { "Received message for sessionId: $sessionId" }
|
||||
|
||||
val transport = transports[sessionId]
|
||||
if (transport == null) {
|
||||
logger.warn("Session not found for sessionId: $sessionId")
|
||||
call.respond(HttpStatusCode.NotFound, "Session not found")
|
||||
return
|
||||
}
|
||||
|
||||
transport.handlePostMessage(call)
|
||||
logger.trace { "Message handled for sessionId: $sessionId" }
|
||||
}
|
||||
|
||||
//–– your custom context element
|
||||
class HttpRequestElement(val request: ApplicationRequest) : CoroutineContext.Element {
|
||||
companion object Key : CoroutineContext.Key<HttpRequestElement>
|
||||
override val key: CoroutineContext.Key<*> = Key
|
||||
}
|
||||
|
||||
//–– install interceptor at the Call phase
|
||||
fun Application.installHttpRequestPropagation() {
|
||||
intercept(ApplicationCallPipeline.Call) {
|
||||
// wrap the rest of the pipeline in your element
|
||||
withContext(HttpRequestElement(this.context.request)) {
|
||||
proceed() // this continues routing, handlers, etc.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val CoroutineContext.httpRequestOrNull: ApplicationRequest? get() = get(HttpRequestElement)?.request
|
||||
val CoroutineContext.mcpSessionId: String? get() = httpRequestOrNull?.queryParameters?.get("sessionId")
|
||||
@@ -2,12 +2,9 @@
|
||||
|
||||
package com.intellij.mcpserver.toolsets.terminal
|
||||
|
||||
import com.intellij.mcpserver.McpServerBundle
|
||||
import com.intellij.mcpserver.McpToolset
|
||||
import com.intellij.mcpserver.*
|
||||
import com.intellij.mcpserver.annotations.McpDescription
|
||||
import com.intellij.mcpserver.annotations.McpTool
|
||||
import com.intellij.mcpserver.project
|
||||
import com.intellij.mcpserver.reportToolActivity
|
||||
import com.intellij.mcpserver.toolsets.Constants
|
||||
import com.intellij.mcpserver.util.TruncateMode
|
||||
import com.intellij.mcpserver.util.checkUserConfirmationIfNeeded
|
||||
@@ -58,8 +55,7 @@ class TerminalToolset : McpToolset {
|
||||
val project = currentCoroutineContext().project
|
||||
checkUserConfirmationIfNeeded(McpServerBundle.message("label.do.you.want.to.execute.command.in.terminal"), command, project)
|
||||
|
||||
// TODO pass from http request later (MCP Client name or something else)
|
||||
val id = "mcp_session"
|
||||
val id = currentCoroutineContext().clientInfoOrNull?.name ?: "mcp_session"
|
||||
val window = ToolWindowManager.getInstance(project).getToolWindow(TerminalToolWindowFactory.TOOL_WINDOW_ID)
|
||||
return executeShellCommand(window = window,
|
||||
project = project,
|
||||
|
||||
@@ -7,12 +7,14 @@ import com.intellij.execution.process.ProcessEvent
|
||||
import com.intellij.execution.process.ProcessListener
|
||||
import com.intellij.execution.process.ProcessOutputTypes
|
||||
import com.intellij.mcpserver.McpServerBundle
|
||||
import com.intellij.mcpserver.clientInfoOrNull
|
||||
import com.intellij.mcpserver.mcpFail
|
||||
import com.intellij.mcpserver.toolsets.terminal.TerminalToolset.CommandExecutionResult
|
||||
import com.intellij.mcpserver.util.TruncateMode
|
||||
import com.intellij.mcpserver.util.truncateText
|
||||
import com.intellij.openapi.application.EDT
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.util.Disposer
|
||||
import com.intellij.openapi.util.Key
|
||||
import com.intellij.openapi.util.NlsSafe
|
||||
import com.intellij.openapi.util.io.toNioPathOrNull
|
||||
@@ -21,10 +23,7 @@ import com.intellij.sh.run.ShConfigurationType
|
||||
import com.intellij.terminal.TerminalExecutionConsole
|
||||
import com.intellij.ui.content.ContentFactory
|
||||
import com.intellij.util.execution.ParametersListUtil
|
||||
import kotlinx.coroutines.CompletableDeferred
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import kotlinx.coroutines.withTimeoutOrNull
|
||||
import kotlinx.coroutines.*
|
||||
import kotlin.time.Duration
|
||||
|
||||
class CommandSession(val sessionId: String, val console: TerminalExecutionConsole)
|
||||
@@ -87,8 +86,15 @@ suspend fun executeShellCommand(
|
||||
}
|
||||
else {
|
||||
val executionConsole = TerminalExecutionConsole(project, processHandler).withConvertLfToCrlfForNonPtyProcess(true)
|
||||
val content = ContentFactory.getInstance().createContent(executionConsole.component, McpServerBundle.message("mcp.general.terminal.tab.name"), false)
|
||||
@Suppress("HardCodedStringLiteral")
|
||||
val displayName = currentCoroutineContext().clientInfoOrNull?.name ?: McpServerBundle.message ("mcp.general.terminal.tab.name")
|
||||
val content = ContentFactory.getInstance().createContent(executionConsole.component, displayName, false)
|
||||
window.contentManager.addContent(content)
|
||||
Disposer.register(content) {
|
||||
@Suppress("HardCodedStringLiteral") // visible to LLM only
|
||||
exitCode.completeExceptionally(ExecutionException("Terminal tab closed by user"))
|
||||
processHandler.destroyProcess()
|
||||
}
|
||||
if (sessionId != null) {
|
||||
content.putUserData(MCP_TERMINAL_KEY, CommandSession(sessionId, executionConsole))
|
||||
}
|
||||
@@ -104,7 +110,12 @@ suspend fun executeShellCommand(
|
||||
processHandler.startNotify()
|
||||
|
||||
val exitCodeValue = withTimeoutOrNull(timeout) {
|
||||
exitCode.await()
|
||||
try {
|
||||
exitCode.await()
|
||||
}
|
||||
catch (e: ExecutionException) {
|
||||
mcpFail("Execution failed: ${e.message}")
|
||||
}
|
||||
}
|
||||
val truncateText = truncateText(text = output.toString(), maxLinesCount = maxLinesCount, truncateMode = truncateMode)
|
||||
if (exitCodeValue == null) {
|
||||
|
||||
Reference in New Issue
Block a user