传统的输入法的联想功能,好像是基于N-gram语言模型来实现的。当然,我并不了解N-gram 。也不太了解现代的输入法的联想功能是否已经加入了现在的流行的语言模型(Transformer)。
输入法的联想功能,无非就是基于现打的词,去预测下一个可能的词,巧的是,LLM 大模型也是基于现有的词元来预测下一个词元。
手机输入法一般只作为日常的聊天使用,常用词不会太多,预测的性能也不需要很高的精度。
因此,我可以按设计llm 大模型的思路来设计一个非常小的Transformer 模型。目标是量化后权重大小至少小于20mb。
如此小的尺寸,词汇就无法做到太大,否则会因为 embedding 层参数过大而造成模型头重脚轻。但是也不能过小,因为太小的词汇量会导致预测性能下降。
我的实验
我设计了一个只有8192 个词汇的 decode only 的Transformer 模型作为实验对象。大概参数:
return cls(
vocab_size=8192,
hidden_dim=320,
num_heads=4,
num_layers=6,
ffn_dim=1024,
max_seq_len=64,
dropout=0.1,
)
DecoderTransformer(
(token_embedding): Embedding(8192, 512)
(position_embedding): Embedding(128, 512)
(blocks): ModuleList(
(0-5): 6 x TransformerBlock(
(attention): MultiHeadSelfAttention(
(q_proj): Linear(in_features=512, out_features=512, bias=False)
(k_proj): Linear(in_features=512, out_features=512, bias=False)
(v_proj): Linear(in_features=512, out_features=512, bias=False)
(o_proj): Linear(in_features=512, out_features=512, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
)
(ffn): FeedForward(
(fc1): Linear(in_features=512, out_features=2048, bias=True)
(fc2): Linear(in_features=2048, out_features=512, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(ln1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(ln_final): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(lm_head): Linear(in_features=512, out_features=8192, bias=False)
(dropout): Dropout(p=0.1, inplace=False)
)
对于日常输入来说,max_seq_len=64 我觉得已经够用了。这个模型尺寸在int8 的量化下,大小为 17M 左右。
整体训练下来,有2点:
-
数据质量很重要,它能直接影响模型是否能收敛。
-
模型的参数量也很重要,更大的尺寸确实能获取更好的性能,至少上面这个尺寸是我感觉不太够的,当然也可能是我能力问题,同时也没有时间测试更好的结构。
推理
安卓上的推理框架有很多个,我对腾讯的ncnn 比较熟悉,可惜它好像不支持Transformer。尝试了阿里的 mnn 框架,结果折腾了一天,都有大问题。于是换成onnxruntime 框架,vibe coding 一次搞定。
N-gram
如果只是单纯地从 Transformer 模型做词的联想预测的话,由于模型权重是固定的,所以有可能日常使用会碰到一些结果也是固定的。而且这并没有更新和记录自己的习惯,因此还是需要使用N-gram 算法和Transformer 模型 来做融合。
具体的做法是,使用Trie 前缀树存储 N-gram:
-
每个节点存储 count(出现次数)、frequency(频率)、prefixCount(前缀计数)
-
支持 Bigram(二元组)和 Trigram(三元组)两种N-gram
融合公式(线性插值):
大体流程如下:
用户输入 → 记录到Trie → 保存频率
↓
获取联想词请求 → 模型候选词 + 用户 N-gram 候选词
↓
融合评分 → 返回排序后的候选词
这些全都在本地计算。
最后
受限于数据(我只找了部分开源的中文数据集以及一些知乎问答数据集)大小,以及模型参数量的大小。目前这个模型的预测性能并没有做到很理想的状态。
自己使用的话,如果有办法导出个人微信聊天记录用于训练,应该会好一些。
这个模型项目的训练还需要时间优化,目前只能说是可用。
资料
输入法项目: https://github.com/ximeiorg/Kime
模型训练项目:https://github.com/ximeiorg/predictive-text
权重下载: