package org.deeplearning4j.models.rntn;

import akka.actor.ActorSystem;
import com.google.common.util.concurrent.AtomicDouble;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.CopyOnWriteArrayList;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.parallel.Parallelization;
import org.deeplearning4j.util.MultiDimensionalMap;
import org.deeplearning4j.util.MultiDimensionalSet;
import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.activation.Activations;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/rntn/RNTN.class */
public class RNTN implements Serializable {
    protected double value;
    private int numOuts;
    private int numHidden;
    private RandomGenerator rng;
    private boolean useFloatTensors;
    private boolean combineClassification;
    private boolean simplifiedModel;
    private boolean randomFeatureVectors;
    private double scalingForInit;
    public static final String UNKNOWN_FEATURE = "UNK";
    private boolean lowerCasefeatureNames;
    protected ActivationFunction activationFunction;
    protected ActivationFunction outputActivation;
    protected AdaGrad paramAdaGrad;
    private double regTransformMatrix;
    private double regClassification;
    private double regWordVector;
    private int adagradResetFrequency;
    private double regTransformINDArray;
    private MultiDimensionalMap<String, String, INDArray> binaryTransform;
    private MultiDimensionalMap<String, String, INDArray> binaryINd4j;
    private Map<String, INDArray> unaryClassification;
    private Map<String, INDArray> featureVectors;
    private MultiDimensionalMap<String, String, INDArray> binaryClassification;
    private int numBinaryMatrices;
    private int binaryTransformSize;
    private int binaryINd4jize;
    private int binaryClassificationSize;
    private int numUnaryMatrices;
    private int unaryClassificationSize;
    private INDArray identity;
    private List<Tree> trainingTrees;
    private Map<Integer, Float> classWeights;
    private static Logger log;
    private transient ActorSystem rnTnActorSystem;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/deeplearning4j/models/rntn/RNTN$Builder.class */
    public static class Builder {
        private int numHidden;
        private RandomGenerator rng;
        private boolean useINd4j;
        private boolean randomFeatureVectors;
        private boolean lowerCasefeatureNames;
        private int adagradResetFrequency;
        private double regTransformINDArray;
        private Map<String, INDArray> featureVectors;
        private int numBinaryMatrices;
        private int binaryTransformSize;
        private int binaryINd4jize;
        private int binaryClassificationSize;
        private int numUnaryMatrices;
        private int unaryClassificationSize;
        private Map<Integer, Float> classWeights;
        private boolean combineClassification = true;
        private boolean simplifiedModel = true;
        private double scalingForInit = 0.0010000000474974513d;
        private ActivationFunction activationFunction = Activations.sigmoid();
        private ActivationFunction outputActivationFunction = Activations.softmax();

        public Builder withOutputActivation(ActivationFunction activationFunction) {
            this.outputActivationFunction = activationFunction;
            return this;
        }

        public Builder setFeatureVectors(Word2Vec word2Vec) {
            setFeatureVectors(word2Vec);
            this.numHidden = word2Vec.getLayerSize();
            return this;
        }

        public Builder setNumHidden(int i) {
            this.numHidden = i;
            return this;
        }

        public Builder setRng(RandomGenerator randomGenerator) {
            this.rng = randomGenerator;
            return this;
        }

        public Builder setUseTensors(boolean z) {
            this.useINd4j = z;
            return this;
        }

        public Builder setCombineClassification(boolean z) {
            this.combineClassification = z;
            return this;
        }

        public Builder setSimplifiedModel(boolean z) {
            this.simplifiedModel = z;
            return this;
        }

        public Builder setRandomFeatureVectors(boolean z) {
            this.randomFeatureVectors = z;
            return this;
        }

        public Builder setScalingForInit(double d) {
            this.scalingForInit = d;
            return this;
        }

        public Builder setLowerCasefeatureNames(boolean z) {
            this.lowerCasefeatureNames = z;
            return this;
        }

        public Builder setActivationFunction(ActivationFunction activationFunction) {
            this.activationFunction = activationFunction;
            return this;
        }

        public Builder setAdagradResetFrequency(int i) {
            this.adagradResetFrequency = i;
            return this;
        }

        public Builder setRegTransformINDArray(double d) {
            this.regTransformINDArray = d;
            return this;
        }

