ファインチューニングの終焉:全ては継続的な事前学習

  • 少し古いですが「ファインチューニングの終焉(The End of Fine-tuning)」という記事が面白かったので、簡単なメモです。

www.latent.space

  • この記事に登場するFast.aiのJeremy Howardさんは「事前学習→追加学習→RLHF」のように、フェーズごとにデータセットの種類・手法を完全に切り替えていく言語モデルの学習方法は「もはや正しくない」と主張しています。

「私は今では誰もが採用している 3 段階の学習アプローチを最初に考案しましたが、これは実は間違っており、使用すべきではないというのが私の今の見解です。正しい方法は、ファインチューニングという考えを捨て去ることです。そのようなものは存在しません。ただ事前学習を継続するだけです。

事前学習とは、最初から、関心のあるあらゆる種類のデータ、関心のあるあらゆる種類の問題、指示、演習、コード、汎用ドキュメントの完成などをすべて含めるようにすることです。そして、トレーニングするにつれて、徐々にそれをキュレートし、徐々に品質を高め、実行したいタスクの種類にますます特化していきます。ただし、データを捨てることはありません...

それが今の私の見解です。ULMFiT は間違ったアプローチだと思います。それが、いわゆるアライメント税(alignment tax)が数多く見られる理由です。実際には、人々が間違ったトレーニングをしているからだと思います。」

  • このコメントで言及されている「アライメント税」は、LLMが有害な出力をしないように微調整するアラインメントによってモデル性能が低下してしまう現象を指しているのだと思います。
  • そのほか、Llama-2にコーディング特化の学習を加えたCodeLlamaが一般タスクでの性能を落とした例にも言及しつつ、過去に獲得した知識を追加学習中に失ってしまう「壊滅的忘却」を防ぐためには常に学習データの多様性を維持し続けることが重要だ、というのが論旨みたいです。
    • 最終的な用途に合わせて重要なデータの構成比率を増やしていく必要はありつつも、それ以外のデータを完全に捨ててしまうと学習によって汎用性能が劣化してしまう。
  • そういえば、以前にも似たような投稿をメモしていました。微調整中も1/4くらいは事前学習データを混ぜたほうがいいのでは?という内容です。

sc-bakushu.hatenablog.com

 

【LLM論文を読む&試す】LLMの刈り込み(Pruning):深いレイヤーの不合理な非効率性

 

huggingface.co

  • このモデルを実際に触ってみましたが、InstructモデルではなくBaseモデルを刈り込んだものなので正直なところ良し悪しはよくわかりませんでした。
    • 投稿によれば、Pruned-Llama3のMMLUスコアは76-77(4bitでテスト)で、元の79から若干低下しているものの、それでもMixtral 8x22Bと同等!の高いスコアを維持しているそうです。
  • ちょっと面白そうなので、このモデル作成手順の元ネタとされている論文を斜め読みしたうえで、モデル刈り込みを試してみたいと思います。

arxiv.org

  • これは2024年3月にMeta FAIR関係の研究者が投稿したarXiv論文です。

要旨

我々は、オープンウェイトの事前学習済みLLMのための単純なレイヤー刈り込み戦略を経験的に研究した。その結果、モデルのレイヤーの大部分(最大で半分)を除去した後でも、様々な質問応答ベンチマークでの性能の劣化を最小限に留められることを発見した。これらのモデルに刈り込みを行うにあたり、レイヤー間の類似性を考慮することによって除去するレイヤーの最適なブロックを特定した。この際パラメータ効率の良いファインチューニング(PEFT)手法、QLoRAを使用し、各実験を単一のA100 GPUで実行している。実用的な観点からは、これらの結果は、レイヤー刈り込み手法が他のPEFT戦略を補完し、一方ではファインチューニングの計算資源をさらに削減でき、他方では推論のメモリとレイテンシを改善できることを示唆している。科学的な観点からは、これらのLLMが層の削除に対して頑健であることは、現在の事前学習手法がネットワークの深いレイヤーのパラメータを適切に活用していないか、浅いレイヤーが知識の保存において重要な役割を果たしていることを示唆している。

メモ

  • Llama-2-70Bのようなベースモデルでは、最大で半分くらいのレイヤーを削除してもその性能の大部分を維持することができる。
  • というのも、刈り込むレイヤーを増やしていくと(単調に性能が低下するのではなく)一定の水準でMMLUスコアが急低下する現象が観察される。

  • これは、モデル性能にとって重要なパラメータが一部のレイヤーに偏っている(=それ以外のレイヤーは削っても性能に影響しにくい)ことを示唆する。
  • 具体的に検証すると、浅いレイヤーではレイヤー間の類似性が低く(無駄がなく)、より深いレイヤーで類似性が高く(冗長になる)ことがわかった。
  • したがって、単純に刈り込む戦略をとるよりも、深いレイヤーを優先して刈り込む戦略をとることで、モデルの性能低下を抑えながらサイズを縮小する道が開ける。
  • なおレイヤーを選択的に刈り込む場合も、モデルの損傷を修繕するための追加学習(Healing)を施すことで術後?の性能がより安定する。

PruneMe:論文の非公式実装

  • さて、MergeKitを手がけるarcee-aiが「PruneMe」というこの論文の非公式な実装を提供していて、冒頭のモデルもこれを利用して作っているそうです。

github.com

  • PruneMeのRepoには「Mistral-7B-instruct-v0.2」を例に実際にレイヤー間の類似性を検証したグラフが掲載されています。
  • 元論文で主張されているとおり、前半の浅いレイヤーは類似性が低く、後半の深いレイヤーで類似性が高くなる様子がプロットされています。

PruneMeによるモデル刈り込みを試す

  • ではGoogle Colabで簡単に試してみます。Repoをクローンし、依存関係をインストールします(TransformersはColabにプリインストール済み)。
# レポジトリのクローン
!git clone https://github.com/arcee-ai/PruneMe

# 依存関係
!pip install bitsandbytes datasets accelerate
  • ここではMistral 7Bベースの日本語チャットモデルを使って、レイヤー類似性を測定します。
  • とりあえずサンプルと同じ英語データセットを使って測定しています(レイヤーの類似性が判定できればよく日本語を使う必要はない気がするので)。
  • 「--layers_to_skip」で、削除したいレイヤーの数を指定します。ここでは32レイヤーのうち1/4の8レイヤーを削除する想定です。
  • なお計算時間の問題で、データセットサイズはサンプルより大幅に減らしています。
!python ./compute_block_similarity/layer_similarity.py --model_path "TFMC/Japanese-Starling-ChatV-7B" \
                      --dataset "arcee-ai/sec-data-mini" \
                      --dataset_column "text" \
                      --batch_size 8 \
                      --max_length 1024 \
                      --layers_to_skip 8 \
                      --dataset_size 200 \
                      --dataset_subset "train"
  • 計算が終わると、CSVファイルで結果が出力されます(上記設定で、標準GPUで20分ほど)。測定結果は以下の通りでした。

  • サンプルとよく似た結果が得られました。21-29のブロックが最も冗長なようです。同じMistral 7Bベースのモデルなので当然といえば当然かもしれません。
    • いずれにしてもレイヤーごとの冗長性は追加学習では特に変わらない(アーキテクチャが同じなら同じ)可能性がありそうです。
  • あとは、この特定したレイヤーブロックを削除するため、MergeKitを導入します。
# MergeKitの導入
!git clone https://github.com/cg123/mergekit.git %cd mergekit !pip install -e .
  • 次に「./PruneMe/slice_with_mergekit/slice.yaml」を適宜編集します。MergeKitのpassthorough手法(いわゆるフランケンマージを作る手法)を流用しているようです。
slice.yaml
--------------------------------------
slices: - sources: - model: TFMC/Japanese-Starling-ChatV-7B layer_range: [0, 21] - sources: - model: TFMC/Japanese-Starling-ChatV-7B layer_range: [29,32] merge_method: passthrough dtype: bfloat16
  • 最後にスライスを実行すると、刈り込みモデルが出力されます。
# 刈り込みの実行
!python /content/PruneMe/slice_with_mergekit/merge_me.py

刈り込みモデルのテスト

  • ということで、モデルの刈り込みができたので実際に出力してみます。以下が出力の一例です。
<s> [INST] <<SYS>>\nあなたは役立つアシスタントです。\n<</SYS>>\n\n空が青いのはどうして? [/INST] 

空が青いのは、一般には空の中での色の相物が効嵑たぼいるにお効がて効ぼとしいともい職て効㩉効い〾くでいてのいい。も樾詖倴棧倴壯〾棧倴壬費塏壯〳ぞ度、度桴壯〷度眸竇疎ゾ度、度諷耡紐ゾお填ゾとしい度熱ぼこいでいてまい臦諓ぞ度、...
  • 残念ながら、モデルの出力は途中から破綻してしまいます。
  • よく考えてみると、冒頭のモデルも元論文で扱っている例もいずれもBaseモデルだけで、Instructモデルには触れていませんでした...
  • 「刈り込みしても性能が落ちにくい」というのは、あくまで選択式のMMLUスコアで測定されるようなBaseモデルの単純性能の話で、Instructモデルはもっと繊細なようです。
    • あるいは、指示追従にとって重要なパラメータが実は冗長な深いレイヤーに存在していたりもするのでしょうか?
  • ということで、基本的には「InstructモデルではなくBaseモデルに刈り込みして、あらためて指示追従の追加学習をやる」という手順が必要になりそうで、お手軽感はありません。
  • 既成のモデルを単にサイズ圧縮するのが目的なら、量子化の質にこだわったほうが期待できそうです。
  • とはいえ、例えばMoEモデルならもっと刈り込める余地がありそう?とか、色々と面白そうな手法ではあるので、有効な活用方法が見いだされてほしい気もします。

 

llama.cpp:iMatrix量子化は日本語性能にどう影響するか?

 

  • 量子化時のモデル劣化を抑制する重要度行列(iMatrix; Importance Matrix)計算の話題です。
  • 最近はHuggingFaceにアップされるGGUFも多くがiMatrix版となっていますが、これらの量子化でよく使われているiMatrix計算用データセットは以下の2種類のようです。 
    1. wiki.train.raw:Wikitext(英語版)のトレーニング用データセット。llama.cppのサンプルでもwiki.train.rawが使われている。iMatrix計算では、このうち10kトークンほど計算すれば実用的には十分な様子。
    2. groups_merged.txt:koboldcppにも関わっているkalomazeさんが作成した疑似ランダムデータセット言語モデルで生成した雑多なテキストを中心とした200KBほどのデータ。
  • となると、一体どちらのiMatrix版が良いのかが気になってきます。以下のllama.cppのRepoの議論に目を通すと「基本的にはwikitextでよいが、英語以外の言語で使うならばgroups_mergedなど他の選択肢を検討すべきかもしれない」という趣旨の話が展開されています。

github.com

iMatrixテキストによる日本語性能の差

  • がiMatrix無しの量子化wiki.trainによるiMatrix量子化groups_mergedによるiMatrix量子化です(なおIQ3_XS以下では非iMatrixの量子化は不可)。
  • "Q5_K_M"の時点ではiMatrixの有無による差はありませんが、4bitの量子化では"iMatrix"による効果がはっきりあらわれます。
  • 同じiMatrix版でも差がみられ、"Q4_K_M"のgroups_merged版(3.52)は、"Q8_0"すら超えるスコア(3.42)を示していて驚きます。
    • かなり差が大きく、単なる誤差ではなさそうです。具体的に得点が上がった回答の一例を確認します。
ElyzerTasks100
設問68
Q:私の目は黒く、髪は明るい黄色で、足元はしっかりと地面に根ざしています。私は太陽の光を顔に浴びるのが好きです。私の正体は何でしょう?

Q8_0量子化

(×不正解)

あなたの特徴から推測すると、あなたはおそらく日本人である可能性が高いです。日本人の多くは黒い目を持ち、明るい黄色の髪色を持つ人は稀です。また、足元がしっかりと地面に根ざしているという表現から、安定感や地に足のついた人物像が浮かびます。太陽の光を好むということから、明るく前向きな性格を持っているのではないでしょうか。ただし、これらの特徴だけでは正確な正体を特定することはできません。もしかしたら、その他の要素も含めて自分自身のアイデンティティを表現してみると良いかもしれません。

Q4_K_M量子化

By groups_merged

(〇正解)

この情報から推測すると、あなたの正体は「ヒマワリ」である可能性が高いです。以下の理由からです。

1. 目が黒い:ヒマワリの花弁には中心部に暗い部分があり、まるで目のように見えることがあります。
2. 髪は明るい黄色:ヒマワリの花弁の色は明るい黄色で、髪色が明るい黄色であることがヒマワリのイメージに一致しています。
3. 足元がしっかりと地面に根づいている:ヒマワリは背が高く、しっかりと地面に根を張っています。足元がしっかりと地面に根づいていることもヒマワリの特徴の一つです。
4. 太陽の光を顔に浴びるのが好き:ヒマワリは太陽の光を浴びることで元気に成長します。太陽の光を好む姿勢はヒマワリにピッタリです。

以上の点から、あなたはヒマワリである可能性が高いと推測されます。ただし、他の花や植物の可能性もあるため、確定するにはさらに詳しい情報が必要かもしれません。
    • この問題では、日本語の比喩表現が理解できるようになっています。ほかの設問も含め、総じて日本語の文章理解が若干向上している印象を受けます。
  • "Q4_K_M"以外ではiMatrixデータの違いによる差は小さく、また必ずしもgroups_merged版のほうがスコアが良いわけでもありません。
  • とはいえ、iMatrixデータを工夫することで日本語性能に何らかの効果を及ぼせるかも、という期待はできそうです。
3bit以下で顕著になる量子化の悪影響
  • 量子化がもたらす性能変化は予測不能な部分が大きいのですが、確実に言えるのは「3bit以下の低bit量子化では性能が大幅に劣化する」という点です。
  • 先ほど紹介したグラフでは、"IQ2_XS"で顕著にベンチマークスコアが低下していましたが、実際の出力内容を見ると実は"IQ3_XS"の時点で回答が破綻するケースが出現します。例えば以下のような例です。
ElyzerTasks100
設問13
IMEとして機能してください。

あなたのタスクはローマ字入力されたアルファベット列を単語ごとにひらがなに変換し、それぞれの単語を漢字に変換する候補を最大3つ提示することです。

(中略)

それでは、入力を開始するので、IMEとして出力を表示してください。

inputs:
j, u, u, b, u, n, o, s, h, o, s, e, k, I
IQ2_XS量子化
Groups_merged
の回答
以下は入力されたアルファベット列をひらがなに変換し、それぞれの単語を漢字に変換する候補を最大3つ提示する例です。

inputs:
j, u, u, b, u, n, o, s, h, o, s, e, k, i

outputs:
- ジュウウブノウホソエキ: 寿生雲鵬, 順守護哀, 住吉水神社
- ゆうう: 優雅, 夢想, 憂う
- せきゃ: 挑発する, 決める, 石田紗希子
- おほ: おほほほほ, おほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほほ
  • 同じトークンが連続して止まらなくなったり、文章が意味もなく冗長になるのは低bit量子化でよく見る現象です。こうなると、モデルとしては実用が難しくなります。
  • 過去に日本語Wikitextを使って量子化によるモデルのPerplexityの推移を調べたことがあったのですが、ここでも量子化クラスが3bitになるとPerplexityが急上昇する傾向が確認できました(Perplexityは一般に低いほど良い)。
    • もちろん英語でも同じようにPPLが悪化しますが、どうやら日本語のほうがより激しくPPLが増加する可能性が高いです(参考記事)。

iMatrix計算に日本語テキストを使うと?
  • じゃあ3bit以下の量子化は避ければ?という話ではあるのですが、最近はローカルLLM大型化が進んでいることもあり、実際には3bitクラスの量子化をやむなくチョイスする場面も(個人的には)よくあります。
  • そこで、iMatrixの計算用データに日本語を含めることで低bitでの日本語性能の劣化を抑制できないか、ということを考えています。
  • 冒頭のllama.cppのRepoの議論でも、英仏混合テキストをiMatrix計算に使うことで仏語でのPPLが改善した(それだけでなく英語性能も一部改善した)という結果が投稿されています。

https://github.com/ggerganov/llama.cpp/discussions/5263#discussioncomment-8353685
  • これらの情報を踏まえ、Wikitextやgroups_mergedのほか上記の例で用いられているc4データセット(Webクロールコーパス)を使って、iMatrix計算用の日英混合テキストを作成してみることにしました。

huggingface.co

  • 新たに用意するテキストについては、200KB程度のgroups_merged.txtで一定の効果が出ていることから日本語200KB+英語200KBの400KBくらいを目安にざっくり抽出しています。
  • というわけで、何種類かのiMatrixテキストでIQ2_XS量子化を行って量子化モデルのベンチマークスコアを取った結果、以下のようになりました。

  • wiki_text(en)wiki.trainと同じもの、wiki_text(en+ja)wiki.trainに日本語wikitextを追加したもの、c4_en_jaがc4の日+英テキスト、c4_ja_onlyがc4の日本語のみテキストです。
  • 日本語テキストを含めることでベンチマークスコアが改善することがわかります。
  • ただし、wiki_text(en+ja)のスコアはwiki_text(en)比で微増といったところでほとんど効果はありません。また、日本語のみのc4_ja_onlyよりも日英混合のc4_en_jaのほうがスコアが高いのも面白いです。
  • このテストからc4データセット由来の日英テキストでiMatrix計算をすると良い結果が得られそうなことが分かったので、もう少し深堀りしてみます。
  • 上記の例では、c4の英語テキスト200KBの後ろにc4の日本語テキスト200KBを結合したのですが、英語と日本語の混ぜ方で違いがあるかもしれません。
  • 日本語テキストを先にして英語テキストを後半に結合したパターンや、英文チャンクと日本語チャンクをランダムにミックスしたパターンも作成し、あらためて量子化後のスコアを測定しました。

  • 結果としては、最初のテストで使った「英語テキスト(先)+日本語テキスト(後)」のパターンが最高スコアとなりました。
  • 「英文チャンクと和文チャンクのランダムミックス」や「日本語テキスト(先)+英語テキスト(後)」は「日本語のみ」よりもスコアが低くなっています(iMatrix計算の後半を日本語テキストのみで行うことに何か効果があるのでしょうか?)
雑感
  • 記事が長くなりそうなので、ひとまずこのあたりに。他のモデル、他の量子化サイズでの効果も追って検証できればと思っています。
  • なお、今回のテストで最も良さげな効果を示した日英混合iMatrix用テキスト(c4_en_plus_ja)については参考までHuggingFaceにアップしています。

huggingface.co

 

Mergoo: LLMの効率的なマージと微調整のためのライブラリ (MoE, Mixture of Adapters)

 

  • Merggoo」というLLMマージ用の新しいライブラリができたそうで、開発者の方がRedditでもPRポストを上げておられます。
  • もしかするとMegekitに代わるものになるのかもしれません。MoEのルーティングの学習や、LoRA Adapterの混合など。

