DPO(直接選好最適化)とは何か、メモ

  • 最近、自作の指示応答データセットを使った微調整によってLLMにパーソナライズされた知識を追加することを試みているのですが、その際にモデルに植え付けられた過剰なアラインメントが知識追加の障壁になる場合があります。
    • 例えばモデルに対し「USER: 好きな色は何色ですか?\nAI: 私の好きな色は青です」と学習させたいのに、モデルが頑として「AI: 私はAIなので人間のような色の好みはありません」と回答し続けるような場合があります。
    • しつこく学習を続ければデータ通りに回答するようにはなりますが、そこまでSFTをやりすぎるとオーバーフィットでモデルの質が劣化します。
  • このような場合、DPO(直接選好最適化)を使って微調整すると過剰なアラインメントを簡単に除去できる可能性があります。
  • ということで、DPOとは何かを開設する記事をまとめて斜め読みしてみました。具体的なコード等は割愛しています。

メモ1:「Mistral-7BをDPOでファインチューンする」

towardsdatascience.com

  • 人間のフィードバックからの強化学習 (RLHF)は、学習済み言語モデルの回答からバイアスや有害性を取り除くことを主な目的として考案された。
  • しかし最近では、RLHFとその派生手法がアラインメントだけでなくパフォーマンスの向上にも寄与することが分かっている。
  • そもそもRLHFの概念はロボット工学で長い間使用されてきたが、OpenAIの論文「人間の好みから言語モデルを微調整する」でLLM向けに普及した。
  • この論文のフレームワークでは、1)まず「報酬モデル」を人間のフィードバックを近似するようにトレーニングし、2)次に近似ポリシー最適化(PPO)アルゴリズムを使い、報酬モデルを介して対象モデルをポリシー微調整する。

  • PPOのポイントは、ポリシーを小さく段階的に更新する点にある。大規模な更新は不安定性や最適化の失敗につながる可能性があるため。
  • とはいえPPOはやはり不安定で(損失が発散しやすい)、再現が困難で(ハイパーパラメータが多い、ランダムシードに敏感)、しかも計算コストが高いという課題がある。
  • そこで、直接選好最適化 (DPO)アルゴリズムが良い代案になる。
  • DPOは、タスクを分類問題として扱うことで制御を簡素化する。具体的には、学習対象モデル(ポリシーモデル) とそのコピーである参照モデルを使用する。
  • DPOはLLM自体を報酬モデルの代わりとして使用することで、大規模なサンプリング、報酬モデルの調節、複雑なハイパーパラメータ調整を必要とせずに、モデルの出力を人間の好みに適合させる(PPOより安定し、効率的で、計算負荷が少ない)。

メモ2:「DPOによる言語モデルのファインチューニング」

www.cerebras.net

  • DPOはRLHFに代わる手法として登場したものであり、RLHFの基本原理を基礎としながらも異なる実装を持つ。DPOでは、RLHF的な損失関数を選好推定のためのBradley-Terryモデルと組み合わせて利用する。
  • これにより学習プロセスが単純化され、RLHFで採用されている複数モデルによる学習と複雑な強化学習最適化は不要になり、モデル収束の安定性も高まる。
  • BTLM(Cerebrasの独自LLM)等を用いた検証では、DPOの適用により会話タスクにおけるモデルの熟練度が向上するだけでなく、その他のさまざまな下流タスクのパフォーマンスも向上した。

  • また、DPOによる事前学習データの忘却について検証したところ、1)DPOによる知識の忘却はごく限られること、2)betaパラメータが大きいほど情報を保持しやすいこと、が示唆された。

  • 一方で、beta値が小さい(0.1 未満) と、チャット指向のデータセット(Anthropic-HH)において最適な結果が得られる。

  • 指示応答や要約に重点を置いたデータセット(SHP, Redditから収集した選好データ)ならbeta値が大きい(0.3 ~ 0.5の範囲)ほうが効果的だが、総じてチャットスタイルのデータセット(Anthropic-HH)の方がDPOに適している。
  • SFTとDPOの学習に異なるデータセットを使用することが効果的な可能性が示唆されるが、これは元のDPO論文の提案とは食い違っている。

