Skip to content

vLLM based Eval framework#3531

Open
dipannita08 wants to merge 9 commits intomainfrom
eval-framework-01
Open

vLLM based Eval framework#3531
dipannita08 wants to merge 9 commits intomainfrom
eval-framework-01

Conversation

@dipannita08
Copy link
Copy Markdown
Collaborator

Description

Implement a evaluation framework with vllm backend. Requirements, design, further details: go/eval-framework-vllm

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov bot commented Mar 31, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@github-actions
Copy link
Copy Markdown

🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

The implementation of the vLLM-based evaluation framework is a strong addition to MaxText, providing native support for custom benchmarks, lm-evaluation-harness, and evalchemy. The code is well-structured, but it needs critical updates to correctly support multi-host TPU environments whereリード (lead) rank coordination is essential.

🔍 General Feedback

  • Rank Coordination: In multi-host TPU setups, client-side operations (warmup, generation, reporting) must be restricted to jax.process_index() == 0 to avoid redundant work and failures on non-lead ranks.
  • Configurability: Key parameters like request timeouts should be made configurable via the CLI/config files rather than being hardcoded.
  • Efficiency: Minor optimizations in NLTK data handling and FastAPI request processing would improve the overall robustness and performance of the evaluation tool.

Comment on lines +183 to +230
warmup_server(base_url=base_url, model=model_name, sample_requests=requests)

# Generate responses.
logger.info("Generating responses for %d prompts.", len(prompts))
t0 = time.time()
results = generate_batch(
prompts=prompts,
base_url=base_url,
model=model_name,
max_tokens=max_tokens,
temperature=temperature,
concurrency=concurrency,
)
elapsed = time.time() - t0
logger.info("Generation completed in %.1fs (%.1f samples/s).", elapsed, len(prompts) / elapsed)

# Score.
responses = [r.text for r in results]
errors = [r for r in results if r.error]
if errors:
logger.warning("%d generation errors (out of %d).", len(errors), len(results))

scorer = get_scorer(benchmark)
scores = scorer(responses, references)
logger.info("Scores: %s", scores)

# Write results
generation_stats = {
"total_samples": len(prompts),
"num_errors": len(errors),
"elapsed_s": round(elapsed, 2),
"samples_per_second": round(len(prompts) / elapsed, 2),
"total_prompt_tokens": sum(r.prompt_tokens for r in results),
"total_completion_tokens": sum(r.completion_tokens for r in results),
}
output = write_results(
benchmark=benchmark,
model_name=model_name,
scores=scores,
generation_stats=generation_stats,
config=cfg,
results_path=results_path,
)

