package org.nd4j.samediff.frameworkimport.registry;

import java.io.File;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.function.BiConsumer;
import kotlin.Metadata;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.internal.Intrinsics;
import org.apache.commons.collections4.MultiSet;
import org.apache.commons.collections4.MultiValuedMap;
import org.apache.commons.collections4.multimap.HashSetValuedHashMap;
import org.apache.commons.io.FileUtils;
import org.jetbrains.annotations.NotNull;
import org.nd4j.ir.MapperNamespace;
import org.nd4j.ir.OpNamespace;
import org.nd4j.samediff.frameworkimport.IRProtobufExtensionsKt;
import org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoaderHolder;
import org.nd4j.samediff.frameworkimport.process.MappingProcess;
import org.nd4j.samediff.frameworkimport.process.MappingProcessLoader;
import org.nd4j.shade.protobuf.GeneratedMessageV3;
import org.nd4j.shade.protobuf.Message;
import org.nd4j.shade.protobuf.ProtocolMessageEnum;
import org.nd4j.shade.protobuf.TextFormat;

/* compiled from: OpMappingRegistry.kt */
@Metadata(mv = {1, 4, 2}, bv = {1, 0, 3}, k = 1, d1 = {"��\u0080\u0001\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0010��\n��\n\u0002\u0010\u000e\n��\n\u0002\u0018\u0002\n\u0002\b\u0004\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0007\n\u0002\u0018\u0002\n\u0002\u0018\u0002\n\u0002\b\u0003\n\u0002\u0010\u000b\n\u0002\b\u0002\n\u0002\u0010\"\n��\n\u0002\u0010\u0002\n��\n\u0002\u0018\u0002\n��\n\u0002\u0018\u0002\n\u0002\b\t\n\u0002\u0018\u0002\n\u0002\b\u0002\n\u0002\u0018\u0002\n\u0002\b\b\n\u0002\u0010$\n\u0002\b\u0002\u0018��*\b\b��\u0010\u0001*\u00020\u0002*\b\b\u0001\u0010\u0003*\u00020\u0002*\b\b\u0002\u0010\u0004*\u00020\u0002*\b\b\u0003\u0010\u0005*\u00020\u0002*\b\b\u0004\u0010\u0006*\u00020\u0007*\b\b\u0005\u0010\b*\u00020\u0002*\b\b\u0006\u0010\t*\u00020\u00022\u00020\nB\u0015\u0012\u0006\u0010\u000b\u001a\u00020\f\u0012\u0006\u0010\r\u001a\u00020\u000e¢\u0006\u0002\u0010\u000fJ\u000e\u0010!\u001a\u00020\"2\u0006\u0010#\u001a\u00020\fJ\f\u0010$\u001a\b\u0012\u0004\u0012\u00020\f0%J@\u0010&\u001a\u00020'2\u0006\u0010(\u001a\u00020)20\u0010*\u001a,\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0003\u0012\u0004\u0012\u00028\u0005\u0012\u0004\u0012\u00028\u0006\u0012\u0004\u0012\u00028\u00040+J@\u0010,\u001a\u00020'2\u0006\u0010-\u001a\u00020\f20\u0010*\u001a,\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0003\u0012\u0004\u0012\u00028\u0005\u0012\u0004\u0012\u00028\u0006\u0012\u0004\u0012\u00028\u00040+J\u0013\u0010.\u001a\u00028\u00022\u0006\u0010/\u001a\u00020\f¢\u0006\u0002\u00100J\u000e\u00101\u001a\u00020\u00142\u0006\u0010/\u001a\u00020\fJ8\u00102\u001a,\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0003\u0012\u0004\u0012\u00028\u0005\u0012\u0004\u0012\u00028\u0006\u0012\u0004\u0012\u00028\u00040\u001e2\u0006\u0010#\u001a\u00020\fJ\f\u00103\u001a\b\u0012\u0004\u0012\u00020\f0%J\f\u00104\u001a\b\u0012\u0004\u0012\u00020\f05J\f\u00106\u001a\b\u0012\u0004\u0012\u00020\f0%J\u000e\u00107\u001a\u0002082\u0006\u00109\u001a\u00020\fJ\u001b\u0010:\u001a\u00020'2\u0006\u0010/\u001a\u00020\f2\u0006\u0010;\u001a\u00028\u0002¢\u0006\u0002\u0010<J@\u0010=\u001a\u00020'2\u0006\u0010#\u001a\u00020\f20\u0010>\u001a,\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0003\u0012\u0004\u0012\u00028\u0005\u0012\u0004\u0012\u00028\u0006\u0012\u0004\u0012\u00028\u00040\u001eJ\u0016\u0010?\u001a\u00020'2\u0006\u0010/\u001a\u00020\f2\u0006\u0010;\u001a\u00020\u0014J\u001a\u0010@\u001a\u00020'2\u0012\u0010\u001a\u001a\u000e\u0012\u0004\u0012\u00020\f\u0012\u0004\u0012\u00028\u00020AJ\u0006\u0010B\u001a\u00020'R\u0011\u0010\u000b\u001a\u00020\f¢\u0006\b\n��\u001a\u0004\b\u0010\u0010\u0011R-\u0010\u0012\u001a\u001e\u0012\u0004\u0012\u00020\f\u0012\u0004\u0012\u00020\u00140\u0013j\u000e\u0012\u0004\u0012\u00020\f\u0012\u0004\u0012\u00020\u0014`\u0015¢\u0006\b\n��\u001a\u0004\b\u0016\u0010\u0017R\u0011\u0010\r\u001a\u00020\u000e¢\u0006\b\n��\u001a\u0004\b\u0018\u0010\u0019R-\u0010\u001a\u001a\u001e\u0012\u0004\u0012\u00020\f\u0012\u0004\u0012\u00028\u00020\u0013j\u000e\u0012\u0004\u0012\u00020\f\u0012\u0004\u0012\u00028\u0002`\u0015¢\u0006\b\n��\u001a\u0004\b\u001b\u0010\u0017RG\u0010\u001c\u001a8\u0012\u0004\u0012\u00020\f\u0012.\u0012,\u0012\u0004\u0012\u00028��\u0012\u0004\u0012\u00028\u0002\u0012\u0004\u0012\u00028\u0001\u0012\u0004\u0012\u00028\u0003\u0012\u0004\u0012\u00028\u0005\u0012\u0004\u0012\u00028\u0006\u0012\u0004\u0012\u00028\u00040\u001e0\u001d¢\u0006\b\n��\u001a\u0004\b\u001f\u0010 ¨\u0006C"}, d2 = {"Lorg/nd4j/samediff/frameworkimport/registry/OpMappingRegistry;", "GRAPH_TYPE", "Lorg/nd4j/shade/protobuf/GeneratedMessageV3;", "NODE_TYPE", "OP_DEF_TYPE", "TENSOR_TYPE", "DATA_TYPE", "Lorg/nd4j/shade/protobuf/ProtocolMessageEnum;", "ATTRIBUTE_TYPE", "ATTRIBUTE_VALUE_TYPE", "", "inputFrameworkName", "", "nd4jOpDescriptors", "Lorg/nd4j/ir/OpNamespace$OpDescriptorList;", "(Ljava/lang/String;Lorg/nd4j/ir/OpNamespace$OpDescriptorList;)V", "getInputFrameworkName", "()Ljava/lang/String;", "nd4jOpDefs", "Ljava/util/HashMap;", "Lorg/nd4j/ir/OpNamespace$OpDescriptor;", "Lkotlin/collections/HashMap;", "getNd4jOpDefs", "()Ljava/util/HashMap;", "getNd4jOpDescriptors", "()Lorg/nd4j/ir/OpNamespace$OpDescriptorList;", "opDefList", "getOpDefList", "registeredOps", "Lorg/apache/commons/collections4/MultiValuedMap;", "Lorg/nd4j/samediff/frameworkimport/process/MappingProcess;", "getRegisteredOps", "()Lorg/apache/commons/collections4/MultiValuedMap;", "hasMappingOpProcess", "", "inputFrameworkOpName", "inputFrameworkOpNames", "", "loadFromDefinitions", "", "mapperDeclarations", "Lorg/nd4j/ir/MapperNamespace$MappingDefinitionSet;", "mappingProcessLoader", "Lorg/nd4j/samediff/frameworkimport/process/MappingProcessLoader;", "loadFromFile", "mapperDeclarationsFile", "lookupInputFrameworkOpDef", "name", "(Ljava/lang/String;)Lorg/nd4j/shade/protobuf/GeneratedMessageV3;", "lookupNd4jOpDef", "lookupOpMappingProcess", "mappedNd4jOpNames", "mappingProcessNames", "Lorg/apache/commons/collections4/MultiSet;", "nd4jOpNames", "opTypeForName", "Lorg/nd4j/ir/OpNamespace$OpDescriptor$OpDeclarationType;", "nd4jOpName", "registerInputFrameworkOpDef", "opDef", "(Ljava/lang/String;Lorg/nd4j/shade/protobuf/GeneratedMessageV3;)V", "registerMappingProcess", "processToRegister", "registerNd4jOpDef", "registerOpDefs", "", "saveProcessesAndRuleSet", "samediff-import-api"})
/* loaded from: input_file:org/nd4j/samediff/frameworkimport/registry/OpMappingRegistry.class */
public final class OpMappingRegistry<GRAPH_TYPE extends GeneratedMessageV3, NODE_TYPE extends GeneratedMessageV3, OP_DEF_TYPE extends GeneratedMessageV3, TENSOR_TYPE extends GeneratedMessageV3, DATA_TYPE extends ProtocolMessageEnum, ATTRIBUTE_TYPE extends GeneratedMessageV3, ATTRIBUTE_VALUE_TYPE extends GeneratedMessageV3> {

