package org.linqs.psl.application.learning.weight.search;

import java.util.List;
import org.linqs.psl.application.learning.weight.VotedPerceptron;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.application.learning.weight.maxlikelihood.MaxLikelihoodMPE;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.Rule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/linqs/psl/application/learning/weight/search/InitialWeightHyperband.class */
public class InitialWeightHyperband extends Hyperband {
    public static final String CONFIG_PREFIX = "initialweighthyperband";
    public static final String INTERNAL_WLA_KEY = "initialweighthyperband.internalwla";
    private VotedPerceptron internalWLA;
    private static final Logger log = LoggerFactory.getLogger((Class<?>) InitialWeightHyperband.class);
    public static final String INTERNAL_WLA_DEFAULT = MaxLikelihoodMPE.class.getName();

    public InitialWeightHyperband(Model model, Database database, Database database2) {
        this(model.getRules(), database, database2);
    }

    public InitialWeightHyperband(List<Rule> list, Database database, Database database2) {
        super(list, database, database2);
        this.internalWLA = (VotedPerceptron) WeightLearningApplication.getWLA(Config.getString(INTERNAL_WLA_KEY, INTERNAL_WLA_DEFAULT), list, database, database2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void postInitGroundModel() {
        super.postInitGroundModel();
        this.internalWLA.initGroundModel(this.reasoner, this.groundRuleStore, this.termStore, this.termGenerator, this.atomManager, this.trainingMap);
    }

    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication
    public void setBudget(double d) {
        this.internalWLA.setBudget(d);
        super.setBudget(d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.linqs.psl.application.learning.weight.search.Hyperband
    public double run(double[] dArr) {
        this.internalWLA.learn();
        for (int i = 0; i < this.mutableRules.size(); i++) {
            dArr[i] = this.mutableRules.get(i).getWeight();
        }
        return super.run(dArr);
    }

    @Override // org.linqs.psl.application.learning.weight.WeightLearningApplication, org.linqs.psl.application.ModelApplication
    public void close() {
        super.close();
        this.internalWLA.close();
    }
}
