package com.gengoai.apollo.ml.model.embedding;

import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.collection.Sets;
import com.gengoai.collection.multimap.HashSetMultimap;
import com.gengoai.collection.multimap.Multimap;
import com.gengoai.io.resource.Resource;
import com.gengoai.math.Math2;
import com.gengoai.string.Strings;
import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/model/embedding/FaruquiRetrofitting.class */
public class FaruquiRetrofitting implements Retrofitting {
    private static final long serialVersionUID = 1;
    private final int iterations;
    private final Multimap<String, String> lexicon;

    public FaruquiRetrofitting() {
        this(25);
    }

    public FaruquiRetrofitting(int i) {
        this.lexicon = new HashSetMultimap();
        this.iterations = i;
    }

    public WordEmbedding apply(@NonNull WordEmbedding wordEmbedding) {
        if (wordEmbedding == null) {
            throw new NullPointerException("origVectors is marked non-null but is null");
        }
        HashSet hashSet = new HashSet(wordEmbedding.getAlphabet());
        Set intersection = Sets.intersection(hashSet, this.lexicon.keySet());
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        hashSet.forEach(str -> {
            NDArray unitize = wordEmbedding.embed(str).unitize();
            hashMap2.put(str, unitize);
            hashMap.put(str, unitize.m1copy());
        });
        for (int i = 0; i < this.iterations; i++) {
            intersection.forEach(str2 -> {
                Set intersection2 = Sets.intersection(this.lexicon.get(str2), hashSet);
                if (intersection2.size() > 0) {
                    NDArray mul = ((NDArray) hashMap.get(str2)).mul(intersection2.size());
                    intersection2.forEach(str2 -> {
                        mul.addi((NDArray) hashMap2.get(str2));
                    });
                    mul.divi((float) (2.0d * intersection2.size()));
                    hashMap2.put(str2, mul);
                }
            });
        }
        PreTrainedWordEmbedding preTrainedWordEmbedding = new PreTrainedWordEmbedding();
        preTrainedWordEmbedding.vectorStore = new InMemoryVectorStore(wordEmbedding.dimension());
        for (String str3 : wordEmbedding.getAlphabet()) {
            int addOrGetIndex = preTrainedWordEmbedding.vectorStore.addOrGetIndex(str3);
            if (hashMap2.containsKey(str3)) {
                preTrainedWordEmbedding.vectorStore.updateVector(addOrGetIndex, ((NDArray) hashMap2.get(str3)).unitize());
            } else {
                preTrainedWordEmbedding.vectorStore.updateVector(addOrGetIndex, wordEmbedding.embed(str3).unitize());
            }
        }
        return preTrainedWordEmbedding;
    }

    private void loadLexicon(Resource resource, Multimap<String, String> multimap) throws IOException {
        resource.forEach(str -> {
            String[] split = str.toLowerCase().trim().split("\\s+");
            String norm = norm(split[0]);
            for (int i = 1; i < split.length; i++) {
                multimap.put(norm, norm(split[i]));
            }
        });
    }

    private String norm(String str) {
        return Math2.tryParseDouble(str) != null ? "---num---" : Strings.isPunctuation(str) ? "---punc---" : str.toLowerCase().replace('_', ' ');
    }

    public void setLexicon(Resource resource) throws IOException {
        this.lexicon.clear();
        loadLexicon(resource, this.lexicon);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -957011653:
                if (implMethodName.equals("lambda$loadLexicon$9eef84a7$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/embedding/FaruquiRetrofitting") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/collection/multimap/Multimap;Ljava/lang/String;)V")) {
                    FaruquiRetrofitting faruquiRetrofitting = (FaruquiRetrofitting) serializedLambda.getCapturedArg(0);
                    Multimap multimap = (Multimap) serializedLambda.getCapturedArg(1);
                    return str -> {
                        String[] split = str.toLowerCase().trim().split("\\s+");
                        String norm = norm(split[0]);
                        for (int i = 1; i < split.length; i++) {
                            multimap.put(norm, norm(split[i]));
                        }
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
