package org.apache.flink.runtime.io.network.partition;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Queue;
import java.util.Random;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.checkpoint.PendingCheckpointTest;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.BufferPool;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
/* loaded from: input_file:org/apache/flink/runtime/io/network/partition/DataBufferTest.class */
public class DataBufferTest {
    private final boolean useHashBuffer;

    /* loaded from: input_file:org/apache/flink/runtime/io/network/partition/DataBufferTest$DataAndType.class */
    public static class DataAndType {
        private final ByteBuffer data;
        private final Buffer.DataType dataType;

        /* JADX INFO: Access modifiers changed from: package-private */
        public DataAndType(ByteBuffer byteBuffer, Buffer.DataType dataType) {
            this.data = byteBuffer;
            this.dataType = dataType;
        }
    }

    @Parameterized.Parameters(name = "UseHashBuffer = {0}")
    public static Object[] parameters() {
        return new Object[]{true, false};
    }

    public DataBufferTest(boolean z) {
        this.useHashBuffer = z;
    }

    @Test
    public void testWriteAndReadDataBuffer() throws Exception {
        BufferWithChannel copyIntoSegment;
        Random random = new Random(1111L);
        Queue[] queueArr = new Queue[10];
        Queue<Buffer>[] queueArr2 = new Queue[10];
        for (int i = 0; i < 10; i++) {
            queueArr[i] = new ArrayDeque();
            queueArr2[i] = new ArrayDeque();
        }
        int[] iArr = new int[10];
        int[] iArr2 = new int[10];
        Arrays.fill(iArr, 0);
        Arrays.fill(iArr2, 0);
        int i2 = 0;
        DataBuffer createDataBuffer = createDataBuffer(512, 1024, 10, getRandomSubpartitionOrder(10));
        int i3 = 5;
        while (i3 > 0) {
            byte[] bArr = new byte[random.nextInt((1024 * 4) - 1) + 1];
            random.nextBytes(bArr);
            ByteBuffer wrap = ByteBuffer.wrap(bArr);
            int nextInt = random.nextInt(10);
            Buffer.DataType dataType = random.nextBoolean() ? Buffer.DataType.DATA_BUFFER : Buffer.DataType.EVENT_BUFFER;
            boolean append = createDataBuffer.append(wrap, nextInt, dataType);
            wrap.flip();
            if (wrap.hasRemaining()) {
                queueArr[nextInt].add(new DataAndType(wrap, dataType));
                iArr[nextInt] = iArr[nextInt] + wrap.remaining();
                i2 += wrap.remaining();
            }
            while (append && createDataBuffer.hasRemaining() && (copyIntoSegment = copyIntoSegment(1024, createDataBuffer)) != null) {
                addBufferRead(copyIntoSegment, queueArr2, iArr2);
            }
            if (append) {
                i3--;
                createDataBuffer.reset();
            }
        }
        if (createDataBuffer.hasRemaining()) {
            Assert.assertTrue(createDataBuffer instanceof HashBasedDataBuffer);
            createDataBuffer.reset();
            createDataBuffer.finish();
            while (createDataBuffer.hasRemaining()) {
                addBufferRead(copyIntoSegment(1024, createDataBuffer), queueArr2, iArr2);
            }
        }
        Assert.assertEquals(i2, createDataBuffer.numTotalBytes());
        checkWriteReadResult(10, iArr, iArr2, queueArr, queueArr2);
    }

    private BufferWithChannel copyIntoSegment(int i, DataBuffer dataBuffer) {
        if (!this.useHashBuffer) {
            return dataBuffer.getNextBuffer(MemorySegmentFactory.allocateUnpooledSegment(i));
        }
        BufferWithChannel nextBuffer = dataBuffer.getNextBuffer((MemorySegment) null);
        if (nextBuffer == null || !nextBuffer.getBuffer().isBuffer()) {
            return nextBuffer;
        }
        MemorySegment allocateUnpooledSegment = MemorySegmentFactory.allocateUnpooledSegment(i);
        int readableBytes = nextBuffer.getBuffer().readableBytes();
        allocateUnpooledSegment.put(0, nextBuffer.getBuffer().getNioBufferReadable(), readableBytes);
        nextBuffer.getBuffer().recycleBuffer();
        return new BufferWithChannel(new NetworkBuffer(allocateUnpooledSegment, (v0) -> {
            v0.free();
        }, Buffer.DataType.DATA_BUFFER, readableBytes), nextBuffer.getChannelIndex());
    }

