Traditional input method predictive text functionality seems to be implemented based on N-gram language models. Of course, I don’t understand N-gram, nor do I know whether modern input method predictive features have incorporated popular language models (Transformer).
The predictive function of input methods is nothing more than predicting the next possible word based on the current input. Coincidentally, LLM large models also predict the next token based on existing tokens.
Mobile input methods are generally used only for daily chatting, with not too many commonly used words, and prediction performance doesn’t require high accuracy.
Therefore, I could design a very small Transformer model following the approach of designing LLM large models. The goal is to have quantized weights smaller than 20MB.
With such a small size, the vocabulary can’t be too large, otherwise the embedding layer parameters would be too large, making the model top-heavy. But it can’t be too small either, as too small a vocabulary would degrade prediction performance.
My Experiment
I designed a decode-only Transformer model with only 8192 vocabulary words as the experiment subject. Approximate parameters:
return cls(
vocab_size=8192,
hidden_dim=320,
num_heads=4,
num_layers=6,
ffn_dim=1024,
max_seq_len=64,
dropout=0.1,
)
DecoderTransformer(
(token_embedding): Embedding(8192, 512)
(position_embedding): Embedding(128, 512)
(blocks): ModuleList(
(0-5): 6 x TransformerBlock(
(attention): MultiHeadSelfAttention(
(q_proj): Linear(in_features=512, out_features=512, bias=False)
(k_proj): Linear(in_features=512, out_features=512, bias=False)
(v_proj): Linear(in_features=512, out_features=512, bias=False)
(o_proj): Linear(in_features=512, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
)
(ffn): FeedForward(
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(ln_final): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(lm_head): Linear(in_features=512, out_features=8192, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
)
For daily input, max_seq_len=64 is sufficient in my opinion.
Under int8 quantization, this model size is around 17MB.
After the overall training, there are 2 key points:
- Data quality is very important; it directly affects whether the model can converge.
- Model parameter count is also important. Larger sizes can indeed achieve better performance. At least the size above doesn’t feel quite enough to me. Of course, it might also be my limitation, and I didn’t have time to test better architectures.
Inference
There are many inference frameworks on Android. I’m familiar with Tencent’s ncnn, but unfortunately it doesn’t seem to support Transformer. I tried Alibaba’s MNN framework, but after struggling for a day, there were still major issues. So I switched to the onnxruntime framework and got it working in one shot with vibe coding.
N-gram
If we only do word prediction from the Transformer model, since the model weights are fixed, some results might be fixed during daily use. Moreover, this doesn’t update or record user habits, so we still need to use the N-gram algorithm and Transformer model for fusion.
The specific approach uses a Trie prefix tree to store N-gram:
- Each node stores count (occurrence count), frequency (frequency), prefixCount (prefix count)
- Supports both Bigram (2-gram) and Trigram (3-gram) N-grams
Fusion formula (linear interpolation):
The general process is as follows:
User input → Record to Trie → Save frequency
↓
Get prediction request → Model candidate words + User N-gram candidate words
↓
Fusion scoring → Return sorted candidate words
All of this is computed locally.
Final Thoughts
Limited by data (I only found some open-source Chinese datasets and some Zhihu Q&A datasets) and model parameter size, the current model’s prediction performance isn’t ideal.
For personal use, if I could export personal WeChat chat logs for training, it should be better.
This model project still needs time for optimization and is currently only usable.
Resources
Input method project: https://github.com/ximeiorg/Kime
Model training project: https://github.com/ximeiorg/predictive-text
Weights download:
- Huggingface: https://huggingface.co/rkingzhong/predictive-text-small
- ModelScope: https://www.modelscope.cn/models/bikeand/predictive-text-small