package cc.redberry.core.transformations;

import cc.redberry.concurrent.OutputPort;
import cc.redberry.core.context.ContextManager;
import cc.redberry.core.indexgenerator.IndexGenerator;
import cc.redberry.core.indexmapping.IndexMapping;
import cc.redberry.core.indices.IndicesUtils;
import cc.redberry.core.indices.SimpleIndices;
import cc.redberry.core.number.Complex;
import cc.redberry.core.tensor.Power;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.SumBuilder;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorBuilder;
import cc.redberry.core.tensor.TensorField;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.tensor.iterator.TraverseGuide;
import cc.redberry.core.tensor.iterator.TraverseState;
import cc.redberry.core.tensor.iterator.TreeTraverseIterator;
import cc.redberry.core.utils.Indicator;
import cc.redberry.core.utils.TensorUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;

/* loaded from: input_file:cc/redberry/core/transformations/Expand.class */
public final class Expand implements Transformation {
    private final Indicator<Tensor> indicator;
    private final int threads;

    /* loaded from: input_file:cc/redberry/core/transformations/Expand$ExpandPairPort.class */
    public static final class ExpandPairPort implements OutputPort<Tensor> {
        private final Tensor sum1;
        private final Tensor sum2;
        private final AtomicLong atomicLong = new AtomicLong();

