package com.github.chen0040.rl.models;

import com.github.chen0040.rl.utils.IndexValue;
import com.github.chen0040.rl.utils.Matrix;
import com.github.chen0040.rl.utils.Vec;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:com/github/chen0040/rl/models/QModel.class */
public class QModel {
    private Matrix Q;
    private Matrix alpha;
    private double gamma;
    private int stateCount;
    private int actionCount;

    public QModel(int i, int i2, double d) {
        this.gamma = 0.7d;
        this.stateCount = i;
        this.actionCount = i2;
        this.Q = new Matrix(i, i2);
        this.alpha = new Matrix(i, i2);
        this.Q.setAll(d);
        this.alpha.setAll(0.1d);
    }

    public QModel(int i, int i2) {
        this(i, i2, 0.1d);
    }

    public QModel() {
        this.gamma = 0.7d;
    }

    public boolean equals(Object obj) {
        if (obj == null || !(obj instanceof QModel)) {
            return false;
        }
        QModel qModel = (QModel) obj;
        if (this.gamma != qModel.gamma || this.stateCount != qModel.stateCount || this.actionCount != qModel.actionCount) {
            return false;
        }
        if (this.Q != null && qModel.Q == null) {
            return false;
        }
        if (this.Q == null && qModel.Q != null) {
            return false;
        }
        if (this.alpha != null && qModel.alpha == null) {
            return false;
        }
        if (this.alpha != null || qModel.alpha == null) {
            return (this.Q == null || this.Q.equals(qModel.Q)) && (this.alpha == null || this.alpha.equals(qModel.alpha));
        }
        return false;
    }

    public Object clone() {
        QModel qModel = new QModel();
        qModel.copy(this);
        return qModel;
    }

    public void copy(QModel qModel) {
        this.gamma = qModel.gamma;
        this.stateCount = qModel.stateCount;
        this.actionCount = qModel.actionCount;
        this.Q = qModel.Q == null ? null : (Matrix) qModel.Q.clone();
        this.alpha = qModel.alpha == null ? null : (Matrix) qModel.alpha.clone();
    }

    public Matrix getQ() {
        return this.Q;
    }

    public double getQ(int i, int i2) {
        return this.Q.get(i, i2);
    }

    public void setQ(Matrix matrix) {
        this.Q = matrix;
    }

    public void setQ(int i, int i2, double d) {
        this.Q.set(i, i2, d);
    }

    public Matrix getAlpha() {
        return this.alpha;
    }

    public double getAlpha(int i, int i2) {
        return this.alpha.get(i, i2);
    }

    public void setAlpha(Matrix matrix) {
        this.alpha = matrix;
    }

    public void setAlpha(double d) {
        this.alpha.setAll(d);
    }

    public double getGamma() {
        return this.gamma;
    }

    public void setGamma(double d) {
        this.gamma = d;
    }

    public int getStateCount() {
        return this.stateCount;
    }

    public int getActionCount() {
        return this.actionCount;
    }

    public IndexValue actionWithMaxQAtState(int i, Set<Integer> set) {
        return this.Q.getRow(i).indexWithMaxValue(set);
    }

    private void reset(double d) {
        this.Q.setAll(d);
    }

    public IndexValue actionWithSoftMaxQAtState(int i, Set<Integer> set, Random random) {
        Vec row = this.Q.getRow(i);
        double d = 0.0d;
        if (set == null) {
            set = new HashSet();
            for (int i2 = 0; i2 < this.actionCount; i2++) {
                set.add(Integer.valueOf(i2));
            }
        }
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        double[] dArr = new double[arrayList.size()];
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            d += row.get(((Integer) arrayList.get(i3)).intValue());
            dArr[i3] = d;
        }
        double nextDouble = random.nextDouble() * d;
        IndexValue indexValue = new IndexValue();
        int i4 = 0;
        while (true) {
            if (i4 >= arrayList.size()) {
                break;
            }
            if (dArr[i4] >= nextDouble) {
                int intValue = ((Integer) arrayList.get(i4)).intValue();
                indexValue.setIndex(intValue);
                indexValue.setValue(row.get(intValue));
                break;
            }
            i4++;
        }
        return indexValue;
    }
}