面白そうなので後で試せればと思いますが、概要だけ先に見ておこうと思います。以下、開発者ブログの和訳です。

Mergoo: 効率的にマージし、微調整する (MoE, Mixture of Adapters)

huggingface.co

私たちは最近、複数の LLM エキスパートを簡単に統合し、統合された LLM を効率的にトレーニングするためのライブラリである mergoo を開発しました。このライブラリでは、さまざまな汎用またはドメインベースの LLM エキスパートの知識を効率的に統合できます。

🚀 Mergooでは次のことが可能です。

  • 複数のオープンソースLLMを簡単に統合できます
  • さまざまなマージ方法を適用:エキスパートの混合、アダプタの混合、レイヤーごとのマージが可能です
  • ゼロから始めることなくMoEを効率的にトレーニングできます
  • HuggingFece🤗モデルやトレーナーと互換性があります
導入

mergoo は、汎用的なものからドメイン固有のものまで、さまざまな LLM エキスパートの知識を統合するための信頼性が高く透過的なパイプラインを構築するように設計されています。エキスパートの混合、アダプタの混合、レイヤーごとのマージなどのさまざまな統合手法が組み込まれており、LLM ビルダーに柔軟性を提供します。マージされた LLM は、SFTrainer、PEFT、Trainer などの Hugging Face 🤗 トレーナーを使用して、特定のユースケースに合わせてさらに微調整できます。
次のセクションでは、MoE を使用して完全に微調整された LLM からマージされた LLM を構築する方法と、LoRA で微調整されたエキスパートからアダプタの混合 LLM を作成する方法を示す 2 つの例について説明します。

完全に微調整済みのLLMの混合

Branch-Train-Mixに続いて、ドメイン固有の LLM エキスパートは、Mixture-of-Expert (MoE) レイヤーのエキスパートとしてフィードフォワードパラメータをまとめ、残りのパラメータを平均化することで統合できます。MoE レイヤーは、後で微調整してトークン レベルのルーティングを学習できます。

例として、次のドメイン固有の LLM エキスパートを統合します。

  • ベースモデル: metal-llama/Llama-2-7b-hf
  • コードエキスパート: codellama/CodeLlama-7b-Python-hf
  • WikiChat エキスパート: stanford-oval/Flame-2-7b-WikiChat-fused

マージするための設定を指定します:

config = \
{
    "model_type": "llama",
    "num_experts_per_tok": 2,
    "experts":[
        {
            "expert_name" : "base_expert",
            "model_id" : "meta-llama/Llama-2-7b-hf"
        },
        {
            "expert_name" : "expert_1",
            "model_id" : "codellama/CodeLlama-7b-Python-hf"
        },
        {
            "expert_name" : "expert_2",
            "model_id" : "stanford-oval/Llama-2-7b-WikiChat-fused"
        }
    ],
    "router_layers":[
        "gate_proj",
        "up_proj",
        "down_proj"
    ],
}

次に、マージされたエキスパートのチェックポイントを構築して保存します。

import torch
from mergoo.compose_experts import ComposeExperts
model_id = "mergoo_llama_code_wikichat"
expertmerger = ComposeExperts(config, torch_dtype=torch.float16)
expertmerger.compose()
expertmerger.save_checkpoint(model_id)

以下では、マージされた LLM のチェックポイントが読み込まれ、Python コード命令データセットでさらに微調整されます。

from mergoo.models.modeling_llama import LlamaForCausalLM
import torch
import datasets
import random
from trl import SFTTrainer
from transformers import TrainingArguments

# load the composed checkkpoint
model = LlamaForCausalLM.from_pretrained(
    "mergoo_llama_code_wikichat", 
    device_map="auto", 
    torch_dtype=torch.bfloat16,
)# 'gate' / router layers are untrained hence loaded warning would appeare for them

# load the train dataset
dataset = datasets.load_dataset("iamtarun/python_code_instructions_18k_alpaca")['train']
dataset = dataset['prompt']
random.shuffle(dataset)
train_dataset =  datasets.Dataset.from_dict(dict(prompt=dataset[:-1000]))
eval_dataset = datasets.Dataset.from_dict(dict(prompt=dataset[-1000:]))

# specify training arguments
trainer_args = TrainingArguments(
    output_dir= "checkpoints/llama_moe",
    per_device_train_batch_size = 1,
    per_device_eval_batch_size = 1, 
    learning_rate= 1e-5,
    save_total_limit=1,
    num_train_epochs=1,
    eval_steps= 5000,
    logging_strategy="steps",
    logging_steps= 25,
    gradient_accumulation_steps=4,
    bf16=True
)

trainer = SFTTrainer(
    model,
    args= trainer_args,
    train_dataset= train_dataset,
    eval_dataset= eval_dataset,
    dataset_text_field="prompt",
)

# start training
trainer.train()

その後、コードを Huggingface Hub にプッシュできます (ぜひプッシュしてください):

model.push_to_hub("mergoo_llama_code_wikichat_trained")

mergooはMistralとBertベースのエキスパートもサポートしています。

アダプターの混合(Mixture of Adapters)

mergoo は、複数のアダプタ (LoRA) を統合された MoE スタイルのアーキテクチャに統合することを容易にします。これは、微調整された LoRA の上にゲーティング レイヤーとルーティング レイヤーを適用することで実現されます。

アダプタ混合 LLM を構築するには:

  • 同じベースモデルで微調整されたアダプタ(LoRA)のプールを収集する
  • Mergoo を適用して MoE スタイルのマージされたエキスパートを作成する
  • 下流のタスクでマージされたエキスパートを微調整する

たとえば、顧客サポート領域では次の専門家を統合できます。

  • predibase/customer_support
  • predibase/customer_support_accounts
  • predibase/customer_support_orders

次のように構成を指定して、マージされたチェックポイントを構築します。

import torch
from mergoo.compose_experts import ComposeExperts

model_id = "mergoo_customer_suppoer_moe_lora"
config = {
    "model_type": "mistral",
    "num_experts_per_tok": 2,
    "base_model": "mistralai/Mistral-7B-v0.1",
    "experts": [
        {
            "expert_name": "adapter_1", 
            "model_id": "predibase/customer_support"
        },
        {
            "expert_name": "adapter_2", 
            "model_id": "predibase/customer_support_accounts"
        },
        {
            "expert_name": "adapter_3", 
            "model_id": "predibase/customer_support_orders"
        }
    ],
}

