V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
爱意满满的作品展示区。
lwch
V2EX  ›  分享创造

我用 GO 语言封装了一个机器学习框架,并实现了一个小型的 GPT 模型来对对联

  •  
  •   lwch · 2023-06-15 11:29:19 +08:00 · 1514 次点击
    这是一个创建于 578 天前的主题,其中的信息可能已经有所发展或是发生改变。

    libgotorch

    首先利用 libtorch 库封装了一个libgotorch库,已支持最新的 libtorch2.0.1

    问题一:cgo 中返回的 tensor 对象在栈上,直接使用可能会有内存安全问题

    我做了一层简单的封装来使其创建到堆上,但其引发的问题是需要手动管理内存,因此我编写了 mmgr 包在每一个 tensor 对象创建的时候自动加入 mmgr 的 storage 当中,最后在每一轮训练完毕后通过 GC 方法释放堆上的 tensor 对象

    问题二:windows 下的 libtorch 库通过 msvc 编译,提供的是 C++接口,无法在 mingw 中无法正常链接

    解决方案是通过在封装一个动态链接库并暴露 C 语言接口,在 mingw 中即可正常链接

    通过解决以上两个问题,已可以在 go 语言中使用 libtorch 库并实现自己的模型了

    对对联

    下面进入正题,我在 tnn 库中实现了一个小型的 GPT 模型来实现对对联:couplet,下面让我们来看一下最终效果

    $ go run main.go evaluate --model model7M 晚风摇树树还挺
    load embedding...
    model loaded
    inputs: [472 3 462 148 148 342 1516]
    map[4.278747:[醉] 5.084207:[润] 8.868446:[晨]]
    map[3.8447263:[花] 4.750472:[润] 8.635651:[露]]
    map[5.46043:[花] 6.7003703:[露] 10.768249:[润]]
    map[4.3850584:[露] 4.875666:[润] 9.896332:[花]]
    map[3.6241615:[红] 5.611262:[润] 10.782802:[花]]
    map[4.3855276:[花] 5.48069:[红] 9.480111:[更]]
    map[3.7904112:[心] 4.269902:[花] 10.3220415:[红]]
    晨露润花花更红
    
    $ go run main.go evaluate --model model7M 投石向天跟命斗
    load embedding...
    model loaded
    inputs: [1233 190 383 11 2623 620 490]
    map[5.7068815:[门] 5.7826476:[问] 9.79136:[闭]]
    map[3.0136497:[问] 3.1092193:[人] 8.903796:[门]]
    map[3.021591:[还] 3.448888:[歌] 8.96453:[问]]
    map[4.9368696:[地] 5.7390223:[时] 9.438878:[卷]]
    map[3.5542138:[话] 3.858942:[时] 8.253393:[与]]
    map[3.025545:[与] 3.2461479:[卷] 9.06726:[时]]
    map[4.250452:[时] 4.712057:[舟] 10.401218:[争]]
    闭门问卷与时争
    

    注意:该模型仅训练了开源数据集couplet-dataset中的前 1 万个样本

    模型的参数结构如下:

    +------------------------+---------+
    |          NAME          |  COUNT  |
    +------------------------+---------+
    | transformer0_attention |    1872 |
    | transformer0_dense     | 1256640 |
    | transformer0_output    | 1254960 |
    | transformer1_attention |    1872 |
    | transformer1_dense     | 1256640 |
    | transformer1_output    | 1254960 |
    | output                 | 2488596 |
    | total                  | 7515540 |
    +------------------------+---------+
    
    train 200, cost=2h15m7.877395694s, loss=3.665343e-02
    

    整个模型共有 751 万个参数,模型包含 2 个 transformer 模块,由于在训练时只使用了 8 个 float32 来对每一个字进行表征,因此 attention 层的参数量较少,其他参数配置如下:

    const embeddingDim = 8 // 8 个 float32 表示一个字向量
    const paddingSize = 70 // 最长为 34*2 ,因此 padding 长度必须大于 68
    const heads = 4
    const batchSize = 128
    const epoch = 200
    const lr = 0.001
    const transformerSize = 2
    

    最后让我们来看看模型的泛化能力如何

    $ go run main.go evaluate --model model7M 我是谁
    load embedding...
    model loaded
    inputs: [85 62 191]
    map[4.3809786:[雨] 4.9436274:[染] 7.105626:[绿]]
    map[3.8163047:[水] 4.013789:[东] 4.088595:[得]]
    map[4.872726:[唱] 5.4107614:[兰] 6.3983927:[发]]
    绿得发
    
    $ go run main.go evaluate --model ./model7M 我在哪
    load embedding...
    model loaded
    inputs: [85 99 1151]
    map[1.480957:[思] 2.002811:[得] 4.0260763:[寻]]
    map[3.4100764:[女] 3.868993:[对] 4.448501:[得]]
    map[2.2672489:[年] 2.3772364:[历] 4.946753:[谁]]
    寻得谁
    

    效果不是很理想,可能还是跟训练的样本数量太少有关

    另外还有一些示例可在 example 目录下找到,如使用 RNN 来学习如何画 sin 曲线等

    最后是项目地址:

    第 1 条附言  ·  2023-06-17 23:19:19 +08:00

    修复了FFN层的实现方式问题,现在参数数量看起来比较正常

    +------------------------+---------+
    |          NAME          |  COUNT  |
    +------------------------+---------+
    | transformer0_attention |   62208 |
    | transformer0_dense     |   66048 |
    | transformer0_output    |   65664 |
    | transformer1_attention |   62208 |
    | transformer1_dense     |   66048 |
    | transformer1_output    |   65664 |
    | transformer2_attention |   62208 |
    | transformer2_dense     |   66048 |
    | transformer2_output    |   65664 |
    | transformer3_attention |   62208 |
    | transformer3_dense     |   66048 |
    | transformer3_output    |   65664 |
    | output                 |  572244 |
    | total                  | 1347924 |
    +------------------------+---------+
    

    另外增加了各种mask的支持,修复了loss函数使用问题

    5 条回复    2024-04-11 09:57:55 +08:00
    coosir
        1
    coosir  
       2023-06-15 11:42:08 +08:00
    强哦,对联对得倒是挺好的
    vus520
        2
    vus520  
       2023-06-16 10:17:31 +08:00
    关注下,如果能把 huggingface 的包实现一遍真是造福我众
    allegory
        3
    allegory  
       278 天前
    libtorch 都还是 beta 版,你再来一个 go 的封装,稳定性/正确性如何保证?不过还是很强
    lwch
        4
    lwch  
    OP
       277 天前
    @allegory libtorch 跟着 pytorch 的版本走的,现在已经 2.2.2 了
    lwch
        5
    lwch  
    OP
       277 天前
    @allegory 我用他来实现了一个小型的 llama 模型大约 1.5 亿参数量在 CPU 上训练速度还行
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   实用小工具   ·   1017 人在线   最高记录 6679   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 24ms · UTC 21:54 · PVG 05:54 · LAX 13:54 · JFK 16:54
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.