    @NotNull
    private final MultiValuedMap<String, MappingProcess<GRAPH_TYPE, OP_DEF_TYPE, NODE_TYPE, TENSOR_TYPE, ATTRIBUTE_TYPE, ATTRIBUTE_VALUE_TYPE, DATA_TYPE>> registeredOps;

    @NotNull
    private final HashMap<String, OP_DEF_TYPE> opDefList;

    @NotNull
    private final HashMap<String, OpNamespace.OpDescriptor> nd4jOpDefs;

    @NotNull
    private final String inputFrameworkName;

    @NotNull
    private final OpNamespace.OpDescriptorList nd4jOpDescriptors;

    @NotNull
    public final MultiValuedMap<String, MappingProcess<GRAPH_TYPE, OP_DEF_TYPE, NODE_TYPE, TENSOR_TYPE, ATTRIBUTE_TYPE, ATTRIBUTE_VALUE_TYPE, DATA_TYPE>> getRegisteredOps() {
        return this.registeredOps;
    }

    @NotNull
    public final HashMap<String, OP_DEF_TYPE> getOpDefList() {
        return this.opDefList;
    }

    @NotNull
    public final HashMap<String, OpNamespace.OpDescriptor> getNd4jOpDefs() {
        return this.nd4jOpDefs;
    }

