package org.linqs.psl.model.deep;

import com.healthmarketscience.sqlbuilder.SqlObjectList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.util.FileUtils;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.StringUtils;

/* loaded from: input_file:org/linqs/psl/model/deep/DeepModelPredicate.class */
public class DeepModelPredicate extends DeepModel {
    private static final Logger log = Logger.getLogger(DeepModelPredicate.class);
    private static final String DELIM = "\t";
    public static final String CONFIG_ENTITY_DATA_MAP_PATH = "entity-data-map-path";
    public static final String CONFIG_ENTITY_ARGUMENT_INDEXES = "entity-argument-indexes";
    public static final String CONFIG_CLASS_SIZE = "class-size";
    private AtomStore atomStore;
    private Predicate predicate;
    private int classSize;
    private int[] atomIndexes;
    private int[] dataIndexes;
    private float[] gradients;
    private float[] symbolicGradients;
    private ArrayList<Integer> validAtomIndexes;
    private ArrayList<Integer> validDataIndexes;

    public DeepModelPredicate(Predicate predicate) {
        super("DeepModelPredicate");
        this.atomStore = null;
        this.predicate = predicate;
        this.classSize = -1;
        this.atomIndexes = null;
        this.dataIndexes = null;
        this.gradients = null;
        this.symbolicGradients = null;
        this.validAtomIndexes = new ArrayList<>();
        this.validDataIndexes = new ArrayList<>();
    }

    public DeepModelPredicate copy() {
        DeepModelPredicate deepModelPredicate = new DeepModelPredicate(this.predicate);
        deepModelPredicate.pythonOptions = this.pythonOptions;
        deepModelPredicate.application = this.application;
        freePort(deepModelPredicate.port);
        deepModelPredicate.port = this.port;
        deepModelPredicate.pythonModule = this.pythonModule;
        deepModelPredicate.sharedMemoryPath = this.sharedMemoryPath;
        deepModelPredicate.pythonServerProcess = this.pythonServerProcess;
        deepModelPredicate.sharedFile = this.sharedFile;
        deepModelPredicate.sharedBuffer = this.sharedBuffer;
        deepModelPredicate.socket = this.socket;
        deepModelPredicate.socketInput = this.socketInput;
        deepModelPredicate.socketOutput = this.socketOutput;
        deepModelPredicate.serverOpen = this.serverOpen;
        deepModelPredicate.atomStore = this.atomStore;
        deepModelPredicate.classSize = this.classSize;
        deepModelPredicate.atomIndexes = null;
        if (this.atomIndexes != null) {
            deepModelPredicate.atomIndexes = Arrays.copyOf(this.atomIndexes, this.atomIndexes.length);
        }
        deepModelPredicate.dataIndexes = null;
        if (this.dataIndexes != null) {
            deepModelPredicate.dataIndexes = Arrays.copyOf(this.dataIndexes, this.dataIndexes.length);
        }
        deepModelPredicate.validAtomIndexes = new ArrayList<>(this.validAtomIndexes.size());
        deepModelPredicate.validAtomIndexes.addAll(this.validAtomIndexes);
        deepModelPredicate.validDataIndexes = new ArrayList<>(this.validDataIndexes.size());
        deepModelPredicate.validDataIndexes.addAll(this.validDataIndexes);
        deepModelPredicate.gradients = null;
        if (this.gradients != null) {
            deepModelPredicate.gradients = Arrays.copyOf(this.gradients, this.gradients.length);
        }
        deepModelPredicate.symbolicGradients = null;
        if (this.symbolicGradients != null) {
            deepModelPredicate.symbolicGradients = Arrays.copyOf(this.symbolicGradients, this.symbolicGradients.length);
        }
        return deepModelPredicate;
    }

    @Override // org.linqs.psl.model.deep.DeepModel
    public int init() {
        log.debug("Initializing deep model predicate: {}", this.predicate.getName());
        validateOptions();
        this.classSize = Integer.parseInt(this.pythonOptions.get(CONFIG_CLASS_SIZE));
        int mapEntitiesFromFileToAtoms = mapEntitiesFromFileToAtoms(FileUtils.makePath(this.pythonOptions.get("relative-dir"), this.pythonOptions.get(CONFIG_ENTITY_DATA_MAP_PATH)), this.atomStore, StringUtils.splitInt(this.pythonOptions.get(CONFIG_ENTITY_ARGUMENT_INDEXES), SqlObjectList.DEFAULT_DELIMITER).length);
        this.atomIndexes = new int[this.validAtomIndexes.size()];
        this.gradients = new float[this.validAtomIndexes.size()];
        this.dataIndexes = new int[this.validDataIndexes.size()];
        for (int i = 0; i < this.atomIndexes.length; i++) {
            this.atomIndexes[i] = this.validAtomIndexes.get(i).intValue();
            this.gradients[i] = 0.0f;
        }
        for (int i2 = 0; i2 < this.dataIndexes.length; i2++) {
            this.dataIndexes[i2] = this.validDataIndexes.get(i2).intValue();
        }
        this.validAtomIndexes.clear();
        this.validDataIndexes.clear();
        return 32 + (mapEntitiesFromFileToAtoms * 32) + (mapEntitiesFromFileToAtoms * this.classSize * 32);
    }

