Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
|
🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
The implementation of the vLLM-based evaluation framework is a strong addition to MaxText, providing native support for custom benchmarks, lm-evaluation-harness, and evalchemy. The code is well-structured, but it needs critical updates to correctly support multi-host TPU environments whereリード (lead) rank coordination is essential.
🔍 General Feedback
- Rank Coordination: In multi-host TPU setups, client-side operations (warmup, generation, reporting) must be restricted to
jax.process_index() == 0to avoid redundant work and failures on non-lead ranks. - Configurability: Key parameters like request timeouts should be made configurable via the CLI/config files rather than being hardcoded.
- Efficiency: Minor optimizations in NLTK data handling and FastAPI request processing would improve the overall robustness and performance of the evaluation tool.
| warmup_server(base_url=base_url, model=model_name, sample_requests=requests) | ||
|
|
||
| # Generate responses. | ||
| logger.info("Generating responses for %d prompts.", len(prompts)) | ||
| t0 = time.time() | ||
| results = generate_batch( | ||
| prompts=prompts, | ||
| base_url=base_url, | ||
| model=model_name, | ||
| max_tokens=max_tokens, | ||
| temperature=temperature, | ||
| concurrency=concurrency, | ||
| ) | ||
| elapsed = time.time() - t0 | ||
| logger.info("Generation completed in %.1fs (%.1f samples/s).", elapsed, len(prompts) / elapsed) | ||
|
|
||
| # Score. | ||
| responses = [r.text for r in results] | ||
| errors = [r for r in results if r.error] | ||
| if errors: | ||
| logger.warning("%d generation errors (out of %d).", len(errors), len(results)) | ||
|
|
||
| scorer = get_scorer(benchmark) | ||
| scores = scorer(responses, references) | ||
| logger.info("Scores: %s", scores) | ||
|
|
||
| # Write results | ||
| generation_stats = { | ||
| "total_samples": len(prompts), | ||
| "num_errors": len(errors), | ||
| "elapsed_s": round(elapsed, 2), | ||
| "samples_per_second": round(len(prompts) / elapsed, 2), | ||
| "total_prompt_tokens": sum(r.prompt_tokens for r in results), | ||
| "total_completion_tokens": sum(r.completion_tokens for r in results), | ||
| } | ||
| output = write_results( | ||
| benchmark=benchmark, | ||
| model_name=model_name, | ||
| scores=scores, | ||
| generation_stats=generation_stats, | ||
| config=cfg, | ||
| results_path=results_path, | ||
| ) | ||
|
|
||
| # Optional GCS Upload. | ||
| if gcs_results_path: | ||
| from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel | ||
| upload_results(output["local_path"], gcs_results_path) |
There was a problem hiding this comment.
🔴 In a multi-host TPU environment, this script runs on all ranks. Currently, all ranks will attempt to send HTTP requests to localhost, which will fail on ranks 1..N (as the server only binds to Rank 0's localhost). Furthermore, reporting and GCS uploads should only be performed by the lead rank to avoid race conditions and redundant work.
| warmup_server(base_url=base_url, model=model_name, sample_requests=requests) | |
| # Generate responses. | |
| logger.info("Generating responses for %d prompts.", len(prompts)) | |
| t0 = time.time() | |
| results = generate_batch( | |
| prompts=prompts, | |
| base_url=base_url, | |
| model=model_name, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| concurrency=concurrency, | |
| ) | |
| elapsed = time.time() - t0 | |
| logger.info("Generation completed in %.1fs (%.1f samples/s).", elapsed, len(prompts) / elapsed) | |
| # Score. | |
| responses = [r.text for r in results] | |
| errors = [r for r in results if r.error] | |
| if errors: | |
| logger.warning("%d generation errors (out of %d).", len(errors), len(results)) | |
| scorer = get_scorer(benchmark) | |
| scores = scorer(responses, references) | |
| logger.info("Scores: %s", scores) | |
| # Write results | |
| generation_stats = { | |
| "total_samples": len(prompts), | |
| "num_errors": len(errors), | |
| "elapsed_s": round(elapsed, 2), | |
| "samples_per_second": round(len(prompts) / elapsed, 2), | |
| "total_prompt_tokens": sum(r.prompt_tokens for r in results), | |
| "total_completion_tokens": sum(r.completion_tokens for r in results), | |
| } | |
| output = write_results( | |
| benchmark=benchmark, | |
| model_name=model_name, | |
| scores=scores, | |
| generation_stats=generation_stats, | |
| config=cfg, | |
| results_path=results_path, | |
| ) | |
| # Optional GCS Upload. | |
| if gcs_results_path: | |
| from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel | |
| upload_results(output["local_path"], gcs_results_path) | |
| with VllmServerManager( | |
| model_path=hf_path, | |
| checkpoint_path=checkpoint_path if use_maxtext_adapter else None, | |
| maxtext_model_name=model_name if use_maxtext_adapter else None, | |
| host=server_host, | |
| port=server_port, | |
| tensor_parallel_size=tensor_parallel_size, | |
| max_model_len=max_model_len, | |
| max_num_batched_tokens=max_num_batched_tokens, | |
| max_num_seqs=max_num_seqs, | |
| env=server_env, | |
| ) as server: | |
| if jax.process_index() == 0: | |
| base_url = server.base_url | |
| # Warmup server. | |
| warmup_server(base_url=base_url, model=model_name, sample_requests=requests) | |
| # Generate responses. | |
| logger.info("Generating responses for %d prompts.", len(prompts)) | |
| t0 = time.time() | |
| results = generate_batch( | |
| prompts=prompts, | |
| base_url=base_url, | |
| model=model_name, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| concurrency=concurrency, | |
| ) | |
| elapsed = time.time() - t0 | |
| logger.info("Generation completed in %.1fs (%.1f samples/s).", elapsed, len(prompts) / elapsed) | |
| # Score. | |
| responses = [r.text for r in results] | |
| errors = [r for r in results if r.error] | |
| if errors: | |
| logger.warning("%d generation errors (out of %d).", len(errors), len(results)) | |
| scorer = get_scorer(benchmark) | |
| scores = scorer(responses, references) | |
| logger.info("Scores: %s", scores) | |
| # Write results | |
| generation_stats = { | |
| "total_samples": len(prompts), | |
| "num_errors": len(errors), | |
| "elapsed_s": round(elapsed, 2), | |
| "samples_per_second": round(len(prompts) / elapsed, 2), | |
| "total_prompt_tokens": sum(r.prompt_tokens for r in results), | |
| "total_completion_tokens": sum(r.completion_tokens for r in results), | |
| } | |
| output = write_results( | |
| benchmark=benchmark, | |
| model_name=model_name, | |
| scores=scores, | |
| generation_stats=generation_stats, | |
| config=cfg, | |
| results_path=results_path, | |
| ) | |
| # Optional GCS Upload. | |
| if gcs_results_path: | |
| from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel | |
| upload_results(output["local_path"], gcs_results_path) | |
| return output | |
| else: | |
| # Non-rank-0 processes wait for the server to finish or participate in distributed work. | |
| # VllmServerManager context ensures they participate in LLM initialization. | |
| pass | |
| return {} |
| model_args = _build_model_args( | ||
| base_url=server.base_url, | ||
| tokenizer_path=hf_path, | ||
| model_name=model_name, | ||
| hf_token=token, | ||
| ) | ||
| logger.info("Running lm-eval tasks %s via local-completions at %s", lm_tasks, server.base_url) | ||
| lm_results = lm_eval_lib.simple_evaluate( | ||
| model="local-completions", | ||
| model_args=model_args, | ||
| tasks=lm_tasks, | ||
| num_fewshot=num_fewshot, | ||
| limit=num_samples, | ||
| log_samples=False, | ||
| ) | ||
|
|
||
| scores = _map_lm_eval_results(lm_results, tasks) | ||
| logger.info("lm-eval scores: %s", scores) | ||
|
|
||
| output = write_results( | ||
| benchmark="+".join(tasks), | ||
| model_name=model_name, | ||
| scores=scores, | ||
| generation_stats={"lm_eval_config": lm_results.get("config", {})}, | ||
| config=cfg, | ||
| results_path=results_path, | ||
| ) | ||
|
|
||
| if gcs_results_path: | ||
| from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel | ||
| upload_results(output["local_path"], gcs_results_path) |
There was a problem hiding this comment.
🔴 This integration lacks a rank check. In a multi-host TPU setup, ranks 1..N will attempt to run simple_evaluate against Rank 0's localhost and fail. All client-side logic (warmup, evaluation, and reporting) should be guarded.
| model_args = _build_model_args( | |
| base_url=server.base_url, | |
| tokenizer_path=hf_path, | |
| model_name=model_name, | |
| hf_token=token, | |
| ) | |
| logger.info("Running lm-eval tasks %s via local-completions at %s", lm_tasks, server.base_url) | |
| lm_results = lm_eval_lib.simple_evaluate( | |
| model="local-completions", | |
| model_args=model_args, | |
| tasks=lm_tasks, | |
| num_fewshot=num_fewshot, | |
| limit=num_samples, | |
| log_samples=False, | |
| ) | |
| scores = _map_lm_eval_results(lm_results, tasks) | |
| logger.info("lm-eval scores: %s", scores) | |
| output = write_results( | |
| benchmark="+".join(tasks), | |
| model_name=model_name, | |
| scores=scores, | |
| generation_stats={"lm_eval_config": lm_results.get("config", {})}, | |
| config=cfg, | |
| results_path=results_path, | |
| ) | |
| if gcs_results_path: | |
| from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel | |
| upload_results(output["local_path"], gcs_results_path) | |
| with VllmServerManager( | |
| model_path=hf_path, | |
| checkpoint_path=checkpoint_path if use_maxtext_adapter else None, | |
| maxtext_model_name=model_name if use_maxtext_adapter else None, | |
| host=server_host, | |
| port=server_port, | |
| tensor_parallel_size=tensor_parallel_size, | |
| max_model_len=max_model_len, | |
| max_num_batched_tokens=max_num_batched_tokens, | |
| max_num_seqs=max_num_seqs, | |
| env=server_env, | |
| ) as server: | |
| if jax.process_index() == 0: | |
| warmup_server(base_url=server.base_url, model=model_name) | |
| model_args = _build_model_args( | |
| base_url=server.base_url, | |
| tokenizer_path=hf_path, | |
| model_name=model_name, | |
| hf_token=token, | |
| ) | |
| logger.info("Running lm-eval tasks %s via local-completions at %s", lm_tasks, server.base_url) | |
| lm_results = lm_eval_lib.simple_evaluate( | |
| model="local-completions", | |
| model_args=model_args, | |
| tasks=lm_tasks, | |
| num_fewshot=num_fewshot, | |
| limit=num_samples, | |
| log_samples=False, | |
| ) | |
| scores = _map_lm_eval_results(lm_results, tasks) | |
| logger.info("lm-eval scores: %s", scores) | |
| output = write_results( | |
| benchmark="+".join(tasks), | |
| model_name=model_name, | |
| scores=scores, | |
| generation_stats={"lm_eval_config": lm_results.get("config", {})}, | |
| config=cfg, | |
| results_path=results_path, | |
| ) | |
| if gcs_results_path: | |
| from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel | |
| upload_results(output["local_path"], gcs_results_path) | |
| return output | |
| else: | |
| pass | |
| return {} |
| warmup_server(base_url=server.base_url, model=model_name) | ||
|
|
||
| model_args = _build_model_args( | ||
| base_url=server.base_url, | ||
| tokenizer_path=hf_path, | ||
| model_name=model_name, | ||
| hf_token=token, | ||
| ) | ||
| logger.info( | ||
| "Running evalchemy tasks %s via local-chat-completions at %s", | ||
| lm_eval_tasks, | ||
| server.base_url, | ||
| ) | ||
| evalchemy_results = lm_eval.simple_evaluate( | ||
| model="local-chat-completions", | ||
| model_args=model_args, | ||
| tasks=lm_eval_tasks, | ||
| num_fewshot=num_fewshot, | ||
| limit=num_samples, | ||
| log_samples=False, | ||
| ) | ||
|
|
||
| scores = _map_evalchemy_results(evalchemy_results, tasks) | ||
| logger.info("evalchemy scores: %s", scores) | ||
|
|
||
| output = write_results( | ||
| benchmark="+".join(tasks), | ||
| model_name=model_name, | ||
| scores=scores, | ||
| generation_stats={"evalchemy_config": evalchemy_results.get("config", {})}, | ||
| config=cfg, | ||
| results_path=results_path, | ||
| ) | ||
|
|
||
| if gcs_results_path: | ||
| from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel | ||
| upload_results(output["local_path"], gcs_results_path) |
There was a problem hiding this comment.
🔴 This integration also lacks a rank check. Ranks 1..N should not attempt to perform evaluation or reporting.
| warmup_server(base_url=server.base_url, model=model_name) | |
| model_args = _build_model_args( | |
| base_url=server.base_url, | |
| tokenizer_path=hf_path, | |
| model_name=model_name, | |
| hf_token=token, | |
| ) | |
| logger.info( | |
| "Running evalchemy tasks %s via local-chat-completions at %s", | |
| lm_eval_tasks, | |
| server.base_url, | |
| ) | |
| evalchemy_results = lm_eval.simple_evaluate( | |
| model="local-chat-completions", | |
| model_args=model_args, | |
| tasks=lm_eval_tasks, | |
| num_fewshot=num_fewshot, | |
| limit=num_samples, | |
| log_samples=False, | |
| ) | |
| scores = _map_evalchemy_results(evalchemy_results, tasks) | |
| logger.info("evalchemy scores: %s", scores) | |
| output = write_results( | |
| benchmark="+".join(tasks), | |
| model_name=model_name, | |
| scores=scores, | |
| generation_stats={"evalchemy_config": evalchemy_results.get("config", {})}, | |
| config=cfg, | |
| results_path=results_path, | |
| ) | |
| if gcs_results_path: | |
| from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel | |
| upload_results(output["local_path"], gcs_results_path) | |
| with VllmServerManager( | |
| model_path=hf_path, | |
| checkpoint_path=checkpoint_path if use_maxtext_adapter else None, | |
| maxtext_model_name=model_name if use_maxtext_adapter else None, | |
| host=server_host, | |
| port=server_port, | |
| tensor_parallel_size=tensor_parallel_size, | |
| max_model_len=max_model_len, | |
| max_num_batched_tokens=max_num_batched_tokens, | |
| max_num_seqs=max_num_seqs, | |
| env=server_env, | |
| ) as server: | |
| if jax.process_index() == 0: | |
| warmup_server(base_url=server.base_url, model=model_name) | |
| model_args = _build_model_args( | |
| base_url=server.base_url, | |
| tokenizer_path=hf_path, | |
| model_name=model_name, | |
| hf_token=token, | |
| ) | |
| logger.info( | |
| "Running evalchemy tasks %s via local-chat-completions at %s", | |
| lm_eval_tasks, | |
| server.base_url, | |
| ) | |
| evalchemy_results = lm_eval.simple_evaluate( | |
| model="local-chat-completions", | |
| model_args=model_args, | |
| tasks=lm_eval_tasks, | |
| num_fewshot=num_fewshot, | |
| limit=num_samples, | |
| log_samples=False, | |
| ) | |
| scores = _map_evalchemy_results(evalchemy_results, tasks) | |
| logger.info("evalchemy scores: %s", scores) | |
| output = write_results( | |
| benchmark="+".join(tasks), | |
| model_name=model_name, | |
| scores=scores, | |
| generation_stats={"evalchemy_config": evalchemy_results.get("config", {})}, | |
| config=cfg, | |
| results_path=results_path, | |
| ) | |
| if gcs_results_path: | |
| from maxtext.eval.reporting.gcs_reporter import upload_results # pylint: disable=import-outside-toplevel | |
| upload_results(output["local_path"], gcs_results_path) | |
| return output | |
| else: | |
| pass | |
| return {} |
| import evaluate # pylint: disable=import-outside-toplevel | ||
| import nltk # pylint: disable=import-outside-toplevel | ||
|
|
||
| nltk.download("punkt", quiet=True) |
There was a problem hiding this comment.
🟡 Downloading NLTK data on every score_batch call is inefficient and can cause issues in environments with restricted internet access. It's better to perform these downloads once during initialization or provide a way to use pre-downloaded data.
| nltk.download("punkt", quiet=True) | |
| def score_batch(preds: list[str], refs: list[str]) -> dict: | |
| """Compute ROUGE scores for a batch of predictions and references.""" | |
| import evaluate | |
| import nltk | |
| import numpy as np | |
| # (TODO): Consider moving these to a module-level initialization step. | |
| nltk.download("punkt", quiet=True) | |
| nltk.download("punkt_tab", quiet=True) |
| _DEFAULT_MAX_TOKENS = 1024 | ||
| _DEFAULT_TEMPERATURE = 0.0 | ||
| _COMPLETIONS_PATH = "/v1/completions" | ||
| _REQUEST_TIMEOUT_S = 600 # (TODO): Check if this is reasoanable. |
There was a problem hiding this comment.
🟡 Hardcoding the request timeout to 600 seconds is inflexible. It would be better to make this value configurable through the cfg dictionary to allow users to adjust it based on their model size and prompt length.
| _REQUEST_TIMEOUT_S = 600 # (TODO): Check if this is reasoanable. | |
| _REQUEST_TIMEOUT_S = 600 # (TODO): Make this configurable via cfg. |
| return {"status": "ok"} | ||
|
|
||
| @app.post("/v1/completions") | ||
| async def completions(request: fastapi.Request): |
There was a problem hiding this comment.
🟡 Calling the synchronous llm.generate directly inside an async FastAPI endpoint blocks the event loop. While this might be acceptable for evaluation scripts where only one client is active, it prevents concurrent request handling within the server. Consider using anyio.to_thread.run_sync or transitioning to vLLM's AsyncLLMEngine for more efficient server operation.
@app.post("/v1/completions")
async def completions(request: fastapi.Request):
# ...
# This blocks the FastAPI event loop:
outputs = llm.generate(prompts, SamplingParams(**sp_kwargs))
RissyRan
left a comment
There was a problem hiding this comment.
Thanks for the PR! I’ve done a high-level pass, though I haven’t done a deep dive into the code just yet. Have you had a chance to run any benchmarks? I'm curious if you're seeing decent scores. Also, is multi-host benchmarking for large models within the scope of this PR?
| Requires: `pip install evalchemy` | ||
|
|
||
| ```bash | ||
| python -m maxtext.eval.runner.evalchemy_runner \ |
There was a problem hiding this comment.
Have you tested if large scale works? i.e. a workload on v5p-64.
| | `--hf_token` | HuggingFace token for gated models | | ||
| | `--num_samples` | Limit number of eval samples | | ||
| | `--hf_mode` | Force HF safetensors mode (disables MaxTextForCausalLM mode) | | ||
| | `--tensor_parallel_size` | vLLM tensor parallelism | |
There was a problem hiding this comment.
is this the only sharding supported?
| python -m maxtext.eval.runner.eval_runner ... | ||
| ``` | ||
|
|
||
| ### Configuration (eval_runner) |
There was a problem hiding this comment.
Do you think it will be useful if we put common configs together? Then only put new ones for eval_runner, lm_eval_runner, and evalchemy_runner?
| logger = logging.getLogger(__name__) | ||
|
|
||
| # Maps MaxText benchmark names to lm-eval task names. | ||
| _TASK_MAP: dict[str, str] = { |
There was a problem hiding this comment.
Wondering if we could directly run lm-eval tasks instead of adding extra mapping layer here? So for any future benchmarks in lm-eval-harness, we could directly use. Similar comments for other framework if applies.
Description
Implement a evaluation framework with vllm backend. Requirements, design, further details: go/eval-framework-vllm
The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456
Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.
Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.