Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def pending_handler(self) -> PendingResponse:
"""
return PendingResponse(count=0)

async def partitions_handler(self) -> PartitionsResponse:
async def active_partitions_handler(self) -> PartitionsResponse:
"""
The simple source always returns default partitions.
"""
Expand Down
4 changes: 3 additions & 1 deletion packages/pynumaflow/pynumaflow/proto/sourcer/source.proto
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,10 @@ message PendingResponse {
*/
message PartitionsResponse {
message Result {
// Required field holding the list of partitions.
// Required field holding the list of active partitions.
repeated int32 partitions = 1;
// Total number of partitions in the source.
optional int32 total_partitions = 2;
}
// Required field holding the result.
Result result = 1;
Expand Down
18 changes: 9 additions & 9 deletions packages/pynumaflow/pynumaflow/proto/sourcer/source_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions packages/pynumaflow/pynumaflow/proto/sourcer/source_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,12 @@ class PendingResponse(_message.Message):
class PartitionsResponse(_message.Message):
__slots__ = ("result",)
class Result(_message.Message):
__slots__ = ("partitions",)
__slots__ = ("partitions", "total_partitions")
PARTITIONS_FIELD_NUMBER: _ClassVar[int]
TOTAL_PARTITIONS_FIELD_NUMBER: _ClassVar[int]
partitions: _containers.RepeatedScalarFieldContainer[int]
def __init__(self, partitions: _Optional[_Iterable[int]] = ...) -> None: ...
total_partitions: int
def __init__(self, partitions: _Optional[_Iterable[int]] = ..., total_partitions: _Optional[int] = ...) -> None: ...
RESULT_FIELD_NUMBER: _ClassVar[int]
result: PartitionsResponse.Result
def __init__(self, result: _Optional[_Union[PartitionsResponse.Result, _Mapping]] = ...) -> None: ...
Expand Down
21 changes: 15 additions & 6 deletions packages/pynumaflow/pynumaflow/sourcer/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,10 @@ def count(self) -> int:
class PartitionsResponse:
"""
PartitionsResponse is the response for the partition request.
It indicates the number of partitions at the user defined source.
A negative count indicates that the partition information is not available.
It indicates the active partitions at the user defined source.

Args:
count: the number of partitions.
partitions: the list of active partitions.
"""

_partitions: list[int]
Expand All @@ -256,7 +255,7 @@ def __init__(self, partitions: list[int]):

@property
def partitions(self) -> list[int]:
"""Returns the list of partitions"""
"""Returns the list of active partitions"""
return self._partitions


Expand Down Expand Up @@ -298,12 +297,22 @@ async def pending_handler(self) -> PendingResponse:
pass

@abstractmethod
async def partitions_handler(self) -> PartitionsResponse:
async def active_partitions_handler(self) -> PartitionsResponse:
"""
The simple source always returns zero to indicate there is no pending record.
Returns the active partitions associated with the source, used by the platform
to determine the partitions to which the watermark should be published.
"""
pass

async def total_partitions_handler(self) -> int | None:
"""
Returns the total number of partitions in the source.
Used by the platform for watermark progression to know when all
processors have reported in.
Returns None by default, indicating the source does not report total partitions.
"""
return None


# Create default partition id from the environment variable "NUMAFLOW_REPLICA"
DefaultPartitionId = int(os.getenv("NUMAFLOW_REPLICA", "0"))
Expand Down
2 changes: 1 addition & 1 deletion packages/pynumaflow/pynumaflow/sourcer/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ async def pending_handler(self) -> PendingResponse:
'''
return PendingResponse(count=0)

async def partitions_handler(self) -> PartitionsResponse:
async def active_partitions_handler(self) -> PartitionsResponse:
'''
The simple source always returns default partitions.
'''
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def __initialize_handlers(self):
self.__source_ack_handler = self.source_handler.ack_handler
self.__source_nack_handler = self.source_handler.nack_handler
self.__source_pending_handler = self.source_handler.pending_handler
self.__source_partitions_handler = self.source_handler.partitions_handler
self.__source_active_partitions_handler = self.source_handler.active_partitions_handler
self.__source_total_partitions_handler = self.source_handler.total_partitions_handler

async def ReadFn(
self,
Expand Down Expand Up @@ -278,10 +279,11 @@ async def PartitionsFn(
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
) -> source_pb2.PartitionsResponse:
"""
PartitionsFn returns the partitions of the user defined source.
PartitionsFn returns the active partitions and total partitions of the user defined source.
"""
try:
partitions = await self.__source_partitions_handler()
partitions = await self.__source_active_partitions_handler()
total_partitions = await self.__source_total_partitions_handler()
except asyncio.CancelledError:
# Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault.
_LOGGER.info("Server shutting down, cancelling RPC.")
Expand All @@ -301,8 +303,10 @@ async def PartitionsFn(
return source_pb2.PartitionsResponse(
result=source_pb2.PartitionsResponse.Result(partitions=[])
)
resp = source_pb2.PartitionsResponse.Result(partitions=partitions.partitions)
return source_pb2.PartitionsResponse(result=resp)
result = source_pb2.PartitionsResponse.Result(
partitions=partitions.partitions, total_partitions=total_partitions
)
return source_pb2.PartitionsResponse(result=result)

def clean_background(self, task):
"""
Expand Down
61 changes: 61 additions & 0 deletions packages/pynumaflow/tests/source/test_async_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ack_req_source_fn,
mock_partitions,
AsyncSource,
AsyncSourceWithTotalPartitions,
mock_offset,
nack_req_source_fn,
)
Expand Down Expand Up @@ -194,6 +195,66 @@ def test_partitions(async_source_server) -> None:
assert response.result.partitions == mock_partitions()


