Skip to content

Support PositiveIndexKernel and dispatching via TaskParameter#728

Open
kalama-ai wants to merge 7 commits intomainfrom
feat/support-positive-index-kernel
Open

Support PositiveIndexKernel and dispatching via TaskParameter#728
kalama-ai wants to merge 7 commits intomainfrom
feat/support-positive-index-kernel

Conversation

@kalama-ai
Copy link
Collaborator

Dispatching between BoTorch's PositiveIndexKernel and GPyTorch's IndexKernel for transfer learning.

  1. New task_correlation Parameter on TaskParameter

Can specify correlation mode when creating a TaskParameter:

# Use PositiveIndexKernel (default)
task_param = TaskParameter(
    name="task",
    values=["source_task_1", "source_task_2", "target_task"],
    active_values=["target_task"],
    task_correlation=TaskCorrelation.POSITIVE, 
)

# Use IndexKernel
task_param = TaskParameter(
    name="task",
    values=["source_task_1", "source_task_2", "target_task"],
    active_values=["target_task"],
    task_correlation=TaskCorrelation.UNKNOWN, 
  1. Kernel Dispatching

The GaussianProcessSurrogate selects the kernel based on the task_correlation:

  • TaskCorrelation.POSITIVE → uses botorch.models.kernels.PositiveIndexKernel
  • TaskCorrelation.UNKNOWN → uses gpytorch.kernels.IndexKernel
  1. Integrated both modes into benchmarks.

- PositiveIndexKernel will normalize the diagonal elements onf the target task
- requires integer value of target task to identify index
- only single index supported
…with PositiveIndexKernel

- new property transfer_mode returning the TL mode of the searchpace, if a
TaskParameter is provided (required for dispatching between two kernels)
- new property target_task_idxs retunrning the indices of the active_values
of the TaskParameter in its computational representation (this is required for
normalization in PositiveIndexKernel)
…GaussianProcessSurrogate

- add required properties (transfer_mode and target_task_idxs) to _ModelContext
- dispatch between two kernels given transfer_mode
- if transfer_mode is `joint_pos` (use PositiveIndexKernel) we implicitely assume
only one active_value for identifying the target_task (a wrong configuration will raise an
error in TaskParameter)
- `TransferMode` was replaced by `TaskCorrelation` because it was
hard to understand without knowing about the kernels
- positive correlation -> use PositiveIndexKernel
- unknown correlation -> use IndexKernel since it might be more robust
# See base class.

task_correlation: TaskCorrelation = field(default=TaskCorrelation.POSITIVE)
"""Task correlation. Defaults to positive correlation via PositiveIndexKernel."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Task correlation. Defaults to positive correlation via PositiveIndexKernel."""
"""Task correlation influencing which kernel will be used du default for task parameters."""

"""
# Check POSITIVE constraint: must have exactly one active value
# Note: _active_values is the internal field, could be None
if value == TaskCorrelation.POSITIVE and self._active_values is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use value is TaskCorrelation.POSITIVE (always check sentinels via is and never ==)

Why are you suing ._active_values and not .active_values? The latter takes care of defaulting to values if user does not specify anything and can never become None

if len(self._active_values) > 1:
raise ValueError(
f"Task correlation '{TaskCorrelation.POSITIVE.value}' requires "
f"one active value, but {len(self._active_values)} were provided: "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"one active value, but {len(self._active_values)} were provided: "
f"exactly one active value, but {len(self._active_values)} were provided: "

return 1

@property
def target_task_idxs(self) -> list[int] | None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would always prefer returning tuples such cases unless there is a limitation that it really must be a list

Comment on lines +73 to +82
@property
def task_correlation(self) -> TaskCorrelation | None:
"""Get the task correlation mode of the task parameter, if available."""
return self.searchspace.task_correlation

@property
def target_task_idxs(self) -> list[int] | None:
"""Determine target task index for PositiveIndexKernel normalization."""
return self.searchspace.target_task_idxs

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how necessary are these helpers? I can get them just via gp.searchspace.x which is not tremendously worse than just gp.x

Comment on lines +195 to +205
elif context.task_correlation == TaskCorrelation.POSITIVE:
task_covar_module = (
botorch.models.kernels.positive_index.PositiveIndexKernel(
num_tasks=context.n_tasks,
active_dims=context.task_idx,
rank=context.n_tasks, # TODO: make controllable
target_task_index=context.target_task_idxs[0],
)
)
covar_module = base_covar_module * task_covar_module
elif context.task_correlation == TaskCorrelation.UNKNOWN:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for our common understanding: these parts will eventually have to e outsourced to a default_task_kernel_factory or similar (not needed in this PR)

"source_data_seed": settings.random_seed + mc_iter,
}
result.update(metrics)
results.append(result)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you expanded the benchmarks: are they still feasible or are they now timing out due to the longer runtime?

data: The benchmark data.
target_tasks: The target tasks for transfer learning.
source_tasks: The source tasks for transfer learning.
task_correlation: The task correlation mode (UNKNOWN or POSITIVE).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please do not hardcode possible enum values here in such comments (not well maintainable) - maybe link the actual enum

active_values=["Target_Function"],
task_correlation=TaskCorrelation.POSITIVE,
)
params_tl_index = params + [task_param_index]
Copy link
Collaborator

@Scienfitz Scienfitz Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

your approach ehre to expand the benchmakrs is to copy the code and make 1 entry for POSITIVE and one entry for UNKNOWN

Instead you could loop over the possible values of TaskCorreclation and automatically create as many searchspaces with autogenerated names etc.

That would have the major advanatge that youd never have to touch this code again to incorporate choices that might be added in the future

encoding: CategoricalEncoding = field(default=CategoricalEncoding.INT, init=False)
# See base class.

task_correlation: TaskCorrelation = field(default=TaskCorrelation.POSITIVE)
Copy link
Collaborator

@Scienfitz Scienfitz Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only big potential problem with this PR I cans pot is the naming of this attribute

In isolation the name is totally accurate and fine. But we already have plans to expand this attribute so have potentially more choices, like eg RGPE, MeanTransfer, CovarTransfer etc (with names yet to be decided) -> the name correlation is then not appropriate anymore. Instead this attribute embodies something like TL_MODE or TL_METHOD or TL_ALGORITHM.

Now of course we could change the name of the attribute later, but since this is merged to main and potentially released before we have the other choices, we would introduce a breaking change that has tobe deprecated. So it would be beneficial if we would avoid that situation.

Here two proposals how to do that:

  • make this attribute private for now indicating to users that its not fully public and can change at any moment
  • already now decide on the attribute name, which should be doable because it will have to be a rather generic one (see proposals above)

@AdrianSosic do you agree with this issue of the attribute name?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kalama-ai @AdrianSosic can you quickly comment on the state of this PR? If I remember correctly, this was one of the PRs that are somewhat depending on the current refactoring work of Adrian. Has this code here already been rebased and is thus ready to review? Or do I misremember?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants