#include <unordered_map>
class DisjointSet
{
public:
DisjointSet();
~DisjointSet();
void insert(int item);
int find(int item);
void join(int a1, int a2);
bool isConnected(int a1, int a2);
private:
std::unordered_map<int, int> parentMap;
std::unordered_map<int, int> rankMap;
};
DisjointSet::DisjointSet()
{
}
DisjointSet::~DisjointSet()
{
}
void DisjointSet::insert(int item)
{
parentMap[item] = item;
rankMap[item] = 1;
}
int DisjointSet::find(int item)
{
if (parentMap[item] != item)
{
parentMap[item] = find(parentMap[item]);
}
return parentMap[item];
}
void DisjointSet::join(int a1, int a2)
{
int root1 = find(a1);
int root2 = find(a2);
if (rankMap[root1] < rankMap[root2])
{
parentMap[root1] = root2;
}
else if (rankMap[root1] > rankMap[root2])
{
parentMap[root2] = root1;
}
else
{
parentMap[root1] = root2;
rankMap[root2] += 1;
}
}
bool DisjointSet::isConnected(int a1, int a2)
{
return find(a1) == find(a2);
}
网友评论