diff --git a/src/src_sharpSAT/MainSolver/DecisionTree.cpp b/src/src_sharpSAT/MainSolver/DecisionTree.cpp index a9a3f11..1cf6dd5 100644 --- a/src/src_sharpSAT/MainSolver/DecisionTree.cpp +++ b/src/src_sharpSAT/MainSolver/DecisionTree.cpp @@ -803,16 +803,22 @@ void DTNode::smooth(int &num_nodes, CMainSolver &solver, set &literals) { // If the counts are the same, then it is already smooth if (variables.size() != (*it)->numVariables()) { - // Create the new AND child - DTNode* newAnd = new DTNode(DT_AND, num_nodes++); - - (*it)->parentDeleted(this); - - toAdd.insert(newAnd); - toRemove.insert(*it); - - newAnd->addChild(*it, true); - + + DTNode *childAnd; + + if (DT_AND == (*it)->getType()) { + childAnd = *it; + } else { + // Create the new AND child + childAnd = new DTNode(DT_AND, num_nodes++); + cout << "Creating a new and for node " << id << endl; + cout << "New node count: " << num_nodes << endl; + (*it)->parentDeleted(this); + toAdd.insert(childAnd); + toRemove.insert(*it); + childAnd->addChild(*it, true); + } + // Add all of the missing variables set::iterator var_it; for (var_it = variables.begin(); var_it != variables.end(); var_it++) @@ -820,14 +826,12 @@ void DTNode::smooth(int &num_nodes, CMainSolver &solver, set &literals) int var = *var_it; if (!((*it)->hasVariable(var))) { - DTNode* newOr = new DTNode(DT_OR, num_nodes++); - newAnd->addChild(newOr, true); - newOr->addChild(solver.get_lit_node_full(var), true); - newOr->addChild(solver.get_lit_node_full(-1 * var), true); + cout << var << endl; + childAnd->addChild(solver.get_universal_or(var), true); } } // Record the new values - newAnd->smooth(num_nodes, solver, literals); + childAnd->smooth(num_nodes, solver, literals); } } diff --git a/src/src_sharpSAT/MainSolver/MainSolver.cpp b/src/src_sharpSAT/MainSolver/MainSolver.cpp index 825cdd4..a0c2454 100755 --- a/src/src_sharpSAT/MainSolver/MainSolver.cpp +++ b/src/src_sharpSAT/MainSolver/MainSolver.cpp @@ -69,6 +69,10 @@ void CMainSolver::solve(const char *lpstrFileName) //Original: litNodes.push_back(new DTNode(i, true, num_Nodes++)); litNodes.push_back(new DTNode(-1 * i, true, num_Nodes++)); + + universalOrNodes.push_back(new DTNode(DT_OR, num_Nodes++)); + universalOrNodes.back()->addChild(litNodes[litNodes.size()-1]); + universalOrNodes.back()->addChild(litNodes[litNodes.size()-2]); } toSTDOUT("#Vars:" << countAllVars() << endl); diff --git a/src/src_sharpSAT/MainSolver/MainSolver.h b/src/src_sharpSAT/MainSolver/MainSolver.h index 1b2fd09..7c2886e 100644 --- a/src/src_sharpSAT/MainSolver/MainSolver.h +++ b/src/src_sharpSAT/MainSolver/MainSolver.h @@ -59,6 +59,7 @@ class CMainSolver: public CInstanceGraph int num_Nodes; bool enable_DT_recording; vector litNodes; + vector universalOrNodes; vector > dirtyLitNodes; ////-----------//// @@ -231,6 +232,10 @@ class CMainSolver: public CInstanceGraph public: + DTNode * get_universal_or(int var) + { + return universalOrNodes[var]; + } DTNode * get_lit_node(int lit) { if (lit < 0)