        public Builder setFeatureVectors(Map<String, INDArray> map) {
            this.featureVectors = map;
            return this;
        }

        public Builder setNumBinaryMatrices(int i) {
            this.numBinaryMatrices = i;
            return this;
        }

        public Builder setBinaryTransformSize(int i) {
            this.binaryTransformSize = i;
            return this;
        }

        public Builder setBinaryINd4jize(int i) {
            this.binaryINd4jize = i;
            return this;
        }

        public Builder setBinaryClassificationSize(int i) {
            this.binaryClassificationSize = i;
            return this;
        }

        public Builder setNumUnaryMatrices(int i) {
            this.numUnaryMatrices = i;
            return this;
        }

        public Builder setUnaryClassificationSize(int i) {
            this.unaryClassificationSize = i;
            return this;
        }

        public Builder setClassWeights(Map<Integer, Float> map) {
            this.classWeights = map;
            return this;
        }

        public RNTN build() {
            return new RNTN(this.numHidden, this.rng, this.useINd4j, this.combineClassification, this.simplifiedModel, this.randomFeatureVectors, this.scalingForInit, this.lowerCasefeatureNames, this.activationFunction, this.adagradResetFrequency, this.regTransformINDArray, this.featureVectors, this.numBinaryMatrices, this.binaryTransformSize, this.binaryINd4jize, this.binaryClassificationSize, this.numUnaryMatrices, this.unaryClassificationSize, this.classWeights);
        }
    }

    private RNTN(int i, RandomGenerator randomGenerator, boolean z, boolean z2, boolean z3, boolean z4, double d, boolean z5, ActivationFunction activationFunction, int i2, double d2, Map<String, INDArray> map, int i3, int i4, int i5, int i6, int i7, int i8, Map<Integer, Float> map2) {
        this.value = 0.0d;
        this.numOuts = 3;
        this.numHidden = 25;
        this.useFloatTensors = true;
        this.combineClassification = true;
        this.simplifiedModel = true;
        this.randomFeatureVectors = true;
        this.scalingForInit = 1.0d;
        this.activationFunction = Activations.tanh();
        this.outputActivation = Activations.softMaxRows();
        this.regTransformMatrix = 0.0010000000474974513d;
        this.regClassification = 9.999999747378752E-5d;
        this.regWordVector = 9.999999747378752E-5d;
        this.adagradResetFrequency = 1;
        this.regTransformINDArray = 0.0010000000474974513d;
        this.rnTnActorSystem = ActorSystem.create("RNTN");
        this.numHidden = i;
        this.rng = randomGenerator;
        this.useFloatTensors = z;
        this.combineClassification = z2;
        this.simplifiedModel = z3;
        this.randomFeatureVectors = z4;
        this.scalingForInit = d;
        this.lowerCasefeatureNames = z5;
        this.activationFunction = activationFunction;
        this.adagradResetFrequency = i2;
        this.regTransformINDArray = d2;
        this.featureVectors = map;
        this.numBinaryMatrices = i3;
        this.binaryTransformSize = i4;
        this.binaryINd4jize = i5;
        this.binaryClassificationSize = i6;
        this.numUnaryMatrices = i7;
        this.unaryClassificationSize = i8;
        this.classWeights = map2;
        init();
    }

