package com.gengoai.apollo.ml.model;

import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.DataSetType;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.encoder.Encoder;
import com.gengoai.apollo.ml.encoder.FixedEncoder;
import com.gengoai.apollo.ml.encoder.NoOptEncoder;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.transform.Transformer;
import com.gengoai.apollo.ml.transform.vectorizer.IndexingVectorizer;
import com.gengoai.collection.Iterables;
import com.gengoai.io.Compression;
import com.gengoai.io.MonitoredObject;
import com.gengoai.io.ResourceMonitor;
import com.gengoai.io.Resources;
import com.gengoai.io.resource.Resource;
import com.gengoai.json.Json;
import com.gengoai.reflection.Reflect;
import com.gengoai.reflection.ReflectionException;
import java.io.File;
import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.NonNull;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

/* loaded from: input_file:com/gengoai/apollo/ml/model/TensorFlowModel.class */
public abstract class TensorFlowModel implements Model {
    private static final long serialVersionUID = 1;
    protected final Map<String, TFVarSpec> inputs;
    protected final LinkedHashMap<String, TFVarSpec> outputs;
    private final FitParameters<?> fitParameters = new FitParameters<>();
    protected Resource modelFile;
    protected volatile transient Transformer transformer;
    private volatile transient MonitoredObject<SavedModelBundle> model;

