package org.campagnelab.dl.framework.domains;

import com.google.common.collect.Iterables;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.List;
import java.util.Properties;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.campagnelab.dl.framework.architecture.graphs.ComputationGraphAssembler;
import org.campagnelab.dl.framework.domains.prediction.PredictionInterpreter;
import org.campagnelab.dl.framework.mappers.FeatureMapper;
import org.campagnelab.dl.framework.mappers.LabelMapper;
import org.campagnelab.dl.framework.models.ModelLoader;
import org.campagnelab.dl.framework.performance.PerformanceMetricDescriptor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/* loaded from: input_file:org/campagnelab/dl/framework/domains/DomainDescriptor.class */
public abstract class DomainDescriptor<RecordType> {
    protected Properties domainProperties;
    protected Properties modelProperties;
    protected ComputationGraphAssembler computationGraphAssembler;
    static final /* synthetic */ boolean $assertionsDisabled;

    public abstract FeatureMapper getFeatureMapper(String str);

    public abstract LabelMapper getLabelMapper(String str);

    public abstract PredictionInterpreter getPredictionInterpreter(String str);

    public abstract Function<String, ? extends Iterable<RecordType>> getRecordIterable();

    public abstract ComputationGraphAssembler getComputationalGraph();

    public abstract int[] getNumInputs(String str);

    public abstract int[] getNumOutputs(String str);

    public abstract int[] getNumMaskInputs(String str);

    public abstract int[] getNumMaskOutputs(String str);

    public abstract int getNumHiddenNodes(String str);

    public abstract LossFunctions.LossFunction getOutputLoss(String str);

    public abstract long getNumRecords(String[] strArr);

    public PerformanceMetricDescriptor<RecordType> performanceDescritor() {
        return new PerformanceMetricDescriptor<RecordType>() { // from class: org.campagnelab.dl.framework.domains.DomainDescriptor.1
            @Override // org.campagnelab.dl.framework.performance.PerformanceMetricDescriptor
            public String[] performanceMetrics() {
                return new String[]{"score"};
            }

            @Override // org.campagnelab.dl.framework.performance.PerformanceMetricDescriptor
            public boolean largerValueIsBetterPerformance(String str) {
                return false;
            }

            @Override // org.campagnelab.dl.framework.performance.PerformanceMetricDescriptor
            public double estimateMetric(ComputationGraph computationGraph, String str, MultiDataSetIterator multiDataSetIterator, long j) {
                return DomainDescriptor.this.estimateScore(computationGraph, str, multiDataSetIterator, j);
            }

            @Override // org.campagnelab.dl.framework.performance.PerformanceMetricDescriptor
            public String earlyStoppingMetric() {
                return "score";
            }
        };
    }

