package cc.redberry.core.transformations;

import cc.redberry.core.indexgenerator.IndexGenerator;
import cc.redberry.core.indexmapping.IndexMapping;
import cc.redberry.core.indexmapping.IndexMappingBuffer;
import cc.redberry.core.indexmapping.IndexMappingBufferRecord;
import cc.redberry.core.indices.Indices;
import cc.redberry.core.indices.IndicesUtils;
import cc.redberry.core.indices.SimpleIndices;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorField;
import cc.redberry.core.tensor.Tensors;
import cc.redberry.core.tensor.iterator.TraverseGuide;
import cc.redberry.core.tensor.iterator.TraverseState;
import cc.redberry.core.tensor.iterator.TreeTraverseIterator;
import cc.redberry.core.utils.ArraysUtils;
import cc.redberry.core.utils.IntArrayList;
import cc.redberry.core.utils.TensorUtils;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:cc/redberry/core/transformations/ApplyIndexMapping.class */
public final class ApplyIndexMapping implements Transformation {
    private final int[] from;
    private final int[] to;
    private final int[] forbidden;

    /* loaded from: input_file:cc/redberry/core/transformations/ApplyIndexMapping$IndexMapper.class */
    public static final class IndexMapper implements IndexMapping {
        private final int[] from;
        private final int[] to;

        public IndexMapper(int[] iArr, int[] iArr2) {
            this.from = iArr;
            this.to = iArr2;
        }

        @Override // cc.redberry.core.indexmapping.IndexMapping
        public int map(int i) {
            int binarySearch = Arrays.binarySearch(this.from, IndicesUtils.getNameWithType(i));
            return binarySearch < 0 ? i : IndicesUtils.getRawStateInt(i) ^ this.to[binarySearch];
        }
    }

    public ApplyIndexMapping(int[] iArr, int[] iArr2, int[] iArr3) {
        this.from = (int[]) iArr.clone();
        this.to = (int[]) iArr2.clone();
        this.forbidden = iArr3;
    }

    @Override // cc.redberry.core.transformations.Transformation
    public Tensor transform(Tensor tensor) {
        checkConsistent(tensor, this.from);
        return applyIndexMapping(tensor, this.from, this.to, this.forbidden);
    }

    private static void checkConsistent(Tensor tensor, int[] iArr) {
        int[] copy = tensor.getIndices().getFreeIndices().getAllIndices().copy();
        Arrays.sort(copy);
        int[] iArr2 = (int[]) iArr.clone();
        Arrays.sort(iArr2);
        if (!Arrays.equals(copy, iArr2)) {
            throw new IllegalArgumentException("From indices are not equal to free indices of tensor.");
        }
    }

    public static Tensor applyIndexMapping(Tensor tensor, int[] iArr, int[] iArr2, int[] iArr3) {
        checkConsistent(tensor, iArr);
        return unsafeApplyIndexMappingFromClonedSource(tensor, (int[]) iArr.clone(), (int[]) iArr2.clone(), iArr3);
    }

    public static Tensor applyIndexMapping(Tensor tensor, IndexMappingBuffer indexMappingBuffer) {
        return applyIndexMapping(tensor, indexMappingBuffer, new int[0]);
    }

    public static Tensor applyIndexMapping(Tensor tensor, IndexMappingBuffer indexMappingBuffer, int[] iArr) {
        if (indexMappingBuffer == null) {
            throw new NullPointerException("Buffer is null.");
        }
        Map<Integer, IndexMappingBufferRecord> map = indexMappingBuffer.getMap();
        int[] iArr2 = new int[map.size()];
        int[] iArr3 = new int[map.size()];
        int i = 0;
        for (Map.Entry<Integer, IndexMappingBufferRecord> entry : map.entrySet()) {
            iArr2[i] = entry.getKey().intValue();
            IndexMappingBufferRecord value = entry.getValue();
            int i2 = i;
            i++;
            iArr3[i2] = value.getIndexName() ^ (value.diffStatesInitialized() ? Integer.MIN_VALUE : 0);
        }
        int[] copy = tensor.getIndices().getFreeIndices().getAllIndices().copy();
        for (int i3 = 0; i3 < copy.length; i3++) {
            copy[i3] = IndicesUtils.getNameWithType(copy[i3]);
        }
        Arrays.sort(copy);
        int[] iArr4 = (int[]) iArr2.clone();
        Arrays.sort(iArr4);
        if (!Arrays.equals(copy, iArr4)) {
            throw new IllegalArgumentException("From indices are not equal to free indices of tensor.");
        }
        Tensor unsafeApplyIndexMappingFromSortedClonedPreparedSource = unsafeApplyIndexMappingFromSortedClonedPreparedSource(tensor, iArr2, iArr3, iArr);
        return indexMappingBuffer.getSignum() ? Tensors.negate(unsafeApplyIndexMappingFromSortedClonedPreparedSource) : unsafeApplyIndexMappingFromSortedClonedPreparedSource;
    }

