Divergent

Speculative Decoding

Speculative Decoding

paper

Algorithm of Speculative Decoding Step

Inputs: Mp,Mq,"prefix"M_p, M_q, "prefix".

Sample gammagamma guesses x(1,dots.h.c,gamma)x_(1,dots.h.c,gamma) from MqM_q autoregressively.

for i = 1 to gammagamma doqi(x)<Mq("prefix"+[x1,dots.h.c,x(i1)])q_i(x) <- M_q("prefix" + [x_1,dots.h.c, x_(i−1)])xitilde.opqi(x)x_i tilde.op q_i(x) end for

Run MpM_p in parallel.

p1(x),dots.h.c,p(gamma+1)(x)<Mp("prefix"),dots.h.c,Mp("prefix"+[x1,dots.h.c,xgamma])p_1(x),dots.h.c,p_(gamma+1)(x) <- M_p("prefix"),dots.h.c, M_p("prefix" + [x_1,dots.h.c, x_gamma])

Determine the number of accepted guesses nn.

r1tilde.opU(0,1),dots.h.c,rgammatilde.opU(0,1)r_1 tilde.op U(0, 1),dots.h.c, r_gamma tilde.op U(0, 1)

n<min(i11lt.eqilt.eqgamma,ri>pi(x)/qi(x)uniongamma)n <- min({i−1|1 lt.eq i lt.eq gamma, r_i>p_i(x)/q_i(x)} union {gamma})

Adjust the distribution from MpM_p if needed.

tilde(p)(x)<p(n+1)(x)tilde(p)(x) <- p_(n+1)(x)

if n<gamman<gamma thentilde(p)(x)<"norm"(max0,p(n+1)(x)q(n+1)(x))tilde(p)(x) <- "norm"(max {0,p_(n+1)(x) − q_(n+1)(x)})

end if

Return one token from MpM_p, and nn tokens from MqM_q.

ttilde.optilde(p)(x)t tilde.op tilde(p)(x)

return "prefix"+[x1,dots.h.c,xn,t]"prefix" + [x_1, dots.h.c, x_n, t]

Correctness of Speculative Sampling

We will now show that for any distributions p(x)p(x) and q(x)q(x), the tokens sampled via speculative sampling from p(x)p(x) and q(x)q(x)

are distributed identically to those sampled from p(x)p(x) alone. Let betabeta be the acceptance probability (Definition).

Note that as

tilde(p)(x) &= "norm"(max {0, p(x)-q(x)}) \ &= (p(x)-min{q(x), p(x)}) / (sum_(x prime)p(x prime)-min{q(x prime), p(x prime)}) \ &= (p(x)-min{q(x), p(x)}) / (1-beta),

the normalizing constant for the adjusted distribution tilde(p)(x)tilde(p)(x) is 1beta1−beta, where the last equation follows immediately from Lemma 3.3 and Theorem 3.5.

Now:

P(x=xprime)=P("guessaccepted",x=xprime)+P("guessrejected",x=xprime)P(x=x prime) = P("guess accepted", x=x prime) + P("guess rejected", x=x prime)

Where:

P("guessaccepted",x=xprime)=q(xprime)min1,p(xprime)/q(xprime)=minq(xprime),p(xprime)P("guess accepted", x=x prime) = q(x prime) min {1,p(x prime)/q(x prime)} = min{q(x prime), p(x prime)}

And:

P("guessrejected",x=xprime)=(1beta)tilde(p)(xprime)=p(xprime)minq(xprime),p(xprime)P("guess rejected", x=x prime) = (1−beta)tilde(p)(x prime) = p(x prime) − min{q(x prime), p(x prime)}

Overall:

P(x=xprime)=minp(xprime),q(xprime)+p(xprime)minp(xprime),q(xprime)=p(xprime).P(x=x prime) = min{p(x prime), q(x prime)} + p(x prime) − min {p(x prime), q(x prime)} = p(x prime).

As desired.

Definition

The acceptance rate beta(x(ltt))beta_(x_(lt t)) , given a prefix x(ltt)x_(lt t), is the probability of accepting xttilde.opq(xtx(ltt))x_t tilde.op q(x_t |x_(lt t)) by speculative sampling.

Lemma

Define

D(LK)(p,q)=sumxp(x)M(x)=sumxq(x)M(x)D_(L K)(p, q)=sum_x |p(x) − M(x)|=sum_x |q(x) − M(x)|

where M(x)=(p(x)+q(x))/2M(x)=(p(x)+q(x))/2. Then

D(LK)(p,q)=1sumxminp(x),q(x)D_(L K)(p, q) = 1−sum_x min{p(x), q(x)}

Proof.

D(LK)(p,q)=sumxp(x)M(x)=sumx(pq)/2=1sumx(p+qpq)/2=1sumxminp(x),q(x)D_(L K)(p, q) = sum_x |p(x) − M(x)|=sum_x (|p−q|) / 2=1−sum_x (p+q−|p−q|)/2=1-sum_x min{p(x), q(x)}

On this page