package moa.recommender.rc.predictor.impl;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import moa.recommender.rc.data.RecommenderData;
import moa.recommender.rc.utils.Pair;
import moa.recommender.rc.utils.Rating;
import moa.recommender.rc.utils.SparseVector;
import moa.recommender.rc.utils.Updatable;

/* loaded from: input_file:moa/recommender/rc/predictor/impl/BRISMFPredictor.class */
public class BRISMFPredictor implements Updatable {
    protected RecommenderData data;
    protected int nFeatures;
    protected HashMap<Integer, float[]> userFeature;
    protected HashMap<Integer, float[]> itemFeature;
    protected Random rnd;
    protected double lRate;
    protected double rFactor;
    protected int nIterations;

    public void setLRate(double d) {
        this.lRate = d;
    }

    public void setRFactor(double d) {
        this.rFactor = d;
    }

    public void setNIterations(int i) {
        this.nIterations = i;
    }

    public RecommenderData getData() {
        return this.data;
    }

    public BRISMFPredictor(int i, RecommenderData recommenderData, boolean z) {
        this.lRate = 0.01d;
        this.rFactor = 0.02d;
        this.nIterations = 30;
        this.data = recommenderData;
        this.nFeatures = i;
        this.userFeature = new HashMap<>();
        this.itemFeature = new HashMap<>();
        this.rnd = new Random(12345L);
        recommenderData.attachUpdatable(this);
        if (z) {
            train();
        }
    }

    public BRISMFPredictor(int i, RecommenderData recommenderData, double d, double d2, boolean z) {
        this.lRate = 0.01d;
        this.rFactor = 0.02d;
        this.nIterations = 30;
        this.data = recommenderData;
        this.nFeatures = i;
        this.userFeature = new HashMap<>();
        this.itemFeature = new HashMap<>();
        this.rnd = new Random(12345L);
        this.lRate = d;
        this.rFactor = d2;
        recommenderData.attachUpdatable(this);
        if (z) {
            train();
        }
    }

    private void resetFeatures(float[] fArr, boolean z) {
        int length = fArr.length;
        for (int i = 0; i < length; i++) {
            fArr[i] = 0.01f * ((this.rnd.nextFloat() * 2.0f) - 1.0f);
        }
        if (z) {
            fArr[0] = 1.0f;
        } else {
            fArr[1] = 1.0f;
        }
    }

    public double predictRating(int i, int i2) {
        return predictRating(this.userFeature.get(Integer.valueOf(i)), this.itemFeature.get(Integer.valueOf(i2)));
    }

    public double predictRating(float[] fArr, float[] fArr2) {
        double globalMean = this.data.getGlobalMean();
        if (fArr != null && fArr2 != null) {
            for (int i = 0; i < this.nFeatures; i++) {
                globalMean += fArr[i] * fArr2[i];
            }
        }
        if (globalMean < this.data.getMinRating()) {
            globalMean = this.data.getMinRating();
        } else if (globalMean > this.data.getMaxRating()) {
            globalMean = this.data.getMaxRating();
        }
        return globalMean;
    }

