Skip to content

Commit c9a4a70

Browse files
authored
Fix regex lookahead for code input tokenization (#314)
1 parent 6d671d2 commit c9a4a70

5 files changed

Lines changed: 57 additions & 17 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ build/
2727

2828
# data
2929
/data/
30+
*.log

chatglm.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,6 +1766,7 @@ TiktokenCoreBPE::TiktokenCoreBPE(std::unordered_map<std::string, int> encoder,
17661766
: regex(std::make_unique<RE2>("(" + pattern + ")")), encoder(std::move(encoder)),
17671767
special_tokens_encoder(std::move(special_tokens_encoder)) {
17681768
CHATGLM_CHECK(regex->ok()) << regex->error();
1769+
CHATGLM_CHECK(regex->NumberOfCapturingGroups() <= 2) << "unimplemented";
17691770

17701771
decoder.reserve(this->encoder.size());
17711772
for (const auto &[token, rank] : this->encoder) {
@@ -1853,24 +1854,24 @@ std::vector<int> TiktokenCoreBPE::byte_pair_encode(const std::string &piece,
18531854

18541855
std::vector<int> TiktokenCoreBPE::_encode_ordinary_native(const std::string &text) const {
18551856
std::vector<int> ret;
1856-
re2::StringPiece input = text;
1857-
re2::StringPiece prev_input = input;
1857+
re2::StringPiece input(text);
1858+
re2::StringPiece prev_input(input);
18581859
std::string piece;
1859-
while (RE2::FindAndConsume(&input, *regex, &piece)) {
1860-
// recover input in case of negative lookahead
1861-
if (prev_input.find(piece) == 0) {
1862-
input = prev_input.substr(piece.size());
1863-
prev_input = input;
1864-
} else {
1865-
std::cerr << "[WARN] chatglm.cpp: encounter unknown token\n";
1860+
std::string piece2;
1861+
while (RE2::FindAndConsume(&input, *regex, &piece, &piece2)) {
1862+
if (!piece2.empty()) {
1863+
// workaround for lookahead: capture sub group and restore input
1864+
auto pos = prev_input.find(piece2);
1865+
input = prev_input.substr(pos + piece2.length());
1866+
piece = std::move(piece2);
18661867
}
1867-
18681868
if (auto it = encoder.find(piece); it != encoder.end()) {
18691869
ret.emplace_back(it->second);
18701870
} else {
18711871
std::vector<int> bpe_ids = byte_pair_encode(piece, encoder);
18721872
ret.insert(ret.end(), bpe_ids.begin(), bpe_ids.end());
18731873
}
1874+
prev_input = input;
18741875
}
18751876
return ret;
18761877
}
@@ -1930,7 +1931,7 @@ ChatGLM4Tokenizer::ChatGLM4Tokenizer(const std::string &vocab_text) {
19301931
observation_token_id = special_tokens_encoder.at("<|observation|>");
19311932

19321933
const std::string pattern =
1933-
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?:$|\s)|\s+)";
1934+
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|(\s+)(?:\s)|\s+)";
19341935
core_bpe = TiktokenCoreBPE(std::move(mergeable_ranks), std::move(special_tokens_encoder), pattern);
19351936
}
19361937

chatglm_cpp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import chatglm_cpp._C as _C
77
from chatglm_cpp._C import ChatMessage
88

9-
__version__ = "0.3.3"
9+
__version__ = "0.3.4"
1010

1111

1212
@dataclass

chatglm_test.cpp

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,7 @@ TEST(Pipeline, ChatGLM4) {
13381338
Pipeline pipeline(model_path.string());
13391339
ASSERT_TRUE(dynamic_cast<ChatGLM4Tokenizer *>(pipeline.tokenizer.get()));
13401340
ASSERT_TRUE(dynamic_cast<ChatGLM4ForCausalLM *>(pipeline.model.get()));
1341+
auto tokenizer = dynamic_cast<ChatGLM4Tokenizer *>(pipeline.tokenizer.get());
13411342

13421343
// const std::string system_tool_call =
13431344
// read_text(fs::path(__FILE__).parent_path() / "examples/system/function_call.txt");
@@ -1346,8 +1347,6 @@ TEST(Pipeline, ChatGLM4) {
13461347

13471348
// tiktoken
13481349
{
1349-
auto tokenizer = dynamic_cast<ChatGLM4Tokenizer *>(pipeline.tokenizer.get());
1350-
13511350
// taken from:
13521351
// https://github.com/ggerganov/llama.cpp/blob/4bfe50f741479c1df1c377260c3ff5702586719e/convert-hf-to-gguf.py#L413
13531352
const std::string chktxt =
@@ -1372,7 +1371,30 @@ TEST(Pipeline, ChatGLM4) {
13721371
498, 2704, 30, 364, 44, 537, 2704, 358, 3278, 1281, 432, 11, 364, 35,
13731372
498, 1075, 1045, 15231, 30, 1205, 6, 42368, 264, 63409, 43};
13741373

1375-
std::vector<int> out_ids = tokenizer->core_bpe.encode_ordinary(chktxt);
1374+
const std::vector<int> out_ids = tokenizer->core_bpe.encode_ordinary(chktxt);
1375+
EXPECT_EQ(ref_ids, out_ids);
1376+
}
1377+
{
1378+
const std::string text = R"(
1379+
```c++
1380+
#include <iostream>
1381+
1382+
int main() {
1383+
printf("hello world\n"); // say hello
1384+
}
1385+
```
1386+
1387+
```python
1388+
if __name__ == '__main__':
1389+
print('hello world') # say hello
1390+
```
1391+
)";
1392+
const std::vector<int> ref_ids = {198, 73022, 66, 22879, 1067, 366, 9661, 1339, 396, 1887, 368,
1393+
341, 262, 4100, 445, 14978, 1879, 1699, 5038, 262, 442, 1977,
1394+
23745, 198, 532, 13865, 19288, 73022, 12663, 198, 333, 1304, 606,
1395+
563, 621, 12106, 3817, 16165, 262, 1173, 492, 14978, 1879, 863,
1396+
286, 671, 1977, 23745, 198, 13865, 3989};
1397+
const std::vector<int> out_ids = tokenizer->core_bpe.encode_ordinary(text);
13761398
EXPECT_EQ(ref_ids, out_ids);
13771399
}
13781400
// tokenizer

tests/test_convert.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,22 @@ def make_glm4_pipeline_data():
803803
chktxt = "\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶\u200d🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````\"\"\"\"......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL"
804804
print("tiktoken:", tokenizer.tokenizer.encode(chktxt, disallowed_special=()))
805805

806+
chktxt = r"""
807+
```c++
808+
#include <iostream>
809+
810+
int main() {
811+
printf("hello world\n"); // say hello
812+
}
813+
```
814+
815+
```python
816+
if __name__ == '__main__':
817+
print('hello world') # say hello
818+
```
819+
"""
820+
print("tiktoken:", tokenizer.tokenizer.encode(chktxt, disallowed_special=()))
821+
806822
# tokenizer
807823
inputs = tokenizer("你好")
808824
print(f"encode: {inputs=}")
@@ -861,8 +877,8 @@ def main():
861877
# make_data_baichuan7b_model()
862878
# make_data_baichuan13b_model()
863879
# make_internlm_model()
864-
make_data_glm4_model()
865-
# make_glm4_pipeline_data()
880+
# make_data_glm4_model()
881+
make_glm4_pipeline_data()
866882

867883

868884
if __name__ == "__main__":

0 commit comments

Comments
 (0)