【求助帖】 点击停止按钮,生成并没有结束 #264

Open
opened 2024-10-22 10:29:06 +08:00 by rebibabo · 2 comments

在上周三的Chinese LLaMA Alpaca系列模型OpenAI API调用实现课程中,GPU释放问题部分,我原封不动的粘贴了老师的代码,并且在try代码中,打印了chunk_str,当我在NextChat中点击停止的时候,并没有触发except asyncio.CancelledError as e,然后停止,而是一直打印chunk_str直到最后一个字输出,NextChat点击暂停触发的是什么事件呢?这个except asyncio.CancelledError as e应该改成什么呢
image

在上周三的Chinese LLaMA Alpaca系列模型OpenAI API调用实现课程中,GPU释放问题部分,我原封不动的粘贴了老师的代码,并且在try代码中,打印了chunk_str,当我在NextChat中点击停止的时候,并没有触发except asyncio.CancelledError as e,然后停止,而是一直打印chunk_str直到最后一个字输出,NextChat点击暂停触发的是什么事件呢?这个except asyncio.CancelledError as e应该改成什么呢 ![image](/attachments/54eb5b1f-1ff2-4f88-b57e-d1ac4711f54d)
Owner

可以将整个代码文件附上来么?便于排查

可以将整个代码文件附上来么?便于排查
Author

好的,下面是代码

import argparse
import os
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from threading import Thread
from sse_starlette.sse import EventSourceResponse
import asyncio
from transformers import StoppingCriteria, StoppingCriteriaList

parser = argparse.ArgumentParser()
parser.add_argument('--port', default=19327, type=int)
parser.add_argument('--base_model', default=None, type=str, required=True)
parser.add_argument('--lora_model', default=None, type=str,help="If None, perform inference on the base model")
parser.add_argument('--tokenizer_path',default=None,type=str)
parser.add_argument('--gpus', default="0", type=str)
parser.add_argument('--load_in_8bit',action='store_true', help='Load the model in 8bit mode')
parser.add_argument('--load_in_4bit',action='store_true', help='Load the model in 4bit mode')
parser.add_argument('--only_cpu',action='store_true',help='Only use CPU for inference')
parser.add_argument('--use_flash_attention_2', action='store_true', help="Use flash-attention2 to accelerate inference")
args = parser.parse_args()
if args.only_cpu is True:
    args.gpus = ""
    if args.load_in_8bit or args.load_in_4bit:
        raise ValueError("Quantization is unavailable on CPU.")
if args.load_in_8bit and args.load_in_4bit:
    raise ValueError("Only one quantization method can be chosen for inference. Please check your arguments")
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus

import torch
import torch.nn.functional as F
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    TextIteratorStreamer,
    BitsAndBytesConfig
)
from peft import PeftModel

import sys

parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)

from openai_api_protocol import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ChatMessage,
    ChatCompletionResponseChoice,
    CompletionRequest,
    CompletionResponse,
    CompletionResponseChoice,
    EmbeddingsRequest,
    EmbeddingsResponse,
    ChatCompletionResponseStreamChoice,
    DeltaMessage,
)

class StopOnEvent(StoppingCriteria):
    def __init__(self) -> None:
        self.is_stop = False
        super().__init__()
        
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if self.is_stop:
            return True
        return False

load_type = torch.float16

# Move the model to the MPS device if available
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    if torch.cuda.is_available():
        device = torch.device(0)
    else:
        device = torch.device('cpu')
print(f"Using device: {device}")

if args.tokenizer_path is None:
    args.tokenizer_path = args.lora_model
    if args.lora_model is None:
        args.tokenizer_path = args.base_model
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, legacy=True)
if args.load_in_4bit or args.load_in_8bit:
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=args.load_in_4bit,
        load_in_8bit=args.load_in_8bit,
        bnb_4bit_compute_dtype=load_type,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
