Submission #6404218


Source Code Expand

#include <bits/stdc++.h>

using namespace std;

template <class T>inline T updmax(T &a, T b) { return a = max(a, b); }
template <class T>inline T updmin(T &a, T b) { return a = min(a, b); }

template <unsigned long long mod> class modint {
public:
	unsigned long long v;
	modint(const long long x = 0) : v(x % mod) {}
	modint operator+(const modint rhs) { return modint(*this) += rhs; }
	modint operator-(const modint rhs) { return modint(*this) -= rhs; }
	modint operator*(const modint rhs) { return modint(*this) *= rhs; }
	modint operator/(const modint rhs) { return modint(*this) /= rhs; }
	modint operator-() { return modint(mod-this->v); }
	modint &operator+=(const modint rhs) {
		v += rhs.v;
		if (v >= mod)v -= mod;
		return *this;
	}
	modint &operator-=(const modint rhs) {
		if (v < rhs.v)v += mod;
		v -= rhs.v;
		return *this;
	}
	modint &operator*=(const modint rhs) {
		v = v * rhs.v % mod;
		return *this;
	}
	modint inverse(modint a) {
		unsigned long long exp = mod - 2;
		modint ret(1ULL);
		while (exp) {
			if (exp % 2) {
				ret *= a;
			}
			a *= a;
			exp >>= 1;
		}
		return ret;
	}
	modint &operator/=(modint rhs) {
		(*this) *= inverse(rhs);
		return *this;
	}
};

template<class T>
T fastpow(T a, long long p){
	T tmp = a;
	T ret = 1;
	while(p){
		if(p & 1)ret *= tmp;
		p >>= 1;
		tmp = tmp * tmp;
	}
	return ret;
}

const int MOD = 998244353;
using mint = modint<MOD>;

char s[200005];
mint dp[200005][3][3];

int main()
{
	scanf("%s", &s);
	int n = strlen(s);
	bool allsame = true, same = false;
	for(int i=0; i < n-1; i++){
		if(s[i] == s[i+1])same = true;
		else allsame = false;
	}
	if(allsame){
		printf("1\n");
		return 0;
	}
	if(n == 2){
		printf("2\n");
		return 0;
	}else if(n == 3){
		printf("%d", same?6:3);
		return 0;
	}else{
		dp[0][0][0] = dp[0][1][1] = dp[0][2][2] = 1;
		for(int i=0; i < n-1; i++){
			for(int j=0; j < 3; j++){
				for(int k=0; k < 3; k++){
					for(int l=0; l < 3; l++){
						if(l != k)dp[i+1][(j+l)%3][l] += dp[i][j][k];
					}
				}
			}
		}
		int sum = 0;
		for(int i=0; i < n; i++)sum += s[i] - 'a';
		sum %= 3;
		mint ret = fastpow(mint{3}, n-1) - dp[n-1][sum][0] - dp[n-1][sum][1] - dp[n-1][sum][2] + (1 - same);
		printf("%lld", ret.v);
	}
	return 0;
}

Submission Info

Submission Time
Task F - Normalization
User UminchuR
Language C++14 (Clang 3.8.0)
Score 0
Code Size 2328 Byte
Status CE

Compile Error

./Main.cpp:1:10: fatal error: 'bits/stdc++.h' file not found
#include <bits/stdc++.h>
         ^
1 error generated.