Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/maxtext/common/gcloud_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ def get_bucket(self, *a, **k): # pylint: disable=unused-argument
def bucket(self, *a, **k): # pylint: disable=unused-argument
return _StubBucket()

def list_blobs(self, *a, **k): # pylint: disable=unused-argument
return iter([])

return SimpleNamespace(Client=_StubClient, _IS_STUB=True)


Expand Down
21 changes: 16 additions & 5 deletions tests/integration/smoke/train_tokenizer_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import unittest
import pytest

from maxtext.common.gcloud_stub import is_decoupled
from maxtext.input_pipeline import input_pipeline_utils
from maxtext.trainers.tokenizer import train_tokenizer
from tests.utils.test_helpers import get_test_dataset_path


class TrainTokenizerFormatTest(unittest.TestCase):
Expand Down Expand Up @@ -49,17 +51,26 @@ def _run_format_test(self, file_pattern, file_type):

@pytest.mark.cpu_only
def test_parquet(self):
self._run_format_test("gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet", "parquet")
path = os.path.join(get_test_dataset_path(), "hf", "c4", "c4-train-00000-of-01637.parquet")
self._run_format_test(path, "parquet")

@pytest.mark.cpu_only
def test_arrayrecord(self):
self._run_format_test(
"gs://maxtext-dataset/array-record/c4/en/3.0.1/c4-train.array_record-00000-of-01024", "arrayrecord"
)
dataset_root = get_test_dataset_path()
if is_decoupled():
path = os.path.join(dataset_root, "c4", "en", "3.0.1", "c4-train.array_record-00000-of-00008")
else:
path = os.path.join(dataset_root, "array-record", "c4", "en", "3.0.1", "c4-train.array_record-00000-of-01024")
self._run_format_test(path, "arrayrecord")

@pytest.mark.cpu_only
def test_tfrecord(self):
self._run_format_test("gs://maxtext-dataset/c4/en/3.0.1/c4-train.tfrecord-00000-of-01024", "tfrecord")
dataset_root = get_test_dataset_path()
if is_decoupled():
path = os.path.join(dataset_root, "c4", "en", "3.0.1", "__local_c4_builder-train.tfrecord-00000-of-00008")
else:
path = os.path.join(dataset_root, "c4", "en", "3.0.1", "c4-train.tfrecord-00000-of-01024")
self._run_format_test(path, "tfrecord")


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions tests/unit/diloco_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def loss_fn(params, batch):
chex.assert_trees_all_equal(diloco_test_state.params, step_three_outer_params)

@pytest.mark.cpu_only
@pytest.mark.tpu_backend
def test_diloco_qwen3_moe_two_slices(self):
temp_dir = gettempdir()
compiled_trainstep_file = os.path.join(temp_dir, "test_compiled_diloco_qwen3_moe.pickle")
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/grain_data_processing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,9 @@ def setUp(self):
grain_train_files = os.path.join(
dataset_root,
"c4",
"array-record",
"en",
"3.0.1",
"c4-train.array_record-00000-of-01024",
"c4-train.array_record-00000-of-00008",
)
base_output_directory = get_test_base_output_directory()
else:
Expand Down Expand Up @@ -384,7 +383,7 @@ def setUp(self):
"c4",
"en",
"3.0.1",
"c4-train.tfrecord-00000-of-01024",
"__local_c4_builder-train.tfrecord-00000-of-00008",
)
base_output_directory = get_test_base_output_directory()
else:
Expand Down Expand Up @@ -427,6 +426,7 @@ def setUp(self):
self.train_iter = grain_data_processing.make_grain_train_iterator(self.config, self.mesh, self.process_indices)


@pytest.mark.external_training
class GrainSFTParquetProcessingTest(unittest.TestCase):
"""Tests the SFT pipeline end-to-end using the real ultrachat_200k parquet dataset."""

Expand Down
Loading