Skip to content

I Trained a Predictive Text Model for My Own Mobile Input Method

kingzcheung
Published date:
Edit this post

Traditional input method predictive text functionality seems to be implemented based on N-gram language models. Of course, I don’t understand N-gram, nor do I know whether modern input method predictive features have incorporated popular language models (Transformer).

The predictive function of input methods is nothing more than predicting the next possible word based on the current input. Coincidentally, LLM large models also predict the next token based on existing tokens.

Mobile input methods are generally used only for daily chatting, with not too many commonly used words, and prediction performance doesn’t require high accuracy.

Therefore, I could design a very small Transformer model following the approach of designing LLM large models. The goal is to have quantized weights smaller than 20MB.

With such a small size, the vocabulary can’t be too large, otherwise the embedding layer parameters would be too large, making the model top-heavy. But it can’t be too small either, as too small a vocabulary would degrade prediction performance.

My Experiment

I designed a decode-only Transformer model with only 8192 vocabulary words as the experiment subject. Approximate parameters:

 return cls(
        vocab_size=8192,
        hidden_dim=320, 
        num_heads=4,
        num_layers=6,
        ffn_dim=1024,
        max_seq_len=64,
        dropout=0.1,
    )
DecoderTransformer(
  (token_embedding): Embedding(8192, 512)
  (position_embedding): Embedding(128, 512)
  (blocks): ModuleList(
    (0-5): 6 x TransformerBlock(
      (attention): MultiHeadSelfAttention(
        (q_proj): Linear(in_features=512, out_features=512, bias=False)
        (k_proj): Linear(in_features=512, out_features=512, bias=False)
        (v_proj): Linear(in_features=512, out_features=512, bias=False)
        (o_proj): Linear(in_features=512, out_features=512, bias=False)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ffn): FeedForward(
        (fc1): Linear(in_features=512, out_features=2048, bias=True)
        (fc2): Linear(in_features=2048, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (ln_final): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=512, out_features=8192, bias=False)
  (dropout): Dropout(p=0.1, inplace=False)
)

For daily input, max_seq_len=64 is sufficient in my opinion.

Under int8 quantization, this model size is around 17MB.

After the overall training, there are 2 key points:

  1. Data quality is very important; it directly affects whether the model can converge.
  2. Model parameter count is also important. Larger sizes can indeed achieve better performance. At least the size above doesn’t feel quite enough to me. Of course, it might also be my limitation, and I didn’t have time to test better architectures.

Inference

There are many inference frameworks on Android. I’m familiar with Tencent’s ncnn, but unfortunately it doesn’t seem to support Transformer. I tried Alibaba’s MNN framework, but after struggling for a day, there were still major issues. So I switched to the onnxruntime framework and got it working in one shot with vibe coding.

N-gram

If we only do word prediction from the Transformer model, since the model weights are fixed, some results might be fixed during daily use. Moreover, this doesn’t update or record user habits, so we still need to use the N-gram algorithm and Transformer model for fusion.

The specific approach uses a Trie prefix tree to store N-gram:

Fusion formula (linear interpolation):

Score=λmodelScore+(1λ)userScoreScore = λ * modelScore + (1-λ) * userScore

The general process is as follows:

User input → Record to Trie → Save frequency

Get prediction request → Model candidate words + User N-gram candidate words

Fusion scoring → Return sorted candidate words

All of this is computed locally.

Final Thoughts

Limited by data (I only found some open-source Chinese datasets and some Zhihu Q&A datasets) and model parameter size, the current model’s prediction performance isn’t ideal.

For personal use, if I could export personal WeChat chat logs for training, it should be better.

This model project still needs time for optimization and is currently only usable.

Resources

Input method project: https://github.com/ximeiorg/Kime

Model training project: https://github.com/ximeiorg/predictive-text

Weights download:

Previous
Kime2.0 艰难的的插件化
Next
我为自己实现的的手机输入法训练了一个联想词模型