package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerType;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/MultiNormalizerHybrid.class */
public class MultiNormalizerHybrid extends AbstractNormalizer implements MultiDataNormalization, Serializable {
    private Map<Integer, NormalizerStats> inputStats;
    private Map<Integer, NormalizerStats> outputStats;
    private NormalizerStrategy globalInputStrategy;
    private NormalizerStrategy globalOutputStrategy;
    private Map<Integer, NormalizerStrategy> perInputStrategies = new HashMap();
    private Map<Integer, NormalizerStrategy> perOutputStrategies = new HashMap();

    public MultiNormalizerHybrid standardizeAllInputs() {
        this.globalInputStrategy = new StandardizeStrategy();
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleAllInputs() {
        this.globalInputStrategy = new MinMaxStrategy();
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleAllInputs(double d, double d2) {
        this.globalInputStrategy = new MinMaxStrategy(d, d2);
        return this;
    }

    public MultiNormalizerHybrid standardizeInput(int i) {
        this.perInputStrategies.put(Integer.valueOf(i), new StandardizeStrategy());
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleInput(int i) {
        this.perInputStrategies.put(Integer.valueOf(i), new MinMaxStrategy());
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleInput(int i, double d, double d2) {
        this.perInputStrategies.put(Integer.valueOf(i), new MinMaxStrategy(d, d2));
        return this;
    }

    public MultiNormalizerHybrid standardizeAllOutputs() {
        this.globalOutputStrategy = new StandardizeStrategy();
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleAllOutputs() {
        this.globalOutputStrategy = new MinMaxStrategy();
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleAllOutputs(double d, double d2) {
        this.globalOutputStrategy = new MinMaxStrategy(d, d2);
        return this;
    }

    public MultiNormalizerHybrid standardizeOutput(int i) {
        this.perOutputStrategies.put(Integer.valueOf(i), new StandardizeStrategy());
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleOutput(int i) {
        this.perOutputStrategies.put(Integer.valueOf(i), new MinMaxStrategy());
        return this;
    }

    public MultiNormalizerHybrid minMaxScaleOutput(int i, double d, double d2) {
        this.perOutputStrategies.put(Integer.valueOf(i), new MinMaxStrategy(d, d2));
        return this;
    }

    public NormalizerStats getInputStats(int i) {
        return getInputStats().get(Integer.valueOf(i));
    }

    public NormalizerStats getOutputStats(int i) {
        return getOutputStats().get(Integer.valueOf(i));
    }

    public Map<Integer, NormalizerStats> getInputStats() {
        assertIsFit();
        return this.inputStats;
    }

    public Map<Integer, NormalizerStats> getOutputStats() {
        assertIsFit();
        return this.outputStats;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.Normalizer
    public void fit(@NonNull MultiDataSet multiDataSet) {
        if (multiDataSet == null) {
            throw new NullPointerException("dataSet is marked @NonNull but is null");
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        fitPartial(multiDataSet, hashMap, hashMap2);
        this.inputStats = buildAllStats(hashMap);
        this.outputStats = buildAllStats(hashMap2);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public void fit(@NonNull MultiDataSetIterator multiDataSetIterator) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        multiDataSetIterator.reset();
        while (multiDataSetIterator.hasNext()) {
            fitPartial(multiDataSetIterator.next(), hashMap, hashMap2);
        }
        this.inputStats = buildAllStats(hashMap);
        this.outputStats = buildAllStats(hashMap2);
    }

    private void fitPartial(MultiDataSet multiDataSet, Map<Integer, NormalizerStats.Builder> map, Map<Integer, NormalizerStats.Builder> map2) {
        ensureStatsBuilders(map, this.globalInputStrategy, this.perInputStrategies, multiDataSet.numFeatureArrays());
        ensureStatsBuilders(map2, this.globalOutputStrategy, this.perOutputStrategies, multiDataSet.numLabelsArrays());
        Iterator<Integer> it = map.keySet().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            map.get(Integer.valueOf(intValue)).add2(multiDataSet.getFeatures(intValue), multiDataSet.getFeaturesMaskArray(intValue));
        }
        Iterator<Integer> it2 = map2.keySet().iterator();
        while (it2.hasNext()) {
            int intValue2 = it2.next().intValue();
            map2.get(Integer.valueOf(intValue2)).add2(multiDataSet.getLabels(intValue2), multiDataSet.getLabelsMaskArray(intValue2));
        }
    }

    private void ensureStatsBuilders(Map<Integer, NormalizerStats.Builder> map, NormalizerStrategy normalizerStrategy, Map<Integer, NormalizerStrategy> map2, int i) {
        if (map.isEmpty()) {
            for (int i2 = 0; i2 < i; i2++) {
                NormalizerStrategy strategy = getStrategy(normalizerStrategy, map2, i2);
                if (strategy != null) {
                    map.put(Integer.valueOf(i2), strategy.newStatsBuilder());
                }
            }
        }
    }

    private Map<Integer, NormalizerStats> buildAllStats(@NonNull Map<Integer, NormalizerStats.Builder> map) {
        if (map == null) {
            throw new NullPointerException("builders is marked @NonNull but is null");
        }
        HashMap hashMap = new HashMap(map.size());
        Iterator<Integer> it = map.keySet().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            hashMap.put(Integer.valueOf(intValue), map.get(Integer.valueOf(intValue)).build());
        }
        return hashMap;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.Normalizer
    public void transform(@NonNull MultiDataSet multiDataSet) {
        if (multiDataSet == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        preProcess(multiDataSet);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization, org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor
    public void preProcess(@NonNull MultiDataSet multiDataSet) {
        if (multiDataSet == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        preProcess(multiDataSet.getFeatures(), multiDataSet.getFeaturesMaskArrays(), this.globalInputStrategy, this.perInputStrategies, getInputStats());
        preProcess(multiDataSet.getLabels(), multiDataSet.getLabelsMaskArrays(), this.globalOutputStrategy, this.perOutputStrategies, getOutputStats());
    }

    private void preProcess(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, NormalizerStrategy normalizerStrategy, Map<Integer, NormalizerStrategy> map, Map<Integer, NormalizerStats> map2) {
        if (iNDArrayArr != null) {
            for (int i = 0; i < iNDArrayArr.length; i++) {
                NormalizerStrategy strategy = getStrategy(normalizerStrategy, map, i);
                if (strategy != null) {
                    strategy.preProcess(iNDArrayArr[i], iNDArrayArr2 == null ? null : iNDArrayArr2[i], map2.get(Integer.valueOf(i)));
                }
            }
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.Normalizer
    public void revert(@NonNull MultiDataSet multiDataSet) {
        if (multiDataSet == null) {
            throw new NullPointerException("data is marked @NonNull but is null");
        }
        revertFeatures(multiDataSet.getFeatures(), multiDataSet.getFeaturesMaskArrays());
        revertLabels(multiDataSet.getLabels(), multiDataSet.getLabelsMaskArrays());
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.Normalizer
    public NormalizerType getType() {
        return NormalizerType.MULTI_HYBRID;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public void revertFeatures(@NonNull INDArray[] iNDArrayArr) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("features is marked @NonNull but is null");
        }
        revertFeatures(iNDArrayArr, null);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public void revertFeatures(@NonNull INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("features is marked @NonNull but is null");
        }
        for (int i = 0; i < iNDArrayArr.length; i++) {
            revertFeatures(iNDArrayArr, iNDArrayArr2, i);
        }
    }

    public void revertFeatures(@NonNull INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, int i) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("features is marked @NonNull but is null");
        }
        NormalizerStrategy strategy = getStrategy(this.globalInputStrategy, this.perInputStrategies, i);
        if (strategy != null) {
            strategy.revert(iNDArrayArr[i], iNDArrayArr2 == null ? null : iNDArrayArr2[i], getInputStats(i));
        }
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public void revertLabels(@NonNull INDArray[] iNDArrayArr) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        revertLabels(iNDArrayArr, null);
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.MultiDataNormalization
    public void revertLabels(@NonNull INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        for (int i = 0; i < iNDArrayArr.length; i++) {
            revertLabels(iNDArrayArr, iNDArrayArr2, i);
        }
    }

    public void revertLabels(@NonNull INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, int i) {
        if (iNDArrayArr == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        NormalizerStrategy strategy = getStrategy(this.globalOutputStrategy, this.perOutputStrategies, i);
        if (strategy != null) {
            strategy.revert(iNDArrayArr[i], iNDArrayArr2 == null ? null : iNDArrayArr2[i], getOutputStats(i));
        }
    }

    private NormalizerStrategy getStrategy(NormalizerStrategy normalizerStrategy, Map<Integer, NormalizerStrategy> map, int i) {
        NormalizerStrategy normalizerStrategy2 = normalizerStrategy;
        if (map.containsKey(Integer.valueOf(i))) {
            normalizerStrategy2 = map.get(Integer.valueOf(i));
        }
        return normalizerStrategy2;
    }

    @Override // org.nd4j.linalg.dataset.api.preprocessor.AbstractNormalizer
    protected boolean isFit() {
        return this.inputStats != null;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof MultiNormalizerHybrid)) {
            return false;
        }
        MultiNormalizerHybrid multiNormalizerHybrid = (MultiNormalizerHybrid) obj;
        if (!multiNormalizerHybrid.canEqual(this)) {
            return false;
        }
        Map<Integer, NormalizerStats> inputStats = getInputStats();
        Map<Integer, NormalizerStats> inputStats2 = multiNormalizerHybrid.getInputStats();
        if (inputStats == null) {
            if (inputStats2 != null) {
                return false;
            }
        } else if (!inputStats.equals(inputStats2)) {
            return false;
        }
        Map<Integer, NormalizerStats> outputStats = getOutputStats();
        Map<Integer, NormalizerStats> outputStats2 = multiNormalizerHybrid.getOutputStats();
        if (outputStats == null) {
            if (outputStats2 != null) {
                return false;
            }
        } else if (!outputStats.equals(outputStats2)) {
            return false;
        }
        NormalizerStrategy globalInputStrategy = getGlobalInputStrategy();
        NormalizerStrategy globalInputStrategy2 = multiNormalizerHybrid.getGlobalInputStrategy();
        if (globalInputStrategy == null) {
            if (globalInputStrategy2 != null) {
                return false;
            }
        } else if (!globalInputStrategy.equals(globalInputStrategy2)) {
            return false;
        }
        NormalizerStrategy globalOutputStrategy = getGlobalOutputStrategy();
        NormalizerStrategy globalOutputStrategy2 = multiNormalizerHybrid.getGlobalOutputStrategy();
        if (globalOutputStrategy == null) {
            if (globalOutputStrategy2 != null) {
                return false;
            }
        } else if (!globalOutputStrategy.equals(globalOutputStrategy2)) {
            return false;
        }
        Map<Integer, NormalizerStrategy> perInputStrategies = getPerInputStrategies();
        Map<Integer, NormalizerStrategy> perInputStrategies2 = multiNormalizerHybrid.getPerInputStrategies();
        if (perInputStrategies == null) {
            if (perInputStrategies2 != null) {
                return false;
            }
        } else if (!perInputStrategies.equals(perInputStrategies2)) {
            return false;
        }
        Map<Integer, NormalizerStrategy> perOutputStrategies = getPerOutputStrategies();
        Map<Integer, NormalizerStrategy> perOutputStrategies2 = multiNormalizerHybrid.getPerOutputStrategies();
        return perOutputStrategies == null ? perOutputStrategies2 == null : perOutputStrategies.equals(perOutputStrategies2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof MultiNormalizerHybrid;
    }

    public int hashCode() {
        Map<Integer, NormalizerStats> inputStats = getInputStats();
        int hashCode = (1 * 59) + (inputStats == null ? 43 : inputStats.hashCode());
        Map<Integer, NormalizerStats> outputStats = getOutputStats();
        int hashCode2 = (hashCode * 59) + (outputStats == null ? 43 : outputStats.hashCode());
        NormalizerStrategy globalInputStrategy = getGlobalInputStrategy();
        int hashCode3 = (hashCode2 * 59) + (globalInputStrategy == null ? 43 : globalInputStrategy.hashCode());
        NormalizerStrategy globalOutputStrategy = getGlobalOutputStrategy();
        int hashCode4 = (hashCode3 * 59) + (globalOutputStrategy == null ? 43 : globalOutputStrategy.hashCode());
        Map<Integer, NormalizerStrategy> perInputStrategies = getPerInputStrategies();
        int hashCode5 = (hashCode4 * 59) + (perInputStrategies == null ? 43 : perInputStrategies.hashCode());
        Map<Integer, NormalizerStrategy> perOutputStrategies = getPerOutputStrategies();
        return (hashCode5 * 59) + (perOutputStrategies == null ? 43 : perOutputStrategies.hashCode());
    }

    public void setInputStats(Map<Integer, NormalizerStats> map) {
        this.inputStats = map;
    }

    public void setOutputStats(Map<Integer, NormalizerStats> map) {
        this.outputStats = map;
    }

    public void setGlobalInputStrategy(NormalizerStrategy normalizerStrategy) {
        this.globalInputStrategy = normalizerStrategy;
    }

    public void setGlobalOutputStrategy(NormalizerStrategy normalizerStrategy) {
        this.globalOutputStrategy = normalizerStrategy;
    }

    public void setPerInputStrategies(Map<Integer, NormalizerStrategy> map) {
        this.perInputStrategies = map;
    }

    public void setPerOutputStrategies(Map<Integer, NormalizerStrategy> map) {
        this.perOutputStrategies = map;
    }

    public NormalizerStrategy getGlobalInputStrategy() {
        return this.globalInputStrategy;
    }

    public NormalizerStrategy getGlobalOutputStrategy() {
        return this.globalOutputStrategy;
    }

    public Map<Integer, NormalizerStrategy> getPerInputStrategies() {
        return this.perInputStrategies;
    }

    public Map<Integer, NormalizerStrategy> getPerOutputStrategies() {
        return this.perOutputStrategies;
    }
}
