package org.apache.sysds.runtime.instructions.spark.data;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectInputStream;
import java.io.ObjectOutput;
import java.io.ObjectOutputStream;
import java.util.Arrays;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlockFactory;
import org.apache.sysds.runtime.util.FastBufferedDataInputStream;
import org.apache.sysds.runtime.util.FastBufferedDataOutputStream;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/spark/data/PartitionedBlock.class */
public class PartitionedBlock<T extends CacheBlock> implements Externalizable {
    private static final long serialVersionUID = 1298817743064415129L;
    protected CacheBlock[] _partBlocks;
    protected long[] _dims;
    protected int _blen;
    protected int _offset;

    public PartitionedBlock() {
        this._partBlocks = null;
        this._dims = new long[]{-1, -1};
        this._blen = -1;
        this._offset = 0;
    }

    public PartitionedBlock(T t, int i) {
        this._partBlocks = null;
        this._dims = new long[]{-1, -1};
        this._blen = -1;
        this._offset = 0;
        int numRows = t.getNumRows();
        int numColumns = t.getNumColumns();
        this._dims = new long[]{numRows, numColumns};
        this._blen = i;
        int numRowBlocks = getNumRowBlocks();
        int numColumnBlocks = getNumColumnBlocks();
        int code = CacheBlockFactory.getCode(t);
        try {
            this._partBlocks = new CacheBlock[numRowBlocks * numColumnBlocks];
            Arrays.parallelSetAll(this._partBlocks, i2 -> {
                int i2 = i2 / numColumnBlocks;
                int i3 = i2 % numColumnBlocks;
                return t.slice(i2 * this._blen, Math.min((i2 + 1) * this._blen, numRows) - 1, i3 * this._blen, Math.min((i3 + 1) * this._blen, numColumns) - 1, false, CacheBlockFactory.newInstance(code));
            });
            this._offset = 0;
        } catch (Exception e) {
            throw new RuntimeException("Failed partitioning of broadcast variable input.", e);
        }
    }

    public PartitionedBlock(T t, long[] jArr, int i) {
        this._partBlocks = null;
        this._dims = new long[]{-1, -1};
        this._blen = -1;
        this._offset = 0;
        this._dims = jArr;
        this._blen = i;
        int i2 = 1;
        for (int i3 = 0; i3 < jArr.length; i3++) {
            i2 *= getNumDimBlocks(i3);
        }
        int code = CacheBlockFactory.getCode(t);
        try {
            this._partBlocks = new CacheBlock[i2];
            int numRowBlocks = i2 / getNumRowBlocks();
            Arrays.parallelSetAll(this._partBlocks, i4 -> {
                int i4 = i4 / numRowBlocks;
                int i5 = i4 % numRowBlocks;
                return t.slice(i4 * this._blen, Math.min((i4 + 1) * this._blen, (int) this._dims[0]) - 1, i5 * this._blen, Math.min((i5 + 1) * this._blen, (int) this._dims[1]) - 1, (int) CacheBlockFactory.newInstance(code));
            });
            this._offset = 0;
        } catch (Exception e) {
            throw new RuntimeException("Failed partitioning of broadcast variable input.", e);
        }
    }

    public PartitionedBlock(int i, int i2, int i3) {
        this._partBlocks = null;
        this._dims = new long[]{-1, -1};
        this._blen = -1;
        this._offset = 0;
        this._dims = new long[]{i, i2};
        this._blen = i3;
        this._partBlocks = new CacheBlock[getNumRowBlocks() * getNumColumnBlocks()];
    }

    public PartitionedBlock<T> createPartition(int i, int i2) {
        PartitionedBlock<T> partitionedBlock = new PartitionedBlock<>();
        partitionedBlock._dims = (long[]) this._dims.clone();
        partitionedBlock._blen = this._blen;
        partitionedBlock._partBlocks = new CacheBlock[i2];
        partitionedBlock._offset = i;
        System.arraycopy(this._partBlocks, i, partitionedBlock._partBlocks, 0, i2);
        return partitionedBlock;
    }

    public long getNumRows() {
        return this._dims[0];
    }

    public long getNumCols() {
        return this._dims[1];
    }

    public long getDim(int i) {
        return this._dims[i];
    }

    public long getBlocksize() {
        return this._blen;
    }

    public int getNumRowBlocks() {
        return getNumDimBlocks(0);
    }