    public float[] trainUserFeats(List<Integer> list, List<Double> list2, int i) {
        float[] fArr = new float[this.nFeatures];
        resetFeatures(fArr, true);
        int size = list.size();
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < size; i3++) {
                float[] fArr2 = this.itemFeature.get(Integer.valueOf(list.get(i3).intValue()));
                double doubleValue = list2.get(i3).doubleValue() - predictRating(fArr, fArr2);
                if (fArr2 != null) {
                    for (int i4 = 1; i4 < this.nFeatures; i4++) {
                        fArr[i4] = (float) (fArr[r1] + (this.lRate * ((doubleValue * fArr2[i4]) - (this.rFactor * fArr[i4]))));
                    }
                }
            }
        }
        return fArr;
    }

    public float[] trainItemFeats(int i, List<Integer> list, List<Double> list2, int i2) {
        float[] fArr = new float[this.nFeatures];
        resetFeatures(fArr, false);
        int size = list.size();
        for (int i3 = 0; i3 < i2; i3++) {
            for (int i4 = 0; i4 < size; i4++) {
                float[] fArr2 = this.userFeature.get(Integer.valueOf(list.get(i4).intValue()));
                double doubleValue = list2.get(i4).doubleValue() - predictRating(fArr2, fArr);
                if (fArr2 != null) {
                    fArr[0] = (float) (fArr[0] + (this.lRate * ((doubleValue * fArr2[0]) - (this.rFactor * fArr[0]))));
                    for (int i5 = 2; i5 < this.nFeatures; i5++) {
                        fArr[i5] = (float) (fArr[r1] + (this.lRate * ((doubleValue * fArr2[i5]) - (this.rFactor * fArr[i5]))));
                    }
                }
            }
        }
        return fArr;
    }

    public void trainUser(int i, List<Integer> list, List<Double> list2, int i2) {
        this.userFeature.put(Integer.valueOf(i), trainUserFeats(list, list2, i2));
    }

    public void trainUser(int i, int i2) {
        SparseVector ratingsUser = this.data.getRatingsUser(i);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator<Pair<Integer, Double>> it = ratingsUser.iterator();
        while (it.hasNext()) {
            Pair<Integer, Double> next = it.next();
            arrayList.add(next.getFirst());
            arrayList2.add(next.getSecond());
        }
        trainUser(i, arrayList, arrayList2, i2);
    }

    public void trainUser(int i, List<Integer> list, List<Double> list2) {
        this.userFeature.put(Integer.valueOf(i), trainUserFeats(list, list2, this.nIterations));
    }

    public void trainItem(int i) {
        SparseVector ratingsItem = this.data.getRatingsItem(i);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator<Pair<Integer, Double>> it = ratingsItem.iterator();
        while (it.hasNext()) {
            Pair<Integer, Double> next = it.next();
            arrayList.add(next.getFirst());
            arrayList2.add(next.getSecond());
        }
        trainItem(i, arrayList, arrayList2);
    }

    public void trainItem(int i, int i2) {
        SparseVector ratingsItem = this.data.getRatingsItem(i);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator<Pair<Integer, Double>> it = ratingsItem.iterator();
        while (it.hasNext()) {
            Pair<Integer, Double> next = it.next();
            arrayList.add(next.getFirst());
            arrayList2.add(next.getSecond());
        }
        trainItem(i, arrayList, arrayList2, i2);
    }

    public void trainUser(int i) {
        SparseVector ratingsUser = this.data.getRatingsUser(i);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator<Pair<Integer, Double>> it = ratingsUser.iterator();
        while (it.hasNext()) {
            Pair<Integer, Double> next = it.next();
            arrayList.add(next.getFirst());
            arrayList2.add(next.getSecond());
        }
        trainUser(i, arrayList, arrayList2);
    }

    public void trainItem(int i, List<Integer> list, List<Double> list2) {
        this.itemFeature.put(Integer.valueOf(i), trainItemFeats(i, list, list2, this.nIterations));
    }

    public void trainItem(int i, List<Integer> list, List<Double> list2, int i2) {
        this.itemFeature.put(Integer.valueOf(i), trainItemFeats(i, list, list2, i2));
    }

    public void train() {
        this.userFeature.clear();
        this.itemFeature.clear();
        int numRatings = this.data.getNumRatings();
        Iterator<Integer> it = this.data.getUsers().iterator();
        while (it.hasNext()) {
            float[] fArr = new float[this.nFeatures];
            resetFeatures(fArr, true);
            this.userFeature.put(it.next(), fArr);
        }
        Iterator<Integer> it2 = this.data.getItems().iterator();
        while (it2.hasNext()) {
            float[] fArr2 = new float[this.nFeatures];
            resetFeatures(fArr2, false);
            this.itemFeature.put(it2.next(), fArr2);
        }
        int i = 0;
        double d = 1.0E20d;
        int i2 = 0;
        int max = Math.max(20, numRatings / 1000000);
        ArrayList arrayList = new ArrayList(numRatings / max);
        do {
            long currentTimeMillis = System.currentTimeMillis();
            Iterator<Rating> ratingIterator = this.data.ratingIterator();
            int i3 = 0;
            while (ratingIterator.hasNext()) {
                Rating next = ratingIterator.next();
                if (i3 % max != 0) {
                    int i4 = next.userID;
                    int i5 = next.itemID;
                    double d2 = next.rating;
                    float[] fArr3 = this.userFeature.get(Integer.valueOf(i4));
                    float[] fArr4 = this.itemFeature.get(Integer.valueOf(i5));
                    double predictRating = d2 - predictRating(fArr3, fArr4);
                    fArr4[0] = (float) (fArr4[0] + (this.lRate * ((predictRating * fArr3[0]) - (this.rFactor * fArr4[0]))));
                    fArr3[1] = (float) (fArr3[1] + (this.lRate * ((predictRating * fArr4[1]) - (this.rFactor * fArr3[1]))));
                    for (int i6 = 2; i6 < this.nFeatures; i6++) {
                        double d3 = fArr3[i6];
                        fArr3[i6] = (float) (fArr3[r1] + (this.lRate * ((predictRating * fArr4[i6]) - (this.rFactor * fArr3[i6]))));
                        fArr4[i6] = (float) (fArr4[r1] + (this.lRate * ((predictRating * d3) - (this.rFactor * fArr4[i6]))));
                    }
                } else if (i2 == 0) {
                    arrayList.add(next);
                }
                i3++;
            }
            int size = arrayList.size();
            double d4 = 0.0d;
            for (int i7 = 0; i7 < size; i7++) {
                d4 += Math.pow(((Rating) arrayList.get(i7)).rating - predictRating(((Rating) arrayList.get(i7)).userID, ((Rating) arrayList.get(i7)).itemID), 2.0d);
            }
            double sqrt = Math.sqrt(d4 / size);
            System.out.println(sqrt + " " + ((System.currentTimeMillis() - currentTimeMillis) / 1000));
            if (sqrt + 1.0E-4d >= d) {
                i++;
            }
            d = sqrt;
            i2++;
        } while (i < 1);
    }

    public float[] getUserFeatures(int i) {
        return this.userFeature.get(Integer.valueOf(i));
    }

    public float[] getItemFeatures(int i) {
        return this.itemFeature.get(Integer.valueOf(i));
    }

    public int getNumFeatures() {
        return this.nFeatures;
    }

    @Override // moa.recommender.rc.utils.Updatable
    public void updateNewUser(int i, List<Integer> list, List<Double> list2) {
        if (list.isEmpty()) {
            return;
        }
        trainUser(i, list, list2);
    }

    @Override // moa.recommender.rc.utils.Updatable
    public void updateNewItem(int i, List<Integer> list, List<Double> list2) {
        if (list.isEmpty()) {
            return;
        }
        trainItem(i, list, list2);
    }

    @Override // moa.recommender.rc.utils.Updatable
    public void updateRemoveUser(int i) {
        this.userFeature.remove(Integer.valueOf(i));
    }

    @Override // moa.recommender.rc.utils.Updatable
    public void updateRemoveItem(int i) {
        this.itemFeature.remove(Integer.valueOf(i));
    }

    @Override // moa.recommender.rc.utils.Updatable
    public void updateSetRating(int i, int i2, double d) {
        double countRatingsUser = this.data.countRatingsUser(i);
        double countRatingsItem = this.data.countRatingsItem(i2);
        double pow = Math.pow(0.99d, countRatingsUser);
        double pow2 = Math.pow(0.99d, countRatingsItem);
        if (countRatingsUser < 5.0d || this.rnd.nextDouble() < pow) {
            SparseVector ratingsUser = this.data.getRatingsUser(i);
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            boolean z = false;
            Iterator<Pair<Integer, Double>> it = ratingsUser.iterator();
            while (it.hasNext()) {
                Pair<Integer, Double> next = it.next();
                arrayList.add(next.getFirst());
                if (next.getFirst().intValue() == i2) {
                    z = true;
                    arrayList2.add(Double.valueOf(d));
                } else {
                    arrayList2.add(next.getSecond());
                }
            }
            if (!z) {
                arrayList.add(Integer.valueOf(i2));
                arrayList2.add(Double.valueOf(d));
            }
            trainUser(i, arrayList, arrayList2);
        }
        if (countRatingsItem < 5.0d || this.rnd.nextDouble() < pow2) {
            Iterator<Pair<Integer, Double>> it2 = this.data.getRatingsItem(i2).iterator();
            boolean z2 = false;
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            while (it2.hasNext()) {
                Pair<Integer, Double> next2 = it2.next();
                arrayList3.add(next2.getFirst());
                if (next2.getFirst().intValue() == i) {
                    z2 = true;
                    arrayList4.add(Double.valueOf(d));
                } else {
                    arrayList4.add(next2.getSecond());
                }
            }
            if (!z2) {
                arrayList3.add(Integer.valueOf(i2));
                arrayList4.add(Double.valueOf(d));
            }
            trainItem(i2, arrayList3, arrayList4);
        }
    }

    @Override // moa.recommender.rc.utils.Updatable
    public void updateRemoveRating(int i, int i2) {
    }

    public List<Double> predictRatings(int i, List<Integer> list) {
        int size = list.size();
        ArrayList arrayList = new ArrayList(size);
        for (int i2 = 0; i2 < size; i2++) {
            arrayList.add(Double.valueOf(predictRating(i, list.get(i2).intValue())));
        }
        return arrayList;
    }
}