    private void init() {
        if (this.rng == null) {
            this.rng = new MersenneTwister(123);
        }
        MultiDimensionalSet hashSet = MultiDimensionalSet.hashSet();
        if (!this.simplifiedModel) {
            throw new UnsupportedOperationException("Not yet implemented");
        }
        hashSet.add("", "");
        HashSet hashSet2 = new HashSet();
        if (!this.simplifiedModel) {
            throw new UnsupportedOperationException("Not yet implemented");
        }
        hashSet2.add("");
        this.identity = Nd4j.eye(this.numHidden);
        this.binaryTransform = MultiDimensionalMap.newTreeBackedMap();
        this.binaryINd4j = MultiDimensionalMap.newTreeBackedMap();
        this.binaryClassification = MultiDimensionalMap.newTreeBackedMap();
        Iterator it = hashSet.iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            String basicCategory = basicCategory((String) pair.getFirst());
            String basicCategory2 = basicCategory((String) pair.getSecond());
            if (!this.binaryTransform.contains(basicCategory, basicCategory2)) {
                this.binaryTransform.put(basicCategory, basicCategory2, randomTransformMatrix());
                if (this.useFloatTensors) {
                    this.binaryINd4j.put(basicCategory, basicCategory2, randomBinaryINDArray());
                }
                if (!this.combineClassification) {
                    this.binaryClassification.put(basicCategory, basicCategory2, randomClassificationMatrix());
                }
            }
        }
        this.numBinaryMatrices = this.binaryTransform.size();
        this.binaryTransformSize = this.numHidden * ((2 * this.numHidden) + 1);
        if (this.useFloatTensors) {
            this.binaryINd4jize = this.numHidden * this.numHidden * this.numHidden * 4;
        } else {
            this.binaryINd4jize = 0;
        }
        this.binaryClassificationSize = this.combineClassification ? 0 : this.numOuts * (this.numHidden + 1);
        this.unaryClassification = new TreeMap();
        Iterator it2 = hashSet2.iterator();
        while (it2.hasNext()) {
            String basicCategory3 = basicCategory((String) it2.next());
            if (!this.unaryClassification.containsKey(basicCategory3)) {
                this.unaryClassification.put(basicCategory3, randomClassificationMatrix());
            }
        }
        this.binaryClassificationSize = this.combineClassification ? 0 : this.numOuts * (this.numHidden + 1);
        this.numUnaryMatrices = this.unaryClassification.size();
        this.unaryClassificationSize = this.numOuts * (this.numHidden + 1);
        this.featureVectors.put("UNK", randomWordVector());
        this.numUnaryMatrices = this.unaryClassification.size();
        this.unaryClassificationSize = this.numOuts * (this.numHidden + 1);
        this.classWeights = new HashMap();
    }

    INDArray randomBinaryINDArray() {
        double d = 1.0f / (4.0f * this.numHidden);
        return Nd4j.rand(new int[]{this.numHidden, this.numHidden * 2, this.numHidden * 2}, -d, d, this.rng).muli(Double.valueOf(this.scalingForInit));
    }

    public INDArray randomTransformMatrix() {
        INDArray create = Nd4j.create(this.numHidden, (this.numHidden * 2) + 1);
        INDArray randomTransformBlock = randomTransformBlock();
        create.put(new NDArrayIndex[]{NDArrayIndex.interval(0, randomTransformBlock.rows()), NDArrayIndex.interval(0, randomTransformBlock.columns())}, randomTransformBlock);
        create.put(new NDArrayIndex[]{NDArrayIndex.interval(0, randomTransformBlock.rows()), NDArrayIndex.interval(this.numHidden, this.numHidden + randomTransformBlock.columns())}, randomTransformBlock());
        return Nd4j.getBlasWrapper().scal(this.scalingForInit, create);
    }

    public INDArray randomTransformBlock() {
        double sqrt = 1.0d / (Math.sqrt(this.numHidden) * 2.0d);
        return Nd4j.rand(this.numHidden, this.numHidden, -sqrt, sqrt, this.rng).add(this.identity);
    }

    INDArray randomClassificationMatrix() {
        double sqrt = 1.0d / Math.sqrt(this.numHidden);
        INDArray zeros = Nd4j.zeros(this.numOuts, this.numHidden + 1);
        zeros.put(new NDArrayIndex[]{NDArrayIndex.interval(0, this.numOuts), NDArrayIndex.interval(0, this.numHidden)}, Nd4j.rand(this.numOuts, this.numHidden, -sqrt, sqrt, this.rng));
        return Nd4j.getBlasWrapper().scal(this.scalingForInit, zeros);
    }

    INDArray randomWordVector() {
        return Nd4j.rand(this.numHidden, 1, this.rng);
    }

    public void fit(List<Tree> list) {
        this.trainingTrees = list;
        Iterator<Tree> it = list.iterator();
        while (it.hasNext()) {
            forwardPropagateTree(it.next());
            setParameters(getParameters().subi(getValueGradient(0)));
        }
    }

    public void setParams(INDArray iNDArray, Iterator<? extends INDArray>... itArr) {
        int i = 0;
        for (Iterator<? extends INDArray> it : itArr) {
            while (it.hasNext()) {
                INDArray next = it.next();
                for (int i2 = 0; i2 < next.length(); i2++) {
                    next.put(i2, iNDArray.getScalar(i));
                    i++;
                }
            }
        }
        if (i != iNDArray.length()) {
            throw new AssertionError("Did not entirely use the theta vector");
        }
    }

    public INDArray getWForNode(Tree tree) {
        if (tree.children().size() == 2) {
            return (INDArray) this.binaryTransform.get(basicCategory(tree.children().get(0).value()), basicCategory(tree.children().get(1).value()));
        }
        if (tree.children().size() == 1) {
            throw new AssertionError("No unary applyTransformToOrigin matrices, only unary classification");
        }
        throw new AssertionError("Unexpected tree children size of " + tree.children().size());
    }

    public INDArray getINDArrayForNode(Tree tree) {
        if (!this.useFloatTensors) {
            throw new AssertionError("Not using INd4j");
        }
        if (tree.children().size() == 2) {
            return (INDArray) this.binaryINd4j.get(basicCategory(tree.children().get(0).value()), basicCategory(tree.children().get(1).value()));
        }
        if (tree.children().size() == 1) {
            throw new AssertionError("No unary applyTransformToOrigin matrices, only unary classification");
        }
        throw new AssertionError("Unexpected tree children size of " + tree.children().size());
    }

    public INDArray getClassWForNode(Tree tree) {
        if (this.combineClassification) {
            return this.unaryClassification.get("");
        }
        if (tree.children().size() == 2) {
            return (INDArray) this.binaryClassification.get(basicCategory(tree.children().get(0).value()), basicCategory(tree.children().get(1).value()));
        }
        if (tree.children().size() != 1) {
            throw new AssertionError("Unexpected tree children size of " + tree.children().size());
        }
        return this.unaryClassification.get(basicCategory(tree.children().get(0).value()));
    }

    private INDArray getINDArrayGradient(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        int length = iNDArray.length();
        INDArray create = Nd4j.create(new int[]{length, length * 2, length * 2});
        INDArray concat = Nd4j.concat(0, new INDArray[]{iNDArray2, iNDArray3});
        for (int i = 0; i < length; i++) {
            create.putSlice(i, Nd4j.getBlasWrapper().scal(((Double) iNDArray.getScalar(i).element()).doubleValue(), concat).mmul(concat.transpose()));
        }
        return create;
    }

    public INDArray getFeatureVector(String str) {
        INDArray iNDArray = this.featureVectors.get(getVocabWord(str));
        if (iNDArray.isRowVector()) {
            iNDArray = iNDArray.transpose();
        }
        return iNDArray;
    }

    public String getVocabWord(String str) {
        if (this.lowerCasefeatureNames) {
            str = str.toLowerCase();
        }
        return this.featureVectors.containsKey(str) ? str : "UNK";
    }

    public String basicCategory(String str) {
        if (this.simplifiedModel) {
            return "";
        }
        throw new IllegalStateException("Only simplified model enabled");
    }

    public INDArray getUnaryClassification(String str) {
        return this.unaryClassification.get(basicCategory(str));
    }

    public INDArray getBinaryClassification(String str, String str2) {
        if (this.combineClassification) {
            return this.unaryClassification.get("");
        }
        return (INDArray) this.binaryClassification.get(basicCategory(str), basicCategory(str2));
    }

    public INDArray getBinaryTransform(String str, String str2) {
        return (INDArray) this.binaryTransform.get(basicCategory(str), basicCategory(str2));
    }

    public INDArray getBinaryINDArray(String str, String str2) {
        return (INDArray) this.binaryINd4j.get(basicCategory(str), basicCategory(str2));
    }

    public int getNumParameters() {
        return (this.numBinaryMatrices * (this.binaryTransform.size() + this.binaryClassificationSize)) + this.binaryINd4jize + (this.numUnaryMatrices * this.unaryClassification.size()) + (this.featureVectors.size() * this.numHidden);
    }

    public INDArray getParameters() {
        return Nd4j.toFlattened(getNumParameters(), new Iterator[]{this.binaryTransform.values().iterator(), this.binaryClassification.values().iterator(), this.binaryINd4j.values().iterator(), this.unaryClassification.values().iterator(), this.featureVectors.values().iterator()});
    }

    double scaleAndRegularize(MultiDimensionalMap<String, String, INDArray> multiDimensionalMap, MultiDimensionalMap<String, String, INDArray> multiDimensionalMap2, double d, double d2) {
        double d3 = 0.0d;
        for (MultiDimensionalMap.Entry entry : multiDimensionalMap2.entrySet()) {
            multiDimensionalMap.put(entry.getFirstKey(), entry.getSecondKey(), Nd4j.getBlasWrapper().scal(d, (INDArray) multiDimensionalMap.get(entry.getFirstKey(), entry.getSecondKey())).add(Nd4j.getBlasWrapper().scal(d2, (INDArray) entry.getValue())));
            d3 += (((Double) ((INDArray) entry.getValue()).mul((INDArray) entry.getValue()).sum(Integer.MAX_VALUE).element()).doubleValue() * d2) / 2.0d;
        }
        return d3;
    }

    double scaleAndRegularize(Map<String, INDArray> map, Map<String, INDArray> map2, double d, double d2) {
        double d3 = 0.0d;
        for (Map.Entry<String, INDArray> entry : map2.entrySet()) {
            map.put(entry.getKey(), Nd4j.getBlasWrapper().scal(d, map.get(entry.getKey())).add(Nd4j.getBlasWrapper().scal(d2, entry.getValue())));
            d3 += (((Double) entry.getValue().mul(entry.getValue()).sum(Integer.MAX_VALUE).element()).doubleValue() * d2) / 2.0d;
        }
        return d3;
    }

    double scaleAndRegularizeINDArray(MultiDimensionalMap<String, String, INDArray> multiDimensionalMap, MultiDimensionalMap<String, String, INDArray> multiDimensionalMap2, double d, double d2) {
        double d3 = 0.0d;
        for (MultiDimensionalMap.Entry entry : multiDimensionalMap2.entrySet()) {
            multiDimensionalMap.put(entry.getFirstKey(), entry.getSecondKey(), ((INDArray) multiDimensionalMap.get(entry.getFirstKey(), entry.getSecondKey())).muli(Double.valueOf(d)).add(((INDArray) entry.getValue()).muli(Double.valueOf(d2))));
            d3 += (((Double) ((INDArray) entry.getValue()).mul((INDArray) entry.getValue()).sum(Integer.MAX_VALUE).element()).doubleValue() * d2) / 2.0d;
        }
        return d3;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void backpropDerivativesAndError(Tree tree, MultiDimensionalMap<String, String, INDArray> multiDimensionalMap, MultiDimensionalMap<String, String, INDArray> multiDimensionalMap2, MultiDimensionalMap<String, String, INDArray> multiDimensionalMap3, Map<String, INDArray> map, Map<String, INDArray> map2) {
        backpropDerivativesAndError(tree, multiDimensionalMap, multiDimensionalMap2, multiDimensionalMap3, map, map2, Nd4j.create(this.numHidden, 1));
    }

    private void backpropDerivativesAndError(Tree tree, MultiDimensionalMap<String, String, INDArray> multiDimensionalMap, MultiDimensionalMap<String, String, INDArray> multiDimensionalMap2, MultiDimensionalMap<String, String, INDArray> multiDimensionalMap3, Map<String, INDArray> map, Map<String, INDArray> map2, INDArray iNDArray) {
        INDArray mmul;
        if (tree.isLeaf()) {
            return;
        }
        INDArray vector = tree.vector();
        String basicCategory = basicCategory(tree.label());
        INDArray create = Nd4j.create(this.numOuts, 1);
        int goldLabel = tree.goldLabel();
        if (goldLabel >= 0) {
            if (!$assertionsDisabled && goldLabel > this.numOuts) {
                throw new AssertionError("Tried adding a label that was >= to the number of configured outputs " + this.numOuts + " with label " + goldLabel);
            }
            create.putScalar(goldLabel, 1.0f);
        }
        Float f = this.classWeights.get(Integer.valueOf(goldLabel));
        if (f == null) {
            f = Float.valueOf(1.0f);
        }
        INDArray prediction = tree.prediction();
        INDArray scal = goldLabel >= 0 ? Nd4j.getBlasWrapper().scal(f.floatValue(), prediction.sub(create)) : Nd4j.create(prediction.rows(), prediction.columns());
        INDArray mmul2 = scal.mmul(Nd4j.appendBias(new INDArray[]{vector}).transpose());
        tree.setError((-((Double) Transforms.log(prediction).muli(create).sum(Integer.MAX_VALUE).element()).doubleValue()) * f.floatValue());
        if (tree.isPreTerminal()) {
            map.put(basicCategory, map.get(basicCategory).add(mmul2));
            String vocabWord = getVocabWord(tree.children().get(0).label());
            map2.put(vocabWord, map2.get(vocabWord).add(getUnaryClassification(basicCategory).transpose().mmul(scal).get(new NDArrayIndex[]{NDArrayIndex.interval(0, this.numHidden), NDArrayIndex.interval(0, 1)}).mul((INDArray) this.activationFunction.apply(vector)).add(iNDArray)));
            return;
        }
        String basicCategory2 = basicCategory(tree.children().get(0).label());
        String basicCategory3 = basicCategory(tree.children().get(1).label());
        if (this.combineClassification) {
            map.put("", map.get("").add(mmul2));
        } else {
            multiDimensionalMap2.put(basicCategory2, basicCategory3, ((INDArray) multiDimensionalMap2.get(basicCategory2, basicCategory3)).add(mmul2));
        }
        INDArray muli = getBinaryClassification(basicCategory2, basicCategory3).transpose().mmul(scal).get(new NDArrayIndex[]{NDArrayIndex.interval(0, this.numHidden), NDArrayIndex.interval(0, 1)}).muli(this.activationFunction.applyDerivative(vector));
        INDArray add = muli.add(iNDArray);
        INDArray vector2 = tree.children().get(0).vector();
        INDArray vector3 = tree.children().get(1).vector();
        multiDimensionalMap.put(basicCategory2, basicCategory3, ((INDArray) multiDimensionalMap.get(basicCategory2, basicCategory3)).add(muli.mmul(Nd4j.appendBias(new INDArray[]{vector2, vector3}).transpose())));
        if (this.useFloatTensors) {
            multiDimensionalMap3.put(basicCategory2, basicCategory3, ((INDArray) multiDimensionalMap3.get(basicCategory2, basicCategory3)).add(getINDArrayGradient(add, vector2, vector3)));
            mmul = computeINDArrayDeltaDown(add, vector2, vector3, getBinaryTransform(basicCategory2, basicCategory3), getBinaryINDArray(basicCategory2, basicCategory3));
        } else {
            mmul = getBinaryTransform(basicCategory2, basicCategory3).transpose().mmul(add);
        }
        INDArray iNDArray2 = (INDArray) this.activationFunction.apply(vector2);
        INDArray iNDArray3 = (INDArray) this.activationFunction.apply(vector3);
        INDArray iNDArray4 = mmul.get(new NDArrayIndex[]{NDArrayIndex.interval(0, add.rows()), NDArrayIndex.interval(0, 1)});
        INDArray iNDArray5 = mmul.get(new NDArrayIndex[]{NDArrayIndex.interval(add.rows(), add.rows() * 2), NDArrayIndex.interval(0, 1)});
        backpropDerivativesAndError(tree.children().get(0), multiDimensionalMap, multiDimensionalMap2, multiDimensionalMap3, map, map2, iNDArray2.mul(iNDArray4));
        backpropDerivativesAndError(tree.children().get(1), multiDimensionalMap, multiDimensionalMap2, multiDimensionalMap3, map, map2, iNDArray3.mul(iNDArray5));
    }

    private INDArray computeINDArrayDeltaDown(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5) {
        INDArray mmul = iNDArray4.transpose().mmul(iNDArray);
        INDArray iNDArray6 = mmul.isMatrix() ? mmul.get(new NDArrayIndex[]{NDArrayIndex.interval(0, 1), NDArrayIndex.interval(0, (iNDArray.rows() * 2) + 1)}) : mmul.get(new NDArrayIndex[]{NDArrayIndex.interval(0, iNDArray.rows() * 2)});
        int length = iNDArray.length();
        INDArray create = Nd4j.create(length * 2, 1);
        INDArray concat = Nd4j.concat(0, new INDArray[]{iNDArray2, iNDArray3});
        for (int i = 0; i < length; i++) {
            create = create.add(iNDArray5.slice(i).add(iNDArray5.slice(i).transpose()).mmul(Nd4j.getBlasWrapper().scal(((Double) iNDArray.getScalar(i).element()).doubleValue(), concat)));
        }
        return create.add(iNDArray6);
    }

    public void forwardPropagateTree(Tree tree) {
        INDArray binaryClassification;
        INDArray iNDArray;
        if (tree.isLeaf()) {
            throw new AssertionError("We should not have reached leaves in forwardPropagate");
        }
        if (tree.isPreTerminal()) {
            binaryClassification = getUnaryClassification(tree.label());
            INDArray featureVector = getFeatureVector(tree.children().get(0).value());
            if (featureVector == null) {
                featureVector = this.featureVectors.get("UNK");
            }
            iNDArray = (INDArray) this.activationFunction.apply(featureVector);
        } else {
            if (tree.children().size() == 1) {
                throw new AssertionError("Non-preterminal nodes of size 1 should have already been collapsed");
            }
            if (tree.children().size() != 2) {
                throw new AssertionError("Tree not correctly binarized");
            }
            Tree firstChild = tree.firstChild();
            Tree lastChild = tree.lastChild();
            forwardPropagateTree(firstChild);
            forwardPropagateTree(lastChild);
            String label = tree.children().get(0).label();
            String label2 = tree.children().get(1).label();
            INDArray binaryTransform = getBinaryTransform(label, label2);
            binaryClassification = getBinaryClassification(label, label2);
            INDArray vector = tree.children().get(0).vector();
            INDArray vector2 = tree.children().get(1).vector();
            INDArray appendBias = Nd4j.appendBias(new INDArray[]{vector, vector2});
            iNDArray = this.useFloatTensors ? (INDArray) this.activationFunction.apply(binaryTransform.mmul(appendBias).add(Nd4j.bilinearProducts(getBinaryINDArray(label, label2), Nd4j.concat(0, new INDArray[]{vector, vector2})))) : (INDArray) this.activationFunction.apply(binaryTransform.mmul(appendBias));
        }
        INDArray appendBias2 = Nd4j.appendBias(new INDArray[]{iNDArray});
        if (appendBias2.rows() != binaryClassification.columns()) {
            appendBias2 = appendBias2.transpose();
        }
        tree.setPrediction((INDArray) this.outputActivation.apply(binaryClassification.mmul(appendBias2)));
        tree.setVector(iNDArray);
    }

    private INDArray getFloatTensorGradient(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        int length = iNDArray.length();
        INDArray create = Nd4j.create(new int[]{length * 2, length * 2, length});
        INDArray concat = Nd4j.concat(0, new INDArray[]{iNDArray2, iNDArray3});
        for (int i = 0; i < length; i++) {
            create.putSlice(i, Nd4j.getBlasWrapper().scal(iNDArray.getDouble(i), concat).mmul(concat.transpose()));
        }
        return create;
    }

    public List<INDArray> output(List<Tree> list) {
        ArrayList arrayList = new ArrayList();
        for (Tree tree : list) {
            forwardPropagateTree(tree);
            arrayList.add(tree.prediction());
        }
        return arrayList;
    }

    public List<Integer> predict(List<Tree> list) {
        ArrayList arrayList = new ArrayList();
        for (Tree tree : list) {
            forwardPropagateTree(tree);
            arrayList.add(Integer.valueOf(Nd4j.getBlasWrapper().iamax(tree.prediction())));
        }
        return arrayList;
    }

    public void setParameters(INDArray iNDArray) {
        setParams(iNDArray, this.binaryTransform.values().iterator(), this.binaryClassification.values().iterator(), this.binaryINd4j.values().iterator(), this.unaryClassification.values().iterator(), this.featureVectors.values().iterator());
    }

    public INDArray getValueGradient(int i) {
        final MultiDimensionalMap<String, String, INDArray> newTreeBackedMap = MultiDimensionalMap.newTreeBackedMap();
        final MultiDimensionalMap<String, String, INDArray> newTreeBackedMap2 = MultiDimensionalMap.newTreeBackedMap();
        final MultiDimensionalMap<String, String, INDArray> newTreeBackedMap3 = MultiDimensionalMap.newTreeBackedMap();
        final Map<String, INDArray> treeMap = new TreeMap<>();
        final Map<String, INDArray> treeMap2 = new TreeMap<>();
        for (MultiDimensionalMap.Entry entry : this.binaryTransform.entrySet()) {
            newTreeBackedMap.put(entry.getFirstKey(), entry.getSecondKey(), Nd4j.create(((INDArray) entry.getValue()).rows(), ((INDArray) entry.getValue()).columns()));
        }
        if (!this.combineClassification) {
            for (MultiDimensionalMap.Entry entry2 : this.binaryClassification.entrySet()) {
                newTreeBackedMap3.put(entry2.getFirstKey(), entry2.getSecondKey(), Nd4j.create(((INDArray) entry2.getValue()).rows(), ((INDArray) entry2.getValue()).columns()));
            }
        }
        if (this.useFloatTensors) {
            for (MultiDimensionalMap.Entry entry3 : this.binaryINd4j.entrySet()) {
                newTreeBackedMap2.put(entry3.getFirstKey(), entry3.getSecondKey(), Nd4j.create(new int[]{((INDArray) entry3.getValue()).size(1), ((INDArray) entry3.getValue()).size(2), ((INDArray) entry3.getValue()).slices()}));
            }
        }
        for (Map.Entry<String, INDArray> entry4 : this.unaryClassification.entrySet()) {
            treeMap.put(entry4.getKey(), Nd4j.create(entry4.getValue().rows(), entry4.getValue().columns()));
        }
        for (Map.Entry<String, INDArray> entry5 : this.featureVectors.entrySet()) {
            treeMap2.put(entry5.getKey(), Nd4j.create(entry5.getValue().rows(), entry5.getValue().columns()));
        }
        final CopyOnWriteArrayList copyOnWriteArrayList = new CopyOnWriteArrayList();
        Parallelization.iterateInParallel(this.trainingTrees, new Parallelization.RunnableWithParams<Tree>() { // from class: org.deeplearning4j.models.rntn.RNTN.1
            public void run(Tree tree, Object[] objArr) {
                Tree tree2 = new Tree(tree);
                tree2.connect(new ArrayList(tree.children()));
                RNTN.this.forwardPropagateTree(tree2);
                copyOnWriteArrayList.add(tree2);
            }
        }, this.rnTnActorSystem);
        final AtomicDouble atomicDouble = new AtomicDouble(0.0d);
        Parallelization.iterateInParallel(copyOnWriteArrayList, new Parallelization.RunnableWithParams<Tree>() { // from class: org.deeplearning4j.models.rntn.RNTN.2
            public void run(Tree tree, Object[] objArr) {
                RNTN.this.backpropDerivativesAndError(tree, newTreeBackedMap, newTreeBackedMap3, newTreeBackedMap2, treeMap, treeMap2);
                atomicDouble.addAndGet(tree.errorSum());
            }
        }, new Parallelization.RunnableWithParams<Tree>() { // from class: org.deeplearning4j.models.rntn.RNTN.3
            public void run(Tree tree, Object[] objArr) {
            }
        }, this.rnTnActorSystem, new Object[]{newTreeBackedMap, newTreeBackedMap3, newTreeBackedMap2, treeMap, treeMap2});
        double size = 1.0f / this.trainingTrees.size();
        this.value = atomicDouble.doubleValue() * size;
        this.value += scaleAndRegularize(newTreeBackedMap, this.binaryTransform, size, this.regTransformMatrix);
        this.value += scaleAndRegularize(newTreeBackedMap3, this.binaryClassification, size, this.regClassification);
        this.value += scaleAndRegularizeINDArray(newTreeBackedMap2, this.binaryINd4j, size, this.regTransformINDArray);
        this.value += scaleAndRegularize(treeMap, this.unaryClassification, size, this.regClassification);
        this.value += scaleAndRegularize(treeMap2, this.featureVectors, size, this.regWordVector);
        INDArray flattened = Nd4j.toFlattened(getNumParameters(), new Iterator[]{newTreeBackedMap.values().iterator(), newTreeBackedMap3.values().iterator(), newTreeBackedMap2.values().iterator(), treeMap.values().iterator(), treeMap2.values().iterator()});
        if (this.paramAdaGrad == null) {
            this.paramAdaGrad = new AdaGrad(1, flattened.columns());
        }
        flattened.muli(this.paramAdaGrad.getLearningRates(flattened));
        return flattened;
    }

    public double getValue() {
        return this.value;
    }

    static {
        $assertionsDisabled = !RNTN.class.desiredAssertionStatus();
        log = LoggerFactory.getLogger(RNTN.class);
    }
}
