package org.deeplearning4j.arbiter.optimize.generator;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import org.apache.commons.math3.random.RandomAdaptor;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.FixedValue;
import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
import org.deeplearning4j.arbiter.util.LeafUtils;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"})
/* loaded from: input_file:org/deeplearning4j/arbiter/optimize/generator/GridSearchCandidateGenerator.class */
public class GridSearchCandidateGenerator extends BaseCandidateGenerator {
    private static final Logger log = LoggerFactory.getLogger(GridSearchCandidateGenerator.class);
    private final int discretizationCount;
    private final Mode mode;
    private int[] numValuesPerParam;
    private int totalNumCandidates;
    private Queue<Integer> order;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/arbiter/optimize/generator/GridSearchCandidateGenerator$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$arbiter$optimize$generator$GridSearchCandidateGenerator$Mode = new int[Mode.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$arbiter$optimize$generator$GridSearchCandidateGenerator$Mode[Mode.Sequential.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$arbiter$optimize$generator$GridSearchCandidateGenerator$Mode[Mode.RandomOrder.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* loaded from: input_file:org/deeplearning4j/arbiter/optimize/generator/GridSearchCandidateGenerator$Mode.class */
    public enum Mode {
        Sequential,
        RandomOrder
    }

    public GridSearchCandidateGenerator(@JsonProperty("parameterSpace") ParameterSpace<?> parameterSpace, @JsonProperty("discretizationCount") int i, @JsonProperty("mode") Mode mode, @JsonProperty("dataParameters") Map<String, Object> map, @JsonProperty("initDone") boolean z) {
        super(parameterSpace, map, z);
        this.discretizationCount = i;
        this.mode = mode;
        initialize();
    }

    public GridSearchCandidateGenerator(ParameterSpace<?> parameterSpace, int i, Mode mode, Map<String, Object> map) {
        this(parameterSpace, i, mode, map, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.arbiter.optimize.generator.BaseCandidateGenerator
    public void initialize() {
        super.initialize();
        List<ParameterSpace> uniqueObjects = LeafUtils.getUniqueObjects(this.parameterSpace.collectLeaves());
        int size = uniqueObjects.size();
        this.numValuesPerParam = new int[size];
        long j = 1;
        for (int i = 0; i < size; i++) {
            ParameterSpace parameterSpace = uniqueObjects.get(i);
            if (parameterSpace instanceof DiscreteParameterSpace) {
                this.numValuesPerParam[i] = ((DiscreteParameterSpace) parameterSpace).numValues();
            } else if (parameterSpace instanceof IntegerParameterSpace) {
                IntegerParameterSpace integerParameterSpace = (IntegerParameterSpace) parameterSpace;
                this.numValuesPerParam[i] = Math.min((integerParameterSpace.getMax() - integerParameterSpace.getMin()) + 1, this.discretizationCount);
            } else if (parameterSpace instanceof FixedValue) {
                this.numValuesPerParam[i] = 1;
            } else {
                this.numValuesPerParam[i] = this.discretizationCount;
            }
            j *= this.numValuesPerParam[i];
        }
        if (j >= 2147483647L) {
            throw new IllegalStateException("Invalid search: cannot process search with " + j + " candidates > Integer.MAX_VALUE");
        }
        this.order = new ConcurrentLinkedQueue();
        this.totalNumCandidates = (int) j;
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$arbiter$optimize$generator$GridSearchCandidateGenerator$Mode[this.mode.ordinal()]) {
            case LocalOptimizationRunner.DEFAULT_MAX_CONCURRENT_TASKS /* 1 */:
                for (int i2 = 0; i2 < this.totalNumCandidates; i2++) {
                    this.order.add(Integer.valueOf(i2));
                }
                return;
            case 2:
                ArrayList arrayList = new ArrayList(this.totalNumCandidates);
                for (int i3 = 0; i3 < this.totalNumCandidates; i3++) {
                    arrayList.add(Integer.valueOf(i3));
                }
                Collections.shuffle(arrayList, new RandomAdaptor(this.rng));
                this.order.addAll(arrayList);
                return;
            default:
                throw new RuntimeException();
        }
    }

    @Override // org.deeplearning4j.arbiter.optimize.api.CandidateGenerator
    public boolean hasMoreCandidates() {
        return !this.order.isEmpty();
    }

    @Override // org.deeplearning4j.arbiter.optimize.api.CandidateGenerator
    public Candidate getCandidate() {
        double[] indexToValues = indexToValues(this.numValuesPerParam, this.order.remove().intValue(), this.totalNumCandidates);
        Object obj = null;
        Exception exc = null;
        try {
            obj = this.parameterSpace.getValue(indexToValues);
        } catch (Exception e) {
            log.warn("Error getting configuration for candidate", e);
            exc = e;
        }
        return new Candidate(obj, this.candidateCounter.getAndIncrement(), indexToValues, this.dataParameters, exc);
    }

    @Override // org.deeplearning4j.arbiter.optimize.api.CandidateGenerator
    public Class<?> getCandidateType() {
        return null;
    }

    public static double[] indexToValues(int[] iArr, int i, int i2) {
        int i3 = i2;
        int i4 = i;
        int[] iArr2 = new int[iArr.length];
        for (int length = iArr2.length - 1; length >= 0; length--) {
            i3 /= iArr[length];
            iArr2[length] = i4 / i3;
            i4 %= i3;
        }
        double[] dArr = new double[iArr.length];
        for (int i5 = 0; i5 < dArr.length; i5++) {
            if (iArr[i5] <= 1) {
                dArr[i5] = 0.0d;
            } else {
                dArr[i5] = iArr2[i5] / (iArr[i5] - 1);
            }
        }
        return dArr;
    }

    @Override // org.deeplearning4j.arbiter.optimize.generator.BaseCandidateGenerator
    public String toString() {
        return "GridSearchCandidateGenerator(mode=" + this.mode + ")";
    }

    @Override // org.deeplearning4j.arbiter.optimize.generator.BaseCandidateGenerator
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof GridSearchCandidateGenerator)) {
            return false;
        }
        GridSearchCandidateGenerator gridSearchCandidateGenerator = (GridSearchCandidateGenerator) obj;
        if (!gridSearchCandidateGenerator.canEqual(this) || !super.equals(obj) || this.discretizationCount != gridSearchCandidateGenerator.discretizationCount) {
            return false;
        }
        Mode mode = this.mode;
        Mode mode2 = gridSearchCandidateGenerator.mode;
        if (mode == null) {
            if (mode2 != null) {
                return false;
            }
        } else if (!mode.equals(mode2)) {
            return false;
        }
        return Arrays.equals(this.numValuesPerParam, gridSearchCandidateGenerator.numValuesPerParam) && getTotalNumCandidates() == gridSearchCandidateGenerator.getTotalNumCandidates();
    }

    @Override // org.deeplearning4j.arbiter.optimize.generator.BaseCandidateGenerator
    protected boolean canEqual(Object obj) {
        return obj instanceof GridSearchCandidateGenerator;
    }

    @Override // org.deeplearning4j.arbiter.optimize.generator.BaseCandidateGenerator
    public int hashCode() {
        int hashCode = (super.hashCode() * 59) + this.discretizationCount;
        Mode mode = this.mode;
        return (((((hashCode * 59) + (mode == null ? 43 : mode.hashCode())) * 59) + Arrays.hashCode(this.numValuesPerParam)) * 59) + getTotalNumCandidates();
    }

    public int getTotalNumCandidates() {
        return this.totalNumCandidates;
    }
}