    @Override // org.linqs.psl.model.deep.DeepModel
    public void writeFitData() {
        log.debug("Writing fit data for deep model predicate: {}", this.predicate.getName());
        for (int i = 0; i < this.gradients.length; i++) {
            this.gradients[i] = this.symbolicGradients[this.atomIndexes[i]];
        }
        writeDataIndexData();
        writeGradientData(this.gradients);
    }

    @Override // org.linqs.psl.model.deep.DeepModel
    public void writePredictData() {
        log.debug("Writing predict data for deep model predicate: {}", this.predicate.getName());
        writeDataIndexData();
    }

    @Override // org.linqs.psl.model.deep.DeepModel
    public float readPredictData() {
        log.debug("Reading predict data for deep model predicate: {}", this.predicate.getName());
        int i = this.sharedBuffer.getInt();
        if (i != this.atomIndexes.length) {
            throw new RuntimeException(String.format("External model did not make the desired number of predictions, got %d, expected %d.", Integer.valueOf(i), Integer.valueOf(this.atomIndexes.length)));
        }
        float[] atomValues = this.atomStore.getAtomValues();
        float f = 0.0f;
        for (int i2 = 0; i2 < this.atomIndexes.length; i2++) {
            float f2 = this.sharedBuffer.getFloat();
            int i3 = this.atomIndexes[i2];
            f += Math.abs(atomValues[i3] - f2);
            atomValues[i3] = f2;
            ((RandomVariableAtom) this.atomStore.getAtom(i3)).setValue(f2);
        }
        return f;
    }

    @Override // org.linqs.psl.model.deep.DeepModel
    public void writeEvalData() {
        log.debug("Writing eval data for deep model predicate: {}", this.predicate.getName());
        writeDataIndexData();
    }

    @Override // org.linqs.psl.model.deep.DeepModel
    public void close() {
        super.close();
        this.classSize = -1;
        this.atomIndexes = null;
        this.dataIndexes = null;
        this.gradients = null;
        this.symbolicGradients = null;
        this.validAtomIndexes.clear();
        this.validDataIndexes.clear();
    }

    public void setAtomStore(AtomStore atomStore) {
        setAtomStore(atomStore, false);
    }

    public void setAtomStore(AtomStore atomStore, boolean z) {
        this.atomStore = atomStore;
        if (z) {
            init();
        }
    }

    public void setSymbolicGradients(float[] fArr) {
        this.symbolicGradients = fArr;
    }

    private void validateOptions() {
        for (Map.Entry<String, Object> entry : this.predicate.getPredicateOptions().entrySet()) {
            this.pythonOptions.put(entry.getKey(), (String) entry.getValue());
        }
        if (FileUtils.makePath(this.pythonOptions.get("relative-dir"), this.pythonOptions.get(CONFIG_ENTITY_DATA_MAP_PATH)) == null) {
            throw new IllegalArgumentException(String.format("A DeepPredicate must have an entity to data map path (\"%s\") specified in predicate config.", CONFIG_ENTITY_DATA_MAP_PATH));
        }
        if (this.pythonOptions.get(CONFIG_ENTITY_ARGUMENT_INDEXES) == null) {
            throw new IllegalArgumentException(String.format("A DeepPredicate must have entity argument indexes (\"%s\") specified in predicate config.", CONFIG_ENTITY_ARGUMENT_INDEXES));
        }
        if (this.pythonOptions.get(CONFIG_CLASS_SIZE) == null) {
            throw new IllegalArgumentException(String.format("A DeepPredicate must have a class size (\"%s\") specified in predicate config.", CONFIG_CLASS_SIZE));
        }
        for (Map.Entry<String, String> entry2 : this.pythonOptions.entrySet()) {
            if (entry2.getKey().contains(this.application + "::")) {
                this.pythonOptions.put(entry2.getKey().split("::")[1], entry2.getValue());
            }
        }
    }

    /* JADX WARN: Code restructure failed: missing block: B:46:0x0145, code lost:
    
        if (r0 == null) goto L54;
     */
    /* JADX WARN: Code restructure failed: missing block: B:48:0x014a, code lost:
    
        if (0 == 0) goto L39;
     */
    /* JADX WARN: Code restructure failed: missing block: B:49:0x0161, code lost:
    
        r0.close();
     */
    /* JADX WARN: Code restructure failed: missing block: B:51:0x014d, code lost:
    
        r0.close();
     */
    /* JADX WARN: Code restructure failed: missing block: B:53:0x0155, code lost:
    
        r20 = move-exception;
     */
    /* JADX WARN: Code restructure failed: missing block: B:54:0x0157, code lost:
    
        r0.addSuppressed(r20);
     */
    /* JADX WARN: Finally extract failed */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private int mapEntitiesFromFileToAtoms(java.lang.String r9, org.linqs.psl.database.AtomStore r10, int r11) {
        /*
            Method dump skipped, instructions count: 447
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.linqs.psl.model.deep.DeepModelPredicate.mapEntitiesFromFileToAtoms(java.lang.String, org.linqs.psl.database.AtomStore, int):int");
    }

    private void writeGradientData(float[] fArr) {
        for (float f : fArr) {
            this.sharedBuffer.putFloat(f);
        }
    }

    private void writeDataIndexData() {
        this.sharedBuffer.putInt(this.dataIndexes.length);
        for (int i = 0; i < this.dataIndexes.length; i++) {
            this.sharedBuffer.putInt(this.dataIndexes[i]);
        }
    }
}