    @NotNull
    public final String getInputFrameworkName() {
        return this.inputFrameworkName;
    }

    @NotNull
    public final OpNamespace.OpDescriptorList getNd4jOpDescriptors() {
        return this.nd4jOpDescriptors;
    }

    @NotNull
    public final Set<String> mappedNd4jOpNames() {
        Collection values = this.registeredOps.values();
        Intrinsics.checkNotNullExpressionValue(values, "registeredOps.values()");
        Collection collection = values;
        ArrayList arrayList = new ArrayList(CollectionsKt.collectionSizeOrDefault(collection, 10));
        Iterator it = collection.iterator();
        while (it.hasNext()) {
            arrayList.add(((MappingProcess) it.next()).opName());
        }
        SortedSet sortedSet = CollectionsKt.toSortedSet(arrayList);
        Intrinsics.checkNotNull(sortedSet);
        return sortedSet;
    }

    @NotNull
    public final MultiSet<String> mappingProcessNames() {
        MultiSet<String> keys = this.registeredOps.keys();
        Intrinsics.checkNotNull(keys);
        return keys;
    }

    @NotNull
    public final Set<String> nd4jOpNames() {
        Set<String> keySet = this.nd4jOpDefs.keySet();
        Intrinsics.checkNotNullExpressionValue(keySet, "nd4jOpDefs.keys");
        return keySet;
    }

    @NotNull
    public final Set<String> inputFrameworkOpNames() {
        Set<String> keySet = this.opDefList.keySet();
        Intrinsics.checkNotNullExpressionValue(keySet, "opDefList.keys");
        return keySet;
    }

    @NotNull
    public final OpNamespace.OpDescriptor lookupNd4jOpDef(@NotNull String str) {
        Intrinsics.checkNotNullParameter(str, "name");
        OpNamespace.OpDescriptor opDescriptor = this.nd4jOpDefs.get(str);
        Intrinsics.checkNotNull(opDescriptor);
        return opDescriptor;
    }

    public final void registerOpDefs(@NotNull Map<String, ? extends OP_DEF_TYPE> map) {
        Intrinsics.checkNotNullParameter(map, "opDefList");
        for (Map.Entry<String, ? extends OP_DEF_TYPE> entry : map.entrySet()) {
            registerInputFrameworkOpDef(entry.getKey(), entry.getValue());
        }
    }

    public final void registerNd4jOpDef(@NotNull String str, @NotNull OpNamespace.OpDescriptor opDescriptor) {
        Intrinsics.checkNotNullParameter(str, "name");
        Intrinsics.checkNotNullParameter(opDescriptor, "opDef");
        this.nd4jOpDefs.put(str, opDescriptor);
    }