expertmerger = ComposeExperts(config, torch_dtype=torch.bfloat16)
expertmerger.compose()
expertmerger.save_checkpoint(model_id)

注:LoRAでマージ候補のエキスパートを微調整する場合、expert_nameはadapter_で始まる。

「完全に微調整された LLM の混合」セクションで定義されているように、マージされたエキスパートはさらに微調整できます。

結び

以上のように、mergooでは、複数の専門家の知識を確実かつ透過的に統合することができます。mergooはmixture-of-expert、mixture-of-adapters (MoE-LoRA)、layer-wise mergingを含むいくつかの統合技術をサポートします。マージされたLLMは、信頼できるエキスパートを提供するために、下流のタスク上でさらに微調整することができます。

 

github.com

 

Chat Vector の効果 vs. MoEマージ の効果

  • 前回の続きです。次はChat Vector処理によってどれくらい日本語チャット機能が改善しているのかを具体的にみてみます。
  • 下記の記事で、Chat Vectorを使わなくても単純に2モデルをMoEマージするだけで一定の性能向上が観察できることは確認しています。

sc-bakushu.hatenablog.com

  • では、Chat Vectorを加えることでモデルのチャット性能はどのように影響を受けるののでしょうか。

単体の7BモデルにおけるChat Vectorの効果

  • まず「LightChatAssistant 2x7B」の再現過程で得られた「ChatNTQ 7BにMistral-7B-v0.2-InstructのChat Vectorを適用したモデル」の性能を測定してみます。
名称 LightChat
Assistant
2x7B
Chat NTQ 7B
random MoE
ChatNTQ 7B
+ChatVector
ChatNTQ 7B
パラメータ 2x7B 2x7B 7B 7B
量子化タイプ Q8_0 Q8_0 Q8_0 Q8_0
ElyzaTasks100
スコア平均
3.31 3.25 3.23 3.06
標準誤差 0.14 0.14 0.15 0.15
平均字数 273.61 261.27 280.04 240.46
  • ChatVector処理だけで平均スコアは3.23に上昇しました。
  • 以前作った「ChatNTQ 7Bのx2ランダムマージ」のスコアが3.25だったので、このChat Vectorはランダムマージに近いレベルの性能上昇効果がありそうです。
  • 回答の平均字数についてもどちらも共通して増加しているものの、Chat Vector版のほうがより顕著に増えています。
    • 体感的にもChat Vector版の回答は総じて丁寧で、Mistral 7B v0.2 Instructの影響が感じられます。
  • 実際の回答例を見てみます。
  • 以下のように、元の「ChatNTQ 7B」では解けなかった計算問題が、ランダムマージ版でも、Chat Vector版でも同じように正答できるようになりました。