# Optional GCS Upload.
if gcs_results_path:
from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel
upload_results(output["local_path"], gcs_results_path)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 In a multi-host TPU environment, this script runs on all ranks. Currently, all ranks will attempt to send HTTP requests to localhost, which will fail on ranks 1..N (as the server only binds to Rank 0's localhost). Furthermore, reporting and GCS uploads should only be performed by the lead rank to avoid race conditions and redundant work.

Suggested change
warmup_server(base_url=base_url, model=model_name, sample_requests=requests)
# Generate responses.
logger.info("Generating responses for %d prompts.", len(prompts))
t0 = time.time()
results = generate_batch(
prompts=prompts,
base_url=base_url,
model=model_name,
max_tokens=max_tokens,
temperature=temperature,
concurrency=concurrency,
)
elapsed = time.time() - t0
logger.info("Generation completed in %.1fs (%.1f samples/s).", elapsed, len(prompts) / elapsed)
# Score.
responses = [r.text for r in results]
errors = [r for r in results if r.error]
if errors:
logger.warning("%d generation errors (out of %d).", len(errors), len(results))
scorer = get_scorer(benchmark)
scores = scorer(responses, references)
logger.info("Scores: %s", scores)
# Write results
generation_stats = {
"total_samples": len(prompts),
"num_errors": len(errors),
"elapsed_s": round(elapsed, 2),
"samples_per_second": round(len(prompts) / elapsed, 2),
"total_prompt_tokens": sum(r.prompt_tokens for r in results),
"total_completion_tokens": sum(r.completion_tokens for r in results),
}
output = write_results(
benchmark=benchmark,
model_name=model_name,
scores=scores,
generation_stats=generation_stats,
config=cfg,
results_path=results_path,
)
# Optional GCS Upload.
if gcs_results_path:
from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel
upload_results(output["local_path"], gcs_results_path)
with VllmServerManager(
model_path=hf_path,
checkpoint_path=checkpoint_path if use_maxtext_adapter else None,
maxtext_model_name=model_name if use_maxtext_adapter else None,
host=server_host,
port=server_port,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
env=server_env,
) as server:
if jax.process_index() == 0:
base_url = server.base_url
# Warmup server.
warmup_server(base_url=base_url, model=model_name, sample_requests=requests)
# Generate responses.
logger.info("Generating responses for %d prompts.", len(prompts))
t0 = time.time()
results = generate_batch(
prompts=prompts,
base_url=base_url,
model=model_name,
max_tokens=max_tokens,
temperature=temperature,
concurrency=concurrency,
)
elapsed = time.time() - t0
logger.info("Generation completed in %.1fs (%.1f samples/s).", elapsed, len(prompts) / elapsed)
# Score.
responses = [r.text for r in results]
errors = [r for r in results if r.error]
if errors:
logger.warning("%d generation errors (out of %d).", len(errors), len(results))
scorer = get_scorer(benchmark)
scores = scorer(responses, references)
logger.info("Scores: %s", scores)
# Write results
generation_stats = {
"total_samples": len(prompts),
"num_errors": len(errors),
"elapsed_s": round(elapsed, 2),
"samples_per_second": round(len(prompts) / elapsed, 2),
"total_prompt_tokens": sum(r.prompt_tokens for r in results),
"total_completion_tokens": sum(r.completion_tokens for r in results),
}
output = write_results(
benchmark=benchmark,
model_name=model_name,
scores=scores,
generation_stats=generation_stats,
config=cfg,
results_path=results_path,
)
# Optional GCS Upload.
if gcs_results_path:
from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel
upload_results(output["local_path"], gcs_results_path)
return output
else:
# Non-rank-0 processes wait for the server to finish or participate in distributed work.
# VllmServerManager context ensures they participate in LLM initialization.
pass
return {}

Comment on lines +159 to +189
model_args = _build_model_args(
base_url=server.base_url,
tokenizer_path=hf_path,
model_name=model_name,
hf_token=token,
)
logger.info("Running lm-eval tasks %s via local-completions at %s", lm_tasks, server.base_url)
lm_results = lm_eval_lib.simple_evaluate(
model="local-completions",
model_args=model_args,
tasks=lm_tasks,
num_fewshot=num_fewshot,
limit=num_samples,
log_samples=False,
)

scores = _map_lm_eval_results(lm_results, tasks)
logger.info("lm-eval scores: %s", scores)

output = write_results(
benchmark="+".join(tasks),
model_name=model_name,
scores=scores,
generation_stats={"lm_eval_config": lm_results.get("config", {})},
config=cfg,
results_path=results_path,
)

if gcs_results_path:
from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel
upload_results(output["local_path"], gcs_results_path)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 This integration lacks a rank check. In a multi-host TPU setup, ranks 1..N will attempt to run simple_evaluate against Rank 0's localhost and fail. All client-side logic (warmup, evaluation, and reporting) should be guarded.

Suggested change
model_args = _build_model_args(
base_url=server.base_url,
tokenizer_path=hf_path,
model_name=model_name,
hf_token=token,
)
logger.info("Running lm-eval tasks %s via local-completions at %s", lm_tasks, server.base_url)
lm_results = lm_eval_lib.simple_evaluate(
model="local-completions",
model_args=model_args,
tasks=lm_tasks,
num_fewshot=num_fewshot,
limit=num_samples,
log_samples=False,
)
scores = _map_lm_eval_results(lm_results, tasks)
logger.info("lm-eval scores: %s", scores)
output = write_results(
benchmark="+".join(tasks),
model_name=model_name,
scores=scores,
generation_stats={"lm_eval_config": lm_results.get("config", {})},
config=cfg,
results_path=results_path,
)
if gcs_results_path:
from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel
upload_results(output["local_path"], gcs_results_path)
with VllmServerManager(
model_path=hf_path,
checkpoint_path=checkpoint_path if use_maxtext_adapter else None,
maxtext_model_name=model_name if use_maxtext_adapter else None,
host=server_host,
port=server_port,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
env=server_env,
) as server:
if jax.process_index() == 0:
warmup_server(base_url=server.base_url, model=model_name)
model_args = _build_model_args(
base_url=server.base_url,
tokenizer_path=hf_path,
model_name=model_name,
hf_token=token,
)
logger.info("Running lm-eval tasks %s via local-completions at %s", lm_tasks, server.base_url)
lm_results = lm_eval_lib.simple_evaluate(
model="local-completions",
model_args=model_args,
tasks=lm_tasks,
num_fewshot=num_fewshot,
limit=num_samples,
log_samples=False,
)
scores = _map_lm_eval_results(lm_results, tasks)
logger.info("lm-eval scores: %s", scores)
output = write_results(
benchmark="+".join(tasks),
model_name=model_name,
scores=scores,
generation_stats={"lm_eval_config": lm_results.get("config", {})},
config=cfg,
results_path=results_path,
)
if gcs_results_path:
from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel
upload_results(output["local_path"], gcs_results_path)
return output
else:
pass
return {}

Comment on lines +216 to +252
warmup_server(base_url=server.base_url, model=model_name)

model_args = _build_model_args(
base_url=server.base_url,
tokenizer_path=hf_path,
model_name=model_name,
hf_token=token,
)
logger.info(
"Running evalchemy tasks %s via local-chat-completions at %s",
lm_eval_tasks,
server.base_url,
)
evalchemy_results = lm_eval.simple_evaluate(
model="local-chat-completions",
model_args=model_args,
tasks=lm_eval_tasks,
num_fewshot=num_fewshot,
limit=num_samples,
log_samples=False,
)

scores = _map_evalchemy_results(evalchemy_results, tasks)
logger.info("evalchemy scores: %s", scores)

output = write_results(
benchmark="+".join(tasks),
model_name=model_name,
scores=scores,
generation_stats={"evalchemy_config": evalchemy_results.get("config", {})},
config=cfg,
results_path=results_path,
)

if gcs_results_path:
from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel
upload_results(output["local_path"], gcs_results_path)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 This integration also lacks a rank check. Ranks 1..N should not attempt to perform evaluation or reporting.

Suggested change
warmup_server(base_url=server.base_url, model=model_name)
model_args = _build_model_args(
base_url=server.base_url,
tokenizer_path=hf_path,
model_name=model_name,
hf_token=token,
)
logger.info(
"Running evalchemy tasks %s via local-chat-completions at %s",
lm_eval_tasks,
server.base_url,
)
evalchemy_results = lm_eval.simple_evaluate(
model="local-chat-completions",
model_args=model_args,
tasks=lm_eval_tasks,
num_fewshot=num_fewshot,
limit=num_samples,
log_samples=False,
)
scores = _map_evalchemy_results(evalchemy_results, tasks)
logger.info("evalchemy scores: %s", scores)
output = write_results(
benchmark="+".join(tasks),
model_name=model_name,
scores=scores,
generation_stats={"evalchemy_config": evalchemy_results.get("config", {})},
config=cfg,
results_path=results_path,
)
if gcs_results_path:
from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel
upload_results(output["local_path"], gcs_results_path)
with VllmServerManager(
model_path=hf_path,
checkpoint_path=checkpoint_path if use_maxtext_adapter else None,
maxtext_model_name=model_name if use_maxtext_adapter else None,
host=server_host,
port=server_port,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
env=server_env,
) as server:
if jax.process_index() == 0:
warmup_server(base_url=server.base_url, model=model_name)
model_args = _build_model_args(
base_url=server.base_url,
tokenizer_path=hf_path,
model_name=model_name,
hf_token=token,
)
logger.info(
"Running evalchemy tasks %s via local-chat-completions at %s",
lm_eval_tasks,
server.base_url,
)
evalchemy_results = lm_eval.simple_evaluate(
model="local-chat-completions",
model_args=model_args,
tasks=lm_eval_tasks,
num_fewshot=num_fewshot,
limit=num_samples,
log_samples=False,
)
scores = _map_evalchemy_results(evalchemy_results, tasks)
logger.info("evalchemy scores: %s", scores)
output = write_results(
benchmark="+".join(tasks),
model_name=model_name,
scores=scores,
generation_stats={"evalchemy_config": evalchemy_results.get("config", {})},
config=cfg,
results_path=results_path,
)
if gcs_results_path:
from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel
upload_results(output["local_path"], gcs_results_path)
return output
else:
pass
return {}

import evaluate # pylint: disable=import-outside-toplevel
import nltk # pylint: disable=import-outside-toplevel

nltk.download("punkt", quiet=True)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Downloading NLTK data on every score_batch call is inefficient and can cause issues in environments with restricted internet access. It's better to perform these downloads once during initialization or provide a way to use pre-downloaded data.

Suggested change
nltk.download("punkt", quiet=True)
def score_batch(preds: list[str], refs: list[str]) -> dict:
"""Compute ROUGE scores for a batch of predictions and references."""
import evaluate
import nltk
import numpy as np
# (TODO): Consider moving these to a module-level initialization step.
nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)

