A lot of weirdness with LLMs can be traced back to tokenization. For example, GPT-2 was worse at non-English languages because of its tokenizer - whereas English tokens were generally longer, non-English languages were split up much more (each token was generally shorter). This means that in the attention head, to understand the same concept it had to go through more tokens, so it effectively understood less of the sequence length compared to that same concept in English.
The same problem applied to coding. The GPT-2 tokenizer had trouble with spaces in Python code (because Python has a lot of indentation). This led to many more tokens (each space was being counted as a token), and the attention heads could pay less attention overall to the code because of the token bloat.
So in both of these cases, future tokenizers increased the number of tokens (e.g., longer tokens for non-English words or multiple spaces). Note: more tokens in a tokenizer isn't strictly better though - this would increase the embeddings table where we convert text to vectors representing the meaning of the tokens, and it will affect the softmax at the end of the attention blocks where we need to predict the next token. There seems to be a sweet spot for the number of tokens.
Alright, so how do we go about this tokenization business?
Can we use the Unicode format? There are too many Unicode points - around 150K code points - and the Unicode standard keeps changing. What about encodings like UTF-8, UTF-16, and UTF-32? These are how we translate Unicode into binary. UTF-8 is the most common - it translates into 1-4 bytes. Let's use the Byte Pair Encoding algorithm!
Byte Pair Encoding Algorithm
Find the "byte pair" that occurs most often and replace it with a byte that is not used in the data. Repeat the process until the data cannot be compressed further. This is what we do too - we start with bytes from UTF-8 encoding, then iteratively run the Byte Pair Encoding algorithm, finding the most common byte pairs and minting new tokens for them until we cannot compress our string further.
Here is the code to implement the byte pair encoding algo:
def get_stats(ids):
counts = {}
for pair in zip(ids, ids[1:]):
counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
newids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
# ---
vocab_size = 276 # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(tokens) # copy so we don't destroy the original list
merges = {} # (int, int) -> int
for i in range(num_merges):
stats = get_stats(ids)
pair = max(stats, key=stats.get)
idx = 256 + i
print(f"merging {pair} into a new token {idx}")
ids = merge(ids, pair, idx)
merges[pair] = idx
When "training" the tokenizer, you can use different training datasets - you might want a certain mixture of languages (not just all English, but other languages too, even if they don't occur too much in the LLM's training set). Intuitively: suppose you have a bunch of Hindi data in your tokenizer's training dataset. Then there will be more merges of bytes from Hindi encodings, and Hindi will have shorter sequences - which is beneficial for the attention head, since it has less context length in token space to reason over.
Now, given a trained tokenizer, how do we do our encoding and decoding?
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
def decode(ids):
# given ids (list of integers), return Python string
tokens = b"".join(vocab[idx] for idx in ids)
text = tokens.decode("utf-8", errors="replace")
return text
print(decode([128]))
Some notes: when doing for (p0, p1), idx in merges.items(), we need the order of merges.items() to match the order in which we inserted into the hashmap. Why? If we iterate the merges dictionary in random order we might encounter two symbols that don't yet have a definition in the vocab dictionary. Luckily with modern Python this ordering is guaranteed.
Additionally, for text = tokens.decode("utf-8", errors="replace") - we use errors="replace" because the LLM can output a byte sequence that exists in our vocab but isn't valid UTF-8. By default the error parameter is set to strict, but we can set it to replace to bypass this and be less strict about UTF-8.
Here is the encoding implementation:
def encode(text):
# given a string, return list of integers (the tokens)
tokens = list(text.encode("utf-8"))
while len(tokens) >= 2:
stats = get_stats(tokens)
pair = min(stats, key=lambda p: merges.get(p, float("inf")))
if pair not in merges:
break # nothing else can be merged
idx = merges[pair]
tokens = merge(tokens, pair, idx)
return tokens
print(encode(""))
Let's look at state-of-the-art tokenizers and see what else they are doing
We found the following from GPT-2's encoder.py file:
gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
This regex ensures you never merge certain parts. For example, if you feed in Hello've world123 how's are you!!!? you get ['Hello', "'ve", ' world', '123', ' how', "'s", ' are', ' you', '!!!?'] - you will never be able to merge 'Hello' with 've. Essentially the researchers used human language heuristics and enforced them on the tokenizer.
TikToken - Official tokenizer library from OpenAI
TikToken is used for inference, not for training tokenizers.
encoder.py from OpenAI's file
Encoder.json is equivalent to our vocab from above, and vocab.bpe is our merges dictionary. This tells us that to fully represent a tokenizer we just need these two objects.
Special Tokens
We can introduce special tokens to create structure in our token streams. In encoder.json, the length is 50257 = 256 raw byte tokens, 50K merges, and 1 special token - used to delimit documents from each other. These special tokens live outside the typical BPE algorithm and are used a lot in fine-tuning, for example to delimit entire conversations.
SentencePiece
SentencePiece is commonly used because unlike TikToken it can efficiently both train and run inference for BPE tokenizers - it's used in the Llama and Mistral series. Also unlike TikToken (which converts text to UTF-8 and then to BPEs), SentencePiece works directly on Unicode points and falls back to UTF-8 for rare codepoints.
So how does tokenization affect our GPT we developed in the last blog post?
The vocab size is used two main times in the GPT file - once to initialize the embeddings table (shape: vocab_size x n_embd) and once for the lm_head at the very end of all the decoder blocks when we output the predicted next token (shape: n_embd x vocab_size). In SOTA models it is generally high - 10K or 100K for vocab_size, set empirically.
What if we want to extend the vocab size of a pretrained model? This is done pretty often, for example when fine-tuning. As mentioned before, we just need to resize the embeddings table and the linear layer tensor - pretty mild model surgeries. Adding new tokens doesn't just help when changing the model after pretraining to act like a chatbot; you can also finetune with new tokens for cool effects. For example, there is a paper where they add new "gist tokens", do the corresponding mild model surgery (reshaping the embeddings table and last linear head), freeze the existing weights, and just train the new weights so the "gist" token acts like a replacement for the original long prompt - compressing the prompt into a single token.
How do you process not just text into transformers but other modalities like images and videos?
For different modalities, do you have to change the model internally? The field has largely converged on no - you don't need to change the model internally, only the tokenization of the input.
Now that we have a good understanding of the tokenization process, why do the following phenomena happen?
Why can't LLMs spell words? Tokenization. Sometimes words are grouped into a single token, and remember the basic unit of the transformer (the basic unit of attention) operates on tokens. The transformer is really good at finding patterns between tokens, but if we ask it to spell a word that is comprised of very large tokens (or the word itself is a single token) it won't be very good at it - it doesn't know the patterns within a token.
Why can't LLMs do super simple string processing tasks like reversing a string? Same reason - it doesn't have knowledge about the patterns within a token (the token's spelling is an ordering of characters internally, which is what I'm calling a pattern here).
Why are LLMs worse at non-English languages (e.g., Japanese)? We covered this earlier: model providers have less training data for these languages, and the tokenization for non-English is often less efficient. They sometimes have less non-English data for tokenizer training, which leads to fewer merges for that language's bytes, which leads to more tokens for non-English text, which leads to worse performance as the attention head has to reason across a longer token sequence.
Why are LLMs bad at simple arithmetic? Because sometimes multiple numbers are represented together in a token, or sometimes separately, depending on whatever number tokens happened to be merged in the BPE algorithm. Llama 2 models represent each number with a separate token, and I suspect their algebra performance would be much better.
Why did GPT-2 have more than necessary trouble coding in Python? Part of it is the model training itself, but tokenization played a role too. Encoding efficiency for spaces (which Python has a lot of) was terrible. Each space was often its own token, so the same reasoning applies: more tokens from lower token efficiency, worse performance on coding tasks.
Why did my LLM abruptly halt when it saw the string "<|endoftext|>"? This doesn't happen anymore, but it used to because user-inputted text could be treated as special tokens if it matched one. The model thought it was a real special token and stopped, as it would for any special token. This was a security hazard - ideally special tokens should never be parseable from user input.
What is this weird warning about "trailing whitespace"? Generally spaces are part of the token that follows a word, not standalone. When you add a trailing space at the end of your prompt, the model has to predict a next token that starts with a space - but it has rarely seen a token that is just a single space (spaces are usually part of other tokens). So we're out of distribution and performance can degrade.
Why does the LLM break if I ask it about "SolidGoldMagikarp"? The tokenizer's training set was much different than the LLM's training set. There was a lot of Reddit data in the tokenizer data, and a Reddit user r/SolidGoldMagikarp posted frequently. The BPE algorithm ended up giving this username its own token. But this user was never in the LLM's training data, so the embedding for that token was never updated during training. At test time, when you say "SolidGoldMagikarp", the model uses this untrained vector and you get undefined behavior.
Why should I prefer YAML over JSON with LLMs? YAML is more token-efficient (longer tokens on average). Longer tokens mean fewer tokens, which means better attention and performance. We pay in tokens - both in dollars and in context size - so it pays to be informed about token density for different formats and languages.
The fundamental problem is that the smallest unit the LLMs receive are tokens, not characters in the way we see them. This can lead to unintended consequences - such as the above examples. Nonetheless, it's pretty fun to mess around with these.
Thanks a lot to Andrej Karpathy!