package se.sics.kompics.network.data.policies;

import com.google.common.collect.TreeMultimap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.TreeMap;
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction;
import org.apache.commons.math3.fitting.PolynomialCurveFitter;
import org.apache.commons.math3.fitting.WeightedObservedPoints;
import org.jscience.mathematics.number.Rational;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;
import org.ujmp.core.SparseMatrix;
import org.ujmp.core.calculation.Calculation;
import se.sics.kompics.config.Config;
import se.sics.kompics.network.data.Statistics;

/* loaded from: input_file:se/sics/kompics/network/data/policies/TDRatioLearner.class */
public class TDRatioLearner implements ProtocolRatioPolicy {
    private static final Logger LOG = LoggerFactory.getLogger(TDRatioLearner.class);
    private final Config config;
    private final double alpha;
    private final double gamma;
    private final double lambda;
    private final Rational stepSize;
    private int state;
    private final int[] actions;
    private final ActionValueEstimator Q;
    private final ActionValueEstimator e;
    private final DerivedPolicy policy;
    private int lastAction;
    private int lastState;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: se.sics.kompics.network.data.policies.TDRatioLearner$1, reason: invalid class name */
    /* loaded from: input_file:se/sics/kompics/network/data/policies/TDRatioLearner$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$se$sics$kompics$network$data$policies$TDRatioLearner$ActionValueEstimator$Implementation = new int[ActionValueEstimator.Implementation.values().length];

        static {
            try {
                $SwitchMap$se$sics$kompics$network$data$policies$TDRatioLearner$ActionValueEstimator$Implementation[ActionValueEstimator.Implementation.MATRIX.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$se$sics$kompics$network$data$policies$TDRatioLearner$ActionValueEstimator$Implementation[ActionValueEstimator.Implementation.COLLAPSED.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$se$sics$kompics$network$data$policies$TDRatioLearner$ActionValueEstimator$Implementation[ActionValueEstimator.Implementation.FUNCTION.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:se/sics/kompics/network/data/policies/TDRatioLearner$AVEFactory.class */
    class AVEFactory {
        AVEFactory() {
        }

