Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions encoderfile-py/python/tests/assets/test_config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
encoderfile:
name: my-model-2
path: models/token_classification
model_type: token_classification
output_path: ./test-model.encoderfile
transform: |
--- Applies a softmax across token classification logits.
--- Each token classification is normalized independently.
---
--- Args:
--- arr (Tensor): A tensor of shape [batch_size, n_tokens, n_labels].
--- The softmax is applied along the third axis (n_labels).
---
--- Returns:
--- Tensor: The input tensor with softmax-normalized embeddings.
---@param arr Tensor
---@return Tensor
function Postprocess(arr)
return arr:softmax(3)
end
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
encoderfile:
name: my-model-2-lua
path: models/dummy_electra_token_classifier
path: models/token_classification
model_type: token_classification
output_path: ./test-model-lua.encoderfile
output_path: ./test-model.encoderfile
lua_libs:
- table
- math
Expand Down
24 changes: 24 additions & 0 deletions encoderfile-py/python/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os

import yaml


def asset_path(filename: str) -> str:
"""
Returns the absolute path to an asset file in the assets directory,
regardless of the current working directory.
"""
base_dir = os.path.dirname(os.path.abspath(__file__))
return os.path.join(base_dir, "assets", filename)


def load_yaml_asset(filename):
"""
Loads a yaml asset file from the assets directory.
"""
path = asset_path(filename)
if filename.endswith((".yml", ".yaml")):
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
else:
raise ValueError("Only yaml files are supported for this fixture.")
5 changes: 0 additions & 5 deletions encoderfile-py/python/tests/test_all.py

This file was deleted.

38 changes: 38 additions & 0 deletions encoderfile-py/python/tests/test_encoderfile_stub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll probably rename the file. It doesn't make much sense to mention the stubbing frankly :-/

from encoderfile import (
EncoderfileBuilder,
ModelConfig,
EncoderfileConfig,
InspectInfo,
inspect,
)
from conftest import asset_path, load_yaml_asset


def test_encoderfilebuilder_from_config_returns_builder():
config_path = asset_path("test_config.yml")
builder = EncoderfileBuilder.from_config(config_path)
assert isinstance(builder, EncoderfileBuilder)


@pytest.mark.parametrize("config_filename", ["test_config.yml", "test_config_lua.yml"])
def test_encoderfilebuilder_build_runs(config_filename):
config_path = asset_path(config_filename)
config_info = load_yaml_asset(config_filename)
builder = EncoderfileBuilder.from_config(config_path)
# Should not raise
builder.build(working_dir=None, version=None, no_download=True)
result = inspect(config_info["encoderfile"]["output_path"])
assert isinstance(result, InspectInfo)
assert isinstance(result.model_config, ModelConfig)
assert isinstance(result.encoderfile_config, EncoderfileConfig)
print(result.model_config)
print(result.encoderfile_config)
assert result.encoderfile_config.name == config_info["encoderfile"]["name"]
assert (
result.encoderfile_config.transform.strip()
== config_info["encoderfile"].get("transform").strip()
)
assert result.encoderfile_config.lua_libs == config_info["encoderfile"].get(
"lua_libs"
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ requires-python = ">=3.13"
dependencies = [ "pip>=25.2", "protobuf>=6.33.0"]

[dependency-groups]
dev = [ "pre-commit>=4.3.0", "ruff>=0.14.2", "mypy>=1.19"]
dev = [ "pre-commit>=4.3.0", "ruff>=0.14.2", "mypy>=1.19", "pytest"]
docs = [ "mkdocs>=1.6.1", "mkdocs-include-markdown-plugin>=7.2.0", "mkdocs-material>=9.7.0",]
setup = [ "onnxruntime>=1.23.1", "optimum[onnxruntime]>=2.0.0", "transformers>=4.55.0",]
models = [ "torch>=2.9.0", "click", {include-group = "setup"} ]
Expand Down
2 changes: 2 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.