/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.policy;

import java.io.IOException;
import java.util.Random;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.network.ac.ActorCriticCompGraph;
import org.deeplearning4j.rl4j.network.ac.ActorCriticSeparate;
import org.deeplearning4j.rl4j.network.ac.IActorCritic;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;

public class ACPolicy<O extends Encodable>
extends Policy<O, Integer> {
    private final IActorCritic IActorCritic;
    Random rd;

    public ACPolicy(IActorCritic IActorCritic2) {
        this.IActorCritic = IActorCritic2;
        NeuralNetwork nn = IActorCritic2.getNeuralNetworks()[0];
        if (nn instanceof ComputationGraph) {
            this.rd = new Random(((ComputationGraph)nn).getConfiguration().getDefaultConfiguration().getSeed());
        } else if (nn instanceof MultiLayerNetwork) {
            this.rd = new Random(((MultiLayerNetwork)nn).getDefaultConfiguration().getSeed());
        }
    }

    public ACPolicy(IActorCritic IActorCritic2, Random rd) {
        this.IActorCritic = IActorCritic2;
        this.rd = rd;
    }

    public static <O extends Encodable> ACPolicy<O> load(String path) throws IOException {
        return new ACPolicy<O>(ActorCriticCompGraph.load(path));
    }

    public static <O extends Encodable> ACPolicy<O> load(String path, Random rd) throws IOException {
        return new ACPolicy<O>(ActorCriticCompGraph.load(path), rd);
    }

    public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy) throws IOException {
        return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy));
    }

    public static <O extends Encodable> ACPolicy<O> load(String pathValue, String pathPolicy, Random rd) throws IOException {
        return new ACPolicy<O>(ActorCriticSeparate.load(pathValue, pathPolicy), rd);
    }

    @Override
    public IActorCritic getNeuralNet() {
        return this.IActorCritic;
    }

    @Override
    public Integer nextAction(INDArray input) {
        INDArray output = this.IActorCritic.outputAll(input)[1];
        if (this.rd == null) {
            return Learning.getMaxAction(output);
        }
        float rVal = this.rd.nextFloat();
        int i = 0;
        while ((long)i < output.length()) {
            if (rVal < output.getFloat((long)i)) {
                return i;
            }
            rVal -= output.getFloat((long)i);
            ++i;
        }
        throw new RuntimeException("Output from network is not a probability distribution: " + output);
    }

    public void save(String filename) throws IOException {
        this.IActorCritic.save(filename);
    }

    public void save(String filenameValue, String filenamePolicy) throws IOException {
        this.IActorCritic.save(filenameValue, filenamePolicy);
    }
}

