package hex.segments;

import hex.ModelBuilderTest;
import hex.segments.SegmentModelsBuilder;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import water.DKV;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.parser.BufferedString;

@RunWith(Parameterized.class)
/* loaded from: input_file:hex/segments/SegmentModelsBuilderTest.class */
public class SegmentModelsBuilderTest extends TestUtil {

    @Parameterized.Parameter
    public Integer parallelism;

    /* loaded from: input_file:hex/segments/SegmentModelsBuilderTest$GetSegment.class */
    private static class GetSegment extends ModelBuilderTest.DummyAction<GetSegment> {
        private GetSegment() {
        }

        @Override // hex.ModelBuilderTest.DummyAction
        protected String run(ModelBuilderTest.DummyModelParameters dummyModelParameters) {
            Vec vec = dummyModelParameters.train().vec("class");
            Assert.assertTrue(vec.isConst());
            return vec.domain()[(int) vec.at(0L)];
        }
    }

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Parameterized.Parameters
    public static Object[] data() {
        return new Object[]{null, 2};
    }

    @Test
    public void buildSegmentModels() {
        try {
            Scope.enter();
            Frame track = Scope.track(new Frame[]{parse_test_file("./smalldata/junit/iris.csv")});
            Frame frame = new Frame(Key.make());
            frame.add("class", Vec.makeVec(new long[]{2, 0, 1}, track.vec("class").domain(), Vec.VectorGroup.VG_LEN1.addVec()));
            DKV.put(frame);
            Scope.track_generic(frame);
            ModelBuilderTest.DummyModelParameters dummyModelParameters = new ModelBuilderTest.DummyModelParameters();
            dummyModelParameters._makeModel = true;
            dummyModelParameters._action = new GetSegment();
            dummyModelParameters._train = track._key;
            dummyModelParameters._response_column = "sepal_wid";
            SegmentModelsBuilder.SegmentModelsParameters segmentModelsParameters = new SegmentModelsBuilder.SegmentModelsParameters();
            segmentModelsParameters._segments = frame._key;
            if (this.parallelism != null) {
                segmentModelsParameters._parallelism = this.parallelism.intValue();
            }
            SegmentModels segmentModels = new SegmentModelsBuilder(segmentModelsParameters, dummyModelParameters).buildSegmentModels().get();
            Scope.track_generic(segmentModels);
            Frame frame2 = segmentModels.toFrame();
            Scope.track(new Frame[]{frame2});
            System.out.println(frame2.toTwoDimTable());
            Assert.assertEquals(3L, frame2.numRows());
            Vec vec = frame2.vec("class");
            vec.getClass();
            Vec.Reader reader = new Vec.Reader(vec);
            Vec vec2 = frame2.vec("model");
            vec2.getClass();
            Vec.Reader reader2 = new Vec.Reader(vec2);
            for (int i = 0; i < 3; i++) {
                String str = vec.domain()[(int) reader.at(i)];
                ModelBuilderTest.DummyModel dummyModel = Key.make(reader2.atStr(new BufferedString(), i).toString()).get();
                Assert.assertNotNull(dummyModel);
                Assert.assertEquals(str, ((ModelBuilderTest.DummyModelOutput) dummyModel._output)._msg);
                dummyModel.remove();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }
}
