Speculative Decoding
paper
Inputs: Mp,Mq,"prefix".
Sample gamma guesses x(1,dots.h.c,gamma) from Mq autoregressively.
for i = 1 to gamma do
qi(x)<−Mq("prefix"+[x1,dots.h.c,x(i−1)])
xitilde.opqi(x)
end for
Run Mp in parallel.
p1(x),dots.h.c,p(gamma+1)(x)<−Mp("prefix"),dots.h.c,Mp("prefix"+[x1,dots.h.c,xgamma])
Determine the number of accepted guesses n.
r1tilde.opU(0,1),dots.h.c,rgammatilde.opU(0,1)
n<−min(i−1∣1lt.eqilt.eqgamma,ri>pi(x)/qi(x)uniongamma)
Adjust the distribution from Mp if needed.
tilde(p)(x)<−p(n+1)(x)
if n<gamma then
tilde(p)(x)<−"norm"(max0,p(n+1)(x)−q(n+1)(x))
end if
Return one token from Mp, and n tokens from Mq.
ttilde.optilde(p)(x)
return "prefix"+[x1,dots.h.c,xn,t]
We will now show that for any distributions p(x) and q(x), the tokens sampled via speculative sampling from p(x) and q(x)
are distributed identically to those sampled from p(x) alone. Let beta be the acceptance probability (Definition).
Note that as
tilde(p)(x) &= "norm"(max {0, p(x)-q(x)}) \
&= (p(x)-min{q(x), p(x)}) / (sum_(x prime)p(x prime)-min{q(x prime), p(x prime)}) \
&= (p(x)-min{q(x), p(x)}) / (1-beta),
the normalizing constant for the adjusted distribution tilde(p)(x) is 1−beta, where the last equation follows immediately from Lemma 3.3 and Theorem 3.5.
Now:
P(x=xprime)=P("guessaccepted",x=xprime)+P("guessrejected",x=xprime)
Where:
P("guessaccepted",x=xprime)=q(xprime)min1,p(xprime)/q(xprime)=minq(xprime),p(xprime)
And:
P("guessrejected",x=xprime)=(1−beta)tilde(p)(xprime)=p(xprime)−minq(xprime),p(xprime)
Overall:
P(x=xprime)=minp(xprime),q(xprime)+p(xprime)−minp(xprime),q(xprime)=p(xprime).
As desired.
Definition
The acceptance rate beta(x(ltt)) , given a prefix x(ltt), is the probability of accepting xttilde.opq(xt∣x(ltt)) by speculative sampling.
Lemma
Define
D(LK)(p,q)=sumx∣p(x)−M(x)∣=sumx∣q(x)−M(x)∣
where M(x)=(p(x)+q(x))/2. Then
D(LK)(p,q)=1−sumxminp(x),q(x)
Proof.
D(LK)(p,q)=sumx∣p(x)−M(x)∣=sumx(∣p−q∣)/2=1−sumx(p+q−∣p−q∣)/2=1−sumxminp(x),q(x)