Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
absl-py
aiohttp
aqtp
array-record
cloud-accelerator-diagnostics
cloud-tpu-diagnostics
datasets
drjax
evaluate
flax
gcsfs
google-api-python-client
Expand All @@ -20,10 +22,12 @@ jsonlines
math-verify
ml-collections
ml-goodput-measurement
nltk
numpy
omegaconf
optax
orbax-checkpoint
pandas
pathwaysutils
pillow
pre-commit
Expand Down
216 changes: 216 additions & 0 deletions src/maxtext/eval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# MaxText Model Evaluation Framework

A vLLM-native evaluation framework for MaxText models.

## Quick Start

### eval_runner: With MaxText checkpoint

```bash
python -m maxtext.eval.runner.eval_runner \
--config src/maxtext/eval/configs/mlperf.yml \
--checkpoint_path gs://<bucket>/checkpoints/0/items \
--model_name llama3.1-8b \
--hf_path meta-llama/Llama-3.1-8B-Instruct \
--base_output_directory gs://<bucket>/ \
--run_name mlperf_eval_run \
--hf_token $HF_TOKEN
```

### eval_runner: With HF model

Use `--hf_mode` with a public HF model to test the framework
without any MaxText checkpoint.

```bash
python -m maxtext.eval.runner.eval_runner \
--config src/maxtext/eval/configs/mlperf.yml \
--hf_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--model_name tinyllama \
--base_output_directory /tmp/eval_test/ \
--run_name smoke_test \
--hf_mode \
--num_samples 20 \
--tensor_parallel_size 1
```

### lm_eval_runner

Uses lm-evaluation-harness with loglikelihood scoring.

Requires: `pip install "lm_eval[api]"`

```bash
python -m maxtext.eval.runner.lm_eval_runner \
--checkpoint_path gs://<bucket>/checkpoints/0/items \
--model_name llama3.1-8b \
--hf_path meta-llama/Llama-3.1-8B-Instruct \
--tasks mmlu gpqa \
--base_output_directory gs://<bucket>/ \
--run_name my_run \
--max_model_len 8192 \
--tensor_parallel_size 4 \
--hf_token $HF_TOKEN
```

### evalchemy_runner

