/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.policy;

import java.io.IOException;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.network.dqn.DQN;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;

public class DQNPolicy<O extends Encodable>
extends Policy<O, Integer> {
    private final IDQN dqn;

    public static <O extends Encodable> DQNPolicy<O> load(String path) throws IOException {
        return new DQNPolicy<O>(DQN.load(path));
    }

    @Override
    public IDQN getNeuralNet() {
        return this.dqn;
    }

    @Override
    public Integer nextAction(INDArray input) {
        INDArray output = this.dqn.output(input);
        return Learning.getMaxAction(output);
    }

    public void save(String filename) throws IOException {
        this.dqn.save(filename);
    }

    public DQNPolicy(IDQN dqn) {
        this.dqn = dqn;
    }
}

