This documentation is automatically generated by online-judge-tools/verification-helper
View the Project on GitHub shibh308/library
struct ZDD{ struct Node{ int label; using Index = typename MemoryPool<Node>::Index; Index c[2]; Node(){} Node(int label, Index nil) : label(label){ c[0] = c[1] = nil; } Node(int label, Index l, Index r) : label(label){ c[0] = l; c[1] = r; } }; using Index = typename MemoryPool<Node>::Index; Index nil, true_idx, false_idx; MemoryPool<Node> pool; // 引数は(Data, label, 最後どう進んだか) // secondが-1ならDataを表す, 0/1なら終端状態を表す ZDD(){ nil = {-1}; true_idx = {-2}; false_idx = {-3}; } template<typename Data> Index build(function<pair<Data, int>(Data, int, bool)> f, Data init){ unordered_map<int, map<Data, Index>> node_map; Index root = pool.alloc(); pool[root] = Node(0, nil); node_map[0][init] = root; build<Data>(root, init, f, node_map); return compress(root); } template<typename Data> void build(Index pi, Data& d, function<pair<Data, int>(Data, int, bool)>& f, unordered_map<int, map<Data, Index>>& node_map){ auto& x = pool[pi]; for(int i = 0; i < 2; ++i){ int fl; Data y; tie(y, fl) = f(d, x.label, i); if(fl == 0){ x.c[i] = false_idx; } else if(fl == 1){ x.c[i] = true_idx; } else if(node_map[x.label + 1].find(y) == node_map[x.label + 1].end()){ Index new_idx = pool.alloc(); pool[new_idx] = Node(x.label + 1, nil); node_map[x.label + 1][y] = new_idx; x.c[i] = new_idx; build(node_map[x.label + 1][y], y, f, node_map); }else{ x.c[i] = node_map[x.label + 1][y]; } } } Index compress(Index pi, unordered_map<int, unordered_map<uint64_t, Index>>& node_map, unordered_map<int, Index>& replace_map){ // 同型で子が一致する頂点を破滅 if(pi == true_idx || pi == false_idx) return pi; if(replace_map.find(pi.idx) != replace_map.end()) return replace_map[pi.idx]; auto& p = pool[pi]; Index li = compress(p.c[0], node_map, replace_map); Index ri = compress(p.c[1], node_map, replace_map); p.c[0] = li; p.c[1] = ri; uint64_t h = uint32_t(li.idx) | ((1uLL * uint32_t(ri.idx)) << 32); if(node_map[p.label].find(h) == node_map[p.label].end()) node_map[p.label][h] = pi; else pool.free(pi); return replace_map[pi.idx] = node_map[p.label][h]; } Index compress(Index pi){ unordered_map<int, unordered_map<uint64_t, Index>> node_map; unordered_map<int, Index> replace_map; return compress(pi, node_map, replace_map); } i64 linear_func_max(Index pi, vector<i64>& a, unordered_map<int, i64>& res){ if(pi == true_idx) return 0; if(pi == false_idx) return -INF; if(res.find(pi.idx) == res.end()) res[pi.idx] = max(linear_func_max(pool[pi].c[0], a, res), linear_func_max(pool[pi].c[1], a, res) + a[pool[pi].label]); return res[pi.idx]; } i64 linear_func_max(Index root, vector<i64> a){ unordered_map<int, i64> res; return linear_func_max(root, a, res); } i64 count(Index pi, unordered_map<int, i64>& res){ if(pi == true_idx) return 1; if(pi == false_idx) return 0; if(res.find(pi.idx) == res.end()) res[pi.idx] = count(pool[pi].c[0], res) + count(pool[pi].c[1], res); return res[pi.idx]; } i64 count(Index root){ unordered_map<int, i64> res; return count(root, res); } double get_per(Index pi, vector<double>& a, unordered_map<int, double>& res){ if(pi == true_idx) return 1; if(pi == false_idx) return 0; if(res.find(pi.idx) == res.end()) res[pi.idx] = a[pool[pi].label] * get_per(pool[pi].c[0], a, res) + (1.0 - a[pool[pi].label]) * get_per(pool[pi].c[1], a, res); return res[pi.idx]; } double get_per(Index root, vector<double>& a){ unordered_map<int, double> res; return get_per(root, a, res); } Index apply(Index li, Index ri, function<Index(Index, Index)>& f, unordered_map<uint64_t, Index>& node_map){ Index res = f(li, ri); if(res != nil) return res; int lidx, ridx; tie(lidx, ridx) = minmax(li.idx, ri.idx); uint64_t h = lidx | ((1uLL * ridx) << 32); if(node_map.find(h) != node_map.end()) return node_map[h]; if(pool[li].label < pool[ri].label) swap(li, ri); auto& l = pool[li]; auto& r = pool[ri]; Index idx = pool.alloc(); if(l.label == r.label) pool[idx] = Node(l.label, apply(l.c[0], r.c[0], f, node_map), apply(l.c[1], r.c[1], f, node_map)); else pool[idx] = Node(l.label, apply(l.c[0], ri, f, node_map), apply(l.c[1], ri, f, node_map)); return node_map[h] = idx; } Index apply_and(Index li, Index ri){ unordered_map<uint64_t, Index> node_map; function<Index(Index, Index)> f = [&](Index li, Index ri){ if(li == true_idx) return ri; if(ri == true_idx) return li; if(li == false_idx || ri == false_idx) return false_idx; else return nil; }; return apply(li, ri, f, node_map); } Index apply_or(Index li, Index ri){ unordered_map<uint64_t, Index> node_map; function<Index(Index, Index)> f = [&](Index li, Index ri){ if(li == false_idx) return ri; if(ri == false_idx) return li; if(li == true_idx || ri == true_idx) return true_idx; else return nil; }; return apply(li, ri, f, node_map); } };
#line 1 "lib/classes/zdd.cpp" struct ZDD{ struct Node{ int label; using Index = typename MemoryPool<Node>::Index; Index c[2]; Node(){} Node(int label, Index nil) : label(label){ c[0] = c[1] = nil; } Node(int label, Index l, Index r) : label(label){ c[0] = l; c[1] = r; } }; using Index = typename MemoryPool<Node>::Index; Index nil, true_idx, false_idx; MemoryPool<Node> pool; // 引数は(Data, label, 最後どう進んだか) // secondが-1ならDataを表す, 0/1なら終端状態を表す ZDD(){ nil = {-1}; true_idx = {-2}; false_idx = {-3}; } template<typename Data> Index build(function<pair<Data, int>(Data, int, bool)> f, Data init){ unordered_map<int, map<Data, Index>> node_map; Index root = pool.alloc(); pool[root] = Node(0, nil); node_map[0][init] = root; build<Data>(root, init, f, node_map); return compress(root); } template<typename Data> void build(Index pi, Data& d, function<pair<Data, int>(Data, int, bool)>& f, unordered_map<int, map<Data, Index>>& node_map){ auto& x = pool[pi]; for(int i = 0; i < 2; ++i){ int fl; Data y; tie(y, fl) = f(d, x.label, i); if(fl == 0){ x.c[i] = false_idx; } else if(fl == 1){ x.c[i] = true_idx; } else if(node_map[x.label + 1].find(y) == node_map[x.label + 1].end()){ Index new_idx = pool.alloc(); pool[new_idx] = Node(x.label + 1, nil); node_map[x.label + 1][y] = new_idx; x.c[i] = new_idx; build(node_map[x.label + 1][y], y, f, node_map); }else{ x.c[i] = node_map[x.label + 1][y]; } } } Index compress(Index pi, unordered_map<int, unordered_map<uint64_t, Index>>& node_map, unordered_map<int, Index>& replace_map){ // 同型で子が一致する頂点を破滅 if(pi == true_idx || pi == false_idx) return pi; if(replace_map.find(pi.idx) != replace_map.end()) return replace_map[pi.idx]; auto& p = pool[pi]; Index li = compress(p.c[0], node_map, replace_map); Index ri = compress(p.c[1], node_map, replace_map); p.c[0] = li; p.c[1] = ri; uint64_t h = uint32_t(li.idx) | ((1uLL * uint32_t(ri.idx)) << 32); if(node_map[p.label].find(h) == node_map[p.label].end()) node_map[p.label][h] = pi; else pool.free(pi); return replace_map[pi.idx] = node_map[p.label][h]; } Index compress(Index pi){ unordered_map<int, unordered_map<uint64_t, Index>> node_map; unordered_map<int, Index> replace_map; return compress(pi, node_map, replace_map); } i64 linear_func_max(Index pi, vector<i64>& a, unordered_map<int, i64>& res){ if(pi == true_idx) return 0; if(pi == false_idx) return -INF; if(res.find(pi.idx) == res.end()) res[pi.idx] = max(linear_func_max(pool[pi].c[0], a, res), linear_func_max(pool[pi].c[1], a, res) + a[pool[pi].label]); return res[pi.idx]; } i64 linear_func_max(Index root, vector<i64> a){ unordered_map<int, i64> res; return linear_func_max(root, a, res); } i64 count(Index pi, unordered_map<int, i64>& res){ if(pi == true_idx) return 1; if(pi == false_idx) return 0; if(res.find(pi.idx) == res.end()) res[pi.idx] = count(pool[pi].c[0], res) + count(pool[pi].c[1], res); return res[pi.idx]; } i64 count(Index root){ unordered_map<int, i64> res; return count(root, res); } double get_per(Index pi, vector<double>& a, unordered_map<int, double>& res){ if(pi == true_idx) return 1; if(pi == false_idx) return 0; if(res.find(pi.idx) == res.end()) res[pi.idx] = a[pool[pi].label] * get_per(pool[pi].c[0], a, res) + (1.0 - a[pool[pi].label]) * get_per(pool[pi].c[1], a, res); return res[pi.idx]; } double get_per(Index root, vector<double>& a){ unordered_map<int, double> res; return get_per(root, a, res); } Index apply(Index li, Index ri, function<Index(Index, Index)>& f, unordered_map<uint64_t, Index>& node_map){ Index res = f(li, ri); if(res != nil) return res; int lidx, ridx; tie(lidx, ridx) = minmax(li.idx, ri.idx); uint64_t h = lidx | ((1uLL * ridx) << 32); if(node_map.find(h) != node_map.end()) return node_map[h]; if(pool[li].label < pool[ri].label) swap(li, ri); auto& l = pool[li]; auto& r = pool[ri]; Index idx = pool.alloc(); if(l.label == r.label) pool[idx] = Node(l.label, apply(l.c[0], r.c[0], f, node_map), apply(l.c[1], r.c[1], f, node_map)); else pool[idx] = Node(l.label, apply(l.c[0], ri, f, node_map), apply(l.c[1], ri, f, node_map)); return node_map[h] = idx; } Index apply_and(Index li, Index ri){ unordered_map<uint64_t, Index> node_map; function<Index(Index, Index)> f = [&](Index li, Index ri){ if(li == true_idx) return ri; if(ri == true_idx) return li; if(li == false_idx || ri == false_idx) return false_idx; else return nil; }; return apply(li, ri, f, node_map); } Index apply_or(Index li, Index ri){ unordered_map<uint64_t, Index> node_map; function<Index(Index, Index)> f = [&](Index li, Index ri){ if(li == false_idx) return ri; if(ri == false_idx) return li; if(li == true_idx || ri == true_idx) return true_idx; else return nil; }; return apply(li, ri, f, node_map); } };