/*
 * Decompiled with CFR 0.152.
 */
package org.jsoar.kernel.rete;

import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import org.jsoar.kernel.Agent;
import org.jsoar.kernel.DefaultProductionManager;
import org.jsoar.kernel.Production;
import org.jsoar.kernel.ProductionType;
import org.jsoar.kernel.SoarException;
import org.jsoar.kernel.events.ProductionAddedEvent;
import org.jsoar.kernel.learning.rl.ReinforcementLearning;
import org.jsoar.kernel.memory.PreferenceType;
import org.jsoar.kernel.rete.AlphaMemory;
import org.jsoar.kernel.rete.NodeVarNames;
import org.jsoar.kernel.rete.Rete;
import org.jsoar.kernel.rete.ReteNetConstants;
import org.jsoar.kernel.rete.ReteNode;
import org.jsoar.kernel.rete.ReteNodeType;
import org.jsoar.kernel.rete.ReteTest;
import org.jsoar.kernel.rete.VarLocation;
import org.jsoar.kernel.rete.VarNames;
import org.jsoar.kernel.rhs.AbstractRhsValue;
import org.jsoar.kernel.rhs.Action;
import org.jsoar.kernel.rhs.ActionSupport;
import org.jsoar.kernel.rhs.FunctionAction;
import org.jsoar.kernel.rhs.MakeAction;
import org.jsoar.kernel.rhs.ReteLocation;
import org.jsoar.kernel.rhs.RhsFunctionCall;
import org.jsoar.kernel.rhs.RhsSymbolValue;
import org.jsoar.kernel.rhs.RhsValue;
import org.jsoar.kernel.rhs.UnboundVariable;
import org.jsoar.kernel.symbols.DoubleSymbol;
import org.jsoar.kernel.symbols.IntegerSymbol;
import org.jsoar.kernel.symbols.StringSymbol;
import org.jsoar.kernel.symbols.Symbol;
import org.jsoar.kernel.symbols.SymbolFactoryImpl;
import org.jsoar.kernel.symbols.SymbolImpl;
import org.jsoar.kernel.symbols.Variable;
import org.jsoar.util.Arguments;
import org.jsoar.util.adaptables.Adaptables;
import org.jsoar.util.properties.PropertyKey;
import org.jsoar.util.properties.PropertyManager;

public class ReteNetReader {
    protected static final String MAGIC_STRING = "JSoarCompactReteNet";
    protected static final int FORMAT_VERSION = 1;
    private final Agent context;
    private final SymbolFactoryImpl syms;
    private final Rete rete;
    private final DefaultProductionManager productionManager;
    private final ReinforcementLearning rl;
    private List<Symbol> symbolMap;
    private List<AlphaMemory> alphaMemories;

    protected ReteNetReader(Agent context) {
        Arguments.checkNotNull(context, "context");
        this.context = context;
        this.syms = Adaptables.require(this.getClass(), context, SymbolFactoryImpl.class);
        this.rete = Adaptables.require(this.getClass(), context, Rete.class);
        this.rl = Adaptables.require(this.getClass(), context, ReinforcementLearning.class);
        this.productionManager = (DefaultProductionManager)context.getProductions();
    }

    public void read(InputStream is) throws IOException, SoarException {
        DataInputStream dis = new DataInputStream(is);
        String magic = dis.readUTF();
        if (!MAGIC_STRING.equals(magic)) {
            throw new SoarException("Input does not appear to be a valid JSoar rete net");
        }
        int version = dis.readInt();
        if (version != 1) {
            throw new SoarException(String.format("Unsupported JSoar rete net version. Expected %d, got %d", 1, version));
        }
        this.readAllSymbols(dis);
        this.readAlphaMemories(dis);
        this.readChildrenOfNode(dis);
        this.readProperties(dis);
    }

