package org.campagnelab.dl.framework.tools;

import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Date;
import java.util.Properties;
import org.apache.commons.io.FileUtils;
import org.campagnelab.dl.framework.architecture.graphs.ComputationGraphAssembler;
import org.campagnelab.dl.framework.domains.DomainDescriptor;
import org.campagnelab.dl.framework.models.ComputationGraphSaver;
import org.campagnelab.dl.framework.models.ModelLoader;
import org.campagnelab.dl.framework.tools.arguments.AbstractTool;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/campagnelab/dl/framework/tools/TransferPretrainingModelParameters.class */
public abstract class TransferPretrainingModelParameters<RecordType> extends AbstractTool<TransferPretrainingModelParametersArguments> {
    private static Logger LOG;
    private DomainDescriptor<RecordType> domainDescriptor;
    private ComputationGraphAssembler assembler;
    private ComputationGraph computationGraph;
    private String modelPath;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // org.campagnelab.dl.framework.tools.arguments.AbstractTool
    public void execute() {
        try {
            this.modelPath = args().modelPath != null ? args().modelPath : "models/" + Long.toString(new Date().getTime());
            FileUtils.forceMkdir(new File(this.modelPath));
            this.domainDescriptor = (DomainDescriptor) Class.forName(args().domainDescriptorName()).getConstructor(String.class).newInstance(args().pretrainingModelPath);
            this.assembler = this.domainDescriptor.getComputationalGraph();
            getComputationGraph();
            transferParams();
            transferProperties();
        } catch (Exception e) {
            throw new RuntimeException("Couldn't transfer parameters", e);
        }
    }

    private void getComputationGraph() {
        ComputationGraphAssembler computationalGraph = this.domainDescriptor.getComputationalGraph();
        if (!$assertionsDisabled && computationalGraph == null) {
            throw new AssertionError("Computational Graph assembler must be defined.");
        }
        computationalGraph.setArguments(args());
        for (String str : computationalGraph.getInputNames()) {
            int[] iArr = (int[]) this.domainDescriptor.getNumInputs(str).clone();
            if (!$assertionsDisabled && iArr.length != 2) {
                throw new AssertionError("Invalid size for domain descriptor feature");
            }
            if ((args().eosIndex != null && args().eosIndex.intValue() == iArr[0]) || args().eosIndex == null) {
                iArr[0] = iArr[0] + 1;
            }
            computationalGraph.setNumInputs(str, iArr);
        }
        for (String str2 : computationalGraph.getOutputNames()) {
            computationalGraph.setNumOutputs(str2, this.domainDescriptor.getNumOutputs(str2));
            computationalGraph.setLossFunction(str2, this.domainDescriptor.getOutputLoss(str2));
        }
        for (String str3 : computationalGraph.getComponentNames()) {
            computationalGraph.setNumHiddenNodes(str3, this.domainDescriptor.getNumHiddenNodes(str3));
        }
        this.computationGraph = computationalGraph.createComputationalGraph(this.domainDescriptor);
        this.computationGraph.init();
    }

    private void transferProperties() throws IOException {
        Properties properties = new Properties();
        Properties properties2 = new Properties();
        FileReader fileReader = new FileReader(new File(args().pretrainingModelPath, "config.properties"));
        FileReader fileReader2 = new FileReader(new File(args().pretrainingModelPath, "domain.properties"));
        properties.load(fileReader);
        properties2.load(fileReader2);
        FileWriter fileWriter = new FileWriter(new File(this.modelPath, "config.properties"));
        FileWriter fileWriter2 = new FileWriter(new File(this.modelPath, "domain.properties"));
        properties.store(fileWriter, String.format("Config properties created via pretraining parameter transfer from %s", args().pretrainingModelPath));
        properties2.store(fileWriter2, String.format("Domain properties created via pretraining parameter transfer from %s", args().pretrainingModelPath));
        fileWriter.close();
        fileWriter2.close();
    }

    private void transferParams() throws IOException {
        if (args().pretrainingModelPath != null) {
            ComputationGraph loadModel = new ModelLoader(args().pretrainingModelPath).loadModel(args().pretrainingModelName);
            ComputationGraph computationGraph = loadModel instanceof ComputationGraph ? loadModel : null;
            if (loadModel == null || computationGraph == null || computationGraph.getUpdater() == null || computationGraph.getLayers() == null) {
                throw new RuntimeException(String.format("Unable to load model for pretraining from %s", args().pretrainingModelPath));
            }
            for (String str : this.assembler.getInputNames()) {
                this.computationGraph.getLayer(str).setParams(computationGraph.getLayer(str).params());
            }
            for (String str2 : this.assembler.getComponentNames()) {
                this.computationGraph.getLayer(str2).setParams(computationGraph.getLayer(str2).params());
            }
            new ComputationGraphSaver(this.modelPath).saveModel(this.computationGraph, args().modelPrefix != null ? args().modelPrefix : "pretraining");
        }
    }

    static {
        $assertionsDisabled = !TransferPretrainingModelParameters.class.desiredAssertionStatus();
        LOG = LoggerFactory.getLogger(TransferPretrainingModelParameters.class);
    }
}