    protected double estimateScore(ComputationGraph computationGraph, String str, MultiDataSetIterator multiDataSetIterator, long j) {
        boolean z = -1;
        switch (str.hashCode()) {
            case 109264530:
                if (str.equals("score")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                double d = 0.0d;
                long j2 = 0;
                long j3 = 0;
                while (multiDataSetIterator.hasNext()) {
                    double score = computationGraph.score((MultiDataSet) multiDataSetIterator.next());
                    if (score == score) {
                        d += score;
                    }
                    j2++;
                    j3 += r0.getFeatures()[0].size(0);
                    if (j3 > j) {
                        return d / j2;
                    }
                }
                return d / j2;
            default:
                throw new IllegalArgumentException("metric name not recognized: " + str);
        }
    }

    public int[] getInputShape(int i, String str) {
        return getShape(i, () -> {
            return getNumInputs(str);
        });
    }

    public int[] getLabelShape(int i, String str) {
        return getShape(i, () -> {
            return getNumOutputs(str);
        });
    }

    public int[] getInputMaskShape(int i, String str) {
        return getShape(i, () -> {
            return getNumMaskInputs(str);
        });
    }

    public int[] getLabelMaskShape(int i, String str) {
        return getShape(i, () -> {
            return getNumMaskOutputs(str);
        });
    }

    public int[] getShape(int i, Supplier<int[]> supplier) {
        int[] iArr = supplier.get();
        if (!$assertionsDisabled && iArr.length > 2) {
            throw new AssertionError();
        }
        switch (iArr.length) {
            case 1:
                return new int[]{i, iArr[0]};
            case 2:
                return new int[]{i, iArr[0], iArr[1]};
            default:
                throw new UnsupportedOperationException();
        }
    }

    public int getNumModelOutputs() {
        return getComputationalGraph().getOutputNames().length;
    }

    public boolean hasOutput(String str) {
        for (String str2 : getComputationalGraph().getOutputNames()) {
            if (str.equals(str2)) {
                return true;
            }
        }
        return false;
    }

    public FeatureMapper[] featureMappers() {
        FeatureMapper[] featureMapperArr = new FeatureMapper[getNumModelInputs()];
        int i = 0;
        for (String str : getComputationalGraph().getInputNames()) {
            int i2 = i;
            i++;
            featureMapperArr[i2] = getFeatureMapper(str);
        }
        return featureMapperArr;
    }

    public int getNumModelInputs() {
        return getComputationalGraph().getInputNames().length;
    }

    public void loadProperties(String str) {
        this.domainProperties = new Properties();
        this.modelProperties = new Properties();
        String str2 = ModelLoader.getModelPath(str) + "/domain.properties";
        String str3 = ModelLoader.getModelPath(str) + "/config.properties";
        try {
            this.domainProperties.load(new FileReader(str2));
            this.modelProperties.load(new FileReader(str3));
        } catch (IOException e) {
            throw new RuntimeException("Unable to load domain properties in model path " + str, e);
        }
    }

    public void loadProperties(Properties properties, Properties properties2) {
        this.domainProperties = new Properties();
        this.modelProperties = new Properties();
        this.domainProperties.putAll(properties);
        this.modelProperties.putAll(properties2);
    }

    public void writeProperties(String str) {
        Properties properties = new Properties();
        String str2 = ModelLoader.getModelPath(str) + "/domain.properties";
        putProperties(properties);
        try {
            properties.store(new FileWriter(str2), "Domain properties created with " + getClass().getCanonicalName());
        } catch (IOException e) {
            throw new RuntimeException("Unable to write domain descriptor properties to " + str2, e);
        }
    }

    public void putProperties(Properties properties) {
        properties.put("net.architecture.classname", this.computationGraphAssembler.getClass().getCanonicalName());
        String[] inputNames = getComputationalGraph().getInputNames();
        String[] outputNames = getComputationalGraph().getOutputNames();
        for (String str : inputNames) {
            properties.put(str + ".featureMapper", getFeatureMapper(str).getClass().getCanonicalName());
            properties.put(str + ".featureMapper.numFeatures", Integer.toString(getFeatureMapper(str).numberOfFeatures()));
        }
        for (String str2 : outputNames) {
            properties.put(str2 + ".labelMapper", getLabelMapper(str2).getClass().getCanonicalName());
            properties.put(str2 + ".labelMapper.numLabels", Integer.toString(getLabelMapper(str2).numberOfLabels()));
            properties.put(str2 + ".predictionInterpreter", getPredictionInterpreter(str2).getClass().getCanonicalName());
        }
    }

    public Iterable<RecordType> getRecordIterable(List<String> list, int i) {
        return Iterables.limit(Iterables.concat((Iterable) list.stream().map(str -> {
            return getRecordIterable().apply(str);
        }).collect(Collectors.toList())), i);
    }

    protected void initializeArchitecture(String str) {
        try {
            this.computationGraphAssembler = (ComputationGraphAssembler) Class.forName(str).newInstance();
        } catch (Exception e) {
            throw new RuntimeException("Unable to load computation graph: " + str);
        }
    }

    protected void initializeArchitecture() {
        initializeArchitecture(this.domainProperties.getProperty("net.architecture.classname"));
    }

    static {
        $assertionsDisabled = !DomainDescriptor.class.desiredAssertionStatus();
    }
}
