package ai.libs.jaicore.ml.ranking.dyad.learner.zeroshot.inputoptimization;

import ai.libs.jaicore.ml.ranking.dyad.learner.algorithm.PLNetDyadRanker;
import ai.libs.jaicore.ml.ranking.dyad.learner.zeroshot.util.InputOptListener;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:ai/libs/jaicore/ml/ranking/dyad/learner/zeroshot/inputoptimization/PLNetInputOptimizer.class */
public class PLNetInputOptimizer {
    private InputOptListener listener;

    public INDArray optimizeInput(PLNetDyadRanker pLNetDyadRanker, INDArray iNDArray, InputOptimizerLoss inputOptimizerLoss, double d, int i, Pair<Integer, Integer> pair) {
        INDArray ones;
        if (pair != null) {
            ones = Nd4j.zeros(new long[]{iNDArray.length()});
            ones.get(new INDArrayIndex[]{NDArrayIndex.interval(((Integer) pair.getFirst()).intValue(), ((Integer) pair.getSecond()).intValue())}).assign(Double.valueOf(1.0d));
        } else {
            ones = Nd4j.ones(new long[]{iNDArray.length()});
        }
        return optimizeInput(pLNetDyadRanker, iNDArray, inputOptimizerLoss, d, i, ones);
    }

    public INDArray optimizeInput(PLNetDyadRanker pLNetDyadRanker, INDArray iNDArray, InputOptimizerLoss inputOptimizerLoss, double d, double d2, int i, Pair<Integer, Integer> pair) {
        INDArray ones;
        if (pair != null) {
            ones = Nd4j.zeros(new long[]{iNDArray.length()});
            ones.get(new INDArrayIndex[]{NDArrayIndex.interval(((Integer) pair.getFirst()).intValue(), ((Integer) pair.getSecond()).intValue())}).assign(Double.valueOf(1.0d));
        } else {
            ones = Nd4j.ones(new long[]{iNDArray.length()});
        }
        return optimizeInput(pLNetDyadRanker, iNDArray, inputOptimizerLoss, d, d2, i, ones);
    }

    public INDArray optimizeInput(PLNetDyadRanker pLNetDyadRanker, INDArray iNDArray, InputOptimizerLoss inputOptimizerLoss, double d, int i, INDArray iNDArray2) {
        return optimizeInput(pLNetDyadRanker, iNDArray, inputOptimizerLoss, d, d, i, iNDArray2);
    }

    public INDArray optimizeInput(PLNetDyadRanker pLNetDyadRanker, INDArray iNDArray, InputOptimizerLoss inputOptimizerLoss, double d, double d2, int i, INDArray iNDArray2) {
        INDArray dup = iNDArray.dup();
        INDArray zeros = Nd4j.zeros(dup.shape());
        INDArray zeros2 = Nd4j.zeros(dup.shape());
        INDArray ones = Nd4j.ones(dup.shape());
        double d3 = pLNetDyadRanker.getPlNet().output(dup).getDouble(0L);
        INDArray dup2 = dup.dup();
        for (int i2 = 0; i2 < i; i2++) {
            double d4 = i2 / i;
            double d5 = ((1.0d - d4) * d) + (d4 * d2);
            INDArray computeInputDerivative = computeInputDerivative(pLNetDyadRanker, dup, inputOptimizerLoss);
            computeInputDerivative.subi(zeros);
            computeInputDerivative.addi(zeros2);
            zeros.subi(dup);
            zeros2.addi(dup.sub(ones));
            BooleanIndexing.replaceWhere(zeros, Double.valueOf(0.0d), Conditions.lessThan(Double.valueOf(0.0d)));
            BooleanIndexing.replaceWhere(zeros2, Double.valueOf(0.0d), Conditions.lessThan(Double.valueOf(0.0d)));
            computeInputDerivative.muli(iNDArray2);
            computeInputDerivative.muli(Double.valueOf(d5));
            dup.subi(computeInputDerivative);
            double d6 = pLNetDyadRanker.getPlNet().output(dup).getDouble(0L);
            if (this.listener != null) {
                this.listener.reportOptimizationStep(dup, d6);
            }
            INDArray muli = dup.dup().muli(iNDArray2);
            if (d6 > d3 && BooleanIndexing.and(muli, Conditions.greaterThanOrEqual(Double.valueOf(0.0d))) && BooleanIndexing.and(muli, Conditions.lessThanOrEqual(Double.valueOf(1.0d)))) {
                dup2 = dup.dup();
                d3 = d6;
            }
        }
        return dup2;
    }

    private static INDArray computeInputDerivative(PLNetDyadRanker pLNetDyadRanker, INDArray iNDArray, InputOptimizerLoss inputOptimizerLoss) {
        MultiLayerNetwork plNet = pLNetDyadRanker.getPlNet();
        INDArray create = Nd4j.create(new double[]{inputOptimizerLoss.lossGradient(plNet.output(iNDArray))});
        plNet.setInput(iNDArray);
        plNet.feedForward(false, false);
        return (INDArray) plNet.backpropGradient(create, (LayerWorkspaceMgr) null).getSecond();
    }

    public void setListener(InputOptListener inputOptListener) {
        this.listener = inputOptListener;
    }
}