    private void readProperties(DataInputStream dis) throws IOException, SoarException {
        int numProperties = dis.readInt();
        PropertyManager properties = this.context.getProperties();
        for (int i = 0; i < numProperties; ++i) {
            int value;
            String name = dis.readUTF();
            PropertyKey<?> propertyKey = properties.getKey(name);
            if (propertyKey == null) {
                throw new SoarException("Unknown property " + name);
            }
            if (propertyKey.getType().equals(Boolean.class)) {
                value = dis.readBoolean();
                properties.set(propertyKey, value != 0);
                continue;
            }
            if (propertyKey.getType().equals(Integer.class)) {
                value = dis.readInt();
                properties.set(propertyKey, value);
                continue;
            }
            throw new SoarException(String.format("Unhandled property type \"%s\" for property %s.", propertyKey.getType(), name));
        }
    }

    private void readChildrenOfNode(DataInputStream dis) throws IOException, SoarException {
        int numNodes = dis.readInt();
        for (int i = 0; i < numNodes; ++i) {
            this.readNodeAndChildren(dis, this.rete.dummy_top_node);
        }
    }

    private void readNodeAndChildren(DataInputStream dis, ReteNode parent) throws IOException, SoarException {
        int count;
        ReteNodeType type = ReteNodeType.valueOf(dis.readUTF());
        ReteNode New = null;
        VarLocation left_hash_loc = new VarLocation(-1, -1);
        switch (type) {
            case MEMORY_BNODE: {
                left_hash_loc = this.readLeftHashLoc(dis);
            }
            case UNHASHED_MEMORY_BNODE: {
                New = ReteNode.make_new_mem_node(this.rete, parent, type, left_hash_loc);
                break;
            }
            case MP_BNODE: {
                left_hash_loc = this.readLeftHashLoc(dis);
            }
            case UNHASHED_MP_BNODE: {
                AlphaMemory am = this.alphaMemories.get(dis.readInt());
                ++am.reference_count;
                ReteTest other_tests = this.readTestList(dis);
                boolean left_unlinked_flag = dis.readBoolean();
                New = ReteNode.make_new_mp_node(this.rete, parent, type, left_hash_loc, am, other_tests, left_unlinked_flag);
                break;
            }
            case POSITIVE_BNODE: 
            case UNHASHED_POSITIVE_BNODE: {
                AlphaMemory am = this.alphaMemories.get(dis.readInt());
                ++am.reference_count;
                ReteTest other_tests = this.readTestList(dis);
                boolean left_unlinked_flag = dis.readBoolean();
                New = ReteNode.make_new_positive_node(this.rete, parent, type, am, other_tests, left_unlinked_flag);
                break;
            }
            case NEGATIVE_BNODE: {
                left_hash_loc = this.readLeftHashLoc(dis);
            }
            case UNHASHED_NEGATIVE_BNODE: {
                AlphaMemory am = this.alphaMemories.get(dis.readInt());
                ++am.reference_count;
                ReteTest other_tests = this.readTestList(dis);
                New = ReteNode.make_new_negative_node(this.rete, parent, type, left_hash_loc, am, other_tests);
                break;
            }
            case CN_PARTNER_BNODE: {
                count = dis.readInt();
                ReteNode ncc_top = parent;
                while (count-- > 0) {
                    ncc_top = ncc_top.real_parent_node();
                }
                New = ReteNode.make_new_cn_node(this.rete, ncc_top, parent);
                break;
            }
            case P_BNODE: {
                String name = dis.readUTF();
                String doc = dis.readUTF();
                ProductionType prodType = ProductionType.valueOf(dis.readUTF());
                Production.Support declaredSupport = Production.Support.valueOf(dis.readUTF());
                Action actionList = this.readActionList(dis);
                Production prod = Production.newBuilder().name(name).documentation(doc).type(prodType).support(declaredSupport).actions(actionList).build();
                int numUnboundVariables = dis.readInt();
                this.rete.update_max_rhs_unbound_variables(numUnboundVariables);
                ArrayList<Variable> unboundVars = new ArrayList<Variable>(numUnboundVariables);
                for (int i = 0; i < numUnboundVariables; ++i) {
                    unboundVars.add(this.getSymbol(dis.readInt()).asVariable());
                }
                prod.setRhsUnboundVariables(unboundVars);
                this.rl.addProduction(prod);
                New = ReteNode.make_new_production_node(this.rete, parent, prod);
                boolean hasNodeVariableNames = dis.readBoolean();
                New.b_p().parents_nvn = hasNodeVariableNames ? this.readNodeVarNames(dis, parent, this.symbolMap) : null;
                this.rete.update_node_with_matches_from_above(New);
                this.productionManager.addProductionToNameTypeMaps(prod);
                this.context.getEvents().fireEvent(new ProductionAddedEvent(this.context, prod));
                break;
            }
            default: {
                throw new SoarException("Unhandled ReteNodeType: " + (Object)((Object)type));
            }
        }
        count = dis.readInt();
        while (count-- > 0) {
            this.readNodeAndChildren(dis, New);
        }
    }