_DEFAULT_MAX_TOKENS = 1024
_DEFAULT_TEMPERATURE = 0.0
_COMPLETIONS_PATH = "/v1/completions"
_REQUEST_TIMEOUT_S = 600 # (TODO): Check if this is reasoanable.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Hardcoding the request timeout to 600 seconds is inflexible. It would be better to make this value configurable through the cfg dictionary to allow users to adjust it based on their model size and prompt length.

Suggested change
_REQUEST_TIMEOUT_S = 600 # (TODO): Check if this is reasoanable.
_REQUEST_TIMEOUT_S = 600 # (TODO): Make this configurable via cfg.

return {"status": "ok"}

@app.post("/v1/completions")
async def completions(request: fastapi.Request):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Calling the synchronous llm.generate directly inside an async FastAPI endpoint blocks the event loop. While this might be acceptable for evaluation scripts where only one client is active, it prevents concurrent request handling within the server. Consider using anyio.to_thread.run_sync or transitioning to vLLM's AsyncLLMEngine for more efficient server operation.

  @app.post("/v1/completions")
  async def completions(request: fastapi.Request):
    # ...
    # This blocks the FastAPI event loop:
    outputs = llm.generate(prompts, SamplingParams(**sp_kwargs))

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I’ve done a high-level pass, though I haven’t done a deep dive into the code just yet. Have you had a chance to run any benchmarks? I'm curious if you're seeing decent scores. Also, is multi-host benchmarking for large models within the scope of this PR?

Requires: `pip install evalchemy`

```bash
python -m maxtext.eval.runner.evalchemy_runner \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested if large scale works? i.e. a workload on v5p-64.

| `--hf_token` | HuggingFace token for gated models |
| `--num_samples` | Limit number of eval samples |
| `--hf_mode` | Force HF safetensors mode (disables MaxTextForCausalLM mode) |
| `--tensor_parallel_size` | vLLM tensor parallelism |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the only sharding supported?

python -m maxtext.eval.runner.eval_runner ...
```

### Configuration (eval_runner)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it will be useful if we put common configs together? Then only put new ones for eval_runner, lm_eval_runner, and evalchemy_runner?

logger = logging.getLogger(__name__)

# Maps MaxText benchmark names to lm-eval task names.
_TASK_MAP: dict[str, str] = {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we could directly run lm-eval tasks instead of adding extra mapping layer here? So for any future benchmarks in lm-eval-harness, we could directly use. Similar comments for other framework if applies.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants