Code: https://github.com/OctopusMind/BBPE
1、正则表达式分词
支持多语种分词
pat_str = r"'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
各部分含义如下:
's|'t|'re|'ve|'m|'ll|'d
- 匹配一些英文中最常见的缩写后缀(类似于词素/词缀):
- ’s(is 或 has 的缩写,如 it’s, John’s)
- ‘t(not 的缩写,如 can’t, won’t)
- ‘re(are 的缩写,如 you’re)
- ‘ve(have 的缩写,如 I’ve)
- ‘m(am 的缩写,如 I’m)
- ‘ll(will 的缩写,如 I’ll)
- ‘d(would 或 had 的缩写,如 I’d)
- 匹配一些英文中最常见的缩写后缀(类似于词素/词缀):
?[\p{L}]+
- 匹配以可选空格开始,后面连着至少一个“字母”(Unicode类别 L,涵盖所有语言的字母)。
- 例: hello、你好、Привет。
- 注意前面的空格是可选的,目的是保留词前可能的空格信息(对BPE分词很关键)。
?[\p{N}]+
- 匹配以可选空格开始,后面连着至少一个“数字”(Unicode类别 N)。
- 例: 2024、345(全角数字也可,因范围广)。
- 同样空格是 token 的一部分。
?[^\s\p{L}\p{N}]+
- 匹配以可选空格开始,后面是至少一个既不是空白(\s),也不是字母(\p{L}),也不是数字(\p{N})的字符。
- 例:标点、符号,如: !、,、@# 等。
\s+(?!\S)
- 匹配连续的空白字符(\s+),但后面不能跟非空白字符((?!\S))。即只能匹配那些在文本末尾的空格。
- 例:句子末尾的多个空格。
\s+
- 匹配一个或多个空白字符。用于捕获普通的空格分隔。
- 注意:整体顺序也很重要,整体分词时通常优先匹配前面的规则。
2、utf-8字节值到unicode可打印字符的映射
用于初始化词汇表:BBPE以字节为单位,初始化词汇表大小为256(0x00 -> 0xFF),将(0->255)映射到unicode字符串
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def bytes_to_unicode():
"""
返回utf-8字节列表和到unicode字符串的映射。我们特别避免映射到bbpe代码所依赖的空白/控制字符。
可逆的bbpe代码在unicode字符串上工作。这意味着如果您想避免UNKs,您需要在您的词汇表中使用大量的unicode字符。
当你有一个10B的token数据集时,你最终需要大约5K才能获得良好的覆盖。这是你正常情况下的一个显著比例,
比如说,32K的词汇量。为了避免这种情况,我们希望查找表介于utf-8字节和unicode字符串之间。
"""
# 初始化bs均为可打印字符
bs = (
# 33 - 126
list(range(ord("!"), ord("~") + 1)) + \
# 162 - 172
list(range(ord("¡"), ord("¬") + 1)) + \
# 174 - 255
list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
# 处理非可打印字符(映射到其他可打印的unicode字符)
for b in range(2 ** 8):
# 如果不在bs中,说明是非可打印字符,需要映射到其他可打印的unicode字符 (256 + n)
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
# 将cs转换为unicode字符串
cs = [chr(n) for n in cs]
# 构建bs到cs的映射
return dict(zip(bs, cs))
3、词汇表训练
从初始词表(字节的所有表示:0x00->0xff)开始,统计语料中相邻词汇的共现频次,然后合并频次最高的一对作为新的词汇加入到词汇表中,迭代进行直到满足预设要求(如词汇表大小)。
BPE的最小词汇是字符级别,BBPE的最小词汇是字节级别。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
@staticmethod
def train_tokenizer(data, vocab_size, vocab_outfile=None, merges_outfile=None):
"""
:param data: 训练文本
:param vocab_size: 保留词表的大小
:param vocab_outfile: 保存词表的文件名
:param merges_outfile: 保存合并字节的词表
"""
if vocab_size < 256:
raise ValueError("vocab_size must be greater than 256")
# 字节到unicode字符映射
byte_encoder = bytes_to_unicode()
# 正则表达式分词模式
pat_str = r"'s|'t|'re|'ve|'m|'ll|'d| ?[\p{L}]+| ?[\p{N}]+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
# 分词并转换为字节,并映射到unicode字符(通过byte_encoder)
split_words = [
[byte_encoder[b] for b in token.encode("utf-8")] for token in re.findall(pat_str, data)
]
# 词汇表初始化为基本的词汇,即bytes_to_unicode所包含的内容
vocab = set(byte_encoder.values())
merges = []
# 构建词汇表,直到满足所需的词汇量
while len(vocab) < vocab_size:
print(len(vocab))
pair_freq = Counter()
# 找出出现频次最多的字节对
for split_word in split_words:
pair_freq.update(zip(split_word[:-1], split_word[1:]))
most_common_pair = pair_freq.most_common(1)[0][0]
# 更新词汇表和合并列表
new_token = most_common_pair[0] + most_common_pair[1]
# 添加新词汇
vocab.add(new_token)
# 记录合并操作
merges.append(most_common_pair)
# 对数据执行合并
new_split_words = []
for split_word in split_words:
i = 0
new_word = []
# 对于单词中的每个重字符,尝试合并
while i < len(split_word) - 1:
# 如果(split_word[i], split_word[i + 1])与新合并的词汇一致,则添加新词汇,并跳过下一个字符
if (split_word[i], split_word[i + 1]) == most_common_pair:
new_word.append(new_token)
i += 2
else:
new_word.append(split_word[i])
i += 1
# 边界处理,如果i等于split_word的长度-1,则添加最后一个字符
if i == len(split_word) - 1:
new_word.append(split_word[i])
new_split_words.append(new_word)
# 更新split_words,用于下一轮迭代
split_words = new_split_words
vocab = sorted(list(vocab))
# 保存文件
if merges_outfile != None:
with open(merges_outfile, "w", encoding="utf-8") as f:
for merge in merges:
f.write(merge[0] + " " + merge[1] + "\n")
if vocab_outfile != None:
with open(vocab_outfile, "w", encoding="utf-8") as f:
json.dump({v: i for i, v in enumerate(vocab)}, f, ensure_ascii=False)
4、字符串编解码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
class BBPETokenizer(nn.Module):
def __init__(self, vocab_path: str, merges_path: str):
super().__init__()
with open(vocab_path, "r", encoding="utf-8") as f: # 获得词表
vocab = json.load(f)
with open(merges_path, "r", encoding="utf-8") as f: # 获得合并token规则词表
merges = f.read()
# 将合并存储为元组列表,删除最后一个空白行
merges = [tuple(merge_str.split()) for merge_str in merges.split("\n")[:-1]]
# token到BBPE解码索引映射
self.encoder = vocab
self.decoder = {v: k for k, v in self.encoder.items()}
# 字节到unicode字符映射,256个字符
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
self.bbpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
# 预标记化拆分正则表达式模式
self.pat = re.compile(r"""
's|'t|'re|'ve|'m|'ll|'d| # 常见的缩写
\ ?\p{L}+|\ ?\p{N}+| # 可选空格,后跟1+ unicode字母或数字
\ ?[^\s\p{L}\p{N}]+| # 可选空格,后面跟着1+非空白/字母/数字
\s+(?!\S)| # 1+空白字符,后面没有非空白字符
\s+ # 1+空格字符
""", re.X)
def forward(self, text):
if isinstance(text, list):
# 批量编码
tokens = self.encode_batch(text)
tokens = [token for row in tokens for token in row]
else:
# 编码字符串
tokens = self.encode(text)
return torch.tensor(tokens)
def bbpe(self, token):
'''
对token应用合并规则
'''
if token in self.cache:
return self.cache[token]
chars = [i for i in token]
# 对于每个合并规则,尝试合并任何相邻的字符对
for pair in self.bbpe_ranks.keys():
i = 0
while i < len(chars) - 1:
if chars[i] == pair[0] and chars[i + 1] == pair[1]:
chars = chars[:i] + ["".join(pair)] + chars[i + 2:]
else:
i += 1
self.cache[token] = chars
return chars
def encode(self, text: str) -> list[int]:
'''
将字符串编码为BBPE token
'''
bbpe_tokens_id = []
# pattern使用要输入BBPE算法的正则表达式模式拆分文本
for token in re.findall(self.pat, text):
# 将token转换为其字节表示,将字节映射到其unicode表示
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
# 对token执行bbpe合并,然后根据编码器将结果映射到它们的bbpe索引
bbpe_tokens_id.extend(self.encoder[bpe_token] for bpe_token in self.bbpe(token))
return bbpe_tokens_id
def tokenize(self, text):
"""
获得编码后的字符
:param text: 文本
:return: 返回编码后的字符
"""
bbpe_tokens = []
# pattern使用要输入BBPE算法的正则表达式模式拆分文本
for token in re.findall(self.pat, text):
# 将token转换为其字节表示,将字节映射到其unicode表示
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
# 对token执行bbpe合并,然后根据编码器获得结果
bbpe_tokens.extend(bpe_token for bpe_token in self.bbpe(token))
return bbpe_tokens
def encode_batch(self, batch: list[str], num_threads=4):
'''
将字符串列表编码为BBPE token列表
'''
with ThreadPoolExecutor(max_workers=num_threads) as executor:
result = executor.map(self.encode, batch)
return list(result)
def decode(self, tokens) -> str:
if isinstance(tokens, torch.Tensor):
tokens = tokens.tolist()
text = "".join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors="replace")
return text