    private VarLocation readLeftHashLoc(DataInputStream dis) throws IOException {
        int field_num = dis.readInt();
        int levels_up = dis.readInt();
        return new VarLocation(levels_up, field_num);
    }

    private Action readActionList(DataInputStream dis) throws IOException, SoarException {
        Action prev_a = null;
        Action first_a = null;
        int count = dis.readInt();
        while (count-- > 0) {
            Action a = this.readAction(dis);
            if (prev_a != null) {
                prev_a.next = a;
            } else {
                first_a = a;
            }
            prev_a = a;
        }
        if (prev_a != null) {
            prev_a.next = null;
        } else {
            first_a = null;
        }
        return first_a;
    }

    private Action readAction(DataInputStream dis) throws IOException, SoarException {
        Action a = null;
        int type = dis.readInt();
        ReteNetConstants.Action actionType = ReteNetConstants.Action.fromOrdinal(type);
        if (actionType == ReteNetConstants.Action.MAKE_ACTION) {
            a = new MakeAction();
        } else if (actionType == ReteNetConstants.Action.FUNCALL_ACTION) {
            a = new FunctionAction(null);
        } else {
            throw new SoarException(String.format("Unknown Action type %d.", type));
        }
        boolean hasPreferenceType = dis.readBoolean();
        if (hasPreferenceType) {
            String preference_type = dis.readUTF();
            a.preference_type = PreferenceType.valueOf(preference_type);
        } else {
            a.preference_type = null;
        }
        a.support = ActionSupport.valueOf(dis.readUTF());
        if (actionType == ReteNetConstants.Action.FUNCALL_ACTION) {
            FunctionAction fa = a.asFunctionAction();
            fa.call = this.readRHSValue(dis).asFunctionCall();
        } else if (actionType == ReteNetConstants.Action.MAKE_ACTION) {
            MakeAction ma = a.asMakeAction();
            ma.id = this.readRHSValue(dis);
            ma.attr = this.readRHSValue(dis);
            ma.value = this.readRHSValue(dis);
            ma.referent = a.preference_type != null && a.preference_type.isBinary() ? this.readRHSValue(dis) : null;
        }
        return a;
    }