def test_partitions_default_total_partitions_is_none(async_source_server) -> None:
"""
Verify total_partitions is not set when the source doesn't override
total_partitions_handler.
"""
with grpc.insecure_channel(server_port) as channel:
stub = source_pb2_grpc.SourceStub(channel)
request = _empty_pb2.Empty()
response = stub.PartitionsFn(request=request)

assert response.result.partitions == mock_partitions()
assert not response.result.HasField("total_partitions")


server_port_tp = "unix:///tmp/async_source_tp.sock"


def NewAsyncSourcerWithTotalPartitions():
class_instance = AsyncSourceWithTotalPartitions()
server = SourceAsyncServer(sourcer_instance=class_instance)
udfs = server.servicer
return udfs


async def start_server_tp(udfs):
server = grpc.aio.server()
source_pb2_grpc.add_SourceServicer_to_server(udfs, server)
listen_addr = server_port_tp
server.add_insecure_port(listen_addr)
logging.info("Starting server on %s", listen_addr)
await server.start()
return server, listen_addr


@pytest.fixture(scope="module")
def async_source_server_with_total_partitions():
"""Module-scoped fixture: starts an async gRPC source server with total partitions."""
loop = create_async_loop()

udfs = NewAsyncSourcerWithTotalPartitions()
server = start_async_server(loop, start_server_tp(udfs))

yield loop

teardown_async_server(loop, server)


def test_partitions_with_total_partitions(async_source_server_with_total_partitions) -> None:
"""
Verify total_partitions flows through gRPC when the source implements total_partitions_handler.
"""
with grpc.insecure_channel(server_port_tp) as channel:
stub = source_pb2_grpc.SourceStub(channel)
request = _empty_pb2.Empty()
response = stub.PartitionsFn(request=request)

assert response.result.partitions == mock_partitions()
assert response.result.total_partitions == 10


@pytest.mark.parametrize(
"max_threads_arg,expected",
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def nack_handler(self, nack_request: NackRequest):
async def pending_handler(self) -> PendingResponse:
return PendingResponse(count=0)

async def partitions_handler(self) -> PartitionsResponse:
async def active_partitions_handler(self) -> PartitionsResponse:
return PartitionsResponse(partitions=[])


Expand Down Expand Up @@ -194,7 +194,7 @@ async def _run():
async def _cancelled_partitions():
raise asyncio.CancelledError()

handler.partitions_handler = _cancelled_partitions
handler.active_partitions_handler = _cancelled_partitions

servicer = AsyncSourceServicer(source_handler=handler)
shutdown_event = asyncio.Event()
Expand Down
38 changes: 36 additions & 2 deletions packages/pynumaflow/tests/source/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,44 @@ async def nack_handler(self, nack_request: NackRequest):
async def pending_handler(self) -> PendingResponse:
return PendingResponse(count=10)

async def partitions_handler(self) -> PartitionsResponse:
async def active_partitions_handler(self) -> PartitionsResponse:
return PartitionsResponse(partitions=mock_partitions())


class AsyncSourceWithTotalPartitions(Sourcer):
"""A test source that implements active_partitions_handler and total_partitions_handler."""

async def read_handler(self, datum: ReadRequest, output: NonBlockingIterator):
payload = b"payload:test_mock_message"
keys = ["test_key"]
offset = mock_offset()
event_time = mock_event_time()
for i in range(10):
await output.put(
Message(
payload=payload,
keys=keys,
offset=offset,
event_time=event_time,
)
)

async def ack_handler(self, ack_request: AckRequest):
return

async def nack_handler(self, nack_request: NackRequest):
return

async def pending_handler(self) -> PendingResponse:
return PendingResponse(count=10)

async def active_partitions_handler(self) -> PartitionsResponse:
return PartitionsResponse(partitions=mock_partitions())

async def total_partitions_handler(self) -> int | None:
return 10


def read_req_source_fn() -> ReadRequest:
request = source_pb2.ReadRequest.Request(
num_records=10,
Expand Down Expand Up @@ -102,5 +136,5 @@ async def nack_handler(self, nack_request: NackRequest):
async def pending_handler(self) -> PendingResponse:
raise RuntimeError("Got a runtime error from pending handler.")

async def partitions_handler(self) -> PartitionsResponse:
async def active_partitions_handler(self) -> PartitionsResponse:
raise RuntimeError("Got a runtime error from partition handler.")
Loading
Loading