AI测试 本地运行 Gemma 的 pytorch 集成

yyintech · 2024年03月01日 · 最后由 yyintech 回复于 2024年03月01日 · 2907 次阅读


Gemma 是 Google 在2024年2月21日发布的一款轻量的开源大模型,采用了和 Google Gemini 模型一样的技术。有猜测 Google 在毫无预告的情况下急忙发布 Gemma 是对 Meta 的 Llama3 的截胡,但不管怎么说作为名厂名牌的大模型,自然要上手尝试尝试。

这次发布的 Gemma 有 2B 参数和 7B 参数两个版本,两个版本又分别提供了预训练 (Pretrained) 和指令调试 (Instruction tuned) 两个版本。预训练版本做了基础训练,而指令调试版本做了根据人类语言交互的特定训练调整,所以如果直接拿来做会话使用可以下载 it 版本。2B 和 7B 在于参数量的多少,7B 需要更多的资源去运行。

好了,前面啰嗦了一堆背景,为了引出这里介绍 2b-it 版本地部署的原因——耗资源少且可以本地使用会话。

准备环境

  • 安装 python venv,命名 gemma-torch conda env create -n "gemma-torch"
  • 激活虚拟环境 conda activate gemma-torch
  • 安装依赖的库 pip install torch immutabledict sentencepiece numpy packaging  后面两个库不是官方文档里要求的,但是根据我执行报错,需要安装。另外上面命令也取消了-q -U 简单粗暴也方便观察。

为了后续用代码连接 kaggle 下载模型,还需要安装 kagglehub 包:

pip install kagglehub

连接 kaggle

这一步的目的是从 kaggle 上面下载模型。

  • 首先获取 kaggle 的访问权限 登录 kaggle,在设置页面 的 API 一节点击按钮 “Create New Token”,会触发下载 kaggle.json。 ​​
  • 配置环境 将 kaggle.json 文件拷贝到~/.kaggle/目录下。并在~/.bash_profile 中设置环境变量 KAGGLE_CONFIG_DIR 为~/.kaggle。

这样就可以通过下面代码访问 (后面的代码写到一块,不需要此处执行)。

import kagglehub

kagglehub.login()

运行代码

经过前面的配置后,可以代码本地运行 2b-it 模型了。不过加载模型还需要 gemma_pytorch 包。
从 github 仓库 clone 到本地:

# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git

将下载好的 gemma_pytorch 文件夹放到下面脚本文件同一级目录下 ,并在~/.bash_profile 中设置 PYTHONPATH 环境变量包含该文件夹路径。

最后运行脚本 (gemma_torch.py):

# Choose variant and machine type
import kagglehub
import os
import sys
from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM
import torch


VARIANT = '2b-it'
#如果是cpu运行,将下面cuda改成cpu,不过巨慢
MACHINE_TYPE = 'cuda'

# Load model weights
# 模型下载到了~/.cache目录下
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

# Set up model config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)

# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=60,
)

一点后话

能用 GPU 还是上 GPU 吧,我本地用的 CPU 笔记本跑的巨慢。

可以在线使用 colab,具体步骤参考这个帖子 (昨天 Google 发布了最新的开源模型 Gemma,今天我来体验一下_gemma_lm.generate-CSDN 博客)。

不过我在使用过程中发现 T4 经常在预测执行时报 OOM,导致无法产出结果。

参考资料:
pytorch 中使用 Gemma: https://ai.google.dev/gemma/docs/pytorch_gemma

官方文档地址:https://ai.google.dev/gemma/docs 

需要 登录 后方可回复, 如果你还没有账号请点击这里 注册