Home BBPE笔记记录
Post
Cancel

BBPE笔记记录

Code: https://github.com/OctopusMind/BBPE

1、正则表达式分词

支持多语种分词

pat_str = r"'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"

各部分含义如下:

  1. 's|'t|'re|'ve|'m|'ll|'d
    • 匹配一些英文中最常见的缩写后缀(类似于词素/词缀):
      • ’s(is 或 has 的缩写,如 it’s, John’s)
      • ‘t(not 的缩写,如 can’t, won’t)
      • ‘re(are 的缩写,如 you’re)
      • ‘ve(have 的缩写,如 I’ve)
      • ‘m(am 的缩写,如 I’m)
      • ‘ll(will 的缩写,如 I’ll)
      • ‘d(would 或 had 的缩写,如 I’d)
  2. ?[\p{L}]+
    • 匹配以可选空格开始,后面连着至少一个“字母”(Unicode类别 L,涵盖所有语言的字母)。
    • 例: hello、你好、Привет。
    • 注意前面的空格是可选的,目的是保留词前可能的空格信息(对BPE分词很关键)。
  3. ?[\p{N}]+
    • 匹配以可选空格开始,后面连着至少一个“数字”(Unicode类别 N)。
    • 例: 2024、345(全角数字也可,因范围广)。
    • 同样空格是 token 的一部分。
  4. ?[^\s\p{L}\p{N}]+
    • 匹配以可选空格开始,后面是至少一个既不是空白(\s),也不是字母(\p{L}),也不是数字(\p{N})的字符。
    • 例:标点、符号,如: !、,、@# 等。
  5. \s+(?!\S)
    • 匹配连续的空白字符(\s+),但后面不能跟非空白字符((?!\S))。即只能匹配那些在文本末尾的空格。
    • 例:句子末尾的多个空格。
  6. \s+
    • 匹配一个或多个空白字符。用于捕获普通的空格分隔。
    • 注意:整体顺序也很重要,整体分词时通常优先匹配前面的规则。

2、utf-8字节值到unicode可打印字符的映射

用于初始化词汇表:BBPE以字节为单位,初始化词汇表大小为256(0x00 -> 0xFF),将(0->255)映射到unicode字符串

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def bytes_to_unicode():
    """
    返回utf-8字节列表和到unicode字符串的映射。我们特别避免映射到bbpe代码所依赖的空白/控制字符。
    可逆的bbpe代码在unicode字符串上工作。这意味着如果您想避免UNKs,您需要在您的词汇表中使用大量的unicode字符。
    当你有一个10B的token数据集时,你最终需要大约5K才能获得良好的覆盖。这是你正常情况下的一个显著比例,
    比如说,32K的词汇量。为了避免这种情况,我们希望查找表介于utf-8字节和unicode字符串之间。
    """
    # 初始化bs均为可打印字符
    bs = (
            # 33 - 126
            list(range(ord("!"), ord("~") + 1)) + \
            # 162 - 172
            list(range(ord("¡"), ord("¬") + 1)) + \
            # 174 - 255
            list(range(ord("®"), ord("ÿ") + 1))
    )
    cs = bs[:]
    n = 0
    # 处理非可打印字符(映射到其他可打印的unicode字符)
    for b in range(2 ** 8):
        # 如果不在bs中,说明是非可打印字符,需要映射到其他可打印的unicode字符 (256 + n)
        if b not in bs:
            bs.append(b)
            cs.append(2 ** 8 + n)
            n += 1
    # 将cs转换为unicode字符串
    cs = [chr(n) for n in cs]
    # 构建bs到cs的映射
    return dict(zip(bs, cs))

3、词汇表训练

从初始词表(字节的所有表示:0x00->0xff)开始,统计语料中相邻词汇的共现频次,然后合并频次最高的一对作为新的词汇加入到词汇表中,迭代进行直到满足预设要求(如词汇表大小)。

