Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,26 @@ def resume_orchestration(self, instance_id: str):
self._logger.info(f"Resuming instance '{instance_id}'.")
self._stub.ResumeInstance(req)

def restart_orchestration(self, instance_id: str, *,
restart_with_new_instance_id: bool = False) -> str:
"""Restarts an existing orchestration instance.

Args:
instance_id: The ID of the orchestration instance to restart.
restart_with_new_instance_id: If True, the restarted orchestration will use a new instance ID.
If False (default), the restarted orchestration will reuse the same instance ID.

Returns:
The instance ID of the restarted orchestration.
"""
req = pb.RestartInstanceRequest(
instanceId=instance_id,
restartWithNewInstanceId=restart_with_new_instance_id)

self._logger.info(f"Restarting instance '{instance_id}'.")
res: pb.RestartInstanceResponse = self._stub.RestartInstance(req)
return res.instanceId

def purge_orchestration(self, instance_id: str, recursive: bool = True):
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
self._logger.info(f"Purging instance '{instance_id}'.")
Expand Down
1 change: 1 addition & 0 deletions durabletask/internal/PROTO_SOURCE_COMMIT_HASH
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
443b333f4f65a438dc9eb4f090560d232afec4b7
fd9369c6a03d6af4e95285e432b7c4e943c06970
026329c53fe6363985655857b9ca848ec7238bd2
494 changes: 273 additions & 221 deletions durabletask/internal/orchestrator_service_pb2.py

Large diffs are not rendered by default.

249 changes: 208 additions & 41 deletions durabletask/internal/orchestrator_service_pb2.pyi

Large diffs are not rendered by default.

143 changes: 135 additions & 8 deletions durabletask/internal/orchestrator_service_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from durabletask.internal import orchestrator_service_pb2 as durabletask_dot_internal_dot_orchestrator__service__pb2
from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2

GRPC_GENERATED_VERSION = '1.65.4'
GRPC_GENERATED_VERSION = '1.78.0'
GRPC_VERSION = grpc.__version__
EXPECTED_ERROR_RELEASE = '1.66.0'
SCHEDULED_RELEASE_DATE = 'August 6, 2024'
_version_not_supported = False

try:
Expand All @@ -19,15 +17,12 @@
_version_not_supported = True

if _version_not_supported:
warnings.warn(
raise RuntimeError(
f'The grpc package installed is at version {GRPC_VERSION},'
+ f' but the generated code in durabletask/internal/orchestrator_service_pb2_grpc.py depends on'
+ ' but the generated code in durabletask/internal/orchestrator_service_pb2_grpc.py depends on'
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
+ f' This warning will become an error in {EXPECTED_ERROR_RELEASE},'
+ f' scheduled for release on {SCHEDULED_RELEASE_DATE}.',
RuntimeWarning
)


Expand Down Expand Up @@ -60,6 +55,11 @@ def __init__(self, channel):
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RewindInstanceRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RewindInstanceResponse.FromString,
_registered_method=True)
self.RestartInstance = channel.unary_unary(
'/TaskHubSidecarService/RestartInstance',
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RestartInstanceRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RestartInstanceResponse.FromString,
_registered_method=True)
self.WaitForInstanceStart = channel.unary_unary(
'/TaskHubSidecarService/WaitForInstanceStart',
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceRequest.SerializeToString,
Expand Down Expand Up @@ -95,6 +95,11 @@ def __init__(self, channel):
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.QueryInstancesRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.QueryInstancesResponse.FromString,
_registered_method=True)
self.ListInstanceIds = channel.unary_unary(
'/TaskHubSidecarService/ListInstanceIds',
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIdsRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIdsResponse.FromString,
_registered_method=True)
self.PurgeInstances = channel.unary_unary(
'/TaskHubSidecarService/PurgeInstances',
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.PurgeInstancesRequest.SerializeToString,
Expand Down Expand Up @@ -170,6 +175,11 @@ def __init__(self, channel):
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.AbandonEntityTaskRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.AbandonEntityTaskResponse.FromString,
_registered_method=True)
self.SkipGracefulOrchestrationTerminations = channel.unary_unary(
'/TaskHubSidecarService/SkipGracefulOrchestrationTerminations',
request_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.SkipGracefulOrchestrationTerminationsRequest.SerializeToString,
response_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.SkipGracefulOrchestrationTerminationsResponse.FromString,
_registered_method=True)


