package ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.featuretransform;

import ai.libs.jaicore.basic.sets.Pair;
import ai.libs.jaicore.math.linearalgebra.DenseDoubleVector;
import ai.libs.jaicore.ml.core.learner.ASupervisedLearner;
import ai.libs.jaicore.ml.ranking.RankingPredictionBatch;
import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.IPLDyadRanker;
import ai.libs.jaicore.ml.ranking.dyad.learner.optimizing.BilinFunction;
import ai.libs.jaicore.ml.ranking.dyad.learner.optimizing.DyadRankingFeatureTransformNegativeLogLikelihood;
import ai.libs.jaicore.ml.ranking.dyad.learner.optimizing.DyadRankingFeatureTransformNegativeLogLikelihoodDerivative;
import ai.libs.jaicore.ml.ranking.dyad.learner.optimizing.IDyadRankingFeatureTransformPLGradientDescendableFunction;
import ai.libs.jaicore.ml.ranking.dyad.learner.optimizing.IDyadRankingFeatureTransformPLGradientFunction;
import ai.libs.jaicore.ml.ranking.label.learner.clusterbased.customdatatypes.Ranking;
import edu.stanford.nlp.optimization.QNMinimizer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.stream.Collectors;
import org.api4.java.ai.ml.core.exception.PredictionException;
import org.api4.java.ai.ml.core.exception.TrainingException;
import org.api4.java.ai.ml.ranking.IRanking;
import org.api4.java.ai.ml.ranking.IRankingPredictionBatch;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingDataset;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingInstance;
import org.api4.java.common.math.IVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/ranking/dyad/learner/algorithm/featuretransform/FeatureTransformPLDyadRanker.class */
public class FeatureTransformPLDyadRanker extends ASupervisedLearner<IDyadRankingInstance, IDyadRankingDataset, IRanking<IDyad>, IRankingPredictionBatch> implements IPLDyadRanker {
    private static final Logger log = LoggerFactory.getLogger(FeatureTransformPLDyadRanker.class);
    private IDyadFeatureTransform featureTransform;
    private IVector w;
    private IDyadRankingFeatureTransformPLGradientDescendableFunction negativeLogLikelihood;
    private IDyadRankingFeatureTransformPLGradientFunction negativeLogLikelihoodDerivative;

    public FeatureTransformPLDyadRanker() {
        this(new BiliniearFeatureTransform());
    }

    public FeatureTransformPLDyadRanker(IDyadFeatureTransform iDyadFeatureTransform) {
        this.negativeLogLikelihood = new DyadRankingFeatureTransformNegativeLogLikelihood();
        this.negativeLogLikelihoodDerivative = new DyadRankingFeatureTransformNegativeLogLikelihoodDerivative();
        this.featureTransform = iDyadFeatureTransform;
    }

    private double computeSkillForDyad(IDyad iDyad) {
        IVector transform = this.featureTransform.transform(iDyad);
        double dotProduct = this.w.dotProduct(transform);
        double exp = Math.exp(dotProduct);
        log.debug("Feature transform for dyad {} is {}. \n Dot-Product is {} and skill is {}", new Object[]{iDyad, transform, Double.valueOf(dotProduct), Double.valueOf(exp)});
        return exp;
    }

    private double likelihoodOfParameter(IVector iVector, IDyadRankingDataset iDyadRankingDataset) {
        int size = iDyadRankingDataset.size();
        double d = 1.0d;
        for (int i = 0; i < size; i++) {
            IDyadRankingInstance iDyadRankingInstance = (IDyadRankingInstance) iDyadRankingDataset.get(i);
            int numberOfRankedElements = iDyadRankingInstance.getNumberOfRankedElements();
            double d2 = 1.0d;
            for (int i2 = 0; i2 < numberOfRankedElements; i2++) {
                double exp = Math.exp(iVector.dotProduct(this.featureTransform.transform((IDyad) iDyadRankingInstance.getLabel().get(i2))));
                double d3 = 0.0d;
                for (int i3 = i2; i3 < numberOfRankedElements; i3++) {
                    d3 += Math.exp(iVector.dotProduct(this.featureTransform.transform((IDyad) iDyadRankingInstance.getLabel().get(i3))));
                }
                d2 *= exp / d3;
            }
            d *= d2;
        }
        return d;
    }

    public void fit(IDyadRankingDataset iDyadRankingDataset) throws TrainingException, InterruptedException {
        Map<IDyadRankingInstance, Map<IDyad, IVector>> preComputedFeatureTransforms = this.featureTransform.getPreComputedFeatureTransforms(iDyadRankingDataset);
        this.negativeLogLikelihood.initialize(iDyadRankingDataset, preComputedFeatureTransforms);
        this.negativeLogLikelihoodDerivative.initialize(iDyadRankingDataset, preComputedFeatureTransforms);
        int length = ((IDyad) ((IDyadRankingInstance) iDyadRankingDataset.get(0)).getLabel().get(0)).getAlternative().length();
        int length2 = ((IDyad) ((IDyadRankingInstance) iDyadRankingDataset.get(0)).getLabel().get(0)).getContext().length();
        this.w = new DenseDoubleVector(this.featureTransform.getTransformedVectorLength(length, length2), 0.3d);
        log.debug("Likelihood of the randomly filled w is {}", Double.valueOf(likelihoodOfParameter(this.w, iDyadRankingDataset)));
        this.w = new DenseDoubleVector(new QNMinimizer().minimize(new BilinFunction(preComputedFeatureTransforms, iDyadRankingDataset, this.featureTransform.getTransformedVectorLength(length, length2)), 0.01d, this.w.asArray()));
        log.debug("Finished optimizing, the final w is {}", this.w);
    }

    @Override // ai.libs.jaicore.ml.core.learner.ASupervisedLearner
    public IRanking<IDyad> predict(IDyadRankingInstance iDyadRankingInstance) throws PredictionException, InterruptedException {
        if (this.w == null) {
            throw new PredictionException("The Ranker has not been trained yet.");
        }
        log.debug("Training ranker with instance {}", iDyadRankingInstance);
        ArrayList arrayList = new ArrayList();
        Iterator it = iDyadRankingInstance.iterator();
        while (it.hasNext()) {
            IDyad iDyad = (IDyad) it.next();
            arrayList.add(new Pair(Double.valueOf(computeSkillForDyad(iDyad)), iDyad));
        }
        return new Ranking((Collection) arrayList.stream().sorted((pair, pair2) -> {
            return Double.compare(((Double) pair.getX()).doubleValue(), ((Double) pair2.getX()).doubleValue());
        }).map((v0) -> {
            return v0.getY();
        }).collect(Collectors.toList()));
    }

    @Override // ai.libs.jaicore.ml.core.learner.ASupervisedLearner
    public IRankingPredictionBatch predict(IDyadRankingInstance[] iDyadRankingInstanceArr) throws PredictionException, InterruptedException {
        ArrayList arrayList = new ArrayList();
        for (IDyadRankingInstance iDyadRankingInstance : iDyadRankingInstanceArr) {
            arrayList.add(predict(iDyadRankingInstance));
        }
        return new RankingPredictionBatch(arrayList);
    }
}
