Source code for a3m.server.rpc.client

import logging
from collections.abc import Callable

import tenacity
from grpc import Channel
from grpc import RpcError

from a3m import __version__
from a3m.api.transferservice import v1beta1 as transfer_service_api

logger = logging.getLogger(__name__)


# Default duration in seconds of RPC calls.
_GRPC_DEFAULT_TIMEOUT_SECS = 30

# Metadata key containing the client version.
_VERSION_METADATA_KEY = "version"


[docs] class Client: """a3m gRPC API client.""" def __init__( self, channel: Channel, rpc_timeout: int | None = _GRPC_DEFAULT_TIMEOUT_SECS, wait_for_ready: bool = False, ): self.transfer_stub = transfer_service_api.service_pb2_grpc.TransferServiceStub( channel ) self.rpc_timeout = rpc_timeout self.wait_for_ready = wait_for_ready def _unary_call(self, api_method, request): rpc_name = request.__class__.__name__.replace("Request", "") logger.debug("RPC call %s with request: %r", rpc_name, request) try: return api_method( request, timeout=self.rpc_timeout, metadata=Client.version_metadata(), wait_for_ready=self.wait_for_ready, ) except RpcError as e: logger.warning("RPC call %s got error %s", rpc_name, e) raise @staticmethod def version_metadata(): return ((_VERSION_METADATA_KEY, __version__),) def submit( self, url: str, name: str, config: transfer_service_api.request_response_pb2.ProcessingConfig = None, ): request = transfer_service_api.request_response_pb2.SubmitRequest( name=name, url=url, config=config ) return self._unary_call(self.transfer_stub.Submit, request) def read(self, package_id: str): request = transfer_service_api.request_response_pb2.ReadRequest(id=package_id) return self._unary_call(self.transfer_stub.Read, request) def wait_until_complete( self, package_id: str, spin_cb: Callable = None ) -> transfer_service_api.request_response_pb2.ReadResponse: """Blocks until processing of a package has completed.""" def _should_continue( resp: transfer_service_api.request_response_pb2.ReadResponse, ): return ( resp.status == transfer_service_api.request_response_pb2.PACKAGE_STATUS_PROCESSING ) def _callback(retry_state): if spin_cb is not None: spin_cb(retry_state) @tenacity.retry( wait=tenacity.wait_fixed(1), retry=tenacity.retry_if_result(_should_continue), after=_callback, ) def _poll(): """Retries while the package is processing.""" return self.read(package_id) return _poll() def list_tasks(self, job_id: str): request = transfer_service_api.request_response_pb2.ListTasksRequest( job_id=job_id ) return self._unary_call(self.transfer_stub.ListTasks, request)