class EagleSpeculator:
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.device = device
self.speculative_config = vllm_config.speculative_config
assert self.speculative_config is not None
self.method = self.speculative_config.method
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
self.draft_model_config = self.speculative_config.draft_model_config
self.scheduler_config = vllm_config.scheduler_config
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_model_len = vllm_config.model_config.max_model_len
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.vocab_size = self.draft_model_config.get_vocab_size()
self.pin_memory = is_pin_memory_available()
self.dtype = vllm_config.model_config.dtype
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
hidden_size=self.hidden_size,
vocab_size=self.vocab_size,
dtype=self.dtype,
device=device,
pin_memory=self.pin_memory,
)
self.hidden_states = torch.zeros(
self.max_num_tokens,
self.hidden_size,
dtype=self.dtype,
device=device,
)
self.temperature = torch.zeros(
self.max_num_reqs,
dtype=torch.float32,
device=device,
)
self.seeds = torch.zeros(
self.max_num_reqs,
dtype=torch.int64,
device=device,
)
self.draft_tokens = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
dtype=torch.int64,
device=device,
)
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
def load_model(self, target_model: nn.Module) -> None:
from vllm.compilation.backends import set_model_tag
with set_model_tag("eagle_head"):
self.model = get_model(
vllm_config=self.vllm_config, model_config=self.draft_model_config
)
share_lm_head = True
if share_lm_head and hasattr(target_model, "lm_head"):
if hasattr(self.model, "lm_head"):
del self.model.lm_head
self.model.lm_head = target_model.lm_head
def set_attn(
self,
kv_cache_config: KVCacheConfig,
attn_metadata_builders: list[AttentionMetadataBuilder],
block_tables: BlockTables,
) -> None:
self.kv_cache_config = kv_cache_config
self.attn_metadata_builders = attn_metadata_builders
self.block_tables = block_tables
@torch.inference_mode()
def run_model(
self,
num_tokens: int,
attn_metadata: dict[str, Any],
num_tokens_across_dp: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
):
ret_hidden_states = self.model(
input_ids=self.input_buffers.input_ids.gpu[:num_tokens],
positions=self.input_buffers.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
return last_hidden_states, hidden_states
def generate_draft(
self,
num_reqs: int,
attn_metadata: dict[str, Any],
num_tokens_across_dp: torch.Tensor | None,
) -> None:
pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
for step in range(1, self.num_speculative_steps):
# Run the eagle model.
last_hidden_states, hidden_states = self.run_model(
num_reqs, attn_metadata, num_tokens_across_dp
)
logits = self.model.compute_logits(last_hidden_states)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens = gumbel_sample(
logits,
self.temperature[:num_reqs],
self.seeds[:num_reqs],
pos + 1,
apply_temperature=True,
)
self.draft_tokens[:num_reqs, step] = draft_tokens
if step < self.num_speculative_steps - 1:
# Update the inputs for the next step.
update_eagle_inputs(
draft_tokens,
hidden_states,
self.input_buffers,
self.hidden_states,
self.max_model_len,
)
self.block_tables.compute_slot_mappings(query_start_loc, pos)
def capture_model(self) -> None:
if self.num_speculative_steps == 1:
return
logger.info("Capturing model for Eagle speculator...")
self.cudagraph_manager.capture(
self.generate_draft,
self.input_buffers,
self.block_tables,
self.attn_metadata_builders,
self.kv_cache_config,
)
@torch.inference_mode()
def propose(
self,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
# [num_tokens, hidden_size]
last_hidden_states: torch.Tensor,
# num_layers x [num_tokens, hidden_size]
aux_hidden_states: list[torch.Tensor] | None,
# [num_reqs]
num_sampled: torch.Tensor,
# [num_reqs]
num_rejected: torch.Tensor,
# [max_num_reqs, 1]
last_sampled: torch.Tensor,
# [num_reqs]
next_prefill_tokens: torch.Tensor,
) -> torch.Tensor:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and
# hidden_states the same as the target model's. This means, we pad each
# request's query length to include any rejected positions. By doing so,
# we can also reuse the attention metadata (e.g., query_start_loc,
# seq_lens) of the target model.
if aux_hidden_states:
assert self.method == "eagle3"
hidden_states = self.model.combine_hidden_states(
torch.cat(aux_hidden_states, dim=-1)
)
else:
hidden_states = last_hidden_states
num_tokens = input_batch.num_tokens_after_padding
self.hidden_states[:num_tokens] = hidden_states
# Get the input ids and last token indices for the speculator.
last_token_indices = prepare_eagle_inputs(
self.input_buffers,
input_batch,
num_sampled,
num_rejected,
last_sampled,
next_prefill_tokens,
)
# Prefill: Run the eagle speculator with eager mode.
# TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states, hidden_states = self.run_model(
num_tokens,
input_batch.attn_metadata,
num_tokens_across_dp=None, # FIXME
)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
num_reqs = input_batch.num_reqs
cu_num_logits = input_batch.cu_num_logits[:num_reqs]
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
# While this may slightly degrade the acceptance rate, it does not
# affect the output distribution after rejection sampling.
temperature = self.temperature[:num_reqs]
seeds = self.seeds[:num_reqs]
pos = self.input_buffers.positions[:num_reqs]
# Gather the values and copy them to the pre-allocated buffers.
torch.gather(sampling_metadata.temperature, 0, cu_num_logits, out=temperature)
torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds)
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
# used for draft and target sampling.
draft_tokens = gumbel_sample(
logits, temperature, seeds, pos + 1, apply_temperature=True
)
if self.num_speculative_steps == 1:
# Early exit.
return draft_tokens.view(-1, 1)
# Save the draft tokens for the first step.
self.draft_tokens[:num_reqs, 0] = draft_tokens
# Prepare the inputs for the decode steps.
prepare_eagle_decode(
draft_tokens,
hidden_states,
last_token_indices,
input_batch.seq_lens,
num_rejected,
self.input_buffers,
self.hidden_states,
self.max_model_len,
self.max_num_reqs,
)
query_start_loc = self.input_buffers.query_start_loc
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc_gpu, pos
)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
if cudagraph_size is not None:
# Run CUDA graph.
self.cudagraph_manager.run(cudagraph_size)
return self.draft_tokens[:num_reqs]
# Run eager mode.
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1)
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
# HACK(woosuk)
seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_reqs,
query_start_loc_gpu=query_start_loc_gpu,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
seq_lens_np=seq_lens_np,
num_computed_tokens_cpu=None, # FIXME
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
)
self.generate_draft(num_reqs, attn_metadata, num_tokens_across_dp=None) # FIXME
return self.draft_tokens[:num_reqs]