base_model = AutoModelForCausalLM.from_pretrained(
    args.base_model,
    torch_dtype=load_type,
    low_cpu_mem_usage=True,
    device_map='auto' if not args.only_cpu else None,
    #load_in_4bit=args.load_in_4bit,
    #load_in_8bit=args.load_in_8bit,
    quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None,
    attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa",
    trust_remote_code=True
)

model_vocab_size = base_model.get_input_embeddings().weight.size(0)
tokenizer_vocab_size = len(tokenizer)
print(f"Vocab of the base model: {model_vocab_size}")
print(f"Vocab of the tokenizer: {tokenizer_vocab_size}")
if model_vocab_size != tokenizer_vocab_size:
    print("Resize model embeddings to fit tokenizer")
    base_model.resize_token_embeddings(tokenizer_vocab_size)
if args.lora_model is not None:
    print("loading peft model")
    model = PeftModel.from_pretrained(
        base_model,
        args.lora_model,
        torch_dtype=load_type,
        device_map="auto",
    )
else:
    model = base_model

if device == torch.device("cpu"):
    model.float()

model.eval()

DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant. 你是一个乐于助人的助手。"

TEMPLATE_WITH_SYSTEM_PROMPT = (
    """<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"""
)
TEMPLATE_WITHOUT_SYSTEM_PROMPT = """<|start_header_id|>user<|end_header_id|>\n\n{message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"""

def generate_prompt(
    message, response="", with_system_prompt=False, system_prompt=None
):
    if with_system_prompt is True:
        if system_prompt is None:
            system_prompt = DEFAULT_SYSTEM_PROMPT
        prompt = TEMPLATE_WITH_SYSTEM_PROMPT.format_map(
            {"message": message, "system_prompt": system_prompt}
        )
    else:
        prompt = TEMPLATE_WITHOUT_SYSTEM_PROMPT.format_map({"message": message})
    if len(response) > 0:
        prompt += " " + response
    return prompt


def generate_completion_prompt(message: str):
    """Generate prompt for completion"""
    return generate_prompt(message, response="", with_system_prompt=False)


def generate_chat_prompt(messages: list):
    """Generate prompt for chat completion"""

    system_msg = None
    for msg in messages:
        if msg.role == "system":
            system_msg = msg.content
    prompt = ""
    is_first_user_content = True
    for msg in messages:
        if msg.role == "system":
            continue
        if msg.role == "user":
            if is_first_user_content is True:
                prompt += generate_prompt(
                    msg.content, with_system_prompt=True, system_prompt=system_msg
                )
                is_first_user_content = False
            else:
                prompt += generate_prompt(msg.content, with_system_prompt=False)
        if msg.role == "assistant":
            prompt += f"{msg.content}" + "<|eot_id|>"
    return prompt


def predict(
    input,
    max_new_tokens=1024,
    top_p=0.9,
    temperature=0.2,
    top_k=40,
    num_beams=1,
    repetition_penalty=1.1,
    do_sample=True,
    **kwargs,
):
    """
    Main inference method
    type(input) == str -> /v1/completions
    type(input) == list -> /v1/chat/completions
    """
    if isinstance(input, str):
        prompt = generate_completion_prompt(input)
    else:
        prompt = generate_chat_prompt(input)
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        num_beams=num_beams,
        do_sample=do_sample,
        **kwargs,
    )
    generation_config.return_dict_in_generate = True
    generation_config.output_scores = False
    generation_config.max_new_tokens = max_new_tokens
    generation_config.repetition_penalty = float(repetition_penalty)

    # c.f. llama-3-instruct generation_config
    llama3_eos_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]

    # For the reason why pad_token_id = eos_token_id, see:
    # https://github.com/meta-llama/llama-recipes/blob/f7aa02af9f2c427ebb70853191b72636130b9df5/src/llama_recipes/finetuning.py#L141
    with torch.no_grad():
        generation_output = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            eos_token_id=llama3_eos_ids,
            pad_token_id=tokenizer.eos_token_id,
            generation_config=generation_config,
        )
    s = generation_output.sequences[0]
    output = tokenizer.decode(s, skip_special_tokens=True)
    #output = output.split("<|eot_id|>")[-1].strip()
    output = output.split("assistant\n\n")[-1].strip()
    return output


