package org.deeplearning4j.rl4j.learning.sync;

import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/SyncLearning.class */
public abstract class SyncLearning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet> extends Learning<O, A, AS, NN> {
    private static final Logger log = LoggerFactory.getLogger(SyncLearning.class);
    private int lastSave;

    public SyncLearning(ILearning.LConfiguration lConfiguration) {
        super(lConfiguration);
        this.lastSave = -100000;
    }

    @Override // org.deeplearning4j.rl4j.learning.ILearning
    public void train() {
        try {
            log.info("training starting.");
            getDataManager().writeInfo(this);
            while (getStepCounter() < getConfiguration().getMaxStep()) {
                preEpoch();
                DataManager.StatEntry trainEpoch = trainEpoch();
                postEpoch();
                incrementEpoch();
                if (getStepCounter() - this.lastSave >= 100000) {
                    getDataManager().save(this);
                    this.lastSave = getStepCounter();
                }
                getDataManager().appendStat(trainEpoch);
                getDataManager().writeInfo(this);
                log.info("Epoch: " + getEpochCounter() + ", reward: " + trainEpoch.getReward());
            }
        } catch (Exception e) {
            log.error("Training failed.", e);
            e.printStackTrace();
        }
    }

    protected abstract void preEpoch();

    protected abstract void postEpoch();

    protected abstract DataManager.StatEntry trainEpoch();
}
