进我的收藏夹吃灰吧:大模型加速超全指南来了
最近,一位名为 Theia Vogel 的博主整理撰写了一篇长文博客,对加速 LLM 推理的方法进行了全面的总结,对各种方法展开了详细的介绍,值得 LLM 研究人员收藏查阅。
以下是博客原文内容。
之前,我使用经典的自回归采样器手动制作了一个 transformer,大致如下:
def generate(prompt: str, tokens_to_generate: int) -> str:
tokens = tokenize(prompt)
for i in range(tokens_to_generate):
next_token = model(tokens)
tokens.append(next_token)
return detokenize(tokens)
这种推理方法很优雅,是 LLM 工作机制的核心。自回归 LLM 在只有数千个参数的情况下运行得很好,但对于实际模型来说就太慢了。为什么会这样,我们怎样才能让它更快?
为什么简单推理这么慢?
Time to First Token(TtFT)—— 收到 prompt 和返回第一个 token 之间需要多长时间? 生成延迟 —— 收到 prompt 和返回最终 token 之间需要多长时间? 吞吐量 硬件利用率 —— 我们使用硬件的计算、内存带宽和其他功能的效率如何?
硬件
def foo(x):
s = torch.sin(x)
c = torch.cos(x)
return s + c
"trace.enabled": True, "trace.graph_diagram": True}) > compiled_foo = torch.compile(foo, options={
# call with an arbitrary value to trigger JIT >
10))) > compiled_foo(torch.tensor(range(
Writing FX graph to file: .../graph_diagram.svg
[2023-11-25 17:31:09,833] [6/0] torch._inductor.debug: [WARNING] model__24_inference_60 debug trace: /tmp/...zfa7e2jl.debug
tensor([ 1.0000, 1.3818, 0.4932, -0.8489, -1.4104, -0.6753, 0.6808, 1.4109,
0.8439, -0.4990])
extern "C" void kernel(const long* in_ptr0,
float* out_ptr0)
{
{
for(long i0=static_cast<long>(0L); i0<static_cast<long>(10L); i0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i0)];
auto tmp1 = static_cast<float>(tmp0);
auto tmp2 = std::sin(tmp1);
auto tmp3 = std::cos(tmp1);
auto tmp4 = tmp2 + tmp3;
out_ptr0[static_cast<long>(i0)] = tmp4;
}
}
}
10_000, 10_000)) > x = torch.rand((
> %timeit foo(x)
246 ms ± 8.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> %timeit compiled_foo(x)
91.3 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# (for small inputs `compiled_foo` was actually slower--not sure why)
10_000, 10_000)) > x = torch.rand((
> %timeit foo(x)
246 ms ± 8.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> %timeit compiled_foo(x)
91.3 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# (for small inputs `compiled_foo` was actually slower--not sure why)
"trace.enabled": True, "trace.graph_diagram": True}) compiled_gbreak = torch.compile(gbreak, options={
10))) compiled_gbreak(torch.tensor(range(
Writing FX graph to file: .../model__27_inference_63.9/graph_diagram.svg
[2023-11-25 17:59:32,823] [9/0] torch._inductor.debug: [WARNING] model__27_inference_63 debug trace: /tmp/torchinductor_user/p3/cp3the7mcowef7zjn7p5rugyrjdm6bhi36hf5fl4nqhqpfdqaczp.debug
Writing FX graph to file: .../graph_diagram.svg
[2023-11-25 17:59:34,815] [10/0] torch._inductor.debug: [WARNING] model__28_inference_64 debug trace: /tmp/torchinductor_user/nk/cnkikooz2z5sms2emkvwj5sml5ik67aqigynt7mp72k3muuvodlu.debug
tensor([ 1.0000, -0.1756, 2.6782, -0.7063, -2.5683, 2.7053, 0.9718, 0.5394,
7.6436, -0.0467])
extern "C" void kernel(const long* in_ptr0,
float* out_ptr0,
float* out_ptr1,
bool* out_ptr2)
{
{
{
float tmp_acc0 = 0;
for(long i0=static_cast<long>(0L); i0<static_cast<long>(10L); i0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i0)];
auto tmp1 = static_cast<float>(tmp0);
auto tmp2 = std::sin(tmp1);
auto tmp3 = std::cos(tmp1);
auto tmp4 = tmp2 + tmp3;
out_ptr0[static_cast<long>(i0)] = tmp4;
tmp_acc0 = tmp_acc0 + tmp4;
}
out_ptr1[static_cast<long>(0L)] = tmp_acc0;
}
}
{
auto tmp0 = out_ptr1[static_cast<long>(0L)];
auto tmp1 = static_cast<float>(0.0);
auto tmp2 = tmp0 < tmp1;
out_ptr2[static_cast<long>(0L)] = tmp2;
}
}
extern "C" void kernel(const float* in_ptr0,
const long* in_ptr1,
float* out_ptr0)
{
{
for(long i0=static_cast<long>(0L); i0<static_cast<long>(10L); i0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(i0)];
auto tmp1 = in_ptr1[static_cast<long>(i0)];
auto tmp2 = static_cast<float>(tmp1);
auto tmp3 = std::cos(tmp2);
auto tmp4 = tmp0 - tmp3;
out_ptr0[static_cast<long>(i0)] = tmp4;
}
}
}
# get an explanation for a given input
>>> explained = torch._dynamo.explain(gbreak)(torch.tensor(range(10)))
# there's a break, because of a jump (if) on line 3
>>> explained.break_reasons
[GraphCompileReason(reason='generic_jump TensorVariable()', user_stack=[<FrameSummary file <stdin>, line 3 in gbreak>], graph_break=True)]
# there are two graphs, since there's a break
>>> explained.graphs
[GraphModule(), GraphModule()]
# let's see what each graph implements, without needing to dive into the kernels!
>>> for g in explained.graphs:
... g.graph.print_tabular()
... print()
...
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ------------ --------
placeholder l_x_ L_x_ () {}
call_function sin <built-in method sin of type object at 0x7fd57167aaa0> (l_x_,) {}
call_function cos <built-in method cos of type object at 0x7fd57167aaa0> (l_x_,) {}
call_function add <built-in function add> (sin, cos) {}
call_method sum_1 sum (add,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((add, lt),) {}
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ----------- --------
placeholder l_x_ L_x_ () {}
placeholder l_r_ L_r_ () {}
call_function tan <built-in method tan of type object at 0x7fd57167aaa0> (l_x_,) {}
call_function sub <built-in function sub> (l_r_, tan) {}
output output output ((sub,),) {}
# pretty cool!
批处理
20 tokens x 1 sequence = ~70ms 20 tokens x 5 sequences = ~220ms (线性扩展~350ms) 20 tokens x 10 sequences = ~400ms (线性扩展~700ms)
>>> gpt2.transformer.h[0].attn.c_attn.weight.dtype
torch.float32
KV cache
# the gpt2 tokenizer produces 3 tokens for this string
" A B C").input_ids > tokens = tokenizer(
> tokens
[317, 347, 327]
# if we put that into the model, we get 3 rows of logits
> logits = gpt2(input_ids=torch.tensor(tokens)).logits.squeeze()
> logits.shape
torch.Size([3, 50257])
# and if we argmax those, we see the model is predicting a next token
# for _every_ prompt token!
1)): > for i, y in enumerate(logits.argmax(-
... print(f"{tokenizer.decode(tokens[:i+1])!r} -> {tokenizer.decode(y)!r}")
' A' -> '.'
' A B' -> ' C'
' A B C' -> ' D'
tokens
[317, 347, 327] # the " A B C" string from before
key_values = gpt2(input_ids=torch.tensor(tokens)).past_key_values
for x in t) for t in key_values) tuple(tuple(x.shape
((torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])),
(torch.Size([1, 12, 3, 64]), torch.Size([1, 12, 3, 64])))
需要预先分配比所需更多的空间; 该保留空间不能被其他请求使用,即使还不需要它; 具有相同前缀的请求不能共享该前缀的 KV 缓存。
猜测解码
>>> for i, y in enumerate(logits.argmax(-1)):
... print(f"{tokenizer.decode(tokens[:i+1])!r} -> {tokenizer.decode(y)!r}")
' A' -> '.'
' A B' -> ' C'
' A B C' -> ' D'
def generate(prompt: str, tokens_to_generate: int) -> str:
tokens: list[int] = tokenize(prompt)
TO = tokenize(" going to")
for i in range(tokens_to_generate):
if tokens[-1] == GOING:
# do our speculative decoding trick
logits = model.forward(tokens + [TO])
# the token the model predicts will follow "... going"
going_pred = argmax(logits[-2, :])
# the token the model predicts will follow "... going to"
to_pred = argmax(logits[-1, :])
if going_pred == TO:
# if our guess was correct, accept "to" and the next token after
tokens += [TO, to_pred]
else:
# otherwise, accept the real next token
# (e.g. "for" if the true generation was "going for broke")
tokens += [going_pred]
else:
# do normal single-token generation
logits = model.forward(tokens)
tokens += [argmax(logits[-1])]
return detokenize(tokens)
def generate(prompt: str, tokens_to_generate: int, n_draft: int = 8) -> str:
tokens: list[int] = tokenize(prompt)
for i in range(tokens_to_generate):
# generate `n_draft` draft tokens in the usual autoregressive way
draft = tokens[:]
for _ in range(n_draft):
logits = draft_model.forward(draft)
draft.append(argmax(logits[-1]))
# run the draft tokens through the oracle model all at once
logits = model.forward(draft)
checked = logits[len(tokens) - 1 :].argmax(-1)
# find the index of the first draft/oracle mismatch—we'll accept every
# token before it
# (the index might be past the end of the draft, if every draft token
# was correct)
n_accepted = next(
idx + 1
for idx, (checked, draft) in enumerate(
# we add None here because the oracle model generates one extra
# token (the prediction for the last draft token)
draft[len(tokens) :] + [None])
)
if checked != draft
)
n_accepted]) :
return detokenize(tokens)
def speculative_threshold(
prompt: str,
max_draft: int = 16,
threshold: float = 0.4,
threshold_all_correct_boost: float = 0.1,
:
tokens = encoder.encode(prompt)
# homegrown KV cache setup has an `n_tokens` method that returns the length
# of the cached sequence, and a `truncate` method to truncate that sequence
# to a specific token
model_kv = gpt2.KVCache()
draft_kv = gpt2.KVCache()
while True:
# generate up to `max_draft` draft tokens autoregressively, stopping
# early if we fall below `threshold`
draft = tokens[:]
drafted_probs = []
for _ in range(max_draft):
logits = draft_model.forward(draft[draft_kv.n_tokens() :], draft_kv)
next_id = np.argmax(logits[-1])
next_prob = gpt2.softmax(logits[-1])[next_id]
if not len(drafted_probs):
drafted_probs.append(next_prob)
else:
* drafted_probs[-1])
draft.append(int(next_id))
if drafted_probs[-1] < threshold:
break
n_draft = len(draft) - len(tokens)
# run draft tokens through the oracle model
logits = model.forward(draft[model_kv.n_tokens() :], model_kv)
checked = logits[-n_draft - 1 :].argmax(-1)
n_accepted = next(
idx + 1
for idx, (checked, draft) in enumerate(
draft[len(tokens) :] + [None])
)
if checked != draft
)
yield from checked[:n_accepted]
n_accepted]) :
if n_accepted <= n_draft:
# adjust threshold towards prob of last accepted token, if we
# ignored any draft tokens
threshold = (threshold + drafted_probs[n_accepted - 1]) / 2
else:
# otherwise, lower the threshold slightly, we're probably being
# too conservative
threshold -= threshold_all_correct_boost
# clamp to avoid pathological thresholds
threshold = min(max(threshold, 0.05), 0.95)
# don't include oracle token in kv cache
- 1)
- 1)
扫描二维码添加小助手微信
关于我们
微信扫码关注该文公众号作者
戳这里提交新闻线索和高质量文章给我们。
来源: qq
点击查看作者最近其他文章