360void Kernel::sort_propagate(std::unordered_map<std::string, node::Sort> &sort_key_map, std::vector<std::vector<parser::Parser::DTTypeDecl>> &datatype_blocks) {
361 std::unordered_map<std::string, node::Sort> new_sort_key_map;
362 std::unordered_map<node::Sort, size_t> tmp_dts;
364 for (
auto &block : datatype_blocks) {
366 for (
size_t i = 0, isz = block.size(); i < isz; ++i) {
367 auto &td = block.at(i);
368 if (sort_key_map.find(td.name) != sort_key_map.end()) {
369 tmp_dts.emplace(sort_key_map.at(td.name), _cnt);
370 sort_key_map.erase(td.name);
375 std::unordered_map<node::Sort, size_t> sort_idx;
376 std::vector<std::vector<size_t>> graph;
377 std::vector<std::vector<std::pair<size_t, size_t>>> reversed_graph;
379 for (
size_t i = 0, isz =
d_nodes.size(); i < isz; ++i) {
380 if (
d_nodes.at(i)->getSort()->isDec()) {
381 if (sort_idx.find(
d_nodes.at(i)->getSort()) == sort_idx.end()) {
382 sort_idx[
d_nodes.at(i)->getSort()] = sort_idx.size();
383 graph.emplace_back();
384 reversed_graph.emplace_back();
389 for (
size_t i = 0, isz =
d_nodes.size(); i < isz; ++i) {
390 if (
d_nodes.at(i)->getSort()->children.empty())
392 for (
auto &child_sort :
d_nodes.at(i)->getSort()->children) {
393 if ((child_sort->isDec() || child_sort->isDatatype()) && sort_idx.find(child_sort) == sort_idx.end()) {
394 new_sort_key_map.emplace(
"UNUSED_SORT", child_sort);
395 child_sort->setName(
"UNUSED_SORT");
400 for (
size_t i = 0, isz =
d_nodes.size(); i < isz; ++i) {
401 graph.emplace_back(
d_graph.at(i));
403 for (
auto &child : graph.back()) {
404 child += sort_idx.size();
406 for (
auto &[father_index, oper_index] : reversed_graph.back()) {
407 father_index += sort_idx.size();
409 if (
d_nodes.at(i)->getSort()->isDec()) {
410 size_t sort_index = sort_idx.at(
d_nodes.at(i)->getSort());
413 graph.at(i + sort_idx.size()).emplace_back(sort_index);
414 reversed_graph.at(sort_index).emplace_back(i + sort_idx.size(), 0);
419 std::reverse(hash_table.begin(), hash_table.end());
420 hash_table.resize(hash_table.size() + sort_idx.size(), 0);
421 std::reverse(hash_table.begin(), hash_table.end());
422 for (
const auto &[sort, index] : tmp_dts) {
423 if (sort_idx.find(sort) != sort_idx.end()) {
438 std::vector<size_t> processing(
d_graph.size());
439 std::iota(processing.begin(), processing.end(), 0);
441 std::vector<size_t> symbols(sort_idx.size());
442 std::iota(symbols.begin(), symbols.end(), 0);
443 for (
size_t i = 0, isz =
d_symbols.size(); i < isz; ++i) {
445 if (!
d_nodes.at(idx)->getSort()->isDec()) {
446 symbols.emplace_back(idx + sort_idx.size());
450 std::vector<size_t> unique_hashes(sort_idx.size());
458 std::transform(symbols.begin(), symbols.begin() + sort_idx.size(), unique_hashes.begin(), [
this](
size_t index) { return d_hash_table.at(index); });
459 std::sort(unique_hashes.begin(), unique_hashes.end());
464 size_t hash_value = 0;
465 bool is_select =
false;
466 for (
size_t i = 1, isz = unique_hashes.size(); i < isz; ++i)
467 if (unique_hashes[i] == unique_hashes[i - 1]) {
468 hash_value = unique_hashes.at(i);
474 size_t sort_index = 0;
475 for (
size_t i = 0, isz = sort_idx.size(); i < isz; ++i) {
477 sort_index = symbols.at(i);
487 std::vector<std::pair<size_t, node::Sort>> sort_hashes;
488 sort_hashes.reserve(sort_idx.size());
489 for (
const auto &[sort, index] : sort_idx) {
493 std::sort(sort_hashes.begin(), sort_hashes.end());
494 std::unordered_map<std::string, std::string> dt_name_map;
495 size_t sort_count = 0, dt_count = 0;
496 for (
size_t i = 0, isz = sort_hashes.size(); i < isz; ++i) {
497 if (sort_key_map.find(sort_hashes.at(i).second->name) == sort_key_map.end()) {
498 dt_name_map[sort_hashes.at(i).second->name] =
"DT" + std::to_string(dt_count);
499 sort_hashes.at(i).second->setName(
"DT" + std::to_string(dt_count));
500 new_sort_key_map[
"DT" + std::to_string(dt_count)] = sort_hashes.at(i).second;
504 sort_hashes.at(i).second->setName(
"SORT" + std::to_string(sort_count));
505 new_sort_key_map[
"SORT" + std::to_string(sort_count)] = sort_hashes.at(i).second;
510 std::unordered_map<std::string, size_t> uf_name_id;
511 for (
size_t i = 0, isz =
d_nodes.size(); i < isz; ++i) {
513 uf_name_id[
d_nodes.at(i)->getName()] = i + sort_idx.size();
516 for (
auto &block : datatype_blocks) {
517 for (
size_t i = 0, isz = block.size(); i < isz; ++i) {
518 auto &td = block.at(i);
519 if (dt_name_map.find(td.name) != dt_name_map.end()) {
520 td.name = dt_name_map.at(td.name);
521 for (
auto &cd : td.ctors) {
522 if (uf_name_id.find(cd.name) != uf_name_id.end()) {
529 for (
size_t i = 0, isz = cd.selectors.size(); i < isz; ++i) {
530 auto &sd = cd.selectors.at(i);
531 if (uf_name_id.find(sd.name) != uf_name_id.end()) {
543 for (
size_t i = 0, isz = cd.selectors.size(); i < isz; ++i) {
544 auto &sd = cd.selectors.at(i);
545 if (uf_name_id.find(sd.name) != uf_name_id.end()) {
556 sort_key_map.swap(new_sort_key_map);
576 std::vector<std::string> &function_names = nm.getFunctionNames();
577 std::unordered_map<std::string, std::shared_ptr<stabilizer::parser::DAGNode>> &function_key_map = nm.getFunKeyMap();
579 std::unordered_map<std::string, size_t> def_fun_id;
580 size_t def_fun_count = 0;
581 for (
const auto &func_name : function_names) {
582 if (function_key_map.at(func_name)->isFuncDef())
583 def_fun_id[func_name] = def_fun_count++;
586 for (
size_t i = 0, isz =
d_nodes.size(); i < isz; ++i) {
587 if (
d_nodes.at(i)->isFuncDef()) {
594 if (std::any_of(
d_nodes.begin(),
d_nodes.end(), [
this](
const node::Node &node) { return node->getSort()->isDec() || node->getSort()->isDatatype(); })) {
599 nm.getSortNames().clear();
600 nm.getDatatypeBlocks().clear();
604 std::unordered_map<size_t, size_t> hash_count;
627 std::unordered_map<std::string, size_t> &var_names = nm.getVarNames();
628 std::unordered_map<std::string, size_t> new_var_names;
629 std::unordered_map<std::string, size_t> &temp_var_names = nm.getTempVarNames();
630 std::unordered_map<std::string, size_t> new_temp_var_names;
631 std::unordered_map<std::string, std::string> var_names_map;
633 std::unordered_map<std::string, std::string> function_names_map;
634 size_t symbol_count = 0, uf_count = 0;
642 function_names_map[uf_name] =
"UF" + std::to_string(uf_count);
653 var_names_map[var_name] =
"S" + std::to_string(symbol_count);
669 std::vector<std::pair<size_t, size_t>> dt_testers;
670 for (
size_t i = 0, isz =
d_nodes.size(); i < isz; ++i) {
672 std::string vuf_name =
d_nodes.at(i)->getName();
673 if (function_names_map.find(vuf_name) != function_names_map.end()) {
675 d_nodes.at(i)->setName(function_names_map.at(vuf_name));
684 std::sort(dt_testers.begin(), dt_testers.end(), [](
const auto &a,
const auto &b) {
685 return a.first < b.first;
688 for (
const auto &[hash_value, idx] : dt_testers) {
692 function_names_map[
d_nodes.at(idx)->getName()] =
"UF" + std::to_string(uf_count);
693 d_nodes.at(idx)->setName(
"UF" + std::to_string(uf_count));
697 auto new_blocks = nm.getDatatypeBlocks();
699 for (
auto &block : nm.getDatatypeBlocks()) {
700 std::sort(block.begin(), block.end(), [](
const auto &a,
const auto &b) {
701 return a.name < b.name;
703 auto new_block = block;
705 for (
auto &td : block) {
706 if (!td.name.starts_with(
"DT") && !td.name.starts_with(
"SORT")) {
709 auto new_ctors = td.ctors;
711 for (
auto &cd : td.ctors) {
712 if (function_names_map.find(cd.name) != function_names_map.end()) {
713 cd.name = function_names_map.at(cd.name);
718 auto new_selectors = cd.selectors;
719 new_selectors.clear();
720 for (
auto &sd : cd.selectors) {
721 if (function_names_map.find(sd.name) != function_names_map.end()) {
722 sd.name = function_names_map.at(sd.name);
723 new_selectors.emplace_back(sd);
728 else if (sd.sort->name.starts_with(
"SORT") || sd.sort->name.starts_with(
"DT")) {
730 new_selectors.emplace_back(sd);
732 else if (sd.sort->isDec() || sd.sort->isDatatype()) {
733 sd.sort->setName(
"UNUSED_SORT");
734 nm.getSortNames().emplace(
"UNUSED_SORT", sd.sort);
736 new_selectors.emplace_back(sd);
740 new_selectors.emplace_back(sd);
744 if (cd.name.empty()) {
745 bool selected =
false;
746 for (
const auto &sd : new_selectors) {
747 if (!sd.name.empty()) {
749 cd.selectors.swap(new_selectors);
757 cd.selectors.swap(new_selectors);
792 for (
size_t i = 0, isz = cd.selectors.size(); i < isz; ++i) {
793 if (cd.selectors.at(i).name.empty())
794 cd.selectors.at(i).name =
"VAR" + std::to_string(idx++);
799 new_ctors.emplace_back(cd);
801 bool need_two = td.ctors.size() >= 2;
802 bool has_zero =
false;
803 bool only_zero =
true;
804 for (
const auto &cd : td.ctors) {
805 if (!cd.selectors.empty()) {
811 if (has_zero && !only_zero) {
815 td.ctors.swap(new_ctors);
816 std::sort(td.ctors.begin(), td.ctors.end(), [](
const auto &a,
const auto &b) {
817 if (a.name != b.name)
818 return a.name < b.name;
819 else if (a.selectors.size() == b.selectors.size()) {
820 for (size_t i = 0, isz = a.selectors.size(); i < isz; ++i) {
821 if (a.selectors.at(i).name != b.selectors.at(i).name)
822 return a.selectors.at(i).name < b.selectors.at(i).name;
827 return a.selectors.size() < b.selectors.size();
830 for (
auto &cd : td.ctors) {
832 cd.name =
"CON" + std::to_string(con_idx++);
835 while (td.ctors.empty() || (td.ctors.size() < 2 && need_two)) {
836 td.ctors.emplace_back();
838 td.ctors.back().name =
"CON" + std::to_string(con_idx++);
839 else if (td.ctors.empty() || !has_zero || td.ctors.front().selectors.empty()) {
840 td.ctors.back().name =
"CON" + std::to_string(con_idx++);
841 td.ctors.back().selectors.emplace_back(td.ctors.back().name +
"_TVAR0", std::make_shared<parser::Sort>(
parser::SORT_KIND::SK_DEC,
"UNUSED_SORT"));
842 td.ctors.back().selectors.back().sort->setName(
"UNUSED_SORT");
843 nm.getSortNames().emplace(
"UNUSED_SORT", td.ctors.back().selectors.back().sort);
846 td.ctors.back().name =
"CON" + std::to_string(con_idx++);
850 for (
auto &cd : td.ctors) {
851 for (
auto &sd : cd.selectors) {
852 if (sd.name.starts_with(
"VAR"))
853 sd.name =
"DT_SEL" + std::to_string(idx++);
859 new_block.emplace_back(td);
861 block.swap(new_block);
863 new_blocks.emplace_back(block);
869 nm.getDatatypeBlocks().swap(new_blocks);
871 for (
const auto &[name, index] : var_names) {
872 auto itr = var_names_map.find(name);
873 if (itr != var_names_map.end()) {
874 new_var_names[itr->second] = index;
877 for (
const auto &[name, index] : temp_var_names) {
878 auto itr = var_names_map.find(name);
879 if (itr != var_names_map.end()) {
880 new_temp_var_names[itr->second] = index;
884 var_names.swap(new_var_names);
885 temp_var_names.swap(new_temp_var_names);
887 std::unordered_map<node::Node, size_t> node2index;
889 std::vector<node::Node> func_dec, func_rec, func_def;
891 std::unordered_map<std::string, std::vector<size_t>> uf_buckets;
892 std::unordered_map<std::string, std::vector<size_t>> func_buckets;
893 for (
size_t i = 0, isz = d_nodes.size(); i < isz; ++i) {
894 if (d_nodes.at(i)->isUFName()) {
901 d_nodes.at(i)->setName(function_names_map.at(d_nodes.at(i)->getName()));
902 util::hash_combine(d_hash_table.at(i), std::hash<std::string>{}(d_nodes.at(i)->getName()));
909 node2index.emplace(d_nodes.at(i), i);
910 auto children = d_nodes.at(i)->getChildren();
912 if (is_commutative(i)) {
917 std::sort(children.begin() + del, children.end(), [
this, &node2index](
const node::Node &a,
const node::Node &b) {
918 if (d_context_hash.at(node2index.at(a)) == d_context_hash.at(node2index.at(b)))
919 return d_hash_table.at(node2index.at(a)) < d_hash_table.at(node2index.at(b));
921 return d_context_hash.at(node2index.at(a)) < d_context_hash.at(node2index.at(b));
925 d_nodes.at(i)->replace_children(children);
934 func_dec.emplace_back(d_nodes.at(i));
936 func_rec.emplace_back(d_nodes.at(i));
938 func_def.emplace_back(d_nodes.at(i));
941 func_buckets[d_nodes.at(i)->getName()].emplace_back(i);
945 std::sort(func_dec.begin(), func_dec.end(), [
this, &node2index](
const node::Node &a,
const node::Node &b) {
946 return d_hash_table.at(node2index.at(a)) < d_hash_table.at(node2index.at(b));
948 std::sort(func_rec.begin(), func_rec.end(), [
this, &node2index](
const node::Node &a,
const node::Node &b) {
949 return d_hash_table.at(node2index.at(a)) < d_hash_table.at(node2index.at(b));
951 std::sort(func_def.begin(), func_def.end(), [
this, &node2index](
const node::Node &a,
const node::Node &b) {
952 return d_hash_table.at(node2index.at(a)) < d_hash_table.at(node2index.at(b));
955 for (
size_t i = 0, isz = func_dec.size(); i < isz; i++) {
956 function_names_map[func_dec.at(i)->getName()] =
"FDEC" + std::to_string(i);
957 func_dec.at(i)->setName(
"FDEC" + std::to_string(i));
959 for (
size_t i = 0, isz = func_rec.size(); i < isz; i++) {
960 function_names_map[func_rec.at(i)->getName()] =
"FREC" + std::to_string(i);
961 func_rec.at(i)->setName(
"FREC" + std::to_string(i));
963 for (
size_t i = 0, isz = func_def.size(); i < isz; i++) {
964 for (
const auto &index : func_buckets[func_def.at(i)->getName()]) {
965 d_nodes.at(index)->setName(
"FDEF" + std::to_string(i));
968 function_names_map[func_def.at(i)->getName()] =
"FDEF" + std::to_string(i);
969 func_def.at(i)->setName(
"FDEF" + std::to_string(i));
970 util::hash_combine(d_hash_table.at(node2index.at(func_def.at(i))), std::hash<std::string>{}(
"FDEF" + std::to_string(i)));
973 std::unordered_map<std::string, std::shared_ptr<stabilizer::parser::DAGNode>> new_function_key_map;
975 std::vector<std::string> ndec, nrec, ndef;
976 for (
const auto &name : function_names) {
977 auto itr = function_names_map.find(name);
979 if (itr != function_names_map.end()) {
980 if (function_key_map.at(name)->isFuncDec())
981 ndec.emplace_back(itr->second);
982 else if (function_key_map.at(name)->isFuncRec())
983 nrec.emplace_back(itr->second);
984 else if (function_key_map.at(name)->isFuncDef())
985 ndef.emplace_back(itr->second);
987 new_function_key_map[itr->second] = function_key_map.at(name);
988 new_function_key_map[itr->second]->setName(itr->second);
989 auto fun = new_function_key_map[itr->second];
990 if (fun->isFuncDef()) {
993 for (
size_t i = 1, isz = fun->getChildrenSize(); i < isz; ++i) {
994 fun->getChild(i)->setName(
"VAR" + std::to_string(i - 1));
999 std::vector<std::string> new_function_names;
1000 std::sort(ndec.begin(), ndec.end());
1001 for (
const auto &name : ndec)
1002 new_function_names.emplace_back(name);
1003 std::sort(nrec.begin(), nrec.end());
1004 for (
const auto &name : nrec)
1005 new_function_names.emplace_back(name);
1006 for (
const auto &name : ndef)
1007 new_function_names.emplace_back(name);
1009 function_names.swap(new_function_names);
1010 function_key_map.swap(new_function_key_map);
1012 auto assertions = nm.assertions();
1014 std::sort(assertions.begin(), assertions.end(), [
this, &node2index](
const node::Node &a,
const node::Node &b) {
1015 if (d_context_hash.at(node2index.at(a)) == d_context_hash.at(node2index.at(b)))
1016 return d_hash_table.at(node2index.at(a)) < d_hash_table.at(node2index.at(b));
1018 return d_context_hash.at(node2index.at(a)) < d_context_hash.at(node2index.at(b));
1022 nm.replace_assertions(assertions);