    protected TensorFlowModel(@NonNull Map<String, TFVarSpec> map, @NonNull LinkedHashMap<String, TFVarSpec> linkedHashMap) {
        if (map == null) {
            throw new NullPointerException("inputs is marked non-null but is null");
        }
        if (linkedHashMap == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        this.inputs = new HashMap(map);
        this.outputs = new LinkedHashMap<>(linkedHashMap);
    }

    public static Model load(@NonNull Resource resource) throws IOException {
        if (resource == null) {
            throw new NullPointerException("resource is marked non-null but is null");
        }
        try {
            TensorFlowModel tensorFlowModel = (TensorFlowModel) Reflect.onClass(Reflect.getClassForNameQuietly(resource.getChild("__class__").readToString().strip())).allowPrivilegedAccess().create().get();
            for (Resource resource2 : resource.getChildren("*.encoder.json.gz")) {
                tensorFlowModel.setEncoder(resource2.baseName().replace(".encoder.json.gz", "").strip(), (Encoder) Json.parse(resource2, Encoder.class));
            }
            tensorFlowModel.transformer = tensorFlowModel.createTransformer();
            tensorFlowModel.modelFile = resource;
            return tensorFlowModel;
        } catch (ReflectionException e) {
            throw new IOException((Throwable) e);
        }
    }

    protected int calculate_max_sequence_length(DataSet dataSet) {
        return 0;
    }

    protected Map<String, Tensor<?>> createTensors(DataSet dataSet) {
        int size = (int) dataSet.size();
        int calculate_max_sequence_length = calculate_max_sequence_length(dataSet);
        HashMap hashMap = new HashMap();
        this.inputs.forEach((str, tFVarSpec) -> {
            hashMap.put(str, tFVarSpec.createBatchNDArray(size, calculate_max_sequence_length));
        });
        int i = 0;
        Iterator<Datum> it = dataSet.iterator();
        while (it.hasNext()) {
            Datum next = it.next();
            for (String str2 : this.inputs.keySet()) {
                this.inputs.get(str2).updateBatch((NDArray) hashMap.get(str2), i, next.get(str2).asNDArray());
            }
            i++;
        }
        HashMap hashMap2 = new HashMap();
        this.inputs.forEach((str3, tFVarSpec2) -> {
            hashMap2.put(tFVarSpec2.getServingName(), tFVarSpec2.toTensor((NDArray) hashMap.get(str3)));
        });
        return hashMap2;
    }

    protected Transformer createTransformer() {
        return new Transformer((List) Stream.concat(this.inputs.entrySet().stream(), this.outputs.entrySet().stream()).filter(entry -> {
            return !(((TFVarSpec) entry.getValue()).getEncoder() instanceof NoOptEncoder);
        }).map(entry2 -> {
            return (IndexingVectorizer) new IndexingVectorizer(((TFVarSpec) entry2.getValue()).getEncoder()).source((String) entry2.getKey());
        }).collect(Collectors.toList()));
    }

    private Datum decode(Datum datum, List<NDArray> list, long j) {
        int i = 0;
        for (Map.Entry<String, TFVarSpec> entry : this.outputs.entrySet()) {
            if (list.get(i).shape().order() > 2) {
                datum.put(entry.getKey(), decodeNDArray(entry.getKey(), list.get(i).slice((int) j)));
            } else {
                datum.put(entry.getKey(), decodeNDArray(entry.getKey(), list.get(i).getRow((int) j)));
            }
            i++;
        }
        return datum;
    }

    protected abstract Observation decodeNDArray(String str, NDArray nDArray);

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        DataSet fitAndTransform = createTransformer().fitAndTransform(dataSet);
        fitAndTransform.getMetadata().forEach((str, observationMetadata) -> {
            setEncoder(str, observationMetadata.getEncoder());
        });
        Resource temporaryFile = Resources.temporaryFile();
        fitAndTransform.persist(temporaryFile);
        System.out.println("DataSet saved to: " + temporaryFile.descriptor());
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public FitParameters<?> getFitParameters() {
        return this.fitParameters;
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public final Set<String> getInputs() {
        return Collections.unmodifiableSet(this.inputs.keySet());
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public final Set<String> getOutputs() {
        return Collections.unmodifiableSet(this.outputs.keySet());
    }

    private SavedModelBundle getTensorFlowModel() {
        if (this.model == null) {
            synchronized (this) {
                if (this.model == null) {
                    this.model = ResourceMonitor.monitor(SavedModelBundle.load(((File) this.modelFile.getChild("tfmodel").asFile().orElseThrow()).getAbsolutePath(), new String[]{"serve"}));
                    this.transformer = createTransformer();
                }
            }
        }
        return (SavedModelBundle) this.model.object;
    }

    protected final List<Datum> processBatch(DataSet dataSet) {
        DataSet transform = this.transformer.transform(dataSet);
        Session.Runner runner = getTensorFlowModel().session().runner();
        Map<String, Tensor<?>> createTensors = createTensors(transform);
        Objects.requireNonNull(runner);
        createTensors.forEach(runner::feed);
        this.outputs.forEach((str, tFVarSpec) -> {
            runner.fetch(tFVarSpec.getServingName());
        });
        ArrayList arrayList = new ArrayList();
        for (Tensor<?> tensor : runner.run()) {
            arrayList.add(NDArrayFactory.ND.fromTensorFlowTensor(tensor));
            tensor.close();
        }
        ArrayList arrayList2 = new ArrayList();
        transform.stream().zipWithIndex().forEachLocal((datum, l) -> {
            arrayList2.add(decode(datum, arrayList, l.longValue()));
        });
        createTensors.values().forEach((v0) -> {
            v0.close();
        });
        return arrayList2;
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void save(@NonNull Resource resource) throws IOException {
        if (resource == null) {
            throw new NullPointerException("resource is marked non-null but is null");
        }
        for (Map.Entry entry : Iterables.concat(new Iterable[]{this.inputs.entrySet(), this.outputs.entrySet()})) {
            Encoder encoder = ((TFVarSpec) entry.getValue()).getEncoder();
            if (!(encoder instanceof NoOptEncoder) && !(encoder instanceof FixedEncoder)) {
                Json.dumpPretty(((TFVarSpec) entry.getValue()).encoder, resource.getChild(((String) entry.getKey()) + ".encoder.json.gz").setCompression(Compression.GZIP));
            }
        }
    }

    protected void setEncoder(String str, Encoder encoder) {
        if (this.inputs.containsKey(str)) {
            this.inputs.get(str).setEncoder(encoder);
        } else {
            this.outputs.get(str).setEncoder(encoder);
        }
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public final Datum transform(@NonNull Datum datum) {
        if (datum == null) {
            throw new NullPointerException("datum is marked non-null but is null");
        }
        return processBatch(DataSetType.InMemory.create(Stream.of(datum))).get(0);
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public DataSet transform(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        return dataSet.map(this::transform);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1052666732:
                if (implMethodName.equals("transform")) {
                    z = false;
                    break;
                }
                break;
            case 1389842520:
                if (implMethodName.equals("lambda$processBatch$a60f13ac$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/TensorFlowModel") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/Datum;)Lcom/gengoai/apollo/ml/Datum;")) {
                    TensorFlowModel tensorFlowModel = (TensorFlowModel) serializedLambda.getCapturedArg(0);
                    return tensorFlowModel::transform;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableBiConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/TensorFlowModel") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Ljava/util/List;Lcom/gengoai/apollo/ml/Datum;Ljava/lang/Long;)V")) {
                    TensorFlowModel tensorFlowModel2 = (TensorFlowModel) serializedLambda.getCapturedArg(0);
                    List list = (List) serializedLambda.getCapturedArg(1);
                    List list2 = (List) serializedLambda.getCapturedArg(2);
                    return (datum, l) -> {
                        list.add(decode(datum, list2, l.longValue()));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
