diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index cca0af7628..708b40f97c 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -19,7 +19,7 @@ import math import threading from abc import ABC, abstractmethod -from collections.abc import Iterator +from collections.abc import Callable, Iterator from copy import copy from enum import Enum from types import TracebackType @@ -1184,6 +1184,92 @@ def existing(self, entry: ManifestEntry) -> ManifestWriter: return self +class RollingManifestWriter: + """As opposed to ManifestWriter, a rolling writer could produce multiple manifest files.""" + + _ROWS_DIVISOR = 250 + + def __init__( + self, + supplier: Callable[[], ManifestWriter], + target_file_size_in_bytes: int, + ) -> None: + self._supplier = supplier + self._target_file_size_in_bytes = target_file_size_in_bytes + self._manifest_files: list[ManifestFile] = [] + self._current_writer: ManifestWriter | None = None + self._current_file_rows: int = 0 + self._closed: bool = False + + def __enter__(self) -> RollingManifestWriter: + """Open the rolling manifest writer.""" + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + """Close the rolling manifest writer and finalize all manifests.""" + try: + self._close_current_writer(exc_type, exc_value, traceback) + finally: + self._closed = True + + def _get_current_writer(self) -> ManifestWriter: + if self._should_roll_to_new_file(): + self._close_current_writer() + if not self._current_writer: + self._current_writer = self._supplier() + self._current_writer.__enter__() + return self._current_writer + + def _should_roll_to_new_file(self) -> bool: + if not self._current_writer or self._current_file_rows == 0: + return False + return ( + self._current_file_rows % self._ROWS_DIVISOR == 0 and self._current_writer.tell() >= self._target_file_size_in_bytes + ) + + def _close_current_writer( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + if self._current_writer: + if self._current_file_rows > 0: + self._current_writer.__exit__(exc_type, exc_value, traceback) + self._manifest_files.append(self._current_writer.to_manifest_file()) + else: + try: + self._current_writer.__exit__(None, None, None) + except ValueError: + pass + self._current_writer = None + self._current_file_rows = 0 + + def add_entry(self, entry: ManifestEntry) -> RollingManifestWriter: + if self._closed: + raise RuntimeError("Cannot add entry to closed manifest writer") + self._get_current_writer().add_entry(entry) + self._current_file_rows += 1 + return self + + def add(self, entry: ManifestEntry) -> RollingManifestWriter: + if self._closed: + raise RuntimeError("Cannot add entry to closed manifest writer") + self._get_current_writer().add(entry) + self._current_file_rows += 1 + return self + + def to_manifest_files(self) -> list[ManifestFile]: + if not self._closed: + raise RuntimeError("Cannot create manifest files from unclosed writer") + return self._manifest_files + + class ManifestWriterV1(ManifestWriter): def __init__( self, diff --git a/tests/utils/test_manifest.py b/tests/utils/test_manifest.py index 3f859b3b32..8a0b73910f 100644 --- a/tests/utils/test_manifest.py +++ b/tests/utils/test_manifest.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=redefined-outer-name,arguments-renamed,fixme +from collections.abc import Callable from tempfile import TemporaryDirectory import fastavro @@ -31,7 +32,9 @@ ManifestEntry, ManifestEntryStatus, ManifestFile, + ManifestWriter, PartitionFieldSummary, + RollingManifestWriter, _manifest_cache, _manifests, read_manifest_list, @@ -932,3 +935,98 @@ def test_manifest_writer_tell(format_version: TableVersion) -> None: after_entry_bytes = writer.tell() assert after_entry_bytes > initial_bytes, "Bytes should increase after adding entry" + + +@pytest.mark.parametrize("format_version", [1, 2]) +def test_rolling_manifest_writer_stays_in_one_file_under_target(format_version: TableVersion) -> None: + with TemporaryDirectory() as tmpdir: + supplier = _create_manifest_writer_supplier( + tmpdir, format_version, Schema(NestedField(1, "id", IntegerType(), required=True)) + ) + entries = [_create_simple_entry(i) for i in range(100)] + + with RollingManifestWriter(supplier=supplier, target_file_size_in_bytes=10000) as writer: + for entry in entries: + writer.add_entry(entry) + + assert len(writer.to_manifest_files()) == 1 + + +@pytest.mark.parametrize("format_version", [1, 2]) +def test_rolling_manifest_writer_splits_when_over_target(format_version: TableVersion) -> None: + with TemporaryDirectory() as tmpdir: + supplier = _create_manifest_writer_supplier( + tmpdir, format_version, Schema(NestedField(1, "id", IntegerType(), required=True)) + ) + entries = [_create_simple_entry(i) for i in range(500)] + + with RollingManifestWriter(supplier=supplier, target_file_size_in_bytes=1) as writer: + for entry in entries: + writer.add_entry(entry) + + manifest_files = writer.to_manifest_files() + # writer will check size every 250 entries. Target=1 forces splits at 250 and 500. + assert len(manifest_files) == 2 + + with pytest.raises(RuntimeError, match="Cannot add entry to closed"): + writer.add_entry(entries[0]) + + +@pytest.mark.parametrize("format_version", [1, 2]) +def test_rolling_manifest_writer_empty(format_version: TableVersion) -> None: + with TemporaryDirectory() as tmpdir: + supplier = _create_manifest_writer_supplier( + tmpdir, format_version, Schema(NestedField(1, "id", IntegerType(), required=True)) + ) + + with RollingManifestWriter(supplier=supplier, target_file_size_in_bytes=42) as writer: + pass + + assert writer.to_manifest_files() == [] + + +def _create_manifest_writer_supplier( + tmpdir: str, + format_version: TableVersion, + schema: Schema, + snapshot_id: int = 1, +) -> Callable[[], ManifestWriter]: + counter = [0] + io = PyArrowFileIO() + + def _supplier() -> ManifestWriter: + output_file = io.new_output(f"{tmpdir}/manifest-{counter[0]}.avro") + counter[0] += 1 + return write_manifest( + format_version=format_version, + spec=UNPARTITIONED_PARTITION_SPEC, + schema=schema, + output_file=output_file, + snapshot_id=snapshot_id, + avro_compression="null", + ) + + return _supplier + + +def _create_simple_entry( + i: int, + status: ManifestEntryStatus = ManifestEntryStatus.ADDED, + sequence_number: int | None = 1, +) -> ManifestEntry: + data_file = DataFile.from_args( + content=DataFileContent.DATA, + file_path=f"data-{i}.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=1, + file_size_in_bytes=1000, + ) + return ManifestEntry.from_args( + status=status, + snapshot_id=1, + sequence_number=sequence_number, + data_sequence_number=1, + file_sequence_number=1, + data_file=data_file, + )