Это продолжениепредыдущей публикациипро реставрацию ruGPT3XL. Для тех кто не читал, кратенько, я конвертировал древний Megatron-LM чекпоинт в HuggingFace-формат, залил веса на HF, накатил поддержку GGUF в llama.cpp и подумал, что всё. Но нет.
По ходу тестов, проведённых разными людьми удалось выявить ряд недоработок, которые я по мере обнаружения правил, ну а после того, как удалось получить стабильную и рабочую версию мне захотелось решить одну старую проблему, которая меня в ruGPT3 моделях очень беспокоила, это проблема маленького контекста в смешные 2k токенов.
Решил поднять контекст до 8k.
PPL, Sparse Attention и Triton
После прошлойпубликациина Хабре меня резонноспросили, а на каких метриках вообще проверялось качество конвертированной модели? Я честно не знал, что ответить, так как гонял MERA в отрыве от оригинала, потому что оригинальную модель через древние Megatron-LM, DeepSpeed и Apex мне запустить так и не удалось, очень старый стек.
Смеркалось, свербило.
Решил взять метрику Perplexity (PPL), она очень простая, плюс указана в карточках всех оригинальных моделей, понятно как считать и что ожидать. Единственная проблема в том, что нужен датасет, на котором тестировали оригиналы, а такого у меня нет, и у SberDevices скорее всего тоже, так как пять лет прошло с тех пор.
Взял датасетgazetaИльи Гусева@Takagi, в нём около 60k русскоязычных новостных статей, все примеры умещаются в 2k токенов, датасет небольшой и всем известный. Написал скрипт расчёта примерно по методологии из оригинальной публикации про ruGPT3, заодно прогнал все четыре размера семейства: ruGPT3small, ruGPT3medium, ruGPT3large и мой ruGPT3XL с наивным dense attention.
Получилась такая вот табличка:
Циферка для ruGPT3small отсутствовала в карточке модели, поэтому там прочерк. Корреляция между замерами на gazeta и оригинально заявленными значениями получилась вполне приличной (R = 0.93):
PPL 50.1 WTF
PPL конвертированного ruGPT3XL первым прогоном показал50.1, а оригинальная модель в своей карточке имеет 12.05. Ошибка в расчётах? Не похоже, ведь у остальных трёх моделей семейства цифры PPL более менее похожие. Значил дело в чём-то другом.
Начал копать. Оказалось, кодовый агент при конвертации решил схалтурить и выбросил механизмSparse Attention, заменив его на обычныйnn.MultiheadAttentionиз GPT-2. Это, конечно, “работает”, модель генерирует текст, вот только веса-то оптимизированы под разреженное внимание, а не под плотное, математика другая, поэтому результат на контексте больше 128 токенов ожидаемо слабый.
Благодаря тому, что я потратил время на детальное изучение исходников Megatron-LM при первой конвертации, понять где именно проблема было несложно. Объяснил агенту что не так, показал примеры кода с правильным механизмом, дал почитать оригинальную публикацию про ruGPT3, и спустя несколько итераций получил исправленныйmodeling_rugpt3xl.pyс репликой Sparse Attention из Megatron-LM.
Sparse Attention, зачем нужен и чем от обычного отличается
Стандартный causal self-attention (то, что в оригинальном GPT-2) - это плотная матрица, где каждый токен смотрит на все предыдущие токены. Память и вычисления растут квадратично от длины последовательности, удвоили контекст, получили в четыре раза больше операций с матрицей внимания и в четыре раза больше памяти потребляем.
Sparse Attention делает то же самое, но с прорежённой маской, в ruGPT3XL используется alternating-паттерн из статьи “Generating Long Sequences with Sparse Transformers” (arxiv:1904.10509):
- Чётные слои (0, 2, 4, …) - block-sparse attention, каждый токен видит только ограниченное локальное окно (128 токенов) плюс несколько “глобальных” блоков через регулярные интервалы. Разные головы внимания используют разные позиции глобальных блоков.
- Нечётные слои (1, 3, 5, …) - обычный плотный causal attention.
Теоретически это даёт почти линейный рост памяти вместо квадратичного. На практике для ruGPT3XL при увеличении контекста в 4 раза память на KV+активации растёт примерно в 3-4 раза (а не в 16x), замеры чуть ниже.
Разница в PPL между sparse и dense режимом для ruGPT3XL на датасете gazeta уже видна на графике выше, но если совсем кратко:
Механизм внимания
PPL (test, gazeta)
Dense (как в GPT-2)
Sparse alternating (оригинал)
После исправления PPL понизился с 50.1 до 11.68, это уже похоже на правду и хорошо коррелирует с заявленными 12.05 у оригинала.
Параллельно выяснилось, что в GGUF-версии та же история - прошлый патч в llama.cpp (PR #21011) добавлял конвертацию весов через архитектуруLLM_ARCH_GPT2, но сама sparse attention там не была реализована. Значит, GGUF-модель тоже считала dense внимание. Пришлось делать новый патч (PR #21161) он добавляет полноценную поддержку ruGPT3XL как отдельной архитектуры со sparse attention.
После релиза фикса механизма внимания один хабровчанин вкомментарияхуказал, что в реализации sparse attention была ошибка:
При обучении маска строилась некорректно для батчей длиннее одного примера, из-за чего обучение падало. Исправил эту проблему.
Ну и под конец добавил поддержкуTritonдля ускорения sparse-операций на GPU.
В преобразованной модели изначально внимание реализовывалось какявноеmatmul+softmax+matmulв режиме PyTorch, это математически корректно, но снижает производительность по сравнению с решениями на базе Triton доступными на современных графических процессорах NVIDIA.
На графике четыре режима работы механизма внимания при обучении (синтетический луп, AdamW, fp16, RTX 4090). Серый столбик - исходныйeagerрежим (~6280 tok/s), это тот самый явный matmul+softmax+matmul. Синий - переключение наF.scaled_dot_product_attention(SDPA) при том же размере батча: +40% почти бесплатно, просто меняется путь исполнения внутри PyTorch. Голубой - SDPA с бо́льшим батчом (5×2048 вместо 2×2048), SDPA эффективнее использует память и позволяет запихнуть больше. Зелёный - SDPA плюсtorch.compileс Inductor-бэкендом: итого ×1.85 к baseline, компилятор дополнительно сплавляет поэлементные операции и местами генерирует Triton-ядра. Числа внутри столбиков - кратность ускорения относительно eager.
Контекст 8k
Откуда идея
2048 токенов - это больная тема для всего семейства ruGPT3, в своё время на этом сгорело не мало моих нервов, пришлось изобретать sliding window в чатах, костыльные стратегии фильтрации датасетов чтобы не поймать OOM, чанковать документы. Всё это конечно же опыт, он мне позже пригодился и не раз, но осадочек остался.
Теперь, когда у меня есть рабочая современная версия модели, грех не попробовать исправить и этот фатальный недостаток.
Вопрос “насколько реально расширить контекст у ruGPT3XL” нетривиальный из-за двух особенностей архитектуры:
- у модели используется “Learned Absolute Positional Embeddings” (Learned APEs), таблица позицийembed_positions, по-простомуnn.Embedding(2048, 2048), обученная вместе со всеми остальными весами. В отличие от Rotary Positional Embeddings (RoPE) таблица APE не умеет экстраполировать - если модель никогда не видела позицию с индексом 2049, она понятия не имеет что туда подставить.
- sparse attention (о которой было выше), сетка разреженного внимания строится изmax_position_embeddings // sparse_block_size, то есть тоже зависит от лимита контекста.
На эту тему нашёл пару релевантных работ:
- “Extending Input Contexts of Language Models through Training on Segmented Sequences” (arXiv:2310.14633) про эксперименты на GPT-2-подобных моделях с dense вниманием, с подъёмом контекста с 1-2k до 4k, там же интерполяция позиционных эмбеддингов, сегментированное обучение и прочите радости ИИинженера.
- “The Impact of Positional Encoding on Length Generalization in Transformers” arXiv:2305.19466) о том, почему APE без дообучения не будет работать на новых длинах, как-раз к нашей модельке применимо.
Короче, нельзя просто взять и увеличитьmax_position_embeddingsв конфиге, ничего хорошего не выйдет, требуется дообучение, а вот после дообучения и с правильной инициализацией уже вполне реально.
Про память и вычисления
Важное следствие sparse attention для планирования экспериментов. Если бы была плотная матрица внимания, переход с контекста L на 4L дал бы примерно 16-кратный рост памяти на self-attention. У ruGPT3XL благодаря alternating sparse-паттерну это скорее 3-4x на практике. Это означает, что 8k контекст в принципе влезет на RTX 4090 с 48 ГБ, причём даже при full обучении (полной разморозкой всех весов модели).
Стратегия расширения
Без изобретения велосипеда ничего не выйдет, так что вот три принципа, которым я решил следовать:
1. Тайлинг позиционных эмбеддингов
Первым напрашивался вариант с линейной интерполяцией, тупо в лоб взять существующую матрицу 2048 x 2048 и заскейлить её до нужного размера - это сработало плохо, интерполяция меняет все 2048 строк, в том числе те, для которых у модели всё итак работало. PPL на коротком контексте сразу после такой операции улетел за сотку, не наш вариант ;)
Покумекав вспомнил про метод тайлинга (зацикливания, ну или проще дублирования), оригинальные позиции 0-2047 копируются буквально, а новые заполняются циклически:
Смысл в том, что модель с первых шагов дообучения хотя бы не паникует на новых индексах, а короткий контекст работает точно так же, как и раньше.
2. Смешанный датасет
60% длинных примеров (несколько статей gazeta, склеенных через EOS до целевой длины) и 40% коротких чанков до половины целевой длины. Длинные обучают новые позиции, короткие не дают модели забыть как работать с привычными контекстами. Без коротких примеров PPL на 2k стремительно деградировал.
3. Ступенчатое расширение
Сначала 2k -> 4k, потом берём обученную 4k-модель и делаем 4k -> 8k. Сразу прыгнуть с 2k на 8k гипотетически можно, но это значит тайлить позиции 2048-8191 из диапазона 0-2047, что довольно грубо, да и модель за три эпохи на небольшом датасете может не успеть освоить такой диапазон, а большее количество эпох может привести к оверфигу (переобучению), чего я бы не хотел допустить.
Параметры обучения
ДатасетIlyaGusev/gazeta, сплитtrain, 3 эпохи на каждый шаг,lr=5e-6с cosine decay, gradient checkpointing, bfloat16, RTX 4090 (48 ГБ) и так далее.
По времени ушло:
- ~2.6 часа на шаг 2k->4k
- ~3.9 часа на 4k->8k
При обучении на 8k контексте CUDA фрагментировала память в процессе backpropagation и падала с OOM на середине - при этом на GPU формально оставался ~1 ГБ свободной памяти, но PyTorch не мог нарезать из него нужный смежный кусок. Решается одной строкой в переменных окружения:PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True. После этого пиковое потребление упало с 46.8 до 38.5 ГБ и обучение дошло до конца без приключений.
Полученная моделька вот тут:evilfreelancer/ruGPT3XL-8k
Perplexity 8k
Тест на сплитеtestдатасетаIlyaGusev/gazeta:
PPL @ 2048
PPL @ 4096
PPL @ 8192
ruGPT3XL (baseline)
ruGPT3XL-4k
ruGPT3XL-8k
Регрессия на исходном 2k контексте всего +0.09 к baseline - модель не разучилась работать с короткими последовательностями.
4k в финальной 8k-модели оказался даже лучше, чем у промежуточного чекпоинта (11.99 vs 12.04) - continued pretraining чуть подтянул общее качество.
На 8k получаем 13.00, что для четырёхкратного расширения контекста вполне достойно.
Длина контекста
KV + активации
Веса модели - ~2.67 ГБ в bfloat16, до 2k overhead растёт почти линейно, что подтверждает работу sparse attention, дальше становится квадратичнее.
Скорость генерации
Длина промпта
На коротких промптах модель летает благодаря KV-кешу, с 2k на 4k скорость падает в 5.6 раза, даже с KV-кешом при каждом autoregressive шаге нужно протащить внимание через всю историю.
Зато переход 4k -> 8k (2x по длине) даёт только 1.8x замедление (67 -> 38 tok/s), хотя памяти надо уже гораздо больше.
Конвертированная ruGPT3XL теперь работает правильно, PPL соответствует оригиналу, sparse attention реализован и в transformers-версии, и в llama.cpp, контекст растянут с 2k до 8k с минимальной регрессией на коротких последовательностях.
На RTX 4090 ruGPT3XL-8k пригодна на любой длине контекста, на бюджетных 8-12 ГБ карточках комфортно до 4k, что уже в два раза лучше оригинала.
Следующий очевидный шаг - instruction tuning, но это уже другая история.
- evilfreelancer/ruGPT3XL- базовая модель (2k)
- evilfreelancer/ruGPT3XL-8k- модель с расширенным контекстом (8k)
- evilfreelancer/ruGPT3XL-GGUF- GGUF версия (2k)
- EvilFreelancer/rugpt3xl-convert- исходники конвертации и скрипты
- PR #21011 в llama.cpp- GGUF конвертация (смерджен)
- PR #21161 в llama.cpp- sparse attention в llama.cpp (ещё висит)
- IlyaGusev/gazeta- датасет для тестирования и дообучения
Послесловие
Вот такой вот занятный эксперимент у меня получился, надеюсь интерес к моему маленькому проекту у читателей и подписчиков сохранится, так как хочется попробовать ещё парочку занятных вещей типа квантизацию и обучения в mxfp4, а так же конвертацию модельки в MoE формат, плюс на очереди ещё пухляшка ruGPT 3.5 на 13B параметров, короче есть ещё чем заняться.
Ну я в свою очередь благодарю вас за прочтение, надеюсь мои наработки пригодятся, буду рад фидбеку в комментариях или втелеграме.