    private void addBufferRead(BufferWithChannel bufferWithChannel, Queue<Buffer>[] queueArr, int[] iArr) {
        int channelIndex = bufferWithChannel.getChannelIndex();
        queueArr[channelIndex].add(bufferWithChannel.getBuffer());
        iArr[channelIndex] = iArr[channelIndex] + bufferWithChannel.getBuffer().readableBytes();
    }

    public static void checkWriteReadResult(int i, int[] iArr, int[] iArr2, Queue<DataAndType>[] queueArr, Queue<Buffer>[] queueArr2) {
        for (int i2 = 0; i2 < i; i2++) {
            Assert.assertEquals(iArr[i2], iArr2[i2]);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ByteBuffer allocate = ByteBuffer.allocate(iArr[i2]);
            for (DataAndType dataAndType : queueArr[i2]) {
                allocate.put(dataAndType.data);
                dataAndType.data.rewind();
                if (dataAndType.dataType.isEvent()) {
                    arrayList.add(dataAndType);
                }
            }
            ByteBuffer allocate2 = ByteBuffer.allocate(iArr2[i2]);
            for (Buffer buffer : queueArr2[i2]) {
                allocate2.put(buffer.getNioBufferReadable());
                if (!buffer.isBuffer()) {
                    arrayList2.add(buffer);
                }
            }
            allocate.flip();
            allocate2.flip();
            Assert.assertEquals(allocate, allocate2);
            Assert.assertEquals(arrayList.size(), arrayList2.size());
            for (int i3 = 0; i3 < arrayList.size(); i3++) {
                Assert.assertEquals(((DataAndType) arrayList.get(i3)).dataType, ((Buffer) arrayList2.get(i3)).getDataType());
                Assert.assertEquals(((DataAndType) arrayList.get(i3)).data, ((Buffer) arrayList2.get(i3)).getNioBufferReadable());
            }
        }
    }

    @Test
    public void testWriteReadWithEmptyChannel() throws Exception {
        ByteBuffer[] byteBufferArr = {ByteBuffer.allocate(PendingCheckpointTest.MAX_PARALLELISM), null, ByteBuffer.allocate(1536), null, ByteBuffer.allocate(1024)};
        DataBuffer createDataBuffer = createDataBuffer(10, 1024, 5);
        for (int i = 0; i < 5; i++) {
            ByteBuffer byteBuffer = byteBufferArr[i];
            if (byteBuffer != null) {
                createDataBuffer.append(byteBuffer, i, Buffer.DataType.DATA_BUFFER);
                byteBuffer.rewind();
            }
        }
        createDataBuffer.finish();
        checkReadResult(createDataBuffer, byteBufferArr[0], 0, 1024);
        ByteBuffer duplicate = byteBufferArr[2].duplicate();
        duplicate.limit(1024);
        checkReadResult(createDataBuffer, duplicate.slice(), 2, 1024);
        ByteBuffer duplicate2 = byteBufferArr[2].duplicate();
        duplicate2.position(1024);
        checkReadResult(createDataBuffer, duplicate2.slice(), 2, 1024);
        checkReadResult(createDataBuffer, byteBufferArr[4], 4, 1024);
    }

    private void checkReadResult(DataBuffer dataBuffer, ByteBuffer byteBuffer, int i, int i2) {
        BufferWithChannel nextBuffer = dataBuffer.getNextBuffer(MemorySegmentFactory.allocateUnpooledSegment(i2));
        Assert.assertEquals(i, nextBuffer.getChannelIndex());
        Assert.assertEquals(byteBuffer, nextBuffer.getBuffer().getNioBufferReadable());
    }

    @Test(expected = IllegalArgumentException.class)
    public void testWriteEmptyData() throws Exception {
        DataBuffer createDataBuffer = createDataBuffer(1, 1024, 1);
        ByteBuffer allocate = ByteBuffer.allocate(1);
        allocate.position(1);
        createDataBuffer.append(allocate, 0, Buffer.DataType.DATA_BUFFER);
    }

    @Test(expected = IllegalStateException.class)
    public void testWriteFinishedDataBuffer() throws Exception {
        DataBuffer createDataBuffer = createDataBuffer(1, 1024, 1);
        createDataBuffer.finish();
        createDataBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
    }

    @Test(expected = IllegalStateException.class)
    public void testWriteReleasedDataBuffer() throws Exception {
        DataBuffer createDataBuffer = createDataBuffer(1, 1024, 1);
        createDataBuffer.release();
        createDataBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
    }

    @Test
    public void testWriteMoreDataThanCapacity() throws Exception {
        DataBuffer createDataBuffer = createDataBuffer(10, 1024, 1);
        for (int i = 1; i < 10; i++) {
            appendAndCheckResult(createDataBuffer, 1024, false, 1024 * i, i, true);
        }
        appendAndCheckResult(createDataBuffer, 1024 + 1, true, this.useHashBuffer ? 1024 * 10 : 1024 * r0, 10 - 1, true);
    }

