package ai.konduit.serving.models.samediff.step.trainer;

import ai.konduit.serving.annotation.json.JsonName;
import ai.konduit.serving.pipeline.api.step.PipelineStep;
import io.swagger.v3.oas.annotations.media.Schema;
import java.util.List;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonName("SAMEDIFF_TRAINING")
@Schema(description = "A pipeline step that configures a SameDiff model that is to be executed.")
/* loaded from: input_file:ai/konduit/serving/models/samediff/step/trainer/SameDiffTrainerStep.class */
public class SameDiffTrainerStep implements PipelineStep {

    @Schema(description = "Specifies the location of a saved model file.")
    private String modelUri;

    @Schema(description = "An L1 regularization coefficient for application during training.  Set this value for l1 regularization. Not applied by default.")
    private double l1;

    @Schema(description = "An L2 regularization coefficient for application during training. Set this value for l2 regularization. Not applied by default.")
    private double l2;

    @Schema(description = "A weight regularization coefficient for application during training. Set this value to enable weight decay. Disabled byd efault.")
    private double weightDecayCoefficient;

    @Schema(description = "Whether to apply learning rate during weight decay,defaults to true")
    private boolean weightDecayApplyLearningRate;

    @Schema(description = "Specifies the location of the model save path")
    private String modelSaveOutputPath;

    @Schema(description = "Specifies the number of epochs to run training for")
    private int numEpochs;

    @Schema(description = "A list of names of the loss variables- the names of the targets to train against for the loss function")
    private List<String> lossVariables;

    @Schema(description = "A list of names of the input variables- the names of the input variables for training")
    private List<String> inputFeatures;

    @Schema(description = "A list of names of the labels variables- the names of the true labels for prediction to calculate error against")
    private List<String> labels;

    @Schema(description = "A list of names of the prediction variables- the names of the prediction labels for prediction to calculate error against")
    private List<String> targetVariables;

    @Schema(description = "The updater to use for training. When specifying an updater on the command line, the type is needed. Valid types include:  AMSGRAD,ADABELIEF,ADAGRAD,ADADELTA,ADAMAX,ADAM,NADAM,NESTEROVS,NOOP,RMSPROP,SGD . Each field for the updater must be specified in terms of field name = value separated by commas. Relevant updaters and their fields can be found here: https://github.com/eclipse/deeplearning4j/tree/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/config")
    private IUpdater updater;

    @Schema(description = "The learning rate to use for training")
    private double learningRate;

    @Schema(description = "The learning rate schedule to use for training. When specifying a learning rate or momentum schedule, comma separated values with key=value for each field is required. Valid values include: poly,step,cycle,fixed,inverse,sigmoid,exponential. Relevant schedules and their fields can be found here: https://github.com/eclipse/deeplearning4j/tree/master/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/schedule - it is recommended when specifying this value on the command line to use \" to ensure the value gets parsed properly.")
    private ISchedule learningRateSchedule;

    @Schema(description = "The initial loss type for the training, defaults to float")
    private DataType initialLossType;

    @Schema(description = "The loss function to use for training models")
    private LossFunctions.LossFunction lossFunction;

    @Schema(description = "Enable debug mode, defaults to false")
    private boolean debugMode;

    @Schema(description = "Enable verbose mode, defaults to false")
    private boolean verboseMode;

    /* loaded from: input_file:ai/konduit/serving/models/samediff/step/trainer/SameDiffTrainerStep$SameDiffTrainerStepBuilder.class */
    public static abstract class SameDiffTrainerStepBuilder<C extends SameDiffTrainerStep, B extends SameDiffTrainerStepBuilder<C, B>> {
        private String modelUri;
        private double l1;
        private double l2;
        private double weightDecayCoefficient;
        private boolean weightDecayApplyLearningRate;
        private String modelSaveOutputPath;
        private int numEpochs;
        private List<String> lossVariables;
        private List<String> inputFeatures;
        private List<String> labels;
        private List<String> targetVariables;
        private IUpdater updater;
        private double learningRate;
        private ISchedule learningRateSchedule;
        private DataType initialLossType;
        private LossFunctions.LossFunction lossFunction;
        private boolean debugMode;
        private boolean verboseMode;

        protected abstract B self();

        public abstract C build();

        public B modelUri(String str) {
            this.modelUri = str;
            return self();
        }

        public B l1(double d) {
            this.l1 = d;
            return self();
        }

        public B l2(double d) {
            this.l2 = d;
            return self();
        }

