package com.mwt.explorers;

import com.mwt.consumers.ConsumeScorer;
import com.mwt.misc.DecisionTuple;
import com.mwt.scorers.Scorer;
import com.mwt.utilities.PRG;
import java.util.ArrayList;

/* loaded from: input_file:com/mwt/explorers/SoftmaxExplorer.class */
public class SoftmaxExplorer<T> implements Explorer<T>, ConsumeScorer<T> {
    private Scorer<T> defaultScorer;
    private boolean explore = true;
    private final float lambda;
    private final int numActions;

    public SoftmaxExplorer(Scorer<T> scorer, float f, int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of actions must be at least 1.");
        }
        this.defaultScorer = scorer;
        this.lambda = f;
        this.numActions = i;
    }

    protected int getNumActions(T t) {
        return this.numActions;
    }

    @Override // com.mwt.consumers.ConsumeScorer
    public void updateScorer(Scorer<T> scorer) {
        this.defaultScorer = scorer;
    }

    @Override // com.mwt.explorers.Explorer
    public DecisionTuple chooseAction(long j, T t) {
        PRG prg = new PRG(j);
        ArrayList<Float> scoreActions = this.defaultScorer.scoreActions(t);
        int size = scoreActions.size();
        if (size != getNumActions(t)) {
            throw new RuntimeException("The number of scores returned by the scorer must equal number of actions");
        }
        int i = 0;
        float f = 1.0f;
        Float valueOf = Float.valueOf(Float.MIN_VALUE);
        for (int i2 = 0; i2 < size; i2++) {
            if (valueOf.floatValue() < scoreActions.get(i2).floatValue()) {
                valueOf = scoreActions.get(i2);
                i = i2;
            }
        }
        if (this.explore) {
            float[] fArr = new float[size];
            for (int i3 = 0; i3 < size; i3++) {
                fArr[i3] = (float) Math.exp(this.lambda * (scoreActions.get(i3).floatValue() - valueOf.floatValue()));
            }
            float f2 = 0.0f;
            for (int i4 = 0; i4 < size; i4++) {
                f2 += fArr[i4];
            }
            float uniformUnitInterval = prg.uniformUnitInterval();
            float f3 = 0.0f;
            f = 0.0f;
            i = size - 1;
            int i5 = 0;
            while (true) {
                if (i5 >= size) {
                    break;
                }
                fArr[i5] = fArr[i5] / f2;
                f3 += fArr[i5];
                if (f3 > uniformUnitInterval) {
                    i = i5;
                    f = fArr[i5];
                    break;
                }
                i5++;
            }
        }
        return new DecisionTuple(i + 1, f, true);
    }

    @Override // com.mwt.explorers.Explorer
    public void enableExplore(boolean z) {
        this.explore = z;
    }
}
