Fast I/O 구현 코드

안녕하세요 jinhan814입니다. 제가 문제를 풀면서 자주 사용하는 FastIO 구현 코드를 공유하면 좋을 거 같아서 BOJ 블로그에 글을 작성해보았습니다.؜

(참고 : https://blog.naver.com/jinhan814/222266396476)

؜

서론

백준 문제를 풀다 보면 '시간 초과'를 정말 많이 만나게 됩니다. 이렇게 시간 초과가 나는 경우는 크게 세 가지가 있습니다.

  1. 무한 루프가 생기는 경우
  2. 문제에서 요구하는 시간복잡도를 맞추지 못한 경우
  3. 시간복잡도는 맞췄지만, 시간 초과가 나는 경우

이번 게시글에서는 3번 케이스처럼 아쉽게 시간 초과가 나는 경우, 그중 문제의 입출력 양이 매우 많은 경우에 유용하게 사용할 수 있는 코드를 알아보겠습니다.

؜

본론

입출력 연산은 c++의 연산 중 가장 무거운 연산 중 하나입니다. 일반적인 연산은 대략 1초에 1억 번 정도의 연산이 가능하다고 알려져있지만, 입출력의 경우엔 입력되는 값이 10만 개만 되더라도 실행 시간에 적지 않은 영향을 미칩니다. (참고)

이런 경우 c++의 cin, cout 대신 fread, fwrite 등의 Low-Level I/O 함수를 이용하면 입출력을 더 빠르게 처리할 수 있습니다. 하지만 이런 로우 레벨 함수들은 cin, cout처럼 다양한 자료형을 처리하는 기능이 구현되어있지 않아서 char[]에 입력을 받은 뒤 직접 파싱을 해주면서 사용할 자료형으로 변환하는 작업이 필요합니다.

다행히도 이런 변환 작업을 구현해둔 코드가 여러 개 있어서 이 부분은 그냥 구현 코드를 가져다 쓰면 됩니다. 이제 제 FastIO 구현 코드를 알아보겠습니다. 본 구현 코드는 topology님의 구현을 참고해 작성되었습니다.

؜

제 코드는 크게 4가지 부분으로 구성되어있습니다.

  1. fread(또는 mmap)를 이용해 입력을 처리하는 클래스
  2. fwrite(또는 write)를 이용해 출력을 처리하는 클래스
  3. cin, cout의 >>, << 연산자를 지원하기 위한 연산자 오버로딩
  4. 코드의 cin, cout을 대체하기 위한 매크로들

1, 2번 부분은 FastIO 구현을 참고하시면 어렵지 않게 이해할 수 있으실 겁니다. 3, 4번 부분에서는 cin, cout을 매크로로 정의해두고 연산자 오버로딩을 해뒀기 때문에 기존의 cin, cout을 이용해 입력을 받도록 작성된 코드에 그냥 구현 코드를 복사 붙여넣기만 하시면 cin, cout이 fread, fwrite로 구현된 빠른 입출력 함수로 대체됩니다.

؜

+) 빠른 A+B(BOJ 15552) 문제에서 실행 시간을 측정한 결과는 다음과 같습니다. 일반적으로 mmap 코드가 더 빠르지만 메모리를 더 많이 사용합니다.

  1. cin, cout + ios::sync_with_stdio(0), cin.tie(0) (244 ms, 코드)
  2. fread, fwrite (28 ms, 코드)
  3. mmap, write (24 ms, 코드)

؜

코드

//구현 1. fread/fwrite 이용
#pragma GCC optimize("O3")
#pragma GCC target("avx2")
#include <bits/stdc++.h>
#include <unistd.h>
using namespace std;

/////////////////////////////////////////////////////////////////////////////////////////////
/*
 * Author : jinhan814
 * Date : 2021-05-06
 * Source : https://blog.naver.com/jinhan814/222266396476
 * Description : FastIO implementation for cin, cout.
 */
constexpr int SZ = 1 << 20;

