package org.linqs.psl.model.predicate.model;

import java.util.Map;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.term.ConstantType;
import org.linqs.psl.util.Logger;

/* loaded from: input_file:org/linqs/psl/model/predicate/model/ModelPredicate.class */
public class ModelPredicate extends StandardPredicate {
    private static final Logger log = Logger.getLogger(ModelPredicate.class);
    private static final String CONFIG_MIRROR = "mirror";
    protected SupportingModel model;
    private boolean modelLoaded;
    private boolean modelRan;

    protected ModelPredicate(String str, ConstantType[] constantTypeArr, SupportingModel supportingModel) {
        super(str, constantTypeArr);
        this.model = supportingModel;
        this.modelLoaded = false;
        this.modelRan = false;
    }

    @Override // org.linqs.psl.model.predicate.StandardPredicate
    public boolean isFixedMirror() {
        return true;
    }

    @Override // org.linqs.psl.model.predicate.Predicate
    public void close() {
        super.close();
        if (this.model != null) {
            this.model.close();
            this.model = null;
        }
    }

    public void loadModel(Map<String, String> map, String str) {
        if (map.containsKey(CONFIG_MIRROR)) {
            StandardPredicate standardPredicate = StandardPredicate.get(map.get(CONFIG_MIRROR));
            if (standardPredicate == null) {
                throw new IllegalArgumentException(String.format("Cannot make unknwon predicate (%s) a mirror for %s.", map.get(CONFIG_MIRROR), this));
            }
            setMirror(standardPredicate);
            standardPredicate.setMirror(this);
        }
        this.model.load(map, str);
        this.modelLoaded = true;
    }

    public float getValue(RandomVariableAtom randomVariableAtom) {
        checkModel();
        if (this.modelRan) {
            return Math.max(0.0f, Math.min(1.0f, this.model.getValue(randomVariableAtom)));
        }
        throw new IllegalStateException("Cannot invoke getValue() before runModel() has been called.");
    }

    public void runModel() {
        checkModel();
        this.model.run();
        this.modelRan = true;
    }

    public void resetLabels() {
        checkModel();
        this.model.resetLabels();
    }

    public float getLabel(RandomVariableAtom randomVariableAtom) {
        checkModel();
        return this.model.getLabel(randomVariableAtom);
    }

    public void setLabel(RandomVariableAtom randomVariableAtom, float f) {
        checkModel();
        this.model.setLabel(randomVariableAtom, f);
    }

    public void fit() {
        checkModel();
        log.trace("Fitting {} ({}).", this, this.model);
        this.model.fit();
        log.trace("Done fitting {} ({}).", this, this.model);
    }

    public void initialFit() {
        checkModel();
        log.trace("Initial fitting {} ({}).", this, this.model);
        this.model.initialFit();
        log.trace("Done initial fitting {} ({}).", this, this.model);
    }

    private void checkModel() {
        if (!this.modelLoaded) {
            throw new IllegalStateException("ModelPredicate (" + this + ") has not been initialized via loadModel().");
        }
    }

    public static ModelPredicate get(String str) {
        StandardPredicate standardPredicate = StandardPredicate.get(str);
        if (standardPredicate == null) {
            return null;
        }
        if (standardPredicate instanceof ModelPredicate) {
            return (ModelPredicate) standardPredicate;
        }
        throw new ClassCastException("Predicate (" + str + ") is not a ModelPredicate.");
    }

    public static ModelPredicate get(String str, SupportingModel supportingModel, ConstantType... constantTypeArr) {
        ModelPredicate modelPredicate = get(str);
        if (modelPredicate == null) {
            return new ModelPredicate(str, constantTypeArr, supportingModel);
        }
        StandardPredicate.validateTypes(modelPredicate, constantTypeArr);
        return modelPredicate;
    }
}