BPE的最小词汇是字符级别,BBPE的最小词汇是字节级别。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    @staticmethod
    def train_tokenizer(data, vocab_size, vocab_outfile=None, merges_outfile=None):
        """
        :param data: 训练文本
        :param vocab_size: 保留词表的大小
        :param vocab_outfile: 保存词表的文件名
        :param merges_outfile: 保存合并字节的词表
        """

        if vocab_size < 256:
            raise ValueError("vocab_size must be greater than 256")

        # 字节到unicode字符映射
        byte_encoder = bytes_to_unicode()
        # 正则表达式分词模式
        pat_str = r"'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
        # 分词并转换为字节,并映射到unicode字符(通过byte_encoder)
        split_words = [
            [byte_encoder[b] for b in token.encode("utf-8")] for token in re.findall(pat_str, data)
        ]
        # 词汇表初始化为基本的词汇,即bytes_to_unicode所包含的内容
        vocab = set(byte_encoder.values())
        merges = []

        # 构建词汇表,直到满足所需的词汇量
        while len(vocab) < vocab_size:
            print(len(vocab))
            pair_freq = Counter()
            # 找出出现频次最多的字节对
            for split_word in split_words:
                pair_freq.update(zip(split_word[:-1], split_word[1:]))
            most_common_pair = pair_freq.most_common(1)[0][0]

            # 更新词汇表和合并列表
            new_token = most_common_pair[0] + most_common_pair[1]
            # 添加新词汇
            vocab.add(new_token)
            # 记录合并操作
            merges.append(most_common_pair)

            # 对数据执行合并
            new_split_words = []
            for split_word in split_words:
                i = 0
                new_word = []
                # 对于单词中的每个重字符,尝试合并
                while i < len(split_word) - 1:
                    # 如果(split_word[i], split_word[i + 1])与新合并的词汇一致,则添加新词汇,并跳过下一个字符
                    if (split_word[i], split_word[i + 1]) == most_common_pair:
                        new_word.append(new_token)
                        i += 2
                    else:
                        new_word.append(split_word[i])
                        i += 1
                # 边界处理,如果i等于split_word的长度-1,则添加最后一个字符
                if i == len(split_word) - 1:
                    new_word.append(split_word[i])
                new_split_words.append(new_word)
            # 更新split_words,用于下一轮迭代
            split_words = new_split_words

        vocab = sorted(list(vocab))
        # 保存文件
        if merges_outfile != None:
            with open(merges_outfile, "w", encoding="utf-8") as f:
                for merge in merges:
                    f.write(merge[0] + " " + merge[1] + "\n")
        if vocab_outfile != None:
            with open(vocab_outfile, "w", encoding="utf-8") as f:
                json.dump({v: i for i, v in enumerate(vocab)}, f, ensure_ascii=False)

4、字符串编解码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
class BBPETokenizer(nn.Module):

    def __init__(self, vocab_path: str, merges_path: str):
        super().__init__()
        with open(vocab_path, "r", encoding="utf-8") as f:  # 获得词表
            vocab = json.load(f)
        with open(merges_path, "r", encoding="utf-8") as f:  # 获得合并token规则词表
            merges = f.read()

        # 将合并存储为元组列表,删除最后一个空白行
        merges = [tuple(merge_str.split()) for merge_str in merges.split("\n")[:-1]]

        # token到BBPE解码索引映射
        self.encoder = vocab
        self.decoder = {v: k for k, v in self.encoder.items()}

        # 字节到unicode字符映射,256个字符
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}

        self.bbpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {}

        # 预标记化拆分正则表达式模式
        self.pat = re.compile(r"""
                                 's|'t|'re|'ve|'m|'ll|'d|  # 常见的缩写
                                 \ ?\p{L}+|\ ?\p{N}+|  # 可选空格,后跟1+ unicode字母或数字
                                 \ ?[^\s\p{L}\p{N}]+|  # 可选空格,后面跟着1+非空白/字母/数字
                                 \s+(?!\S)|  # 1+空白字符,后面没有非空白字符
                                 \s+  # 1+空格字符
                                 """, re.X)

    def forward(self, text):
        if isinstance(text, list):
            # 批量编码
            tokens = self.encode_batch(text)
            tokens = [token for row in tokens for token in row]
        else:
            # 编码字符串
            tokens = self.encode(text)
        return torch.tensor(tokens)

    def bbpe(self, token):
        '''
        对token应用合并规则
        '''
        if token in self.cache:
            return self.cache[token]

        chars = [i for i in token]
        # 对于每个合并规则,尝试合并任何相邻的字符对
        for pair in self.bbpe_ranks.keys():
            i = 0
            while i < len(chars) - 1:
                if chars[i] == pair[0] and chars[i + 1] == pair[1]:
                    chars = chars[:i] + ["".join(pair)] + chars[i + 2:]
                else:
                    i += 1
        self.cache[token] = chars
        return chars

    def encode(self, text: str) -> list[int]:
        '''
        将字符串编码为BBPE token
        '''
        bbpe_tokens_id = []
        # pattern使用要输入BBPE算法的正则表达式模式拆分文本
        for token in re.findall(self.pat, text):
            # 将token转换为其字节表示,将字节映射到其unicode表示
            token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
            # 对token执行bbpe合并,然后根据编码器将结果映射到它们的bbpe索引
            bbpe_tokens_id.extend(self.encoder[bpe_token] for bpe_token in self.bbpe(token))
        return bbpe_tokens_id

    def tokenize(self, text):
        """
        获得编码后的字符
        :param text: 文本
        :return: 返回编码后的字符
        """
        bbpe_tokens = []
        # pattern使用要输入BBPE算法的正则表达式模式拆分文本
        for token in re.findall(self.pat, text):
            # 将token转换为其字节表示,将字节映射到其unicode表示
            token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
            # 对token执行bbpe合并,然后根据编码器获得结果
            bbpe_tokens.extend(bpe_token for bpe_token in self.bbpe(token))
        return bbpe_tokens

    def encode_batch(self, batch: list[str], num_threads=4):
        '''
        将字符串列表编码为BBPE token列表
        '''
        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            result = executor.map(self.encode, batch)
        return list(result)

    def decode(self, tokens) -> str:
        if isinstance(tokens, torch.Tensor):
            tokens = tokens.tolist()
        text = "".join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace")
        return text
This post is licensed under CC BY 4.0 by the author.

RoPE + YARN

-