    @Test
    public void testWriteLargeRecord() throws Exception {
        appendAndCheckResult(createDataBuffer(10, 1024, 1), (10 * 1024) + 1, true, this.useHashBuffer ? 10 * 1024 : 0L, 0L, this.useHashBuffer);
    }

    private void appendAndCheckResult(DataBuffer dataBuffer, int i, boolean z, long j, long j2, boolean z2) throws IOException {
        Assert.assertEquals(Boolean.valueOf(z), Boolean.valueOf(dataBuffer.append(ByteBuffer.allocate(i), 0, Buffer.DataType.DATA_BUFFER)));
        Assert.assertEquals(j, dataBuffer.numTotalBytes());
        Assert.assertEquals(j2, dataBuffer.numTotalRecords());
        Assert.assertEquals(Boolean.valueOf(z2), Boolean.valueOf(dataBuffer.hasRemaining()));
    }

    @Test(expected = IllegalStateException.class)
    public void testReadUnfinishedDataBuffer() throws Exception {
        DataBuffer createDataBuffer = createDataBuffer(1, 1024, 1);
        createDataBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
        Assert.assertTrue(createDataBuffer.hasRemaining());
        createDataBuffer.getNextBuffer(MemorySegmentFactory.allocateUnpooledSegment(1024));
    }

    @Test(expected = IllegalStateException.class)
    public void testReadReleasedDataBuffer() throws Exception {
        DataBuffer createDataBuffer = createDataBuffer(1, 1024, 1);
        createDataBuffer.append(ByteBuffer.allocate(1), 0, Buffer.DataType.DATA_BUFFER);
        createDataBuffer.finish();
        Assert.assertTrue(createDataBuffer.hasRemaining());
        createDataBuffer.release();
        Assert.assertTrue(createDataBuffer.hasRemaining());
        createDataBuffer.getNextBuffer(MemorySegmentFactory.allocateUnpooledSegment(1024));
    }

    @Test
    public void testReadEmptyDataBuffer() throws Exception {
        DataBuffer createDataBuffer = createDataBuffer(1, 1024, 1);
        createDataBuffer.finish();
        Assert.assertFalse(createDataBuffer.hasRemaining());
        Assert.assertNull(createDataBuffer.getNextBuffer(MemorySegmentFactory.allocateUnpooledSegment(1024)));
    }

    @Test
    public void testReleaseDataBuffer() throws Exception {
        int i = (10 - 1) * 1024;
        SortBasedDataBuffer sortBasedDataBuffer = new SortBasedDataBuffer(new NetworkBufferPool(10, 1024).createBufferPool(10, 10), 1, 1024, 10, (int[]) null);
        sortBasedDataBuffer.append(ByteBuffer.allocate(i), 0, Buffer.DataType.DATA_BUFFER);
        Assert.assertEquals(10, r0.bestEffortGetNumOfUsedBuffers());
        Assert.assertTrue(sortBasedDataBuffer.hasRemaining());
        Assert.assertEquals(1L, sortBasedDataBuffer.numTotalRecords());
        Assert.assertEquals(i, sortBasedDataBuffer.numTotalBytes());
        sortBasedDataBuffer.release();
        Assert.assertEquals(0L, r0.bestEffortGetNumOfUsedBuffers());
        Assert.assertTrue(sortBasedDataBuffer.hasRemaining());
        Assert.assertEquals(1L, sortBasedDataBuffer.numTotalRecords());
        Assert.assertEquals(i, sortBasedDataBuffer.numTotalBytes());
    }

    private DataBuffer createDataBuffer(int i, int i2, int i3) throws IOException {
        return createDataBuffer(i, i2, i3, null);
    }

    private DataBuffer createDataBuffer(int i, int i2, int i3, int[] iArr) throws IOException {
        BufferPool createBufferPool = new NetworkBufferPool(i, i2).createBufferPool(i, i);
        return this.useHashBuffer ? new HashBasedDataBuffer(createBufferPool, i3, i, iArr) : new SortBasedDataBuffer(createBufferPool, i3, i2, i, iArr);
    }

    public static int[] getRandomSubpartitionOrder(int i) {
        int[] iArr = new int[i];
        int nextInt = new Random(1111L).nextInt(i);
        for (int i2 = 0; i2 < i; i2++) {
            iArr[i2] = (i2 + nextInt) % i;
        }
        return iArr;
    }
}