ElyzaTasks100
設問26
質問: 太郎くんはりんごを5つ持っています。彼はさらに2つのりんごの箱を買いました。1つの箱には3つのりんごが入っています。太郎くんは何個のりんごを持っていますか?
ChatNTQ-7B
回答(不正解×
太郎くんは最初に5つのりんごを持っていました。その後、2つの箱を買い、それぞれに3つのりんごが入っていたため、太郎くんは合計で8つのりんごを持つことになります。したがって、彼が最初に持っていた5つのりんごと、後から買った2つの箱に入っていた3つのりんごを合わせると、合計で8つのりんごを持っていることになります。
ChatNTQ 2x7B
Random MoE
回答(正解〇
太郎くんが最初に持っていた5つのりんごと、2つの箱の合計個数を計算する必要があります。

太郎くんが最初に持っていた5つのりんごの数は、5です。

1つの箱には3つのりんごが入っているため、もう1つの箱にも同じ数の3つのりんごが入っていると仮定すると、2つの箱には合計6つのりんごが入っています。

したがって、太郎くんが最初に持っていた5つのりんごと、2つの箱に入っている6つのりんごを合計すると、11個のりんごを持っていることになります。

したがって、太郎くんは11個のりんごを持っています。
ChatNTQ-7B
+ Chat Vector
回答(正解〇
最初に太郎くんが持っていたりんごは5つです。そして、2つの箱を買って、それぞれ1つの箱に3つずつ入っているので、合計で2 × 3 = 6つのりんごが追加されます。

最初の5つと合わせて、太郎くんは11個のりんごを持っています。
  • MoEマージとChat Vectorという全く異なる処理ですが、全般的な性能向上はどちらの手法においても期待できそうな印象です。
  • ランダムMoEマージの場合、おそらく総パラメータ数の増加が性能に寄与していて、Chat Vectorの場合は英語モデルの追加学習スキルが性能に寄与しているのだと思います。

Chat VectorモデルのランダムMoEマージ

  • つづいて、Chat Vectorを掛けた「ChatNTQ 7B」を単純にランダムMoEマージするとどうなるでしょうか。結果は以下のとおりです。
名称 LightChat
Assistant
2x7B
ChatNTQ 7B
+ ChatVector
Random MoE
ChatNTQ 7B
+ ChatVector
ChatNTQ 7B
パラメータ 2x7B 2x7B 7B 7B
量子化タイプ Q8_0 Q8_0 Q8_0 Q8_0
ElyzaTasks100
スコア平均
3.31 3.33 3.23 3.06
標準誤差 0.14 0.15 0.15 0.15
平均字数 273.61 286.73 280.04 240.46
  • ベンチマーク性能では「LightChatAssistant 2x7B」とほぼ同じレベルのモデルが得られました。
  • したがって、「LightChatAssistant 2x7B」の高い日本語チャット性能それ自体は、「Chat Vector効果」 +「マージによるパラメータ増加効果」の2つだけで説明することもできそうです。
  • とはいえ実際の出力を見比べてみると、小説生成モデルを掛け合わせた「LightChatAssistant 2x7B」の出力の方が、同一モデルマージの「ChatNTQ-ChatVector 2x7B」より明らかにクリエイティブで面白いと感じます。
  • 異なる特化モデルを掛け合わせることの効果は、ベンチマークで測定しきれない部分にこそ現れるのかもしれません。

別のモデルからChat Vectorを抽出してみる

  • ついでに、異なるモデルからChat Vectorを抽出することを試してみます。
  • 元論文や「LightChatAssistant 2x7B」で差分抽出に使われた「Mistral 7B v0.2 Instruct」は、Mistral公式のファインチューンモデルであり、体感的にも優れたチャットモデルです。
  • ただしベンチマーク上は「Mistral 7B v0.2 Instruct」を上回るMistral系モデルがいくつか作出されていて、その代表例がスタンフォード大学系の研究者が公開した「Starling-LM-7B-beta」です。
  • Starling-7Bシリーズについては、以下の記事で概要を確認しています。

sc-bakushu.hatenablog.com

  • 「Starling-LM-7B-beta」は現在も「LMSYS Chatbot Arena Leaderboard」で7Bモデル最高スコアを示していて、英語チャット性能ではこのクラスのSOTAです。
    • 7Bのオープンウェイトが人手による評価で GPT-3.5-Turbo・Claude 2.1・Gemini Pro など主要各社の商用モデルと同等の評価を得られるというのは少々驚きです。

  • このモデルからChat Vectorを抽出して掛け合わせれば、いい感じの日本語チャットモデルが出来上がる予感がします(素人の発想)。
  • というわけで「Starling-LM-7B-beta」のChat Vectorを「ChatNTQ 7B」に掛けたモデルを作り、ベンチマークにかけてみました。
  • なお、「Starling-LM-7B-beta」のChat Vectorは1.0で掛けてもまったく英語が混ざらなかったので、0.8ではなく1.0の強度を採用しています。
  • 結果は以下のとおりです。
名称 Command R
V01
LightChat
Assistant
2x7B
ChatNTQ +
Staling ChatV
ChatNTQ 7B
パラメータ 35B 2x7B 7B 7B
量子化タイプ Q8_0 Q8_0 Q8_0 Q8_0
ElyzaTasks100
スコア平均
3.42 3.31 3.42 3.06
標準誤差 0.15 0.14 0.15 0.15
平均字数 194.02 273.61 340.82 240.46
  • スコア上は「LightChatAssistant 2x7B」を上回り、8bit量子化ベースで「Command R 35B」に並ぶモデルが出来上がりました。
  • 出力字数も多く(平均341字)、とにかく丁寧なチャット対応ができるようになっています。ElyzaTasks100では回答の「有用性」が評価基準に含まれるので、評価者のGPT-4もこの観点から高く評価しているようです。
  • とはいえ「Command R 35B」は遥かに簡潔な回答(平均194字)で同じだけのスコアを出しているので、純粋な「賢さ」に関して言えば明らかに「Command R 35B」が上だと感じます。

Starling LMのChat VectorモデルのMoEマージ

  • では、このモデルをさらにMoEマージすればもっとスコアは上がるのか、と安易に考えてやってみましたが、結果は以下のとおりでした。
名称 LightChat
Assistant
2x7B
ChatNTQ +
Staling ChatV
Random MoE
ChatNTQ +
Staling ChatV
ChatNTQ 7B
パラメータ 2x7B 2x7B 7B 7B
量子化タイプ Q8_0 Q8_0 Q8_0 Q8_0
ElyzaTasks100
スコア平均
3.31 3.18 3.42 3.06
標準誤差 0.14 0.15 0.15 0.15
平均字数 273.61 351.42 340.82 240.46
  • 出力字数こそ多少増えているものの、GPT-4の評価は逆に低下しています。
  • 体感的にも「口の速さに思考が追い付いていない」ような、ちぐはぐな印象を受けるモデルになってしまいました。
  • 単一モデルに対して最適なChat Vector比率とMoEマージ後に最適になるChat Vector比率が異なるという話(こちらのAratakoさんのモデルカード参照)もあり、必ずしも足し算にはならないのだと再確認しました。

雑感

  • ということで、ひとまず「LightChatAssistant 2x7B」の後追いをしながらMoEとChatVectorでいくつかの試行錯誤をしてみました。
  • なお「Starling ChatV 7Bモデル」についてはベンチマーク性能があまりに高く、何かの参考にもなる気もしたので、HuggingFaceにアップしてみました。

huggingface.co

  • Llama-3の小型モデルのリリースが間近に迫っているという話もあり、そちらも楽しみにしながら、CVとMoEで何ができるかも引き続き試していきたいと思います。

 

LightChatAssistant 2x7B を再現する

 

  • 前回Chat Vectorについて簡単に予習したので、とりあえず「LightChatAssistant 2x7B」の作成手順を再現してみたいと思います。
  • 作者さんがモデルカードで丁寧に説明してくださっているので、基本的にそれをなぞるだけです。まずはまったく同じモデルを作ってみます。

huggingface.co

Chat Vectorによる処理

  • Google ColabのCPU(ハイメモリ)で試します。3つ分のモデルファイルを扱えるだけのRAM容量さえあればよく、GPU無しで完結するようです。
  • まずは、Chat Vector処理を行うための依存関係をインストールします。ColabだとTransformersはプリインストールされてます。
# 依存関係のインストール
!pip install accelerate protobuf
  • さてChat Vectorは、追加学習したチャットモデルの重みから、対応するベースモデルの重みを差し引くことで得られる差分です。
  • ここでは、ベースモデルは"Mistral-7B-v0.1"、チャットモデルに"Mistral-7B-Instruct-v0.2"を使っています。いずれもMistralの公式モデルで、元論文を踏襲した処理です。
  • それぞれのモデルをHuggingFaceからロードします。
from transformers import AutoModelForCausalLM
import torch

# ベースモデル="Mistral-7B-v0.1"
base_model = AutoModelForCausalLM.from_pretrained( "mistralai/Mistral-7B-v0.1", torch_dtype=torch.bfloat16, device_map="cpu", )
# 差分を取り出すチャットモデル="Mistral-7B-Instruct-v0.2"
inst_model = AutoModelForCausalLM.from_pretrained( "mistralai/Mistral-7B-Instruct-v0.2", torch_dtype=torch.bfloat16, device_map="cpu", )
  • つづいて、取り出したChat Vectorを統合するための日本語モデルを用意します。
  • 1つ目は"chatntq-ja-7b-v1.0"です。Mistral7B系モデルの中では相対的に高い日本語チャット性能を持っているモデルです。
  • すでに日本語でチャットファインチューンされたモデルですが、英語モデルのChat Vectorを適用してさらに強化を図る形になります。
# 差分を張り付ける日本語モデル(1)="chatntq-ja-7b-v1.0"
cp_model = AutoModelForCausalLM.from_pretrained( "NTQAI/chatntq-ja-7b-v1.0", torch_dtype=torch.bfloat16, device_map="cpu", )
  • あとは以下のようにChat Vectorを抽出します。この手順は、jovyanさんのこちらの記事で詳しく紹介されているものです。大変ありがたいです。
# Chat Vectorを取り出して"chatntq-ja-7b-v1.0"へコピー
for k, v in cp_model.state_dict().items(): chat_vector = inst_model.state_dict()[k] - base_model.state_dict()[k] new_v = v + ( 0.8 * chat_vector.to(v.device) ) v.copy_(new_v)
# Chat Vectorを張り付けた"chatntq-ja-7b-v1.0"を保存 cp_model.save_pretrained("./chatntq-ja-7b-v1.0-chatvector")
  • なお、元論文ではChat Vectorを掛ける強度として1.0または0.5を試しています。1.0だと差分の影響が強すぎて英語が混ざりやすくなり、逆に0.5だと差分の影響が表れにくくなる可能性があるようです。
    • おそらくそうした理由からLightChatAssistant 2x7Bを含むいくつかのChatVectorモデルで0.8掛けという数値を採用しているのだと思います。
  • さて、もう一つの日本語モデルは小説生成モデルの"Antler-7B"です。このモデルの学習データにNSFWが含まれているのでご注意ください。
  • 手順は先ほどと同じで、今度はAntler-7BにChatVectorを掛けます。こちらはチャットファインチューンされていないモデルにチャット能力を付与するという元論文の意図に近い形になっています。
# 差分を張り付ける日本語モデル(2)="Antler-7B"
cp_model2 = AutoModelForCausalLM.from_pretrained( "Elizezen/Antler-7B", torch_dtype=torch.bfloat16, device_map="cpu", )
# Chat Vectorを取り出して"Antler-7B"へコピー
for k, v in cp_model2.state_dict().items(): chat_vector = inst_model.state_dict()[k] - base_model.state_dict()[k] new_v = v + ( 0.8 * chat_vector.to(v.device) ) v.copy_(new_v)
# Chat Vectorを張り付けた"Antler-7B"を保存
cp_model2.save_pretrained("./Antler-7B-chatvector")
  • これで2つのChat Vector適用済みモデルが用意できました。
  • Tokenizerはそれぞれの日本語モデルから以下のように流用します。
from transformers import AutoTokenizer

cp_tokenizer = AutoTokenizer.from_pretrained("NTQAI/chatntq-ja-7b-v1.0")
cp_tokenizer.save_pretrained("./chatntq-ja-7b-v1.0-chatvector")

cp_tokenizer2 = AutoTokenizer.from_pretrained("Elizezen/Antler-7B")
cp_tokenizer2.save_pretrained("./Antler-7B-chatvector")
  • 最後に、各モデルのconfig.jsonの修正が必要です。
  • 差分を取り出した"Mistral-7B-Instruct-v0.2"のコンテキスト関係の設定値が他の"Mistral-7B-v0.1"系モデルとは異なっているために必要な処理のようです。
    • "Mistral-7B-Instruct-v0.2"は、v0.1系モデルで採用していたSliding Windowという仕組みが無効になっています。
    • そのためChat Vector適用したモデルも"Mistral-7B-Instruct-v0.2"のコンテキスト設定に合わせる必要が生じます。
  • 具体的には、"max_position_embeddings" = 32768、"rope_theta" = 1000000.0、"sliding_window" = None、に修正します。
  • Claudeに投げたところ以下のような修正用コードを作ってくれました。
import json

with open("./chatntq-ja-7b-v1.0-chatvector/config.json", "r", encoding="utf-8") as f:
    config = json.load(f)

config["max_position_embeddings"] = 32768
config["rope_theta"] = 1000000.0
config["sliding_window"] = None

with open("./chatntq-ja-7b-v1.0-chatvector/config.json", "w", encoding="utf-8") as f: json.dump(config, f, indent=2, ensure_ascii=False) with open("./Antler-7B-chatvector/config.json", "r", encoding="utf-8") as f: config = json.load(f) config["max_position_embeddings"] = 32768 config["rope_theta"] = 1000000.0 config["sliding_window"] = None with open("./Antler-7B-chatvector/config.json", "w", encoding="utf-8") as f: json.dump(config, f, indent=2, ensure_ascii=False)
  • ちなみに、最初に試したときにこのconfig.jsonの処理をスキップしてしまったのですが結果的にモデル出力の質がかなり劣化する形となりました。どうやらロングコンテキストの処理に影響するだけでは留まらないようです。
  • 以上で、Chat Vector関係の処理は終わりです。

MergeKitによるMoEマージ

  • では先ほど作成した2つのChat Vector適用モデルをMergeKitでマージします。
  • 通常版のMergeKitではMoEマージに対応していないので、Mixtralブランチを指定して導入する必要があります。
  • MoEマージの手順については、以前も紹介したはちさんの記事で詳しく解説されているものを参考にしています。
# メモリ容量の確保
del cp_model, cp_model2

# MoE対応版のMergeKitの導入
!git clone -b mixtral https://github.com/cg123/mergekit.git
!mkdir ./mergekit/output
!mkdir ./mergekit/config

!pip install -e ./mergekit
!pip install git+https://github.com/huggingface/transformers #セッションは再起動しない
  • なおColabでは、途中でセッションの再起動を促されますが無視してキャンセルします。
  • MoEマージのためのConfigは、LightChatAssistant-2x7BのRepoに上がっていますのでここではそのまま借用します。内容的にNSFWです。
# LightChatAssistant-2x7BのRepoからマージ設定をDL(NSFWにつき注意)
!wget https://huggingface.co/Sdff-Ltba/LightChatAssistant-2x7B/resolve/main/mergekit_moe_config.yml -P ./mergekit/config
  • あとは2モデルのマージを実行するだけです。完了後、outputフォルダに2x7Bのモデルファイルが格納されます。
# MoEマージの実行
!python ./mergekit/mergekit/scripts/mixtral_moe.py ./mergekit/config/mergekit_moe_config.yml ./mergekit/output -v

llama.cppによる量子化

  • 以上で「LightChatAssistant 2x7B」の作成手順の再現が完了しました。
  • とはいえ実際にローカルでモデルを使用する場合、そのまま使う人はまずいないと思いますのでllama.cppのGGUFフォーマットに変換しておきます。
# llama.cppの導入(依存関係はモデル変換用のみインストール)
!git clone https://github.com/ggerganov/llama.cpp
!pip install -r ./llama.cpp/requirements/requirements-convert.txt

# Q8_0量子化
!python ./llama.cpp/convert.py ./mergekit/output --outfile LightChat-2x7B-Repro-q8.gguf --outtype q8_0
  • Q8量子化後で、12.8GBほどのファイルサイズになりました。もっと圧縮する場合、f16に変換した後でmakeしてquantizeコマンドから希望の量子化タイプに変換することになります。

再現モデルのテスト

  • 再現したモデルを、元の「LightChatAssistant 2x7B」と比較してみます。
  • このモデルを、たびたび利用している「ELYZA-tasks-100」ベンチマークにかけてGPT-4に評価させてみました。
名称 LightChatAssistant
2x7B【オリジナル】
LightChatAssistant
2x7B【再現】
ChatNTQ JA 7B
パラメータ 2x7B 2x7B 7B
量子化タイプ Q8_0 Q8_0 Q8_0
ElyzaTasks100
スコア平均
3.31 3.29 3.06
標準誤差 0.14 0.15 0.15
平均回答字数 273.61 277.00 240.46
  • オリジナルと比較してスコアの差は十分に小さそうです。おおむねうまく「LightChatAssistant 2x7B」の再現ができたようです。詳細な作成手順を公開し、再現を可能にしてくださった作者さんにあらためて感謝申し上げます。ありがとうございます。

Chat Vectorの差分抽出モデルを入れ替えたり、MoEマージの手法を変更することで、いろいろと試行錯誤して楽しむことができそうです。

 

【LLM論文を読む】Chat Vector:LLMに新たな言語での指示追従とアラインメントを付与するシンプルなアプローチ

  • 前回の記事につづき、「LightChatAssistant 2x7B」に高いチャット性能をもたらした要素について調べています。
  • 基本的には、1)ベースモデルとしての「ChatNTQ JA 7B」の日本語性能の高さ、2) MoEマージで総パラメータ数が増えることによる全般的な性能向上、3) 小説生成モデル「Antler-7B」とのマージによるモデル特長の相互補完、の3点でかなりの部分が説明できそうです。
  • 一方で、モデル作者のSdff-Ltbaさんはマージ前のモデルに対してChat Vectorという処理を施し、チャット能力を高めたと説明されています。よって、Chat Vectorがどの程度の性能向上に寄与しているのかもざっくり知りたい気がします。
  • さしあたって、そもそもChat Vectorとは何かについて簡単に予習しておきたいと思います。元のarXiv論文は、2023年10月に台湾の研究グループから投稿されています。

