package com.github.aaronshan.functions.string;

import com.github.aaronshan.functions.utils.Failures;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceUtf8;
import io.airlift.slice.Slices;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;

@Description(name = "levenshtein_distance", value = "_FUNC_(string, string) - computes Levenshtein distance between two strings.", extended = "Example:\n > select _FUNC_(string, string) from src;")
/* loaded from: input_file:com/github/aaronshan/functions/string/UDFStringLevenshteinDistance.class */
public class UDFStringLevenshteinDistance extends UDF {
    private LongWritable result = new LongWritable(0);

    public LongWritable evaluate(Text text, Text text2) throws HiveException {
        if (text == null || text2 == null) {
            return null;
        }
        Slice utf8Slice = Slices.utf8Slice(text.toString());
        Slice utf8Slice2 = Slices.utf8Slice(text2.toString());
        int[] castToCodePoints = castToCodePoints(utf8Slice);
        int[] castToCodePoints2 = castToCodePoints(utf8Slice2);
        if (castToCodePoints.length < castToCodePoints2.length) {
            castToCodePoints = castToCodePoints2;
            castToCodePoints2 = castToCodePoints;
        }
        if (castToCodePoints2.length == 0) {
            this.result.set(castToCodePoints.length);
            return this.result;
        }
        Failures.checkCondition(castToCodePoints.length * (castToCodePoints2.length - 1) <= 1000000, "The combined inputs for Levenshtein distance are too large", new Object[0]);
        int[] iArr = new int[castToCodePoints2.length];
        for (int i = 0; i < castToCodePoints2.length; i++) {
            iArr[i] = i + 1;
        }
        for (int i2 = 0; i2 < castToCodePoints.length; i2++) {
            int i3 = iArr[0];
            if (castToCodePoints[i2] == castToCodePoints2[0]) {
                iArr[0] = i2;
            } else {
                iArr[0] = Math.min(i2, iArr[0]) + 1;
            }
            for (int i4 = 1; i4 < castToCodePoints2.length; i4++) {
                int i5 = iArr[i4];
                if (castToCodePoints[i2] == castToCodePoints2[i4]) {
                    iArr[i4] = i3;
                } else {
                    iArr[i4] = Math.min(iArr[i4 - 1], Math.min(i3, iArr[i4])) + 1;
                }
                i3 = i5;
            }
        }
        this.result.set(iArr[castToCodePoints2.length - 1]);
        return this.result;
    }

    private static int[] castToCodePoints(Slice slice) throws HiveException {
        int[] iArr = new int[safeCountCodePoints(slice)];
        int i = 0;
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr[i2] = SliceUtf8.getCodePointAt(slice, i);
            i += SliceUtf8.lengthOfCodePoint(slice, i);
        }
        return iArr;
    }

    private static int safeCountCodePoints(Slice slice) throws HiveException {
        int i = 0;
        int i2 = 0;
        while (i2 < slice.length()) {
            int tryGetCodePointAt = SliceUtf8.tryGetCodePointAt(slice, i2);
            if (tryGetCodePointAt < 0) {
                throw new HiveException("Invalid UTF-8 encoding in characters: " + slice.toStringUtf8());
            }
            i2 += SliceUtf8.lengthOfCodePoint(tryGetCodePointAt);
            i++;
        }
        return i;
    }
}