メモ3:「DPOによるLLMの選好チューニング」

huggingface.co

  • この投稿では、有望な3つのLLMアライメントアルゴリズムである直接選好最適化(DPO)、同一選好最適化 (IPO)、カーネマン-トヴェルスキー最適化 (KTO) の実証的評価を行っている。
  • DPO:DPOは、LLMを人間やAIの嗜好にアライメントするための有望な選択肢として登場した。強化学習に基づく従来のアライメント手法とは異なり、DPOはアライメントの定式化を選好データセット上で直接最適化できる単純な損失関数として再構成する。
  • IPO:DPO の欠点の1つは、選好データセットに過剰適合する傾向があること。これを回避するため、Google DeepMindの研究者は、Identity Preference Optimisation (IPO)を導入した。これは、DPO損失に正則化項を追加し、早期停止などのトリックを必要とせずにモデルを収束までトレーニングできる。
  • KTO:ほとんどのアライメント方法と同様に、DPOではペアの選好データセットが必要で、アノテーターが有用性や有害性などの一連の基準に従って、どの応答がより良いかをラベル付けしなければいけない。現実的には、この作業には時間とコストがかかる。そこで、ContextualAIはKahneman-Tversky 最適化 (KTO)と呼ばれる代替案を考案した。これは、損失関数を「良い」または「悪い」とラベル付けされた個々の例 (たとえば、チャット UI に表示される 👍 または 👎 アイコン) に基づいて定義できる。
  • 以下は、ペア選好データセット(HuggingFaceH4/orca_dpo_pairs)を用いてZephyr-7b-beta-SFTおよびOpenHermes-2.5-Mistral-7Bのポリシー最適化を実行した結果である。

  • 少なくとも適切なbeta値においては、DPOが最もパフォーマンスの優れたLLMアラインメントアルゴリズムであることが示唆された。

メモ4:「Starling-LM-7B-beta」

sc-bakushu.hatenablog.com

  • 現在、7Bモデルで最もチャット性能が高いと評価されている「Starling-LM-7B-beta」では、DPOではなくPPOを採用している。
  • これはStarling-LMがすでにC-RLFTで微調整済みのOpenchat-3.5をベースにしているため。(DPOと類似性のある)C-RLFTのような手法で学習済みのモデルをさらに鍛えるには、DPOではなくPPOのほうが有望とのこと。

雑感

  • OpenHermes-2.5-Mistral-7BやOpenchat-3.5のような徹底したファインチューニングを経たチャットモデルにおいては、DPOによるチャット性能向上の余地は限られるようです。
  • やはりPPOとDPOは別物であって、パラメータの最適化と計算資源の確保の問題さえクリアできるならPPOのほうが有効性が高いように見えます。とはいえ、現実的に素人が個人でやるならPPOの選択肢は無さそうですが。
  • 知識追加のためのファインチューニングの観点では、おそらくSFTで知識そのものをインプットしたあとで、学習させた知識を出しやすくするための仕上げの工程でDPOを使うことになると思います。DPOによる知識の忘却はそれほど心配なさそうです(SFTモデルにDPOを上掛けしたモデルがたくさんある)。
  • DPOにおいて重要となるbeta値は、通常0.1程度が標準になっているようですが、その最適値は、目的/モデル/データセットに応じて完全にケースバイケースになるようです。

 

Starling-7B: RLAIF で LLM の有用性と無害性を向上させる

  • お馴染みのLMSYS Chatbot Arena ELOランキングが更新されていました。

 

  • Claude 3シリーズのレーティングの高さも目を引く一方、Mistral 7BベースのStarling-LM-7B-Betaが小型モデルとしては際立ったスコアを示しています。
    • これはStarling-LM-7B-alphaの後継として今年3月にリリースされたモデルです。前モデル同様、Mistral 7BベースのOpenchat-3.5をもとに、ポリシー最適化でファインチューンしたモデルです。
    • サンプル数が少ないので暫定ですがArena ELOレーティングはClaude 2相当で、GPT-3.5/Mixtral 8x7B/Gemini Proを上回っています。
    • もちろんコーディングや計算/推論、多言語スキルなど含めた汎用性能では大型モデルに劣ると思いますが、7Bモデルでも特定の用途(ここでは英語チャットボット)に特化すれば高性能に到達できるというのは面白いです。
  • 個人的には、Starling-LM-7B-Alphaは受け答えがあまりに「疑似ChatGPT」感があって敬遠した記憶があるのですが、改めて興味を持ったので当時のStarling-LM開発チームのブログ投稿(Alphaリリース時のもの)に目を通してみました。