arxiv.org

要旨

近年、オープンソースの大規模言語モデル(LLM)の開発が急速に進んでいる。しかしながら、データの制約から、オープンソースのLLMの多くは主に英語に特化している。この問題に対処するため、我々はChat Vectorという概念を導入し、単純なモデル演算によって、事前学習された言語モデルに指示追従と人間の価値観のアラインメントを付与させる。Chat Vectorは、事前に訓練されたベースモデル(例えばLLaMA2)の重みから、対応するチャットモデル(例えばLLaMA2-chat)の重みを引くことで得られる。継続的に事前訓練されたモデルの重みにChat Vectorを追加するだけで、さらなる訓練を必要とせずに、新しい言語でのチャット機能をモデルに与えることができる。我々の実証的研究は、3つの異なる側面からChat Vectorの優れた有効性を実証している:すなわち、指示追従、毒性緩和、マルチターン対話である。さらに、本アプローチの適応性を示すために、様々な言語、ベースモデル、Chat Vectorを包含するように実験を拡張した。その結果、Chat Vectorのシンプルさ、有効性、適用範囲の広さが強調され、事前に学習された言語モデルにおいて会話機能を効率的に実現するための説得力のあるソリューションであることが分かった。

メモ

  • Sdff-Ltbaさんがモデルカードにリンクされているこちらの記事に、概要から手順までわかりやすく解説されていました。半月ほど前にかなり話題になっていたんですね。

