package org.campagnelab.dl.framework.domains;

import java.util.Properties;
import java.util.function.Function;
import org.campagnelab.dl.framework.architecture.graphs.ComputationGraphAssembler;
import org.campagnelab.dl.framework.domains.prediction.PredictionInterpreter;
import org.campagnelab.dl.framework.mappers.ConfigurableFeatureMapper;
import org.campagnelab.dl.framework.mappers.ConfigurableLabelMapper;
import org.campagnelab.dl.framework.mappers.FeatureMapper;
import org.campagnelab.dl.framework.mappers.LabelMapper;
import org.campagnelab.dl.framework.tools.TrainingArguments;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/campagnelab/dl/framework/domains/PretrainingDomainDescriptor.class */
public abstract class PretrainingDomainDescriptor<RecordType> extends DomainDescriptor<RecordType> {
    private static final Logger LOG = LoggerFactory.getLogger(PretrainingDomainDescriptor.class);
    private DomainDescriptor<RecordType> delegate;
    private String inputName;
    private String pretrainingFeatureMapperClassName;
    private String pretrainingLabelMapperClassName;
    private Integer eosIndex;

    public PretrainingDomainDescriptor(String str) {
        super.loadProperties(str);
        String property = this.modelProperties.getProperty("delegate.domain_descriptor");
        if (property == null) {
            throw new RuntimeException("Delegate domain descriptor not set. Should be set by PretrainModel run");
        }
        try {
            this.delegate = (DomainDescriptor) Class.forName(property).getConstructor(String.class).newInstance(str);
            initialize();
        } catch (Exception e) {
            throw new RuntimeException("Invalid instance of delegate domain descriptor. Should have a constructorfrom a model path and be parameterized on the same RecordType as the PretrainingDomainDescriptor", e);
        }
    }

    public PretrainingDomainDescriptor(DomainDescriptor<RecordType> domainDescriptor, TrainingArguments trainingArguments) {
        Properties pretrainingModelProperties = pretrainingModelProperties(trainingArguments);
        pretrainingModelProperties = pretrainingModelProperties == null ? new Properties() : pretrainingModelProperties;
        Properties pretrainingDomainProperties = pretrainingDomainProperties(trainingArguments);
        super.loadProperties(pretrainingDomainProperties == null ? new Properties() : pretrainingDomainProperties, pretrainingModelProperties);
        this.delegate = domainDescriptor;
        initialize();
    }

    private void initialize() {
        String property = this.delegate.modelProperties.getProperty("delegate.eos_index");
        if (property != null) {
            this.modelProperties.setProperty("delegate.eos_index", property);
        }
        String property2 = this.delegate.domainProperties.getProperty("net.architecture.classname");
        if (property2 != null) {
            this.domainProperties.setProperty("net.architecture.classname", property2);
        }
        String property3 = this.modelProperties.getProperty("delegate.eos_index");
        this.eosIndex = property3.equals("null") ? null : Integer.valueOf(Integer.parseInt(property3));
        this.pretrainingFeatureMapperClassName = pretrainingFeatureMapperClassName();
        this.pretrainingLabelMapperClassName = pretrainingLabelMapperClassName();
        ComputationGraphAssembler computationalGraph = this.delegate.getComputationalGraph();
        String[] inputNames = computationalGraph.getInputNames();
        if (inputNames.length != 1) {
            throw new IllegalArgumentException("Graph should only have one input");
        }
        this.inputName = inputNames[0];
        this.computationGraphAssembler = computationalGraph;
        this.delegate.computationGraphAssembler = computationalGraph;
        initializeArchitecture();
    }

    @Override // org.campagnelab.dl.framework.domains.DomainDescriptor
    public FeatureMapper getFeatureMapper(String str) {
        try {
            FeatureMapper featureMapper = (FeatureMapper) Class.forName(this.pretrainingFeatureMapperClassName).newInstance();
            ((ConfigurableFeatureMapper) featureMapper).configure(this.modelProperties);
            return featureMapper;
        } catch (Exception e) {
            throw new IllegalArgumentException("Invalid instance of feature mapper", e);
        }
    }