"Starling-7B: RLAIF で LLM の有用性と無害性を向上させる"

starling.cs.berkeley.edu

概要

教師あり微調整(SFT)は、特にChatGPT/GPT-4(Alpaca、Vicuna、OpenHermes 2.5、Openchat 3.5を含む)から抽出された高品質なデータを活用する場合、言語モデルからチャットボットシステムを開発する際に顕著な効果を発揮します。しかし、人間のフィードバックからの強化学習(RLHF)やAIのフィードバック(RLAIF)が、高品質な嗜好データをスケーリングする際に、どの程度モデルを強化できるかは、未解決の問題のままです。Zephyr-7B、Neural-Chat-7B、Tulu-2-DPO-70Bなどのオープンソースコミュニティにおける初期の試みは、直接選好最適化(DPO)を採用していましたが、OpenHermes 2.5やOpenchat 3.5のような主要なSFTモデルと比較した場合、MT Bench(およびChatbot Arenaの一部)でのパフォーマンスは、RLHFの可能性を十分に示すものではありませんでした。

RLHFのより詳細な研究を促進するためには、チャットに特化した高品質のランキングデータセットが不可欠です。私たちは183KチャットプロンプトからなるGPT-4ラベル付きランキングデータセットNectarを公開します。各プロンプトはGPT-4、GPT-3.5-instruct、GPT-3.5-turbo、Mistral-7B-Instruct、Llama2-7Bのような様々なモデルから抽出された7つの回答を含み、合計380万のペアワイズ比較を提供します(GPT-4に順位を求める際、位置の偏りを軽減するためにかなりの努力が払われましたが、その詳細は以下のデータセットのセクションで詳しく説明します)。

オープンソースの報酬モデルは著しく少ないのが現状です。我々は、NectarデータセットのKワイズロスで学習させた報酬モデルStarling-RM-7B-alphaを公開することで、このギャップに対処します。

私たちは学習した報酬モデルを使用して、Openchat 3.5の言語モデルを微調整しました。その結果、MT-Benchのスコアは7.81から8.09に、AlpacaEvalのスコアは88.51%から91.99%に向上しました。どちらの指標もチャットボットの有用性を評価するものです。

メモ

  • 「RLAIF」とは「人間のフィードバックからの強化学習(RLHF)」をAIで代替した「AIのフィードバックからの強化学習」を指す。
  • RLHF/RLAIFは、すでに教師あり微調整(SFT)を済ませたモデルに対し、学習の最終工程として行うもの。AIチャットボットとして人間の嗜好にあった好ましい回答ができるようにポリシー最適化する。
  • Starling-LM-7B-alphaの具体的な作成手順は以下の通り。まず、様々な言語モデルによって同一のプロンプトに対する7つの応答パターンを出力させる。

  • 次に、その7応答に対してGPT-4に「回答としての好ましさ」を評価させ、順位付けする。このプロセスにより、著者らはNectarという183Kのランキングデータセットを作成している。
  • さらに、このランキングデータセットをもとに報酬モデル(Reward Model)を作成する。ここでは、Llama-2-Chat-7BにNectarデータセットで追加学習させ、Starling-RM-7B-alphaという報酬モデルを作出している。

  • 最後に、教師あり微調整でトレーニング済みの初期モデル(ここではMistral 7B v0.1をSFTしたOpenchat-3.5)に対し、先ほどの報酬モデルを用いたポリシー微調整を施し、Starling-LM-7B-alphaが完成する。
  • なお、この研究では複数のポリシー最適化の手法を予備実験によって比較検討している(DPO、APA、PPO、P3O)。
    • DPOは実装がシンプルで、事前に収集されたオフラインの嗜好データセットに基づいて言語モデルを直接更新する。手軽でありローカルLLMコミュニティでもよく用いられる。
    • 対照的に、PPOなどのオンラインRL手法では、現在の言語モデルを使用して新しい応答をサンプリングし、トレーニング済みの報酬モデルを使用して新しい応答にスコアを付け、新しい応答の報酬情報を使用して言語モデルを更新する。PPOのパラメータ最適化については試行錯誤が必要。
    • DPOを用いた予備実験では、初期モデルOpenchat 3.5と比べて大幅な改善は見られなかった。Openchat-3.5自体がC-RLFTというオフラインの嗜好に基づく強化学習を受けたモデルであり、機能的に類似したDPOの重ね掛けでは追加的な効果が得られなかった可能性がある。
    • この研究では最終的に、報酬モデルを用いたPPOの派生手法であるAPAによって学習したチェックポイントを採用した。
  • ここでのポリシー最適化は、トレーニング速度の向上のためモデルの最後の4層のみを解凍して行っている。モデルは、バッチサイズ28、合計10,000ステップで8個のA100 GPUでトレーニングされた。

