package org.deeplearning4j.rbm;

import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

/* loaded from: input_file:org/deeplearning4j/rbm/ConvolutionalRBM.class */
public class ConvolutionalRBM extends RBM {
    private static final long serialVersionUID = 6868729665328916878L;
    private int numFilters;
    private int poolRows;
    private int poolColumns;

    public DoubleMatrix visibleExpectation(DoubleMatrix doubleMatrix, double d) {
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(this.numFilters);
        for (int i = 0; i < this.numFilters; i++) {
            doubleMatrix2.putRow(i, MatrixUtil.convolution2D(doubleMatrix, doubleMatrix.columns, doubleMatrix.rows, getW().getRow(i), getW().rows, getW().columns).add(getvBias().add(d)).transpose());
        }
        DoubleMatrix pool = pool(doubleMatrix2);
        pool.addi(1.0d);
        return MatrixUtil.sigmoid(MatrixUtil.oneDiv(pool));
    }

    public DoubleMatrix pooledExpectation(DoubleMatrix doubleMatrix, double d) {
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(this.numFilters);
        for (int i = 0; i < this.numFilters; i++) {
            doubleMatrix2.putRow(i, MatrixUtil.convolution2D(doubleMatrix, doubleMatrix.columns, doubleMatrix.rows, getW().getRow(i), getW().rows, getW().columns).add(gethBias().add(d)).transpose());
        }
        DoubleMatrix pool = pool(doubleMatrix2);
        pool.addi(1.0d);
        return MatrixUtil.oneDiv(pool);
    }

    public DoubleMatrix pool(DoubleMatrix doubleMatrix) {
        DoubleMatrix exp = MatrixFunctions.exp(doubleMatrix.transpose());
        DoubleMatrix zeros = DoubleMatrix.zeros(exp.rows, exp.columns);
        int ceil = (int) Math.ceil(this.poolColumns / doubleMatrix.columns);
        for (int i = 0; i < ceil; i++) {
            int i2 = i * this.poolColumns;
            int i3 = (i + 1) * this.poolColumns;
            int ceil2 = (int) Math.ceil(this.poolRows / doubleMatrix.rows);
            for (int i4 = 0; i4 < ceil2; i4++) {
                int i5 = i4 * this.poolRows;
                int i6 = (i4 + 1) * this.poolRows;
                zeros.put(new int[]{i5, i6}, new int[]{i2, i3}, exp.get(new int[]{i5, i6}, new int[]{i2, i3}).rowSums().rowSums());
            }
        }
        return zeros.transpose();
    }
}