        public ExpandPairPort(Sum sum, Sum sum2) {
            this.sum1 = sum;
            this.sum2 = sum2;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // cc.redberry.concurrent.OutputPort
        public Tensor take() {
            long andIncrement = this.atomicLong.getAndIncrement();
            if (andIncrement >= this.sum1.size() * this.sum2.size()) {
                return null;
            }
            return Tensors.multiply(this.sum1.get((int) (andIncrement / this.sum2.size())), this.sum2.get((int) (andIncrement % this.sum2.size())));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/Expand$IndexMapper.class */
    public static final class IndexMapper implements IndexMapping {
        private final IndexGenerator generator;
        private final Map<Integer, Integer> map;

        public IndexMapper(int[] iArr) {
            this.generator = new IndexGenerator(iArr);
            this.map = new HashMap(iArr.length);
        }

        @Override // cc.redberry.core.indexmapping.IndexMapping
        public int map(int i) {
            Integer num = this.map.get(Integer.valueOf(IndicesUtils.getNameWithType(i)));
            if (num == null) {
                Map<Integer, Integer> map = this.map;
                Integer valueOf = Integer.valueOf(i);
                Integer valueOf2 = Integer.valueOf(this.generator.generate(IndicesUtils.getType(i)));
                num = valueOf2;
                map.put(valueOf, valueOf2);
            }
            return IndicesUtils.getRawStateInt(i) ^ num.intValue();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cc/redberry/core/transformations/Expand$Worker.class */
    public static final class Worker implements Runnable {
        private final ExpandPairPort epp;
        private final TensorBuilder builder;
        private final Transformation[] transformations;

        public Worker(ExpandPairPort expandPairPort, TensorBuilder tensorBuilder, Transformation[] transformationArr) {
            this.epp = expandPairPort;
            this.builder = tensorBuilder;
            this.transformations = transformationArr;
        }

        @Override // java.lang.Runnable
        public void run() {
            while (true) {
                Tensor take = this.epp.take();
                Tensor tensor = take;
                if (take == null) {
                    return;
                }
                for (Transformation transformation : this.transformations) {
                    tensor = transformation.transform(tensor);
                }
                this.builder.put(tensor);
            }
        }
    }

    public Expand(Indicator<Tensor> indicator, int i) {
        this.indicator = indicator;
        this.threads = i;
    }

    public Expand() {
        this.indicator = Indicator.TRUE_INDICATOR;
        this.threads = 1;
    }

    @Override // cc.redberry.core.transformations.Transformation
    public Tensor transform(Tensor tensor) {
        return expand(tensor, this.indicator, new Transformation[0], this.threads);
    }

    public static Tensor expand(Tensor tensor) {
        return expand(tensor, Indicator.TRUE_INDICATOR, new Transformation[0], 1);
    }

    public static Tensor expand(Tensor tensor, int i) {
        return expand(tensor, Indicator.TRUE_INDICATOR, new Transformation[0], i);
    }

    public static Tensor expand(Tensor tensor, Transformation... transformationArr) {
        return expand(tensor, Indicator.TRUE_INDICATOR, transformationArr, 1);
    }

    public static Tensor expand(Tensor tensor, Transformation[] transformationArr, int i) {
        return expand(tensor, Indicator.TRUE_INDICATOR, transformationArr, i);
    }

    public static Tensor expand(Tensor tensor, Indicator<Tensor> indicator, Transformation[] transformationArr, int i) {
        TreeTraverseIterator treeTraverseIterator = new TreeTraverseIterator(tensor, TraverseGuide.EXCEPT_FUNCTIONS_AND_FIELDS);
        while (true) {
            TraverseState next = treeTraverseIterator.next();
            if (next == null) {
                return treeTraverseIterator.result();
            }
            if (next == TraverseState.Leaving) {
                Tensor current = treeTraverseIterator.current();
                if ((current instanceof Product) && indicator.is(current)) {
                    treeTraverseIterator.set(expandProductOfSums(current, indicator, transformationArr, i));
                } else if ((current instanceof Power) && (current.get(0) instanceof Sum) && TensorUtils.isNatural(current.get(1)) && indicator.is(current)) {
                    treeTraverseIterator.set(expandPower((Sum) current.get(0), ((Complex) current.get(1)).getReal().intValue(), transformationArr, i));
                }
            }
        }
    }

    private static Tensor expandProductOfSums(Tensor tensor, Indicator<Tensor> indicator, Transformation[] transformationArr, int i) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Sum sum = null;
        Sum sum2 = null;
        boolean z = false;
        for (int size = tensor.size() - 1; size >= 0; size--) {
            Tensor tensor2 = tensor.get(size);
            if (tensor2.getIndices().size() == 0) {
                if (!(tensor2 instanceof Sum)) {
                    arrayList.add(tensor2);
                } else if (sum == null) {
                    sum = (Sum) tensor2;
                } else {
                    Tensor expandPairOfSumsConcurrent = expandPairOfSumsConcurrent((Sum) tensor2, sum, transformationArr, i);
                    z = true;
                    if (expandPairOfSumsConcurrent instanceof Sum) {
                        sum = (Sum) expandPairOfSumsConcurrent;
                    } else {
                        arrayList.add(expandPairOfSumsConcurrent);
                        sum = null;
                    }
                }
            } else if (!(tensor2 instanceof Sum)) {
                arrayList2.add(tensor2);
            } else if (sum2 == null) {
                sum2 = (Sum) tensor2;
            } else {
                Tensor expand = expand(expandPairOfSumsConcurrent((Sum) tensor2, sum2, transformationArr, i), indicator, transformationArr, i);
                z = true;
                if (expand instanceof Sum) {
                    sum2 = (Sum) expand;
                } else {
                    arrayList2.add(expand);
                    sum2 = null;
                }
            }
        }
        if (!z && sum2 == null && (sum == null || arrayList.isEmpty())) {
            return tensor;
        }
        Tensor multiply = Tensors.multiply((Tensor[]) arrayList.toArray(new Tensor[arrayList.size()]));
        if (sum != null) {
            multiply = Tensors.multiplySumElementsOnFactor(sum, multiply);
        }
        Tensor multiply2 = Tensors.multiply((Tensor[]) arrayList2.toArray(new Tensor[arrayList2.size()]));
        if (sum2 != null) {
            multiply2 = Tensors.multiplySumElementsOnFactor(sum2, Tensors.multiply((Tensor[]) arrayList2.toArray(new Tensor[arrayList2.size()])));
        }
        return multiply2 instanceof Sum ? Tensors.multiplySumElementsOnFactorAndExpandScalars((Sum) multiply2, multiply) : Tensors.multiply(multiply, multiply2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v19, types: [cc.redberry.core.tensor.Tensor] */
    private static Tensor expandPower(Sum sum, int i, Transformation[] transformationArr, int i2) {
        Sum sum2 = sum;
        Set<Integer> allIndicesNames = TensorUtils.getAllIndicesNames(sum);
        int[] iArr = new int[allIndicesNames.size()];
        int i3 = -1;
        Iterator<Integer> it = allIndicesNames.iterator();
        while (it.hasNext()) {
            i3++;
            iArr[i3] = it.next().intValue();
        }
        IndexMapper indexMapper = new IndexMapper(iArr);
        for (int i4 = i - 1; i4 >= 1; i4--) {
            sum2 = expandPairOfSumsConcurrent(sum2, (Sum) renameDummy(sum, indexMapper), transformationArr, i2);
        }
        return sum2;
    }

    private static Tensor renameDummy(Tensor tensor, IndexMapper indexMapper) {
        SimpleTensor simpleTensor;
        SimpleIndices indices;
        SimpleIndices applyIndexMapping;
        TreeTraverseIterator treeTraverseIterator = new TreeTraverseIterator(tensor);
        while (true) {
            TraverseState next = treeTraverseIterator.next();
            if (next == null) {
                return treeTraverseIterator.result();
            }
            if (next == TraverseState.Leaving && (treeTraverseIterator.current() instanceof SimpleTensor) && indices != (applyIndexMapping = (indices = (simpleTensor = (SimpleTensor) treeTraverseIterator.current()).getIndices()).applyIndexMapping((IndexMapping) indexMapper))) {
                if (simpleTensor instanceof TensorField) {
                    treeTraverseIterator.set(Tensors.setIndicesToField((TensorField) simpleTensor, applyIndexMapping));
                } else {
                    treeTraverseIterator.set(Tensors.setIndicesToSimpleTensor(simpleTensor, applyIndexMapping));
                }
            }
        }
    }

    public static Tensor expandPairOfSums(Sum sum, Sum sum2) {
        return expandPairOfSums(sum, sum2, new Transformation[0]);
    }

    public static Tensor expandPairOfSums(Sum sum, Sum sum2, Transformation[] transformationArr) {
        ExpandPairPort expandPairPort = new ExpandPairPort(sum, sum2);
        SumBuilder sumBuilder = new SumBuilder();
        while (true) {
            Tensor take = expandPairPort.take();
            Tensor tensor = take;
            if (take == null) {
                return sumBuilder.build();
            }
            for (Transformation transformation : transformationArr) {
                tensor = transformation.transform(tensor);
            }
            sumBuilder.put(tensor);
        }
    }

    public static Tensor expandPairOfSumsConcurrent(Sum sum, Sum sum2, int i) {
        return expandPairOfSumsConcurrent(sum, sum2, new Transformation[0], i);
    }

    public static Tensor expandPairOfSumsConcurrent(Sum sum, Sum sum2, Transformation[] transformationArr, int i) {
        if (i == 1) {
            return expandPairOfSums(sum, sum2, transformationArr);
        }
        Future[] futureArr = new Future[i];
        ExpandPairPort expandPairPort = new ExpandPairPort(sum, sum2);
        SumBuilder sumBuilder = new SumBuilder();
        for (int i2 = 0; i2 < i; i2++) {
            futureArr[i2] = ContextManager.getExecutorService().submit(new Worker(expandPairPort, sumBuilder, transformationArr));
        }
        try {
            for (Future future : futureArr) {
                future.get();
            }
            return sumBuilder.build();
        } catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException(e);
        }
    }
}