        public ActionValueEstimator getInstance(ActionValueEstimator.Implementation implementation, Rational rational) {
            switch (AnonymousClass1.$SwitchMap$se$sics$kompics$network$data$policies$TDRatioLearner$ActionValueEstimator$Implementation[implementation.ordinal()]) {
                case Statistics.WINDOW_SIZE /* 1 */:
                    return new BasicMatrixEstimator(rational);
                case 2:
                    return new CollapsedMatrixEstimator(rational);
                case 3:
                    return new FunctionApproximationEstimator(rational);
                default:
                    return null;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:se/sics/kompics/network/data/policies/TDRatioLearner$ActionValueEstimator.class */
    public interface ActionValueEstimator {

        /* loaded from: input_file:se/sics/kompics/network/data/policies/TDRatioLearner$ActionValueEstimator$Implementation.class */
        public enum Implementation {
            MATRIX,
            COLLAPSED,
            FUNCTION
        }

        int numStates();

        int bestActionAt(int i);

        int randomActionAt(int i, Random random);

        double at(int i, int i2);

        void set(double d, int i, int i2);

        int middleState();

        int maxState();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:se/sics/kompics/network/data/policies/TDRatioLearner$BasicMatrixEstimator.class */
    public class BasicMatrixEstimator implements ActionValueEstimator {
        private final int stateMiddleIndex;
        private final int height;
        private final int width;
        private final Matrix m;

        public BasicMatrixEstimator(Rational rational) {
            this.width = TDRatioLearner.this.actions.length;
            this.height = Rational.valueOf(2L, 1L).divide(rational).intValue() + 1;
            TDRatioLearner.LOG.trace("Initialising {}x{} Matrix", Integer.valueOf(this.height), Integer.valueOf(this.width));
            if (Math.max(this.width, this.height) < 10) {
                this.m = DenseMatrix.Factory.zeros(this.height, this.width);
            } else {
                this.m = SparseMatrix.Factory.zeros(this.height, this.width);
            }
            this.stateMiddleIndex = this.height / 2;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int numStates() {
            return this.height;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public double at(int i, int i2) {
            return this.m.getAsDouble(new long[]{i, i2});
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public void set(double d, int i, int i2) {
            this.m.setAsDouble(d, new long[]{i, i2});
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int middleState() {
            return this.stateMiddleIndex;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int maxState() {
            return this.height - 1;
        }

        public String toString() {
            return "BasicMatrixEstimator with Q=\n" + this.m;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int bestActionAt(int i) {
            SparseMatrix sparseMatrix = (Matrix) this.m.getRowList().get(i);
            if (sparseMatrix == null) {
                sparseMatrix = SparseMatrix.Factory.zeros(1L, TDRatioLearner.this.actions.length);
            }
            return sparseMatrix.indexOfMax(Calculation.Ret.NEW, 1).getAsInt(new long[]{0, 0});
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int randomActionAt(int i, Random random) {
            TreeMultimap create = TreeMultimap.create();
            for (int i2 = 0; i2 < TDRatioLearner.this.actions.length; i2++) {
                int i3 = i + TDRatioLearner.this.actions[i2];
                if (i3 < 0) {
                    i3 = 0;
                } else if (i3 >= this.height) {
                    i3 = maxState();
                }
                create.put(Integer.valueOf(i3), Integer.valueOf(i2));
            }
            Object[] array = create.keySet().toArray();
            int i4 = Integer.MAX_VALUE;
            for (Integer num : create.get(Integer.valueOf(((Integer) array[random.nextInt(array.length)]).intValue()))) {
                if (Math.abs(num.intValue()) < Math.abs(i4)) {
                    i4 = num.intValue();
                }
            }
            return i4;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:se/sics/kompics/network/data/policies/TDRatioLearner$CollapsedMatrixEstimator.class */
    public class CollapsedMatrixEstimator implements ActionValueEstimator {
        private final int stateMiddleIndex;
        private final double[] m;

        public CollapsedMatrixEstimator(Rational rational) {
            int intValue = Rational.valueOf(2L, 1L).divide(rational).intValue() + 1;
            this.m = new double[intValue];
            Arrays.fill(this.m, 0.0d);
            this.stateMiddleIndex = intValue / 2;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int numStates() {
            return this.m.length;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public double at(int i, int i2) {
            int i3 = i + TDRatioLearner.this.actions[i2];
            if (i3 < 0) {
                i3 = 0;
            } else if (i3 >= this.m.length) {
                i3 = this.m.length - 1;
            }
            return this.m[i3];
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public void set(double d, int i, int i2) {
            int i3 = i + TDRatioLearner.this.actions[i2];
            if (i3 < 0) {
                i3 = 0;
            } else if (i3 >= this.m.length) {
                i3 = this.m.length - 1;
            }
            this.m[i3] = d;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int middleState() {
            return this.stateMiddleIndex;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int maxState() {
            return this.m.length - 1;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int bestActionAt(int i) {
            TreeMultimap create = TreeMultimap.create();
            for (int i2 = 0; i2 < TDRatioLearner.this.actions.length; i2++) {
                int i3 = i + TDRatioLearner.this.actions[i2];
                if (i3 < 0) {
                    i3 = 0;
                } else if (i3 >= this.m.length) {
                    i3 = maxState();
                }
                create.put(Integer.valueOf(i3), Integer.valueOf(i2));
            }
            int i4 = i;
            double d = this.m[i];
            for (Integer num : create.keys()) {
                double d2 = this.m[num.intValue()];
                if (d2 > d) {
                    i4 = num.intValue();
                    d = d2;
                }
            }
            TDRatioLearner.LOG.trace("Of the target states {}, the best is {} (val: {}) in {}", new Object[]{create, Integer.valueOf(i4), Double.valueOf(d), Arrays.toString(this.m)});
            int i5 = Integer.MAX_VALUE;
            for (Integer num2 : create.get(Integer.valueOf(i4))) {
                if (Math.abs(num2.intValue()) < Math.abs(i5)) {
                    i5 = num2.intValue();
                }
            }
            return i5;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int randomActionAt(int i, Random random) {
            TreeMultimap create = TreeMultimap.create();
            for (int i2 = 0; i2 < TDRatioLearner.this.actions.length; i2++) {
                int i3 = i + TDRatioLearner.this.actions[i2];
                if (i3 < 0) {
                    i3 = 0;
                } else if (i3 >= this.m.length) {
                    i3 = maxState();
                }
                create.put(Integer.valueOf(i3), Integer.valueOf(i2));
            }
            Object[] array = create.keySet().toArray();
            int i4 = Integer.MAX_VALUE;
            for (Integer num : create.get(Integer.valueOf(((Integer) array[random.nextInt(array.length)]).intValue()))) {
                if (Math.abs(num.intValue()) < Math.abs(i4)) {
                    i4 = num.intValue();
                }
            }
            return i4;
        }

        public String toString() {
            return "CollapsedMatrixEstimator with Q=\n" + Arrays.toString(this.m);
        }
    }

    /* loaded from: input_file:se/sics/kompics/network/data/policies/TDRatioLearner$DerivedPolicy.class */
    interface DerivedPolicy {
        int chooseAction();
    }

    /* loaded from: input_file:se/sics/kompics/network/data/policies/TDRatioLearner$EpsilonGreedy.class */
    class EpsilonGreedy implements DerivedPolicy {
        private double epsilon;
        private final double epsilonDelta;
        private final double minEpsilon;
        private final Random RAND = new Random(1);

        EpsilonGreedy() {
            this.epsilon = ((Double) TDRatioLearner.this.config.getValue("kompics.net.data.td.epsilonGreedy.epsilon", Double.class)).doubleValue();
            this.epsilonDelta = ((Double) TDRatioLearner.this.config.getValue("kompics.net.data.td.epsilonGreedy.epsilonDelta", Double.class)).doubleValue();
            this.minEpsilon = ((Double) TDRatioLearner.this.config.getValue("kompics.net.data.td.epsilonGreedy.minEpsilon", Double.class)).doubleValue();
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.DerivedPolicy
        public int chooseAction() {
            if (this.epsilon > this.minEpsilon) {
                this.epsilon = Math.max(this.epsilon - this.epsilonDelta, this.minEpsilon);
            }
            if (this.RAND.nextDouble() < this.epsilon) {
                int randomActionAt = TDRatioLearner.this.Q.randomActionAt(TDRatioLearner.this.state, this.RAND);
                TDRatioLearner.LOG.trace("Selected {} randomly.", Integer.valueOf(TDRatioLearner.this.actions[randomActionAt]));
                return randomActionAt;
            }
            int bestActionAt = TDRatioLearner.this.Q.bestActionAt(TDRatioLearner.this.state);
            TDRatioLearner.LOG.trace("Selected {} greedily", Integer.valueOf(TDRatioLearner.this.actions[bestActionAt]));
            return bestActionAt;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:se/sics/kompics/network/data/policies/TDRatioLearner$FunctionApproximationEstimator.class */
    public class FunctionApproximationEstimator implements ActionValueEstimator {
        private final int stateMiddleIndex;
        private final double[] m;
        private PolynomialFunction approxFunction;

        public FunctionApproximationEstimator(Rational rational) {
            int intValue = Rational.valueOf(2L, 1L).divide(rational).intValue() + 1;
            this.m = new double[intValue];
            Arrays.fill(this.m, 0.0d);
            this.stateMiddleIndex = intValue / 2;
            this.approxFunction = new PolynomialFunction(new double[]{0.0d});
        }

        private void updateFunction() {
            WeightedObservedPoints weightedObservedPoints = new WeightedObservedPoints();
            int i = 0;
            double d = 0.0d;
            for (int i2 = 0; i2 < this.m.length; i2++) {
                double d2 = this.m[i2];
                if (d2 != 0.0d) {
                    i++;
                    weightedObservedPoints.add(i2, d2);
                    if (d2 > d) {
                        d = d2;
                    }
                }
            }
            if (i == 0) {
                return;
            }
            if (i == 1) {
                weightedObservedPoints.add(-1.0d, 2.0d * d);
                i++;
            }
            if (i == 2) {
                this.approxFunction = new PolynomialFunction(PolynomialCurveFitter.create(1).fit(weightedObservedPoints.toList()));
            } else {
                this.approxFunction = new PolynomialFunction(PolynomialCurveFitter.create(2).fit(weightedObservedPoints.toList()));
            }
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int numStates() {
            return this.m.length;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public double at(int i, int i2) {
            int i3 = i + TDRatioLearner.this.actions[i2];
            if (i3 < 0) {
                i3 = 0;
            } else if (i3 >= this.m.length) {
                i3 = this.m.length - 1;
            }
            return this.m[i3];
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public void set(double d, int i, int i2) {
            int i3 = i + TDRatioLearner.this.actions[i2];
            if (i3 < 0) {
                i3 = 0;
            } else if (i3 >= this.m.length) {
                i3 = this.m.length - 1;
            }
            this.m[i3] = d;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int middleState() {
            return this.stateMiddleIndex;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int maxState() {
            return this.m.length - 1;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int bestActionAt(int i) {
            boolean z = false;
            TreeMultimap create = TreeMultimap.create();
            for (int i2 = 0; i2 < TDRatioLearner.this.actions.length; i2++) {
                int i3 = i + TDRatioLearner.this.actions[i2];
                if (i3 < 0) {
                    i3 = 0;
                } else if (i3 >= this.m.length) {
                    i3 = maxState();
                }
                create.put(Integer.valueOf(i3), Integer.valueOf(i2));
            }
            int i4 = i;
            double d = this.m[i];
            TreeMap treeMap = new TreeMap();
            for (Integer num : create.keys()) {
                double d2 = this.m[num.intValue()];
                if (d2 == 0.0d) {
                    if (!z) {
                        updateFunction();
                        z = true;
                    }
                    d2 = this.approxFunction.value(num.intValue());
                }
                treeMap.put(num, Double.valueOf(d2));
                if (d2 > d) {
                    i4 = num.intValue();
                    d = d2;
                }
            }
            TDRatioLearner.LOG.trace("Of the target states {} with values {}, the best is {} (val: {}) in {}", new Object[]{create, treeMap, Integer.valueOf(i4), Double.valueOf(d), Arrays.toString(this.m)});
            int i5 = Integer.MAX_VALUE;
            for (Integer num2 : create.get(Integer.valueOf(i4))) {
                if (Math.abs(num2.intValue()) < Math.abs(i5)) {
                    i5 = num2.intValue();
                }
            }
            return i5;
        }

        @Override // se.sics.kompics.network.data.policies.TDRatioLearner.ActionValueEstimator
        public int randomActionAt(int i, Random random) {
            TreeMultimap create = TreeMultimap.create();
            for (int i2 = 0; i2 < TDRatioLearner.this.actions.length; i2++) {
                int i3 = i + TDRatioLearner.this.actions[i2];
                if (i3 < 0) {
                    i3 = 0;
                } else if (i3 >= this.m.length) {
                    i3 = maxState();
                }
                create.put(Integer.valueOf(i3), Integer.valueOf(i2));
            }
            Object[] array = create.keySet().toArray();
            int i4 = Integer.MAX_VALUE;
            for (Integer num : create.get(Integer.valueOf(((Integer) array[random.nextInt(array.length)]).intValue()))) {
                if (Math.abs(num.intValue()) < Math.abs(i4)) {
                    i4 = num.intValue();
                }
            }
            return i4;
        }

        public String toString() {
            return "CollapsedMatrixEstimator with Q=\n" + Arrays.toString(this.m);
        }
    }

    public TDRatioLearner(Config config) {
        this.config = config;
        this.alpha = ((Double) this.config.getValue("kompics.net.data.td.alpha", Double.class)).doubleValue();
        this.gamma = ((Double) this.config.getValue("kompics.net.data.td.gamma", Double.class)).doubleValue();
        this.lambda = ((Double) this.config.getValue("kompics.net.data.td.lambda", Double.class)).doubleValue();
        this.stepSize = Rational.valueOf(1L, ((Long) this.config.getValue("kompics.net.data.td.stepSize", Long.class)).longValue());
        this.actions = completeActions(this.config.getValues("kompics.net.data.td.actions", Integer.class));
        ActionValueEstimator.Implementation valueOf = ActionValueEstimator.Implementation.valueOf((String) this.config.getValue("kompics.net.data.td.actionValueEstimator", String.class));
        AVEFactory aVEFactory = new AVEFactory();
        this.Q = aVEFactory.getInstance(valueOf, this.stepSize);
        this.e = aVEFactory.getInstance(ActionValueEstimator.Implementation.MATRIX, this.stepSize);
        LOG.trace("Initialised TD with actions {}, and {} states, stepSize {}", new Object[]{Arrays.toString(this.actions), Integer.valueOf(this.Q.numStates()), this.stepSize});
        this.policy = new EpsilonGreedy();
        this.lastAction = this.actions[this.actions.length / 2];
        this.state = this.Q.middleState();
        this.lastState = this.state;
    }

    @Override // se.sics.kompics.network.data.policies.ProtocolRatioPolicy
    public Rational update(double d, double d2) {
        if (Double.isNaN(d)) {
            return stateToRatio(this.state);
        }
        int i = this.lastAction;
        int i2 = this.lastState;
        int i3 = this.state;
        int chooseAction = this.policy.chooseAction();
        double at = (d + (this.gamma * this.Q.at(i3, chooseAction))) - this.Q.at(i2, i);
        this.e.set(1.0d, i2, i);
        for (int i4 = 0; i4 < this.actions.length; i4++) {
            if (i != i4) {
                this.e.set(0.0d, i2, i4);
            }
        }
        for (int i5 = 0; i5 < this.Q.numStates(); i5++) {
            for (int i6 = 0; i6 < this.actions.length; i6++) {
                double at2 = this.e.at(i5, i6);
                double at3 = this.Q.at(i5, i6) + (this.alpha * at * at2);
                double d3 = this.gamma * this.lambda * at2;
                this.Q.set(at3, i5, i6);
                this.e.set(d3, i5, i6);
            }
        }
        this.lastState = i3;
        this.lastAction = chooseAction;
        this.state = applyAction(chooseAction, i3);
        Rational stateToRatio = stateToRatio(this.state);
        LOG.info("Updated learner: r={}, s={}, a={}, s'={}, a'={}, s''={} ({}), \n Q={} \n e={}", new Object[]{Double.valueOf(d), Integer.valueOf(i2), Integer.valueOf(i), Integer.valueOf(i3), Integer.valueOf(chooseAction), Integer.valueOf(this.state), stateToRatio, this.Q, this.e});
        return stateToRatio;
    }

    @Override // se.sics.kompics.network.data.policies.ProtocolRatioPolicy
    public void initialState(Rational rational) {
        this.state = ratioToState(rational);
    }

    private int ratioToState(Rational rational) {
        if (rational.isGreaterThan(Rational.ONE)) {
            return this.Q.maxState();
        }
        if (rational.isLessThan(Rational.ONE.inverse())) {
            return 0;
        }
        return Rational.valueOf(rational.getDividend().times(this.stepSize.getDividend()), rational.getDivisor()).round().intValue() + this.Q.middleState();
    }

    private Rational stateToRatio(int i) {
        return this.stepSize.times(i - this.Q.middleState());
    }

    private int applyAction(int i, int i2) {
        int i3 = i2 + this.actions[i];
        if (i3 < 0) {
            return 0;
        }
        return i3 > this.Q.maxState() ? this.Q.maxState() : i3;
    }

    private static int[] completeActions(List<Integer> list) {
        ArrayList arrayList = new ArrayList((2 * list.size()) + 1);
        boolean z = false;
        for (Integer num : list) {
            arrayList.add(num);
            if (num.intValue() != 0) {
                arrayList.add(Integer.valueOf(-num.intValue()));
            } else {
                z = true;
            }
        }
        if (!z) {
            arrayList.add(0);
        }
        int[] iArr = new int[arrayList.size()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = ((Integer) arrayList.get(i)).intValue();
        }
        Arrays.sort(iArr);
        return iArr;
    }
}
