diff --git a/pkg/adaptation/adaptation_suite_test.go b/pkg/adaptation/adaptation_suite_test.go index 4e3bc639..d77efeec 100644 --- a/pkg/adaptation/adaptation_suite_test.go +++ b/pkg/adaptation/adaptation_suite_test.go @@ -201,6 +201,24 @@ var _ = Describe("Plugin connection", func() { s.Cleanup() }) + It("should reject plugins with an invalid name", func() { + var ( + validPlugin = &mockPlugin{ + name: "abcd-0123+EFGH_4567.ijkl", + idx: "05", + } + invalidPlugin = &mockPlugin{ + name: "foo,bar", + idx: "10", + } + ) + + s.Startup() + + Expect(validPlugin.Start(s.dir)).To(Succeed()) + Expect(invalidPlugin.Start(s.dir)).ToNot(Succeed()) + }) + It("should configure the plugin", func() { var ( plugin = s.plugins[0] diff --git a/pkg/adaptation/plugin.go b/pkg/adaptation/plugin.go index 50fbebfe..0441a0c3 100644 --- a/pkg/adaptation/plugin.go +++ b/pkg/adaptation/plugin.go @@ -435,9 +435,9 @@ func (p *plugin) qualifiedName() string { // RegisterPlugin handles the plugin's registration request. func (p *plugin) RegisterPlugin(ctx context.Context, req *RegisterPluginRequest) (*RegisterPluginResponse, error) { if p.isExternal() { - if req.PluginName == "" { - p.regC <- fmt.Errorf("plugin %q registered with an empty name", p.qualifiedName()) - return &RegisterPluginResponse{}, errors.New("invalid (empty) plugin name") + if err := api.CheckPluginName(req.PluginName); err != nil { + p.regC <- fmt.Errorf("plugin registered with an invalid name: %w", err) + return &RegisterPluginResponse{}, fmt.Errorf("invalid plugin name: %w", err) } if err := api.CheckPluginIndex(req.PluginIdx); err != nil { p.regC <- fmt.Errorf("plugin %q registered with an invalid index: %w", req.PluginName, err) diff --git a/pkg/api/owners.go b/pkg/api/owners.go index 05c1e7c9..86e305e2 100644 --- a/pkg/api/owners.go +++ b/pkg/api/owners.go @@ -475,14 +475,7 @@ func (f *FieldOwners) ClaimMount(destination, plugin string) error { } func (f *FieldOwners) ClaimHooks(plugin string) error { - plugins := plugin - - if current, ok := f.simpleOwner(Field_OciHooks.Key()); ok { - f.clearSimple(Field_OciHooks.Key(), plugin) - plugins = current + "," + plugin - } - - f.claimSimple(Field_OciHooks.Key(), plugins) + f.accumulateSimple(Field_OciHooks.Key(), plugin) return nil } @@ -678,6 +671,14 @@ func (f *FieldOwners) ClearRdt(plugin string) { f.clearSimple(Field_RdtEnableMonitoring.Key(), plugin) } +func (f *FieldOwners) accumulateSimple(field int32, plugin string) { + old, ok := f.simpleOwner(field) + if ok { + plugin = old + "," + plugin + } + f.Simple[field] = plugin +} + func (f *FieldOwners) Conflict(field int32, plugin, other string, qualifiers ...string) error { return fmt.Errorf("plugins %q and %q both tried to set %s", plugin, other, qualify(field, qualifiers...)) diff --git a/pkg/api/owners_test.go b/pkg/api/owners_test.go index 939274af..4efb53b0 100644 --- a/pkg/api/owners_test.go +++ b/pkg/api/owners_test.go @@ -113,3 +113,24 @@ func TestCompoundClaims(t *testing.T) { require.Equal(t, api.Field_Annotations.String(), "Annotations", "annotation field name") } + +func TestAccumulatingOwnership(t *testing.T) { + o := api.NewOwningPlugins() + + // claim OCI hooks of ctr0 + err := o.ClaimHooks("ctr0", "plugin0") + require.NoError(t, err, "ctr0 OCI hooks by plugin0") + + // claim OCI hooks of ctr0 + err = o.ClaimHooks("ctr0", "plugin1") + require.NoError(t, err, "ctr0 OCI hooks by plugin1") + + // claim OCI hooks of ctr0 + err = o.ClaimHooks("ctr0", "plugin2") + require.NoError(t, err, "ctr0 OCI hooks by plugin2") + + owners, ok := o.HooksOwner("ctr0") + require.True(t, ok, "ctr0 has hooks owners") + require.Equal(t, "plugin0,plugin1,plugin2", owners, "ctr0 hooks owners") + +} diff --git a/pkg/api/plugin.go b/pkg/api/plugin.go index 7eaf8920..e4bfa2e7 100644 --- a/pkg/api/plugin.go +++ b/pkg/api/plugin.go @@ -17,6 +17,7 @@ package api import ( + "errors" "fmt" "strings" ) @@ -57,3 +58,23 @@ func CheckPluginIndex(idx string) error { } return nil } + +// CheckPluginName verifies that a plugin name is not empty and only contains +// characters from the set [a-zA-Z0-9_.+-]. +func CheckPluginName(name string) error { + if name == "" { + return errors.New("invalid plugin name: name is empty") + } + + for _, r := range name { + switch { + case 'a' <= r && r <= 'z', 'A' <= r && r <= 'Z': + case '0' <= r && r <= '9': + case r == '-', r == '_', r == '.', r == '+': + default: + return fmt.Errorf("invalid plugin name %q: contains invalid character %q", name, r) + } + } + + return nil +}