How to Implement Task Chaining#
Goal: Create a workflow where one task's output feeds into the next task.
When to use this: Your problem requires multiple steps where each builds on previous results.
1. Define dependent tasks#
from corral.backend.task import TaskDefinition, TaskGroup
# Task 1: Retrieve data
retrieve_task = TaskDefinition(
name="retrieve",
description="Fetch molecule structure from database",
tools=["database_query"],
scoring_fn=lambda x: 1.0 if x else 0.0,
submission_format={"structure": "string"},
initial_input={"molecule_id": "mol_123"},
)
# Task 2: Analyze (uses output from Task 1)
analyze_task = TaskDefinition(
name="analyze",
description="Analyze the molecule structure",
tools=["structure_analyzer"],
scoring_fn=lambda x: 1.0 if "properties" in x else 0.0,
submission_format={"analysis": "dict"},
input_from_tasks=["retrieve"], # Dependency
)
task_group = TaskGroup(
group_id="workflow", tasks={"retrieve": retrieve_task, "analyze": analyze_task}
)
2. Create environments that use task results#
class ChainedEnvironment(Environment):
def __init__(self, task_id: str, task_group: TaskGroup):
super().__init__(task_id)
self.task_group = task_group
self.current_task = task_group.tasks[task_id]
# Add tools
for tool_name in self.current_task.tools:
self.add_tool(get_tool(tool_name))
def get_task_prompt(self) -> str:
# Get input from previous tasks
task_input = self.task_group.get_task_input(self.task_id)
prompt = self.current_task.description
if task_input:
prompt += f"\n\nPrevious results:\n{task_input}"
return prompt
def score(self) -> float:
score = self.current_task.scoring_fn(self.state.submitted_answer)
# Store for next task
self.task_group.store_result(
self.task_id, {"answer": self.state.submitted_answer}, score
)
return score
3. Run the chained workflow#
# Create environments for each task
environments = {
task_id: ChainedEnvironment(task_id, task_group)
for task_id in task_group.tasks.keys()
}
# The runner will execute in dependency order
runner.bench(task_ids=list(environments.keys()), trials_per_task=1)
Done: Tasks now execute sequentially with data passing between them.