Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions examples/scripts/code_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,10 +864,6 @@ def _compile_one_kernel(kernel):
logger.debug(f"Tensor order: {list(tensors.keys())}")
logger.debug(f"func_args count: {len(func_args)}")

# Create and initialize runtime (including kernel registration)
logger.info("=== Initializing Runtime ===")
runtime = Runtime()

# Build environment for runtime initialization
run_env = _kernel_config_runtime_env(self._kernel_config, self.kernels_dir)
if run_env:
Expand All @@ -883,20 +879,21 @@ def _compile_one_kernel(kernel):

initial_outputs = {k: v.clone() for k, v in outputs.items()}

logger.info("=== Initializing Runtime ===")
runtime = Runtime()

# Enable profiling if requested (only first round)
if self.enable_profiling:
runtime.enable_profiling(True)
logger.info("Profiling enabled")

for round_idx in range(self.repeat_rounds):
if self.repeat_rounds > 1:
logger.info(f"--- Round {round_idx + 1}/{self.repeat_rounds} ---")

for k, v in initial_outputs.items():
outputs[k].copy_(v)

runtime = Runtime()

# Enable profiling if requested (only first round)
if self.enable_profiling and round_idx == 0:
runtime.enable_profiling(True)
logger.info("Profiling enabled")

with _temporary_env(run_env):
runtime.initialize(
orch_so_binary,
Expand All @@ -917,7 +914,10 @@ def _compile_one_kernel(kernel):
orch_thread_num=self.orch_thread_num,
)

runtime.finalize()
if round_idx < self.repeat_rounds - 1:
runtime.finalize_round()
else:
runtime.finalize()
self._compare_with_golden(outputs, golden)

logger.info(f"=== Case {case_idx + 1}/{total_cases} Passed ===")
Expand Down
96 changes: 96 additions & 0 deletions python/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ def _setup_functions(self):
self.lib.finalize_runtime.argtypes = [c_void_p]
self.lib.finalize_runtime.restype = c_int

# reinit_runtime - lightweight re-init for subsequent rounds
self.lib.reinit_runtime.argtypes = [
c_void_p, # runtime
POINTER(c_uint64), # func_args
c_int, # func_args_count
POINTER(c_int), # arg_types
POINTER(c_uint64), # arg_sizes
]
self.lib.reinit_runtime.restype = c_int

# finalize_runtime_round - copy results back without freeing resources
self.lib.finalize_runtime_round.argtypes = [c_void_p]
self.lib.finalize_runtime_round.restype = c_int

# Note: register_kernel has been internalized into init_runtime
# Kernel binaries are now passed directly to init_runtime()

Expand Down Expand Up @@ -209,6 +223,7 @@ def __init__(self, lib: CDLL):
size = lib.get_runtime_size()
self._buffer = ctypes.create_string_buffer(size)
self._handle = ctypes.cast(self._buffer, c_void_p)
self._initialized = False

