package ai.entrolution.thylacine.model.optimization.gradientdescent;

import ai.entrolution.thylacine.model.components.posterior.Posterior;
import ai.entrolution.thylacine.model.core.AsyncImplicits;
import ai.entrolution.thylacine.model.core.telemetry.OptimisationTelemetryUpdate;
import ai.entrolution.thylacine.model.core.values.IndexedVectorCollection;
import ai.entrolution.thylacine.model.core.values.modelparameters.ModelParameterContext;
import ai.entrolution.thylacine.model.optimization.ModelParameterOptimizer;
import ai.entrolution.thylacine.model.optimization.line.GoldenSectionSearch;
import ai.entrolution.thylacine.model.optimization.line.LineEvaluationResult;
import ai.entrolution.thylacine.util.ScalaVectorOps$Implicits$;
import cats.effect.implicits$;
import cats.effect.kernel.Async$;
import cats.effect.kernel.syntax.GenSpawnOps$;
import cats.syntax.FlatMapOps$;
import cats.syntax.package$all$;
import scala.Function1;
import scala.MatchError;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.immutable.Vector;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: ConjugateGradientEngine.scala */
@ScalaSignature(bytes = "\u0006\u0005\u0005-d!C\u0005\u000b!\u0003\r\taFA\u0015\u0011\u00159\u0004\u0001\"\u00019\u0011\u001da\u0004A1A\u0005\u0012uBQ!\u0013\u0001\u0007\u0012)CQA\u0014\u0001\u0007\u0012=CQa\u0015\u0001\u0007\u0012QCQ!\u0019\u0001\u0007\u0012\tDQ\u0001\u001a\u0001\u0005\n\u0015Dq!!\t\u0001\t#\t\u0019CA\fD_:TWoZ1uK\u001e\u0013\u0018\rZ5f]R,enZ5oK*\u00111\u0002D\u0001\u0010OJ\fG-[3oi\u0012,7oY3oi*\u0011QBD\u0001\r_B$\u0018.\\5{CRLwN\u001c\u0006\u0003\u001fA\tQ!\\8eK2T!!\u0005\n\u0002\u0013QD\u0017\u0010\\1dS:,'BA\n\u0015\u0003-)g\u000e\u001e:pYV$\u0018n\u001c8\u000b\u0003U\t!!Y5\u0004\u0001U\u0011\u0001$J\n\u0005\u0001ey\u0012\u0007\u0005\u0002\u001b;5\t1DC\u0001\u001d\u0003\u0015\u00198-\u00197b\u0013\tq2D\u0001\u0004B]f\u0014VM\u001a\t\u0004A\u0005\u001aS\"\u0001\u0007\n\u0005\tb!aF'pI\u0016d\u0007+\u0019:b[\u0016$XM](qi&l\u0017N_3s!\t!S\u0005\u0004\u0001\u0005\u000b\u0019\u0002!\u0019A\u0014\u0003\u0003\u0019+\"\u0001K\u0018\u0012\u0005%b\u0003C\u0001\u000e+\u0013\tY3DA\u0004O_RD\u0017N\\4\u0011\u0005ii\u0013B\u0001\u0018\u001c\u0005\r\te.\u001f\u0003\u0006a\u0015\u0012\r\u0001\u000b\u0002\u0005?\u0012\"\u0013\u0007E\u00023k\rj\u0011a\r\u0006\u0003i1\tA\u0001\\5oK&\u0011ag\r\u0002\u0014\u000f>dG-\u001a8TK\u000e$\u0018n\u001c8TK\u0006\u00148\r[\u0001\u0007I%t\u0017\u000e\u001e\u0013\u0015\u0003e\u0002\"A\u0007\u001e\n\u0005mZ\"\u0001B+oSR\fq\u0002^3mK6,GO]=Qe\u00164\u0017\u000e_\u000b\u0002}A\u0011qH\u0012\b\u0003\u0001\u0012\u0003\"!Q\u000e\u000e\u0003\tS!a\u0011\f\u0002\rq\u0012xn\u001c;?\u0013\t)5$\u0001\u0004Qe\u0016$WMZ\u0005\u0003\u000f\"\u0013aa\u0015;sS:<'BA#\u001c\u0003Q\u0019wN\u001c<fe\u001e,gnY3UQJ,7\u000f[8mIV\t1\n\u0005\u0002\u001b\u0019&\u0011Qj\u0007\u0002\u0007\t>,(\r\\3\u000235Lg.[7v[:+XNY3s\u001f\u001aLE/\u001a:bi&|gn]\u000b\u0002!B\u0011!$U\u0005\u0003%n\u00111!\u00138u\u0003]IG/\u001a:bi&|g.\u00169eCR,7)\u00197mE\u0006\u001c7.F\u0001V!\u0011Qb\u000b\u00171\n\u0005][\"!\u0003$v]\u000e$\u0018n\u001c82!\tIf,D\u0001[\u0015\tYF,A\u0005uK2,W.\u001a;ss*\u0011QLD\u0001\u0005G>\u0014X-\u0003\u0002`5\nYr\n\u001d;j[&\u001c\u0018\r^5p]R+G.Z7fiJLX\u000b\u001d3bi\u0016\u00042\u0001J\u0013:\u0003MI7oQ8om\u0016\u0014x-\u001a3DC2d'-Y2l+\u0005\u0019\u0007\u0003\u0002\u000eWs\u0001\f1cY1mGVd\u0017\r^3OKb$Hj\\4QI\u001a$\u0002B\u001a?\u0002\u0004\u0005e\u0011Q\u0004\t\u0004I\u0015:\u0007\u0003\u0002\u000ei\u0017*L!![\u000e\u0003\rQ+\b\u000f\\33!\tY\u0017P\u0004\u0002mm:\u0011Q\u000e\u001e\b\u0003]Nt!a\u001c:\u000f\u0005A\fX\"\u0001\n\n\u0005E\u0011\u0012BA\b\u0011\u0013\tif\"\u0003\u0002v9\u00061a/\u00197vKNL!a\u001e=\u0002/%sG-\u001a=fIZ+7\r^8s\u0007>dG.Z2uS>t'BA;]\u0013\tQ8P\u0001\rN_\u0012,G\u000eU1sC6,G/\u001a:D_2dWm\u0019;j_:T!a\u001e=\t\u000bu<\u0001\u0019\u0001@\u0002%M$\u0018M\u001d;j]\u001e,e/\u00197vCRLwN\u001c\t\u0003e}L1!!\u00014\u0005Qa\u0015N\\3Fm\u0006dW/\u0019;j_:\u0014Vm];mi\"9\u0011QA\u0004A\u0002\u0005\u001d\u0011\u0001\u00059sKZLw.^:He\u0006$\u0017.\u001a8u!\u0015\tI!a\u0005L\u001d\u0011\tY!a\u0004\u000f\u0007\u0005\u000bi!C\u0001\u001d\u0013\r\t\tbG\u0001\ba\u0006\u001c7.Y4f\u0013\u0011\t)\"a\u0006\u0003\rY+7\r^8s\u0015\r\t\tb\u0007\u0005\b\u000379\u0001\u0019AA\u0004\u0003]\u0001(/\u001a<j_V\u001c8+Z1sG\"$\u0015N]3di&|g\u000e\u0003\u0004\u0002 \u001d\u0001\r\u0001U\u0001\nSR,'/\u0019;j_:\facY1mGVd\u0017\r^3NCbLW.^7M_\u001e\u0004FM\u001a\u000b\u0004M\u0006\u0015\u0002BBA\u0014\u0011\u0001\u0007!.\u0001\u0006ti\u0006\u0014H/\u001b8h!R\u0014b!a\u000b\u00020\u0005MbABA\u0017\u0001\u0001\tIC\u0001\u0007=e\u00164\u0017N\\3nK:$h\b\u0005\u0003\u00022\u0001\u0019S\"\u0001\u0006\u0013\r\u0005U\u0012qGA \r\u0019\ti\u0003\u0001\u0001\u00024A)\u0011\u0011HA\u001eG5\tA,C\u0002\u0002>q\u0013a\"Q:z]\u000eLU\u000e\u001d7jG&$8\u000f\r\u0003\u0002B\u0005\u001d\u0004#CA\"\u0003\u001b\u001a\u0013\u0011KA3\u001b\t\t)E\u0003\u0003\u0002H\u0005%\u0013!\u00039pgR,'/[8s\u0015\r\tYED\u0001\u000bG>l\u0007o\u001c8f]R\u001c\u0018\u0002BA(\u0003\u000b\u0012\u0011\u0002U8ti\u0016\u0014\u0018n\u001c:1\t\u0005M\u0013\u0011\r\t\b\u0003+\nYfIA0\u001b\t\t9F\u0003\u0003\u0002Z\u0005%\u0013!\u00029sS>\u0014\u0018\u0002BA/\u0003/\u0012Q\u0001\u0015:j_J\u00042\u0001JA1\t)\t\u0019\u0007AA\u0001\u0002\u0003\u0015\t\u0001\u000b\u0002\u0004?\u0012\n\u0004c\u0001\u0013\u0002h\u0011Q\u0011\u0011\u000e\u0001\u0002\u0002\u0003\u0005)\u0011\u0001\u0015\u0003\u0007}##\u0007")
/* loaded from: input_file:ai/entrolution/thylacine/model/optimization/gradientdescent/ConjugateGradientEngine.class */
public interface ConjugateGradientEngine<F> extends ModelParameterOptimizer<F>, GoldenSectionSearch<F> {
    void ai$entrolution$thylacine$model$optimization$gradientdescent$ConjugateGradientEngine$_setter_$telemetryPrefix_$eq(String str);

    String telemetryPrefix();

    double convergenceThreshold();

    int minimumNumberOfIterations();

    Function1<OptimisationTelemetryUpdate, F> iterationUpdateCallback();

    Function1<BoxedUnit, F> isConvergedCallback();

    /* JADX INFO: Access modifiers changed from: private */
    /* JADX WARN: Multi-variable type inference failed */
    default F calculateNextLogPdf(LineEvaluationResult lineEvaluationResult, Vector<Object> vector, Vector<Object> vector2, int i) {
        return (F) package$all$.MODULE$.toFlatMapOps(package$all$.MODULE$.toFlatMapOps(package$all$.MODULE$.toFunctorOps(((Posterior) this).logPdfGradientAt(lineEvaluationResult.modelParameterArgument()), ((AsyncImplicits) this).asyncF()).map(indexedVectorCollection -> {
            return indexedVectorCollection.rawScalarMultiplyWith(-1.0d);
        }), ((AsyncImplicits) this).asyncF()).flatMap(indexedVectorCollection2 -> {
            return package$all$.MODULE$.toFlatMapOps(Async$.MODULE$.apply(((AsyncImplicits) this).asyncF()).delay(() -> {
                return ((ModelParameterContext) this).modelParameterCollectionToVectorValues(indexedVectorCollection2);
            }), ((AsyncImplicits) this).asyncF()).flatMap(vector3 -> {
                return package$all$.MODULE$.toFlatMapOps(Async$.MODULE$.apply(((AsyncImplicits) this).asyncF()).delay(() -> {
                    return ScalaVectorOps$Implicits$.MODULE$.VectorOps(vector3).magnitudeSquared();
                }), ((AsyncImplicits) this).asyncF()).flatMap(obj -> {
                    return $anonfun$calculateNextLogPdf$6(this, vector3, vector, vector2, lineEvaluationResult, BoxesRunTime.unboxToDouble(obj));
                });
            });
        }), ((AsyncImplicits) this).asyncF()).flatMap(tuple3 -> {
            if (tuple3 == null) {
                throw new MatchError(tuple3);
            }
            LineEvaluationResult lineEvaluationResult2 = (LineEvaluationResult) tuple3._1();
            Vector vector3 = (Vector) tuple3._2();
            Vector vector4 = (Vector) tuple3._3();
            return package$all$.MODULE$.toFlatMapOps(Async$.MODULE$.apply(((AsyncImplicits) this).asyncF()).delay(() -> {
                return lineEvaluationResult2.result() - lineEvaluationResult.result();
            }), ((AsyncImplicits) this).asyncF()).flatMap(obj -> {
                return $anonfun$calculateNextLogPdf$16(this, lineEvaluationResult2, i, vector3, vector4, BoxesRunTime.unboxToDouble(obj));
            });
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // ai.entrolution.thylacine.model.optimization.ModelParameterOptimizer
    default F calculateMaximumLogPdf(IndexedVectorCollection indexedVectorCollection) {
        return (F) package$all$.MODULE$.toFlatMapOps(((Posterior) this).logPdfAt(indexedVectorCollection), ((AsyncImplicits) this).asyncF()).flatMap(obj -> {
            return $anonfun$calculateMaximumLogPdf$1(this, indexedVectorCollection, BoxesRunTime.unboxToDouble(obj));
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    static /* synthetic */ Object $anonfun$calculateNextLogPdf$8(ConjugateGradientEngine conjugateGradientEngine, Vector vector, Vector vector2, LineEvaluationResult lineEvaluationResult, double d) {
        return package$all$.MODULE$.toFlatMapOps(Async$.MODULE$.apply(((AsyncImplicits) conjugateGradientEngine).asyncF()).delay(() -> {
            return ScalaVectorOps$Implicits$.MODULE$.VectorOps(vector).add(ScalaVectorOps$Implicits$.MODULE$.VectorOps(vector2).scalarMultiplyWith(d));
        }), ((AsyncImplicits) conjugateGradientEngine).asyncF()).flatMap(vector3 -> {
            return package$all$.MODULE$.toFlatMapOps(conjugateGradientEngine.searchDirectionAlong(new Tuple2(BoxesRunTime.boxToDouble(lineEvaluationResult.result()), lineEvaluationResult.vectorArgument()), vector3), ((AsyncImplicits) conjugateGradientEngine).asyncF()).flatMap(tuple2 -> {
                return package$all$.MODULE$.toFunctorOps(Async$.MODULE$.apply(((AsyncImplicits) conjugateGradientEngine).asyncF()).delay(() -> {
                    return new LineEvaluationResult(tuple2._1$mcD$sp(), (Vector) tuple2._2(), ((ModelParameterContext) conjugateGradientEngine).vectorValuesToModelParameterCollection((Vector) tuple2._2()));
                }), ((AsyncImplicits) conjugateGradientEngine).asyncF()).map(lineEvaluationResult2 -> {
                    return new Tuple3(lineEvaluationResult2, vector, vector3);
                });
            });
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    static /* synthetic */ Object $anonfun$calculateNextLogPdf$6(ConjugateGradientEngine conjugateGradientEngine, Vector vector, Vector vector2, Vector vector3, LineEvaluationResult lineEvaluationResult, double d) {
        return package$all$.MODULE$.toFlatMapOps(Async$.MODULE$.apply(((AsyncImplicits) conjugateGradientEngine).asyncF()).delay(() -> {
            return Math.max(ScalaVectorOps$Implicits$.MODULE$.VectorOps(ScalaVectorOps$Implicits$.MODULE$.VectorOps(vector).subtract(vector2)).dotProductWith(vector) / d, 0.0d);
        }), ((AsyncImplicits) conjugateGradientEngine).asyncF()).flatMap(obj -> {
            return $anonfun$calculateNextLogPdf$8(conjugateGradientEngine, vector, vector3, lineEvaluationResult, BoxesRunTime.unboxToDouble(obj));
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    static /* synthetic */ Object $anonfun$calculateNextLogPdf$16(ConjugateGradientEngine conjugateGradientEngine, LineEvaluationResult lineEvaluationResult, int i, Vector vector, Vector vector2, double d) {
        return package$all$.MODULE$.toFlatMapOps(GenSpawnOps$.MODULE$.start$extension(implicits$.MODULE$.genSpawnOps(conjugateGradientEngine.iterationUpdateCallback().apply(new OptimisationTelemetryUpdate(lineEvaluationResult.result(), d, conjugateGradientEngine.telemetryPrefix())), ((AsyncImplicits) conjugateGradientEngine).asyncF()), ((AsyncImplicits) conjugateGradientEngine).asyncF()), ((AsyncImplicits) conjugateGradientEngine).asyncF()).flatMap(fiber -> {
            return Async$.MODULE$.apply(((AsyncImplicits) conjugateGradientEngine).asyncF()).ifM(Async$.MODULE$.apply(((AsyncImplicits) conjugateGradientEngine).asyncF()).delay(() -> {
                return d > conjugateGradientEngine.convergenceThreshold() || i < conjugateGradientEngine.minimumNumberOfIterations();
            }), () -> {
                return conjugateGradientEngine.calculateNextLogPdf(lineEvaluationResult, vector, vector2, i + 1);
            }, () -> {
                return FlatMapOps$.MODULE$.$greater$greater$extension(package$all$.MODULE$.catsSyntaxFlatMapOps(GenSpawnOps$.MODULE$.start$extension(implicits$.MODULE$.genSpawnOps(conjugateGradientEngine.isConvergedCallback().apply(BoxedUnit.UNIT), ((AsyncImplicits) conjugateGradientEngine).asyncF()), ((AsyncImplicits) conjugateGradientEngine).asyncF()), ((AsyncImplicits) conjugateGradientEngine).asyncF()), () -> {
                    return Async$.MODULE$.apply(((AsyncImplicits) conjugateGradientEngine).asyncF()).pure(new Tuple2(BoxesRunTime.boxToDouble(lineEvaluationResult.result()), lineEvaluationResult.modelParameterArgument()));
                }, ((AsyncImplicits) conjugateGradientEngine).asyncF());
            });
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    static /* synthetic */ Object $anonfun$calculateMaximumLogPdf$1(ConjugateGradientEngine conjugateGradientEngine, IndexedVectorCollection indexedVectorCollection, double d) {
        return package$all$.MODULE$.toFlatMapOps(package$all$.MODULE$.toFunctorOps(((Posterior) conjugateGradientEngine).logPdfGradientAt(indexedVectorCollection), ((AsyncImplicits) conjugateGradientEngine).asyncF()).map(indexedVectorCollection2 -> {
            return ((ModelParameterContext) conjugateGradientEngine).modelParameterCollectionToVectorValues(indexedVectorCollection2.rawScalarMultiplyWith(-1.0d));
        }), ((AsyncImplicits) conjugateGradientEngine).asyncF()).flatMap(vector -> {
            return package$all$.MODULE$.toFlatMapOps(Async$.MODULE$.apply(((AsyncImplicits) conjugateGradientEngine).asyncF()).delay(() -> {
                return ((ModelParameterContext) conjugateGradientEngine).modelParameterCollectionToVectorValues(indexedVectorCollection);
            }), ((AsyncImplicits) conjugateGradientEngine).asyncF()).flatMap(vector -> {
                return conjugateGradientEngine.calculateNextLogPdf(new LineEvaluationResult(d, vector, indexedVectorCollection), vector, vector, 1);
            });
        });
    }
}