Starling-7B-alphaとStarling-7B-betaの違い

  • 上述のブログ記事は、Starling-7B-alpha公開時のもので、後継モデルであるbetaではいくつかの点で事情が異なっています。
  • Starling-LM-7B-betaでは、Yi-34Bベースの新たな報酬モデルStarling-RM-34Bを作成し、さらにポリシー微調整にはAPAではなくPPOを採用しているようです。

 

消費者向けAIチャットサービスの収益化問題

  • ChatGPTのような消費者向けAIチャットサービスの収益化問題に関する記事がRedditで共有されていました。

www.businessinsider.com

記事によると

  • 最近「Inflection AI」というAIスタートアップから主要メンバーがMicrosoftに引き抜かれ、会社が瓦解しかかっていることが話題になっています。
  • Inflection AIはAIチャットボット"Pi"や大規模言語モデルInflectionシリーズを手掛け、昨年にはビル・ゲイツ氏やNVIDIAなどから10億ドルを調達していました。
    • "Pi"は、ChatGPTブーム期の2023年3月に「共感性を持つAI」という謳い文句でローンチされ、その後は大きく話題になることもなく今に至ります。
  • Piだけでなく、ChatGPTやClaudeのような消費者向けAIチャットサービス全般の利用者数が伸び悩んでおり、事業としての持続性の問題に直面しています。
  • OpenAIは現在、企業向けのサービスに力を入れており、そちらに収益性の糸口を見出そうとしているように見えます。

雑感

  • 昨年のChatGPT流行期に騒がれた「Googleの検索サービスが一気にAIチャットに取って代わられる」という予測なども、今のところ下火になっています。
  • 例えば、AIチャットボットが一般的に活用される用途は相変わらず「コーディング支援」「翻訳」「要約」あたりに偏っている印象がありますし、単純な調べ物はチャットを介さず自分で検索したほうが結局速い、という経験を多くのユーザーが持っていると思います。
    • 問題の一部は、検索機能や端末との統合性が不十分であるという点に起因していて、今後AndroidiPhoneでAIシステムが拡充されれば使い勝手が向上するはずです。
    • 一方で、チャットインターフェースは単純な情報検索には不向きであるという事実は変えづらいように思えます。
  • 今の言語モデルベースのAIサービスの抱える問題は、電気自動車(EV)の抱える問題と似ているかもしれません。
    • つまり、超長期的には有望なのは間違いない一方で、現時点ではコストが高いわりに融通が効かず、なおかつそれを活用するためのインフラ整備も不十分です。
    • そのため、アーリーアダプターによる採用が一巡してしまうと、それ以上のマーケット拡大が難しいという課題に直面します。
    • しかも、参入が比較的容易で競争が激しいため、価格を上げて収益性を確保することもままならず、ひたすら事業資金を飲み込まれてしまいます。
    • そのため、規模の小さいプレイヤーから順番に資金繰りが難しくなり、競争からフェードアウトしていくことになります。
  • 目下のAIバブルの規模はEVバブルと比べても格段に大きいので、万が一これが一気に弾けるとなると、色々な意味で厳しい感じがします。
  • 現状のローカルLLMの領域も、ビッグテックやその影響下のAIベンチャーが公開するベースモデルに依存しています。なので、クローズドなAIによる寡占も困るけれどバブル崩壊で開発が滞るのも困る、というのが都合のいい本音です。

 