def stream_predict(
    input,
    max_new_tokens=1024,
    top_p=0.9,
    temperature=0.2,
    top_k=40,
    num_beams=4,
    repetition_penalty=1.1,
    do_sample=True,
    model_id="llama-3-chinese",
    **kwargs,
):
    choice_data = ChatCompletionResponseStreamChoice(
        index=0, delta=DeltaMessage(role="assistant"), finish_reason=None
    )
    chunk = ChatCompletionResponse(
        model=model_id,
        choices=[choice_data],
        object="chat.completion.chunk",
    )
    yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))

    if isinstance(input, str):
        prompt = generate_completion_prompt(input)
    else:
        prompt = generate_chat_prompt(input)
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    generation_config = GenerationConfig(
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        num_beams=num_beams,
        do_sample=do_sample,
        **kwargs,
    )

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = dict(
        streamer=streamer,
        input_ids=input_ids,
        generation_config=generation_config,
        return_dict_in_generate=True,
        output_scores=False,
        max_new_tokens=max_new_tokens,
        repetition_penalty=float(repetition_penalty),
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=[tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")]
    )
    Thread(target=model.generate, kwargs=generation_kwargs).start()
    for new_text in streamer:
        choice_data = ChatCompletionResponseStreamChoice(
            index=0, delta=DeltaMessage(content=new_text), finish_reason=None
        )
        chunk = ChatCompletionResponse(
            model=model_id, choices=[choice_data], object="chat.completion.chunk"
        )
        yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
    choice_data = ChatCompletionResponseStreamChoice(
        index=0, delta=DeltaMessage(), finish_reason="stop"
    )
    chunk = ChatCompletionResponse(
        model=model_id, choices=[choice_data], object="chat.completion.chunk"
    )
    yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
    yield "[DONE]"


def get_embedding(input):
    """Get embedding main function"""
    with torch.no_grad():
        encoding = tokenizer(input, padding=True, return_tensors="pt")
        input_ids = encoding["input_ids"].to(device)
        attention_mask = encoding["attention_mask"].to(device)
        model_output = model(input_ids, attention_mask, output_hidden_states=True)
        data = model_output.hidden_states[-1]
        mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
        masked_embeddings = data * mask
        sum_embeddings = torch.sum(masked_embeddings, dim=1)
        seq_length = torch.sum(mask, dim=1)
        embedding = sum_embeddings / seq_length
        normalized_embeddings = F.normalize(embedding, p=2, dim=1)
        ret = normalized_embeddings.squeeze(0).tolist()
    return ret


app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
    """Creates a completion for the chat message"""
    async def event_publisher(
        input,
        max_new_tokens=4096,
        top_p=0.75,
        temperature=0.1,
        top_k=40,
        num_beams=4,
        repetition_penalty=1.0,
        do_sample=True,
        model_id="chinese-llama-alpaca-2",
        **kwargs,
    ):
        choice_data = ChatCompletionResponseStreamChoice(
            index=0,
            delta=DeltaMessage(role="assistant"),
            finish_reason=None
        )
        chunk = ChatCompletionResponse(
            model=model_id,
            choices=[choice_data],
            object="chat.completion.chunk",
        )
        yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
        
        if isinstance(input, str):
            prompt = generate_completion_prompt(input)
        else:
            prompt = generate_chat_prompt(input)

        inputs = tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to(device)
        generation_config = GenerationConfig(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            num_beams=num_beams,
            do_sample=do_sample,
            **kwargs,
        )
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
        stop = StopOnEvent()
        generation_kwargs = dict(
            streamer=streamer,
            input_ids=input_ids,
            generation_config=generation_config,
            return_dict_in_generate=True,
            output_scores=False,
            max_new_tokens=max_new_tokens,
            repetition_penalty=float(repetition_penalty),
            stopping_criteria=StoppingCriteriaList([stop]),
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
        )
        i = 0
        Thread(target=model.generate, kwargs=generation_kwargs).start()
        try:
            for new_text in streamer:
                choice_data = ChatCompletionResponseStreamChoice(
                    index=0,
                    delta=DeltaMessage(content=new_text),
                    finish_reason=None
                )
                chunk = ChatCompletionResponse(
                    model=model_id,
                    choices=[choice_data],
                    object="chat.completion.chunk"
                )
                chunk_str = "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
                print(chunk_str)
                i += 1
                yield chunk_str
                await asyncio.sleep(0.03)
        except asyncio.CancelledError as e:
            print(f"Disconnected from client (via refresh/close)")
            stop.is_stop = True
            raise e
    
    msgs = request.messages
    if isinstance(msgs, str):
        msgs = [ChatMessage(role="user", content=msgs)]
    else:
        msgs = [ChatMessage(role=x["role"], content=x["content"]) for x in msgs]
    if request.stream:
        # generate = stream_predict(
        generate = event_publisher(
            input=msgs,
            max_new_tokens=request.max_tokens,
            top_p=request.top_p,
            top_k=request.top_k,
            temperature=request.temperature,
            num_beams=request.num_beams,
            repetition_penalty=request.repetition_penalty,
            do_sample=request.do_sample,
        )
        return EventSourceResponse(generate, media_type="text/event-stream")
    output = predict(
        input=msgs,
        max_new_tokens=request.max_tokens,
        top_p=request.top_p,
        top_k=request.top_k,
        temperature=request.temperature,
        num_beams=request.num_beams,
        repetition_penalty=request.repetition_penalty,
        do_sample=request.do_sample,
    )
    choices = [
        ChatCompletionResponseChoice(index=i, message=msg) for i, msg in enumerate(msgs)
    ]
    choices += [
        ChatCompletionResponseChoice(
            index=len(choices), message=ChatMessage(role="assistant", content=output)
        )
    ]
    return ChatCompletionResponse(choices=choices)


@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
    """Creates a completion"""
    output = predict(
        input=request.prompt,
        max_new_tokens=request.max_tokens,
        top_p=request.top_p,
        top_k=request.top_k,
        temperature=request.temperature,
        num_beams=request.num_beams,
        repetition_penalty=request.repetition_penalty,
        do_sample=request.do_sample,
    )
    choices = [CompletionResponseChoice(index=0, text=output)]
    return CompletionResponse(choices=choices)


@app.post("/v1/embeddings")
async def create_embeddings(request: EmbeddingsRequest):
    """Creates text embedding"""
    embedding = get_embedding(request.input)
    data = [{"object": "embedding", "embedding": embedding, "index": 0}]
    return EmbeddingsResponse(data=data)


if __name__ == "__main__":
    log_config = uvicorn.config.LOGGING_CONFIG
    log_config["formatters"]["access"][
        "fmt"
    ] = "%(asctime)s - %(levelname)s - %(message)s"
    log_config["formatters"]["default"][
        "fmt"
    ] = "%(asctime)s - %(levelname)s - %(message)s"
    uvicorn.run(app, host="0.0.0.0", port=args.port, workers=1, log_config=log_config)

好的,下面是代码 ```python import argparse import os from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware import uvicorn from threading import Thread from sse_starlette.sse import EventSourceResponse import asyncio from transformers import StoppingCriteria, StoppingCriteriaList parser = argparse.ArgumentParser() parser.add_argument('--port', default=19327, type=int) parser.add_argument('--base_model', default=None, type=str, required=True) parser.add_argument('--lora_model', default=None, type=str,help="If None, perform inference on the base model") parser.add_argument('--tokenizer_path',default=None,type=str) parser.add_argument('--gpus', default="0", type=str) parser.add_argument('--load_in_8bit',action='store_true', help='Load the model in 8bit mode') parser.add_argument('--load_in_4bit',action='store_true', help='Load the model in 4bit mode') parser.add_argument('--only_cpu',action='store_true',help='Only use CPU for inference') parser.add_argument('--use_flash_attention_2', action='store_true', help="Use flash-attention2 to accelerate inference") args = parser.parse_args() if args.only_cpu is True: args.gpus = "" if args.load_in_8bit or args.load_in_4bit: raise ValueError("Quantization is unavailable on CPU.") if args.load_in_8bit and args.load_in_4bit: raise ValueError("Only one quantization method can be chosen for inference. Please check your arguments") os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus import torch import torch.nn.functional as F from transformers import ( AutoModelForCausalLM, AutoTokenizer, GenerationConfig, TextIteratorStreamer, BitsAndBytesConfig ) from peft import PeftModel import sys parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) from openai_api_protocol import ( ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ChatCompletionResponseChoice, CompletionRequest, CompletionResponse, CompletionResponseChoice, EmbeddingsRequest, EmbeddingsResponse, ChatCompletionResponseStreamChoice, DeltaMessage, ) class StopOnEvent(StoppingCriteria): def __init__(self) -> None: self.is_stop = False super().__init__() def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: if self.is_stop: return True return False load_type = torch.float16 # Move the model to the MPS device if available if torch.backends.mps.is_available(): device = torch.device("mps") else: if torch.cuda.is_available(): device = torch.device(0) else: device = torch.device('cpu') print(f"Using device: {device}") if args.tokenizer_path is None: args.tokenizer_path = args.lora_model if args.lora_model is None: args.tokenizer_path = args.base_model tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, legacy=True) if args.load_in_4bit or args.load_in_8bit: quantization_config = BitsAndBytesConfig( load_in_4bit=args.load_in_4bit, load_in_8bit=args.load_in_8bit, bnb_4bit_compute_dtype=load_type, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) base_model = AutoModelForCausalLM.from_pretrained( args.base_model, torch_dtype=load_type, low_cpu_mem_usage=True, device_map='auto' if not args.only_cpu else None, #load_in_4bit=args.load_in_4bit, #load_in_8bit=args.load_in_8bit, quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None, attn_implementation="flash_attention_2" if args.use_flash_attention_2 else "sdpa", trust_remote_code=True ) model_vocab_size = base_model.get_input_embeddings().weight.size(0) tokenizer_vocab_size = len(tokenizer) print(f"Vocab of the base model: {model_vocab_size}") print(f"Vocab of the tokenizer: {tokenizer_vocab_size}") if model_vocab_size != tokenizer_vocab_size: print("Resize model embeddings to fit tokenizer") base_model.resize_token_embeddings(tokenizer_vocab_size) if args.lora_model is not None: print("loading peft model") model = PeftModel.from_pretrained( base_model, args.lora_model, torch_dtype=load_type, device_map="auto", ) else: model = base_model if device == torch.device("cpu"): model.float() model.eval() DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant. 你是一个乐于助人的助手。" TEMPLATE_WITH_SYSTEM_PROMPT = ( """<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""" ) TEMPLATE_WITHOUT_SYSTEM_PROMPT = """<|start_header_id|>user<|end_header_id|>\n\n{message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""" def generate_prompt( message, response="", with_system_prompt=False, system_prompt=None ): if with_system_prompt is True: if system_prompt is None: system_prompt = DEFAULT_SYSTEM_PROMPT prompt = TEMPLATE_WITH_SYSTEM_PROMPT.format_map( {"message": message, "system_prompt": system_prompt} ) else: prompt = TEMPLATE_WITHOUT_SYSTEM_PROMPT.format_map({"message": message}) if len(response) > 0: prompt += " " + response return prompt def generate_completion_prompt(message: str): """Generate prompt for completion""" return generate_prompt(message, response="", with_system_prompt=False) def generate_chat_prompt(messages: list): """Generate prompt for chat completion""" system_msg = None for msg in messages: if msg.role == "system": system_msg = msg.content prompt = "" is_first_user_content = True for msg in messages: if msg.role == "system": continue if msg.role == "user": if is_first_user_content is True: prompt += generate_prompt( msg.content, with_system_prompt=True, system_prompt=system_msg ) is_first_user_content = False else: prompt += generate_prompt(msg.content, with_system_prompt=False) if msg.role == "assistant": prompt += f"{msg.content}" + "<|eot_id|>" return prompt def predict( input, max_new_tokens=1024, top_p=0.9, temperature=0.2, top_k=40, num_beams=1, repetition_penalty=1.1, do_sample=True, **kwargs, ): """ Main inference method type(input) == str -> /v1/completions type(input) == list -> /v1/chat/completions """ if isinstance(input, str): prompt = generate_completion_prompt(input) else: prompt = generate_chat_prompt(input) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) attention_mask = inputs['attention_mask'].to(device) generation_config = GenerationConfig( temperature=temperature, top_p=top_p, top_k=top_k, num_beams=num_beams, do_sample=do_sample, **kwargs, ) generation_config.return_dict_in_generate = True generation_config.output_scores = False generation_config.max_new_tokens = max_new_tokens generation_config.repetition_penalty = float(repetition_penalty) # c.f. llama-3-instruct generation_config llama3_eos_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")] # For the reason why pad_token_id = eos_token_id, see: # https://github.com/meta-llama/llama-recipes/blob/f7aa02af9f2c427ebb70853191b72636130b9df5/src/llama_recipes/finetuning.py#L141 with torch.no_grad(): generation_output = model.generate( input_ids=input_ids, attention_mask=attention_mask, eos_token_id=llama3_eos_ids, pad_token_id=tokenizer.eos_token_id, generation_config=generation_config, ) s = generation_output.sequences[0] output = tokenizer.decode(s, skip_special_tokens=True) #output = output.split("<|eot_id|>")[-1].strip() output = output.split("assistant\n\n")[-1].strip() return output def stream_predict( input, max_new_tokens=1024, top_p=0.9, temperature=0.2, top_k=40, num_beams=4, repetition_penalty=1.1, do_sample=True, model_id="llama-3-chinese", **kwargs, ): choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role="assistant"), finish_reason=None ) chunk = ChatCompletionResponse( model=model_id, choices=[choice_data], object="chat.completion.chunk", ) yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) if isinstance(input, str): prompt = generate_completion_prompt(input) else: prompt = generate_chat_prompt(input) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) generation_config = GenerationConfig( temperature=temperature, top_p=top_p, top_k=top_k, num_beams=num_beams, do_sample=do_sample, **kwargs, ) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( streamer=streamer, input_ids=input_ids, generation_config=generation_config, return_dict_in_generate=True, output_scores=False, max_new_tokens=max_new_tokens, repetition_penalty=float(repetition_penalty), pad_token_id=tokenizer.eos_token_id, eos_token_id=[tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")] ) Thread(target=model.generate, kwargs=generation_kwargs).start() for new_text in streamer: choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(content=new_text), finish_reason=None ) chunk = ChatCompletionResponse( model=model_id, choices=[choice_data], object="chat.completion.chunk" ) yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(), finish_reason="stop" ) chunk = ChatCompletionResponse( model=model_id, choices=[choice_data], object="chat.completion.chunk" ) yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield "[DONE]" def get_embedding(input): """Get embedding main function""" with torch.no_grad(): encoding = tokenizer(input, padding=True, return_tensors="pt") input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) model_output = model(input_ids, attention_mask, output_hidden_states=True) data = model_output.hidden_states[-1] mask = attention_mask.unsqueeze(-1).expand(data.size()).float() masked_embeddings = data * mask sum_embeddings = torch.sum(masked_embeddings, dim=1) seq_length = torch.sum(mask, dim=1) embedding = sum_embeddings / seq_length normalized_embeddings = F.normalize(embedding, p=2, dim=1) ret = normalized_embeddings.squeeze(0).tolist() return ret app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest): """Creates a completion for the chat message""" async def event_publisher( input, max_new_tokens=4096, top_p=0.75, temperature=0.1, top_k=40, num_beams=4, repetition_penalty=1.0, do_sample=True, model_id="chinese-llama-alpaca-2", **kwargs, ): choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role="assistant"), finish_reason=None ) chunk = ChatCompletionResponse( model=model_id, choices=[choice_data], object="chat.completion.chunk", ) yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) if isinstance(input, str): prompt = generate_completion_prompt(input) else: prompt = generate_chat_prompt(input) inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) generation_config = GenerationConfig( temperature=temperature, top_p=top_p, top_k=top_k, num_beams=num_beams, do_sample=do_sample, **kwargs, ) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) stop = StopOnEvent() generation_kwargs = dict( streamer=streamer, input_ids=input_ids, generation_config=generation_config, return_dict_in_generate=True, output_scores=False, max_new_tokens=max_new_tokens, repetition_penalty=float(repetition_penalty), stopping_criteria=StoppingCriteriaList([stop]), pad_token_id=tokenizer.eos_token_id, eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")] ) i = 0 Thread(target=model.generate, kwargs=generation_kwargs).start() try: for new_text in streamer: choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(content=new_text), finish_reason=None ) chunk = ChatCompletionResponse( model=model_id, choices=[choice_data], object="chat.completion.chunk" ) chunk_str = "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) print(chunk_str) i += 1 yield chunk_str await asyncio.sleep(0.03) except asyncio.CancelledError as e: print(f"Disconnected from client (via refresh/close)") stop.is_stop = True raise e msgs = request.messages if isinstance(msgs, str): msgs = [ChatMessage(role="user", content=msgs)] else: msgs = [ChatMessage(role=x["role"], content=x["content"]) for x in msgs] if request.stream: # generate = stream_predict( generate = event_publisher( input=msgs, max_new_tokens=request.max_tokens, top_p=request.top_p, top_k=request.top_k, temperature=request.temperature, num_beams=request.num_beams, repetition_penalty=request.repetition_penalty, do_sample=request.do_sample, ) return EventSourceResponse(generate, media_type="text/event-stream") output = predict( input=msgs, max_new_tokens=request.max_tokens, top_p=request.top_p, top_k=request.top_k, temperature=request.temperature, num_beams=request.num_beams, repetition_penalty=request.repetition_penalty, do_sample=request.do_sample, ) choices = [ ChatCompletionResponseChoice(index=i, message=msg) for i, msg in enumerate(msgs) ] choices += [ ChatCompletionResponseChoice( index=len(choices), message=ChatMessage(role="assistant", content=output) ) ] return ChatCompletionResponse(choices=choices) @app.post("/v1/completions") async def create_completion(request: CompletionRequest): """Creates a completion""" output = predict( input=request.prompt, max_new_tokens=request.max_tokens, top_p=request.top_p, top_k=request.top_k, temperature=request.temperature, num_beams=request.num_beams, repetition_penalty=request.repetition_penalty, do_sample=request.do_sample, ) choices = [CompletionResponseChoice(index=0, text=output)] return CompletionResponse(choices=choices) @app.post("/v1/embeddings") async def create_embeddings(request: EmbeddingsRequest): """Creates text embedding""" embedding = get_embedding(request.input) data = [{"object": "embedding", "embedding": embedding, "index": 0}] return EmbeddingsResponse(data=data) if __name__ == "__main__": log_config = uvicorn.config.LOGGING_CONFIG log_config["formatters"]["access"][ "fmt" ] = "%(asctime)s - %(levelname)s - %(message)s" log_config["formatters"]["default"][ "fmt" ] = "%(asctime)s - %(levelname)s - %(message)s" uvicorn.run(app, host="0.0.0.0", port=args.port, workers=1, log_config=log_config) ```
Sign in to join this conversation.
No Milestone
No project
No Assignees
2 Participants
Notifications
Due Date
The due date is invalid or out of range. Please use the format 'yyyy-mm-dd'.

No due date set.

Dependencies

No dependencies set.

Reference: HswOAuth/llm_course#264
No description provided.