Slide 26
Slide 26 text
import torch
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
prompt_base = "ユーザー: {}システム: "
start = time.perf_counter()
tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft", use_fast=False)
end = time.perf_counter()
print("Tokenizer loaded:"+str(end-start))
start = time.perf_counter()
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft")
#GPUメモリが12-16GBの場合、float16でなんとかメモリ内に収める
#model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-sft", torch_dtype=torch.float16)
end = time.perf_counter()
print("CausalLM loaded:"+str(end-start))
if torch.cuda.is_available():
model = model.to("cuda")
print ("cuda is available")
def encoding(prompt):
start = time.perf_counter()
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
with torch.no_grad():
output_ids = model.generate(
token_ids.to(model.device),
do_sample=True,
max_new_tokens=256,
temperature=0.9,
top_k=50,
repetition_penalty=1.0,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
output = output.replace("", "¥n")
end = time.perf_counter()
print("Encoding completed:"+str(end-start))
return output
#続き
def do_conversation():
text = input("Neox-3.6b>")
if text == "end":
return False
prompt = prompt_base.format(text)
result = encoding(prompt)
print(result)
return True
while True:
res = do_conversation()
if res == False:
break