diff --git a/examples/scripts/code_runner.py b/examples/scripts/code_runner.py index b4d93c8a..150412a2 100644 --- a/examples/scripts/code_runner.py +++ b/examples/scripts/code_runner.py @@ -864,10 +864,6 @@ def _compile_one_kernel(kernel): logger.debug(f"Tensor order: {list(tensors.keys())}") logger.debug(f"func_args count: {len(func_args)}") - # Create and initialize runtime (including kernel registration) - logger.info("=== Initializing Runtime ===") - runtime = Runtime() - # Build environment for runtime initialization run_env = _kernel_config_runtime_env(self._kernel_config, self.kernels_dir) if run_env: @@ -883,6 +879,14 @@ def _compile_one_kernel(kernel): initial_outputs = {k: v.clone() for k, v in outputs.items()} + logger.info("=== Initializing Runtime ===") + runtime = Runtime() + + # Enable profiling if requested (only first round) + if self.enable_profiling: + runtime.enable_profiling(True) + logger.info("Profiling enabled") + for round_idx in range(self.repeat_rounds): if self.repeat_rounds > 1: logger.info(f"--- Round {round_idx + 1}/{self.repeat_rounds} ---") @@ -890,13 +894,6 @@ def _compile_one_kernel(kernel): for k, v in initial_outputs.items(): outputs[k].copy_(v) - runtime = Runtime() - - # Enable profiling if requested (only first round) - if self.enable_profiling and round_idx == 0: - runtime.enable_profiling(True) - logger.info("Profiling enabled") - with _temporary_env(run_env): runtime.initialize( orch_so_binary, @@ -917,7 +914,10 @@ def _compile_one_kernel(kernel): orch_thread_num=self.orch_thread_num, ) - runtime.finalize() + if round_idx < self.repeat_rounds - 1: + runtime.finalize_round() + else: + runtime.finalize() self._compare_with_golden(outputs, golden) logger.info(f"=== Case {case_idx + 1}/{total_cases} Passed ===") diff --git a/python/bindings.py b/python/bindings.py index 6474f35d..5dfb85f9 100644 --- a/python/bindings.py +++ b/python/bindings.py @@ -141,6 +141,20 @@ def _setup_functions(self): self.lib.finalize_runtime.argtypes = [c_void_p] self.lib.finalize_runtime.restype = c_int + # reinit_runtime - lightweight re-init for subsequent rounds + self.lib.reinit_runtime.argtypes = [ + c_void_p, # runtime + POINTER(c_uint64), # func_args + c_int, # func_args_count + POINTER(c_int), # arg_types + POINTER(c_uint64), # arg_sizes + ] + self.lib.reinit_runtime.restype = c_int + + # finalize_runtime_round - copy results back without freeing resources + self.lib.finalize_runtime_round.argtypes = [c_void_p] + self.lib.finalize_runtime_round.restype = c_int + # Note: register_kernel has been internalized into init_runtime # Kernel binaries are now passed directly to init_runtime() @@ -209,6 +223,7 @@ def __init__(self, lib: CDLL): size = lib.get_runtime_size() self._buffer = ctypes.create_string_buffer(size) self._handle = ctypes.cast(self._buffer, c_void_p) + self._initialized = False def initialize( self, @@ -248,6 +263,20 @@ def initialize( func_args = func_args or [] func_args_count = len(func_args) + # If already initialized, delegate to lightweight reinit + if self._initialized: + rc = self.lib.reinit_runtime( + self._handle, + (c_uint64 * len(func_args))(*func_args) if func_args else None, + func_args_count, + (c_int * len(arg_types))(*arg_types) if arg_types else None, + (c_uint64 * len(arg_sizes))(*arg_sizes) if arg_sizes else None, + ) + if rc == 0: + return + # Reinit not supported by this runtime, fallback to full finalize + init + self.finalize() + # Convert func_args to ctypes array if func_args_count > 0: func_args_array = (c_uint64 * func_args_count)(*func_args) @@ -310,6 +339,7 @@ def initialize( ) if rc != 0: raise RuntimeError(f"init_runtime failed: {rc}") + self._initialized = True def finalize(self) -> None: """ @@ -322,10 +352,76 @@ def finalize(self) -> None: Raises: RuntimeError: If finalization fails """ + if not self._initialized: + return rc = self.lib.finalize_runtime(self._handle) if rc != 0: raise RuntimeError(f"finalize_runtime failed: {rc}") + self._initialized = False + + def reinitialize( + self, + func_args: Optional[List[int]] = None, + arg_types: Optional[List[int]] = None, + arg_sizes: Optional[List[int]] = None, + ) -> None: + """ + Lightweight re-initialization for subsequent rounds within the same case. + + Skips kernel upload, GM heap/shared memory allocation, and orch SO copy. + Only re-copies input/inout tensor data to existing device memory. + + Args: + func_args: Arguments for orchestration (host pointers, sizes, etc.) + arg_types: Array describing each argument's type + arg_sizes: Array of sizes for pointer arguments (0 for scalars) + + Raises: + RuntimeError: If re-initialization fails + """ + func_args = func_args or [] + func_args_count = len(func_args) + + if func_args_count > 0: + func_args_array = (c_uint64 * func_args_count)(*func_args) + else: + func_args_array = None + + if arg_types is not None and len(arg_types) > 0: + arg_types_array = (c_int * len(arg_types))(*arg_types) + else: + arg_types_array = None + + if arg_sizes is not None and len(arg_sizes) > 0: + arg_sizes_array = (c_uint64 * len(arg_sizes))(*arg_sizes) + else: + arg_sizes_array = None + + rc = self.lib.reinit_runtime( + self._handle, + func_args_array, + func_args_count, + arg_types_array, + arg_sizes_array, + ) + if rc != 0: + raise RuntimeError(f"reinit_runtime failed: {rc}") + + def finalize_round(self) -> None: + """ + Round-level finalize: copy results back but keep device resources alive. + + Copies output/inout tensors from device to host without freeing + device memory or kernel binaries. Use between rounds in the same case. + + Raises: + RuntimeError: If round finalization fails + """ + rc = self.lib.finalize_runtime_round(self._handle) + if rc != 0: + # Not supported by this runtime, fallback to full finalize + self.finalize() def enable_profiling(self, enabled: bool = True) -> None: """ diff --git a/src/a2a3/platform/include/host/pto_runtime_c_api.h b/src/a2a3/platform/include/host/pto_runtime_c_api.h index bafca1b4..4792e4ef 100644 --- a/src/a2a3/platform/include/host/pto_runtime_c_api.h +++ b/src/a2a3/platform/include/host/pto_runtime_c_api.h @@ -173,6 +173,41 @@ int launch_runtime(RuntimeHandle runtime, size_t aicore_size, int orch_thread_num); +/** + * Lightweight re-initialization for subsequent rounds within the same case. + * + * Skips kernel upload, GM heap allocation, shared memory allocation, and + * orchestration SO copy. Only re-copies input/inout tensor data to device + * using the existing device memory allocations from the first round. + * + * Must be called after a successful init_runtime() + finalize_runtime_round() + * sequence. The Runtime handle must not have been fully finalized. + * + * @param runtime Runtime handle (previously initialized) + * @param func_args Arguments for orchestration (host pointers, sizes, etc.) + * @param func_args_count Number of arguments + * @param arg_types Array describing each argument's type (ArgType enum) + * @param arg_sizes Array of sizes for pointer arguments (0 for scalars) + * @return 0 on success, -1 on failure + */ +int reinit_runtime(RuntimeHandle runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes); + +/** + * Round-level finalize: copy results back but keep device resources alive. + * + * Copies output/inout tensors from device to host, but does NOT free + * device memory, kernel binaries, or call the Runtime destructor. + * Use this between rounds within the same case. + * + * @param runtime Runtime handle to finalize for this round + * @return 0 on success, -1 on failure + */ +int finalize_runtime_round(RuntimeHandle runtime); + /** * Finalize and cleanup a runtime instance. * diff --git a/src/a2a3/platform/onboard/host/pto_runtime_c_api.cpp b/src/a2a3/platform/onboard/host/pto_runtime_c_api.cpp index 95448c9a..d099d324 100644 --- a/src/a2a3/platform/onboard/host/pto_runtime_c_api.cpp +++ b/src/a2a3/platform/onboard/host/pto_runtime_c_api.cpp @@ -31,6 +31,12 @@ int init_runtime_impl(Runtime* runtime, const size_t* kernel_sizes, int kernel_count); int validate_runtime_impl(Runtime* runtime); +int reinit_runtime_impl(Runtime* runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes); +int validate_runtime_round_impl(Runtime* runtime); /* Forward declarations for device memory functions used in init_runtime */ void* device_malloc(size_t size); @@ -199,6 +205,34 @@ int launch_runtime(RuntimeHandle runtime, } } +int reinit_runtime(RuntimeHandle runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes) { + if (runtime == NULL) { + return -1; + } + try { + Runtime* r = static_cast(runtime); + return reinit_runtime_impl(r, func_args, func_args_count, arg_types, arg_sizes); + } catch (...) { + return -1; + } +} + +int finalize_runtime_round(RuntimeHandle runtime) { + if (runtime == NULL) { + return -1; + } + try { + Runtime* r = static_cast(runtime); + return validate_runtime_round_impl(r); + } catch (...) { + return -1; + } +} + int finalize_runtime(RuntimeHandle runtime) { if (runtime == NULL) { return -1; diff --git a/src/a2a3/platform/sim/host/pto_runtime_c_api.cpp b/src/a2a3/platform/sim/host/pto_runtime_c_api.cpp index 07c59c40..e5c6d41e 100644 --- a/src/a2a3/platform/sim/host/pto_runtime_c_api.cpp +++ b/src/a2a3/platform/sim/host/pto_runtime_c_api.cpp @@ -34,6 +34,12 @@ int init_runtime_impl(Runtime* runtime, const size_t* kernel_sizes, int kernel_count); int validate_runtime_impl(Runtime* runtime); +int reinit_runtime_impl(Runtime* runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes); +int validate_runtime_round_impl(Runtime* runtime); /* Forward declarations */ void* device_malloc(size_t size); @@ -202,6 +208,34 @@ int launch_runtime(RuntimeHandle runtime, } } +int reinit_runtime(RuntimeHandle runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes) { + if (runtime == NULL) { + return -1; + } + try { + Runtime* r = static_cast(runtime); + return reinit_runtime_impl(r, func_args, func_args_count, arg_types, arg_sizes); + } catch (...) { + return -1; + } +} + +int finalize_runtime_round(RuntimeHandle runtime) { + if (runtime == NULL) { + return -1; + } + try { + Runtime* r = static_cast(runtime); + return validate_runtime_round_impl(r); + } catch (...) { + return -1; + } +} + int finalize_runtime(RuntimeHandle runtime) { if (runtime == NULL) { return -1; diff --git a/src/a2a3/runtime/aicpu_build_graph/host/runtime_maker.cpp b/src/a2a3/runtime/aicpu_build_graph/host/runtime_maker.cpp index e1fb639d..d39a67ba 100644 --- a/src/a2a3/runtime/aicpu_build_graph/host/runtime_maker.cpp +++ b/src/a2a3/runtime/aicpu_build_graph/host/runtime_maker.cpp @@ -331,6 +331,25 @@ int validate_runtime_impl(Runtime* runtime) { return rc; } +int reinit_runtime_impl(Runtime* runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes) { + (void)func_args; + (void)func_args_count; + (void)arg_types; + (void)arg_sizes; + std::cerr << "Error: reinit_runtime_impl not supported for aicpu_build_graph runtime\n"; + return -1; +} + +int validate_runtime_round_impl(Runtime* runtime) { + (void)runtime; + std::cerr << "Error: validate_runtime_round_impl not supported for aicpu_build_graph runtime\n"; + return -1; +} + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/src/a2a3/runtime/host_build_graph/host/runtime_maker.cpp b/src/a2a3/runtime/host_build_graph/host/runtime_maker.cpp index 91db0ca0..69dd0f15 100644 --- a/src/a2a3/runtime/host_build_graph/host/runtime_maker.cpp +++ b/src/a2a3/runtime/host_build_graph/host/runtime_maker.cpp @@ -228,6 +228,25 @@ int validate_runtime_impl(Runtime *runtime) { return rc; } +int reinit_runtime_impl(Runtime* runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes) { + (void)func_args; + (void)func_args_count; + (void)arg_types; + (void)arg_sizes; + LOG_ERROR("reinit_runtime_impl not supported for host_build_graph runtime"); + return -1; +} + +int validate_runtime_round_impl(Runtime* runtime) { + (void)runtime; + LOG_ERROR("validate_runtime_round_impl not supported for host_build_graph runtime"); + return -1; +} + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/src/a2a3/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp b/src/a2a3/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp index 3b3c5e33..ef7f13bf 100644 --- a/src/a2a3/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp +++ b/src/a2a3/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp @@ -327,6 +327,118 @@ extern "C" int init_runtime_impl(Runtime *runtime, return 0; } +/** + * Lightweight re-initialization for subsequent rounds within the same case. + * + * Skips kernel upload, GM heap/shared memory allocation, and orch SO copy. + * Only re-copies input/inout tensor data to existing device memory. + * + * @param runtime Pointer to previously initialized Runtime + * @param func_args Arguments for orchestration (host pointers, sizes, etc.) + * @param func_args_count Number of arguments + * @param arg_types Array describing each argument's type (ArgType enum) + * @param arg_sizes Array of sizes for pointer arguments (0 for scalars) + * @return 0 on success, -1 on failure + */ +extern "C" int reinit_runtime_impl(Runtime *runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes) { + if (runtime == nullptr) { + LOG_ERROR("Runtime pointer is null"); + return -1; + } + + TensorPair* tensor_pairs = runtime->get_tensor_pairs(); + int pair_idx = 0; + + for (int i = 0; i < func_args_count; i++) { + if (arg_types[i] == ARG_SCALAR) continue; + + if (pair_idx >= runtime->get_tensor_pair_count()) { + LOG_ERROR("reinit: tensor_pair index out of range at arg %d", i); + return -1; + } + + const TensorPair& pair = tensor_pairs[pair_idx]; + void* host_ptr = reinterpret_cast(func_args[i]); + size_t size = arg_sizes[i]; + + if (arg_types[i] == ARG_INPUT_PTR || arg_types[i] == ARG_INOUT_PTR) { + int rc = runtime->host_api.copy_to_device(pair.dev_ptr, host_ptr, size); + if (rc != 0) { + LOG_ERROR("reinit: failed to re-copy arg %d to device", i); + return -1; + } + LOG_INFO("reinit: arg %d re-copied %zu bytes", i, size); + } + pair_idx++; + } + + LOG_INFO("reinit complete: re-copied %d tensor args", pair_idx); + return 0; +} + +/** + * Round-level validate: copy results back but keep device resources alive. + * + * Same copy-back logic as validate_runtime_impl, but does NOT free + * device memory, kernel binaries, or clear tensor pairs. + * + * @param runtime Pointer to Runtime + * @return 0 on success, -1 on failure + */ +extern "C" int validate_runtime_round_impl(Runtime *runtime) { + if (runtime == nullptr) { + LOG_ERROR("Runtime pointer is null"); + return -1; + } + + int rc = 0; + LOG_INFO("=== Round Finalize: Copying Results Back ==="); + + TensorPair* tensor_pairs = runtime->get_tensor_pairs(); + int tensor_pair_count = runtime->get_tensor_pair_count(); + + void* pto2_sm = runtime->get_pto2_gm_sm_ptr(); + uint64_t graph_out_ptr = 0; + uint64_t graph_out_size = 0; + + if (pto2_sm != nullptr) { + PTO2SharedMemoryHeader host_header; + int hdr_rc = runtime->host_api.copy_from_device(&host_header, pto2_sm, sizeof(PTO2SharedMemoryHeader)); + if (hdr_rc == 0) { + graph_out_ptr = host_header.graph_output_ptr; + graph_out_size = host_header.graph_output_size; + } + } + + bool first_output_tensor = true; + for (int i = 0; i < tensor_pair_count; i++) { + const TensorPair& pair = tensor_pairs[i]; + if (pair.dev_ptr == nullptr || pair.host_ptr == nullptr) continue; + + void* src_ptr = pair.dev_ptr; + size_t copy_size = pair.size; + + if (first_output_tensor && graph_out_ptr != 0 && graph_out_size > 0) { + src_ptr = reinterpret_cast(static_cast(graph_out_ptr)); + copy_size = static_cast(graph_out_size); + first_output_tensor = false; + } + + int copy_rc = runtime->host_api.copy_from_device(pair.host_ptr, src_ptr, copy_size); + if (copy_rc != 0) { + LOG_ERROR("Round finalize: failed to copy tensor %d from device", i); + rc = copy_rc; + } + } + + LOG_INFO("=== Round Finalize Complete ==="); + return rc; +} + /** * Validate runtime results and cleanup. * diff --git a/src/a5/platform/include/host/pto_runtime_c_api.h b/src/a5/platform/include/host/pto_runtime_c_api.h index c33b41a2..7b70861a 100644 --- a/src/a5/platform/include/host/pto_runtime_c_api.h +++ b/src/a5/platform/include/host/pto_runtime_c_api.h @@ -173,6 +173,41 @@ int launch_runtime(RuntimeHandle runtime, size_t aicore_size, int orch_thread_num); +/** + * Lightweight re-initialization for subsequent rounds within the same case. + * + * Skips kernel upload, GM heap allocation, shared memory allocation, and + * orchestration SO copy. Only re-copies input/inout tensor data to device + * using the existing device memory allocations from the first round. + * + * Must be called after a successful init_runtime() + finalize_runtime_round() + * sequence. The Runtime handle must not have been fully finalized. + * + * @param runtime Runtime handle (previously initialized) + * @param func_args Arguments for orchestration (host pointers, sizes, etc.) + * @param func_args_count Number of arguments + * @param arg_types Array describing each argument's type (ArgType enum) + * @param arg_sizes Array of sizes for pointer arguments (0 for scalars) + * @return 0 on success, -1 on failure + */ +int reinit_runtime(RuntimeHandle runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes); + +/** + * Round-level finalize: copy results back but keep device resources alive. + * + * Copies output/inout tensors from device to host, but does NOT free + * device memory, kernel binaries, or call the Runtime destructor. + * Use this between rounds within the same case. + * + * @param runtime Runtime handle to finalize for this round + * @return 0 on success, -1 on failure + */ +int finalize_runtime_round(RuntimeHandle runtime); + /** * Finalize and cleanup a runtime instance. * diff --git a/src/a5/platform/onboard/host/pto_runtime_c_api.cpp b/src/a5/platform/onboard/host/pto_runtime_c_api.cpp index 95448c9a..d099d324 100644 --- a/src/a5/platform/onboard/host/pto_runtime_c_api.cpp +++ b/src/a5/platform/onboard/host/pto_runtime_c_api.cpp @@ -31,6 +31,12 @@ int init_runtime_impl(Runtime* runtime, const size_t* kernel_sizes, int kernel_count); int validate_runtime_impl(Runtime* runtime); +int reinit_runtime_impl(Runtime* runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes); +int validate_runtime_round_impl(Runtime* runtime); /* Forward declarations for device memory functions used in init_runtime */ void* device_malloc(size_t size); @@ -199,6 +205,34 @@ int launch_runtime(RuntimeHandle runtime, } } +int reinit_runtime(RuntimeHandle runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes) { + if (runtime == NULL) { + return -1; + } + try { + Runtime* r = static_cast(runtime); + return reinit_runtime_impl(r, func_args, func_args_count, arg_types, arg_sizes); + } catch (...) { + return -1; + } +} + +int finalize_runtime_round(RuntimeHandle runtime) { + if (runtime == NULL) { + return -1; + } + try { + Runtime* r = static_cast(runtime); + return validate_runtime_round_impl(r); + } catch (...) { + return -1; + } +} + int finalize_runtime(RuntimeHandle runtime) { if (runtime == NULL) { return -1; diff --git a/src/a5/platform/sim/host/pto_runtime_c_api.cpp b/src/a5/platform/sim/host/pto_runtime_c_api.cpp index 07c59c40..e5c6d41e 100644 --- a/src/a5/platform/sim/host/pto_runtime_c_api.cpp +++ b/src/a5/platform/sim/host/pto_runtime_c_api.cpp @@ -34,6 +34,12 @@ int init_runtime_impl(Runtime* runtime, const size_t* kernel_sizes, int kernel_count); int validate_runtime_impl(Runtime* runtime); +int reinit_runtime_impl(Runtime* runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes); +int validate_runtime_round_impl(Runtime* runtime); /* Forward declarations */ void* device_malloc(size_t size); @@ -202,6 +208,34 @@ int launch_runtime(RuntimeHandle runtime, } } +int reinit_runtime(RuntimeHandle runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes) { + if (runtime == NULL) { + return -1; + } + try { + Runtime* r = static_cast(runtime); + return reinit_runtime_impl(r, func_args, func_args_count, arg_types, arg_sizes); + } catch (...) { + return -1; + } +} + +int finalize_runtime_round(RuntimeHandle runtime) { + if (runtime == NULL) { + return -1; + } + try { + Runtime* r = static_cast(runtime); + return validate_runtime_round_impl(r); + } catch (...) { + return -1; + } +} + int finalize_runtime(RuntimeHandle runtime) { if (runtime == NULL) { return -1; diff --git a/src/a5/runtime/host_build_graph/host/runtime_maker.cpp b/src/a5/runtime/host_build_graph/host/runtime_maker.cpp index 91db0ca0..69dd0f15 100644 --- a/src/a5/runtime/host_build_graph/host/runtime_maker.cpp +++ b/src/a5/runtime/host_build_graph/host/runtime_maker.cpp @@ -228,6 +228,25 @@ int validate_runtime_impl(Runtime *runtime) { return rc; } +int reinit_runtime_impl(Runtime* runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes) { + (void)func_args; + (void)func_args_count; + (void)arg_types; + (void)arg_sizes; + LOG_ERROR("reinit_runtime_impl not supported for host_build_graph runtime"); + return -1; +} + +int validate_runtime_round_impl(Runtime* runtime) { + (void)runtime; + LOG_ERROR("validate_runtime_round_impl not supported for host_build_graph runtime"); + return -1; +} + #ifdef __cplusplus } /* extern "C" */ #endif diff --git a/src/a5/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp b/src/a5/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp index ae22d562..abe04d13 100644 --- a/src/a5/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp +++ b/src/a5/runtime/tensormap_and_ringbuffer/host/runtime_maker.cpp @@ -318,6 +318,118 @@ extern "C" int init_runtime_impl(Runtime *runtime, return 0; } +/** + * Lightweight re-initialization for subsequent rounds within the same case. + * + * Skips kernel upload, GM heap/shared memory allocation, and orch SO copy. + * Only re-copies input/inout tensor data to existing device memory. + * + * @param runtime Pointer to previously initialized Runtime + * @param func_args Arguments for orchestration (host pointers, sizes, etc.) + * @param func_args_count Number of arguments + * @param arg_types Array describing each argument's type (ArgType enum) + * @param arg_sizes Array of sizes for pointer arguments (0 for scalars) + * @return 0 on success, -1 on failure + */ +extern "C" int reinit_runtime_impl(Runtime *runtime, + uint64_t* func_args, + int func_args_count, + int* arg_types, + uint64_t* arg_sizes) { + if (runtime == nullptr) { + LOG_ERROR("Runtime pointer is null"); + return -1; + } + + TensorPair* tensor_pairs = runtime->get_tensor_pairs(); + int pair_idx = 0; + + for (int i = 0; i < func_args_count; i++) { + if (arg_types[i] == ARG_SCALAR) continue; + + if (pair_idx >= runtime->get_tensor_pair_count()) { + LOG_ERROR("reinit: tensor_pair index out of range at arg %d", i); + return -1; + } + + const TensorPair& pair = tensor_pairs[pair_idx]; + void* host_ptr = reinterpret_cast(func_args[i]); + size_t size = arg_sizes[i]; + + if (arg_types[i] == ARG_INPUT_PTR || arg_types[i] == ARG_INOUT_PTR) { + int rc = runtime->host_api.copy_to_device(pair.dev_ptr, host_ptr, size); + if (rc != 0) { + LOG_ERROR("reinit: failed to re-copy arg %d to device", i); + return -1; + } + LOG_INFO("reinit: arg %d re-copied %zu bytes", i, size); + } + pair_idx++; + } + + LOG_INFO("reinit complete: re-copied %d tensor args", pair_idx); + return 0; +} + +/** + * Round-level validate: copy results back but keep device resources alive. + * + * Same copy-back logic as validate_runtime_impl, but does NOT free + * device memory, kernel binaries, or clear tensor pairs. + * + * @param runtime Pointer to Runtime + * @return 0 on success, -1 on failure + */ +extern "C" int validate_runtime_round_impl(Runtime *runtime) { + if (runtime == nullptr) { + LOG_ERROR("Runtime pointer is null"); + return -1; + } + + int rc = 0; + LOG_INFO("=== Round Finalize: Copying Results Back ==="); + + TensorPair* tensor_pairs = runtime->get_tensor_pairs(); + int tensor_pair_count = runtime->get_tensor_pair_count(); + + void* pto2_sm = runtime->get_pto2_gm_sm_ptr(); + uint64_t graph_out_ptr = 0; + uint64_t graph_out_size = 0; + + if (pto2_sm != nullptr) { + PTO2SharedMemoryHeader host_header; + int hdr_rc = runtime->host_api.copy_from_device(&host_header, pto2_sm, sizeof(PTO2SharedMemoryHeader)); + if (hdr_rc == 0) { + graph_out_ptr = host_header.graph_output_ptr; + graph_out_size = host_header.graph_output_size; + } + } + + bool first_output_tensor = true; + for (int i = 0; i < tensor_pair_count; i++) { + const TensorPair& pair = tensor_pairs[i]; + if (pair.dev_ptr == nullptr || pair.host_ptr == nullptr) continue; + + void* src_ptr = pair.dev_ptr; + size_t copy_size = pair.size; + + if (first_output_tensor && graph_out_ptr != 0 && graph_out_size > 0) { + src_ptr = reinterpret_cast(static_cast(graph_out_ptr)); + copy_size = static_cast(graph_out_size); + first_output_tensor = false; + } + + int copy_rc = runtime->host_api.copy_from_device(pair.host_ptr, src_ptr, copy_size); + if (copy_rc != 0) { + LOG_ERROR("Round finalize: failed to copy tensor %d from device", i); + rc = copy_rc; + } + } + + LOG_INFO("=== Round Finalize Complete ==="); + return rc; +} + /** * Validate runtime results and cleanup. *