/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.objdetect;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.layers.objdetect.DetectedObject;
import org.deeplearning4j.nn.layers.objdetect.YoloUtils;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.transforms.IsMax;
import org.nd4j.linalg.api.ops.impl.transforms.Not;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.lossfunctions.impl.LossL2;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;

public class Yolo2OutputLayer
extends AbstractLayer<org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer>
implements Serializable,
IOutputLayer {
    private static final Gradient EMPTY_GRADIENT = new DefaultGradient();
    protected INDArray labels;
    private double fullNetworkL1;
    private double fullNetworkL2;
    private double score;

    public Yolo2OutputLayer(NeuralNetConfiguration conf) {
        super(conf);
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        INDArray epsOut = this.computeBackpropGradientAndScore(workspaceMgr, false, false);
        return new Pair((Object)EMPTY_GRADIENT, (Object)epsOut);
    }

    private INDArray computeBackpropGradientAndScore(LayerWorkspaceMgr workspaceMgr, boolean scoreOnly, boolean computeScoreForExamples) {
        this.assertInputSet(true);
        Preconditions.checkState((this.labels != null ? 1 : 0) != 0, (String)"Cannot calculate gradients/score: labels are null");
        Preconditions.checkState((this.labels.rank() == 4 ? 1 : 0) != 0, (String)"Expected rank 4 labels array with shape [minibatch, 4+numClasses, h, w] but got rank %s labels array with shape %s", (Object)this.labels.rank(), (Object)this.labels.shape());
        double lambdaCoord = ((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getLambdaCoord();
        double lambdaNoObj = ((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getLambdaNoObj();
        int mb = (int)this.input.size(0);
        int h = (int)this.input.size(2);
        int w = (int)this.input.size(3);
        int b = (int)((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getBoundingBoxes().size(0);
        int c = (int)this.labels.size(1) - 4;
        int[] nhw = new int[]{mb, h, w};
        long size1 = this.labels.size(1);
        INDArray classLabels = this.labels.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((long)4L, (long)size1), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray maskObjectPresent = classLabels.sum(Nd4j.createUninitialized((int[])nhw, (char)'c'), new int[]{1});
        INDArray labelTLXY = this.labels.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)2), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray labelBRXY = this.labels.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval((int)2, (int)4), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray labelCenterXY = labelTLXY.add(labelBRXY).muli((Number)0.5);
        INDArray labelsCenterXYInGridBox = labelCenterXY.dup(labelCenterXY.ordering());
        labelsCenterXYInGridBox.subi(Transforms.floor((INDArray)labelsCenterXYInGridBox, (boolean)true));
        INDArray labelWHSqrt = labelBRXY.sub(labelTLXY);
        labelWHSqrt = Transforms.sqrt((INDArray)labelWHSqrt, (boolean)false);
        long[] expInputShape = new long[]{mb, b * (5 + c), h, w};
        long[] newShape = new long[]{mb, b, 5 + c, h, w};
        long newLength = ArrayUtil.prodLong((long[])newShape);
        Preconditions.checkState((boolean)Arrays.equals(expInputShape, this.input.shape()), (String)"Unable to reshape input - input array shape does not match expected shape. Expected input shape [minibatch, B*(5+C), H, W]=%s but got input of shape %ndShape. This may be due to an incorrect nOut (layer size/channels) for the last convolutional layer in the network. nOut of the last layer must be B*(5+C) where B is the number of bounding boxes, and C is the number of object classes. Expected B=%s from network configuration and C=%s from labels", (Object)expInputShape, (Object)this.input, (Object)b, (Object)c);
        INDArray input5 = this.input.dup('c').reshape('c', new int[]{mb, b, 5 + c, h, w});
        INDArray inputClassesPreSoftmax = input5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)5, (int)(5 + c)), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray preSigmoidPredictedXYCenterGrid = input5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)2), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray predictedXYCenterGrid = Transforms.sigmoid((INDArray)preSigmoidPredictedXYCenterGrid, (boolean)true);
        INDArray predictedWHPreExp = input5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)2, (int)4), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray predictedWH = Transforms.exp((INDArray)predictedWHPreExp, (boolean)true);
        Broadcast.mul((INDArray)predictedWH, (INDArray)((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getBoundingBoxes(), (INDArray)predictedWH, (int[])new int[]{1, 2});
        INDArray predictedWHSqrt = Transforms.sqrt((INDArray)predictedWH, (boolean)true);
        IOURet iouRet = Yolo2OutputLayer.calculateIOULabelPredicted(labelTLXY, labelBRXY, predictedWH, predictedXYCenterGrid, maskObjectPresent);
        INDArray iou = iouRet.getIou();
        INDArray mask1_ij_obj = Nd4j.getExecutioner().execAndReturn((TransformOp)new IsMax(iou.dup('c'), new int[]{1}));
        Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(mask1_ij_obj, maskObjectPresent, mask1_ij_obj, new int[]{0, 2, 3}));
        INDArray mask1_ij_noobj = Transforms.not((INDArray)mask1_ij_obj);
        INDArray labelConfidence = iou.mul(mask1_ij_obj);
        INDArray predictedConfidencePreSigmoid = input5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)4L), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray predictedConfidence = Transforms.sigmoid((INDArray)predictedConfidencePreSigmoid, (boolean)true);
        INDArray mask1_ij_obj_2d = mask1_ij_obj.reshape((long)(mb * b * h * w), 1L);
        INDArray mask1_ij_noobj_2d = Transforms.not((INDArray)mask1_ij_obj_2d);
        INDArray predictedXYCenter2d = predictedXYCenterGrid.permute(new int[]{0, 1, 3, 4, 2}).dup('c').reshape('c', mb * b * h * w, 2);
        INDArray labelsCenterXYInGridBroadcast = Nd4j.createUninitialized((int[])new int[]{mb, b, 2, h, w}, (char)'c');
        for (int i = 0; i < b; ++i) {
            labelsCenterXYInGridBroadcast.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((long)i), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()}).assign(labelsCenterXYInGridBox);
        }
        INDArray labelXYCenter2d = labelsCenterXYInGridBroadcast.permute(new int[]{0, 1, 3, 4, 2}).dup('c').reshape('c', mb * b * h * w, 2);
        INDArray predictedWHSqrt2d = predictedWHSqrt.permute(new int[]{0, 1, 3, 4, 2}).dup('c').reshape((long)(mb * b * h * w), 2L).dup('c');
        INDArray labelWHSqrtBroadcast = Nd4j.createUninitialized((int[])new int[]{mb, b, 2, h, w}, (char)'c');
        for (int i = 0; i < b; ++i) {
            labelWHSqrtBroadcast.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((long)i), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()}).assign(labelWHSqrt);
        }
        INDArray labelWHSqrt2d = labelWHSqrtBroadcast.permute(new int[]{0, 1, 3, 4, 2}).dup('c').reshape((long)(mb * b * h * w), 2L).dup('c');
        INDArray labelConfidence2d = labelConfidence.dup('c').reshape('c', mb * b * h * w, 1);
        INDArray predictedConfidence2d = predictedConfidence.dup('c').reshape('c', mb * b * h * w, 1).dup('c');
        INDArray predictedConfidence2dPreSigmoid = predictedConfidencePreSigmoid.dup('c').reshape('c', mb * b * h * w, 1).dup('c');
        INDArray classPredictionsPreSoftmax2d = inputClassesPreSoftmax.permute(new int[]{0, 1, 3, 4, 2}).dup('c').reshape('c', new int[]{mb * b * h * w, c});
        INDArray classLabelsBroadcast = Nd4j.createUninitialized((int[])new int[]{mb, b, c, h, w}, (char)'c');
        for (int i = 0; i < b; ++i) {
            classLabelsBroadcast.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point((long)i), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()}).assign(classLabels);
        }
        INDArray classLabels2d = classLabelsBroadcast.permute(new int[]{0, 1, 3, 4, 2}).dup('c').reshape('c', new int[]{mb * b * h * w, c});
        LossL2 lossConfidence = new LossL2();
        ActivationIdentity identity = new ActivationIdentity();
        if (computeScoreForExamples) {
            INDArray positionLoss = ((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getLossPositionScale().computeScoreArray(labelXYCenter2d, predictedXYCenter2d, (IActivation)identity, mask1_ij_obj_2d);
            INDArray sizeScaleLoss = ((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getLossPositionScale().computeScoreArray(labelWHSqrt2d, predictedWHSqrt2d, (IActivation)identity, mask1_ij_obj_2d);
            INDArray confidenceLossPt1 = lossConfidence.computeScoreArray(labelConfidence2d, predictedConfidence2d, (IActivation)identity, mask1_ij_obj_2d);
            INDArray confidenceLossPt2 = lossConfidence.computeScoreArray(labelConfidence2d, predictedConfidence2d, (IActivation)identity, mask1_ij_noobj_2d).muli((Number)lambdaNoObj);
            INDArray classPredictionLoss = ((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getLossClassPredictions().computeScoreArray(classLabels2d, classPredictionsPreSoftmax2d, (IActivation)new ActivationSoftmax(), mask1_ij_obj_2d);
            INDArray scoreForExamples = positionLoss.addi(sizeScaleLoss).muli((Number)lambdaCoord).addi(confidenceLossPt1).addi(confidenceLossPt2.muli((Number)lambdaNoObj)).addi(classPredictionLoss).dup('c');
            scoreForExamples = scoreForExamples.reshape('c', mb, b * h * w).sum(new int[]{1}).addi((Number)(this.fullNetworkL1 + this.fullNetworkL2));
            return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, scoreForExamples);
        }
        double positionLoss = ((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getLossPositionScale().computeScore(labelXYCenter2d, predictedXYCenter2d, (IActivation)identity, mask1_ij_obj_2d, false);
        double sizeScaleLoss = ((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getLossPositionScale().computeScore(labelWHSqrt2d, predictedWHSqrt2d, (IActivation)identity, mask1_ij_obj_2d, false);
        double confidenceLoss = lossConfidence.computeScore(labelConfidence2d, predictedConfidence2d, (IActivation)identity, mask1_ij_obj_2d, false) + lambdaNoObj * lossConfidence.computeScore(labelConfidence2d, predictedConfidence2d, (IActivation)identity, mask1_ij_noobj_2d, false);
        double classPredictionLoss = ((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getLossClassPredictions().computeScore(classLabels2d, classPredictionsPreSoftmax2d, (IActivation)new ActivationSoftmax(), mask1_ij_obj_2d, false);
        this.score = lambdaCoord * (positionLoss + sizeScaleLoss) + confidenceLoss + classPredictionLoss;
        this.score /= (double)this.getInputMiniBatchSize();
        this.score += this.fullNetworkL1 + this.fullNetworkL2;
        if (scoreOnly) {
            return null;
        }
        INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, this.input.shape(), 'c');
        INDArray epsOut5 = Shape.newShapeNoCopy((INDArray)epsOut, (int[])new int[]{mb, b, 5 + c, h, w}, (boolean)false);
        INDArray epsClassPredictions = epsOut5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)5, (int)(5 + c)), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray epsXY = epsOut5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)0, (int)2), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray epsWH = epsOut5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval((int)2, (int)4), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray epsC = epsOut5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)4L), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray gradPredictionLoss2d = ((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getLossClassPredictions().computeGradient(classLabels2d, classPredictionsPreSoftmax2d, (IActivation)new ActivationSoftmax(), mask1_ij_obj_2d);
        INDArray gradPredictionLoss5d = gradPredictionLoss2d.dup('c').reshape(new long[]{mb, b, h, w, c}).permute(new int[]{0, 1, 4, 2, 3}).dup('c');
        epsClassPredictions.assign(gradPredictionLoss5d);
        INDArray gradXYCenter2d = ((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getLossPositionScale().computeGradient(labelXYCenter2d, predictedXYCenter2d, (IActivation)identity, mask1_ij_obj_2d);
        gradXYCenter2d.muli((Number)lambdaCoord);
        INDArray gradXYCenter5d = gradXYCenter2d.dup('c').reshape('c', new int[]{mb, b, h, w, 2}).permute(new int[]{0, 1, 4, 2, 3});
        gradXYCenter5d = (INDArray)new ActivationSigmoid().backprop(preSigmoidPredictedXYCenterGrid.dup(), gradXYCenter5d).getFirst();
        epsXY.assign(gradXYCenter5d);
        INDArray gradWHSqrt2d = ((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getLossPositionScale().computeGradient(labelWHSqrt2d, predictedWHSqrt2d, (IActivation)identity, mask1_ij_obj_2d);
        INDArray gradWH2d = gradWHSqrt2d.muli((Number)0.5).divi(predictedWHSqrt2d);
        INDArray gradWH5d = gradWH2d.dup('c').reshape(new long[]{mb, b, h, w, 2L}).permute(new int[]{0, 1, 4, 2, 3});
        gradWH5d.muli(predictedWH);
        gradWH5d.muli((Number)lambdaCoord);
        epsWH.assign(gradWH5d);
        INDArray gradConfidence2dA = lossConfidence.computeGradient(labelConfidence2d, predictedConfidence2d, (IActivation)identity, mask1_ij_obj_2d);
        INDArray gradConfidence2dB = lossConfidence.computeGradient(labelConfidence2d, predictedConfidence2d, (IActivation)identity, mask1_ij_noobj_2d);
        INDArray dLc_dC_2d = gradConfidence2dA.addi(gradConfidence2dB.muli((Number)lambdaNoObj));
        INDArray dLc_dzc_2d = (INDArray)new ActivationSigmoid().backprop(predictedConfidence2dPreSigmoid, dLc_dC_2d).getFirst();
        INDArray epsConfidence4d = dLc_dzc_2d.dup('c').reshape('c', new int[]{mb, b, h, w});
        epsC.assign(epsConfidence4d);
        INDArray twoIOUSubPredicted = iou.subi(predictedConfidence).muli((Number)2.0);
        INDArray dLc_dIOU = twoIOUSubPredicted.muli(mask1_ij_obj.add(mask1_ij_noobj.muli((Number)lambdaNoObj)));
        INDArray dLc_dxy = Nd4j.createUninitialized((long[])iouRet.dIOU_dxy.shape(), (char)iouRet.dIOU_dxy.ordering());
        Broadcast.mul((INDArray)iouRet.dIOU_dxy, (INDArray)dLc_dIOU, (INDArray)dLc_dxy, (int[])new int[]{0, 1, 3, 4});
        INDArray dLc_dwh = Nd4j.createUninitialized((long[])iouRet.dIOU_dwh.shape(), (char)iouRet.dIOU_dwh.ordering());
        Broadcast.mul((INDArray)iouRet.dIOU_dwh, (INDArray)dLc_dIOU, (INDArray)dLc_dwh, (int[])new int[]{0, 1, 3, 4});
        INDArray dLc_din_wh = dLc_dwh.muli(predictedWH);
        INDArray dLc_din_xy = (INDArray)new ActivationSigmoid().backprop(preSigmoidPredictedXYCenterGrid, dLc_dxy).getFirst();
        Broadcast.mul((INDArray)dLc_din_wh, (INDArray)mask1_ij_obj, (INDArray)dLc_din_wh, (int[])new int[]{0, 1, 3, 4});
        Broadcast.mul((INDArray)dLc_din_xy, (INDArray)mask1_ij_obj, (INDArray)dLc_din_xy, (int[])new int[]{0, 1, 3, 4});
        epsWH.addi(dLc_din_wh);
        epsXY.addi(dLc_din_xy);
        return epsOut;
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(false);
        return YoloUtils.activate(((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getBoundingBoxes(), this.input, workspaceMgr);
    }

    public Layer clone() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public boolean needsLabels() {
        return true;
    }

    @Override
    public double computeScore(double fullNetworkL1, double fullNetworkL2, boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.fullNetworkL1 = fullNetworkL1;
        this.fullNetworkL2 = fullNetworkL2;
        this.computeBackpropGradientAndScore(workspaceMgr, true, false);
        return this.score();
    }

    @Override
    public double score() {
        return this.score;
    }

    private static IOURet calculateIOULabelPredicted(INDArray labelTL, INDArray labelBR, INDArray predictedWH, INDArray predictedXYinGridBox, INDArray objectPresentMask) {
        int mb = (int)labelTL.size(0);
        int h = (int)labelTL.size(2);
        int w = (int)labelTL.size(3);
        int b = (int)predictedWH.size(1);
        INDArray labelWH = labelBR.sub(labelTL);
        int gridH = (int)labelTL.size(2);
        int gridW = (int)labelTL.size(3);
        INDArray linspaceX = Nd4j.linspace((long)0L, (long)(gridW - 1), (long)gridW);
        INDArray linspaceY = Nd4j.linspace((long)0L, (long)(gridH - 1), (long)gridH);
        INDArray grid = Nd4j.createUninitialized((int[])new int[]{2, gridH, gridW}, (char)'c');
        INDArray gridX = grid.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.all(), NDArrayIndex.all()});
        INDArray gridY = grid.get(new INDArrayIndex[]{NDArrayIndex.point((long)1L), NDArrayIndex.all(), NDArrayIndex.all()});
        Broadcast.copy((INDArray)gridX, (INDArray)linspaceX, (INDArray)gridX, (int[])new int[]{1});
        Broadcast.copy((INDArray)gridY, (INDArray)linspaceY, (INDArray)gridY, (int[])new int[]{0});
        INDArray predictedXY = Nd4j.createUninitialized((long[])predictedXYinGridBox.shape(), (char)predictedXYinGridBox.ordering());
        Broadcast.add((INDArray)predictedXYinGridBox, (INDArray)grid, (INDArray)predictedXY, (int[])new int[]{2, 3, 4});
        INDArray halfWH = predictedWH.mul((Number)0.5);
        INDArray predictedTL_XY = halfWH.rsub(predictedXY);
        INDArray predictedBR_XY = halfWH.add(predictedXY);
        INDArray maxTL = Nd4j.createUninitialized((long[])predictedTL_XY.shape(), (char)predictedTL_XY.ordering());
        Broadcast.max((INDArray)predictedTL_XY, (INDArray)labelTL, (INDArray)maxTL, (int[])new int[]{0, 2, 3, 4});
        INDArray minBR = Nd4j.createUninitialized((long[])predictedBR_XY.shape(), (char)predictedBR_XY.ordering());
        Broadcast.min((INDArray)predictedBR_XY, (INDArray)labelBR, (INDArray)minBR, (int[])new int[]{0, 2, 3, 4});
        INDArray diff = minBR.sub(maxTL);
        INDArray intersectionArea = diff.prod(new int[]{2});
        Broadcast.mul((INDArray)intersectionArea, (INDArray)objectPresentMask, (INDArray)intersectionArea, (int[])new int[]{0, 2, 3});
        INDArray noIntMask1 = Nd4j.createUninitialized((long[])maxTL.shape(), (char)maxTL.ordering());
        INDArray noIntMask2 = Nd4j.createUninitialized((long[])maxTL.shape(), (char)maxTL.ordering());
        Broadcast.lt((INDArray)predictedBR_XY, (INDArray)labelTL, (INDArray)noIntMask1, (int[])new int[]{0, 2, 3, 4});
        Broadcast.gt((INDArray)predictedTL_XY, (INDArray)labelBR, (INDArray)noIntMask2, (int[])new int[]{0, 2, 3, 4});
        noIntMask1 = Transforms.or((INDArray)noIntMask1.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)0L), NDArrayIndex.all(), NDArrayIndex.all()}), (INDArray)noIntMask1.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)1L), NDArrayIndex.all(), NDArrayIndex.all()}));
        noIntMask2 = Transforms.or((INDArray)noIntMask2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)0L), NDArrayIndex.all(), NDArrayIndex.all()}), (INDArray)noIntMask2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)1L), NDArrayIndex.all(), NDArrayIndex.all()}));
        INDArray noIntMask = Transforms.or((INDArray)noIntMask1, (INDArray)noIntMask2);
        INDArray intMask = Nd4j.getExecutioner().execAndReturn((TransformOp)new Not(noIntMask, noIntMask, (Number)0.0));
        Broadcast.mul((INDArray)intMask, (INDArray)objectPresentMask, (INDArray)intMask, (int[])new int[]{0, 2, 3});
        intersectionArea.muli(intMask);
        INDArray areaPredicted = predictedWH.prod(new int[]{2});
        Broadcast.mul((INDArray)areaPredicted, (INDArray)objectPresentMask, (INDArray)areaPredicted, (int[])new int[]{0, 2, 3});
        INDArray areaLabel = labelWH.prod(new int[]{1});
        INDArray unionArea = Broadcast.add((INDArray)areaPredicted, (INDArray)areaLabel, (INDArray)areaPredicted.dup(), (int[])new int[]{0, 2, 3});
        unionArea.subi(intersectionArea);
        unionArea.muli(intMask);
        INDArray iou = intersectionArea.div(unionArea);
        BooleanIndexing.replaceWhere((INDArray)iou, (Number)0.0, (Condition)Conditions.isNan());
        Broadcast.mul((INDArray)iou, (INDArray)objectPresentMask, (INDArray)iou, (int[])new int[]{0, 2, 3});
        INDArray maskMaxTL = Nd4j.createUninitialized((long[])maxTL.shape(), (char)maxTL.ordering());
        Broadcast.gt((INDArray)predictedTL_XY, (INDArray)labelTL, (INDArray)maskMaxTL, (int[])new int[]{0, 2, 3, 4});
        INDArray maskMinBR = Nd4j.createUninitialized((long[])maxTL.shape(), (char)maxTL.ordering());
        Broadcast.lt((INDArray)predictedBR_XY, (INDArray)labelBR, (INDArray)maskMinBR, (int[])new int[]{0, 2, 3, 4});
        INDArray dI_dxy = maskMinBR.sub(maskMaxTL);
        INDArray dI_dwh = maskMinBR.add(maskMaxTL).muli((Number)0.5);
        dI_dxy.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)0L), NDArrayIndex.all(), NDArrayIndex.all()}).muli(diff.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)1L), NDArrayIndex.all(), NDArrayIndex.all()}));
        dI_dxy.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)1L), NDArrayIndex.all(), NDArrayIndex.all()}).muli(diff.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)0L), NDArrayIndex.all(), NDArrayIndex.all()}));
        dI_dwh.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)0L), NDArrayIndex.all(), NDArrayIndex.all()}).muli(diff.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)1L), NDArrayIndex.all(), NDArrayIndex.all()}));
        dI_dwh.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)1L), NDArrayIndex.all(), NDArrayIndex.all()}).muli(diff.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)0L), NDArrayIndex.all(), NDArrayIndex.all()}));
        INDArray uPlusI = unionArea.add(intersectionArea);
        INDArray u2 = unionArea.mul(unionArea);
        INDArray uPlusIDivU2 = uPlusI.div(u2);
        BooleanIndexing.replaceWhere((INDArray)uPlusIDivU2, (Number)0.0, (Condition)Conditions.isNan());
        INDArray dIOU_dxy = Nd4j.createUninitialized((int[])new int[]{mb, b, 2, h, w}, (char)'c');
        Broadcast.mul((INDArray)dI_dxy, (INDArray)uPlusIDivU2, (INDArray)dIOU_dxy, (int[])new int[]{0, 1, 3, 4});
        INDArray predictedHW = Nd4j.createUninitialized((int[])new int[]{mb, b, 2, h, w}, (char)predictedWH.ordering());
        predictedHW.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)0L), NDArrayIndex.all(), NDArrayIndex.all()}).assign(predictedWH.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)1L), NDArrayIndex.all(), NDArrayIndex.all()}));
        predictedHW.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)1L), NDArrayIndex.all(), NDArrayIndex.all()}).assign(predictedWH.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point((long)0L), NDArrayIndex.all(), NDArrayIndex.all()}));
        INDArray Ihw = Nd4j.createUninitialized((long[])predictedHW.shape(), (char)predictedHW.ordering());
        Broadcast.mul((INDArray)predictedHW, (INDArray)intersectionArea, (INDArray)Ihw, (int[])new int[]{0, 1, 3, 4});
        INDArray dIOU_dwh = Nd4j.createUninitialized((int[])new int[]{mb, b, 2, h, w}, (char)'c');
        Broadcast.mul((INDArray)dI_dwh, (INDArray)uPlusI, (INDArray)dIOU_dwh, (int[])new int[]{0, 1, 3, 4});
        dIOU_dwh.subi(Ihw);
        Broadcast.div((INDArray)dIOU_dwh, (INDArray)u2, (INDArray)dIOU_dwh, (int[])new int[]{0, 1, 3, 4});
        BooleanIndexing.replaceWhere((INDArray)dIOU_dwh, (Number)0.0, (Condition)Conditions.isNan());
        return new IOURet(iou, dIOU_dxy, dIOU_dwh);
    }

    @Override
    public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair((Object)this.gradient(), (Object)this.score());
    }

    @Override
    public INDArray computeScoreForExamples(double fullNetworkL1, double fullNetworkL2, LayerWorkspaceMgr workspaceMgr) {
        this.fullNetworkL1 = fullNetworkL1;
        this.fullNetworkL2 = fullNetworkL2;
        return this.computeBackpropGradientAndScore(workspaceMgr, false, true);
    }

    @Override
    public double f1Score(DataSet data) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public double f1Score(INDArray examples, INDArray labels) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public int numLabels() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void fit(DataSetIterator iter) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public int[] predict(INDArray examples) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public List<String> predict(DataSet dataSet) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public INDArray labelProbabilities(INDArray examples) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void fit(INDArray examples, INDArray labels) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void fit(DataSet data) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void fit(INDArray examples, int[] labels) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public void clearNoiseWeightParams() {
    }

    public List<DetectedObject> getPredictedObjects(INDArray networkOutput, double threshold) {
        return YoloUtils.getPredictedObjects(((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getBoundingBoxes(), networkOutput, threshold, 0.0);
    }

    public INDArray getConfidenceMatrix(INDArray networkOutput, int example, int bbNumber) {
        INDArray conf = networkOutput.get(new INDArrayIndex[]{NDArrayIndex.point((long)example), NDArrayIndex.point((long)(4 + bbNumber * 5)), NDArrayIndex.all(), NDArrayIndex.all()});
        return conf;
    }

    public INDArray getProbabilityMatrix(INDArray networkOutput, int example, int classNumber) {
        long bbs = ((org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer)this.layerConf()).getBoundingBoxes().size(0);
        INDArray conf = networkOutput.get(new INDArrayIndex[]{NDArrayIndex.point((long)example), NDArrayIndex.point((long)(5L * bbs + (long)classNumber)), NDArrayIndex.all(), NDArrayIndex.all()});
        return conf;
    }

    @Override
    public void setLabels(INDArray labels) {
        this.labels = labels;
    }

    @Override
    public INDArray getLabels() {
        return this.labels;
    }

    private static class IOURet {
        private INDArray iou;
        private INDArray dIOU_dxy;
        private INDArray dIOU_dwh;

        public IOURet(INDArray iou, INDArray dIOU_dxy, INDArray dIOU_dwh) {
            this.iou = iou;
            this.dIOU_dxy = dIOU_dxy;
            this.dIOU_dwh = dIOU_dwh;
        }

        public INDArray getIou() {
            return this.iou;
        }

        public INDArray getDIOU_dxy() {
            return this.dIOU_dxy;
        }

        public INDArray getDIOU_dwh() {
            return this.dIOU_dwh;
        }

        public void setIou(INDArray iou) {
            this.iou = iou;
        }

        public void setDIOU_dxy(INDArray dIOU_dxy) {
            this.dIOU_dxy = dIOU_dxy;
        }

        public void setDIOU_dwh(INDArray dIOU_dwh) {
            this.dIOU_dwh = dIOU_dwh;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof IOURet)) {
                return false;
            }
            IOURet other = (IOURet)o;
            if (!other.canEqual(this)) {
                return false;
            }
            INDArray this$iou = this.getIou();
            INDArray other$iou = other.getIou();
            if (this$iou == null ? other$iou != null : !this$iou.equals(other$iou)) {
                return false;
            }
            INDArray this$dIOU_dxy = this.getDIOU_dxy();
            INDArray other$dIOU_dxy = other.getDIOU_dxy();
            if (this$dIOU_dxy == null ? other$dIOU_dxy != null : !this$dIOU_dxy.equals(other$dIOU_dxy)) {
                return false;
            }
            INDArray this$dIOU_dwh = this.getDIOU_dwh();
            INDArray other$dIOU_dwh = other.getDIOU_dwh();
            return !(this$dIOU_dwh == null ? other$dIOU_dwh != null : !this$dIOU_dwh.equals(other$dIOU_dwh));
        }

        protected boolean canEqual(Object other) {
            return other instanceof IOURet;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            INDArray $iou = this.getIou();
            result = result * 59 + ($iou == null ? 43 : $iou.hashCode());
            INDArray $dIOU_dxy = this.getDIOU_dxy();
            result = result * 59 + ($dIOU_dxy == null ? 43 : $dIOU_dxy.hashCode());
            INDArray $dIOU_dwh = this.getDIOU_dwh();
            result = result * 59 + ($dIOU_dwh == null ? 43 : $dIOU_dwh.hashCode());
            return result;
        }

        public String toString() {
            return "Yolo2OutputLayer.IOURet(iou=" + this.getIou() + ", dIOU_dxy=" + this.getDIOU_dxy() + ", dIOU_dwh=" + this.getDIOU_dwh() + ")";
        }
    }
}

