package com.github.chen0040.rl.actionselection;

import com.github.chen0040.rl.models.QModel;
import com.github.chen0040.rl.utils.IndexValue;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:com/github/chen0040/rl/actionselection/GibbsSoftMaxActionSelectionStrategy.class */
public class GibbsSoftMaxActionSelectionStrategy extends AbstractActionSelectionStrategy {
    private Random random;

    public GibbsSoftMaxActionSelectionStrategy() {
        this.random = null;
        this.random = new Random();
    }

    public GibbsSoftMaxActionSelectionStrategy(Random random) {
        this.random = null;
        this.random = random;
    }

    @Override // com.github.chen0040.rl.actionselection.AbstractActionSelectionStrategy
    public Object clone() {
        return new GibbsSoftMaxActionSelectionStrategy();
    }

    @Override // com.github.chen0040.rl.actionselection.AbstractActionSelectionStrategy, com.github.chen0040.rl.actionselection.ActionSelectionStrategy
    public IndexValue selectAction(int i, QModel qModel, Set<Integer> set) {
        ArrayList arrayList = new ArrayList();
        if (set == null) {
            for (int i2 = 0; i2 < qModel.getActionCount(); i2++) {
                arrayList.add(Integer.valueOf(i2));
            }
        } else {
            Iterator<Integer> it = set.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next());
            }
        }
        double d = 0.0d;
        ArrayList arrayList2 = new ArrayList();
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            d += Math.exp(qModel.getQ(i, ((Integer) arrayList.get(i3)).intValue()));
            arrayList2.add(Double.valueOf(d));
        }
        IndexValue indexValue = new IndexValue();
        indexValue.setIndex(-1);
        indexValue.setValue(Double.NEGATIVE_INFINITY);
        double nextDouble = d * this.random.nextDouble();
        int i4 = 0;
        while (true) {
            if (i4 >= arrayList.size()) {
                break;
            }
            if (((Double) arrayList2.get(i4)).doubleValue() >= nextDouble) {
                int intValue = ((Integer) arrayList.get(i4)).intValue();
                indexValue.setValue(qModel.getQ(i, intValue));
                indexValue.setIndex(intValue);
                break;
            }
            i4++;
        }
        return indexValue;
    }
}
