package org.linqs.psl.runtime;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.config.RuntimeOptions;
import org.linqs.psl.database.DataStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.atom.AtomManager;
import org.linqs.psl.database.atom.PersistedAtomManager;
import org.linqs.psl.database.loading.Inserter;
import org.linqs.psl.database.rdbms.RDBMSDataStore;
import org.linqs.psl.database.rdbms.driver.H2DatabaseDriver;
import org.linqs.psl.grounding.GroundRuleStore;
import org.linqs.psl.grounding.Grounding;
import org.linqs.psl.grounding.MemoryGroundRuleStore;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.arithmetic.AbstractGroundArithmeticRule;
import org.linqs.psl.model.rule.logical.AbstractGroundLogicalRule;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.model.term.ConstantType;
import org.linqs.psl.model.term.UniqueStringID;
import org.linqs.psl.parser.ModelLoader;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.StringUtils;

/* loaded from: input_file:org/linqs/psl/runtime/GroundingAPI.class */
public final class GroundingAPI {
    public static final String PARTITION_OBS = "observed";
    public static final String PARTITION_UNOBS = "unobserved";
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/linqs/psl/runtime/GroundingAPI$GroundRuleInfo.class */
    public static final class GroundRuleInfo {
        public int ruleIndex;
        public String operator;
        public float constant;
        public float[] coefficients;
        public int[] atoms;

        public GroundRuleInfo(int i, String str, float f, float[] fArr, int[] iArr) {
            this.ruleIndex = i;
            this.operator = str;
            this.constant = f;
            this.coefficients = fArr;
            this.atoms = iArr;
        }

        public String toString() {
            return String.format("Rule: %d, Operator: %s, Constant: %f, coefficients: [%s], atoms: [%s].", Integer.valueOf(this.ruleIndex), this.operator, Float.valueOf(this.constant), StringUtils.join(", ", this.coefficients), StringUtils.join(", ", this.atoms));
        }
    }

    public static GroundRuleInfo[] ground(String[] strArr, String[] strArr2, int[] iArr, String[][] strArr3, String[][] strArr4) {
        Logger.setLevel("WARN");
        if (!$assertionsDisabled && strArr2.length != iArr.length) {
            throw new AssertionError();
        }
        RDBMSDataStore rDBMSDataStore = new RDBMSDataStore(new H2DatabaseDriver(H2DatabaseDriver.Type.Disk, RuntimeOptions.DB_H2_PATH.defaultValue().toString(), true));
        registerPredicates(strArr2, iArr, rDBMSDataStore);
        ArrayList arrayList = new ArrayList(strArr.length);
        for (String str : strArr) {
            arrayList.add(ModelLoader.loadRule(str));
        }
        Database database = rDBMSDataStore.getDatabase(rDBMSDataStore.getPartition(PARTITION_UNOBS), loadData(rDBMSDataStore, strArr2, strArr3, strArr4), rDBMSDataStore.getPartition(PARTITION_OBS));
        PersistedAtomManager persistedAtomManager = new PersistedAtomManager(database);
        MemoryGroundRuleStore memoryGroundRuleStore = new MemoryGroundRuleStore();
        Map<GroundAtom, Integer> buildAtomMap = buildAtomMap(strArr2, strArr3, strArr4, persistedAtomManager);
        Grounding.groundAll(arrayList, persistedAtomManager, memoryGroundRuleStore);
        GroundRuleInfo[] mapGroundRules = mapGroundRules(arrayList, buildAtomMap, memoryGroundRuleStore);
        memoryGroundRuleStore.close();
        database.close();
        rDBMSDataStore.close();
        return mapGroundRules;
    }

    private static void registerPredicates(String[] strArr, int[] iArr, DataStore dataStore) {
        for (int i = 0; i < strArr.length; i++) {
            ConstantType[] constantTypeArr = new ConstantType[iArr[i]];
            for (int i2 = 0; i2 < constantTypeArr.length; i2++) {
                constantTypeArr[i2] = ConstantType.UniqueStringID;
            }
            dataStore.registerPredicate(StandardPredicate.get(strArr[i], constantTypeArr));
        }
    }

