package org.campagnelab.dl.framework.tools;

import it.unimi.dsi.fastutil.io.FastBufferedOutputStream;
import it.unimi.dsi.logging.ProgressLogger;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Date;
import java.util.List;
import java.util.Properties;
import org.apache.commons.io.FilenameUtils;
import org.campagnelab.dl.framework.domains.DomainDescriptor;
import org.campagnelab.dl.framework.iterators.MultiDataSetIteratorAdapter;
import org.campagnelab.dl.framework.tools.arguments.AbstractTool;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/campagnelab/dl/framework/tools/MapMultiDatasetFeatures.class */
public abstract class MapMultiDatasetFeatures<RecordType> extends AbstractTool<MapFeaturesArguments> {
    private static Logger LOG = LoggerFactory.getLogger(MapMultiDatasetFeatures.class);
    private int numRecordsWritten;

    public void setNumRecordsWritten(int i) {
        this.numRecordsWritten = i;
    }

    protected abstract DomainDescriptor<RecordType> domainDescriptor();

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.campagnelab.dl.framework.tools.arguments.AbstractTool
    public MapFeaturesArguments createArguments() {
        return new MapMultiDatasetFeaturesArguments();
    }

    @Override // org.campagnelab.dl.framework.tools.arguments.AbstractTool
    /* renamed from: args, reason: merged with bridge method [inline-methods] */
    public MapFeaturesArguments args2() {
        return (MapMultiDatasetFeaturesArguments) this.arguments;
    }

    @Override // org.campagnelab.dl.framework.tools.arguments.AbstractTool
    public void execute() {
        DomainDescriptor<RecordType> domainDescriptor = domainDescriptor();
        AsyncMultiDataSetIterator asyncMultiDataSetIterator = args2().adapter;
        if (asyncMultiDataSetIterator == null) {
            try {
                asyncMultiDataSetIterator = new MultiDataSetIteratorAdapter<RecordType>(domainDescriptor.getRecordIterable(args2().trainingSets, args2().cacheN), args2().miniBatchSize, domainDescriptor) { // from class: org.campagnelab.dl.framework.tools.MapMultiDatasetFeatures.1
                    @Override // org.campagnelab.dl.framework.iterators.MultiDataSetIteratorAdapter
                    public String getBasename() {
                        return MapMultiDatasetFeatures.this.buildBaseName(MapMultiDatasetFeatures.this.args2().trainingSets);
                    }
                };
            } catch (IOException e) {
                throw new RuntimeException("Unable to load training set ", e);
            }
        }
        AsyncMultiDataSetIterator asyncMultiDataSetIterator2 = asyncMultiDataSetIterator;
        if (asyncMultiDataSetIterator.asyncSupported()) {
            asyncMultiDataSetIterator2 = new AsyncMultiDataSetIterator(asyncMultiDataSetIterator, 12);
        }
        String str = args2().outputBasename + ".cf";
        try {
            FastBufferedOutputStream fastBufferedOutputStream = new FastBufferedOutputStream(new FileOutputStream(str));
            Throwable th = null;
            try {
                try {
                    ProgressLogger progressLogger = new ProgressLogger(LOG);
                    progressLogger.expectedUpdates = Math.min(domainDescriptor.getNumRecords(args2().getTrainingSets()), args2().cacheN) / args2().miniBatchSize;
                    progressLogger.displayLocalSpeed = true;
                    progressLogger.itemsName = "miniBatch";
                    progressLogger.start();
                    ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                    long j = 0;
                    long j2 = args2().writeAtMostN;
                    this.numRecordsWritten = 0;
                    while (asyncMultiDataSetIterator2.hasNext()) {
                        MultiDataSet multiDataSet = (MultiDataSet) asyncMultiDataSetIterator2.next();
                        byteArrayOutputStream.reset();
                        multiDataSet.save(byteArrayOutputStream);
                        byte[] byteArray = byteArrayOutputStream.toByteArray();
                        fastBufferedOutputStream.write((byteArray.length >> 24) & 255);
                        fastBufferedOutputStream.write((byteArray.length >> 16) & 255);
                        fastBufferedOutputStream.write((byteArray.length >> 8) & 255);
                        fastBufferedOutputStream.write(byteArray.length & 255);
                        fastBufferedOutputStream.write(byteArray);
                        progressLogger.lightUpdate();
                        if (this.numRecordsWritten > j2) {
                            break;
                        }
                        j++;
                        this.numRecordsWritten += multiDataSet.getFeatures()[0].size(0);
                        if (this.numRecordsWritten > args2().cacheN) {
                            break;
                        }
                    }
                    fastBufferedOutputStream.close();
                    progressLogger.stop();
                    domainDescriptor.getNumRecords(args2().getTrainingSets());
                    Properties properties = new Properties();
                    properties.put("domainDescriptor", domainDescriptor().getClass().getCanonicalName());
                    properties.put("multiDataSet", "true");
                    properties.put("miniBatchSize", Integer.toString(args2().miniBatchSize));
                    if (args2().domainDescriptor != null) {
                        args2().domainDescriptor.putProperties(properties);
                    } else {
                        properties.put("featureMapper", args2().featureMapperClassname);
                    }
                    properties.put("isTrio", Boolean.toString(args2().isTrio));
                    properties.put("numRecords", Long.toString(this.numRecordsWritten));
                    properties.put("numDatasets", Long.toString(j));
                    String[] inputNames = domainDescriptor.getComputationalGraph().getInputNames();
                    for (String str2 : inputNames) {
                        int i = 0;
                        for (int i2 : domainDescriptor().getNumInputs(str2)) {
                            properties.put(str2 + ".numFeatures.dim" + Integer.toString(i), Integer.toString(i2));
                            i++;
                        }
                    }
                    if (inputNames.length == 1) {
                        properties.put("numFeatures", Integer.toString(domainDescriptor().getNumInputs(inputNames[0])[0]));
                    }
                    properties.put("stored", args2().trainingSets.toString());
                    properties.store(new FileWriter(new File(args2().outputBasename + ".cfp")), new Date().toString());
                    if (fastBufferedOutputStream != null) {
                        if (0 != 0) {
                            try {
                                fastBufferedOutputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            fastBufferedOutputStream.close();
                        }
                    }
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } finally {
            }
        } catch (FileNotFoundException e2) {
            LOG.error("Unable to create output file: " + str, e2);
        } catch (IOException e3) {
            LOG.error("Unable to write to output file: " + str, e3);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public String buildBaseName(List<String> list) {
        String str;
        if (list.size() == 1) {
            str = FilenameUtils.getBaseName(list.get(0));
        } else {
            long j = 8723872838723L;
            while (list.iterator().hasNext()) {
                j ^= FilenameUtils.getBaseName(r0.next()).hashCode();
            }
            str = "multiset-" + Long.toString(j);
        }
        return str;
    }

    public int getNumRecordsWritten() {
        return this.numRecordsWritten;
    }

    public void setArguments(MapMultiDatasetFeaturesArguments<RecordType> mapMultiDatasetFeaturesArguments) {
        this.arguments = mapMultiDatasetFeaturesArguments;
    }
}
