import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer
from threading import Thread
from typing import Dict
CHECKPOINT = "tiiuae/falcon-7b-instruct"
class Model:
def __init__(self, **kwargs) -> None:
self.tokenizer = None
self.model = None
def load(self):
self.tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
self.model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto"
)
def predict(self, request: Dict):
prompt = request["prompt"]
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True, padding=True)
input_ids = inputs["input_ids"].to("cuda")
streamer = TextIteratorStreamer(self.tokenizer)
generation_config = GenerationConfig(temperature=1, top_p=0.95, top_k=40)
def generate():
self.model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
pad_token_id=self.tokenizer.eos_token_id,
max_new_tokens=150,
streamer=streamer,
)
thread = Thread(target=generate)
thread.start()
def stream_output():
for text in streamer:
yield text
thread.join()
return stream_output()