class INPUT {
private:
    char read_buf[SZ];
    int read_idx, next_idx;
    bool __END_FLAG__, __GETLINE_FLAG__;
public:
    explicit operator bool() { return !__END_FLAG__; }
    bool IsBlank(char c) { return c == ' ' || c == '\n'; }
    bool IsEnd(char c) { return c == '\0'; }
    char _ReadChar() {
        if (read_idx == next_idx) {
            next_idx = fread(read_buf, sizeof(char), SZ, stdin);
            if (next_idx == 0) return 0;
            read_idx = 0;
        }
        return read_buf[read_idx++];
    }
    char ReadChar() {
        char ret = _ReadChar();
        for (; IsBlank(ret); ret = _ReadChar());
        return ret;
    }
    template<typename T> T ReadInt() {
        T ret = 0; char cur = _ReadChar(); bool flag = 0;
        for (; IsBlank(cur); cur = _ReadChar());
        if (cur == '-') flag = 1, cur = _ReadChar();
        for (; !IsBlank(cur) && !IsEnd(cur); cur = _ReadChar()) ret = 10 * ret + (cur & 15);
        if (IsEnd(cur)) __END_FLAG__ = 1;
        return flag ? -ret : ret;
    }
    string ReadString() {
        string ret; char cur = _ReadChar();
        for (; IsBlank(cur); cur = _ReadChar());
        for (; !IsBlank(cur) && !IsEnd(cur); cur = _ReadChar()) ret.push_back(cur);
        if (IsEnd(cur)) __END_FLAG__ = 1;
        return ret;
    }
    double ReadDouble() {
        string ret = ReadString();
        return stod(ret);
    }
    string getline() {
        string ret; char cur = _ReadChar();
        for (; cur != '\n' && !IsEnd(cur); cur = _ReadChar()) ret.push_back(cur);
        if (__GETLINE_FLAG__) __END_FLAG__ = 1;
        if (IsEnd(cur)) __GETLINE_FLAG__ = 1;
        return ret;
    }
    friend INPUT& getline(INPUT& in, string& s) { s = in.getline(); return in; }
} _in;

class OUTPUT {
private:
    char write_buf[SZ];
    int write_idx;
public:
    ~OUTPUT() { Flush(); }
    explicit operator bool() { return 1; }
    void Flush() {
        fwrite(write_buf, sizeof(char), write_idx, stdout);
        write_idx = 0;
    }
    void WriteChar(char c) {
        if (write_idx == SZ) Flush();
        write_buf[write_idx++] = c;
    }
    template<typename T> int GetSize(T n) {
        int ret = 1;
        for (n = n >= 0 ? n : -n; n >= 10; n /= 10) ret++;
        return ret;
    }
    template<typename T> void WriteInt(T n) {
        int sz = GetSize(n);
        if (write_idx + sz >= SZ) Flush();
        if (n < 0) write_buf[write_idx++] = '-', n = -n;
        for (int i = sz; i-- > 0; n /= 10) write_buf[write_idx + i] = n % 10 | 48;
        write_idx += sz;
    }
    void WriteString(string s) { for (auto& c : s) WriteChar(c); }
    void WriteDouble(double d) { WriteString(to_string(d)); }
} _out;

/* operators */
INPUT& operator>> (INPUT& in, char& i) { i = in.ReadChar(); return in; }
INPUT& operator>> (INPUT& in, string& i) { i = in.ReadString(); return in; }
template<typename T, typename std::enable_if_t<is_arithmetic_v<T>>* = nullptr>
INPUT& operator>> (INPUT& in, T& i) {
    if constexpr (is_floating_point_v<T>) i = in.ReadDouble();
    else if constexpr (is_integral_v<T>) i = in.ReadInt<T>(); return in; }

