package org.cleartk.ml.tksvmlight.kernel;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import org.cleartk.ml.tksvmlight.TreeFeature;
import org.cleartk.ml.tksvmlight.model.LexicalFunctionModel;
import org.cleartk.util.treebank.TopTreebankNode;
import org.cleartk.util.treebank.TreebankFormatParser;
import org.cleartk.util.treebank.TreebankNode;

/* loaded from: input_file:org/cleartk/ml/tksvmlight/kernel/SyntacticSemanticTreeKernel.class */
public class SyntacticSemanticTreeKernel extends TreeKernel_ImplBase {
    private static final long serialVersionUID = -4333998714675622519L;
    public static final double LAMBDA_DEFAULT = 0.4d;
    private double lambda;
    private boolean normalize;
    private ConcurrentHashMap<String, Double> normalizers = new ConcurrentHashMap<>();
    HashMap<String, TopTreebankNode> trees;
    private LexicalFunctionModel lexModel;

    public SyntacticSemanticTreeKernel(LexicalFunctionModel lexicalFunctionModel, double d, boolean z) {
        this.lambda = 0.4d;
        this.normalize = false;
        this.trees = null;
        this.lexModel = null;
        this.lexModel = lexicalFunctionModel;
        this.lambda = d;
        this.normalize = z;
        this.trees = new HashMap<>();
    }

    @Override // org.cleartk.ml.tksvmlight.kernel.TreeKernel_ImplBase, org.cleartk.ml.tksvmlight.kernel.ComposableTreeKernel
    public double evaluate(TreeFeature treeFeature, TreeFeature treeFeature2) {
        return sstk(treeFeature.getValue().toString(), treeFeature2.getValue().toString());
    }

    private double sstk(String str, String str2) {
        TopTreebankNode topTreebankNode;
        TopTreebankNode topTreebankNode2;
        if (this.trees.containsKey(str)) {
            topTreebankNode = this.trees.get(str);
        } else {
            topTreebankNode = TreebankFormatParser.parse(str);
            this.trees.put(str, topTreebankNode);
        }
        if (this.trees.containsKey(str2)) {
            topTreebankNode2 = this.trees.get(str2);
        } else {
            topTreebankNode2 = TreebankFormatParser.parse(str2);
            this.trees.put(str2, topTreebankNode2);
        }
        double d = 0.0d;
        double d2 = 0.0d;
        if (this.normalize) {
            if (!this.normalizers.containsKey(str)) {
                this.normalizers.put(str, Double.valueOf(sim(topTreebankNode, topTreebankNode)));
            }
            if (!this.normalizers.containsKey(str2)) {
                this.normalizers.put(str2, Double.valueOf(sim(topTreebankNode2, topTreebankNode2)));
            }
            d = this.normalizers.get(str).doubleValue();
            d2 = this.normalizers.get(str2).doubleValue();
        }
        return this.normalize ? sim(topTreebankNode, topTreebankNode2) / Math.sqrt(d * d2) : sim(topTreebankNode, topTreebankNode2);
    }

    private double sim(TreebankNode treebankNode, TreebankNode treebankNode2) {
        double d = 0.0d;
        List<TreebankNode> nodeList = TreeKernelUtils.getNodeList(treebankNode);
        List<TreebankNode> nodeList2 = TreeKernelUtils.getNodeList(treebankNode2);
        for (TreebankNode treebankNode3 : nodeList) {
            Iterator<TreebankNode> it = nodeList2.iterator();
            while (it.hasNext()) {
                d += numCommonSubtrees(treebankNode3, it.next());
            }
        }
        return d;
    }

    private double numCommonSubtrees(TreebankNode treebankNode, TreebankNode treebankNode2) {
        double d;
        double d2 = 1.0d;
        List children = treebankNode.getChildren();
        List children2 = treebankNode2.getChildren();
        int size = children.size();
        if (size != children2.size()) {
            d = 0.0d;
        } else if (!treebankNode.getType().equals(treebankNode2.getType())) {
            d = 0.0d;
        } else if (treebankNode.isLeaf() && treebankNode2.isLeaf()) {
            d = this.lambda * this.lexModel.getLexicalSimilarity(treebankNode.getValue(), treebankNode2.getValue());
        } else {
            boolean z = true;
            int i = 0;
            while (true) {
                if (i >= size) {
                    break;
                }
                if (!((TreebankNode) children.get(i)).getType().equals(((TreebankNode) children2.get(i)).getType())) {
                    z = false;
                    break;
                }
                i++;
            }
            if (z) {
                for (int i2 = 0; i2 < size; i2++) {
                    d2 *= 1.0d + numCommonSubtrees((TreebankNode) children.get(i2), (TreebankNode) children2.get(i2));
                }
                d = this.lambda * d2;
            } else {
                d = 0.0d;
            }
        }
        return d;
    }
}
