diff --git a/src/activation.rs b/src/activation.rs index 3ea25b6..dadb9d2 100644 --- a/src/activation.rs +++ b/src/activation.rs @@ -93,24 +93,24 @@ impl ActivationRegistry { } /// Fetches a random activation fn that applies to the provided scope. + /// + /// # Panics + /// + /// Panics if there are no activation functions registered for the given scope. pub fn random_activation_in_scope( &self, scope: NeuronScope, rng: &mut impl rand::Rng, ) -> ActivationFn { - let mut iter = self.fns.values().cycle(); - let num_iterations = rng.random_range(0..self.fns.len() - 1); + let activations = self.activations_in_scope(scope); - for _ in 0..num_iterations { - iter.next().unwrap(); - } - - let mut val = iter.next().unwrap(); - while !val.scope.contains(scope) { - val = iter.next().unwrap(); - } + assert!( + !activations.is_empty(), + "no activation functions registered for scope {:?}", + scope + ); - val.clone() + activations[rng.random_range(0..activations.len())].clone() } }