diff --git a/platform/platform-impl/src/com/intellij/execution/ijent/IjentChildProcessAdapterDelegate.kt b/platform/platform-impl/src/com/intellij/execution/ijent/IjentChildProcessAdapterDelegate.kt index df8bc454ee9a..952eb6b85e10 100644 --- a/platform/platform-impl/src/com/intellij/execution/ijent/IjentChildProcessAdapterDelegate.kt +++ b/platform/platform-impl/src/com/intellij/execution/ijent/IjentChildProcessAdapterDelegate.kt @@ -38,12 +38,12 @@ internal class IjentChildProcessAdapterDelegate( ijentChildProcess.stderr.consumeEach { chunk -> merged.send(chunk) } } - inputStream = ChannelInputStream(coroutineScope, merged) + inputStream = ChannelInputStream.forArrays(coroutineScope, merged) errorStream = ByteArrayInputStream(byteArrayOf()) } else { - inputStream = ChannelInputStream(coroutineScope, ijentChildProcess.stdout) - errorStream = ChannelInputStream(coroutineScope, ijentChildProcess.stderr) + inputStream = ChannelInputStream.forArrays(coroutineScope, ijentChildProcess.stdout) + errorStream = ChannelInputStream.forArrays(coroutineScope, ijentChildProcess.stderr) } } diff --git a/platform/platform-impl/src/com/intellij/execution/ijent/IjentStdinOutputStream.kt b/platform/platform-impl/src/com/intellij/execution/ijent/IjentStdinOutputStream.kt index 4c5ccb686b22..fcfbebdf1182 100644 --- a/platform/platform-impl/src/com/intellij/execution/ijent/IjentStdinOutputStream.kt +++ b/platform/platform-impl/src/com/intellij/execution/ijent/IjentStdinOutputStream.kt @@ -1,4 +1,4 @@ -// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. package com.intellij.execution.ijent import com.intellij.platform.ijent.IjentChildProcess @@ -12,7 +12,7 @@ internal class IjentStdinOutputStream( private val coroutineContext: CoroutineContext, private val ijentChildProcess: IjentChildProcess, ) : OutputStream() { - private val delegate = ChannelOutputStream(ijentChildProcess.stdin) + private val delegate = ChannelOutputStream.forArrays(ijentChildProcess.stdin) override fun write(b: Int) { delegate.write(b) diff --git a/platform/util/coroutines/api-dump.txt b/platform/util/coroutines/api-dump.txt index 0f71283f34da..e8da94bdfb61 100644 --- a/platform/util/coroutines/api-dump.txt +++ b/platform/util/coroutines/api-dump.txt @@ -16,17 +16,23 @@ f:com.intellij.platform.util.coroutines.CoroutineScopeKt - a:out(java.lang.Object,kotlin.coroutines.Continuation):java.lang.Object *f:com.intellij.platform.util.coroutines.channel.ChannelInputStream - java.io.InputStream -- (kotlinx.coroutines.CoroutineScope,kotlinx.coroutines.channels.ReceiveChannel):V +- *sf:Companion:com.intellij.platform.util.coroutines.channel.ChannelInputStream$Companion - available():I - close():V - read():I - read(B[],I,I):I -*f:com.intellij.platform.util.coroutines.channel.ChannelOutputStream +*f:com.intellij.platform.util.coroutines.channel.ChannelInputStream$Companion +- f:forArrays(kotlinx.coroutines.CoroutineScope,kotlinx.coroutines.channels.ReceiveChannel):com.intellij.platform.util.coroutines.channel.ChannelInputStream +- f:forByteBuffers(kotlinx.coroutines.CoroutineScope,kotlinx.coroutines.channels.ReceiveChannel):com.intellij.platform.util.coroutines.channel.ChannelInputStream +*a:com.intellij.platform.util.coroutines.channel.ChannelOutputStream - java.io.OutputStream -- (kotlinx.coroutines.channels.SendChannel):V +- *sf:Companion:com.intellij.platform.util.coroutines.channel.ChannelOutputStream$Companion - close():V - write(I):V - write(B[],I,I):V +*f:com.intellij.platform.util.coroutines.channel.ChannelOutputStream$Companion +- f:forArrays(kotlinx.coroutines.channels.SendChannel):com.intellij.platform.util.coroutines.channel.ChannelOutputStream +- f:forByteBuffers(kotlinx.coroutines.channels.SendChannel):com.intellij.platform.util.coroutines.channel.ChannelOutputStream f:com.intellij.platform.util.coroutines.flow.FlowKt - sf:collectLatestUndispatched(kotlinx.coroutines.flow.SharedFlow,kotlin.jvm.functions.Function2,kotlin.coroutines.Continuation):java.lang.Object - sf:debounceBatch-HG0u8IE(kotlinx.coroutines.flow.Flow,J):kotlinx.coroutines.flow.Flow diff --git a/platform/util/coroutines/src/channel/ChannelUtil.kt b/platform/util/coroutines/src/channel/ChannelUtil.kt index bd0532b4e7ff..e470473ab4fd 100644 --- a/platform/util/coroutines/src/channel/ChannelUtil.kt +++ b/platform/util/coroutines/src/channel/ChannelUtil.kt @@ -1,4 +1,4 @@ -// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. package com.intellij.platform.util.coroutines.channel import kotlinx.coroutines.CoroutineScope @@ -8,10 +8,10 @@ import kotlinx.coroutines.channels.consumeEach import kotlinx.coroutines.channels.trySendBlocking import kotlinx.coroutines.launch import org.jetbrains.annotations.ApiStatus -import java.io.ByteArrayInputStream import java.io.IOException import java.io.InputStream import java.io.OutputStream +import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingDeque import kotlin.coroutines.cancellation.CancellationException @@ -19,24 +19,32 @@ import kotlin.coroutines.cancellation.CancellationException Solution from fleet.api.exec.ExecApiProcess.kt. Maybe it should be merged somehow */ @ApiStatus.Experimental -class ChannelInputStream( - parentCoroutineScope: CoroutineScope, - private val channel: ReceiveChannel, -) : InputStream() { - private sealed class Content { - class Data(val stream: ByteArrayInputStream) : Content() - class Error(val cause: Throwable) : Content() - object End : Content() - } +class ChannelInputStream private constructor( + private val channel: ReceiveChannel<*>, +): InputStream() { + companion object { + fun forArrays(parentCoroutineScope: CoroutineScope, channel: ReceiveChannel): ChannelInputStream { + val result = ChannelInputStream(channel) + parentCoroutineScope.launch { + consumeChannel(channel, result.myBuffer) { ByteBuffer.wrap(it) } + } + return result + } - private val myBuffer = LinkedBlockingDeque() + fun forByteBuffers(parentCoroutineScope: CoroutineScope, channel: ReceiveChannel): ChannelInputStream { + val result = ChannelInputStream(channel) + parentCoroutineScope.launch { + consumeChannel(channel, result.myBuffer) { it } + } + return result + } - init { - parentCoroutineScope.launch { + private suspend inline fun consumeChannel(channel: ReceiveChannel, myBuffer: LinkedBlockingDeque, crossinline transform: (T) -> ByteBuffer) { try { - channel.consumeEach { bytes -> - if (bytes.isNotEmpty()) { - myBuffer.offerLast(Content.Data(ByteArrayInputStream(bytes))) + channel.consumeEach { obj -> + val bytes = transform(obj) + if (bytes.hasRemaining()) { + myBuffer.offerLast(Content.Data(bytes)) } } myBuffer.offerLast(Content.End) @@ -52,18 +60,28 @@ class ChannelInputStream( } } + private val myBuffer = LinkedBlockingDeque() + + private sealed class Content { + class Data(val buffer: ByteBuffer) : Content() + class Error(val cause: Throwable) : Content() + object End : Content() + } + override fun close() { channel.cancel(CancellationException("ChannelInputStream was closed")) } override fun read(): Int { val available = getAvailableBuffer() ?: return -1 - return available.read() + return available.get().toInt() } override fun read(b: ByteArray, off: Int, len: Int): Int { val available = getAvailableBuffer() ?: return -1 - return available.read(b, off, minOf(len, available.available())) + val resultSize = minOf(len, available.remaining()) + available.get(b, off, resultSize) + return resultSize } override tailrec fun available(): Int = @@ -71,7 +89,7 @@ class ChannelInputStream( null -> 0 is Content.Data -> { - val availableInCurrent = current.stream.available() + val availableInCurrent = current.buffer.remaining() if (availableInCurrent > 0) { myBuffer.putFirst(current) availableInCurrent @@ -87,7 +105,7 @@ class ChannelInputStream( } } - private fun getAvailableBuffer(): ByteArrayInputStream? { + private fun getAvailableBuffer(): ByteBuffer? { while (true) { val current = try { @@ -107,9 +125,9 @@ class ChannelInputStream( throw IOException(current.cause) } is Content.Data -> { - if (current.stream.available() > 0) { + if (current.buffer.hasRemaining()) { myBuffer.putFirst(current) - return current.stream + return current.buffer } } } @@ -120,9 +138,33 @@ class ChannelInputStream( private const val MAX_ARRAY_SIZE_SENT: Int = 1024 @ApiStatus.Experimental -class ChannelOutputStream(private val channel: SendChannel) : OutputStream() { +sealed class ChannelOutputStream(private val channel: SendChannel) : OutputStream() { + companion object { + fun forArrays(channel: SendChannel): ChannelOutputStream = + ForByteArray(channel) + + fun forByteBuffers(channel: SendChannel): ChannelOutputStream = + ForByteBuffer(channel) + } + + private class ForByteArray(channel: SendChannel) : ChannelOutputStream(channel) { + override fun fromByte(b: Byte): ByteArray = byteArrayOf(b) + override fun range(b: ByteArray, offset: Int, nextOffset: Int): ByteArray = b.copyOfRange(offset, nextOffset) + } + + private class ForByteBuffer(channel: SendChannel) : ChannelOutputStream(channel) { + override fun fromByte(b: Byte): ByteBuffer = ByteBuffer.wrap(byteArrayOf(b)) + override fun range(b: ByteArray, offset: Int, nextOffset: Int): ByteBuffer = ByteBuffer.wrap(b, offset, nextOffset) + } + + @ApiStatus.Internal + protected abstract fun fromByte(b: Byte): T + + @ApiStatus.Internal + protected abstract fun range(b: ByteArray, offset: Int, nextOffset: Int): T + override fun write(b: Int) { - val result = channel.trySendBlocking(byteArrayOf(b.toByte())) + val result = channel.trySendBlocking(fromByte(b.toByte())) when { result.isClosed -> throw IOException("Unable to write, channel is closed") result.isFailure -> throw IOException("Unable to write to channel", result.exceptionOrNull()) @@ -134,7 +176,7 @@ class ChannelOutputStream(private val channel: SendChannel) : OutputS while (offset < len) { val nextOffset = minOf(offset + MAX_ARRAY_SIZE_SENT, len) - val result = channel.trySendBlocking(b.copyOfRange(offset, nextOffset)) + val result = channel.trySendBlocking(range(b, offset, nextOffset)) when { result.isClosed -> throw IOException("Unable to write, channel is closed")