package edu.cmu.dynet.examples;

import edu.cmu.dynet.internal.ComputationGraph;
import edu.cmu.dynet.internal.Dim;
import edu.cmu.dynet.internal.DynetParams;
import edu.cmu.dynet.internal.Expression;
import edu.cmu.dynet.internal.FloatVector;
import edu.cmu.dynet.internal.LongVector;
import edu.cmu.dynet.internal.Parameter;
import edu.cmu.dynet.internal.ParameterCollection;
import edu.cmu.dynet.internal.SWIGTYPE_p_float;
import edu.cmu.dynet.internal.SimpleSGDTrainer;
import edu.cmu.dynet.internal.dynet_swig;

/* loaded from: input_file:edu/cmu/dynet/examples/XorExample.class */
public class XorExample {
    static final int HIDDEN_SIZE = 8;
    static final int ITERATIONS = 30;

    static Dim makeDim(int[] iArr) {
        LongVector longVector = new LongVector();
        for (int i : iArr) {
            longVector.add(i);
        }
        return new Dim(longVector);
    }

    public static void main(String[] strArr) {
        System.out.println("Running XOR example");
        dynet_swig.initialize(new DynetParams());
        System.out.println("Dynet initialized!");
        ParameterCollection parameterCollection = new ParameterCollection();
        SimpleSGDTrainer simpleSGDTrainer = new SimpleSGDTrainer(parameterCollection);
        ComputationGraph computationGraph = ComputationGraph.getNew();
        Parameter add_parameters = parameterCollection.add_parameters(makeDim(new int[]{HIDDEN_SIZE, 2}));
        Parameter add_parameters2 = parameterCollection.add_parameters(makeDim(new int[]{HIDDEN_SIZE}));
        Parameter add_parameters3 = parameterCollection.add_parameters(makeDim(new int[]{1, HIDDEN_SIZE}));
        Parameter add_parameters4 = parameterCollection.add_parameters(makeDim(new int[]{1}));
        Expression parameter = dynet_swig.parameter(computationGraph, add_parameters);
        Expression parameter2 = dynet_swig.parameter(computationGraph, add_parameters2);
        Expression parameter3 = dynet_swig.parameter(computationGraph, add_parameters3);
        Expression parameter4 = dynet_swig.parameter(computationGraph, add_parameters4);
        FloatVector floatVector = new FloatVector(2L);
        Expression input = dynet_swig.input(computationGraph, makeDim(new int[]{2}), floatVector);
        SWIGTYPE_p_float new_floatp = dynet_swig.new_floatp();
        dynet_swig.floatp_assign(new_floatp, 0.0f);
        Expression squared_distance = dynet_swig.squared_distance(dynet_swig.exprPlus(dynet_swig.exprTimes(parameter3, dynet_swig.tanh(dynet_swig.exprPlus(dynet_swig.exprTimes(parameter, input), parameter2))), parameter4), dynet_swig.input(computationGraph, new_floatp));
        System.out.println();
        System.out.println("Training...");
        for (int i = 0; i < ITERATIONS; i++) {
            float f = 0.0f;
            for (int i2 = 0; i2 < 4; i2++) {
                boolean z = i2 % 2 > 0;
                boolean z2 = (i2 / 2) % 2 > 0;
                floatVector.set(0, z ? 1.0f : -1.0f);
                floatVector.set(1, z2 ? 1.0f : -1.0f);
                dynet_swig.floatp_assign(new_floatp, z != z2 ? 1.0f : -1.0f);
                f += dynet_swig.as_scalar(computationGraph.forward(squared_distance));
                computationGraph.backward(squared_distance);
                simpleSGDTrainer.update();
            }
            simpleSGDTrainer.update_epoch();
            System.out.println("iter = " + i + ", loss = " + (f / 4.0f));
        }
    }
}