    @Override // org.campagnelab.dl.framework.domains.DomainDescriptor
    public LabelMapper getLabelMapper(String str) {
        try {
            LabelMapper labelMapper = (LabelMapper) Class.forName(this.pretrainingLabelMapperClassName).newInstance();
            ((ConfigurableLabelMapper) labelMapper).configure(this.modelProperties);
            return labelMapper;
        } catch (Exception e) {
            throw new IllegalArgumentException("Invalid instance of feature mapper", e);
        }
    }

    @Override // org.campagnelab.dl.framework.domains.DomainDescriptor
    public PredictionInterpreter getPredictionInterpreter(String str) {
        return this.delegate.getPredictionInterpreter(str);
    }

    @Override // org.campagnelab.dl.framework.domains.DomainDescriptor
    public Function<String, ? extends Iterable<RecordType>> getRecordIterable() {
        return this.delegate.getRecordIterable();
    }

    @Override // org.campagnelab.dl.framework.domains.DomainDescriptor
    public int[] getNumInputs(String str) {
        if (!str.equals(this.inputName)) {
            LOG.warn("Invalid input name; given {} but should be {}", str, this.inputName);
        }
        int[] iArr = (int[]) this.delegate.getNumInputs(str).clone();
        if (iArr.length != 2) {
            throw new IllegalArgumentException("Delegate number of inputs should be two dimensional");
        }
        if ((this.eosIndex != null && this.eosIndex.intValue() == iArr[0]) || this.eosIndex == null) {
            iArr[0] = iArr[0] + 1;
        }
        iArr[1] = iArr[1] * 2;
        iArr[1] = iArr[1] + 1;
        return iArr;
    }

    @Override // org.campagnelab.dl.framework.domains.DomainDescriptor
    public int[] getNumOutputs(String str) {
        return getNumInputs(this.inputName);
    }

    @Override // org.campagnelab.dl.framework.domains.DomainDescriptor
    public int[] getNumMaskInputs(String str) {
        if (!str.equals(this.inputName)) {
            LOG.warn("Invalid input name; given {} but should be {}", str, this.inputName);
        }
        int[] iArr = (int[]) this.delegate.getNumMaskInputs(str).clone();
        if (iArr.length != 1) {
            throw new IllegalArgumentException("Delegate mask should be one dimensional");
        }
        iArr[0] = iArr[0] * 2;
        iArr[0] = iArr[0] + 1;
        return iArr;
    }

    @Override // org.campagnelab.dl.framework.domains.DomainDescriptor
    public int[] getNumMaskOutputs(String str) {
        return getNumMaskInputs(this.inputName);
    }

    @Override // org.campagnelab.dl.framework.domains.DomainDescriptor
    public int getNumHiddenNodes(String str) {
        return this.delegate.getNumHiddenNodes(str);
    }

    @Override // org.campagnelab.dl.framework.domains.DomainDescriptor
    public long getNumRecords(String[] strArr) {
        return this.delegate.getNumRecords(strArr);
    }

    @Override // org.campagnelab.dl.framework.domains.DomainDescriptor
    public ComputationGraphAssembler getComputationalGraph() {
        return this.delegate.getComputationalGraph();
    }

    @Override // org.campagnelab.dl.framework.domains.DomainDescriptor
    public ILossFunction getOutputLoss(String str) {
        return this.delegate.getOutputLoss(str);
    }

    public abstract String pretrainingFeatureMapperClassName();

    public abstract String pretrainingLabelMapperClassName();

    public abstract Properties pretrainingModelProperties(TrainingArguments trainingArguments);

    public abstract Properties pretrainingDomainProperties(TrainingArguments trainingArguments);

    @Override // org.campagnelab.dl.framework.domains.DomainDescriptor
    public String produceCacheUniqueId(int i) {
        int hashCode = super.produceCacheUniqueId(i).hashCode();
        if (inputsPaddedEos() != null) {
            for (String str : inputsPaddedEos().keySet()) {
                hashCode = (hashCode ^ str.hashCode()) ^ inputsPaddedEos().get(str).hashCode();
            }
        }
        return Integer.toHexString(hashCode);
    }
}
