package org.linqs.psl.model.predicate;

import java.io.IOException;
import java.io.NotSerializableException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.ObjectStreamException;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.model.deep.DeepModelPredicate;
import org.linqs.psl.model.term.ConstantType;
import org.linqs.psl.util.Logger;

/* loaded from: input_file:org/linqs/psl/model/predicate/DeepPredicate.class */
public class DeepPredicate extends StandardPredicate {
    private static final Logger log = Logger.getLogger(DeepPredicate.class);
    private DeepModelPredicate deepModel;

    protected DeepPredicate(String str, ConstantType[] constantTypeArr) {
        super(str, constantTypeArr);
        this.deepModel = new DeepModelPredicate(this);
    }

    public void initDeepPredicate(AtomStore atomStore, String str) {
        this.deepModel.setAtomStore(atomStore);
        this.deepModel.initDeepModel(str);
    }

    public void fitDeepPredicate(float[] fArr) {
        this.deepModel.setSymbolicGradients(fArr);
        this.deepModel.fitDeepModel();
    }

    public DeepModelPredicate getDeepModel() {
        return this.deepModel;
    }

    public void setDeepModel(DeepModelPredicate deepModelPredicate) {
        this.deepModel = deepModelPredicate;
    }

    public float predictDeepModel() {
        return this.deepModel.predictDeepModel(false);
    }

    public float predictDeepModel(Boolean bool) {
        return this.deepModel.predictDeepModel(bool);
    }

    public void evalDeepModel() {
        this.deepModel.evalDeepModel();
    }

    public void saveDeepModel() {
        this.deepModel.saveDeepModel();
    }

    @Override // org.linqs.psl.model.predicate.Predicate
    public synchronized void close() {
        super.close();
        this.deepModel.close();
    }

    public static DeepPredicate get(String str) {
        StandardPredicate standardPredicate = StandardPredicate.get(str);
        if (standardPredicate == null) {
            return null;
        }
        if (standardPredicate instanceof DeepPredicate) {
            return (DeepPredicate) standardPredicate;
        }
        throw new ClassCastException("Predicate (" + str + ") is not a DeepPredicate.");
    }

    public static DeepPredicate get(String str, ConstantType... constantTypeArr) {
        DeepPredicate deepPredicate = get(str);
        if (deepPredicate == null) {
            return new DeepPredicate(str, constantTypeArr);
        }
        StandardPredicate.validateTypes(deepPredicate, constantTypeArr);
        return deepPredicate;
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        throw new NotSerializableException();
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        throw new NotSerializableException();
    }

    private void readObjectNoData() throws ObjectStreamException {
        throw new NotSerializableException();
    }
}
