通义千问-7B-预训练-Embedding

我要开发同款
匿名用户2024年07月31日
33阅读
所属分类ai、qwen、pytorch
开源地址https://modelscope.cn/models/sccHyFuture/Qwen-7B-Embedding

作品详情

Qwen-7B-Embedding

1- 基本使用示范

import torch 
from torch import nn
from modelscope.hub.snapshot_download import snapshot_download
from modelscope import AutoTokenizer 
from os.path import join as p_join


snapshot_download(model_id="sccHyFuture/Qwen-7B-Embedding", cache_dir="./")
tokenizer = AutoTokenizer.from_pretrained("qwen/Qwen-7B", trust_remote_code=True)
ipts = tokenizer('一只猫', return_tensors='pt')['input_ids']
embed_dict = torch.load('./sccHyFuture/Qwen-7B-Embedding/qwen_7b_embed.pth')
vocab_size, embd_dim = embed_dict['weight'].size()
embed = nn.Embedding(vocab_size, embd_dim)
embed(ipts)
embed.load_state_dict(embed_dict)
embed(ipts)

2- chormadb Embedding提换示例

详细见chromadb_embd.py

class QwenEmbeddingFunction(EmbeddingFunction):
    def __init__(self,
            model_name: str = "qwen-7B",
            device: str = "cpu",
            normalize_embeddings: bool = False,
            max_length: int=128
        ):
        self.model_name = model_name
        self.device = device
        self.normalize_embeddings = normalize_embeddings
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained("qwen/Qwen-7B", trust_remote_code=True, pad_token='<|endoftext|>')
        self.emb = self._load_embd().to(self.device)

    @torch.no_grad()
    def __call__(self, input: Documents) -> Embeddings:
        tk = self.tokenizer(input, return_tensors='pt', padding=True, truncation=True, max_length=self.max_length)['input_ids'].to(self.device)
        emb_out = self.emb(tk).detach()
        if self.normalize_embeddings:
            emb_out = torch.nn.functional.normalize(emb_out, p=2, dim=1)
        return cast(
            Embeddings,
            [emb.numpy().flatten().tolist() for emb in emb_out] 
        )

    def _load_embd(self):
        sep = os.path.sep
        embed_name_id = 'sccHyFuture/Qwen-7B-Embedding'
        local_embed_file = f'.{sep}{embed_name_id.replace("/", sep)}{sep}qwen_7b_embed.pth'
        if not os.path.exists(local_embed_file):
            snapshot_download(model_id=embed_name_id, cache_dir=f"./")

        embed_dict = torch.load(local_embed_file, map_location=self.device) 
        vocab_size, embd_dim = embed_dict['weight'].size()
        embed = torch.nn.Embedding(vocab_size, embd_dim)
        embed.load_state_dict(embed_dict)
        return embed
声明:本文仅代表作者观点,不代表本站立场。如果侵犯到您的合法权益,请联系我们删除侵权资源!如果遇到资源链接失效,请您通过评论或工单的方式通知管理员。未经允许,不得转载,本站所有资源文章禁止商业使用运营!
下载安装【程序员客栈】APP
实时对接需求、及时收发消息、丰富的开放项目需求、随时随地查看项目状态

评论