🐹 임베딩 모델이란?
임베딩 모델은 텍스트, 이미지, 음성과 같은 데이터를 벡터 공간으로 변환하는 역할을 하는 모델을 말한다.
벡터들은 숫자 값으로 이루어진 고차원 공간에서 표현된다. 이를 통해 데이터 간의 의미적 유사성을 비교하거나 분석할 수 있다.
M3-Embedding모델에 대해 정리해봤다.
https://arxiv.org/pdf/2402.03216
Abstract
M3-Embedding은 multi-linguality(다국어 지원), multi-functionality(다기능성), muti-granularity(다중 그레뉼러리티)를 갖추고 있다
이 모델은 Dense Retrieval, multi-vector retrieval, sparse retrieval의 세가지의 주요 검색 기능을 동시에 수행할 수 있다. 또한 짧은 문장에서 부터 최대 8192토큰까지 처리할 수 있다.
이 논문에서는 새로운 self knowledge distillation접근 방식을 제안한다. 이는 다양한 검색기능에서 생성된 relevance scores를 학습 품질을 향상시키기 위한 teacher signal로 통합한다.
이 모델은 배치전략을 최적화 하여 대규모 배치 크기와 높은 처리량(throughput)을 가능하게 하여 임베딩의 변별력을 향상시킨다.
M3-Embedding은 multilingual, cross-lingual, 긴 문서에서 최첨단 성능을 보여주며, 우수한 성과를 기록했다.
Introduction
텍스트 임베딩은 NLP에서 DNN의 응용에 있어 중요한 구성 요소이다. 텍스트 데이터를 Latent Space에서 표현하여 데이터의 내재된 의미를 Output Embeddings을 통해 나타낼 수 있다.
사전 학습된 언어 모델의 발전에 따라 텍스트 임베딩의 품질이 크게 향상되었으며 이는 정보 검색 시스템(IR)에서 필수적인 요소 이다. 임베딩 기반 IR 응용은 텍스트 임베딩을 사용하여 쿼리와 관련된 항목을 벡터 유사도에 기반하여 검색하는 형태이다.
하지만 기존의 텍스트 임베딩 모델들은 유연성(Versatility)에서 다음과 같은 제한점을 가진다.
- 언어 편향
- 대부분의 임베딩 모델은 영어에 최적화되어 있으며, 다른 언어에서는 유사한 성능을 보장하지 못함.
- 단일 검색 기능 지원
- 기존 모델은 단일 검색 기능에만 최적화
- 긴 문서 처리 한계
M3-Embedding은 3가지 주요 특징을 통해 기존 한계를 극복한다.
- Multi-Linguality
- 100개 이상의 언어를 처리하며, 다국어 간의 의미적 유사성을 유지할 수 있는 통합 공간 제공
- Multi-Functionality
- dense retrieval, sparse retrieval, multi-vector retrieval 등 다양한 검색 기능 수행
- Multi-Granularity
- 짧은 문장에서 긴 문서까지 다양한 텍스트 길이 처리 가능 (최대 8,192 토큰)
M3-Embedding의 학습 과정은 3가지의 challenge를 보여준다.
- Self-Knowledge Distillation (다중 검색 기능을 공동으로 학습)
- [CLS] 임베딩은 dense retrieval, 그외 임베딩은 sparse retrieval, multi-vector retrieval용으로 사용
- dense retrieval, sparse retrieval, multi-vector retrieval 에서 생성된 관련성 점수(relevance scores)를 학습 품질을 높이기 위한 교사 신호(Teacher Signal)로 활용
- 앙상블 원리
- 효율적인 배치 전략
- 대규모 배치(batch size)와 높은 처리량(throughput)을 활용하여 임베딩 변별력을 향상.
- 문장 길이에 따라 그룹을 정하고, 짧은 문장에 더 많은 배치 적용
- 고품질 학습 데이터 구축
- 대규모 다국어 코퍼스(Multi-lingual Corpus), 증강 데이터(Synthesized Data), 및 데이터 샘플링(Data Sampling)을 통해 학습 데이터셋 품질을 개선
M3-Embedding
이 모델은 3가지의 주요 특징을 제공한다
- 다양한 언어와 다양한 길이의 input data를 처리한다
- x언어에서 y언어로의 교차 검색이 가능하다
- dense retrieval, sparse retrieval, multi-vector retrieval 와 같은 다양한 검색 작업 수행이 가능하다.
Data Curation
이 모델은 대량의 다양한 데이터로 학습되었다. 3가지의 주요 데이터 소스를 활용했다.
- 비지도 학습 데이터
- Wikipedia, CC100, MTP, 및 CC-News와 같은 다국어 코퍼스(Multi-Lingual Corpus)에서 추출한 데이터
- 데이터는 194개 언어와 265개의 언어 간 조합에서 12억 개의 텍스트 쌍을 포함
- 라벨링 된 데이터
- NLLB 팀과 CCMatrix 등에서 번역된 데이터가 포함
- 데이터에서 관련성이 낮은 부분을 제거하고 의미적으로 중요한 텍스트를 필터링
- fine tuning데이터
- HotpotQA, TriviaQA, SQuAD 등에서 가져온 고품질 다국어 텍스트 데이터
- Fine-tuning 과정에서 추가적인 학습 데이터로 사용되며, 긴 문서 처리 및 언어 간 작업을 위한 모델의 성능을 강화
최종적으로 긴 문서 검색 및 다국어 작업의 부족한 데이터를 보완하기 위해 합성 데이터를 생성했다.
이는 긴 문서 데이터를 위해 wikipedia, wudao, mC4와 같은 곳에서 랜덤하게 구절을 고른다. 질문은 GPT-3.5를 이용하여 생성되었다.
Hybrid Retrieval
M3-Embedding은 다양한 검색 기능을 통합하여 사용한다.
- Dense Retrieval 벡터 유사도를 기반으로 데이터 검색
- Lexical Retrieval 키워드 기반의 희소 검색
- Multi-Vector Retrieval 여러 의미적 특징을 다중 벡터로 표현하여 세밀한 검색 가능
Self-Knowledge Distillation
임베딩 모델은 positive samples과 negative samples을 구별하도록 학습된다. 각 검색 방법에서 쿼리의 양성 샘플에 대해 높은 점수를 부여하고, 음성 샘플에는 낮은 점수를 부여하도록 한다. 따라서 학습 과정은 InfoNCE 손실 함수를 최소화하는 데 집중한다.
- p∗: 쿼리 q에 대한 양성 샘플
- P′: 쿼리 q에 대한 음성 샘플
- s(⋅)은 검색 방법 중 하나를 나타내며, {s_dense(⋅),s_lex(⋅),s_mul(⋅)}로 정의
증류 과정은 아래를 따른다.
통합 점수 계산
s_inter(각 검색 방법에서 계산된 예측 점수를 가중 합)를 교사신호로 사용하며 각 검색의 손실 함수 수정
- p(⋅) : 소프트맥스 활성화 함수
- s_s: s_dence, s_lex, s_mul중 하나
최종적으로 통합 및 정규화된 수정된 손실 함수는 다음과 같다
수정된 손실함수와 원래 손실 함수를 선형 결합하여 최종 손실함수를 정의한다.
학습과정은 다단계의 workflow를 가진다.
- 모델은 대규모 비지도 데이터로 pre training되며 이 단계에서는 dense retrieval만 학습한다
- 여기에서 자기 지식 증류가 사용되며 3개의 retrieval 기능이 통합되도록 fine tuning 된다
초기에는 sparse retrieval이 낮은 성능을 보이므로 이를 반영하여 가중치를 조정했다. 라벨 데이터와 합성 데이터 모두를 사용하여 모델 초기화를 강화한다.
Efficient Batch
임베딩 모델은 다양하고 많은 다국어 데이터를 학습하여 각 언어의 일반적 의미를 포착해야한다. 배치 사이즈는 가능한 크게 유지하며 많은 배치 샘플을 활용하여 텍스트 임베딩 변별력을 보장해야한다. GPU의 메모리와 연산 능력에는 한계가 존재하여 일반적으로 입력 데이터를 짧은 시퀀스로 자르거나 작은 배치크기를 사용해야한다. M3-Embedding은 짧은 시퀀스와 긴 시퀀스를 모두 효율적으로 처리해야하기 때문에 위 방법은 적합하지 않다.
배치 전략은 아래와 같다
- 시퀀스 길이 기반 그룹화
- 학습 데이터는 시퀀스 길이에 따라 사전에 그룹화
- 미니 배치 내의 샘플은 동일한 그룹에서 선택 (유사한 시퀀스 길이로 인해 padding의 양을 줄이기/ GPU의 효율성을 높이기)
- 고정된 랜덤 시드 사용
- 여러 GPU에서 데이터를 샘플링할 때, 랜덤 시드는 항상 고정
- 이는 로드 밸런스를 유지하고 학습 과정에서의 변동성을 줄이며 시간 낭비를 최소화
- 긴 시퀀스 데이터 처리
- 긴 시퀀스 데이터는 서브 배치로 나뉘어 관리
- 각 서브 배치는 Gradient Checkpointing을 사용해 메모리 사용량을 절감하며, 생성된 임베딩을 효율적으로 수집