微調整データセットには事前学習データも混ぜたほうがいい?

投稿の要旨

  • 言語モデルのファインチューンは基本的に「加算的」ではなく「破壊的」な側面を持つ最適化プロセスである。
  • 他方で、ファインチューンで広く使われている指示応答型データセットは一定の偏りや癖を持っている。
    • 特定のプロンプト形式で統一されていたり、
    • GPT-4で機械的に合成されたデータを使っていたり、
    • (データがマルチターン形式でない限り)ごく狭い範囲のコンテキスト内で予測を行うように変質したりする。
  • そのため指示応答型データだけで追加学習させると、しばしばwikitextなどで測定される当惑度(perplexity)が悪化し、モデルの一般性能が低下する。
  • そこで、RedPajamaサブセットのような事前学習用データセットをファインチューンデータに混ぜることで、このような問題を回避できる可能性が高まる。
    • その比率について検証した研究は見当たらないが、さしあたってファインチューンデータの25%程度含めるのがよい出発点になるのではないか。

雑感

  • ファインチューンによる過剰最適化によって、モデルの性能が低下したり以前持っていた知識を忘れたりする事象はよく問題になります。
  • 例えば知識学習の文脈では、こちらの研究のように前段階の学習内容を繰り返しながら後続の学習を行う(Replay Buffer)ことで、壊滅的忘却を防ぐケースがあります。
  • 単に無関係のテキストを混ぜるだけだとファインチューン精度の低下が心配です。できればファインチューンデータのフォーマットを多様化したり、学習する知識領域を周辺隣接分野に広げたりすることで過剰最適化の防止と学習精度の維持を両立できないかな、と考えています。
  • 個人的には、単純な知識学習後のPerplexityやHellaSwagスコアの変化を観察していますが、今のところ性能面の悪影響は見られません(むしろPerplexityが改善したりする)。r=64程度の高めのランクでがっつりファインチューンしていますが、データセットが小さいと影響が出にくいのかもしれません。

 

Mistral AI のCEO、Arthur Mensch の対談メモ

www.youtube.com

Mistral AI と Figma のCEOの対談に関する投稿がRedditに上がっていた(文字起こしのリンクが貼られている)。目を通して気になった点を適当にメモしておく。

  • Llama-7Bのような小型のモデルはコミュニティの需要が大きい一方で、改善の余地が多くあると気づいたので、まず小型モデルに狙いを定めた。
  • Mistral 7Bの開発・リリースには4か月かかった。500台のGPUを使い、5人のチームでほとんど休暇を取らずに作業した。個人的にはAI開発チームは4-5人程度の規模がベストだと考えている。
  • 当面は新しいオープンソースモデル(汎用モデルと、金融など特定領域に特化したモデル)のリリースを控えているほか、Mistralのウェブプラットフォームの機能拡充を進めている。
  • Microsoftと提携しAzureでMistralのモデルが採用されたことで、1000社ほどの顧客を獲得できた。
  • 資金調達により計算資源が増えたことでより大型のモデルを開発する余力ができた。ただし、当社はあくまで推論の効率性を重視していて、オープンソースの小型モデルも引き続きリリースしていく。
  • 2年ほど前まではRLHF(人間のフィードバックによる強化学習)が非常に重要だった。今では言語モデル自身を使って強化学習できるので、確かにその重要性は以前ほどではない。一方で、LLM開発が隆盛を極めていることでRLHFを低コストで行えるような環境も整備されつつある。
  • 今後3年以内には、多くのホワイトカラー業務でAIが人間を代替できる状況が生まれているのではないか。AIエージェントをデプロイし、評価し、ロバストで信頼性の高いものにする方法を見出すことが重要。
  • いわゆるnext token predictionだけでは、多彩な科学領域で実用的なツールとなることは難しい。
  • 昨今のGPU不足と計算コスト高騰は、ハードウェア分野での競争が進むことで次第に緩和されるだろう。NVIDIAのチップにはメモリ帯域幅の問題があり、Transformersに最適化したカスタムなチップが登場すればコストは大幅に削減される。
  • EUなどのAI規制に関しては、喧伝されている実存的リスクは定義があいまいで、科学的根拠にも欠けている。いくつもの異なる議論がごちゃまぜにされている。
  • 音声AIなどによるディープフェイクはもちろん大きな懸念があるが、当社はひとまずテキスト生成にフォーカスしており、この領域では現実的なリスクを制御できると考えている。
  • 多くのLLMは英語中心だが、英語は言語のひとつに過ぎず、我々は欧州の諸言語に注力し、そこに大きなマーケットを見出した。他にもアジアではアジアの言語に優れたモデルに対するマーケットがあるのだろうが、そこは我々の力の及ぶところではない。
  • 今後も、効率的なオープンソースモデルと強力なクローズドAPIを並行して提供する戦略を維持する予定。
  • 当社のAI開発チームでは、インフラスタックからパイプラインの作成、抽出、変換、ロード、数学的考察まで、あらゆることができる人材を探してきたが、そのようなフルスタックのAIエンジニアは行動に偏りがある傾向があった。
  • 私たちがフォーカスしたのは、退屈な裏方仕事も嫌がらずにこなしてくれるような利己的でない人材を探すことで、それが実際にチームに生産的な結果をもたらした。
  • Mistral Large に勝る 7B モデルが実現できるかどうかは少し難しい。方法はあるかもしれない。特定のタスクに絞れば非常に強力な7Bモデルを作れるだろうが、例えばこのサイズで多言語モデルを作るのはおそらく良いアイディアではない。

 

