package dev.langchain4j.classification;

import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.segment.TextSegment;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.assertj.core.api.WithAssertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:dev/langchain4j/classification/TextClassifierTest.class */
class TextClassifierTest implements WithAssertions {

    /* loaded from: input_file:dev/langchain4j/classification/TextClassifierTest$CatClassifier.class */
    public static class CatClassifier implements TextClassifier<Categories> {
        public List<Categories> classify(String str) {
            HashSet hashSet = new HashSet();
            if (str.contains("cat")) {
                hashSet.add(Categories.CAT);
            }
            if (str.contains("dog")) {
                hashSet.add(Categories.DOG);
            }
            if (str.contains("fish")) {
                hashSet.add(Categories.FISH);
            }
            return new ArrayList(hashSet);
        }
    }

    /* loaded from: input_file:dev/langchain4j/classification/TextClassifierTest$Categories.class */
    public enum Categories {
        CAT,
        DOG,
        FISH
    }

    TextClassifierTest() {
    }

    @Test
    public void test() {
        CatClassifier catClassifier = new CatClassifier();
        assertThat(catClassifier.classify("cat fish")).containsOnly(new Categories[]{Categories.CAT, Categories.FISH});
        assertThat(catClassifier.classify(TextSegment.from("dog fish"))).containsOnly(new Categories[]{Categories.DOG, Categories.FISH});
        assertThat(catClassifier.classify(Document.from("dog cat"))).containsOnly(new Categories[]{Categories.CAT, Categories.DOG});
    }
}