    private RhsValue readRHSValue(DataInputStream dis) throws IOException, SoarException {
        AbstractRhsValue rv = null;
        int type = dis.readInt();
        ReteNetConstants.RHS rhsType = ReteNetConstants.RHS.fromOrdinal(type);
        switch (rhsType) {
            case RHS_SYMBOL: {
                SymbolImpl sym = this.getSymbol(dis.readInt());
                rv = new RhsSymbolValue(sym);
                break;
            }
            case RHS_FUNCALL: {
                SymbolImpl sym = this.getSymbol(dis.readInt());
                boolean isStandalone = dis.readBoolean();
                if (this.context.getRhsFunctions().getHandler(sym.asString().getValue()) == null) {
                    this.context.getPrinter().warn("\nWARNING: Loaded a rete network that references undefined RHS function %s\n", sym.asString().getValue());
                }
                RhsFunctionCall funCall = new RhsFunctionCall(sym.asString(), isStandalone);
                int count = dis.readInt();
                while (count-- > 0) {
                    funCall.addArgument(this.readRHSValue(dis));
                }
                rv = funCall;
                break;
            }
            case RHS_RETELOC: {
                int field_num = dis.readInt();
                int levels_up = dis.readInt();
                rv = ReteLocation.create(field_num, levels_up);
                break;
            }
            case RHS_UNBOUND_VAR: {
                int index = dis.readInt();
                this.rete.update_max_rhs_unbound_variables(index + 1);
                rv = UnboundVariable.create(index);
                break;
            }
            default: {
                throw new SoarException("Unhandled RHS type: " + type);
            }
        }
        return rv;
    }

    private ReteTest readTestList(DataInputStream dis) throws IOException, SoarException {
        ReteTest prev_rt = null;
        ReteTest first = null;
        int count = dis.readInt();
        while (count-- > 0) {
            ReteTest rt = this.readTest(dis);
            if (prev_rt != null) {
                prev_rt.next = rt;
            } else {
                first = rt;
            }
            prev_rt = rt;
        }
        if (prev_rt != null) {
            prev_rt.next = null;
        } else {
            first = null;
        }
        return first;
    }

    private ReteTest readTest(DataInputStream dis) throws IOException, SoarException {
        int type = dis.readInt();
        int right_field_num = dis.readInt();
        ReteTest rt = new ReteTest(type);
        if (rt.test_is_constant_relational_test()) {
            SymbolImpl sym = this.getSymbol(dis.readInt());
            rt = ReteTest.createConstantTest(type += 0, right_field_num, sym);
        } else if (rt.test_is_variable_relational_test()) {
            int field_num = dis.readInt();
            int levels_up = dis.readInt();
            rt = ReteTest.createVariableTest(type -= 16, right_field_num, new VarLocation(levels_up, field_num));
        } else if (type == 32) {
            int count = dis.readInt();
            ArrayList<SymbolImpl> disjuncts = new ArrayList<SymbolImpl>(count);
            while (count-- > 0) {
                SymbolImpl sym = this.getSymbol(dis.readInt());
                disjuncts.add(sym);
            }
            rt = ReteTest.createDisjunctionTest(right_field_num, disjuncts);
        } else if (type == 48) {
            rt = ReteTest.createGoalIdTest();
        } else if (type == 49) {
            rt = ReteTest.createImpasseIdTest();
        } else {
            throw new SoarException("Unknown test type: " + rt + " (" + type + ")");
        }
        return rt;
    }

    private void readAllSymbols(DataInputStream dis) throws IOException, SoarException {
        ArrayList<Symbol> result = new ArrayList<Symbol>();
        result.add(null);
        result.addAll(this.readSymbolList(dis, new SymbolReader<StringSymbol>(){

            @Override
            public StringSymbol read(DataInputStream dis) throws IOException {
                return ReteNetReader.this.syms.createString(dis.readUTF());
            }
        }));
        result.addAll(this.readSymbolList(dis, new SymbolReader<Variable>(){

            @Override
            public Variable read(DataInputStream dis) throws IOException {
                return ReteNetReader.this.syms.make_variable(dis.readUTF());
            }
        }));
        result.addAll(this.readSymbolList(dis, new SymbolReader<IntegerSymbol>(){

            @Override
            public IntegerSymbol read(DataInputStream dis) throws IOException {
                return ReteNetReader.this.syms.createInteger(dis.readLong());
            }
        }));
        result.addAll(this.readSymbolList(dis, new SymbolReader<DoubleSymbol>(){

            @Override
            public DoubleSymbol read(DataInputStream dis) throws IOException {
                return ReteNetReader.this.syms.createDouble(dis.readDouble());
            }
        }));
        this.symbolMap = result;
    }

