package gov.sandia.cognition.text.algorithm;

import gov.sandia.cognition.learning.algorithm.minimization.matrix.ConjugateGradientMatrixSolver;
import gov.sandia.cognition.learning.algorithm.semisupervised.valence.MultipartiteValenceMatrix;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.util.Pair;
import java.lang.Comparable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:gov/sandia/cognition/text/algorithm/ValenceSpreader.class */
public class ValenceSpreader<TermType extends Comparable<TermType>, DocIdType extends Comparable<DocIdType>> {
    private Map<TermType, Pair<Double, Double>> weightedTerms = new HashMap();
    private Map<DocIdType, Pair<Double, Double>> weightedDocuments = new HashMap();
    private Map<DocIdType, Map<TermType, Double>> documents = new HashMap();
    private double tolerance = 1.0E-5d;
    private int numThreads = 2;

    /* loaded from: input_file:gov/sandia/cognition/text/algorithm/ValenceSpreader$Result.class */
    public static class Result<TermType, DocIdType> {
        public Map<TermType, Double> termWeights;
        public Map<DocIdType, Double> documentWeights;
    }

    public void setNumThreads(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Unable to set the number of threads to less than 1");
        }
        this.numThreads = i;
    }

    public void setIterativeSolverTolerance(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Unable to set the tolerance to a value less than or equal to zero.");
        }
        this.tolerance = d;
    }

    public void addWeightedTerm(TermType termtype, double d) {
        addWeightedTerm(termtype, d, 1.0d);
    }

    public void addWeightedTerm(TermType termtype, double d, double d2) {
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Trust must be greater than 0.  Input: " + d2);
        }
        this.weightedTerms.put(termtype, new DefaultInputOutputPair(Double.valueOf(d), Double.valueOf(d2)));
    }

    public void addWeightedDocument(DocIdType docidtype, double d) {
        addWeightedDocument(docidtype, d, 1.0d);
    }

    public void addWeightedDocument(DocIdType docidtype, double d, double d2) {
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Trust must be greater than 0.  Input: " + d2);
        }
        this.weightedDocuments.put(docidtype, new DefaultInputOutputPair(Double.valueOf(d), Double.valueOf(d2)));
    }

    public void addDocumentTermOccurrences(DocIdType docidtype, Set<TermType> set) {
        HashMap hashMap = new HashMap(set.size());
        Iterator<TermType> it = set.iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), Double.valueOf(1.0d));
        }
        this.documents.put(docidtype, hashMap);
    }

    public void addDocumentTermWeights(DocIdType docidtype, Map<TermType, Double> map) {
        this.documents.put(docidtype, new HashMap(map));
    }

    private static <Type> void centerMap(Map<Type, Pair<Double, Double>> map) {
        double d = Double.MAX_VALUE;
        double d2 = Double.MIN_VALUE;
        for (Pair<Double, Double> pair : map.values()) {
            d = Math.min(pair.getFirst().doubleValue(), d);
            d2 = Math.max(pair.getFirst().doubleValue(), d2);
        }
        double d3 = 2.0d / (d2 - d);
        for (Map.Entry<Type, Pair<Double, Double>> entry : map.entrySet()) {
            map.put(entry.getKey(), new DefaultInputOutputPair(Double.valueOf(((entry.getValue().getFirst().doubleValue() - d) * d3) - 1.0d), entry.getValue().getSecond()));
        }
    }

    public void centerWeightsRange() {
        centerMap(this.weightedTerms);
        centerMap(this.weightedDocuments);
    }

    public Result<TermType, DocIdType> spreadValence() {
        return spreadValence(10);
    }

    public Result<TermType, DocIdType> spreadValence(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Unable to work with non-positive power: " + i);
        }
        int size = this.documents.size();
        HashSet hashSet = new HashSet();
        Iterator<Map<TermType, Double>> it = this.documents.values().iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().keySet());
        }
        int size2 = hashSet.size();
        ArrayList arrayList = new ArrayList(hashSet);
        Collections.sort(arrayList);
        HashMap hashMap = new HashMap(size2);
        for (int i2 = 0; i2 < size2; i2++) {
            hashMap.put(arrayList.get(i2), Integer.valueOf(i2));
        }
        ArrayList arrayList2 = new ArrayList(this.documents.keySet());
        Collections.sort(arrayList2);
        HashMap hashMap2 = new HashMap(size);
        for (int i3 = 0; i3 < size; i3++) {
            hashMap2.put(arrayList2.get(i3), Integer.valueOf(i3));
        }
        ArrayList arrayList3 = new ArrayList(2);
        arrayList3.add(Integer.valueOf(size2));
        arrayList3.add(Integer.valueOf(size));
        MultipartiteValenceMatrix multipartiteValenceMatrix = new MultipartiteValenceMatrix(arrayList3, i, this.numThreads);
        for (int i4 = 0; i4 < size; i4++) {
            for (Map.Entry<TermType, Double> entry : this.documents.get(arrayList2.get(i4)).entrySet()) {
                multipartiteValenceMatrix.addRelationship(0, ((Integer) hashMap.get(entry.getKey())).intValue(), 1, i4, entry.getValue().doubleValue());
            }
        }
        for (Map.Entry<TermType, Pair<Double, Double>> entry2 : this.weightedTerms.entrySet()) {
            Integer num = (Integer) hashMap.get(entry2.getKey());
            if (num != null) {
                multipartiteValenceMatrix.setElementsScore(0, num.intValue(), entry2.getValue().getSecond().doubleValue(), entry2.getValue().getFirst().doubleValue());
            }
        }
        for (Map.Entry<DocIdType, Pair<Double, Double>> entry3 : this.weightedDocuments.entrySet()) {
            multipartiteValenceMatrix.setElementsScore(1, ((Integer) hashMap2.get(entry3.getKey())).intValue(), entry3.getValue().getSecond().doubleValue(), entry3.getValue().getFirst().doubleValue());
        }
        Vector init = multipartiteValenceMatrix.init();
        Vector output = new ConjugateGradientMatrixSolver(init, init, this.tolerance).learn((ConjugateGradientMatrixSolver) multipartiteValenceMatrix).getOutput();
        Result<TermType, DocIdType> result = new Result<>();
        result.termWeights = new HashMap(size2);
        result.documentWeights = new HashMap(size);
        for (int i5 = 0; i5 < size2; i5++) {
            result.termWeights.put(arrayList.get(i5), Double.valueOf(output.getElement(i5)));
        }
        for (int i6 = 0; i6 < size; i6++) {
            result.documentWeights.put(arrayList2.get(i6), Double.valueOf(output.getElement(size2 + i6)));
        }
        return result;
    }
}