    @NotNull
    public final OP_DEF_TYPE lookupInputFrameworkOpDef(@NotNull String str) {
        Intrinsics.checkNotNullParameter(str, "name");
        if (this.opDefList.isEmpty()) {
            OpDescriptorLoaderHolder.INSTANCE.listForFramework(this.inputFrameworkName).forEach(new BiConsumer<String, OP_DEF_TYPE>() { // from class: org.nd4j.samediff.frameworkimport.registry.OpMappingRegistry$lookupInputFrameworkOpDef$1
                /* JADX WARN: Incorrect types in method signature: (Ljava/lang/String;TOP_DEF_TYPE;)V */
                @Override // java.util.function.BiConsumer
                public final void accept(@NotNull String str2, @NotNull GeneratedMessageV3 generatedMessageV3) {
                    Intrinsics.checkNotNullParameter(str2, "name");
                    Intrinsics.checkNotNullParameter(generatedMessageV3, "opDefType");
                    OpMappingRegistry.this.getOpDefList().put(str2, generatedMessageV3);
                }
            });
        }
        OP_DEF_TYPE op_def_type = this.opDefList.get(str);
        Intrinsics.checkNotNull(op_def_type);
        return op_def_type;
    }

    public final void registerInputFrameworkOpDef(@NotNull String str, @NotNull OP_DEF_TYPE op_def_type) {
        Intrinsics.checkNotNullParameter(str, "name");
        Intrinsics.checkNotNullParameter(op_def_type, "opDef");
        this.opDefList.put(str, op_def_type);
    }

    public final void registerMappingProcess(@NotNull String str, @NotNull MappingProcess<GRAPH_TYPE, OP_DEF_TYPE, NODE_TYPE, TENSOR_TYPE, ATTRIBUTE_TYPE, ATTRIBUTE_VALUE_TYPE, DATA_TYPE> mappingProcess) {
        Intrinsics.checkNotNullParameter(str, "inputFrameworkOpName");
        Intrinsics.checkNotNullParameter(mappingProcess, "processToRegister");
        this.registeredOps.put(str, mappingProcess);
    }

    public final boolean hasMappingOpProcess(@NotNull String str) {
        Intrinsics.checkNotNullParameter(str, "inputFrameworkOpName");
        return this.registeredOps.containsKey(str);
    }

    @NotNull
    public final MappingProcess<GRAPH_TYPE, OP_DEF_TYPE, NODE_TYPE, TENSOR_TYPE, ATTRIBUTE_TYPE, ATTRIBUTE_VALUE_TYPE, DATA_TYPE> lookupOpMappingProcess(@NotNull String str) {
        Intrinsics.checkNotNullParameter(str, "inputFrameworkOpName");
        if (!this.registeredOps.containsKey(str)) {
            throw new IllegalArgumentException("No import process defined for " + str);
        }
        Collection collection = this.registeredOps.get(str);
        Intrinsics.checkNotNull(collection);
        Object first = CollectionsKt.first(collection);
        Intrinsics.checkNotNullExpressionValue(first, "registeredOps[inputFrameworkOpName]!!.first()");
        return (MappingProcess) first;
    }

    @NotNull
    public final OpNamespace.OpDescriptor.OpDeclarationType opTypeForName(@NotNull String str) {
        Intrinsics.checkNotNullParameter(str, "nd4jOpName");
        OpNamespace.OpDescriptor.OpDeclarationType opDeclarationType = IRProtobufExtensionsKt.findOp(this.nd4jOpDescriptors, str).getOpDeclarationType();
        Intrinsics.checkNotNullExpressionValue(opDeclarationType, "descriptor.opDeclarationType");
        return opDeclarationType;
    }

    public final void loadFromFile(@NotNull String str, @NotNull MappingProcessLoader<GRAPH_TYPE, OP_DEF_TYPE, NODE_TYPE, TENSOR_TYPE, ATTRIBUTE_TYPE, ATTRIBUTE_VALUE_TYPE, DATA_TYPE> mappingProcessLoader) {
        Intrinsics.checkNotNullParameter(str, "mapperDeclarationsFile");
        Intrinsics.checkNotNullParameter(mappingProcessLoader, "mappingProcessLoader");
        byte[] readFileToByteArray = FileUtils.readFileToByteArray(new File(str));
        Message.Builder newBuilder = MapperNamespace.MappingDefinitionSet.newBuilder();
        Intrinsics.checkNotNullExpressionValue(readFileToByteArray, "bytes");
        Charset defaultCharset = Charset.defaultCharset();
        Intrinsics.checkNotNullExpressionValue(defaultCharset, "Charset.defaultCharset()");
        TextFormat.merge(new String(readFileToByteArray, defaultCharset), newBuilder);
        MapperNamespace.MappingDefinitionSet build = newBuilder.build();
        Intrinsics.checkNotNullExpressionValue(build, "defs");
        loadFromDefinitions(build, mappingProcessLoader);
    }

