library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub gyouzasushi/library

:heavy_check_mark: 最小共通祖先(LCA)
(graph/lowest_common_ancestor.hpp)

概要

使い方

計算量

Depends on

Verified with

Code

#pragma once
#include <cassert>
#include <vector>

#include "../datastructure/plus_minus_one_range_minimum.hpp"
struct lowest_common_ancestor {
public:
    lowest_common_ancestor() {
    }
    lowest_common_ancestor(int n, int root = 0)
        : _n(n), _root(root), g(n), id(n), vs(2 * n - 1), dep(2 * n - 1) {
    }
    void add_edge(int from, int to) {
        assert(0 <= from && from < _n);
        assert(0 <= to && to < _n);
        g[from].push_back(to);
        g[to].push_back(from);
    }
    void build() {
        int k = 0;
        auto dfs = [&](auto dfs, int pos, int pre, int d) -> void {
            id[pos] = k;
            vs[k] = pos;
            dep[k++] = d;
            for (int nxt : g[pos]) {
                if (nxt == pre) continue;
                dfs(dfs, nxt, pos, d + 1);
                vs[k] = pos;
                dep[k++] = d;
            }
        };
        dfs(dfs, _root, -1, 0);
        rmq.init(dep);
    }

    int get(int u, int v) {
        int l = std::min(id[u], id[v]);
        int r = std::max(id[u], id[v]) + 1;
        return vs[rmq.prod(l, r)];
    }
    int get(int u, int v, int r) {
        return get(r, u) ^ get(u, v) ^ get(v, r);
    }
    int depth(int u) {
        return dep[id[u]];
    }
    int dist(int u, int v) {
        return depth(u) + depth(v) - 2 * depth(get(u, v));
    }

private:
    int _n, _root;
    std::vector<std::vector<int>> g;
    std::vector<int> id, vs, dep;
    PlusMinusOneRMQ rmq;
};
#line 2 "graph/lowest_common_ancestor.hpp"
#include <cassert>
#include <vector>

#line 2 "datastructure/plus_minus_one_range_minimum.hpp"
#include <cmath>
#line 4 "datastructure/plus_minus_one_range_minimum.hpp"

#line 3 "datastructure/static_range_minimum.hpp"
struct StaticRMQ {
public:
    void init(const std::vector<std::pair<int, int>>& _v) {
        _n = int(_v.size()), d.resize(_n), ceil_log2.resize(_n + 1);
        ceil_log2[0] = 0;
        ceil_log2[1] = 0;
        for (int i = 2; i <= _n; i++) ceil_log2[i] = ceil_log2[i >> 1] + 1;
        for (int i = 0; i < _n; i++) {
            d[i].resize(ceil_log2[_n] + 1);
            d[i][0] = _v[i];
        }
        for (int b = 0; b < ceil_log2[_n]; b++) {
            for (int i = 0; i < _n; i++) {
                if (i + (1 << (b + 1)) > _n) break;
                d[i][b + 1] = std::min(d[i][b], d[i + (1 << b)][b]);
            }
        }
    }
    std::pair<int, int> prod(int l, int r) {
        if (!(l < r)) return PINF;
        int b = ceil_log2[r - l];
        return std::min(d[l][b], d[r - (1 << b)][b]);
    }

private:
    int _n;
    std::vector<std::vector<std::pair<int, int>>> d;
    std::vector<int> ceil_log2;
    const std::pair<int, int> PINF = {1 << 30, 1 << 30};
};
#line 6 "datastructure/plus_minus_one_range_minimum.hpp"
struct PlusMinusOneRMQ {
public:
    void init(const std::vector<int>& _v) {
        _n = int(_v.size());
        v = _v;
        s = std::max(1, int(std::log2(_n) / 2));
        B = (_n + s - 1) / s;
        std::vector<std::pair<int, int>> _spt(B);
        pattern.resize(B);
        for (int i = 0; i < _n; i += s) {
            int min_j = i;
            int bit = 0;
            for (int j = i; j < std::min(_n, i + s); j++) {
                if (v[j] < v[min_j]) min_j = j;
                if (j + 1 < std::min(_n, i + s) && v[j] < v[j + 1])
                    bit |= 1 << (j - i);
            }
            _spt[i / s] = {v[min_j], min_j};
            pattern[i / s] = bit;
        }
        sparse_table.init(_spt);

        lookup_table.resize(1 << (s - 1));
        for (int bit = 0; bit < (1 << (s - 1)); bit++) {
            lookup_table[bit].resize(s + 1);
            for (int l = 0; l <= s; l++) {
                lookup_table[bit][l].resize(s + 1, INF);
                int min_ = 0;
                int min_i = l;
                int sum = 0;
                for (int r = l + 1; r <= s; r++) {
                    lookup_table[bit][l][r] = min_i;
                    sum += bit >> (r - 1) & 1 ? 1 : -1;
                    if (sum < min_) {
                        min_ = sum;
                        min_i = r;
                    }
                }
            }
        }
    }
    int prod(int l, int r) {
        int m1 = (l + s - 1) / s;
        int m2 = r / s;
        int l1 = s * m1;
        int r1 = s * m2;
        if (m2 < m1) {
            return lookup_table[pattern[m2]][l - r1][r - r1] + r1;
        }
        int ret = INF;
        if (m1 > 0) {
            ret = argmin(
                ret, lookup_table[pattern[m1 - 1]][s - (l1 - l)][s] + l1 - s);
        }
        ret = argmin(ret, sparse_table.prod(m1, m2).second);
        if (m2 < B) {
            ret = argmin(ret, lookup_table[pattern[m2]][0][r - r1] + r1);
        }
        return ret;
    }

private:
    int _n;
    int s, B;
    StaticRMQ sparse_table;
    std::vector<std::vector<std::vector<int>>> lookup_table;
    std::vector<int> pattern, v;
    const int INF = 1 << 30;
    int argmin(int i, int j) {
        if (i >= INF || j >= INF || v[i] == v[j]) return std::min(i, j);
        return v[i] < v[j] ? i : j;
    }
};
#line 6 "graph/lowest_common_ancestor.hpp"
struct lowest_common_ancestor {
public:
    lowest_common_ancestor() {
    }
    lowest_common_ancestor(int n, int root = 0)
        : _n(n), _root(root), g(n), id(n), vs(2 * n - 1), dep(2 * n - 1) {
    }
    void add_edge(int from, int to) {
        assert(0 <= from && from < _n);
        assert(0 <= to && to < _n);
        g[from].push_back(to);
        g[to].push_back(from);
    }
    void build() {
        int k = 0;
        auto dfs = [&](auto dfs, int pos, int pre, int d) -> void {
            id[pos] = k;
            vs[k] = pos;
            dep[k++] = d;
            for (int nxt : g[pos]) {
                if (nxt == pre) continue;
                dfs(dfs, nxt, pos, d + 1);
                vs[k] = pos;
                dep[k++] = d;
            }
        };
        dfs(dfs, _root, -1, 0);
        rmq.init(dep);
    }

    int get(int u, int v) {
        int l = std::min(id[u], id[v]);
        int r = std::max(id[u], id[v]) + 1;
        return vs[rmq.prod(l, r)];
    }
    int get(int u, int v, int r) {
        return get(r, u) ^ get(u, v) ^ get(v, r);
    }
    int depth(int u) {
        return dep[id[u]];
    }
    int dist(int u, int v) {
        return depth(u) + depth(v) - 2 * depth(get(u, v));
    }

private:
    int _n, _root;
    std::vector<std::vector<int>> g;
    std::vector<int> id, vs, dep;
    PlusMinusOneRMQ rmq;
};
Back to top page