diff --git a/src/datacustomcode/__init__.py b/src/datacustomcode/__init__.py index fdb0679..2662e74 100644 --- a/src/datacustomcode/__init__.py +++ b/src/datacustomcode/__init__.py @@ -17,7 +17,9 @@ from datacustomcode.credentials import AuthType, Credentials from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader from datacustomcode.io.writer.print import PrintDataCloudWriter -from datacustomcode.proxy.client.local_proxy_client import LocalProxyClientProvider +from datacustomcode.proxy.client.LocalProxyClientProvider import ( + LocalProxyClientProvider, +) __all__ = [ "AuthType", diff --git a/src/datacustomcode/client.py b/src/datacustomcode/client.py index d1a1138..01aed31 100644 --- a/src/datacustomcode/client.py +++ b/src/datacustomcode/client.py @@ -107,8 +107,9 @@ class Client: _reader: BaseDataCloudReader _writer: BaseDataCloudWriter _file: DefaultFindFilePath - _proxy: BaseProxyClient + _proxy: Optional[BaseProxyClient] _data_layer_history: dict[DataCloudObjectType, set[str]] + _code_type: str def __new__( cls, @@ -116,11 +117,13 @@ def __new__( writer: Optional["BaseDataCloudWriter"] = None, proxy: Optional[BaseProxyClient] = None, spark_provider: Optional["BaseSparkSessionProvider"] = None, + code_type: str = "script", ) -> Client: + if "function" in code_type: + return cls._new_function_client() + if cls._instance is None: cls._instance = super().__new__(cls) - - spark = None # Initialize Readers and Writers from config # and/or provided reader and writer if reader is None or writer is None: @@ -139,22 +142,6 @@ def __new__( provider = DefaultSparkSessionProvider() spark = provider.get_session(config.spark_config) - elif ( - proxy is None - and config.proxy_config is not None - and config.spark_config is not None - ): - # Both reader and writer provided; we still need spark for proxy init - provider = ( - spark_provider - if spark_provider is not None - else ( - config.spark_provider_config.to_object() - if config.spark_provider_config is not None - else DefaultSparkSessionProvider() - ) - ) - spark = provider.get_session(config.spark_config) if config.reader_config is None and reader is None: raise ValueError( @@ -163,28 +150,9 @@ def __new__( elif reader is None or ( config.reader_config is not None and config.reader_config.force ): - if config.proxy_config is None: - raise ValueError( - "Proxy config is required when reader is built from config" - ) - assert ( - spark is not None - ) # set in "reader is None or writer is None" branch - assert config.reader_config is not None # ensured by branch condition - proxy_init = config.proxy_config.to_object(spark) - - reader_init = config.reader_config.to_object(spark) + reader_init = config.reader_config.to_object(spark) # type: ignore else: reader_init = reader - if proxy is not None: - proxy_init = proxy - elif config.proxy_config is None: - raise ValueError("Proxy config is required when reader is provided") - else: - assert ( - spark is not None - ) # set in "both provided; proxy from config" branch - proxy_init = config.proxy_config.to_object(spark) if config.writer_config is None and writer is None: raise ValueError( "Writer config is required when writer is not provided" @@ -192,15 +160,13 @@ def __new__( elif writer is None or ( config.writer_config is not None and config.writer_config.force ): - assert spark is not None # set when reader or writer from config - assert config.writer_config is not None # ensured by branch condition - writer_init = config.writer_config.to_object(spark) + writer_init = config.writer_config.to_object(spark) # type: ignore else: writer_init = writer + cls._instance._reader = reader_init cls._instance._writer = writer_init cls._instance._file = DefaultFindFilePath() - cls._instance._proxy = proxy_init cls._instance._data_layer_history = { DataCloudObjectType.DLO: set(), DataCloudObjectType.DMO: set(), @@ -209,6 +175,16 @@ def __new__( raise ValueError("Cannot set reader or writer after client is initialized") return cls._instance + @classmethod + def _new_function_client(cls) -> Client: + cls._instance = super().__new__(cls) + cls._instance._proxy = ( + config.proxy_config.to_object() # type: ignore + if config.proxy_config is not None + else None + ) + return cls._instance + def read_dlo(self, name: str) -> PySparkDataFrame: """Read a DLO from Data Cloud. @@ -260,6 +236,8 @@ def write_to_dmo( return self._writer.write_to_dmo(name, dataframe, write_mode, **kwargs) def call_llm_gateway(self, LLM_MODEL_ID: str, prompt: str, maxTokens: int) -> str: + if self._proxy is None: + raise ValueError("No proxy configured; set proxy or proxy_config") return self._proxy.call_llm_gateway(LLM_MODEL_ID, prompt, maxTokens) def find_file_path(self, file_name: str) -> Path: diff --git a/src/datacustomcode/config.py b/src/datacustomcode/config.py index b1edfc4..602e182 100644 --- a/src/datacustomcode/config.py +++ b/src/datacustomcode/config.py @@ -38,6 +38,7 @@ from datacustomcode.io.base import BaseDataAccessLayer from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH001 from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH001 +from datacustomcode.proxy.base import BaseProxyAccessLayer from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH001 from datacustomcode.spark.base import BaseSparkSessionProvider @@ -93,6 +94,23 @@ class SparkConfig(ForceableConfig): _P = TypeVar("_P", bound=BaseSparkSessionProvider) +_PX = TypeVar("_PX", bound=BaseProxyAccessLayer) + + +class ProxyAccessLayerObjectConfig(ForceableConfig, Generic[_PX]): + """Config for proxy clients that take no constructor args (e.g. no spark).""" + + model_config = ConfigDict(validate_default=True, extra="forbid") + type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer + type_config_name: str = Field( + description="CONFIG_NAME of the proxy client (e.g. 'LocalProxyClient').", + ) + options: dict[str, Any] = Field(default_factory=dict) + + def to_object(self) -> _PX: + type_ = self.type_base.subclass_from_config_name(self.type_config_name) + return cast(_PX, type_(**self.options)) + class SparkProviderConfig(ForceableConfig, Generic[_P]): model_config = ConfigDict(validate_default=True, extra="forbid") @@ -110,7 +128,7 @@ def to_object(self) -> _P: class ClientConfig(BaseModel): reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None - proxy_config: Union[AccessLayerObjectConfig[BaseProxyClient], None] = None + proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None spark_config: Union[SparkConfig, None] = None spark_provider_config: Union[ SparkProviderConfig[BaseSparkSessionProvider], None diff --git a/src/datacustomcode/proxy/base.py b/src/datacustomcode/proxy/base.py index cba92f6..71cf314 100644 --- a/src/datacustomcode/proxy/base.py +++ b/src/datacustomcode/proxy/base.py @@ -19,6 +19,6 @@ from datacustomcode.mixin import UserExtendableNamedConfigMixin -class BaseDataAccessLayer(ABC, UserExtendableNamedConfigMixin): +class BaseProxyAccessLayer(ABC, UserExtendableNamedConfigMixin): def __init__(self): pass diff --git a/src/datacustomcode/proxy/client/local_proxy_client.py b/src/datacustomcode/proxy/client/LocalProxyClientProvider.py similarity index 94% rename from src/datacustomcode/proxy/client/local_proxy_client.py rename to src/datacustomcode/proxy/client/LocalProxyClientProvider.py index 2c2f962..515db00 100644 --- a/src/datacustomcode/proxy/client/local_proxy_client.py +++ b/src/datacustomcode/proxy/client/LocalProxyClientProvider.py @@ -22,5 +22,8 @@ class LocalProxyClientProvider(BaseProxyClient): CONFIG_NAME = "LocalProxyClientProvider" + def __init__(self, **kwargs: object) -> None: + pass + def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str: return f"Hello, thanks for using {llmModelId}. So many tokens: {maxTokens}" diff --git a/src/datacustomcode/proxy/client/base.py b/src/datacustomcode/proxy/client/base.py index 3c4a56b..5c840a0 100644 --- a/src/datacustomcode/proxy/client/base.py +++ b/src/datacustomcode/proxy/client/base.py @@ -16,13 +16,12 @@ from abc import abstractmethod -from datacustomcode.io.base import BaseDataAccessLayer +from datacustomcode.proxy.base import BaseProxyAccessLayer -class BaseProxyClient(BaseDataAccessLayer): - def __init__(self, spark=None, **kwargs): - if spark is not None: - super().__init__(spark) +class BaseProxyClient(BaseProxyAccessLayer): + def __init__(self): + pass @abstractmethod def call_llm_gateway(self, llmModelId: str, prompt: str, maxTokens: int) -> str: ...