package org.deeplearning4j.spark.impl.multilayer.scoring;

import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.class */
public class VaeReconstructionProbWithKeyFunction<K> extends BaseVaeReconstructionProbWithKeyFunction<K> {
    public VaeReconstructionProbWithKeyFunction(Broadcast<INDArray> broadcast, Broadcast<String> broadcast2, boolean z, int i, int i2) {
        super(broadcast, broadcast2, z, i, i2);
    }

    @Override // org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction
    public VariationalAutoencoder getVaeLayer() {
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) this.jsonConfig.getValue()));
        multiLayerNetwork.init();
        INDArray unsafeDuplication = ((INDArray) this.params.value()).unsafeDuplication();
        if (unsafeDuplication.length() != multiLayerNetwork.numParams(false)) {
            throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
        }
        multiLayerNetwork.setParameters(unsafeDuplication);
        VariationalAutoencoder layer = multiLayerNetwork.getLayer(0);
        if (layer instanceof VariationalAutoencoder) {
            return layer;
        }
        throw new RuntimeException("Cannot use VaeReconstructionProbWithKeyFunction on network that doesn't have a VAE layer as layer 0. Layer type: " + layer.getClass());
    }
}
