package org.nd4j.linalg.lossfunctions;

import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/nd4j/linalg/lossfunctions/LossUtil.class */
public class LossUtil {
    public static boolean isPerOutputMasking(INDArray iNDArray, INDArray iNDArray2) {
        return !iNDArray2.isColumnVector() || Arrays.equals(iNDArray.shape(), iNDArray2.shape());
    }

    public static void applyMask(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray2.isColumnVector()) {
            iNDArray.muliColumnVector(iNDArray2);
        } else {
            if (!Arrays.equals(iNDArray.shape(), iNDArray2.shape())) {
                throw new IllegalStateException("Invalid mask array: per-example masking should be a column vector, per output masking arrays should be the same shape as the labels array. Mask shape: " + Arrays.toString(iNDArray2.shape()) + ", output shape: " + Arrays.toString(iNDArray.shape()));
            }
            iNDArray.muli(iNDArray2);
        }
    }
}