    private static Set<StandardPredicate> loadData(DataStore dataStore, String[] strArr, String[][] strArr2, String[][] strArr3) {
        HashSet hashSet = new HashSet(strArr.length);
        for (int i = 0; i < strArr.length; i++) {
            StandardPredicate standardPredicate = StandardPredicate.get(strArr[i]);
            boolean z = true;
            Object[] objArr = new Object[standardPredicate.getArity()];
            Inserter inserter = dataStore.getInserter(standardPredicate, dataStore.getPartition(PARTITION_OBS));
            Inserter inserter2 = dataStore.getInserter(standardPredicate, dataStore.getPartition(PARTITION_UNOBS));
            for (int i2 = 0; i2 < strArr2.length; i2++) {
                if (i == Integer.parseInt(strArr2[i2][1])) {
                    Integer.parseInt(strArr2[i2][0]);
                    Double valueOf = strArr2[i2][2].length() > 0 ? Double.valueOf(Double.parseDouble(strArr2[i2][2])) : null;
                    for (int i3 = 0; i3 < standardPredicate.getArity(); i3++) {
                        objArr[i3] = strArr3[i2][i3];
                    }
                    if (valueOf == null) {
                        z = false;
                        inserter2.insert(objArr);
                    } else {
                        inserter.insertValue(valueOf.doubleValue(), objArr);
                    }
                }
            }
            if (z) {
                hashSet.add(standardPredicate);
            }
        }
        return hashSet;
    }

    private static Map<GroundAtom, Integer> buildAtomMap(String[] strArr, String[][] strArr2, String[][] strArr3, AtomManager atomManager) {
        HashMap hashMap = new HashMap(atomManager.getCachedObsCount() + atomManager.getCachedRVACount());
        for (int i = 0; i < strArr.length; i++) {
            StandardPredicate standardPredicate = StandardPredicate.get(strArr[i]);
            Constant[] constantArr = new Constant[standardPredicate.getArity()];
            for (int i2 = 0; i2 < strArr2.length; i2++) {
                if (i == Integer.parseInt(strArr2[i2][1])) {
                    int parseInt = Integer.parseInt(strArr2[i2][0]);
                    for (int i3 = 0; i3 < constantArr.length; i3++) {
                        constantArr[i3] = new UniqueStringID(strArr3[i2][i3]);
                    }
                    hashMap.put(atomManager.getAtom(standardPredicate, constantArr), Integer.valueOf(parseInt));
                }
            }
        }
        return hashMap;
    }

    private static GroundRuleInfo[] mapGroundRules(List<Rule> list, Map<GroundAtom, Integer> map, GroundRuleStore groundRuleStore) {
        GroundRuleInfo[] groundRuleInfoArr = new GroundRuleInfo[(int) groundRuleStore.size()];
        int i = 0;
        for (GroundRule groundRule : groundRuleStore.getGroundRules()) {
            if (groundRule instanceof AbstractGroundLogicalRule) {
                int i2 = i;
                i++;
                groundRuleInfoArr[i2] = mapLogicalGroundRule(list.indexOf(groundRule.getRule()), map, (AbstractGroundLogicalRule) groundRule);
            } else {
                int i3 = i;
                i++;
                groundRuleInfoArr[i3] = mapArithmeticGroundRule(list.indexOf(groundRule.getRule()), map, (AbstractGroundArithmeticRule) groundRule);
            }
        }
        return groundRuleInfoArr;
    }

    private static GroundRuleInfo mapLogicalGroundRule(int i, Map<GroundAtom, Integer> map, AbstractGroundLogicalRule abstractGroundLogicalRule) {
        int i2 = 0;
        float[] fArr = new float[abstractGroundLogicalRule.size()];
        int[] iArr = new int[abstractGroundLogicalRule.size()];
        for (GroundAtom groundAtom : abstractGroundLogicalRule.getPositiveAtoms()) {
            fArr[i2] = 1.0f;
            iArr[i2] = map.get(groundAtom).intValue();
            i2++;
        }
        for (GroundAtom groundAtom2 : abstractGroundLogicalRule.getNegativeAtoms()) {
            fArr[i2] = -1.0f;
            iArr[i2] = map.get(groundAtom2).intValue();
            i2++;
        }
        return new GroundRuleInfo(i, "|", 0.0f, fArr, iArr);
    }

    private static GroundRuleInfo mapArithmeticGroundRule(int i, Map<GroundAtom, Integer> map, AbstractGroundArithmeticRule abstractGroundArithmeticRule) {
        GroundAtom[] orderedAtoms = abstractGroundArithmeticRule.getOrderedAtoms();
        int[] iArr = new int[orderedAtoms.length];
        for (int i2 = 0; i2 < orderedAtoms.length; i2++) {
            iArr[i2] = map.get(orderedAtoms[i2]).intValue();
        }
        return new GroundRuleInfo(i, abstractGroundArithmeticRule.getComparator().toString(), abstractGroundArithmeticRule.getConstant(), abstractGroundArithmeticRule.getCoefficients(), iArr);
    }

    static {
        $assertionsDisabled = !GroundingAPI.class.desiredAssertionStatus();
    }
}
