package org.nd4j.autodiff.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.ReductionShape;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JException;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/autodiff/util/SameDiffUtils.class */
public class SameDiffUtils {
    public static Map<String, INDArray> stackOutputs(List<Map<String, INDArray>> list) {
        HashMap hashMap = new HashMap();
        for (Map<String, INDArray> map : list) {
            for (String str : map.keySet()) {
                if (!hashMap.containsKey(str)) {
                    hashMap.put(str, new ArrayList());
                }
                ((List) hashMap.get(str)).add(map.get(str));
            }
        }
        HashMap hashMap2 = new HashMap();
        for (String str2 : hashMap.keySet()) {
            try {
                hashMap2.put(str2, Nd4j.concat(0, (INDArray[]) ((List) hashMap.get(str2)).toArray(new INDArray[0])));
            } catch (Exception e) {
                throw new ND4JException("Error concatenating batch outputs", e);
            }
        }
        return hashMap2;
    }

    public static List<INDArray> getSingleOutput(List<Map<String, INDArray>> list, String str) {
        ArrayList arrayList = new ArrayList();
        Iterator<Map<String, INDArray>> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().get(str));
        }
        return arrayList;
    }

    public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map<String, INDArray> map, SDVariable... sDVariableArr) {
        Preconditions.checkArgument(sDVariableArr != null && sDVariableArr.length > 0, "Require at least one SDVariable to be specified when using external errors: got %s", sDVariableArr);
        ExternalErrorsFunction externalErrorsFunction = new ExternalErrorsFunction(sameDiff, Arrays.asList(sDVariableArr), map);
        externalErrorsFunction.outputVariable();
        return externalErrorsFunction;
    }

    public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, SDVariable[] sDVariableArr) {
        return externalErrors(sameDiff, null, sDVariableArr);
    }

    public static SDVariable reductionBroadcastableWithOrigShape(int i, int[] iArr, SDVariable sDVariable) {
        if (Shape.isWholeArray(i, iArr)) {
            return sDVariable;
        }
        if (i == 2 && iArr.length == 1) {
            return sDVariable;
        }
        for (int i2 : iArr) {
            sDVariable = sDVariable.getSameDiff().expandDims(sDVariable, i2);
        }
        return sDVariable;
    }

    public static SDVariable reductionBroadcastableWithOrigShape(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return sDVariable3.reshape(reductionShape(sDVariable.shape(), sDVariable2, true));
    }

    public static SDVariable reductionShape(SDVariable sDVariable, SDVariable sDVariable2, boolean z) {
        return new ReductionShape(sDVariable.getSameDiff(), sDVariable, sDVariable2, z).outputVariable();
    }

    public static void validateDifferentialFunctionSameDiff(SameDiff sameDiff, SDVariable sDVariable, DifferentialFunction differentialFunction) {
        Preconditions.checkState(sDVariable != null, "Passed in function was null.");
        Preconditions.checkState(sDVariable.getSameDiff() == sameDiff);
        Preconditions.checkState(sDVariable.getSameDiff() == sameDiff, "Function applications must be contained in same sameDiff. The left %s must match this function %s", sDVariable, differentialFunction);
    }

    private SameDiffUtils() {
    }
}