class TaskHubSidecarServiceServicer(object):
Expand Down Expand Up @@ -203,6 +213,13 @@ def RewindInstance(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def RestartInstance(self, request, context):
"""Restarts an orchestration instance.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def WaitForInstanceStart(self, request, context):
"""Waits for an orchestration instance to reach a running or completion state.
"""
Expand Down Expand Up @@ -253,6 +270,12 @@ def QueryInstances(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def ListInstanceIds(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def PurgeInstances(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
Expand Down Expand Up @@ -353,6 +376,14 @@ def AbandonTaskEntityWorkItem(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def SkipGracefulOrchestrationTerminations(self, request, context):
""""Skip" graceful termination of orchestrations by immediately changing their status in storage to "terminated".
Note that a maximum of 500 orchestrations can be terminated at a time using this method.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_TaskHubSidecarServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -376,6 +407,11 @@ def add_TaskHubSidecarServiceServicer_to_server(servicer, server):
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RewindInstanceRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RewindInstanceResponse.SerializeToString,
),
'RestartInstance': grpc.unary_unary_rpc_method_handler(
servicer.RestartInstance,
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RestartInstanceRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.RestartInstanceResponse.SerializeToString,
),
'WaitForInstanceStart': grpc.unary_unary_rpc_method_handler(
servicer.WaitForInstanceStart,
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.GetInstanceRequest.FromString,
Expand Down Expand Up @@ -411,6 +447,11 @@ def add_TaskHubSidecarServiceServicer_to_server(servicer, server):
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.QueryInstancesRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.QueryInstancesResponse.SerializeToString,
),
'ListInstanceIds': grpc.unary_unary_rpc_method_handler(
servicer.ListInstanceIds,
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIdsRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIdsResponse.SerializeToString,
),
'PurgeInstances': grpc.unary_unary_rpc_method_handler(
servicer.PurgeInstances,
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.PurgeInstancesRequest.FromString,
Expand Down Expand Up @@ -486,6 +527,11 @@ def add_TaskHubSidecarServiceServicer_to_server(servicer, server):
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.AbandonEntityTaskRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.AbandonEntityTaskResponse.SerializeToString,
),
'SkipGracefulOrchestrationTerminations': grpc.unary_unary_rpc_method_handler(
servicer.SkipGracefulOrchestrationTerminations,
request_deserializer=durabletask_dot_internal_dot_orchestrator__service__pb2.SkipGracefulOrchestrationTerminationsRequest.FromString,
response_serializer=durabletask_dot_internal_dot_orchestrator__service__pb2.SkipGracefulOrchestrationTerminationsResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'TaskHubSidecarService', rpc_method_handlers)
Expand Down Expand Up @@ -605,6 +651,33 @@ def RewindInstance(request,
metadata,
_registered_method=True)

@staticmethod
def RestartInstance(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/TaskHubSidecarService/RestartInstance',
durabletask_dot_internal_dot_orchestrator__service__pb2.RestartInstanceRequest.SerializeToString,
durabletask_dot_internal_dot_orchestrator__service__pb2.RestartInstanceResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def WaitForInstanceStart(request,
target,
Expand Down Expand Up @@ -794,6 +867,33 @@ def QueryInstances(request,
metadata,
_registered_method=True)

@staticmethod
def ListInstanceIds(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/TaskHubSidecarService/ListInstanceIds',
durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIdsRequest.SerializeToString,
durabletask_dot_internal_dot_orchestrator__service__pb2.ListInstanceIdsResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def PurgeInstances(request,
target,
Expand Down Expand Up @@ -1198,3 +1298,30 @@ def AbandonTaskEntityWorkItem(request,
timeout,
metadata,
_registered_method=True)

@staticmethod
def SkipGracefulOrchestrationTerminations(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/TaskHubSidecarService/SkipGracefulOrchestrationTerminations',
durabletask_dot_internal_dot_orchestrator__service__pb2.SkipGracefulOrchestrationTerminationsRequest.SerializeToString,
durabletask_dot_internal_dot_orchestrator__service__pb2.SkipGracefulOrchestrationTerminationsResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class ProtoTaskHubSidecarServiceStub(Protocol):
StartInstance: Callable[..., Any]
GetInstance: Callable[..., Any]
RewindInstance: Callable[..., Any]
RestartInstance: Callable[..., Any]
WaitForInstanceStart: Callable[..., Any]
WaitForInstanceCompletion: Callable[..., Any]
RaiseEvent: Callable[..., Any]
Expand Down
77 changes: 77 additions & 0 deletions tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")


def _get_credential():
"""Returns DefaultAzureCredential if endpoint is https, otherwise None (for emulator)."""
if endpoint.startswith("https://"):
from azure.identity import DefaultAzureCredential
return DefaultAzureCredential()
return None


def test_empty_orchestration():

invoked = False
Expand Down Expand Up @@ -371,6 +379,75 @@ def child(ctx: task.OrchestrationContext, _):
assert state is None


def test_restart_with_same_instance_id():
def orchestrator(ctx: task.OrchestrationContext, _):
result = yield ctx.call_activity(say_hello, input="World")
return result

def say_hello(ctx: task.ActivityContext, input: str):
return f"Hello, {input}!"

credential = _get_credential()

# Start a worker, which will connect to the sidecar in a background thread
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=credential) as w:
w.add_orchestrator(orchestrator)
w.add_activity(say_hello)
w.start()

task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=credential)
id = task_hub_client.schedule_new_orchestration(orchestrator)
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)
assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
assert state.serialized_output == json.dumps("Hello, World!")

# Restart the orchestration with the same instance ID
restarted_id = task_hub_client.restart_orchestration(id)
assert restarted_id == id

state = task_hub_client.wait_for_orchestration_completion(restarted_id, timeout=30)
assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
assert state.serialized_output == json.dumps("Hello, World!")


def test_restart_with_new_instance_id():
def orchestrator(ctx: task.OrchestrationContext, _):
result = yield ctx.call_activity(say_hello, input="World")
return result

def say_hello(ctx: task.ActivityContext, input: str):
return f"Hello, {input}!"

credential = _get_credential()

# Start a worker, which will connect to the sidecar in a background thread
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=credential) as w:
w.add_orchestrator(orchestrator)
w.add_activity(say_hello)
w.start()

task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
taskhub=taskhub_name, token_credential=credential)
id = task_hub_client.schedule_new_orchestration(orchestrator)
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)
assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED

# Restart the orchestration with a new instance ID
restarted_id = task_hub_client.restart_orchestration(id, restart_with_new_instance_id=True)
assert restarted_id != id

state = task_hub_client.wait_for_orchestration_completion(restarted_id, timeout=30)
assert state is not None
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
assert state.serialized_output == json.dumps("Hello, World!")


# def test_continue_as_new():
# all_results = []

Expand Down
Loading
Loading