        public B weightDecayCoefficient(double d) {
            this.weightDecayCoefficient = d;
            return self();
        }

        public B weightDecayApplyLearningRate(boolean z) {
            this.weightDecayApplyLearningRate = z;
            return self();
        }

        public B modelSaveOutputPath(String str) {
            this.modelSaveOutputPath = str;
            return self();
        }

        public B numEpochs(int i) {
            this.numEpochs = i;
            return self();
        }

        public B lossVariables(List<String> list) {
            this.lossVariables = list;
            return self();
        }

        public B inputFeatures(List<String> list) {
            this.inputFeatures = list;
            return self();
        }

        public B labels(List<String> list) {
            this.labels = list;
            return self();
        }

        public B targetVariables(List<String> list) {
            this.targetVariables = list;
            return self();
        }

        public B updater(IUpdater iUpdater) {
            this.updater = iUpdater;
            return self();
        }

        public B learningRate(double d) {
            this.learningRate = d;
            return self();
        }

        public B learningRateSchedule(ISchedule iSchedule) {
            this.learningRateSchedule = iSchedule;
            return self();
        }

        public B initialLossType(DataType dataType) {
            this.initialLossType = dataType;
            return self();
        }

        public B lossFunction(LossFunctions.LossFunction lossFunction) {
            this.lossFunction = lossFunction;
            return self();
        }

        public B debugMode(boolean z) {
            this.debugMode = z;
            return self();
        }

        public B verboseMode(boolean z) {
            this.verboseMode = z;
            return self();
        }

        public String toString() {
            return "SameDiffTrainerStep.SameDiffTrainerStepBuilder(modelUri=" + this.modelUri + ", l1=" + this.l1 + ", l2=" + this.l2 + ", weightDecayCoefficient=" + this.weightDecayCoefficient + ", weightDecayApplyLearningRate=" + this.weightDecayApplyLearningRate + ", modelSaveOutputPath=" + this.modelSaveOutputPath + ", numEpochs=" + this.numEpochs + ", lossVariables=" + this.lossVariables + ", inputFeatures=" + this.inputFeatures + ", labels=" + this.labels + ", targetVariables=" + this.targetVariables + ", updater=" + this.updater + ", learningRate=" + this.learningRate + ", learningRateSchedule=" + this.learningRateSchedule + ", initialLossType=" + this.initialLossType + ", lossFunction=" + this.lossFunction + ", debugMode=" + this.debugMode + ", verboseMode=" + this.verboseMode + ")";
        }
    }

    /* loaded from: input_file:ai/konduit/serving/models/samediff/step/trainer/SameDiffTrainerStep$SameDiffTrainerStepBuilderImpl.class */
    private static final class SameDiffTrainerStepBuilderImpl extends SameDiffTrainerStepBuilder<SameDiffTrainerStep, SameDiffTrainerStepBuilderImpl> {
        private SameDiffTrainerStepBuilderImpl() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.konduit.serving.models.samediff.step.trainer.SameDiffTrainerStep.SameDiffTrainerStepBuilder
        public SameDiffTrainerStepBuilderImpl self() {
            return this;
        }

        @Override // ai.konduit.serving.models.samediff.step.trainer.SameDiffTrainerStep.SameDiffTrainerStepBuilder
        public SameDiffTrainerStep build() {
            return new SameDiffTrainerStep(this);
        }
    }

    public SameDiffTrainerStep(@JsonProperty("modelUri") String str, @JsonProperty("l1") double d, @JsonProperty("l2") double d2, @JsonProperty("modelSaveOutputPath") String str2, @JsonProperty("numEpochs") int i, @JsonProperty("inputFeatures") List<String> list, @JsonProperty("lossVariables") List<String> list2, @JsonProperty("labels") List<String> list3, @JsonProperty("targetVariables") List<String> list4, @JsonProperty("weightDecayCoefficient") double d3, @JsonProperty("weightDecayApplyLearningRate") boolean z, @JsonProperty("updater") IUpdater iUpdater, @JsonProperty("learningRate") double d4, @JsonProperty("learningRateSchedule") ISchedule iSchedule, @JsonProperty("initialLossType") DataType dataType, @JsonProperty("lossFunction") LossFunctions.LossFunction lossFunction, @JsonProperty("debugMode") boolean z2, @JsonProperty("verboseMode") boolean z3) {
        this.l1 = -1.0d;
        this.l2 = -1.0d;
        this.weightDecayApplyLearningRate = true;
        this.numEpochs = 1;
        this.initialLossType = DataType.FLOAT;
        this.debugMode = false;
        this.verboseMode = false;
        this.modelUri = str;
        this.l1 = d;
        this.l2 = d2;
        this.modelSaveOutputPath = str2;
        this.numEpochs = i;
        this.lossVariables = list2;
        this.inputFeatures = list;
        this.targetVariables = list4;
        this.labels = list3;
        this.weightDecayApplyLearningRate = z;
        this.weightDecayCoefficient = d3;
        this.learningRate = d4;
        this.learningRateSchedule = iSchedule;
        this.updater = iUpdater;
        this.lossFunction = lossFunction;
        if (dataType != null) {
            this.initialLossType = dataType;
        }
        if (d4 > 0.0d && iSchedule != null) {
            this.updater.setLrAndSchedule(d4, iSchedule);
        }
        this.debugMode = z2;
        this.verboseMode = z3;
    }