    public static Tensor renameDummy(Tensor tensor, int[] iArr) {
        return renameDummyFromClonedSource(tensor, (int[]) iArr.clone());
    }

    public static Tensor renameDummyFromClonedSource(Tensor tensor, int[] iArr) {
        int[] copy = tensor.getIndices().getFreeIndices().getAllIndices().copy();
        for (int length = copy.length - 1; length >= 0; length--) {
            copy[length] = IndicesUtils.getNameWithType(copy[length]);
        }
        Arrays.sort(copy);
        return unsafeApplyIndexMappingFromSortedClonedPreparedSource(tensor, copy, copy, iArr);
    }

    private static Tensor unsafeApplyIndexMappingFromClonedSource(Tensor tensor, int[] iArr, int[] iArr2, int[] iArr3) {
        for (int length = iArr.length - 1; length >= 0; length--) {
            int rawStateInt = IndicesUtils.getRawStateInt(iArr[length]);
            int i = length;
            iArr[i] = iArr[i] ^ rawStateInt;
            int i2 = length;
            iArr2[i2] = iArr2[i2] ^ rawStateInt;
        }
        ArraysUtils.quickSort(iArr, iArr2);
        return unsafeApplyIndexMappingFromSortedClonedPreparedSource(tensor, iArr, iArr2, iArr3);
    }

    private static Tensor unsafeApplyIndexMappingFromSortedClonedPreparedSource(Tensor tensor, int[] iArr, int[] iArr2, int[] iArr3) {
        Set<Integer> allIndicesNames = TensorUtils.getAllIndicesNames(tensor);
        Indices freeIndices = tensor.getIndices().getFreeIndices();
        for (int size = freeIndices.size() - 1; size >= 0; size--) {
            allIndicesNames.remove(Integer.valueOf(IndicesUtils.getNameWithType(freeIndices.get(size))));
        }
        int[] iArr4 = new int[iArr2.length + iArr3.length];
        System.arraycopy(iArr2, 0, iArr4, 0, iArr2.length);
        System.arraycopy(iArr3, 0, iArr4, iArr2.length, iArr3.length);
        for (int length = iArr4.length - 1; length >= 0; length--) {
            iArr4[length] = IndicesUtils.getNameWithType(iArr4[length]);
        }
        IntArrayList intArrayList = new IntArrayList(iArr.length);
        IntArrayList intArrayList2 = new IntArrayList(iArr2.length);
        intArrayList.addAll(iArr);
        intArrayList2.addAll(iArr2);
        Arrays.sort(iArr4);
        int[] iArr5 = new int[iArr4.length + allIndicesNames.size()];
        System.arraycopy(iArr4, 0, iArr5, 0, iArr4.length);
        int length2 = iArr4.length - 1;
        Iterator<Integer> it = allIndicesNames.iterator();
        while (it.hasNext()) {
            length2++;
            iArr5[length2] = it.next().intValue();
        }
        IndexGenerator indexGenerator = new IndexGenerator(iArr5);
        for (Integer num : allIndicesNames) {
            if (Arrays.binarySearch(iArr4, num.intValue()) >= 0 && Arrays.binarySearch(iArr, num.intValue()) < 0) {
                intArrayList.add(num.intValue());
                intArrayList2.add(indexGenerator.generate(IndicesUtils.getType(num.intValue())));
            }
        }
        int[] array = intArrayList.toArray();
        int[] array2 = intArrayList2.toArray();
        ArraysUtils.quickSort(array, array2);
        return applyIndexMapping(tensor, new IndexMapper(array, array2));
    }

    public static Tensor applyIndexMapping(Tensor tensor, IndexMapper indexMapper) {
        SimpleTensor simpleTensor;
        SimpleIndices indices;
        SimpleIndices applyIndexMapping;
        TreeTraverseIterator treeTraverseIterator = new TreeTraverseIterator(tensor, TraverseGuide.EXCEPT_FUNCTIONS_AND_FIELDS);
        while (true) {
            TraverseState next = treeTraverseIterator.next();
            if (next == null) {
                return treeTraverseIterator.result();
            }
            if (next != TraverseState.Leaving && (treeTraverseIterator.current() instanceof SimpleTensor) && indices != (applyIndexMapping = (indices = (simpleTensor = (SimpleTensor) treeTraverseIterator.current()).getIndices()).applyIndexMapping((IndexMapping) indexMapper))) {
                if (simpleTensor instanceof TensorField) {
                    treeTraverseIterator.set(Tensors.setIndicesToField((TensorField) simpleTensor, applyIndexMapping));
                } else {
                    treeTraverseIterator.set(Tensors.setIndicesToSimpleTensor(simpleTensor, applyIndexMapping));
                }
            }
        }
    }
}
