package org.tensorflow.framework.optimizers;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Operation;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
import org.tensorflow.op.OpScope;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Assign;
import org.tensorflow.op.core.NoOp;
import org.tensorflow.op.core.Variable;
import org.tensorflow.types.family.TType;

/* loaded from: input_file:org/tensorflow/framework/optimizers/Optimizer.class */
public abstract class Optimizer {
    public static final String VARIABLE_V2 = "VariableV2";
    protected final Graph graph;
    protected final Ops tf;
    private final Map<String, Map<String, Variable<?>>> slots = new HashMap();
    protected final List<Variable<?>> globals = new ArrayList();

    /* loaded from: input_file:org/tensorflow/framework/optimizers/Optimizer$GradAndVar.class */
    public static class GradAndVar<T extends TType> {
        private final Output<T> gradient;
        private final Output<T> variable;

        public GradAndVar(Output<T> output, Output<T> output2) {
            this.gradient = output;
            this.variable = output2;
        }

        public Output<T> getGradient() {
            return this.gradient;
        }

        public Output<T> getVariable() {
            return this.variable;
        }
    }

    /* loaded from: input_file:org/tensorflow/framework/optimizers/Optimizer$Options.class */
    public static class Options {
        protected String sharedName;

        private Options() {
        }

        public Options sharedName(String str) {
            this.sharedName = str;
            return this;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Optimizer(Graph graph) {
        this.graph = graph;
        this.tf = Ops.create(graph).withName(getOptimizerName());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Optimizer(Graph graph, String str) {
        this.graph = graph;
        this.tf = Ops.create(graph).withName(str);
    }

    public static String createName(Output<? extends TType> output, String str) {
        return output.op().name() + "-" + str;
    }

    public final Ops getTF() {
        return this.tf;
    }

    public Op minimize(Operand<?> operand) {
        return minimize(operand, getOptimizerName() + "-minimize");
    }

    public Op minimize(Operand<?> operand, String str) {
        return applyGradients(computeGradients(operand), str);
    }

    public <T extends TType> List<GradAndVar<?>> computeGradients(Operand<?> operand) {
        ArrayList arrayList = new ArrayList();
        this.graph.operations().forEachRemaining(operation -> {
            if (operation.type().equals(VARIABLE_V2)) {
                arrayList.add(operation);
            }
        });
        Output[] outputArr = new Output[arrayList.size()];
        for (int i = 0; i < arrayList.size(); i++) {
            outputArr[i] = ((Operation) arrayList.get(i)).output(0);
        }
        Output[] addGradients = this.graph.addGradients(operand.asOutput(), outputArr);
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < outputArr.length; i2++) {
            arrayList2.add(new GradAndVar(addGradients[i2], outputArr[i2]));
        }
        return arrayList2;
    }

    public Op applyGradients(List<GradAndVar<? extends TType>> list, String str) {
        createSlots((List) list.stream().map((v0) -> {
            return v0.getVariable();
        }).collect(Collectors.toList()));
        Ops withControlDependencies = this.tf.withControlDependencies((List) list.stream().map((v0) -> {
            return v0.getGradient();
        }).filter(output -> {
            return !output.isClosed();
        }).collect(Collectors.toList()));
        Optional<Op> prepare = prepare(str + "/prepare");
        ArrayList arrayList = new ArrayList();
        Objects.requireNonNull(arrayList);
        prepare.ifPresent((v1) -> {
            r1.add(v1);
        });
        for (GradAndVar<? extends TType> gradAndVar : list) {
            if (!((GradAndVar) gradAndVar).gradient.isClosed()) {
                arrayList.add(applyDense(withControlDependencies, gradAndVar));
            }
        }
        return finish(arrayList, str);
    }

    public <T extends TType> Optional<Variable<T>> getSlot(Output<T> output, String str) {
        return getSlot(output.op().name(), str);
    }

    private <T extends TType> Optional<Variable<T>> getSlot(String str, String str2) {
        Variable<?> variable;
        Map<String, Variable<?>> map = this.slots.get(str2);
        if (map != null && (variable = map.get(str)) != null) {
            return Optional.of(variable);
        }
        return Optional.empty();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public <T extends TType> void createSlot(Output<T> output, String str, Operand<T> operand) {
        Variable<?> variable = this.tf.withInitScope().withName(createName(output, str)).variable(output.shape(), output.type(), new Variable.Options[0]);
        this.tf.withInitScope().assign(variable, operand, new Assign.Options[0]);
        this.slots.computeIfAbsent(str, str2 -> {
            return new HashMap();
        }).put(output.op().name(), variable);
    }

    protected Optional<Op> prepare(String str) {
        return Optional.empty();
    }

    protected void createSlots(List<Output<? extends TType>> list) {
    }

    private <T extends TType> Op applyDense(Ops ops, GradAndVar<T> gradAndVar) {
        return applyDense(ops, gradAndVar.getGradient(), gradAndVar.getVariable());
    }

    protected abstract <T extends TType> Op applyDense(Ops ops, Output<T> output, Output<T> output2);

    /* JADX INFO: Access modifiers changed from: protected */
    public Op finish(List<Op> list, String str) {
        return NoOp.create(new OpScope(this.graph).withName(str).withControlDependencies(list));
    }

    public abstract String getOptimizerName();
}