Uses [mlfoundations/evalchemy](https://github.com/mlfoundations/evalchemy), which
extends lm-evaluation-harness with chat-completions-based benchmarks. Imports
`evalchemy` for task registration then drives evaluation via `lm_eval.simple_evaluate`.

Requires: `pip install evalchemy`

```bash
python -m maxtext.eval.runner.evalchemy_runner \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested if large scale works? i.e. a workload on v5p-64.

--checkpoint_path gs://<bucket>/checkpoints/0/items \
--model_name llama3.1-8b \
--hf_path meta-llama/Llama-3.1-8B-Instruct \
--tasks ifeval math500 gpqa_diamond \
--base_output_directory gs://<bucket>/ \
--run_name my_run \
--max_model_len 8192 \
--tensor_parallel_size 4 \
--hf_token $HF_TOKEN
```

## HuggingFace Token

Llama, Gemma, and other gated models require a HuggingFace token. You must
also have accepted the model license on huggingface.co.

In the `MaxTextForCausalLM` mode, the token is only needed to
download the tokenizer, not model weights.

Pass the token in order of preference:

1. `--hf_token` — forwarded to the server and tokenizer loading.
2. `HF_TOKEN` environment variable (picked up automatically if `--hf_token` is not set).

```bash
# Pass hf_token.
python -m maxtext.eval.runner.eval_runner ... --hf_token hf_...

# Or export env variable.
export HF_TOKEN=hf_...
python -m maxtext.eval.runner.eval_runner ...
```

### Configuration (eval_runner)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it will be useful if we put common configs together? Then only put new ones for eval_runner, lm_eval_runner, and evalchemy_runner?


| Flag | Description |
|---|---|
| `--config` | Path to benchmark YAML. |
| `--base_config` | Path to MaxText config |
| `--checkpoint_path` | MaxText orbax checkpoint. Enables MaxTextForCausalLM mode. |
| `--hf_path` | HF model ID or tokenizer dir. |
| `--model_name` | MaxText model name (e.g. `llama3.1-8b`) |
| `--base_output_directory` | GCS or local base directory for results |
| `--run_name` | Run name, used in results path |
| `--hf_token` | HuggingFace token for gated models |
| `--num_samples` | Limit number of eval samples |
| `--hf_mode` | Force HF safetensors mode (disables MaxTextForCausalLM mode) |
| `--tensor_parallel_size` | vLLM tensor parallelism |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the only sharding supported?

| `--max_num_batched_tokens` | vLLM scheduler tokens per step |
| `--max_num_seqs` | vLLM max concurrent sequences (KV cache cap) |

### Configuration (lm_eval_runner)

| Flag | Description |
|---|---|
| `--checkpoint_path` | MaxText orbax checkpoint. Enables MaxTextForCausalLM mode. |
| `--model_name` | MaxText model name |
| `--hf_path` | HF model ID for tokenizer |
| `--tasks` | Space-separated lm-eval task names (e.g. `mmlu gpqa`) |
| `--base_output_directory` | GCS or local base directory for results |
| `--run_name` | Run name |
| `--max_model_len` | vLLM max context length |
| `--tensor_parallel_size` | Number of chips |
| `--num_fewshot` | Few-shot examples per task (default: 0) |
| `--num_samples` | Limit samples per task (default: full dataset) |
| `--hf_token` | HuggingFace token for gated models |
| `--hf_mode` | Force HF safetensors mode |

### Configuration (evalchemy_runner)

| Flag | Description |
|---|---|
| `--checkpoint_path` | MaxText orbax checkpoint. Enables MaxTextForCausalLM mode. |
| `--model_name` | MaxText model name |
| `--hf_path` | HF model ID for tokenizer |
| `--tasks` | Space-separated task names from the table above |
| `--base_output_directory` | GCS or local base directory for results |
| `--run_name` | Run name |
| `--max_model_len` | vLLM max context length |
| `--tensor_parallel_size` | Number of chips |
| `--num_fewshot` | Few-shot examples per task (default: 0) |
| `--num_samples` | Limit samples per task (default: full dataset) |
| `--hf_token` | HuggingFace token for gated models |
| `--hf_mode` | Force HF safetensors mode |

## Async Evaluation for RL Training

After an RL training run saves a checkpoint, you can evaluate it asynchronously
on a separate machine/job using `evalchemy_runner`.

**Prerequisites:**
- The checkpoint is written to GCS.
- Ensure evalchemy is installed (`pip show evalchemy` or `pip install evalchemy`)
- `HF_TOKEN` exported (needed for tokenizer download only)

**Supported math tasks:** `math500`, `aime24`, `aime25`, `amc23`, `gsm8k`

```bash
STEP=1000 # training step to evaluate
MODEL=qwen3-30b-a3b
HF_PATH=Qwen/Qwen3-30B-A3B
CHECKPOINT=gs://<bucket>/run/checkpoints/${STEP}/items
OUTPUT=gs://<bucket>/eval/

python -m maxtext.eval.runner.evalchemy_runner \
--checkpoint_path ${CHECKPOINT} \
--model_name ${MODEL} \
--hf_path ${HF_PATH} \
--tasks math500 aime24 gsm8k \
--base_output_directory ${OUTPUT} \
--run_name rl_${MODEL}_step${STEP} \
--max_model_len 8192 \
--tensor_parallel_size 8 \
--hf_token $HF_TOKEN
```

Results are written to `${OUTPUT}/rl_${MODEL}_step${STEP}/eval_results/` as JSON,
and optionally uploaded to GCS via `--gcs_results_path`.

**Notes:**
- `--hf_path` is required since vLLM uses it to fetch the model architecture
config and tokenizer even when loading weights from the MaxText checkpoint.
- Do not run this on the same machine as an active training job, both use vLLM
and will contend for TPU HBM.
- To limit dataset size, add `--num_samples 50`.

## Adding a New Benchmark

For custom datasets not covered by lm-eval or evalchemy:

1. Implement `BenchmarkDataset` in `src/maxtext/eval/datasets/`:

```python
from maxtext.eval.datasets.base import BenchmarkDataset, SampleRequest

class MyDataset(BenchmarkDataset):
name = "my_benchmark"

def sample_requests(self, num_samples, tokenizer) -> list[SampleRequest]:
# load dataset, build prompts, return SampleRequest list
```

2. Register it in `src/maxtext/eval/datasets/registry.py`:

```python
from maxtext.eval.datasets.my_dataset import MyDataset
DATASET_REGISTRY["my_benchmark"] = MyDataset
```

3. Add a scorer in `src/maxtext/eval/scoring/` and register it in
`src/maxtext/eval/scoring/registry.py`.
8 changes: 8 additions & 0 deletions src/maxtext/eval/configs/base_eval.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Base evaluation configuration.

temperature: 0.0
concurrency: 64
server_host: "localhost"
server_port: 8000
tensor_parallel_size: 4
num_samples: null
5 changes: 5 additions & 0 deletions src/maxtext/eval/configs/mlperf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# MLPerf OpenOrca evaluation config.

benchmark: "mlperf_openorca"
max_tokens: 1024
num_samples: 5000
57 changes: 57 additions & 0 deletions src/maxtext/eval/datasets/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Abstract base classes for benchmark datasets."""

from __future__ import annotations

import abc
from typing import NamedTuple


class SampleRequest(NamedTuple):
"""A single inference request with its ground-truth reference.

Attributes:
prompt: The full text prompt to send to the model (after chat templating).
reference: Ground-truth answer/label used by the scorer.
metadata: Optional dict of extra fields forwarded to the scorer
(e.g. {"subject": "college_math"} for per-subject MMLU stats).
"""

prompt: str
reference: str
metadata: dict | None = None


class BenchmarkDataset(abc.ABC):
"""Abstract base class for benchmark datasets."""
name: str

@abc.abstractmethod
def sample_requests(
self,
num_samples: int | None,
tokenizer,
) -> list[SampleRequest]:
"""Load the dataset and return a list of SampleRequests.

Args:
num_samples: If not None, truncate to this number of samples.
tokenizer: A HuggingFace tokenizer used for chat templating. Implementations
that do not require tokenization may ignore this parameter.

Returns:
List of SampleRequest objects ready for inference.
"""
63 changes: 63 additions & 0 deletions src/maxtext/eval/datasets/mlperf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MLPerf OpenOrca summarisation dataset."""

from __future__ import annotations

from maxtext.eval.datasets.base import BenchmarkDataset, SampleRequest

_SYSTEM_PROMPT = (
"You are a helpful assistant. Summarize the following conversation."
)


class MlperfOpenOrcaDataset(BenchmarkDataset):
"""MLPerf OpenOrca — summarisation benchmark used in MLPerf Inference.

Uses Open-Orca/OpenOrca HuggingFace dataset.
"""

name = "mlperf_openorca"

def sample_requests(self, num_samples, tokenizer) -> list[SampleRequest]:
# pylint: disable=import-outside-toplevel
import datasets as hf_datasets

ds = hf_datasets.load_dataset("Open-Orca/OpenOrca", split="train", streaming=True)

requests = []
for row in ds:
if not row.get("response", "").strip():
continue

system_prompt = row.get("system_prompt", _SYSTEM_PROMPT) or _SYSTEM_PROMPT
question = row["question"]
reference = row["response"]

if tokenizer is not None:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
prompt = f"{system_prompt}\n\nUser: {question}\nAssistant:"

requests.append(SampleRequest(prompt=prompt, reference=reference))

if num_samples is not None and len(requests) >= num_samples:
break

return requests
Loading
Loading