    protected SameDiffTrainerStep(SameDiffTrainerStepBuilder<?, ?> sameDiffTrainerStepBuilder) {
        this.l1 = -1.0d;
        this.l2 = -1.0d;
        this.weightDecayApplyLearningRate = true;
        this.numEpochs = 1;
        this.initialLossType = DataType.FLOAT;
        this.debugMode = false;
        this.verboseMode = false;
        this.modelUri = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).modelUri;
        this.l1 = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).l1;
        this.l2 = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).l2;
        this.weightDecayCoefficient = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).weightDecayCoefficient;
        this.weightDecayApplyLearningRate = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).weightDecayApplyLearningRate;
        this.modelSaveOutputPath = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).modelSaveOutputPath;
        this.numEpochs = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).numEpochs;
        this.lossVariables = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).lossVariables;
        this.inputFeatures = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).inputFeatures;
        this.labels = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).labels;
        this.targetVariables = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).targetVariables;
        this.updater = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).updater;
        this.learningRate = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).learningRate;
        this.learningRateSchedule = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).learningRateSchedule;
        this.initialLossType = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).initialLossType;
        this.lossFunction = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).lossFunction;
        this.debugMode = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).debugMode;
        this.verboseMode = ((SameDiffTrainerStepBuilder) sameDiffTrainerStepBuilder).verboseMode;
    }

    public static SameDiffTrainerStepBuilder<?, ?> builder() {
        return new SameDiffTrainerStepBuilderImpl();
    }

    public String modelUri() {
        return this.modelUri;
    }

    public double l1() {
        return this.l1;
    }

    public double l2() {
        return this.l2;
    }

    public double weightDecayCoefficient() {
        return this.weightDecayCoefficient;
    }

    public boolean weightDecayApplyLearningRate() {
        return this.weightDecayApplyLearningRate;
    }

    public String modelSaveOutputPath() {
        return this.modelSaveOutputPath;
    }

    public int numEpochs() {
        return this.numEpochs;
    }

    public List<String> lossVariables() {
        return this.lossVariables;
    }

    public List<String> inputFeatures() {
        return this.inputFeatures;
    }

    public List<String> labels() {
        return this.labels;
    }

    public List<String> targetVariables() {
        return this.targetVariables;
    }

    public IUpdater updater() {
        return this.updater;
    }

    public double learningRate() {
        return this.learningRate;
    }

    public ISchedule learningRateSchedule() {
        return this.learningRateSchedule;
    }

    public DataType initialLossType() {
        return this.initialLossType;
    }

    public LossFunctions.LossFunction lossFunction() {
        return this.lossFunction;
    }

    public boolean debugMode() {
        return this.debugMode;
    }

    public boolean verboseMode() {
        return this.verboseMode;
    }

    public SameDiffTrainerStep modelUri(String str) {
        this.modelUri = str;
        return this;
    }

    public SameDiffTrainerStep l1(double d) {
        this.l1 = d;
        return this;
    }

    public SameDiffTrainerStep l2(double d) {
        this.l2 = d;
        return this;
    }

    public SameDiffTrainerStep weightDecayCoefficient(double d) {
        this.weightDecayCoefficient = d;
        return this;
    }

    public SameDiffTrainerStep weightDecayApplyLearningRate(boolean z) {
        this.weightDecayApplyLearningRate = z;
        return this;
    }

    public SameDiffTrainerStep modelSaveOutputPath(String str) {
        this.modelSaveOutputPath = str;
        return this;
    }

    public SameDiffTrainerStep numEpochs(int i) {
        this.numEpochs = i;
        return this;
    }

    public SameDiffTrainerStep lossVariables(List<String> list) {
        this.lossVariables = list;
        return this;
    }

    public SameDiffTrainerStep inputFeatures(List<String> list) {
        this.inputFeatures = list;
        return this;
    }

    public SameDiffTrainerStep labels(List<String> list) {
        this.labels = list;
        return this;
    }

    public SameDiffTrainerStep targetVariables(List<String> list) {
        this.targetVariables = list;
        return this;
    }

    public SameDiffTrainerStep updater(IUpdater iUpdater) {
        this.updater = iUpdater;
        return this;
    }

    public SameDiffTrainerStep learningRate(double d) {
        this.learningRate = d;
        return this;
    }

    public SameDiffTrainerStep learningRateSchedule(ISchedule iSchedule) {
        this.learningRateSchedule = iSchedule;
        return this;
    }

    public SameDiffTrainerStep initialLossType(DataType dataType) {
        this.initialLossType = dataType;
        return this;
    }

    public SameDiffTrainerStep lossFunction(LossFunctions.LossFunction lossFunction) {
        this.lossFunction = lossFunction;
        return this;
    }

    public SameDiffTrainerStep debugMode(boolean z) {
        this.debugMode = z;
        return this;
    }

    public SameDiffTrainerStep verboseMode(boolean z) {
        this.verboseMode = z;
        return this;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof SameDiffTrainerStep)) {
            return false;
        }
        SameDiffTrainerStep sameDiffTrainerStep = (SameDiffTrainerStep) obj;
        if (!sameDiffTrainerStep.canEqual(this) || Double.compare(l1(), sameDiffTrainerStep.l1()) != 0 || Double.compare(l2(), sameDiffTrainerStep.l2()) != 0 || Double.compare(weightDecayCoefficient(), sameDiffTrainerStep.weightDecayCoefficient()) != 0 || weightDecayApplyLearningRate() != sameDiffTrainerStep.weightDecayApplyLearningRate() || numEpochs() != sameDiffTrainerStep.numEpochs() || Double.compare(learningRate(), sameDiffTrainerStep.learningRate()) != 0 || debugMode() != sameDiffTrainerStep.debugMode() || verboseMode() != sameDiffTrainerStep.verboseMode()) {
            return false;
        }
        String modelUri = modelUri();
        String modelUri2 = sameDiffTrainerStep.modelUri();
        if (modelUri == null) {
            if (modelUri2 != null) {
                return false;
            }
        } else if (!modelUri.equals(modelUri2)) {
            return false;
        }
        String modelSaveOutputPath = modelSaveOutputPath();
        String modelSaveOutputPath2 = sameDiffTrainerStep.modelSaveOutputPath();
        if (modelSaveOutputPath == null) {
            if (modelSaveOutputPath2 != null) {
                return false;
            }
        } else if (!modelSaveOutputPath.equals(modelSaveOutputPath2)) {
            return false;
        }
        List<String> lossVariables = lossVariables();
        List<String> lossVariables2 = sameDiffTrainerStep.lossVariables();
        if (lossVariables == null) {
            if (lossVariables2 != null) {
                return false;
            }
        } else if (!lossVariables.equals(lossVariables2)) {
            return false;
        }
        List<String> inputFeatures = inputFeatures();
        List<String> inputFeatures2 = sameDiffTrainerStep.inputFeatures();
        if (inputFeatures == null) {
            if (inputFeatures2 != null) {
                return false;
            }
        } else if (!inputFeatures.equals(inputFeatures2)) {
            return false;
        }
        List<String> labels = labels();
        List<String> labels2 = sameDiffTrainerStep.labels();
        if (labels == null) {
            if (labels2 != null) {
                return false;
            }
        } else if (!labels.equals(labels2)) {
            return false;
        }
        List<String> targetVariables = targetVariables();
        List<String> targetVariables2 = sameDiffTrainerStep.targetVariables();
        if (targetVariables == null) {
            if (targetVariables2 != null) {
                return false;
            }
        } else if (!targetVariables.equals(targetVariables2)) {
            return false;
        }
        IUpdater updater = updater();
        IUpdater updater2 = sameDiffTrainerStep.updater();
        if (updater == null) {
            if (updater2 != null) {
                return false;
            }
        } else if (!updater.equals(updater2)) {
            return false;
        }
        ISchedule learningRateSchedule = learningRateSchedule();
        ISchedule learningRateSchedule2 = sameDiffTrainerStep.learningRateSchedule();
        if (learningRateSchedule == null) {
            if (learningRateSchedule2 != null) {
                return false;
            }
        } else if (!learningRateSchedule.equals(learningRateSchedule2)) {
            return false;
        }
        DataType initialLossType = initialLossType();
        DataType initialLossType2 = sameDiffTrainerStep.initialLossType();
        if (initialLossType == null) {
            if (initialLossType2 != null) {
                return false;
            }
        } else if (!initialLossType.equals(initialLossType2)) {
            return false;
        }
        LossFunctions.LossFunction lossFunction = lossFunction();
        LossFunctions.LossFunction lossFunction2 = sameDiffTrainerStep.lossFunction();
        return lossFunction == null ? lossFunction2 == null : lossFunction.equals(lossFunction2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof SameDiffTrainerStep;
    }

    public int hashCode() {
        long doubleToLongBits = Double.doubleToLongBits(l1());
        int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        long doubleToLongBits2 = Double.doubleToLongBits(l2());
        int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
        long doubleToLongBits3 = Double.doubleToLongBits(weightDecayCoefficient());
        int numEpochs = (((((i2 * 59) + ((int) ((doubleToLongBits3 >>> 32) ^ doubleToLongBits3))) * 59) + (weightDecayApplyLearningRate() ? 79 : 97)) * 59) + numEpochs();
        long doubleToLongBits4 = Double.doubleToLongBits(learningRate());
        int i3 = (((((numEpochs * 59) + ((int) ((doubleToLongBits4 >>> 32) ^ doubleToLongBits4))) * 59) + (debugMode() ? 79 : 97)) * 59) + (verboseMode() ? 79 : 97);
        String modelUri = modelUri();
        int hashCode = (i3 * 59) + (modelUri == null ? 43 : modelUri.hashCode());
        String modelSaveOutputPath = modelSaveOutputPath();
        int hashCode2 = (hashCode * 59) + (modelSaveOutputPath == null ? 43 : modelSaveOutputPath.hashCode());
        List<String> lossVariables = lossVariables();
        int hashCode3 = (hashCode2 * 59) + (lossVariables == null ? 43 : lossVariables.hashCode());
        List<String> inputFeatures = inputFeatures();
        int hashCode4 = (hashCode3 * 59) + (inputFeatures == null ? 43 : inputFeatures.hashCode());
        List<String> labels = labels();
        int hashCode5 = (hashCode4 * 59) + (labels == null ? 43 : labels.hashCode());
        List<String> targetVariables = targetVariables();
        int hashCode6 = (hashCode5 * 59) + (targetVariables == null ? 43 : targetVariables.hashCode());
        IUpdater updater = updater();
        int hashCode7 = (hashCode6 * 59) + (updater == null ? 43 : updater.hashCode());
        ISchedule learningRateSchedule = learningRateSchedule();
        int hashCode8 = (hashCode7 * 59) + (learningRateSchedule == null ? 43 : learningRateSchedule.hashCode());
        DataType initialLossType = initialLossType();
        int hashCode9 = (hashCode8 * 59) + (initialLossType == null ? 43 : initialLossType.hashCode());
        LossFunctions.LossFunction lossFunction = lossFunction();
        return (hashCode9 * 59) + (lossFunction == null ? 43 : lossFunction.hashCode());
    }

    public String toString() {
        return "SameDiffTrainerStep(modelUri=" + modelUri() + ", l1=" + l1() + ", l2=" + l2() + ", weightDecayCoefficient=" + weightDecayCoefficient() + ", weightDecayApplyLearningRate=" + weightDecayApplyLearningRate() + ", modelSaveOutputPath=" + modelSaveOutputPath() + ", numEpochs=" + numEpochs() + ", lossVariables=" + lossVariables() + ", inputFeatures=" + inputFeatures() + ", labels=" + labels() + ", targetVariables=" + targetVariables() + ", updater=" + updater() + ", learningRate=" + learningRate() + ", learningRateSchedule=" + learningRateSchedule() + ", initialLossType=" + initialLossType() + ", lossFunction=" + lossFunction() + ", debugMode=" + debugMode() + ", verboseMode=" + verboseMode() + ")";
    }

    public SameDiffTrainerStep() {
        this.l1 = -1.0d;
        this.l2 = -1.0d;
        this.weightDecayApplyLearningRate = true;
        this.numEpochs = 1;
        this.initialLossType = DataType.FLOAT;
        this.debugMode = false;
        this.verboseMode = false;
    }
}
