/*
 * Decompiled with CFR 0.152.
 */
package chat.octet.model.components.criteria.impl;

import chat.octet.model.LlamaService;
import chat.octet.model.beans.Token;
import chat.octet.model.components.criteria.StoppingCriteria;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.lang3.math.NumberUtils;

public class StoppingWordCriteria
implements StoppingCriteria {
    private final List<int[]> stoppingTokens;

    public StoppingWordCriteria(String ... words) {
        Preconditions.checkNotNull((Object)words, (Object)"Stopping words cannot be null");
        this.stoppingTokens = Lists.newArrayList();
        for (String word : words) {
            if (NumberUtils.isParsable((String)word)) {
                this.stoppingTokens.add(new int[]{Integer.parseInt(word)});
                continue;
            }
            int[] tokens = LlamaService.tokenize(word, false, true);
            this.stoppingTokens.add(tokens);
        }
    }

    @Override
    public boolean criteria(@Nullable int[] inputTokenIds, @Nonnull float[] scores, Object ... args) {
        if (args != null && args.length == 1) {
            List generateTokens = (List)args[0];
            for (int[] tokens : this.stoppingTokens) {
                int[] lastTokens;
                int length = tokens.length;
                if (length > generateTokens.size() || !Arrays.equals(tokens, lastTokens = generateTokens.subList(generateTokens.size() - length, generateTokens.size()).stream().mapToInt(Token::getId).toArray())) continue;
                return true;
            }
        }
        return false;
    }
}

