package ai.djl.modality.rl.agent;

import ai.djl.modality.rl.ActionSpace;
import ai.djl.modality.rl.env.RlEnv;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.GradientCollector;
import ai.djl.training.Trainer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.translate.Batchifier;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;

/* loaded from: input_file:ai/djl/modality/rl/agent/QAgent.class */
public class QAgent implements RlAgent {
    private Trainer trainer;
    private float rewardDiscount;
    private Batchifier batchifier;

    public QAgent(Trainer trainer, float f) {
        this(trainer, f, Batchifier.STACK);
    }

    public QAgent(Trainer trainer, float f, Batchifier batchifier) {
        this.trainer = trainer;
        this.rewardDiscount = f;
        this.batchifier = batchifier;
    }

    @Override // ai.djl.modality.rl.agent.RlAgent
    public NDList chooseAction(RlEnv rlEnv, boolean z) {
        ActionSpace actionSpace = rlEnv.getActionSpace();
        return actionSpace.get(Math.toIntExact(this.trainer.evaluate(this.batchifier.batchify(buildInputs(rlEnv.getObservation(), actionSpace))).singletonOrThrow().squeeze(-1).argMax().getLong(new long[0])));
    }

    @Override // ai.djl.modality.rl.agent.RlAgent
    public void trainBatch(RlEnv.Step[] stepArr) {
        TrainingListener.BatchData batchData = new TrainingListener.BatchData(null, new ConcurrentHashMap(), new ConcurrentHashMap());
        for (RlEnv.Step step : stepArr) {
            NDList[] nDListArr = (NDList[]) Stream.concat(Arrays.stream(buildInputs(step.getPostObservation(), Collections.singletonList(step.getAction()))), Arrays.stream(buildInputs(step.getPostObservation(), step.getPostActionSpace()))).toArray(i -> {
                return new NDList[i];
            });
            GradientCollector newGradientCollector = this.trainer.newGradientCollector();
            Throwable th = null;
            try {
                try {
                    NDArray squeeze = this.trainer.forward(this.batchifier.batchify(nDListArr)).singletonOrThrow().squeeze(-1);
                    NDList nDList = new NDList(squeeze.get(0));
                    NDList nDList2 = new NDList((squeeze.size() > 1 ? squeeze.get("1:", new Object[0]).max() : squeeze.getManager().create(0.0f)).mul(Float.valueOf(this.rewardDiscount)).add(step.getReward()));
                    newGradientCollector.backward(this.trainer.getLoss().evaluate(nDList, nDList2));
                    batchData.getLabels().put(nDList2.get(0).getDevice(), nDList2);
                    batchData.getPredictions().put(nDList.get(0).getDevice(), nDList);
                    if (newGradientCollector != null) {
                        if (0 != 0) {
                            try {
                                newGradientCollector.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            newGradientCollector.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (newGradientCollector != null) {
                    if (th != null) {
                        try {
                            newGradientCollector.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        newGradientCollector.close();
                    }
                }
                throw th3;
            }
        }
        this.trainer.notifyListeners(trainingListener -> {
            trainingListener.onTrainingBatch(this.trainer, batchData);
        });
    }

    private NDList[] buildInputs(NDList nDList, List<NDList> list) {
        NDList[] nDListArr = new NDList[list.size()];
        for (int i = 0; i < list.size(); i++) {
            nDListArr[i] = new NDList().addAll(nDList).addAll(list.get(i));
        }
        return nDListArr;
    }
}
