package com.github.chen0040.rl.learning.sarsa;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.github.chen0040.rl.actionselection.AbstractActionSelectionStrategy;
import com.github.chen0040.rl.actionselection.ActionSelectionStrategy;
import com.github.chen0040.rl.actionselection.ActionSelectionStrategyFactory;
import com.github.chen0040.rl.actionselection.EpsilonGreedyActionSelectionStrategy;
import com.github.chen0040.rl.models.QModel;
import com.github.chen0040.rl.utils.IndexValue;
import java.io.Serializable;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:com/github/chen0040/rl/learning/sarsa/SarsaLearner.class */
public class SarsaLearner implements Serializable, Cloneable {
    protected QModel model;
    private ActionSelectionStrategy actionSelectionStrategy;

    public String toJson() {
        return JSON.toJSONString(this, new SerializerFeature[]{SerializerFeature.BrowserCompatible});
    }

    public static SarsaLearner fromJson(String str) {
        return (SarsaLearner) JSON.parseObject(str, SarsaLearner.class);
    }

    public SarsaLearner makeCopy() {
        SarsaLearner sarsaLearner = new SarsaLearner();
        sarsaLearner.copy(this);
        return sarsaLearner;
    }

    public void copy(SarsaLearner sarsaLearner) {
        this.model = sarsaLearner.model.makeCopy();
        this.actionSelectionStrategy = (ActionSelectionStrategy) ((AbstractActionSelectionStrategy) sarsaLearner.actionSelectionStrategy).clone();
    }

    public boolean equals(Object obj) {
        if (obj == null || !(obj instanceof SarsaLearner)) {
            return false;
        }
        SarsaLearner sarsaLearner = (SarsaLearner) obj;
        if (this.model.equals(sarsaLearner.model)) {
            return this.actionSelectionStrategy.equals(sarsaLearner.actionSelectionStrategy);
        }
        return false;
    }

    public QModel getModel() {
        return this.model;
    }

    public void setModel(QModel qModel) {
        this.model = qModel;
    }

    public String getActionSelection() {
        return ActionSelectionStrategyFactory.serialize(this.actionSelectionStrategy);
    }

    public void setActionSelection(String str) {
        this.actionSelectionStrategy = ActionSelectionStrategyFactory.deserialize(str);
    }

    public SarsaLearner() {
    }

    public SarsaLearner(int i, int i2) {
        this(i, i2, 0.1d, 0.7d, 0.1d);
    }

    public SarsaLearner(QModel qModel, ActionSelectionStrategy actionSelectionStrategy) {
        this.model = qModel;
        this.actionSelectionStrategy = actionSelectionStrategy;
    }

    public SarsaLearner(int i, int i2, double d, double d2, double d3) {
        this.model = new QModel(i, i2, d3);
        this.model.setAlpha(d);
        this.model.setGamma(d2);
        this.actionSelectionStrategy = new EpsilonGreedyActionSelectionStrategy();
    }

    public static void main(String[] strArr) {
        SarsaLearner sarsaLearner = new SarsaLearner(100, 10);
        Random random = new Random();
        int nextInt = random.nextInt(100);
        int index = sarsaLearner.selectAction(nextInt).getIndex();
        for (int i = 0; i < 1000; i++) {
            System.out.println("Controller does action-" + index);
            int nextInt2 = random.nextInt(10);
            double nextDouble = random.nextDouble();
            System.out.println("Now the new state is " + nextInt2);
            System.out.println("Controller receives Reward = " + nextDouble);
            int index2 = sarsaLearner.selectAction(nextInt2).getIndex();
            System.out.println("Controller is expected to do action-" + index2);
            sarsaLearner.update(nextInt, index, nextInt2, index2, nextDouble);
            nextInt = nextInt2;
            index = index2;
        }
    }

    public IndexValue selectAction(int i, Set<Integer> set) {
        return this.actionSelectionStrategy.selectAction(i, this.model, set);
    }

    public IndexValue selectAction(int i) {
        return selectAction(i, null);
    }

    public void update(int i, int i2, int i3, int i4, double d) {
        double q = this.model.getQ(i, i2);
        this.model.setQ(i, i2, q + (this.model.getAlpha(i, i2) * ((d + (this.model.getGamma() * this.model.getQ(i3, i4))) - q)));
    }
}