OUTPUT& operator<< (OUTPUT& out, char i) { out.WriteChar(i); return out; }
OUTPUT& operator<< (OUTPUT& out, string i) { out.WriteString(i); return out; }
template<typename T, typename std::enable_if_t<is_arithmetic_v<T>>* = nullptr>
OUTPUT& operator<< (OUTPUT& out, T i) {
    if constexpr (is_floating_point_v<T>) out.WriteDouble(i);
    else if constexpr (is_integral_v<T>) out.WriteInt<T>(i); return out; }

/* macros */
#define fastio 1
#define cin _in
#define cout _out
#define istream INPUT
#define ostream OUTPUT
/////////////////////////////////////////////////////////////////////////////////////////////
//구현 2. mmap/write 이용
#pragma GCC optimize("O3")
#pragma GCC target("avx2")
#include <bits/stdc++.h>
#include <sys/stat.h>
#include <sys/mman.h>
#include <unistd.h>
using namespace std;

/////////////////////////////////////////////////////////////////////////////////////////////
/*
 * Author : jinhan814
 * Date : 2021-05-06
 * Source : https://blog.naver.com/jinhan814/222266396476
 * Description : FastIO implementation for cin, cout. (mmap ver.)
 */
constexpr int SZ = 1 << 20;

class INPUT {
private:
    char* p;
    bool __END_FLAG__, __GETLINE_FLAG__;
public:
    explicit operator bool() { return !__END_FLAG__; }
    INPUT() {
        struct stat st; fstat(0, &st);
        p = (char*)mmap(0, st.st_size, PROT_READ, MAP_SHARED, 0, 0);
    }
    bool IsBlank(char c) { return c == ' ' || c == '\n'; }
    bool IsEnd(char c) { return c == '\0'; }
    char _ReadChar() { return *p++; }
    char ReadChar() {
        char ret = _ReadChar();
        for (; IsBlank(ret); ret = _ReadChar());
        return ret;
    }
    template<typename T> T ReadInt() {
        T ret = 0; char cur = _ReadChar(); bool flag = 0;
        for (; IsBlank(cur); cur = _ReadChar());
        if (cur == '-') flag = 1, cur = _ReadChar();
        for (; !IsBlank(cur) && !IsEnd(cur); cur = _ReadChar()) ret = 10 * ret + (cur & 15);
        if (IsEnd(cur)) __END_FLAG__ = 1;
        return flag ? -ret : ret;
    }
    string ReadString() {
        string ret; char cur = _ReadChar();
        for (; IsBlank(cur); cur = _ReadChar());
        for (; !IsBlank(cur) && !IsEnd(cur); cur = _ReadChar()) ret.push_back(cur);
        if (IsEnd(cur)) __END_FLAG__ = 1;
        return ret;
    }
    double ReadDouble() {
        string ret = ReadString();
        return stod(ret);
    }
    string getline() {
        string ret; char cur = _ReadChar();
        for (; cur != '\n' && !IsEnd(cur); cur = _ReadChar()) ret.push_back(cur);
        if (__GETLINE_FLAG__) __END_FLAG__ = 1;
        if (IsEnd(cur)) __GETLINE_FLAG__ = 1;
        return ret;
    }
    friend INPUT& getline(INPUT& in, string& s) { s = in.getline(); return in; }
} _in;

class OUTPUT {
private:
    char write_buf[SZ];
    int write_idx;
public:
    ~OUTPUT() { Flush(); }
    explicit operator bool() { return 1; }
    void Flush() {
        write(1, write_buf, write_idx);
        write_idx = 0;
    }
    void WriteChar(char c) {
        if (write_idx == SZ) Flush();
        write_buf[write_idx++] = c;
    }
    template<typename T> int GetSize(T n) {
        int ret = 1;
        for (n = n >= 0 ? n : -n; n >= 10; n /= 10) ret++;
        return ret;
    }
    template<typename T> void WriteInt(T n) {
        int sz = GetSize(n);
        if (write_idx + sz >= SZ) Flush();
        if (n < 0) write_buf[write_idx++] = '-', n = -n;
        for (int i = sz; i --> 0; n /= 10) write_buf[write_idx + i] = n % 10 | 48;
        write_idx += sz;
    }
    void WriteString(string s) { for (auto& c : s) WriteChar(c); }
    void WriteDouble(double d) { WriteString(to_string(d)); }
} _out;

