mirror of
https://gitflic.ru/project/openide/openide.git
synced 2025-12-15 11:53:49 +07:00
(cherry picked from commit bcb3c215ddcccfad3311ed5163503175e54254a4) IJ-CR-166188 GitOrigin-RevId: 0a16d89676ee5f7d9f80688805bd4312d7f20ccd
134 lines
4.4 KiB
Kotlin
134 lines
4.4 KiB
Kotlin
package com.intellij.mcpserver
|
|
|
|
import com.intellij.mcpserver.annotations.McpDescription
|
|
import com.intellij.mcpserver.impl.McpServerService
|
|
import com.intellij.mcpserver.impl.util.asTool
|
|
import com.intellij.openapi.project.Project
|
|
import com.intellij.openapi.util.Disposer
|
|
import com.intellij.openapi.util.use
|
|
import com.intellij.testFramework.junit5.TestApplication
|
|
import com.intellij.testFramework.junit5.fixture.projectFixture
|
|
import com.intellij.util.application
|
|
import io.ktor.client.HttpClient
|
|
import io.ktor.client.plugins.sse.SSE
|
|
import io.ktor.utils.io.streams.asInput
|
|
import io.modelcontextprotocol.kotlin.sdk.Implementation
|
|
import io.modelcontextprotocol.kotlin.sdk.client.Client
|
|
import io.modelcontextprotocol.kotlin.sdk.client.SseClientTransport
|
|
import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport
|
|
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
|
|
import kotlinx.coroutines.CompletableDeferred
|
|
import kotlinx.coroutines.currentCoroutineContext
|
|
import kotlinx.coroutines.runBlocking
|
|
import kotlinx.coroutines.withTimeout
|
|
import kotlinx.io.asSink
|
|
import kotlinx.io.buffered
|
|
import org.junit.jupiter.params.ParameterizedTest
|
|
import org.junit.jupiter.params.provider.MethodSource
|
|
import java.util.concurrent.TimeUnit
|
|
import kotlin.test.assertEquals
|
|
import kotlin.test.fail
|
|
|
|
@TestApplication
|
|
class TransportTest {
|
|
|
|
companion object {
|
|
val projectFixture = projectFixture(openAfterCreation = true)
|
|
val project by projectFixture
|
|
|
|
@JvmStatic
|
|
fun getTransports(): Array<TransportHolder> {
|
|
return arrayOf(
|
|
StdioTransportHolder(project),
|
|
SseTransportHolder(project),
|
|
)
|
|
}
|
|
}
|
|
|
|
@ParameterizedTest
|
|
@MethodSource("getTransports")
|
|
fun list_tools(transport: TransportHolder) = transportTest(transport) { client ->
|
|
val listTools = client.listTools() ?: fail("No tools returned")
|
|
assert(listTools.tools.isNotEmpty()) { "No tools returned" }
|
|
}
|
|
|
|
@ParameterizedTest
|
|
@MethodSource("getTransports")
|
|
fun tool_call_has_project(transport: TransportHolder) = transportTest(transport) { client ->
|
|
Disposer.newDisposable().use { disposable ->
|
|
application.extensionArea.getExtensionPoint(McpToolsProvider.EP).registerExtension(object : McpToolsProvider {
|
|
override fun getTools(): List<McpTool> {
|
|
return listOf(this@TransportTest::test_tool.asTool())
|
|
}
|
|
}, disposable)
|
|
client.callTool("test_tool", emptyMap())
|
|
|
|
val actual = withTimeout(2000) { projectFromTool.await() }
|
|
assertEquals(project, actual)
|
|
}
|
|
}
|
|
|
|
val projectFromTool = CompletableDeferred<Project?>()
|
|
|
|
@com.intellij.mcpserver.annotations.McpTool()
|
|
@McpDescription("Test description")
|
|
suspend fun test_tool() {
|
|
projectFromTool.complete(currentCoroutineContext().projectOrNull)
|
|
}
|
|
|
|
private fun transportTest(transportHolder: TransportHolder, action: suspend (Client) -> Unit) = runBlocking {
|
|
try {
|
|
McpServerService.getInstance().start()
|
|
val client = Client(Implementation(name = "test client", version = "1.0"))
|
|
client.connect(transportHolder.transport)
|
|
action(client)
|
|
}
|
|
finally {
|
|
transportHolder.close()
|
|
McpServerService.getInstance().stop()
|
|
}
|
|
}
|
|
}
|
|
|
|
abstract class TransportHolder {
|
|
abstract val transport: AbstractTransport
|
|
|
|
// do not make it AutoCloseable because Junit tries to close it automatically but we want to close it in test method manually
|
|
open fun close() {
|
|
runBlocking {
|
|
transport.close()
|
|
}
|
|
}
|
|
}
|
|
|
|
class StdioTransportHolder(project: Project) : TransportHolder() {
|
|
val process: Process by lazy {
|
|
createStdioMcpServerCommandLine(McpServerService.getInstance().port, project.basePath).toProcessBuilder().start()
|
|
}
|
|
|
|
override val transport: AbstractTransport by lazy {
|
|
StdioClientTransport(process.inputStream.asInput(), process.outputStream.asSink().buffered())
|
|
}
|
|
|
|
override fun close() {
|
|
super.close() //sseClientTransport.close()
|
|
if (!process.waitFor(10, TimeUnit.SECONDS)) process.destroyForcibly()
|
|
if (!process.waitFor(10, TimeUnit.SECONDS)) fail("Process is still alive")
|
|
}
|
|
|
|
override fun toString(): String {
|
|
return "Stdio"
|
|
}
|
|
}
|
|
|
|
class SseTransportHolder(project: Project) : TransportHolder() {
|
|
override val transport: AbstractTransport by lazy {
|
|
SseClientTransport(HttpClient {
|
|
install(SSE)
|
|
}, "http://localhost:${McpServerService.getInstance().port}/")
|
|
}
|
|
|
|
override fun toString(): String {
|
|
return "SSE"
|
|
}
|
|
} |