package com.mwt.explorers;

import com.mwt.consumers.ConsumeScorer;
import com.mwt.misc.DecisionTuple;
import com.mwt.scorers.Scorer;
import com.mwt.utilities.PRG;
import java.util.ArrayList;
import java.util.Iterator;

/* loaded from: input_file:com/mwt/explorers/GenericExplorer.class */
public class GenericExplorer<T> implements Explorer<T>, ConsumeScorer<T> {
    private Scorer<T> defaultScorer;
    private boolean explore = true;
    private final int numActions;

    public GenericExplorer(Scorer<T> scorer, int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of actions must be at least 1.");
        }
        this.defaultScorer = scorer;
        this.numActions = i;
    }

    protected int getNumActions(T t) {
        return this.numActions;
    }

    @Override // com.mwt.consumers.ConsumeScorer
    public void updateScorer(Scorer<T> scorer) {
        this.defaultScorer = scorer;
    }

    @Override // com.mwt.explorers.Explorer
    public DecisionTuple chooseAction(long j, T t) {
        PRG prg = new PRG(j);
        ArrayList<Float> scoreActions = this.defaultScorer.scoreActions(t);
        int size = scoreActions.size();
        if (size != getNumActions(t)) {
            throw new RuntimeException("The number of weights returned by the scorer must equal number of actions");
        }
        float f = 0.0f;
        Iterator<Float> it = scoreActions.iterator();
        while (it.hasNext()) {
            Float next = it.next();
            if (next.floatValue() < 0.0f) {
                throw new RuntimeException("Scores must be non-negative.");
            }
            f += next.floatValue();
        }
        if (f == 0.0f) {
            throw new RuntimeException("At least one score must be positive.");
        }
        float uniformUnitInterval = prg.uniformUnitInterval();
        float f2 = 0.0f;
        float f3 = 0.0f;
        int i = size - 1;
        int i2 = 0;
        while (true) {
            if (i2 >= size) {
                break;
            }
            float floatValue = scoreActions.get(i2).floatValue() / f;
            f2 += floatValue;
            if (f2 > uniformUnitInterval) {
                i = i2;
                f3 = floatValue;
                break;
            }
            i2++;
        }
        return new DecisionTuple(i + 1, f3, true);
    }

    @Override // com.mwt.explorers.Explorer
    public void enableExplore(boolean z) {
        this.explore = z;
    }
}
