package org.linqs.psl.application.learning.weight.search.grid;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/search/grid/BaseGridSearch.class */
public abstract class BaseGridSearch extends WeightLearningApplication {
    private static final Logger log = Logger.getLogger(BaseGridSearch.class);
    protected int maxNumLocations;
    protected int numLocations;
    protected Map<String, Double> objectives;
    protected String currentLocation;

    public BaseGridSearch(List<Rule> list, Database database, Database database2, Database database3, Database database4, boolean z) {
        super(list, database, database2, database3, database4, Boolean.valueOf(z));
        this.maxNumLocations = 0;
        this.numLocations = this.maxNumLocations;
        this.objectives = new HashMap();
        this.currentLocation = null;
        if (this.runValidation) {
            throw new IllegalArgumentException("Validation is not supported by GridSearch weight learning applications.");
        }
    }

    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    protected void doLearn() {
        if (this.evaluation == null) {
            throw new IllegalStateException(String.format("No evaluation has been set for weight learning method (%s), which is required for search-based methods.", getClass().getName()));
        }
        double d = -1.0d;
        float[] fArr = new float[this.mutableRules.size()];
        float[] fArr2 = new float[this.mutableRules.size()];
        float[] fArr3 = new float[this.mutableRules.size()];
        int i = 0;
        while (true) {
            if (i >= this.numLocations) {
                break;
            }
            if (!chooseNextLocation()) {
                log.debug("Stopping search.");
                break;
            }
            log.debug("Iteration {} / {} ({}) -- Inspecting location {}", Integer.valueOf(i), Integer.valueOf(this.numLocations), Integer.valueOf(this.maxNumLocations), this.currentLocation);
            boolean z = false;
            getWeights(fArr2);
            System.arraycopy(fArr2, 0, fArr3, 0, fArr2.length);
            int i2 = 0;
            while (true) {
                if (i2 >= fArr2.length) {
                    break;
                }
                if (fArr2[i2] > 0.0d) {
                    z = true;
                    break;
                }
                i2++;
            }
            if (z) {
                MathUtils.toUnit(fArr3);
            }
            for (int i3 = 0; i3 < this.mutableRules.size(); i3++) {
                this.mutableRules.get(i3).setWeight(fArr2[i3]);
            }
            log.trace("Weights: {}", fArr2);
            this.inTrainingMAPState = false;
            double inspectLocation = inspectLocation(fArr2);
            this.objectives.put(this.currentLocation, Double.valueOf(inspectLocation));
            if (i == 0 || inspectLocation < d) {
                d = inspectLocation;
                for (int i4 = 0; i4 < this.mutableRules.size(); i4++) {
                    fArr[i4] = fArr2[i4];
                }
            }
            log.debug("Weights: {} -- objective: {}", this.currentLocation, Double.valueOf(inspectLocation));
            i++;
        }
        for (int i5 = 0; i5 < this.mutableRules.size(); i5++) {
            this.mutableRules.get(i5).setWeight(fArr[i5]);
        }
        this.inTrainingMAPState = false;
    }

    protected double inspectLocation(float[] fArr) {
        computeTrainingMAPState();
        this.evaluation.compute(this.trainingMap);
        return (-1.0d) * this.evaluation.getNormalizedRepMetric();
    }

    protected abstract void getWeights(float[] fArr);

    protected abstract boolean chooseNextLocation();
}