qiita.com

  • 簡単に言うと「深層学習モデルの重みに対して足し算引き算が成り立つ」という前提から、英語チャットモデルの「チャット能力」に関するベクトルを取り出し、別の言語のチャットモデルに適用する、という手法だそうです。

ほかに元論文に目を通して気になった点も列挙します。

  • すでに「Task Vector」というアプローチがあり(微調整済みモデルの重みから事前学習済みモデルの重みを引き算して得る)、この研究ではこれを英語から別言語へのチャット能力転移に利用した。
  • Vicunaベンチマークで指示追従能力をテストし、以下のことが確認された。
    • llama2 → 継続事前学習 + Chat Vectorは、繁体中国語LLaMAとChineseLLaMAの両方において、llama2 → 継続事前学習 → ファインチューンと同等の結果を示した。
    • ファインチューンとChat Vectorは相補的な効果がある: FTとChat Vectorを一緒に行うことで、どちらか一方だけを行うよりも良いパフォーマンスが得られる。
    • llama2-chat → 継続事前学習 + FT はチャットの能力を失わせる:llama2 → 継続事前学習 → FT よりも性能が良いが、llama2 CP + Chat Vectorや llama2 → 継続事前学習 → FT + Chat Vectorよりも性能が悪い。これは、llama2-chat → 継続事前学習 + FTは情報損失を引き起こすことを示しており、また、Chat Vectorを使用することの重要性を示している。

まとめ

  • Chat Vectorは単体でも効果があるものの、ファインチューンと併用するほうがよいみたいです。「LightChatAssistant 2x7B」では、片方がチャットファインチューンモデルの「ChatNTQ JA 7B」です(もう片方は小説ファインチューンモデルの「Antler-7B」)。
  • ChatVectorを取り出す英語チャットモデルについて、元論文ではMistralモデルに関しては公式の「Mistral Instruct v0.2」を使っていて、これは「LightChatAssistant 2x7B」でも踏襲されています。
  • なお、チャットモデルをベースにしてChatVectorをかけると情報損失が起きるリスクがあるので、ベースにはあくまで共通のベースモデルを使うべきとしています。「LightChatAssistant 2x7B」ではおおもとの「Mistral-7B-v0.1」をベースに使っています。
  • ChatVectorの重みをどれらい強くかけるかはケースバイケースのようで(論文で使われているのは1.0と0.5)、1でかけると出力に英語が混ざったりするようです。「LightChatAssistant 2x7B」では0.8を採用しています。

ということで、次回はChatVectorを使って実際に「LightChatAssistant 2x7B」を再現してみたいと思います。