下雪了

洛谷P3384 – 树链剖分(模板)

树链剖分模板题emmm

(又因为int什么的wa了啊….以后#define int long long 好了……)

AC代码:

#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<string>
#include<vector>
#include<cstdio>
#include<stack>
#include<cmath>
#include<ctime>
#include<queue>
#include<deque>
#include<list>
#include<map>
#define int long long
#define ffor(i, a, b) for(int i = a; i <= b; i++)
#define rfor(i, a, b) for(int i = a; i >= b; i--)
#define mes(a,b) memset(a, b, sizeof(a))
#define cos(x) cos(x*PI/180.0)
#define sin(x) sin(x*PI/180.0)
#define stop system("pause")
#define see(s,x) cout<<(s)<<'='<<(x)<<endl
#define IMAX 0x7fffffff
#define PI 3.141592654
#define INF 0x3f3f3f3f
#define eps 1e-6
#define lowbit(x) (x&(-x))
typedef long long ll;
ll mod;
ll max(ll a, ll b) { return(a > b) ? a : b; }
ll min(ll a, ll b) { return(a < b) ? a : b; }
using namespace std;

struct unit
{
    int l;
    int r;
    ll sum;
    ll lazy;
};
unit tab[100005<<2];

vector<int> G[100005];
int fa[100005];
int dep[100005];
int sz[100005];
int rk[100005];
int id[100005];
int top[100005];
int son[100005];
int w[100005];

int n, q, rt, root;
int cnt;

void dfs1(int x, int f, int d)//处理出fa、dep
{
    dep[x] = d;
    fa[x] = f;
    sz[x] = 1;
    int maxson = -1;
    for (auto i : G[x])
    {
        if (i == f)continue;
        dfs1(i, x, d + 1);
        sz[x] += sz[i];
        if (sz[i] > maxson)
        {
            son[x] = i;
            maxson = sz[i];
        }
    }
}

void dfs2(int x, int t)//处理出rk、top、id
{
    id[x] = ++cnt;
    rk[cnt] = w[x];
    top[x] = t;
    if (!son[x])
        return;
    dfs2(son[x], t);
    for (auto i:G[x]) 
    {
        if (i != fa[x] && i != son[x])
            dfs2(i, i);
    }
}

void build(int &x,int l,int r)
{
    x = ++cnt;
    tab[x].lazy = 0;
    if (l == r)
    {
        tab[x].sum = rk[l] % mod;
        return;
    }
    int m = (l + r) >> 1;
    build(tab[x].l, l, m);
    build(tab[x].r, m + 1, r);
    tab[x].sum = (tab[tab[x].l].sum + tab[tab[x].r].sum) % mod;
}

void update(int x, int L, int R, int l, int r, ll value)// 汇总子节点的时候,要注意加上自己的lazy
{
    if (L <= l && R >= r)
    {
        tab[x].sum = (tab[x].sum + (r - l + 1) * value) % mod;
        tab[x].lazy += value;
        return;
    }
    int m = (l + r) >> 1;
    if (L <= m) update(tab[x].l, L, R, l, m, value);
    if (R > m) update(tab[x].r, L, R, m + 1, r, value);
    tab[x].sum = (tab[tab[x].l].sum + tab[tab[x].r].sum%mod + (tab[x].lazy * (r - l + 1))%mod) % mod;
}

ll check(int x, int L, int R, int l, int r, ll add)// 用add把每一层的lazy相加(*很重要!*)
{
    if (L <= l && R >= r)
        return (tab[x].sum + add * (r - l + 1)) % mod;
    int m = (l + r) >> 1;
    ll sum = 0;
    if (L <= m) sum = (sum + check(tab[x].l, L, R, l, m, (add + tab[x].lazy)%mod)) % mod;
    if (R > m) sum = (sum + check(tab[x].r, L, R, m + 1, r, (add + tab[x].lazy)%mod)) % mod;
    return sum;
}

void lpls(int x,int y,int z)
{
    z %= mod;
    while (top[x] != top[y])
    {
        if (dep[top[x]] < dep[top[y]])
            swap(x, y);
        update(rt, id[top[x]], id[x], 1, n, z);
        x = fa[top[x]];
    }
    if (dep[x] > dep[y])
        swap(x, y);
    update(rt, id[x], id[y], 1, n, z);
}

ll lcheck(int x, int y)
{
    ll res = 0;
    while (top[x] != top[y])
    {
        if (dep[top[x]] < dep[top[y]])
            swap(x, y);
        res = (res + check(rt, id[top[x]], id[x], 1, n, 0)) % mod;
        x = fa[top[x]];
    }
    if (dep[x] > dep[y])
        swap(x, y);
    res = (res + check(rt, id[x], id[y], 1, n, 0)) % mod;
    return res;
}

void spls(int x, int y)
{
    update(rt, id[x], id[x] + sz[x] - 1, 1, n, y);
}

ll scheck(int x)
{
    return check(rt, id[x], id[x]+sz[x]-1, 1, n, 0) % mod;
}

signed main()
{
    cin >> n >> q >> root >> mod;
    ffor(i, 1, n)
        scanf("%lld", &w[i]);
    ffor(i, 1, n - 1)
    {
        int x, y;
        scanf("%lld%lld", &x, &y);
        G[x].push_back(y);
        G[y].push_back(x);
    }
    cnt = 0;
    dfs1(root, 0, 1);
    dfs2(root, root);
    cnt = 0;
    build(rt, 1, n);
    ffor(i, 1, q)
    {
        int o;
        scanf("%lld", &o);
        int x, y, z;
        switch (o)
        {
        case 1:
            scanf("%lld%lld%lld", &x, &y, &z);
            lpls(x, y, z);
            break;
        case 2:
            scanf("%lld%lld", &x, &y);
            printf("%lld\n",lcheck(x, y));
            break;
        case 3:
            scanf("%lld%lld", &x, &y);
            spls(x, y);
            break;
        default:
            scanf("%lld", &x);
            printf("%lld\n", scheck(x));
            break;
        }
    }
    return 0;
}

 


Add Your Comment

* Indicates Required Field

Your email address will not be published.

*