def initialize(
self,
Expand Down Expand Up @@ -248,6 +263,20 @@ def initialize(
func_args = func_args or []
func_args_count = len(func_args)

# If already initialized, delegate to lightweight reinit
if self._initialized:
rc = self.lib.reinit_runtime(
self._handle,
(c_uint64 * len(func_args))(*func_args) if func_args else None,
func_args_count,
(c_int * len(arg_types))(*arg_types) if arg_types else None,
(c_uint64 * len(arg_sizes))(*arg_sizes) if arg_sizes else None,
)
if rc == 0:
return
# Reinit not supported by this runtime, fallback to full finalize + init
self.finalize()

# Convert func_args to ctypes array
if func_args_count > 0:
func_args_array = (c_uint64 * func_args_count)(*func_args)
Expand Down Expand Up @@ -310,6 +339,7 @@ def initialize(
)
if rc != 0:
raise RuntimeError(f"init_runtime failed: {rc}")
self._initialized = True

def finalize(self) -> None:
"""
Expand All @@ -322,10 +352,76 @@ def finalize(self) -> None:
Raises:
RuntimeError: If finalization fails
"""
if not self._initialized:
return

rc = self.lib.finalize_runtime(self._handle)
if rc != 0:
raise RuntimeError(f"finalize_runtime failed: {rc}")
self._initialized = False

def reinitialize(
self,
func_args: Optional[List[int]] = None,
arg_types: Optional[List[int]] = None,
arg_sizes: Optional[List[int]] = None,
) -> None:
"""
Lightweight re-initialization for subsequent rounds within the same case.

Skips kernel upload, GM heap/shared memory allocation, and orch SO copy.
Only re-copies input/inout tensor data to existing device memory.

Args:
func_args: Arguments for orchestration (host pointers, sizes, etc.)
arg_types: Array describing each argument's type
arg_sizes: Array of sizes for pointer arguments (0 for scalars)

Raises:
RuntimeError: If re-initialization fails
"""
func_args = func_args or []
func_args_count = len(func_args)

if func_args_count > 0:
func_args_array = (c_uint64 * func_args_count)(*func_args)
else:
func_args_array = None

if arg_types is not None and len(arg_types) > 0:
arg_types_array = (c_int * len(arg_types))(*arg_types)
else:
arg_types_array = None

if arg_sizes is not None and len(arg_sizes) > 0:
arg_sizes_array = (c_uint64 * len(arg_sizes))(*arg_sizes)
else:
arg_sizes_array = None

rc = self.lib.reinit_runtime(
self._handle,
func_args_array,
func_args_count,
arg_types_array,
arg_sizes_array,
)
if rc != 0:
raise RuntimeError(f"reinit_runtime failed: {rc}")

def finalize_round(self) -> None:
"""
Round-level finalize: copy results back but keep device resources alive.

Copies output/inout tensors from device to host without freeing
device memory or kernel binaries. Use between rounds in the same case.

Raises:
RuntimeError: If round finalization fails
"""
rc = self.lib.finalize_runtime_round(self._handle)
if rc != 0:
# Not supported by this runtime, fallback to full finalize
self.finalize()

def enable_profiling(self, enabled: bool = True) -> None:
"""
Expand Down
35 changes: 35 additions & 0 deletions src/a2a3/platform/include/host/pto_runtime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,41 @@ int launch_runtime(RuntimeHandle runtime,
size_t aicore_size,
int orch_thread_num);

/**
* Lightweight re-initialization for subsequent rounds within the same case.
*
* Skips kernel upload, GM heap allocation, shared memory allocation, and
* orchestration SO copy. Only re-copies input/inout tensor data to device
* using the existing device memory allocations from the first round.
*
* Must be called after a successful init_runtime() + finalize_runtime_round()
* sequence. The Runtime handle must not have been fully finalized.
*
* @param runtime Runtime handle (previously initialized)
* @param func_args Arguments for orchestration (host pointers, sizes, etc.)
* @param func_args_count Number of arguments
* @param arg_types Array describing each argument's type (ArgType enum)
* @param arg_sizes Array of sizes for pointer arguments (0 for scalars)
* @return 0 on success, -1 on failure
*/
int reinit_runtime(RuntimeHandle runtime,
uint64_t* func_args,
int func_args_count,
int* arg_types,
uint64_t* arg_sizes);

/**
* Round-level finalize: copy results back but keep device resources alive.
*
* Copies output/inout tensors from device to host, but does NOT free
* device memory, kernel binaries, or call the Runtime destructor.
* Use this between rounds within the same case.
*
* @param runtime Runtime handle to finalize for this round
* @return 0 on success, -1 on failure
*/
int finalize_runtime_round(RuntimeHandle runtime);

/**
* Finalize and cleanup a runtime instance.
*
Expand Down
34 changes: 34 additions & 0 deletions src/a2a3/platform/onboard/host/pto_runtime_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ int init_runtime_impl(Runtime* runtime,
const size_t* kernel_sizes,
int kernel_count);
int validate_runtime_impl(Runtime* runtime);
int reinit_runtime_impl(Runtime* runtime,
uint64_t* func_args,
int func_args_count,
int* arg_types,
uint64_t* arg_sizes);
int validate_runtime_round_impl(Runtime* runtime);

/* Forward declarations for device memory functions used in init_runtime */
void* device_malloc(size_t size);
Expand Down Expand Up @@ -199,6 +205,34 @@ int launch_runtime(RuntimeHandle runtime,
}
}

int reinit_runtime(RuntimeHandle runtime,
uint64_t* func_args,
int func_args_count,
int* arg_types,
uint64_t* arg_sizes) {
if (runtime == NULL) {
return -1;
}
try {
Runtime* r = static_cast<Runtime*>(runtime);
return reinit_runtime_impl(r, func_args, func_args_count, arg_types, arg_sizes);
} catch (...) {
return -1;
}
}

int finalize_runtime_round(RuntimeHandle runtime) {
if (runtime == NULL) {
return -1;
}
try {
Runtime* r = static_cast<Runtime*>(runtime);
return validate_runtime_round_impl(r);
} catch (...) {
return -1;
}
}

int finalize_runtime(RuntimeHandle runtime) {
if (runtime == NULL) {
return -1;
Expand Down
34 changes: 34 additions & 0 deletions src/a2a3/platform/sim/host/pto_runtime_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ int init_runtime_impl(Runtime* runtime,
const size_t* kernel_sizes,
int kernel_count);
int validate_runtime_impl(Runtime* runtime);
int reinit_runtime_impl(Runtime* runtime,
uint64_t* func_args,
int func_args_count,
int* arg_types,
uint64_t* arg_sizes);
int validate_runtime_round_impl(Runtime* runtime);

/* Forward declarations */
void* device_malloc(size_t size);
Expand Down Expand Up @@ -202,6 +208,34 @@ int launch_runtime(RuntimeHandle runtime,
}
}

int reinit_runtime(RuntimeHandle runtime,
uint64_t* func_args,
int func_args_count,
int* arg_types,
uint64_t* arg_sizes) {
if (runtime == NULL) {
return -1;
}
try {
Runtime* r = static_cast<Runtime*>(runtime);
return reinit_runtime_impl(r, func_args, func_args_count, arg_types, arg_sizes);
} catch (...) {
return -1;
}
}

int finalize_runtime_round(RuntimeHandle runtime) {
if (runtime == NULL) {
return -1;
}
try {
Runtime* r = static_cast<Runtime*>(runtime);
return validate_runtime_round_impl(r);
} catch (...) {
return -1;
}
}

int finalize_runtime(RuntimeHandle runtime) {
if (runtime == NULL) {
return -1;
Expand Down
19 changes: 19 additions & 0 deletions src/a2a3/runtime/aicpu_build_graph/host/runtime_maker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,25 @@ int validate_runtime_impl(Runtime* runtime) {
return rc;
}

int reinit_runtime_impl(Runtime* runtime,
uint64_t* func_args,
int func_args_count,
int* arg_types,
uint64_t* arg_sizes) {
(void)func_args;
(void)func_args_count;
(void)arg_types;
(void)arg_sizes;
std::cerr << "Error: reinit_runtime_impl not supported for aicpu_build_graph runtime\n";
return -1;
}

int validate_runtime_round_impl(Runtime* runtime) {
(void)runtime;
std::cerr << "Error: validate_runtime_round_impl not supported for aicpu_build_graph runtime\n";
return -1;
}

#ifdef __cplusplus
} /* extern "C" */
#endif
19 changes: 19 additions & 0 deletions src/a2a3/runtime/host_build_graph/host/runtime_maker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,25 @@ int validate_runtime_impl(Runtime *runtime) {
return rc;
}

int reinit_runtime_impl(Runtime* runtime,
uint64_t* func_args,
int func_args_count,
int* arg_types,
uint64_t* arg_sizes) {
(void)func_args;
(void)func_args_count;
(void)arg_types;
(void)arg_sizes;
LOG_ERROR("reinit_runtime_impl not supported for host_build_graph runtime");
return -1;
}

int validate_runtime_round_impl(Runtime* runtime) {
(void)runtime;
LOG_ERROR("validate_runtime_round_impl not supported for host_build_graph runtime");
return -1;
}

#ifdef __cplusplus
} /* extern "C" */
#endif
Loading
Loading