package org.deeplearning4j.zoo.util.imagenet;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import org.deeplearning4j.zoo.util.BaseLabels;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.databind.ObjectMapper;

/* loaded from: input_file:org/deeplearning4j/zoo/util/imagenet/ImageNetLabels.class */
public class ImageNetLabels extends BaseLabels {
    private static final String jsonResource = "imagenet_class_index.json";
    private ArrayList<String> predictionLabels;

    public ImageNetLabels() throws IOException {
        this.predictionLabels = null;
        this.predictionLabels = getLabels();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.deeplearning4j.zoo.util.BaseLabels
    protected ArrayList<String> getLabels() throws IOException {
        if (this.predictionLabels == null) {
            HashMap hashMap = (HashMap) new ObjectMapper().readValue(getClass().getResourceAsStream(jsonResource), HashMap.class);
            this.predictionLabels = new ArrayList<>(hashMap.size());
            for (int i = 0; i < hashMap.size(); i++) {
                this.predictionLabels.add(((ArrayList) hashMap.get(String.valueOf(i))).get(1));
            }
        }
        return this.predictionLabels;
    }

    @Override // org.deeplearning4j.zoo.util.BaseLabels, org.deeplearning4j.zoo.util.Labels
    public String getLabel(int i) {
        return this.predictionLabels.get(i);
    }

    public String decodePredictions(INDArray iNDArray) {
        String str = "";
        int[] iArr = new int[5];
        float[] fArr = new float[5];
        int i = 0;
        for (int i2 = 0; i2 < iNDArray.size(0); i2++) {
            String str2 = str + "Predictions for batch ";
            if (iNDArray.size(0) > 1) {
                str2 = str2 + String.valueOf(i2);
            }
            str = str2 + " :";
            INDArray dup = iNDArray.getRow(i2).dup();
            while (i < 5) {
                iArr[i] = Nd4j.argMax(dup, new int[]{1}).getInt(new int[]{0, 0});
                fArr[i] = dup.getFloat(i2, iArr[i]);
                dup.putScalar(0, iArr[i], 0.0d);
                str = str + "\n\t" + String.format("%3f", Float.valueOf(fArr[i] * 100.0f)) + "%, " + this.predictionLabels.get(iArr[i]);
                i++;
            }
        }
        return str;
    }
}