LoRAのランク(r)は高いほうがいいのか?

  • LoRAファインチューンでは様々なハイパーパラメータがあります。モデルとデータセットに合ったパラメータを選ぶことで、学習速度・精度が変わります。
  • 今日は主要なハイパーパラメータの一つであるLoRAランク (r)が気になったので、簡単な備忘録を書いておきます。

LoRAのランクとは

  • LoRA(Low Rank Adaptation)は、モデルの追加学習の際に膨大な重み行列を低ランク行列に分解することで計算資源を節約する手法です。
  • 低ランクに分解するとモデルが追加学習するパラメータ数が減るので、ランクを低い値に設定するほど計算コストが低下します。
  • 学習パラメータ数を減らすと学習精度も悪化するようにに思えますが「実はファインチューンの際の重みの変化は低い固有のランクを持っているため、低ランクでの学習でもフルのファインチューンに匹敵する精度を得られる」というのがLoRA論文の主張です。
    • 言語モデルにはある種の冗長性があって、重要なパラメータは低い固有次元に偏っている、と理解すればいいようです。
  • さて、このLoRAによる学習時の具体的なランクの値としては、8,16,32,64, 128 あたりがよく使われますが、この値の選び方について所説あるようです(別に8の倍数である必要はありません)。

「だからランクは低くて十分だよ」説

  • 上述のとおり、おおもとのLoRA論文では低いランクでも十分高い精度が期待できるとされています。具体的には、ランク8程度でもランク64とほとんど変わらない学習パフォーマンスを得たようです(GPT-3の場合)。
  • LoRAに量子化を組み合わてさらなる効率化を図ったQLoRA論文でも、8~64のどのランクを選んでも学習パフォーマンスの差はみられないという分析を示しています(Llama-1-7B、4bitのQLoRAの場合)。
  • 通常のファインチューンとして行うならランクは8ないし16(高くても32程度)でOK、というのがLoRAの一般的な考え方です。

「とはいえランクは高い方がいいよ」説

  • とはいえ計算資源に余裕があれば無理にランクを下げる必要はありません。場合によっては64や128などのランクを使いたくなるかもしれません。
  • 例えば前回のDoRA論文などをみると、フルのファインチューンとLoRAでは依然として学習精度にギャップが残るとする主張も根強いようです。
  • ランクを上げれば学習パラメータ数が増えるので、実質的にフルファインチューンに近くなるという想定ができます。
  • ドメイン知識追加の目的で最適なLoRAハイパーパラメータを探索した研究では、100程度の高いランクが効果的と示唆しています(Llama-2の場合)。