/* operators */
INPUT& operator>> (INPUT& in, char& i) { i = in.ReadChar(); return in; }
INPUT& operator>> (INPUT& in, string& i) { i = in.ReadString(); return in; }
template<typename T, typename std::enable_if_t<is_arithmetic_v<T>>* = nullptr>
INPUT& operator>> (INPUT& in, T& i) {
    if constexpr (is_floating_point_v<T>) i = in.ReadDouble();
    else if constexpr (is_integral_v<T>) i = in.ReadInt<T>(); return in; }

OUTPUT& operator<< (OUTPUT& out, char i) { out.WriteChar(i); return out; }
OUTPUT& operator<< (OUTPUT& out, string i) { out.WriteString(i); return out; }
template<typename T, typename std::enable_if_t<is_arithmetic_v<T>>* = nullptr>
OUTPUT& operator<< (OUTPUT& out, T i) {
    if constexpr (is_floating_point_v<T>) out.WriteDouble(i);
    else if constexpr (is_integral_v<T>) out.WriteInt<T>(i); return out; }

/* macros */
#define fastio 1
#define cin _in
#define cout _out
#define istream INPUT
#define ostream OUTPUT
/////////////////////////////////////////////////////////////////////////////////////////////

؜

사용 시 주의점

첫 번째 구현과 두 번째 구현 모두 헤더부터 밑의 ///...//까지 복사한 뒤 코드 상단에 붙여넣기해서 사용하시면 됩니다. (예시)

구현 코드 중 SZ는 입출력 버퍼의 크기를 지정하는 변수로 작게 잡으면 메모리를 적게 사용하는 대신 버퍼를 비우는 일이 많아져서 시간이 오래 걸리게 되고, 크게 잡으면 메모리를 많이 사용하는 대신 시간이 적게 걸리게 됩니다. 이 부분은 각 문제마다 적절히 SZ값을 바꿔가면서 설정해주시면 좋을 거 같습니다.

실수 입출력의 경우 구현에 stod, to_string 함수를 가져다 사용했기 때문에 fixed, setprecision등의 기능을 지원하지 않습니다.

p.s. 위 코드는 c++17에서 작성되었기 때문에 다른 버전에서는 잘 작동하지 않을 수 있습니다.

؜

여담

문제를 풀다보면 입력이 음이 아닌 정수만 들어오는 등의 제약조건이 있는 경우가 있습니다. 이런 경우엔 굳이 '-' 등이 존재하는지 확인할 필요가 없기 때문에 이런 부분을 제외하고 직접 파싱을 해주면 입출력을 더 빠르게 처리할 수 있습니다. (예시)

이외에도 문자열만 입력으로 들어오는 경우 char[]에 입력을 받은 뒤 입력 버퍼 자체를 문자열로 이용해주면 string을 이용하는 것보다 빠르게 문제를 풀 수 있습니다. (예시)

자세한 구현 원리가 궁금하시거나 위의 예시처럼 문제에 맞게 FastIO를 커스텀해서 사용하고 싶으시다면 우측의 링크를 참고해주세요. (참고)

؜

reference

topology님 FastIO 구현 코드 (제출 번호 5178388번)

https://cgiosy.github.io/posts/fast-io/

https://panty.run/fastio/


댓글 (7개) 댓글 쓰기


gkswns3708 3년 전

최곱니다


ejqmf94 3년 전

감사합니다


ilbul2 3년 전

혹시 참새를 싫어하나요?


hs9200 2년 전

미쳤다...!


notfound404 2년 전

잘 배워갑니다.


dohoon 2년 전

백준 블로그 중 가장 유익했어요


pianop 2년 전

감사합니다 선생님! 열심히 공부해서 잘 사용할 수 있도록 하겠습니다!