    private <T extends Symbol> List<T> readSymbolList(DataInputStream dis, SymbolReader<T> reader) throws IOException, SoarException {
        int size = dis.readInt();
        if (size < 0) {
            throw new SoarException(String.format("Invalid symbol list size %d", size));
        }
        ArrayList<T> result = new ArrayList<T>(size);
        for (int i = 0; i < size; ++i) {
            result.add(reader.read(dis));
        }
        return result;
    }

    private SymbolImpl getSymbol(int index) throws SoarException {
        if (index < 0 || index >= this.symbolMap.size()) {
            throw new SoarException(String.format("Invalid symbol index %d", index));
        }
        return (SymbolImpl)this.symbolMap.get(index);
    }

    private void readAlphaMemories(DataInputStream dis) throws IOException, SoarException {
        int count = dis.readInt();
        if (count < 0) {
            throw new SoarException(String.format("Invalid alpha memory list size %d", count));
        }
        ArrayList<AlphaMemory> ams = new ArrayList<AlphaMemory>(count);
        ams.add(null);
        for (int i = 0; i < count; ++i) {
            SymbolImpl id = this.getSymbol(dis.readInt());
            SymbolImpl attr = this.getSymbol(dis.readInt());
            SymbolImpl value = this.getSymbol(dis.readInt());
            boolean acceptable = dis.readBoolean();
            ams.add(this.rete.find_or_make_alpha_mem(id, attr, value, acceptable));
        }
        this.alphaMemories = ams;
    }

    private Object readVarNames(DataInputStream dis) throws SoarException, IOException {
        int type = dis.readInt();
        ReteNetConstants.VarName varNameType = ReteNetConstants.VarName.fromOrdinal(type);
        switch (varNameType) {
            case VARNAME_NULL: {
                return null;
            }
            case VARNAME_ONE_VAR: {
                int index = dis.readInt();
                return VarNames.one_var_to_varnames(this.getSymbol(index).asVariable());
            }
            case VARNAME_LIST: {
                int count = dis.readInt();
                if (count < 0) {
                    throw new SoarException(String.format("Count of varnames list record must be positive, got %d", count));
                }
                LinkedList<Variable> vars = new LinkedList<Variable>();
                for (int i = 0; i < count; ++i) {
                    vars.add(this.getSymbol(i).asVariable());
                }
                return VarNames.var_list_to_varnames(vars);
            }
        }
        throw new SoarException(String.format("Invalid varnames record type. Expected 0, 1, or 2, got %d", type));
    }

    private NodeVarNames readNodeVarNames(DataInputStream dis, ReteNode node, List<Symbol> symbolMap) throws SoarException, IOException {
        if (node.node_type == ReteNodeType.DUMMY_TOP_BNODE) {
            return null;
        }
        if (node.node_type == ReteNodeType.CN_BNODE) {
            ReteNode temp;
            NodeVarNames nvn_for_ncc;
            NodeVarNames bottom_of_subconditions = nvn_for_ncc = this.readNodeVarNames(dis, temp, symbolMap);
            for (temp = node.b_cn().partner.parent; temp != node.parent; temp = temp.real_parent_node()) {
                nvn_for_ncc = nvn_for_ncc.parent;
            }
            return NodeVarNames.createForNcc(nvn_for_ncc, bottom_of_subconditions);
        }
        Object id = this.readVarNames(dis);
        Object attr = this.readVarNames(dis);
        Object value = this.readVarNames(dis);
        NodeVarNames parent = this.readNodeVarNames(dis, node.real_parent_node(), symbolMap);
        return NodeVarNames.newInstance(parent, id, attr, value);
    }

    private static interface SymbolReader<T extends Symbol> {
        public T read(DataInputStream var1) throws IOException;
    }
}