「むしろランクは低い方がいいよ」説

  • 一方で、ランクを上げると学習が不安定化するので値はむしろ低いほうがよい、とする検証結果もあります。
  • ランクの違いによるLlama-2モデルの挙動を調べた研究では、1)モデルの当惑度(perplexity)はLoRAランクを上げても改善しないうえ、r=2048など極端な値をとればむしろ悪化する、2)高いLoRAランクをとると勾配ノルムが破綻する、という結果を示しています。
  • 他にもモデルに日本語知識の追加を試みた検証では、Llama-2-7Bの場合にランクが低いほうが学習が安定したと報告しています(ただし13Bではランクの影響なし)。

雑感

  • 前提として、適切なランクの設定にはモデル、データセット、エポック数や学習率など他の要素との兼ね合いが重要で、一般化するのはあまりに難しい気がします。
  • そのうえで強いて言うなら、会話スタイルなど形式面のファインチューンを行いたい場合は8~16の低いランク、知識の追加など内容理解を伴うファインチューンを行うならもっと高めのランクがよさそうな印象です。
  • 一方で、高すぎるランクを使うと学習が不安定化するリスクがあり、また、当然ながら必要な計算コストも膨らみます。基本的に128~などの高いランクは避けたほうが無難に見えます。

補足:LoRAのアルファについて

  • ランク(r)とは別にLoRAにはアルファ(α)というハイパーパラメータもあります。
  • アルファはLoRAにおけるスケーリングファクターで(weight += (lora_B @ lora_A) * (alpha / r))、学習がモデルの重みに与える影響を強さを決定します。例えばr=32, α=16なら、32/16=2倍のスケールとなります。
  • よく使われるのはLoRA論文の実装で使われた2倍(つまりランクの半分のアルファ)ですが、スケールは1倍がよいという人、2.5倍や3倍がよいという人もいて、ケースバイケースのようです。
  • QLoRA論文では、アルファは常に学習率に比例するとして、ランクとは独立にアルファを固定しています(ランクより小さい、つまり1倍未満のスケールを採用している)。
  • アルファを高めに設定しても基本的に問題なさそうですが、学習率が高めになる場合には、対応して下げざるを得なくなるかもしれません。

参考

LoRA論文:[2106.09685] LoRA: Low-Rank Adaptation of Large Language Models

QLoRA論文:[2305.14314] QLoRA: Efficient Finetuning of Quantized LLMs

DoRA論文:[2402.09353] DoRA: Weight-Decomposed Low-Rank Adaptation

ドメイン知識追加論文:[2312.03360] Teaching Specific Scientific Knowledge into Large Language Models through Additional Training

rsLoRA論文:[2312.03732] A Rank Stabilization Scaling Factor for Fine-Tuning with LoRA

日本語知識追加:大規模言語モデルのFine-tuningによるドメイン知識獲得の検討 - Preferred Networks Research & Development

【LLM論文を読む】DoRA:Weight-Decomposed Low-Rank Adaptation(重み分解LoRA)

  • ここ数日「Stable Knowledge Editing」を参考にしながら、LoRAファインチューンによるLLMへの知識の追加を試しています。
  • LoRAのハイパーパラメータ調整のコツを調べるなかで、「DoRA(重み分解LoRA)」という別のLoRA派生手法の存在を知りました。HuggingFaceのPEFTライブラリでも対応しているツールのようです。
  • DoRA論文は、2024年2月にNVIDIA香港科技大学の研究者によりarXivに投稿されています。

arxiv.org

概要

広く使われているパラメータ効率的ファインチューニング(PEFT)手法の中で、LoRA(低ランク適応)とその亜種は、追加の推論コストを回避できることから、かなりの人気を得ている。しかし、これらの手法と完全なファインチューニング(FT)との間には、まだしばしば精度のギャップが存在する。本研究では、まずFTとLoRAの本質的な違いを調べるために、新しい重み分解分析を導入する。得られた知見からFTの学習能力に似せることを目指し、重み分解低ランク適応(DoRA)を提案する。DoRAは、事前に学習された重みを大きさと方向の2つの成分に分解し、微調整を行うもので、特に方向の更新にLoRAを採用することで、学習可能なパラメータ数を効率的に最小化する。DoRAを採用することで、LoRAの学習能力と学習の安定性の両方を向上させると同時に、推論オーバーヘッドの追加を回避している。DoRAは、LLaMA、LLaVA、VL-BARTのファインチューニングにおいて、常識推論、視覚命令チューニング、画像/動画テキスト理解などの様々な下流タスクにおいて、一貫してLoRAを凌駕する。

