package org.tribuo.interop.onnx.extractors;

import ai.onnxruntime.OrtException;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Iterator;
import java.util.List;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.classification.LabelFactory;
import org.tribuo.interop.onnx.extractors.BERTFeatureExtractor;

/* loaded from: input_file:org/tribuo/interop/onnx/extractors/BERTFeatureExtractorTest.class */
public class BERTFeatureExtractorTest {
    @Test
    public void testTokenizerLoading() throws URISyntaxException, IOException {
        Path path = Paths.get(BERTFeatureExtractorTest.class.getResource("bert-base-cased-vocab.txt").toURI());
        BERTFeatureExtractor.TokenizerConfig loadTokenizer = BERTFeatureExtractor.loadTokenizer(Paths.get(BERTFeatureExtractorTest.class.getResource("bert-base-cased-tokenizer.json").toURI()));
        List<String> readAllLines = Files.readAllLines(path, StandardCharsets.UTF_8);
        Assertions.assertEquals(loadTokenizer.tokenIDs.size(), readAllLines.size());
        Iterator<String> it = readAllLines.iterator();
        while (it.hasNext()) {
            Assertions.assertTrue(loadTokenizer.tokenIDs.containsKey(it.next()));
        }
        Assertions.assertEquals(100, loadTokenizer.maxInputCharsPerWord);
        Assertions.assertEquals(false, Boolean.valueOf(loadTokenizer.lowercase));
        Assertions.assertEquals(false, Boolean.valueOf(loadTokenizer.stripAccents));
        Assertions.assertEquals("[UNK]", loadTokenizer.unknownToken);
        Assertions.assertEquals("[CLS]", loadTokenizer.classificationToken);
        Assertions.assertEquals("[SEP]", loadTokenizer.separatorToken);
        BERTFeatureExtractor.TokenizerConfig loadTokenizer2 = BERTFeatureExtractor.loadTokenizer(Paths.get(BERTFeatureExtractorTest.class.getResource("tinybert-tokenizer.json").toURI()));
        Assertions.assertEquals(100, loadTokenizer2.maxInputCharsPerWord);
        Assertions.assertEquals(true, Boolean.valueOf(loadTokenizer2.lowercase));
        Assertions.assertEquals(true, Boolean.valueOf(loadTokenizer2.stripAccents));
        Assertions.assertEquals("[UNK]", loadTokenizer2.unknownToken);
        Assertions.assertEquals("[CLS]", loadTokenizer2.classificationToken);
        Assertions.assertEquals("[SEP]", loadTokenizer2.separatorToken);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Test
    public void testEmbedding() throws URISyntaxException, OrtException {
        Throwable th;
        Path path = Paths.get(BERTFeatureExtractorTest.class.getResource("tinybert-tokenizer.json").toURI());
        Path path2 = Paths.get(BERTFeatureExtractorTest.class.getResource("tinybert.onnx").toURI());
        LabelFactory labelFactory = new LabelFactory();
        BERTFeatureExtractor bERTFeatureExtractor = new BERTFeatureExtractor(labelFactory, path2, path, BERTFeatureExtractor.OutputPooling.CLS, 512, false);
        Throwable th2 = null;
        try {
            try {
                List list = bERTFeatureExtractor.tokenize("It is a truth universally acknowledged,");
                Assertions.assertEquals(11, list.size());
                Assertions.assertArrayEquals(new double[]{-0.09987275302410126d, -0.08381578326225281d, -0.17915815114974976d, 0.1595402956008911d, 0.12995541095733643d, 0.02285454422235489d, 0.16443753242492676d, -0.05802210792899132d, 0.25674450397491455d, -0.09596416354179382d, 0.08692581206560135d, -0.17145220935344696d, 0.05614880844950676d, 0.14230673015117645d, 0.09240773320198059d, 0.03262120857834816d, 0.05173583701252937d, 0.3492385447025299d, -0.010329307056963444d, 0.22916817665100098d, 0.1269291639328003d, 0.033620379865169525d, 0.12352693825960159d, 0.0520106665790081d, -0.012766036204993725d, 0.029396483674645424d, -0.09637446701526642d, 0.1646318882703781d, -0.08488218486309052d, -0.11151651293039322d, -0.14075034856796265d, -0.1965598613023758d, -0.25300613045692444d, 0.1736740618944168d, 0.19785678386688232d, -0.07669950276613235d, 0.03425660356879234d, 0.15457485616207123d, 0.005061550531536341d, 0.09869188815355301d, -0.06988175213336945d, -0.1692686229944229d, -0.03754367679357529d, -0.18752126395702362d, -0.2161409854888916d, -0.23712319135665894d, 0.03122984990477562d, 0.2796807289123535d, -0.19152438640594482d, -0.16166169941425323d}, bERTFeatureExtractor.extractFeatures(list), 1.0E-7d);
                if (bERTFeatureExtractor != null) {
                    if (0 != 0) {
                        try {
                            bERTFeatureExtractor.close();
                        } catch (Throwable th3) {
                            th2.addSuppressed(th3);
                        }
                    } else {
                        bERTFeatureExtractor.close();
                    }
                }
                bERTFeatureExtractor = new BERTFeatureExtractor(labelFactory, path2, path, BERTFeatureExtractor.OutputPooling.MEAN, 512, false);
                th = null;
            } catch (Throwable th4) {
                th2 = th4;
                throw th4;
            }
            try {
                try {
                    double[] extractFeatures = bERTFeatureExtractor.extractFeatures(bERTFeatureExtractor.tokenize("It is a truth universally acknowledged,"));
                    double[] dArr = {new double[]{0.15068432688713074d, 1.291318655014038d, -1.2186375856399536d, -0.1044885441660881d, -0.4916091561317444d, 0.3050057888031006d, 0.02411005087196827d, 0.6914452314376831d, -0.9399610161781311d, 0.27564719319343567d, 0.3189747631549835d, -1.7623217105865479d, 1.0957914590835571d, -0.5502046346664429d, 0.9324173927307129d, -1.5440735816955566d, -0.6300070285797119d, 0.16870944201946259d, -0.23421932756900787d, 1.5419358015060425d, 0.8080865144729614d, 0.025547025725245476d, 0.8078239560127258d, -0.393135666847229d, 0.7077765464782715d, -0.9835149645805359d, 1.161102294921875d, 1.3422735929489136d, -0.22637833654880524d, 1.3680726289749146d, -0.5111463069915771d, -0.5181847810745239d, 1.517228364944458d, -1.2042882442474365d, 0.8298169374465942d, -1.6887481212615967d, 1.2908772230148315d, -0.5774198174476624d, 1.1692675352096558d, -0.14680930972099304d, 0.8950840830802917d, -0.5876469016075134d, -0.19654417037963867d, 0.5392388105392456d, -1.7276521921157837d, -0.611663281917572d, -0.16960109770298004d, -0.049318667501211166d, -3.06258487701416d, 0.8719245195388794d}, new double[]{1.0965361595153809d, -1.1880749464035034d, 0.3113195300102234d, -0.3655802607536316d, -1.5703083276748657d, -1.1889897584915161d, 0.7478023171424866d, -0.39957600831985474d, -1.3990098237991333d, -0.8826226592063904d, 0.9648794531822205d, -1.6420735120773315d, 0.024305380880832672d, 0.24259810149669647d, 1.059976577758789d, -1.037889838218689d, 0.26438167691230774d, 0.6861287355422974d, 0.5879555940628052d, 0.24698519706726074d, 1.4645506143569946d, -0.006917477119714022d, 1.4397014379501343d, 1.1548478603363037d, 0.6714677214622498d, 0.6217436790466309d, 0.9137017130851746d, 1.1192125082015991d, -0.3414439857006073d, 0.35929620265960693d, -0.33042025566101074d, 1.302490234375d, 0.9483469724655151d, -0.8030039668083191d, -0.703107476234436d, -1.240598201751709d, 0.8048795461654663d, 0.6972590088844299d, -0.7481553554534912d, -1.8371062278747559d, 0.3635812997817993d, 1.581039547920227d, -2.5392677783966064d, 0.9263530373573303d, -1.9384249448776245d, -0.011463197879493237d, 0.07263753563165665d, -0.3027907609939575d, -0.09641885757446289d, -0.10073360800743103d}, new double[]{-1.8009929656982422d, 1.7652853727340698d, -1.0941085815429688d, 0.7871851325035095d, 0.4216032326221466d, -1.0672706365585327d, -0.8045699596405029d, 0.9037136435508728d, -2.143636465072632d, -0.15777075290679932d, 0.8275935649871826d, -1.1132452487945557d, -0.6541622877120972d, 0.07551676779985428d, 0.05679287016391754d, -0.2197706550359726d, -1.9069557189941406d, -0.4351661801338196d, 1.4026257991790771d, 0.8644767999649048d, -0.4851606786251068d, 1.1354173421859741d, 1.2878172397613525d, -0.3415570855140686d, 0.987838089466095d, -0.5241823196411133d, 1.1697313785552979d, 1.0625451803207397d, 1.6975505352020264d, -0.3230779767036438d, -0.5761590003967285d, 0.603135883808136d, -0.2959841191768646d, 0.3705970048904419d, 0.21842513978481293d, 0.26845696568489075d, 1.1077520847320557d, 1.4424563646316528d, 0.4948478937149048d, -1.5140318870544434d, -0.18938101828098297d, -0.7575010657310486d, -0.7218572497367859d, 0.7249206304550171d, -1.6331541538238525d, -0.8491988182067871d, -1.0908783674240112d, 1.3952903747558594d, 0.2033471167087555d, -0.5751500129699707d}, new double[]{0.4679090082645416d, 0.07996948063373566d, -1.6602343320846558d, 0.20126283168792725d, 0.06442692875862122d, -0.08430163562297821d, -0.7668660283088684d, 0.7862005829811096d, -1.517468810081482d, 0.6565523147583008d, 0.5232337117195129d, -0.4334183931350708d, -0.9862611889839172d, 0.38712450861930847d, 0.0734504759311676d, 0.6105567812919617d, 0.2887442409992218d, -1.2738916873931885d, -0.6189880967140198d, -0.43752387166023254d, 0.2172611504793167d, -1.9195876121520996d, 1.0163580179214478d, -0.5466981530189514d, 0.42114630341529846d, 2.586289644241333d, -1.0503790378570557d, 0.5456677079200745d, -0.2378668189048767d, -0.9922659993171692d, 0.15639954805374146d, 0.8844059705734253d, 0.9527930617332458d, 0.22640955448150635d, 1.870503544807434d, 0.43025293946266174d, 1.041293978691101d, 1.3405168056488037d, 0.3268677890300751d, -1.6156721115112305d, -1.0111846923828125d, 1.5576460361480713d, -1.6189697980880737d, 1.4933661222457886d, -1.9079337120056152d, -0.8028690218925476d, 0.37441205978393555d, -0.5571766495704651d, 0.5627382397651672d, -0.10420186072587967d}, new double[]{-1.9169484376907349d, -2.682704210281372d, 0.5323475003242493d, 0.323544442653656d, 0.5599474310874939d, 0.8109522461891174d, 0.874146580696106d, 0.8635344505310059d, -1.7794417142868042d, 0.1830170750617981d, -0.30078354477882385d, -1.4007086753845215d, 0.11536423116922379d, 2.058515787124634d, 1.2980265617370605d, -0.26386335492134094d, -0.9431896209716797d, 0.6774272322654724d, 0.1838160753250122d, 0.6408257484436035d, -0.28405261039733887d, -1.5093880891799927d, -0.5268791317939758d, -1.393326759338379d, 0.32029837369918823d, 1.4354349374771118d, 0.17443843185901642d, 0.681001603603363d, -0.34653395414352417d, 0.7144835591316223d, 0.36404064297676086d, -0.4427735507488251d, 0.9191749095916748d, -0.441799134016037d, 0.7796847224235535d, -0.8655410408973694d, 1.3878437280654907d, -0.8746767640113831d, 1.1680442094802856d, -0.6320986151695251d, 0.1265067160129547d, 1.5239441394805908d, 0.06903105229139328d, 0.39993590116500854d, -0.24691283702850342d, -1.212052345275879d, 1.1660674810409546d, -0.46258115768432617d, -1.5795762538909912d, -0.24556468427181244d}, new double[]{1.169201135635376d, -0.6180689930915833d, -1.2667185068130493d, 0.1621074080467224d, -0.825954258441925d, 1.011772871017456d, 0.08253408223390579d, 0.16792099177837372d, -1.3966248035430908d, 1.5229543447494507d, -0.9187943339347839d, -1.4296988248825073d, -1.2473783493041992d, 0.33634138107299805d, 1.6199390888214111d, -0.8765130639076233d, -0.4082728326320648d, 0.4292406439781189d, -0.3038734495639801d, 2.0255155686754733E-4d, -0.10973295569419861d, 0.10103811323642731d, -0.7566922903060913d, 1.201561450958252d, 0.04530053585767746d, 1.48433256149292d, -1.4958760738372803d, 0.29490426182746887d, 0.735073447227478d, 0.6239333748817444d, -1.2725467681884766d, 0.8728184700012207d, -0.7653454542160034d, -0.13305020332336426d, 1.9468834400177002d, 2.4297072887420654d, 1.7113466262817383d, -1.1554838418960571d, 0.035768039524555206d, 0.26373907923698425d, -0.03857085481286049d, -1.030165672302246d, -1.9738726615905762d, 0.3481074869632721d, -0.4252174496650696d, -0.04754994064569473d, 0.6649491190910339d, -1.1836568117141724d, -0.29658934473991394d, 0.7145687341690063d}, new double[]{-0.5735183358192444d, 0.005896596238017082d, 0.07026499509811401d, -1.0516016483306885d, 0.40990760922431946d, -1.519545078277588d, -0.12028831243515015d, -1.4255272150039673d, -0.6732795834541321d, -0.5314444303512573d, 0.8614209890365601d, -0.18908481299877167d, 1.2515850067138672d, -0.33477455377578735d, 1.0043144226074219d, 0.26335152983665466d, -0.7736467719078064d, -0.610704779624939d, 0.14823174476623535d, 0.7582481503486633d, -0.89017254114151d, -0.6541666388511658d, 1.7902003526687622d, 0.12512031197547913d, 0.03991509601473808d, 1.8844478130340576d, 0.12046905606985092d, -0.7644900679588318d, 0.8469402194023132d, 0.18888333439826965d, 2.0620357990264893d, 1.4848750829696655d, 0.49371078610420227d, -1.1692612171173096d, 0.34266752004623413d, -1.111390471458435d, -1.9194800853729248d, 0.303196519613266d, 0.20584909617900848d, -1.9672980308532715d, 1.3575668334960938d, 2.0061933994293213d, -0.8416401743888855d, -0.019423190504312515d, -1.9798663854599d, -0.2901528477668762d, 0.8275874853134155d, 0.6406838893890381d, -0.22013679146766663d, 0.1373300552368164d}, new double[]{-0.5070935487747192d, -0.8618849515914917d, 0.7598404288291931d, -0.7632965445518494d, 0.7372108101844788d, -0.7502579689025879d, -0.4673466384410858d, 0.03350372612476349d, -0.1724679172039032d, 0.34873172640800476d, -0.5511385202407837d, -0.8617361783981323d, 1.2812525033950806d, 1.1977880001068115d, 0.422703355550766d, -2.4478912353515625d, 0.27947863936424255d, 0.6899591684341431d, 0.023330552503466606d, 3.2386202812194824d, 1.835152506828308d, -1.3643983602523804d, 1.940202236175537d, -0.5567695498466492d, -0.9043402075767517d, 0.2982616722583771d, 0.7104737758636475d, 0.05262554809451103d, 0.8592529296875d, -0.3310755491256714d, 0.11795064806938171d, 0.5258741974830627d, -0.18802325427532196d, -0.26788589358329773d, 0.6526803970336914d, -1.298613429069519d, -1.7447354793548584d, 0.7002922892570496d, 0.5828584432601929d, -1.5788142681121826d, 0.30589115619659424d, 0.20253418385982513d, -1.2974342107772827d, -0.036622676998376846d, -1.264859914779663d, -0.3871012032032013d, -0.4832276999950409d, 0.10778568685054779d, 0.7777668833732605d, 0.40499362349510193d}, new double[]{0.9790793061256409d, -0.8327827453613281d, 0.5742382407188416d, -0.241596519947052d, -0.20028072595596313d, -0.7769615054130554d, -1.190406322479248d, 0.508434534072876d, -2.475158452987671d, 0.5680020451545715d, -1.662092924118042d, -1.3921698331832886d, 2.20672345161438d, -1.5228992700576782d, -0.39632880687713623d, -0.27906090021133423d, -0.15907344222068787d, 0.3922439217567444d, -0.787753701210022d, 1.533280611038208d, 1.4667662382125854d, 0.3229754567146301d, 0.08585838973522186d, -1.5718507766723633d, 0.4096935987472534d, 0.715343177318573d, 0.9958964586257935d, 1.2487760782241821d, -0.08343572914600372d, -1.0164155960083008d, 0.3275226056575775d, 0.6343787908554077d, -0.7280924320220947d, 0.6641810536384583d, 1.6676288843154907d, -1.5158360004425049d, 1.9924885034561157d, 0.584326446056366d, -0.05175713077187538d, -0.8921176195144653d, -0.05891117826104164d, 0.4360177516937256d, 0.3757384717464447d, -0.19090303778648376d, -1.121803641319275d, -0.07946958392858505d, 0.7248428463935852d, 0.06145211309194565d, -0.8474668860435486d, 0.598736047744751d}, new double[]{0.13182871043682098d, -1.8527439832687378d, -0.24037398397922516d, -0.4527086615562439d, -0.5142684578895569d, 0.578464150428772d, 0.6006605625152588d, 0.656891942024231d, 0.1326184719800949d, 0.03020528331398964d, 2.6893508434295654d, -2.0991835594177246d, -0.8301790356636047d, 0.11594317108392715d, -0.5993587970733643d, 1.4296510219573975d, -0.6532222032546997d, 1.4397293329238892d, 1.2564067840576172d, 1.098759412765503d, 0.2947686016559601d, -0.41795819997787476d, 1.5545347929000854d, 0.52629554271698d, -0.21128879487514496d, 0.6580076217651367d, -0.5271479487419128d, 1.1180450916290283d, 0.5557465553283691d, 1.3636577129364014d, -1.094968318939209d, 0.9319839477539062d, -0.5854425430297852d, 0.2713215947151184d, 1.3153860569000244d, -0.935158371925354d, -0.6154105067253113d, -1.1272748708724976d, 1.2313356399536133d, -1.4124794006347656d, -0.03874369338154793d, -1.23979914188385d, -1.7541760206222534d, -0.29308855533599854d, -0.5822644233703613d, 0.013161801733076572d, -0.06327825039625168d, -0.11618737131357193d, -1.5397554636001587d, -0.19829346239566803d}, new double[]{0.0013332264497876167d, 0.8481882810592651d, -0.768883466720581d, -0.10661126673221588d, -0.30787110328674316d, -0.47086861729621887d, 0.15401172637939453d, 0.10515493154525757d, -1.9277739524841309d, -0.8757210373878479d, -0.4928605854511261d, -2.2564198970794678d, -0.548955500125885d, -0.17002613842487335d, 0.2518070638179779d, -0.4958566725254059d, 0.24115560948848724d, 0.46547016501426697d, -0.18646040558815002d, 1.8159278631210327d, 1.3635350465774536d, -0.6588341593742371d, -1.2847622632980347d, 1.0262277126312256d, 0.4574947953224182d, 2.02490234375d, 0.5681936144828796d, 0.5467365384101868d, -1.489783525466919d, -0.33221685886383057d, -0.5433152318000793d, 1.5180481672286987d, 0.21392902731895447d, 0.592125654220581d, 0.8234230279922485d, -1.286543607711792d, 0.3742519021034241d, -0.8470395803451538d, 2.015620470046997d, -0.266152948141098d, -0.509074866771698d, 2.9326205253601074d, -0.3784724473953247d, -0.010701287537813187d, -0.5838567614555359d, 0.7221097946166992d, -0.19488774240016937d, -0.41090619564056396d, -1.174034833908081d, -0.4833768904209137d}, new double[]{-1.1144920587539673d, -0.35063648223876953d, 1.8217822313308716d, -0.48949581384658813d, -1.2925562858581543d, -0.5098785161972046d, -0.5128171443939209d, 0.1419551968574524d, -2.438873767852783d, -0.7029894590377808d, -0.42537757754325867d, 1.2728787660598755d, -0.9750368595123291d, -0.18741849064826965d, -1.15322744846344d, -0.4969038665294647d, 0.2774584889411926d, 0.23619344830513d, 0.6430550217628479d, 2.301358938217163d, -0.10586179792881012d, -0.2300921380519867d, -0.5464906096458435d, 0.46304431557655334d, 0.40598222613334656d, 0.3899063467979431d, -0.020308958366513252d, 1.85163414478302d, -0.8247016072273254d, -0.18942953646183014d, -0.8463541865348816d, 0.9607068300247192d, -0.2659664750099182d, 0.9000354409217834d, 1.2827112674713135d, -0.007229707669466734d, 1.2139497995376587d, 0.7489311695098877d, 0.4424838423728943d, -1.2714074850082397d, 0.4331207871437073d, 1.0890902280807495d, -0.4623361825942993d, 2.1549410820007324d, -1.7559632062911987d, -0.791874349117279d, -0.15250569581985474d, -1.5794051885604858d, -0.15555831789970398d, 0.8239690065383911d}, new double[]{-0.5861448049545288d, -0.6640603542327881d, -1.0603013038635254d, -1.3815442323684692d, -2.1465418338775635d, 0.47405195236206055d, -0.7235577702522278d, 1.4718230962753296d, -1.3939424753189087d, 0.04069202393293381d, 1.031272530555725d, -0.1701117753982544d, 0.3608003556728363d, -0.6034207940101624d, -0.7339737415313721d, -0.4035419523715973d, 0.25463899970054626d, -0.7091950178146362d, 1.2676637172698975d, 0.6636833548545837d, 0.9009960293769836d, 0.9830664396286011d, 0.5832557082176208d, -0.7218805551528931d, 0.9672513008117676d, -1.8420109748840332d, 1.5450838804244995d, -0.05154690146446228d, -1.3365509510040283d, 0.3695466220378876d, 0.24466310441493988d, 1.4004534482955933d, -1.043494701385498d, 0.4102145731449127d, 1.4667704105377197d, 0.701165497303009d, 1.6584479808807373d, -0.8489638566970825d, 1.3032140731811523d, -0.9052810072898865d, -1.1861495971679688d, 1.2016456127166748d, -0.8434803485870361d, 0.4377703070640564d, -0.4729674160480499d, -1.5361210107803345d, 1.1898629665374756d, 0.727762758731842d, 0.2389160394668579d, -0.5299307107925415d}};
                    double[] dArr2 = new double[dArr[0].length];
                    for (int i = 1; i < dArr.length - 1; i++) {
                        for (int i2 = 0; i2 < dArr2.length; i2++) {
                            int i3 = i2;
                            dArr2[i3] = dArr2[i3] + dArr[i][i2];
                        }
                    }
                    for (int i4 = 0; i4 < dArr2.length; i4++) {
                        int i5 = i4;
                        dArr2[i5] = dArr2[i5] / (dArr.length - 2);
                    }
                    Assertions.assertArrayEquals(dArr2, extractFeatures, 1.0E-7d);
                    if (bERTFeatureExtractor != null) {
                        if (0 == 0) {
                            bERTFeatureExtractor.close();
                            return;
                        }
                        try {
                            bERTFeatureExtractor.close();
                        } catch (Throwable th5) {
                            th.addSuppressed(th5);
                        }
                    }
                } catch (Throwable th6) {
                    th = th6;
                    throw th6;
                }
            } finally {
            }
        } finally {
        }
    }
}