    public final void loadFromDefinitions(@NotNull MapperNamespace.MappingDefinitionSet mappingDefinitionSet, @NotNull MappingProcessLoader<GRAPH_TYPE, OP_DEF_TYPE, NODE_TYPE, TENSOR_TYPE, ATTRIBUTE_TYPE, ATTRIBUTE_VALUE_TYPE, DATA_TYPE> mappingProcessLoader) {
        Intrinsics.checkNotNullParameter(mappingDefinitionSet, "mapperDeclarations");
        Intrinsics.checkNotNullParameter(mappingProcessLoader, "mappingProcessLoader");
        List<MapperNamespace.MapperDeclaration> mappingsList = mappingDefinitionSet.getMappingsList();
        Intrinsics.checkNotNullExpressionValue(mappingsList, "mapperDeclarations.mappingsList");
        for (MapperNamespace.MapperDeclaration mapperDeclaration : mappingsList) {
            Intrinsics.checkNotNullExpressionValue(mapperDeclaration, "it");
            MappingProcess<GRAPH_TYPE, OP_DEF_TYPE, NODE_TYPE, TENSOR_TYPE, ATTRIBUTE_TYPE, ATTRIBUTE_VALUE_TYPE, DATA_TYPE> createProcess = mappingProcessLoader.createProcess(mapperDeclaration);
            String inputFrameworkOpName = mapperDeclaration.getInputFrameworkOpName();
            Intrinsics.checkNotNullExpressionValue(inputFrameworkOpName, "it.inputFrameworkOpName");
            registerMappingProcess(inputFrameworkOpName, createProcess);
        }
    }

    public final void saveProcessesAndRuleSet() {
        final ArrayList arrayList = new ArrayList();
        final StringBuilder sb = new StringBuilder();
        this.registeredOps.asMap().forEach(new BiConsumer<String, Collection<MappingProcess<GRAPH_TYPE, OP_DEF_TYPE, NODE_TYPE, TENSOR_TYPE, ATTRIBUTE_TYPE, ATTRIBUTE_VALUE_TYPE, DATA_TYPE>>>() { // from class: org.nd4j.samediff.frameworkimport.registry.OpMappingRegistry$saveProcessesAndRuleSet$1
            @Override // java.util.function.BiConsumer
            public final void accept(String str, Collection<MappingProcess<GRAPH_TYPE, OP_DEF_TYPE, NODE_TYPE, TENSOR_TYPE, ATTRIBUTE_TYPE, ATTRIBUTE_VALUE_TYPE, DATA_TYPE>> collection) {
                Intrinsics.checkNotNullExpressionValue(collection, "listOfMappingProcesses");
                Iterator<T> it = collection.iterator();
                while (it.hasNext()) {
                    arrayList.add(((MappingProcess) it.next()).serialize());
                }
                ArrayList arrayList2 = arrayList;
                ArrayList arrayList3 = new ArrayList(CollectionsKt.collectionSizeOrDefault(arrayList2, 10));
                Iterator<T> it2 = arrayList2.iterator();
                while (it2.hasNext()) {
                    arrayList3.add(((MapperNamespace.MapperDeclaration) it2.next()).toString());
                }
                Iterator<T> it3 = arrayList3.iterator();
                while (it3.hasNext()) {
                    sb.append(((String) it3.next()) + "\n");
                }
            }
        });
        MapperNamespace.MappingDefinitionSet.Builder newBuilder = MapperNamespace.MappingDefinitionSet.newBuilder();
        newBuilder.addAllMappings(arrayList);
        FileUtils.write(new File(this.inputFrameworkName + "-processes.pbtxt"), newBuilder.build().toString(), Charset.defaultCharset());
    }

    public OpMappingRegistry(@NotNull String str, @NotNull OpNamespace.OpDescriptorList opDescriptorList) {
        Intrinsics.checkNotNullParameter(str, "inputFrameworkName");
        Intrinsics.checkNotNullParameter(opDescriptorList, "nd4jOpDescriptors");
        this.registeredOps = new HashSetValuedHashMap<>();
        this.opDefList = new HashMap<>();
        this.nd4jOpDefs = new HashMap<>();
        this.inputFrameworkName = str;
        this.nd4jOpDescriptors = opDescriptorList;
    }
}
