4. 既存 AI モデルへの新しいオブジェクトのクラスの追加

TAO Toolkit では、学習済みのモデルを適応させて新しいカスタム クラスを容易に追加できます。

4.1 課題

公開されているモデルを利用して開発を始めると、アプリケーションに適したものとは異なるクラスが含まれている可能性があります。モデルをカスタマイズする場合、モデルのアーキテクチャをきちんと理解していなければ、クラスを追加または削除することはできません。必要な変更を行うためには、何千行ものコードを調べなければならず、モデルのテンプレート、データローダー、損失関数を変更することになります。新しいクラスを追加する、または既存のモデルからいくつかのクラスを削除することは、転移学習に伴う作業の 1 つです。

例として、人物を検出する既存のモデルにヘルメットのクラスを追加する場合を考えてみます。タスクは、人物とヘルメットの両方を検出することです。この新規のモデルをトレーニングしてヘルメットのクラスを追加する場合、人物とヘルメットの両方が含まれる、適切にアノテーションされたデータセットを用意する必要があります。ヘルメットのデータだけを用意し、人物のデータがないと、モデルはヘルメットに対しては高いパフォーマンスを発揮するようになりますが、人物に対するパフォーマンスは低下します。そのため、データセットには両方のクラスが含まれていることが非常に重要です。一般的に、新しいクラスを追加する転移学習では、既存のクラスと新しいクラスを網羅した代表的なデータが必要になります。

4.2 解決策

NVIDIA の学習済みモデルは、一般的なオブジェクトに対してあらかじめトレーニングされているため、推測とラベル生成に利用できます。そのため、作業する必要のあるタスクは、アプリケーションに必要なカスタム クラスへのラベル付けのみになります。そして、モデルをデータセット全体でトレーニングすれば、元の機能を維持しつつ、カスタム クラスにも対応できるようになります。

前述の人物とヘルメットの例で言えば、ヘルメットのラベル付きデータセットしかなくても、NVIDIA の PeopleNet モデルの推論を利用することで、人物と顔のアノテーションが可能です。NVIDIA の PeopleNet モデルは、数百万枚の顔や人物の画像でトレーニングされているため、これらのクラスをゼロから学習する必要はありません。まず、PeopleNet で推論を行い、人物と顔のラベルを作成してから、推論されたラベルをヘルメット クラスのラベルと統合します。

注: PeopleNet を使用して、人物や顔のクラスのグラウンド トゥルースを生成する場合は、これらのクラスの偽陽性や偽陰性にご注意ください。手動でのクリーンアップが必要になる場合があります。

PeopleNet モデルは人物や顔を既に高い精度で検出できるため、ヘルメット クラスをトレーニングするだけで済みます。それには、トレーニングの仕様の中で、クラスの重み付けをヘルメットに対しては重く、人物や顔に対しては軽くします。これにより、人物や顔を検出するモデルの性能を維持しながら、新しいクラスについて精度の高いトレーニングが可能になります。

4.3 結論

このタスクでは、オープン ソース3 のヘルメット検出用データセット (トレーニング用に 611 枚の画像、検証用に 152 枚の画像) を使用しました。トレーニングでは、新しいヘルメットのクラスに 0.8、人物や顔のクラスに 0.1 の重み付けを設定しました。これは、新しいヘルメットのクラスに比重を置いてモデルをトレーニングするためです。PeopleNet モデルを人物、ヘルメット、顔のデータセットでトレーニングした結果、100 エポック以内でヘルメット クラスの AP が 80% に達しました。

図 9. 再学習後のヘルメット クラスの精度
図 9. 再学習後のヘルメット クラスの精度
図 10. 人物、顔、ヘルメットの推論
図 10. 人物、顔、ヘルメットの推論

学習済みのモデルを使用し、自前のデータセットで推論を行った後、元のクラスに加えてカスタム クラスにも対応するように再トレーニングするという発想の大枠は、PeopleNet に限らず応用でき、TAO の学習済みモデルすべてにこの手法が使えます。ただし、注意点があります。AI モデルを使用してアノテーションを行うと、偽陽性や偽陰性が発生して手動でのクリーンアップが必要になる可能性があります。

このタスクは、Kaggle のヘルメット検出用データセットを用いて実行しました。ガイド付きの完全なタスク実施手順は、TAO タスク GitHub リポジトリで入手できます。


3 https://www.kaggle.com/andrewmvd/helmet-detection