    public int getNumColumnBlocks() {
        return getNumDimBlocks(1);
    }

    public int getNumDimBlocks(int i) {
        return (int) Math.ceil(this._dims[i] / this._blen);
    }

    public T getBlock(int i, int i2) {
        int numRowBlocks = getNumRowBlocks();
        int numColumnBlocks = getNumColumnBlocks();
        if (i <= 0 || i > numRowBlocks || i2 <= 0 || i2 > numColumnBlocks) {
            throw new DMLRuntimeException("Block indexes [" + i + "," + i2 + "] out of range [" + numRowBlocks + "," + numColumnBlocks + "]");
        }
        return (T) this._partBlocks[(((i - 1) * numColumnBlocks) + (i2 - 1)) - this._offset];
    }

    public T getBlock(int[] iArr) {
        return (T) this._partBlocks[(int) (UtilFunctions.computeBlockNumber(iArr, this._dims, this._blen) - this._offset)];
    }

    public void setBlock(int i, int i2, T t) {
        int numRowBlocks = getNumRowBlocks();
        int numColumnBlocks = getNumColumnBlocks();
        if (i <= 0 || i > numRowBlocks || i2 <= 0 || i2 > numColumnBlocks) {
            throw new DMLRuntimeException("Block indexes [" + i + "," + i2 + "] out of range [" + numRowBlocks + "," + numColumnBlocks + "]");
        }
        this._partBlocks[(((i - 1) * numColumnBlocks) + (i2 - 1)) - this._offset] = t;
    }

    public long getInMemorySize() {
        long j = 24 + 32;
        if (this._partBlocks != null) {
            for (CacheBlock cacheBlock : this._partBlocks) {
                j += cacheBlock.getInMemorySize();
            }
        }
        return j;
    }

    public long getExactSerializedSize() {
        long j = 24;
        if (this._partBlocks != null) {
            for (CacheBlock cacheBlock : this._partBlocks) {
                j += cacheBlock.getExactSerializedSize();
            }
        }
        return j;
    }

    public void clearBlocks() {
    }

    @Override // java.io.Externalizable
    public void readExternal(ObjectInput objectInput) throws IOException {
        DataInput dataInput = objectInput;
        int readHeader = readHeader(dataInput);
        if ((objectInput instanceof ObjectInputStream) && readHeader == 0) {
            dataInput = new FastBufferedDataInputStream((ObjectInputStream) objectInput);
        }
        readPayload(dataInput, readHeader);
    }

    @Override // java.io.Externalizable
    public void writeExternal(ObjectOutput objectOutput) throws IOException {
        if (!(objectOutput instanceof ObjectOutputStream)) {
            writeHeaderAndPayload(objectOutput);
            return;
        }
        FastBufferedDataOutputStream fastBufferedDataOutputStream = new FastBufferedDataOutputStream((ObjectOutputStream) objectOutput);
        writeHeaderAndPayload(fastBufferedDataOutputStream);
        fastBufferedDataOutputStream.flush();
    }

    private void writeHeaderAndPayload(DataOutput dataOutput) throws IOException {
        dataOutput.writeInt(this._dims.length);
        for (long j : this._dims) {
            dataOutput.writeLong(j);
        }
        dataOutput.writeInt(this._blen);
        dataOutput.writeInt(this._offset);
        dataOutput.writeInt(this._partBlocks.length);
        dataOutput.writeByte(CacheBlockFactory.getCode(this._partBlocks[0]));
        for (CacheBlock cacheBlock : this._partBlocks) {
            cacheBlock.write(dataOutput);
        }
    }

    private int readHeader(DataInput dataInput) throws IOException {
        int readInt = dataInput.readInt();
        this._dims = new long[readInt];
        for (int i = 0; i < readInt; i++) {
            this._dims[i] = dataInput.readLong();
        }
        this._blen = dataInput.readInt();
        this._offset = dataInput.readInt();
        int readInt2 = dataInput.readInt();
        byte readByte = dataInput.readByte();
        this._partBlocks = new CacheBlock[readInt2];
        return readByte;
    }

    private void readPayload(DataInput dataInput, int i) throws IOException {
        int length = this._partBlocks.length;
        for (int i2 = 0; i2 < length; i2++) {
            this._partBlocks[i2] = CacheBlockFactory.newInstance(i);
            this._partBlocks[i2].readFields(dataInput);
        }
    }
}