簡単な説明

先述のPEFTのページでは、以下のようにDoRAを紹介しています。

重み分解LoRA (DoRA):

この手法では、重み行列の更新を「大きさ」と「方向」の2つに分解します。そのうえで、方向は通常のLoRAで処理され、大きさは別の学習可能なパラメータで処理されます。これにより、特に低いランク(r)においてLoRAの性能を改善できます。現在、DoRAは非量子化線形レイヤーのみサポートします。DoRAは純粋なLoRAより大きなオーバーヘッドがあるため、推論時には重みをマージしておくことを推奨します。

  • 重み行列を大きさと方向に分解したうえで、パラメータ数の多い「方向」にだけLoRAをかける方法のようです。
  • LoRAのランク(r)が小さく設定しても(学習パラメータ数を減らしても)高い学習精度が期待できるとされています。
  • ただ、現時点で量子化に未対応らしいので、単に学習効率を上げるだけならQLoRAの方が速そうです。実用的には学習精度の向上に注目です。

メモ

  • 一般的に「LoRAのFTに対する精度の低さは、学習可能なパラメータの数が少ないことに起因する(=LoRAでもパラメータ数を増やせばFTの精度に近似する)」とみなされているが、その点の具体的な検証は乏しい。
  • そこで本論文では、まずLoRAとFTの学習パターンの違いを調べるため、重み行列を大きさと方向の2つの別々の成分に分解して分析した。

  • 結果、上図のようにLoRAはすべての中間ステップで一貫した正の傾きを示し、変化の方向と大きさが比例した一方、FTでは緩やかな負の傾きを持った多様な学習パターンを示した。
  • この違いが、おそらくFTとLoRAの学習能力の差を反映する。LoRAは「大きさ」を大幅に変更しながら「方向」を僅かに変えるような微妙な調整は苦手である。
  • そこで本論文では、LoRAの学習パターンをよりFTに近づけ、LoRAよりも学習能力を向上させる変形LoRAとして「DoRA」を提案する。
  • 重み分解LoRA(DoRA)ではまず、事前に訓練された重みを大きさ成分と方向成分に分解し、その両方を微調整する。その際、方向成分はパラメータ数が大きいため、LoRAでさらに分解し、効率的な微調整を行う。
  • 重み分解によって単に学習が効率化されるだけでなく、方向更新を最適化するプロセスがより安定し、LoRAに比べFTに近い学習パターンを実現できる。
  • 以下は、Llama-7B/13Bを8種類の常識推論用のデータセットで学習し、それぞれ対応するベンチマークで評価したもの。LoRA/DoRAではr=32を用い、DoRA†は半分のr=16を使用。

  • 以下はMT-Bench(マルチターンの記述式ベンチマーク)による評価結果。Llama-1/2をAlpacaデータセットでLoRA/DoRAで学習させて評価。なお、VeRAは別のLoRA効率化手法で、DVoRAはVeRAとDoRAを組み合わせたもの。

  • 少ないデータセット数でもDoRAは他の手法に比べて高いスコアを示す。

  • なお、本論文ではマルチモーダルモデルの画像理解タスクにおけるDoRAの有効性についても検証している(割愛)。

雑感

  • ちょうど今DoRAによる学習を試していますが、やはり量子化が使えないと低ランク(とりあえず16で実施)でも学習が遅いので、個人用途では扱いにくい印象です。量子化実装に勝手に期待したいです。
  • ところで、付表にあったランク(r)ごとの学習後ベンチマークがちょっと気になりました。r=64でHellaSwagのスコアだけ大幅に低下しています。

  • 過学習でモデルが壊れてしまったのでしょうか?ほかのベンチマークではr=32と大差がありません。
  • 知識編集目的の場合、基本的にランクは大きいほうがいいとされていますが、性能低下にも気を配る必要はありそうです。