package ai.libs.jaicore.ml.ranking.dyad.learner.optimizing;

import ai.libs.jaicore.math.linearalgebra.DenseDoubleVector;
import java.util.Map;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingDataset;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingInstance;
import org.api4.java.common.math.IVector;

/* loaded from: input_file:ai/libs/jaicore/ml/ranking/dyad/learner/optimizing/DyadRankingFeatureTransformNegativeLogLikelihoodDerivative.class */
public class DyadRankingFeatureTransformNegativeLogLikelihoodDerivative implements IDyadRankingFeatureTransformPLGradientFunction {
    private IDyadRankingDataset dataset;
    private Map<IDyadRankingInstance, Map<IDyad, IVector>> featureTransforms;

    @Override // ai.libs.jaicore.ml.ranking.dyad.learner.optimizing.IDyadRankingFeatureTransformPLGradientFunction
    public void initialize(IDyadRankingDataset iDyadRankingDataset, Map<IDyadRankingInstance, Map<IDyad, IVector>> map) {
        this.dataset = iDyadRankingDataset;
        this.featureTransforms = map;
    }

    public IVector apply(IVector iVector) {
        DenseDoubleVector denseDoubleVector = new DenseDoubleVector(iVector.length());
        for (int i = 0; i < iVector.length(); i++) {
            denseDoubleVector.setValue(i, computeDerivativeForIndex(i, iVector));
        }
        return denseDoubleVector;
    }

    private double computeDerivativeForIndex(int i, IVector iVector) {
        double d = 0.0d;
        int size = this.dataset.size();
        double d2 = 0.0d;
        for (int i2 = 0; i2 < size; i2++) {
            IDyadRankingInstance iDyadRankingInstance = (IDyadRankingInstance) this.dataset.get(i2);
            int numberOfRankedElements = iDyadRankingInstance.getNumberOfRankedElements();
            for (int i3 = 0; i3 < numberOfRankedElements - 1; i3++) {
                double d3 = 0.0d;
                double d4 = 0.0d;
                d2 += this.featureTransforms.get(iDyadRankingInstance).get((IDyad) iDyadRankingInstance.getLabel().get(i3)).getValue(i);
                for (int i4 = i3; i4 < numberOfRankedElements; i4++) {
                    IVector iVector2 = this.featureTransforms.get(iDyadRankingInstance).get(iDyadRankingInstance.getAttributeValue(i4));
                    double exp = Math.exp(iVector.dotProduct(iVector2));
                    d4 += iVector2.getValue(i) * exp;
                    d3 += exp;
                }
                if (d3 != 0.0d) {
                    d += d4 / d3;
                }
            }
        }
        return (-d2) + d;
    }
}
