package org.deeplearning4j.plot;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/plot/PlotFilters.class */
public class PlotFilters {
    private INDArray plot;
    private INDArray input;
    private int[] tileShape;
    private int[] tileSpacing;
    private int[] imageShape;
    private boolean scaleRowsToInterval = true;
    private boolean outputPixels = true;

    public PlotFilters(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3) {
        this.tileSpacing = new int[]{0, 0};
        this.input = iNDArray;
        this.tileShape = iArr;
        this.tileSpacing = iArr2;
        this.imageShape = iArr3;
    }

    public INDArray getInput() {
        return this.input;
    }

    public void setInput(INDArray iNDArray) {
        this.input = iNDArray;
    }

    public INDArray scale(INDArray iNDArray) {
        return iNDArray.sub(iNDArray.min(new int[]{Integer.MAX_VALUE})).muli(Double.valueOf(1.0d / (Nd4j.EPS_THRESHOLD + iNDArray.max(new int[]{Integer.MAX_VALUE}).getDouble(0))));
    }

    public void plot() {
        int[] iArr = {((this.imageShape[0] + this.tileSpacing[0]) * this.tileShape[0]) - this.tileSpacing[0], ((this.imageShape[1] + this.tileSpacing[1]) * this.tileShape[1]) - this.tileSpacing[1]};
        if (this.input.rank() == 2) {
            this.plot = plotSection(this.input, iArr);
            return;
        }
        this.plot = Nd4j.zeros(new int[]{iArr[0], iArr[1], 4});
        for (int i = 0; i < 4; i++) {
            this.plot.putSlice(i, plotSection(this.input.slice(i), iArr));
        }
    }

    public INDArray getPlot() {
        return this.plot;
    }

    public void setPlot(INDArray iNDArray) {
        this.plot = iNDArray;
    }

    private INDArray plotSection(INDArray iNDArray, int[] iArr) {
        INDArray zeros = Nd4j.zeros(iArr);
        if (iNDArray.getLeadingOnes() == 2) {
            iNDArray = iNDArray.reshape(iNDArray.size(-2), iNDArray.size(-1));
        }
        int i = this.imageShape[0];
        int i2 = this.imageShape[1];
        int i3 = this.tileSpacing[0];
        int i4 = this.tileSpacing[1];
        for (int i5 = 0; i5 < this.tileShape[0]; i5++) {
            for (int i6 = 0; i6 < this.tileShape[1]; i6++) {
                if ((i5 * this.tileShape[1]) + i6 < iNDArray.size(0)) {
                    INDArray reshape = iNDArray.get(new INDArrayIndex[]{new NDArrayIndex(new int[]{(i5 * this.tileShape[1]) + i6})}).reshape(this.imageShape);
                    if (this.scaleRowsToInterval) {
                        reshape = scale(reshape);
                    }
                    if (this.outputPixels) {
                        reshape.muli(255);
                    }
                    zeros.put(new INDArrayIndex[]{NDArrayIndex.interval(i5 * (i + i3), (i5 * (i + i3)) + i), NDArrayIndex.interval(i6 * (i2 + i4), (i6 * (i2 + i4)) + i2)}, reshape);
                }
            }
        }
        return zeros;
    }
}
