training-gpt-from-scratch.html 65 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749
  1. <!DOCTYPE html>
  2. <html lang="en" data-default-color-scheme=auto>
  3. <head><!-- hexo injector head_begin start -->
  4. <script defer src="https://api.limour.top/vue/0d2f95c1-755d-436b-adf8-eee12a80ed32/script.js"></script>
  5. <!-- hexo injector head_begin end -->
  6. <meta charset="UTF-8">
  7. <link rel="apple-touch-icon" sizes="76x76" href="https://img.limour.top/2023/08/29/64ee07361815a.webp">
  8. <link rel="icon" href="https://img.limour.top/2023/08/29/64ee07361815a.webp">
  9. <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=5.0, shrink-to-fit=no">
  10. <meta http-equiv="x-ua-compatible" content="ie=edge">
  11. <meta name="theme-color" content="#2f4154">
  12. <meta name="author" content="Limour">
  13. <meta name="keywords" content="">
  14. <meta name="description" content="探索整个过程,从在一台搭载1660Ti显卡的笔记本电脑上构建 Tokenizer,定义带有 RoPE 的 Transformer,一直到训练、保存模型和可视化训练过程。沉浸在从零开始训练 GPT 的旅程中,深入了解每一个步骤。跳入深度学习的世界,释放在你的便携1660Ti笔记本上的强大潜能。">
  15. <title>【探索】从零开始训练 GPT - Limour&#39;s Blog</title>
  16. <link rel="stylesheet" href="https://jscdn.limour.top/npm/bootstrap@4.6.1/dist/css/bootstrap.min.css" />
  17. <link rel="stylesheet" href="https://jscdn.limour.top/npm/github-markdown-css@4.0.0/github-markdown.min.css" />
  18. <link rel="stylesheet" href="https://jscdn.limour.top/npm/hint.css@2.7.0/hint.min.css" />
  19. <!-- 主题依赖的图标库,不要自行修改 -->
  20. <!-- Do not modify the link that theme dependent icons -->
  21. <link rel="stylesheet" href="//at.alicdn.com/t/c/font_1749284_5i9bdhy70f8.css">
  22. <link rel="stylesheet" href="//at.alicdn.com/t/font_1736178_lbnruvf0jn.css">
  23. <link rel="stylesheet" href="/css/main.css" />
  24. <link id="highlight-css" rel="stylesheet" href="/css/highlight.css" />
  25. <link id="highlight-css-dark" rel="stylesheet" href="/css/highlight-dark.css" />
  26. <link rel="stylesheet" href="/theme-inject/custom.css">
  27. <link rel="stylesheet" href="/theme-inject/iconfont.css">
  28. <script id="fluid-configs">
  29. var Fluid = window.Fluid || {};
  30. Fluid.ctx = Object.assign({}, Fluid.ctx)
  31. var CONFIG = {"hostname":"hexo.limour.top","root":"/","version":"1.9.8","typing":{"enable":false,"typeSpeed":70,"cursorChar":"_","loop":false,"scope":[]},"anchorjs":{"enable":true,"element":"h1,h2,h3,h4,h5,h6","placement":"left","visible":"hover","icon":"§"},"progressbar":{"enable":true,"height_px":3,"color":"#29d","options":{"showSpinner":false,"trickleSpeed":100}},"code_language":{"enable":true,"default":"TEXT"},"copy_btn":true,"image_caption":{"enable":true},"image_zoom":{"enable":false,"img_url_replace":["",""]},"toc":{"enable":true,"placement":"right","headingSelector":"h1,h2,h3,h4,h5,h6","collapseDepth":0},"lazyload":{"enable":true,"loading_img":"https://jscdn.limour.top/gh/Limour-dev/Sakurairo_Vision/load_svg/inload.svg","onlypost":false,"offset_factor":2},"web_analytics":{"enable":false,"follow_dnt":true,"baidu":null,"google":{"measurement_id":null},"tencent":{"sid":null,"cid":null},"leancloud":{"app_id":null,"app_key":null,"server_url":null,"path":"window.location.pathname","ignore_local":false},"umami":{"src":null,"website_id":null,"domains":null,"start_time":"2024-01-01T00:00:00.000Z","token":null,"api_server":null},"woyaola":null,"cnzz":null},"search_path":"/local-search.xml","include_content_in_search":true};
  32. if (CONFIG.web_analytics.follow_dnt) {
  33. var dntVal = navigator.doNotTrack || window.doNotTrack || navigator.msDoNotTrack;
  34. Fluid.ctx.dnt = dntVal && (dntVal.startsWith('1') || dntVal.startsWith('yes') || dntVal.startsWith('on'));
  35. }
  36. </script>
  37. <script src="/js/utils.js" ></script>
  38. <script src="/js/color-schema.js" ></script>
  39. <link rel="canonical" href="https://hexo.limour.top/training-gpt-from-scratch"/>
  40. <meta name="generator" content="Hexo 7.1.1"><link rel="alternate" href="/atom.xml" title="Limour's Blog" type="application/atom+xml">
  41. <link rel="alternate" href="/rss2.xml" title="Limour's Blog" type="application/rss+xml">
  42. </head>
  43. <body>
  44. <header>
  45. <div class="header-inner" style="height: 70vh;">
  46. <nav id="navbar" class="navbar fixed-top navbar-expand-lg navbar-dark scrolling-navbar">
  47. <div class="container">
  48. <a class="navbar-brand" href="/">
  49. <strong>Limour&#39;s Blog</strong>
  50. </a>
  51. <button id="navbar-toggler-btn" class="navbar-toggler" type="button" data-toggle="collapse"
  52. data-target="#navbarSupportedContent"
  53. aria-controls="navbarSupportedContent" aria-expanded="false" aria-label="Toggle navigation">
  54. <div class="animated-icon"><span></span><span></span><span></span></div>
  55. </button>
  56. <!-- Collapsible content -->
  57. <div class="collapse navbar-collapse" id="navbarSupportedContent">
  58. <ul class="navbar-nav ml-auto text-center">
  59. <li class="nav-item">
  60. <a class="nav-link" href="https://hexo.limour.top/" target="_self">
  61. <i class="iconfont icon-home-fill"></i>
  62. <span>Home</span>
  63. </a>
  64. </li>
  65. <li class="nav-item">
  66. <a class="nav-link" href="/archives/" target="_self">
  67. <i class="iconfont icon-archive-fill"></i>
  68. <span>Archive1</span>
  69. </a>
  70. </li>
  71. <li class="nav-item">
  72. <a class="nav-link" href="https://occdn.limour.top/archives/" target="_self">
  73. <i class="iconfont icon-archive-fill"></i>
  74. <span>Archive2</span>
  75. </a>
  76. </li>
  77. <li class="nav-item">
  78. <a class="nav-link" href="https://b.limour.top/archives/" target="_self">
  79. <i class="iconfont icon-archive-fill"></i>
  80. <span>Archive3</span>
  81. </a>
  82. </li>
  83. <li class="nav-item">
  84. <a class="nav-link" href="https://od.limour.top/" target="_self">
  85. <i class="iconfont icon-onedrive"></i>
  86. <span>Alist</span>
  87. </a>
  88. </li>
  89. <li class="nav-item">
  90. <a class="nav-link" href="https://orcid.org/0000-0001-8897-1685" target="_self">
  91. <i class="iconfont icon-orcid"></i>
  92. <span>Orcid</span>
  93. </a>
  94. </li>
  95. <li class="nav-item">
  96. <a class="nav-link" href="/links/" target="_self">
  97. <i class="iconfont icon-link-fill"></i>
  98. <span>Links</span>
  99. </a>
  100. </li>
  101. <li class="nav-item">
  102. <a class="nav-link" href="/atom.xml" target="_self">
  103. <i class="iconfont icon-rss"></i>
  104. <span>RSS</span>
  105. </a>
  106. </li>
  107. <li class="nav-item" id="search-btn">
  108. <a class="nav-link" target="_self" href="javascript:;" data-toggle="modal" data-target="#modalSearch" aria-label="Search">
  109. <i class="iconfont icon-search"></i>
  110. </a>
  111. </li>
  112. <li class="nav-item" id="color-toggle-btn">
  113. <a class="nav-link" target="_self" href="javascript:;" aria-label="Color Toggle">
  114. <i class="iconfont icon-dark" id="color-toggle-icon"></i>
  115. </a>
  116. </li>
  117. </ul>
  118. </div>
  119. </div>
  120. </nav>
  121. <div id="banner" class="banner" parallax=true
  122. style="background: url('https://img.limour.top/2023/08/29/64ee08e108638.webp') no-repeat center center; background-size: cover;">
  123. <div class="full-bg-img">
  124. <div class="mask flex-center" style="background-color: rgba(0, 0, 0, 0.3)">
  125. <div class="banner-text text-center fade-in-up">
  126. <div class="h2">
  127. <span id="subtitle">【探索】从零开始训练 GPT</span>
  128. </div>
  129. <div class="mt-3">
  130. <span class="post-meta mr-2">
  131. <i class="iconfont icon-author" aria-hidden="true"></i>
  132. Limour
  133. </span>
  134. <span class="post-meta">
  135. <i class="iconfont icon-date-fill" aria-hidden="true"></i>
  136. <time datetime="2024-01-18 22:19" pubdate>
  137. January 18, 2024 pm
  138. </time>
  139. </span>
  140. </div>
  141. <div class="mt-1">
  142. <span class="post-meta mr-2">
  143. <i class="iconfont icon-chart"></i>
  144. 1.7k words
  145. </span>
  146. <span class="post-meta mr-2">
  147. <i class="iconfont icon-clock-fill"></i>
  148. 15 mins
  149. </span>
  150. </div>
  151. </div>
  152. </div>
  153. </div>
  154. </div>
  155. </div>
  156. </header>
  157. <main>
  158. <div class="container-fluid nopadding-x">
  159. <div class="row nomargin-x">
  160. <div class="side-col d-none d-lg-block col-lg-2">
  161. </div>
  162. <div class="col-lg-8 nopadding-x-md">
  163. <div class="container nopadding-x-md" id="board-ctn">
  164. <div id="board">
  165. <article class="post-content mx-auto">
  166. <h1 id="seo-header">【探索】从零开始训练 GPT</h1>
  167. <p id="updated-time" class="note note-info" style="">
  168. Last updated on March 19, 2024 pm
  169. </p>
  170. <div class="markdown-body">
  171. <p><img src="https://img.limour.top/2024/01/18/65a93c6a8065a.webp" srcset="https://jscdn.limour.top/gh/Limour-dev/Sakurairo_Vision/load_svg/inload.svg" lazyload alt="训练中..."></p>
  172. <h2 id="预期结构">预期结构</h2>
  173. <ul>
  174. <li><a href="https://hexo.limour.top/go/#aHR0cHM6Ly9naXRodWIuY29tL0xpbW91ci1kZXYvSGVsbG9HUFQ=" rel="noopener external nofollow noreferrer">相关代码已经放到 Github</a></li>
  175. </ul>
  176. <figure class="highlight txt"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br></pre></td><td class="code"><pre><code class="hljs txt">HelloGPT(<br> (tok_embeddings): Embedding(32765, 768)<br> (rotary_emb): RotaryEmbedding(head_dim=64, max_seq_len=1024)<br> (layers): ModuleList(<br> (0-11): 12 x Decoder(<br> (ln1): RMSNorm(hidden_size=768, eps=1e-06)<br> (attn): Attention(<br> (q_proj): Linear(in_features=768, out_features=768, bias=False)<br> (k_proj): Linear(in_features=768, out_features=768, bias=False)<br> (v_proj): Linear(in_features=768, out_features=768, bias=False)<br> (o_proj): Linear(in_features=768, out_features=768, bias=False)<br> )<br> (ln2): RMSNorm(hidden_size=768, eps=1e-06)<br> (mlp): MLP(<br> (gate_proj): Linear(in_features=768, out_features=1536, bias=False)<br> (up_proj): Linear(in_features=768, out_features=1536, bias=False)<br> (down_proj): Linear(in_features=1536, out_features=768, bias=False)<br> )<br> )<br> )<br> (norm): RMSNorm(hidden_size=768, eps=1e-06)<br> (ln2): Linear(in_features=768, out_features=32765, bias=False)<br>)<br></code></pre></td></tr></table></figure>
  177. <h2 id="配置环境">配置环境</h2>
  178. <ul>
  179. <li><a href="/-ji-lu--an-zhuang-conda-bing-geng-huan-qing-hua-yuan">安装conda</a></li>
  180. </ul>
  181. <figure class="highlight powershell"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><code class="hljs powershell"><span class="hljs-built_in">cd</span> E:\GPT<br>conda install mamba <span class="hljs-literal">-c</span> conda<span class="hljs-literal">-forge</span><br>mamba create <span class="hljs-literal">-n</span> HelloGPT pytorch pytorch<span class="hljs-literal">-cuda</span>=<span class="hljs-number">12.1</span> <span class="hljs-literal">-c</span> pytorch <span class="hljs-literal">-c</span> nvidia <span class="hljs-literal">-c</span> conda<span class="hljs-literal">-forge</span><br>conda activate HelloGPT<br>conda install numpy transformers tiktoken tensorboard sentencepiece<span class="hljs-literal">-python</span> jieba emoji <span class="hljs-literal">-c</span> conda<span class="hljs-literal">-forge</span><br>pip install opencc<span class="hljs-literal">-python-reimplemented</span> <span class="hljs-literal">-i</span> https://pypi.tuna.tsinghua.edu.cn/simple<br>python test_cuda.py<br>python test_SPDA.py<br>D:\vscode\Code.exe<br></code></pre></td></tr></table></figure>
  182. <h2 id="准备数据">准备数据</h2>
  183. <ul>
  184. <li>下载 <a href="https://hexo.limour.top/go/#aHR0cHM6Ly9odWdnaW5nZmFjZS5jby9jb2xsZWN0aW9ucy9MaW1vdXIvcjE4LW5vdmVscy1nYWxnYW1lLTY1OThmMTY4OTRjYWRjOWNkY2IzZjNhYg==" rel="noopener external nofollow noreferrer">h-corpus-2023</a></li>
  185. </ul>
  186. <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">import</span> os<br><br><span class="hljs-keyword">class</span> <span class="hljs-title class_">Fileset</span>(<span class="hljs-title class_ inherited__">list</span>):<br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, path, ext=<span class="hljs-string">&#x27;&#x27;</span>, _read=<span class="hljs-literal">None</span></span>):<br> <span class="hljs-keyword">if</span> <span class="hljs-built_in">isinstance</span>(path, <span class="hljs-built_in">str</span>):<br> self.root = path<br> self.extend(f <span class="hljs-keyword">for</span> f <span class="hljs-keyword">in</span> os.listdir(self.root) <span class="hljs-keyword">if</span> f.endswith(ext))<br> self._read = _read<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__getitem__</span>(<span class="hljs-params">self, index</span>):<br> <span class="hljs-keyword">if</span> <span class="hljs-built_in">isinstance</span>(index, <span class="hljs-built_in">int</span>): <span class="hljs-comment"># index是索引</span><br> <span class="hljs-keyword">if</span> self._read:<br> <span class="hljs-keyword">return</span> self._read(os.path.join(self.root, <span class="hljs-built_in">super</span>().__getitem__(index)))<br> <span class="hljs-keyword">else</span>:<br> <span class="hljs-keyword">return</span> os.path.join(self.root, <span class="hljs-built_in">super</span>().__getitem__(index))<br> <span class="hljs-keyword">else</span>: <span class="hljs-comment"># index是切片</span><br> fileset = Fileset(<span class="hljs-literal">None</span>)<br> fileset.root = self.root<br> fileset._read = self._read<br> fileset.extend(<span class="hljs-built_in">super</span>().__getitem__(index))<br> <span class="hljs-keyword">return</span> fileset<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">getFileName</span>(<span class="hljs-params">self, index</span>):<br> fname, ext = os.path.splitext(<span class="hljs-built_in">super</span>().__getitem__(index))<br> <span class="hljs-keyword">return</span> fname<br><br><br><span class="hljs-keyword">from</span> tokenizer <span class="hljs-keyword">import</span> tokenizer<br>token_eos = <span class="hljs-number">2</span><br><br><br><span class="hljs-keyword">def</span> <span class="hljs-title function_">readOne</span>(<span class="hljs-params">filePath</span>):<br> retn = []<br> <span class="hljs-keyword">with</span> <span class="hljs-built_in">open</span>(file=filePath, encoding=<span class="hljs-string">&#x27;utf-8&#x27;</span>) <span class="hljs-keyword">as</span> f:<br> <span class="hljs-keyword">for</span> line <span class="hljs-keyword">in</span> f:<br> retn += tokenizer.encode(line).ids<br> retn.append(token_eos)<br> <span class="hljs-keyword">return</span> retn<br><br><br><span class="hljs-keyword">class</span> <span class="hljs-title class_">Hcorpus</span>():<br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, path, ext=<span class="hljs-string">&#x27;txt&#x27;</span>, fileset_idx=<span class="hljs-number">0</span>, fileset_sub_idx=<span class="hljs-number">0</span></span>):<br> self.fileset = Fileset(path, ext, readOne)<br> self.fileset_idx = fileset_idx<br> self.fileset_sub_idx = fileset_sub_idx<br> <span class="hljs-keyword">if</span> self.fileset_sub_idx &lt; <span class="hljs-number">0</span>: <span class="hljs-comment"># 再读上一个太复杂了,直接放弃</span><br> self.fileset_sub_idx = <span class="hljs-number">0</span><br> <span class="hljs-keyword">if</span> self.fileset_idx &gt;= <span class="hljs-built_in">len</span>(self.fileset):<br> self.fileset_idx = <span class="hljs-number">0</span><br> self.cache = self.fileset[self.fileset_idx]<br> self.fileset_idx += <span class="hljs-number">1</span><br> self.cache_idx = self.fileset_sub_idx<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__call__</span>(<span class="hljs-params">self, size=<span class="hljs-number">512</span></span>):<br> <span class="hljs-keyword">while</span> <span class="hljs-built_in">len</span>(self.cache) &lt; self.cache_idx + size:<br> <span class="hljs-keyword">if</span> self.fileset_idx &gt;= <span class="hljs-built_in">len</span>(self.fileset):<br> self.fileset_idx = <span class="hljs-number">0</span><br> self.fileset_sub_idx = self.cache_idx - <span class="hljs-built_in">len</span>(self.cache)<br> self.cache = self.cache[self.cache_idx:] + self.fileset[self.fileset_idx]<br> self.cache_idx = <span class="hljs-number">0</span><br> self.fileset_idx += <span class="hljs-number">1</span><br> retn = self.cache[self.cache_idx:self.cache_idx + size]<br> self.cache_idx += size<br> self.fileset_sub_idx += size<br> <span class="hljs-keyword">return</span> retn<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__repr__</span>(<span class="hljs-params">self</span>):<br> <span class="hljs-keyword">return</span> <span class="hljs-string">f&quot;Hcorpus(r&#x27;<span class="hljs-subst">&#123;self.fileset.root&#125;</span>&#x27;, fileset_idx=<span class="hljs-subst">&#123;self.fileset_idx-<span class="hljs-number">1</span>&#125;</span>, fileset_sub_idx=<span class="hljs-subst">&#123;self.fileset_sub_idx&#125;</span>)&quot;</span><br></code></pre></td></tr></table></figure>
  187. <h2 id="训练Tokenizer">训练Tokenizer</h2>
  188. <ul>
  189. <li><a href="https://hexo.limour.top/go/#aHR0cHM6Ly9odWdnaW5nZmFjZS5jby9kb2NzL3Rva2VuaXplcnMvcXVpY2t0b3Vy" rel="noopener external nofollow noreferrer">tokenizer 包的文档</a></li>
  190. <li>繁体转换成简体:<a href="https://hexo.limour.top/go/#aHR0cHM6Ly9naXRodWIuY29tL0xpbW91ci1kZXYvSGVsbG9HUFQvYmxvYi9tYWluL3RyYWluX3Rva2VuaXplcl9wcmUucHk=" rel="noopener external nofollow noreferrer">train_tokenizer_pre.py</a></li>
  191. <li>获取常用 emoji:<a href="https://hexo.limour.top/go/#aHR0cHM6Ly9naXRodWIuY29tL0xpbW91ci1kZXYvSGVsbG9HUFQvYmxvYi9tYWluL3RtcF9lbW9qaS5weQ==" rel="noopener external nofollow noreferrer">tmp_emoji.py</a></li>
  192. <li>分词统计词频:<a href="https://hexo.limour.top/go/#aHR0cHM6Ly9naXRodWIuY29tL0xpbW91ci1kZXYvSGVsbG9HUFQvYmxvYi9tYWluL3RyYWluX3Rva2VuaXplcl9qaWViYS5weQ==" rel="noopener external nofollow noreferrer">tokenizer_jieba.py</a></li>
  193. <li>区分词性并构造 BPE 语料:<a href="https://hexo.limour.top/go/#aHR0cHM6Ly9naXRodWIuY29tL0xpbW91ci1kZXYvSGVsbG9HUFQvYmxvYi9tYWluL3RyYWluX3Rva2VuaXplcl9qaWViYV9zdGF0aXN0aWNzLnB5" rel="noopener external nofollow noreferrer">train_tokenizer_jieba_statistics.py</a></li>
  194. <li>训练 BPE 模型:<a href="https://hexo.limour.top/go/#aHR0cHM6Ly9naXRodWIuY29tL0xpbW91ci1kZXYvSGVsbG9HUFQvYmxvYi9tYWluL3RyYWluX3Rva2VuaXplci5weQ==" rel="noopener external nofollow noreferrer">train_tokenizer.py</a></li>
  195. <li>最终训练好的 BPE 模型:<a href="https://hexo.limour.top/go/#aHR0cHM6Ly9naXRodWIuY29tL0xpbW91ci1kZXYvSGVsbG9HUFQvYmxvYi9tYWluL0hlbGxvQlBFLnRva2VuaXplci5qc29u" rel="noopener external nofollow noreferrer">HelloBPE.tokenizer.json</a></li>
  196. </ul>
  197. <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">from</span> tokenizers <span class="hljs-keyword">import</span> Tokenizer<br>tokenizer = Tokenizer.from_file(<span class="hljs-string">&quot;HelloBPE.tokenizer.json&quot;</span>)<br></code></pre></td></tr></table></figure>
  198. <h2 id="定义模型">定义模型</h2>
  199. <h3 id="定义-Decoder">定义 Decoder</h3>
  200. <h4 id="定义-RMSnorm">定义 RMSnorm</h4>
  201. <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">class</span> <span class="hljs-title class_">RMSNorm</span>(nn.Module):<br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, dim: <span class="hljs-built_in">int</span>, eps: <span class="hljs-built_in">float</span> = <span class="hljs-number">1e-6</span></span>):<br> <span class="hljs-built_in">super</span>().__init__()<br> self.eps = eps<br> self.weight = nn.Parameter(torch.ones(dim))<br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, x</span>):<br> x = x * torch.rsqrt(x.<span class="hljs-built_in">pow</span>(<span class="hljs-number">2</span>).mean(-<span class="hljs-number">1</span>, keepdim=<span class="hljs-literal">True</span>) + self.eps)<br> <span class="hljs-keyword">return</span> x * self.weight<br></code></pre></td></tr></table></figure>
  202. <h4 id="定义-RoPE">定义 RoPE</h4>
  203. <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">class</span> <span class="hljs-title class_">RotaryEmbedding</span>(nn.Module):<br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, head_dim: <span class="hljs-built_in">int</span>, max_seq_len: <span class="hljs-built_in">int</span>, device=device, theta: <span class="hljs-built_in">float</span> = <span class="hljs-number">10000.0</span></span>):<br> <span class="hljs-built_in">super</span>().__init__()<br> self.head_dim = head_dim<br> self.set_max_seq_len(max_seq_len, device, theta)<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">set_max_seq_len</span>(<span class="hljs-params">self, max_seq_len: <span class="hljs-built_in">int</span>, device=device, theta: <span class="hljs-built_in">float</span> = <span class="hljs-number">10000.0</span></span>):<br> self.max_seq_len = max_seq_len<br> freqs = <span class="hljs-number">1.0</span> / (theta ** (torch.arange(<span class="hljs-number">0</span>, self.head_dim, <span class="hljs-number">2</span>).<span class="hljs-built_in">float</span>().to(device) / self.head_dim))<br> t = torch.arange(max_seq_len, device=device) <span class="hljs-comment"># type: ignore</span><br> freqs = torch.outer(t, freqs).<span class="hljs-built_in">float</span>() <span class="hljs-comment"># 外积</span><br> self.freqs_cis = torch.polar(torch.ones_like(freqs), freqs) <span class="hljs-comment"># 复数,模 1,角度 freqs</span><br> self.freqs_cis.requires_grad = <span class="hljs-literal">False</span> <span class="hljs-comment"># filter(lambda p : p.requires_grad, model.parameters())</span><br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">rotary_emb</span>(<span class="hljs-params">self, x</span>):<br> x_ = torch.view_as_complex(x.<span class="hljs-built_in">float</span>().reshape(*x.shape[:-<span class="hljs-number">1</span>], -<span class="hljs-number">1</span>, <span class="hljs-number">2</span>))<br> x_out = torch.view_as_real(x_ * self.local_freqs_cis).flatten(<span class="hljs-number">3</span>)<br> <span class="hljs-keyword">return</span> x_out.type_as(x)<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, start_pos: <span class="hljs-built_in">int</span>, seqlen: <span class="hljs-built_in">int</span></span>):<br> self.local_freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen].view(<span class="hljs-number">1</span>, seqlen, <span class="hljs-number">1</span>, -<span class="hljs-number">1</span>) <span class="hljs-comment"># cacheKV 相关,可忽略</span><br> self.local_freqs_cis.requires_grad = <span class="hljs-literal">False</span><br> <span class="hljs-keyword">return</span> self.rotary_emb<br></code></pre></td></tr></table></figure>
  204. <h4 id="定义-Attention">定义 Attention</h4>
  205. <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">class</span> <span class="hljs-title class_">Attention</span>(nn.Module):<br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, hidden_size, n_heads, cacheKV, max_batch_size, max_seq_len, device=device</span>):<br> <span class="hljs-built_in">super</span>().__init__()<br> self.n_heads = n_heads<br> self.head_dim = hidden_size // n_heads<br> self.q_proj = nn.Linear(hidden_size, hidden_size, bias=<span class="hljs-literal">False</span>)<br> self.k_proj = nn.Linear(hidden_size, hidden_size, bias=<span class="hljs-literal">False</span>)<br> self.v_proj = nn.Linear(hidden_size, hidden_size, bias=<span class="hljs-literal">False</span>)<br> self.o_proj = nn.Linear(hidden_size, hidden_size, bias=<span class="hljs-literal">False</span>)<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, hidden_states, rotary_emb, start_pos=<span class="hljs-number">0</span>, mask=<span class="hljs-literal">None</span>, is_causal=<span class="hljs-literal">True</span></span>):<br> bsz, seqlen, hidden_size = hidden_states.shape<br><br> q = self.q_proj(hidden_states)<br> k = self.k_proj(hidden_states)<br> v = self.v_proj(hidden_states)<br><br> q = q.view(bsz, seqlen, self.n_heads, self.head_dim)<br> k = k.view(bsz, seqlen, self.n_heads, self.head_dim)<br> v = v.view(bsz, seqlen, self.n_heads, self.head_dim)<br><br> q = rotary_emb(q)<br> k = rotary_emb(k)<br><br> q = q.transpose(<span class="hljs-number">1</span>, <span class="hljs-number">2</span>) <span class="hljs-comment"># (bs, n_heads, seqlen, head_dim)</span><br> k = k.transpose(<span class="hljs-number">1</span>, <span class="hljs-number">2</span>) <span class="hljs-comment"># (bs, n_local_heads, cache_len + seqlen, head_dim)</span><br> v = v.transpose(<span class="hljs-number">1</span>, <span class="hljs-number">2</span>) <span class="hljs-comment"># (bs, n_local_heads, cache_len + seqlen, head_dim)</span><br><br> output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)<br><br> output = output.transpose(<span class="hljs-number">1</span>, <span class="hljs-number">2</span>).contiguous().view(bsz, seqlen, hidden_size)<br> <span class="hljs-keyword">return</span> self.o_proj(output)<br></code></pre></td></tr></table></figure>
  206. <h4 id="定义-MLP">定义 MLP</h4>
  207. <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">class</span> <span class="hljs-title class_">MLP</span>(nn.Module):<br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, hidden_size</span>):<br> <span class="hljs-built_in">super</span>().__init__()<br> intermediate_size = <span class="hljs-built_in">int</span>(<span class="hljs-number">2</span> * hidden_size)<br> self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=<span class="hljs-literal">False</span>)<br> self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=<span class="hljs-literal">False</span>)<br> self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=<span class="hljs-literal">False</span>)<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, x</span>):<br> gate = F.silu(self.gate_proj(x))<br> intermediate_states = self.up_proj(x)<br> <span class="hljs-keyword">return</span> self.down_proj(gate * intermediate_states)<br></code></pre></td></tr></table></figure>
  208. <h4 id="组装-Decoder">组装 Decoder</h4>
  209. <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">class</span> <span class="hljs-title class_">Decoder</span>(nn.Module):<br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, hidden_size, n_heads, cacheKV, max_batch_size, max_seq_len</span>):<br> <span class="hljs-built_in">super</span>().__init__()<br> self.ln1 = RMSNorm(hidden_size)<br> self.attn = Attention(hidden_size, n_heads, cacheKV, max_batch_size, max_seq_len)<br> self.ln2 = RMSNorm(hidden_size)<br> self.mlp = MLP(hidden_size)<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, x, rotary_emb, start_pos, mask=<span class="hljs-literal">None</span>, is_causal=<span class="hljs-literal">True</span></span>):<br> x = x + self.attn(self.ln1(x), rotary_emb, start_pos, mask, is_causal)<br> <span class="hljs-keyword">return</span> x + self.mlp(self.ln2(x))<br></code></pre></td></tr></table></figure>
  210. <h3 id="组装模型">组装模型</h3>
  211. <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">class</span> <span class="hljs-title class_">HelloGPT</span>(nn.Module):<br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">__init__</span>(<span class="hljs-params">self, vocab_size=<span class="hljs-number">32765</span>, hidden_size=<span class="hljs-number">768</span>, n_heads=<span class="hljs-number">12</span>, max_seq_len=<span class="hljs-number">1024</span>, n_layers=<span class="hljs-number">12</span>, cacheKV=<span class="hljs-literal">False</span>, max_batch_size=<span class="hljs-number">1</span></span>):<br> <span class="hljs-built_in">super</span>().__init__()<br> <span class="hljs-comment"># hidden_size &gt; 8.33 * ln(vocab_size)</span><br> self.tok_embeddings = nn.Embedding(vocab_size, hidden_size)<br> self.rotary_emb = RotaryEmbedding(hidden_size // n_heads, max_seq_len * <span class="hljs-number">2</span>)<br> self.rotary_emb.requires_grad = <span class="hljs-literal">False</span><br> self.layers = nn.ModuleList()<br> <span class="hljs-keyword">for</span> layer_id <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(n_layers):<br> self.layers.append(Decoder(hidden_size, n_heads, cacheKV, max_batch_size, max_seq_len))<br> self.norm = RMSNorm(hidden_size)<br> self.ln2 = nn.Linear(hidden_size, vocab_size, bias=<span class="hljs-literal">False</span>)<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, input_ids: torch.Tensor, start_pos=<span class="hljs-number">0</span>, no_mask=<span class="hljs-literal">True</span></span>):<br> _bsz, seqlen = input_ids.shape<br> h = self.tok_embeddings(input_ids)<br><br> <span class="hljs-comment"># 预计算,减少每一层的重复计算</span><br> rotary_emb = self.rotary_emb(start_pos, seqlen)<br> <span class="hljs-keyword">for</span> layer <span class="hljs-keyword">in</span> self.layers:<br> h = layer(h, rotary_emb, start_pos)<br><br> h = self.norm(h)<br> h = self.ln2(h)<br> <span class="hljs-keyword">return</span> h.<span class="hljs-built_in">float</span>()<br></code></pre></td></tr></table></figure>
  212. <h2 id="训练模型">训练模型</h2>
  213. <h3 id="数据载入">数据载入</h3>
  214. <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><code class="hljs python">data = Hcorpus(<span class="hljs-string">r&#x27;D:\datasets\h-corpus&#x27;</span>)<br><span class="hljs-keyword">def</span> <span class="hljs-title function_">get_batch</span>(<span class="hljs-params">size=<span class="hljs-number">512</span>, bsz=<span class="hljs-number">8</span></span>):<br> x = []<br> y = []<br> <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(bsz):<br> tmp = data(size+<span class="hljs-number">1</span>)<br> x.append(tmp[:size])<br> y.append(tmp[<span class="hljs-number">1</span>:])<br> <span class="hljs-keyword">return</span> torch.tensor(x).to(device), torch.tensor(y).to(device)<br></code></pre></td></tr></table></figure>
  215. <h3 id="模型载入">模型载入</h3>
  216. <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><code class="hljs python">model = HelloGPT(n_layers=<span class="hljs-number">8</span>, max_seq_len=<span class="hljs-number">768</span>)<br>model.to(device)<br></code></pre></td></tr></table></figure>
  217. <h3 id="训练模型-2">训练模型</h3>
  218. <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-comment">## 初始化训练器</span><br>criterion = nn.CrossEntropyLoss() <span class="hljs-comment"># 交叉熵损失函数</span><br>optimizer = torch.optim.Adam(train_parameters, lr=<span class="hljs-number">6e-4</span>) <span class="hljs-comment"># Adam 优化器</span><br>scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=<span class="hljs-number">5</span>, T_mult=<span class="hljs-number">2</span>) <span class="hljs-comment"># 余弦退火学习率</span><br>torch.manual_seed(<span class="hljs-number">1337</span>) <span class="hljs-comment"># 魔术随机种子</span><br><br>total_loss = <span class="hljs-number">0</span><br>print_iter = <span class="hljs-number">20</span><br><span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">1</span>, <span class="hljs-number">100001</span>):<br> optimizer.zero_grad(set_to_none=<span class="hljs-literal">True</span>) <span class="hljs-comment"># 清空梯度,节省显存</span><br> x, y = get_batch(size=<span class="hljs-number">384</span>, bsz=<span class="hljs-number">4</span>) <span class="hljs-comment"># x 是训练语料 y 是 x 移动了一位,当做预测目标</span><br> y_ = model(x) <span class="hljs-comment"># 通过 x 预测的 y</span><br> loss = criterion(y_.view(-<span class="hljs-number">1</span>, <span class="hljs-number">32765</span>), y.view(-<span class="hljs-number">1</span>)) <span class="hljs-comment"># 计算损失</span><br> loss.backward() <span class="hljs-comment"># 反向传播梯度</span><br> torch.nn.utils.clip_grad_norm_(train_parameters, <span class="hljs-number">0.5</span>) <span class="hljs-comment"># 梯度裁剪,减轻过拟合</span><br> optimizer.step() <span class="hljs-comment"># 通过梯度优化训练参数</span><br> scheduler.step() <span class="hljs-comment"># 计算下一步的学习率</span><br> total_loss += loss <span class="hljs-comment"># 累计损失</span><br><br> <span class="hljs-keyword">if</span> epoch % print_iter == <span class="hljs-number">0</span>:<br> <span class="hljs-built_in">print</span>(data)<br> <span class="hljs-built_in">print</span>(<span class="hljs-string">f&#x27;epoch: <span class="hljs-subst">&#123;epoch&#125;</span> lr: <span class="hljs-subst">&#123;scheduler.get_last_lr()[<span class="hljs-number">0</span>]:<span class="hljs-number">.4</span>e&#125;</span> loss: <span class="hljs-subst">&#123;total_loss / print_iter:<span class="hljs-number">.4</span>e&#125;</span>&#x27;</span>)<br> total_loss = <span class="hljs-number">0</span><br></code></pre></td></tr></table></figure>
  219. <h3 id="保存读取">保存读取</h3>
  220. <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">with</span> <span class="hljs-built_in">open</span>(<span class="hljs-string">&#x27;tmp_training.pkl&#x27;</span>, <span class="hljs-string">&#x27;rb&#x27;</span>) <span class="hljs-keyword">as</span> file:<br> epoch = pickle.load(file) <span class="hljs-comment"># 读取 epoch 位置</span><br> tmp_fileset_idx = pickle.load(file) <span class="hljs-comment"># 读取 data 位置</span><br> tmp_fileset_sub_idx = pickle.load(file)<br><span class="hljs-comment"># 恢复数据位置</span><br>data = Hcorpus(<span class="hljs-string">r&#x27;D:\datasets\h-corpus&#x27;</span>, fileset_idx=tmp_fileset_idx-<span class="hljs-number">1</span>, fileset_sub_idx=tmp_fileset_sub_idx)<br>model = torch.load(<span class="hljs-string">f&#x27;tmp_model_<span class="hljs-subst">&#123;epoch&#125;</span>.pth&#x27;</span>) <span class="hljs-comment"># 恢复模型</span><br><span class="hljs-built_in">print</span>(<span class="hljs-string">f&#x27;start from epoch: <span class="hljs-subst">&#123;epoch&#125;</span> data: <span class="hljs-subst">&#123;data&#125;</span>&#x27;</span>)<br><br>save_iter = <span class="hljs-number">5000</span><br><span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">1</span>, <span class="hljs-number">100001</span>):<br> <span class="hljs-keyword">pass</span><br> <span class="hljs-keyword">if</span> epoch % save_iter == <span class="hljs-number">0</span>:<br> optimizer.zero_grad(set_to_none=<span class="hljs-literal">True</span>) <span class="hljs-comment"># 清空梯度,节省显存</span><br> <span class="hljs-keyword">with</span> <span class="hljs-built_in">open</span>(<span class="hljs-string">&#x27;tmp_training.pkl&#x27;</span>, <span class="hljs-string">&#x27;wb&#x27;</span>) <span class="hljs-keyword">as</span> file:<br> pickle.dump(epoch, file) <span class="hljs-comment"># 保存 epoch 位置</span><br> pickle.dump(data.fileset_idx, file) <span class="hljs-comment"># 保存 data 位置</span><br> pickle.dump(data.fileset_sub_idx, file)<br> torch.save(model, <span class="hljs-string">f&#x27;tmp_model_<span class="hljs-subst">&#123;epoch&#125;</span>.pth&#x27;</span>) <span class="hljs-comment"># 保存模型</span><br> <span class="hljs-built_in">print</span>(<span class="hljs-string">f&#x27;save to tmp_model_<span class="hljs-subst">&#123;epoch&#125;</span>.pth&#x27;</span>)<br></code></pre></td></tr></table></figure>
  221. <h3 id="可视化">可视化</h3>
  222. <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><code class="hljs python">writer = SummaryWriter(<span class="hljs-string">&#x27;logs&#x27;</span>) <span class="hljs-comment"># tensorboard --logdir logs</span><br><span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(<span class="hljs-number">1</span>, <span class="hljs-number">100001</span>):<br> <span class="hljs-keyword">pass</span><br> writer.add_scalar(<span class="hljs-string">&#x27;lr&#x27;</span>, scheduler.get_last_lr()[<span class="hljs-number">0</span>], epoch)<br> writer.add_scalar(<span class="hljs-string">&#x27;loss&#x27;</span>, loss, epoch)<br> <span class="hljs-keyword">if</span> epoch % print_iter == <span class="hljs-number">0</span>:<br> <span class="hljs-keyword">pass</span><br> writer.add_scalar(<span class="hljs-string">&#x27;total_loss&#x27;</span>, total_loss / print_iter, epoch)<br>writer.close()<br></code></pre></td></tr></table></figure>
  223. <h2 id="附加-streaming-llm">附加 streaming_llm</h2>
  224. <figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br></pre></td><td class="code"><pre><code class="hljs python"><span class="hljs-keyword">class</span> <span class="hljs-title class_">RotaryEmbedding</span>(nn.Module):<br> <span class="hljs-keyword">pass</span><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">inverse_rotary_emb</span>(<span class="hljs-params">self, x</span>):<br> x_ = torch.view_as_complex(x.<span class="hljs-built_in">float</span>().reshape(*x.shape[:-<span class="hljs-number">1</span>], -<span class="hljs-number">1</span>, <span class="hljs-number">2</span>))<br> x_out = torch.view_as_real(x_ * self.local_freqs_cis_inverse).flatten(<span class="hljs-number">3</span>)<br> <span class="hljs-keyword">return</span> x_out.type_as(x)<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">inverse_forward</span>(<span class="hljs-params">self, start_pos: <span class="hljs-built_in">int</span>, seqlen: <span class="hljs-built_in">int</span></span>):<br> self.local_freqs_cis_inverse = self.freqs_cis[start_pos: start_pos + seqlen].view(<span class="hljs-number">1</span>, seqlen, <span class="hljs-number">1</span>, -<span class="hljs-number">1</span>) <span class="hljs-comment"># cacheKV 相关,可忽略</span><br> self.local_freqs_cis_inverse = self.local_freqs_cis_inverse.conj() <span class="hljs-comment"># 乘上共轭就旋转回去了</span><br> self.local_freqs_cis.requires_grad = <span class="hljs-literal">False</span><br> <span class="hljs-keyword">return</span> self.inverse_rotary_emb<br><br><span class="hljs-keyword">class</span> <span class="hljs-title class_">Attention</span>(nn.Module):<br> <span class="hljs-keyword">pass</span><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">forward</span>(<span class="hljs-params">self, hidden_states, rotary_emb, start_pos=<span class="hljs-number">0</span>, mask=<span class="hljs-literal">None</span>, is_causal=<span class="hljs-literal">True</span></span>):<br> <span class="hljs-keyword">pass</span><br> <span class="hljs-keyword">if</span> self.cacheKV: <span class="hljs-comment"># cacheKV 相关,可忽略</span><br> self.cache_k[:bsz, start_pos: start_pos + seqlen] = k<br> self.cache_v[:bsz, start_pos: start_pos + seqlen] = v<br> k = self.cache_k[:bsz, : start_pos + seqlen]<br> v = self.cache_v[:bsz, : start_pos + seqlen]<br><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">streaming_llm</span>(<span class="hljs-params">self, start_pos, seqlen, to_pos, inverse_rotary_emb, rotary_emb, bsz</span>):<br> k = self.cache_k[:bsz, start_pos: start_pos + seqlen]<br> v = self.cache_v[:bsz, start_pos: start_pos + seqlen]<br> k = inverse_rotary_emb(k)<br> k = rotary_emb(k)<br> self.cache_k[:bsz, to_pos: to_pos + seqlen] = k<br> self.cache_v[:bsz, to_pos: to_pos + seqlen] = v<br><br><span class="hljs-keyword">class</span> <span class="hljs-title class_">HelloGPT</span>(nn.Module):<br> <span class="hljs-keyword">pass</span><br> <span class="hljs-keyword">def</span> <span class="hljs-title function_">streaming_llm</span>(<span class="hljs-params">self, start_pos, seqlen, to_pos, max_batch_size=<span class="hljs-number">1</span></span>):<br> rotary_emb = self.rotary_emb(to_pos, seqlen)<br> inverse_rotary_emb = self.rotary_emb.inverse_forward(start_pos, seqlen)<br> <span class="hljs-keyword">for</span> layer <span class="hljs-keyword">in</span> self.layers:<br> layer.attn.streaming_llm(start_pos, seqlen, to_pos, inverse_rotary_emb, rotary_emb, max_batch_size)<br></code></pre></td></tr></table></figure>
  225. </div>
  226. <hr/>
  227. <div>
  228. <div class="post-metas my-3">
  229. <div class="post-meta">
  230. <i class="iconfont icon-tags"></i>
  231. <a href="/tags/%E6%8E%A2%E7%B4%A2/" class="print-no-link">#探索</a>
  232. <a href="/tags/llama/" class="print-no-link">#llama</a>
  233. </div>
  234. </div>
  235. <div class="license-box my-3">
  236. <div class="license-title">
  237. <div>【探索】从零开始训练 GPT</div>
  238. <div>https://hexo.limour.top/training-gpt-from-scratch</div>
  239. </div>
  240. <div class="license-meta">
  241. <div class="license-meta-item">
  242. <div>Author</div>
  243. <div>Limour</div>
  244. </div>
  245. <div class="license-meta-item license-meta-date">
  246. <div>Posted on</div>
  247. <div>January 18, 2024</div>
  248. </div>
  249. <div class="license-meta-item license-meta-date">
  250. <div>Updated on</div>
  251. <div>March 19, 2024</div>
  252. </div>
  253. <div class="license-meta-item">
  254. <div>Licensed under</div>
  255. <div>
  256. <a class="print-no-link" target="_blank" href="https://creativecommons.org/licenses/by-nc-sa/4.0/">
  257. <span class="hint--top hint--rounded" aria-label="BY - Attribution">
  258. <i class="iconfont icon-cc-by"></i>
  259. </span>
  260. </a>
  261. <a class="print-no-link" target="_blank" href="https://creativecommons.org/licenses/by-nc-sa/4.0/">
  262. <span class="hint--top hint--rounded" aria-label="NC - Non-commercial">
  263. <i class="iconfont icon-cc-nc"></i>
  264. </span>
  265. </a>
  266. <a class="print-no-link" target="_blank" href="https://creativecommons.org/licenses/by-nc-sa/4.0/">
  267. <span class="hint--top hint--rounded" aria-label="SA - Share-alike">
  268. <i class="iconfont icon-cc-sa"></i>
  269. </span>
  270. </a>
  271. </div>
  272. </div>
  273. </div>
  274. <div class="license-icon iconfont"></div>
  275. </div>
  276. <div class="post-prevnext my-3">
  277. <article class="post-prev col-6">
  278. <a href="/Convert-BlueLM-7B-Chat-to-the-standard-GGUF-model" title="【探索】将BlueLM-7B-Chat转换为标准的GGUF模型">
  279. <i class="iconfont icon-arrowleft"></i>
  280. <span class="hidden-mobile">【探索】将BlueLM-7B-Chat转换为标准的GGUF模型</span>
  281. <span class="visible-mobile">Previous</span>
  282. </a>
  283. </article>
  284. <article class="post-next col-6">
  285. <a href="/Azure-AI-prevents-reverse-wool-shearing" title="【避坑】Azure AI 避免反向薅羊毛">
  286. <span class="hidden-mobile">【避坑】Azure AI 避免反向薅羊毛</span>
  287. <span class="visible-mobile">Next</span>
  288. <i class="iconfont icon-arrowright"></i>
  289. </a>
  290. </article>
  291. </div>
  292. </div>
  293. <article id="comments" lazyload>
  294. <div id="waline"></div>
  295. <script type="text/javascript">
  296. Fluid.utils.loadComments('#waline', function() {
  297. Fluid.utils.createCssLink('https://cdn.staticfile.org/waline/2.15.5/waline.css')
  298. Fluid.utils.createScript('https://cdn.staticfile.org/waline/2.15.5/waline.js', function() {
  299. var options = Object.assign(
  300. {"serverURL":"https://comments.limour.top","path":"window.location.pathname","meta":["nick","mail","link"],"requiredMeta":["nick"],"lang":"zh-CN","emoji":["https://jscdn.limour.top/gh/walinejs/emojis/weibo"],"dark":"html[data-user-color-scheme=\"dark\"]","wordLimit":0,"pageSize":10},
  301. {
  302. el: '#waline',
  303. path: window.location.pathname
  304. }
  305. )
  306. Waline.init(options);
  307. Fluid.utils.waitElementVisible('#waline .vcontent', () => {
  308. var imgSelector = '#waline .vcontent img:not(.vemoji)';
  309. Fluid.plugins.imageCaption(imgSelector);
  310. Fluid.plugins.fancyBox(imgSelector);
  311. })
  312. });
  313. });
  314. </script>
  315. <noscript>Please enable JavaScript to view the comments</noscript>
  316. </article>
  317. </article>
  318. </div>
  319. </div>
  320. </div>
  321. <div class="side-col d-none d-lg-block col-lg-2">
  322. <aside class="sidebar" style="margin-left: -1rem">
  323. <div id="toc">
  324. <p class="toc-header">
  325. <i class="iconfont icon-list"></i>
  326. <span>Table of Contents</span>
  327. </p>
  328. <div class="toc-body" id="toc-body"></div>
  329. </div>
  330. </aside>
  331. </div>
  332. </div>
  333. </div>
  334. <a id="scroll-top-button" aria-label="TOP" href="#" role="button">
  335. <i class="iconfont icon-arrowup" aria-hidden="true"></i>
  336. </a>
  337. <div class="modal fade" id="modalSearch" tabindex="-1" role="dialog" aria-labelledby="ModalLabel"
  338. aria-hidden="true">
  339. <div class="modal-dialog modal-dialog-scrollable modal-lg" role="document">
  340. <div class="modal-content">
  341. <div class="modal-header text-center">
  342. <h4 class="modal-title w-100 font-weight-bold">Search</h4>
  343. <button type="button" id="local-search-close" class="close" data-dismiss="modal" aria-label="Close">
  344. <span aria-hidden="true">&times;</span>
  345. </button>
  346. </div>
  347. <div class="modal-body mx-3">
  348. <div class="md-form mb-5">
  349. <input type="text" id="local-search-input" class="form-control validate">
  350. <label data-error="x" data-success="v" for="local-search-input">Keyword</label>
  351. </div>
  352. <div class="list-group" id="local-search-result"></div>
  353. </div>
  354. </div>
  355. </div>
  356. </div>
  357. </main>
  358. <footer>
  359. <div class="footer-inner">
  360. <div class="footer-content">
  361. <a target="_blank" rel="nofollow noopener" href="http://www.beian.gov.cn/portal/registerSystemInfo?recordcode=43130202000203"><img src="https://img.limour.top/2023/08/27/64eadeb81d6a0.webp" srcset="https://jscdn.limour.top/gh/Limour-dev/Sakurairo_Vision/load_svg/inload.svg" lazyload>湘公网安备43130202000203号 </a> <a target="_blank" rel="nofollow noopener" href="https://beian.miit.gov.cn/">湘ICP备20008299号 </a> <a target="_blank" rel="nofollow noopener" href="https://icp.gov.moe/?keyword=20210128">萌ICP备20210128号</a> <br> <a href="https://www.foreverblog.cn/" target="_blank"> <img src="https://img.foreverblog.cn/logo_en_default.png" srcset="https://jscdn.limour.top/gh/Limour-dev/Sakurairo_Vision/load_svg/inload.svg" lazyload alt="" style="width:auto;height:24px"> </a> <br> <a href="https://hexo.io" target="_blank" rel="nofollow noopener"><span>Hexo</span></a> <i class="iconfont icon-love"></i> <a href="https://github.com/fluid-dev/hexo-theme-fluid" target="_blank" rel="nofollow noopener"><span>Fluid</span></a> <i class="iconfont icon-love"></i> <a href="https://github.com/limour-blog/limour-blog.github.io" target="_blank" rel="nofollow noopener"><span>SRC</span></a> <i class="iconfont icon-love"></i> <a href="https://web.archive.org/web/20231130095837/https://effectiveacceleration.tech/" target="_blank" rel="nofollow noopener"><span>e/Acc</span></a>
  362. </div>
  363. </div>
  364. </footer>
  365. <!-- Scripts -->
  366. <script src="https://jscdn.limour.top/npm/nprogress@0.2.0/nprogress.min.js" ></script>
  367. <link rel="stylesheet" href="https://jscdn.limour.top/npm/nprogress@0.2.0/nprogress.min.css" />
  368. <script>
  369. NProgress.configure({"showSpinner":false,"trickleSpeed":100})
  370. NProgress.start()
  371. window.addEventListener('load', function() {
  372. NProgress.done();
  373. })
  374. </script>
  375. <script src="https://jscdn.limour.top/npm/jquery@3.6.4/dist/jquery.min.js" ></script>
  376. <script src="https://jscdn.limour.top/npm/bootstrap@4.6.1/dist/js/bootstrap.min.js" ></script>
  377. <script src="/js/events.js" ></script>
  378. <script src="/js/plugins.js" ></script>
  379. <script src="/js/img-lazyload.js" ></script>
  380. <script>
  381. Fluid.utils.createScript('https://jscdn.limour.top/npm/tocbot@4.20.1/dist/tocbot.min.js', function() {
  382. var toc = jQuery('#toc');
  383. if (toc.length === 0 || !window.tocbot) { return; }
  384. var boardCtn = jQuery('#board-ctn');
  385. var boardTop = boardCtn.offset().top;
  386. window.tocbot.init(Object.assign({
  387. tocSelector : '#toc-body',
  388. contentSelector : '.markdown-body',
  389. linkClass : 'tocbot-link',
  390. activeLinkClass : 'tocbot-active-link',
  391. listClass : 'tocbot-list',
  392. isCollapsedClass: 'tocbot-is-collapsed',
  393. collapsibleClass: 'tocbot-is-collapsible',
  394. scrollSmooth : true,
  395. includeTitleTags: true,
  396. headingsOffset : -boardTop,
  397. }, CONFIG.toc));
  398. if (toc.find('.toc-list-item').length > 0) {
  399. toc.css('visibility', 'visible');
  400. }
  401. Fluid.events.registerRefreshCallback(function() {
  402. if ('tocbot' in window) {
  403. tocbot.refresh();
  404. var toc = jQuery('#toc');
  405. if (toc.length === 0 || !tocbot) {
  406. return;
  407. }
  408. if (toc.find('.toc-list-item').length > 0) {
  409. toc.css('visibility', 'visible');
  410. }
  411. }
  412. });
  413. });
  414. </script>
  415. <script src=https://lib.baomitu.com/clipboard.js/2.0.11/clipboard.min.js></script>
  416. <script>Fluid.plugins.codeWidget();</script>
  417. <script>
  418. Fluid.utils.createScript('https://jscdn.limour.top/npm/anchor-js@4.3.1/anchor.min.js', function() {
  419. window.anchors.options = {
  420. placement: CONFIG.anchorjs.placement,
  421. visible : CONFIG.anchorjs.visible
  422. };
  423. if (CONFIG.anchorjs.icon) {
  424. window.anchors.options.icon = CONFIG.anchorjs.icon;
  425. }
  426. var el = (CONFIG.anchorjs.element || 'h1,h2,h3,h4,h5,h6').split(',');
  427. var res = [];
  428. for (var item of el) {
  429. res.push('.markdown-body > ' + item.trim());
  430. }
  431. if (CONFIG.anchorjs.placement === 'left') {
  432. window.anchors.options.class = 'anchorjs-link-left';
  433. }
  434. window.anchors.add(res.join(', '));
  435. Fluid.events.registerRefreshCallback(function() {
  436. if ('anchors' in window) {
  437. anchors.removeAll();
  438. var el = (CONFIG.anchorjs.element || 'h1,h2,h3,h4,h5,h6').split(',');
  439. var res = [];
  440. for (var item of el) {
  441. res.push('.markdown-body > ' + item.trim());
  442. }
  443. if (CONFIG.anchorjs.placement === 'left') {
  444. anchors.options.class = 'anchorjs-link-left';
  445. }
  446. anchors.add(res.join(', '));
  447. }
  448. });
  449. });
  450. </script>
  451. <script>Fluid.plugins.imageCaption();</script>
  452. <script src="/js/local-search.js" ></script>
  453. <!-- 主题的启动项,将它保持在最底部 -->
  454. <!-- the boot of the theme, keep it at the bottom -->
  455. <script src="/js/boot.js" ></script>
  456. <noscript>
  457. <div class="noscript-warning">Blog works best with JavaScript enabled</div>
  458. </noscript>
  459. <!-- hexo injector body_end start -->
  460. <script defer src="/theme-inject/timeliness.js"></script>
  461. <!-- hexo injector body_end end --></body>
  462. </html>