diff --git a/src/activation.rs b/src/activation.rs index dadb9d2..1ddfcc8 100644 --- a/src/activation.rs +++ b/src/activation.rs @@ -80,6 +80,10 @@ impl ActivationRegistry { /// Gets all activation functions that are valid for a scope. pub fn activations_in_scope(&self, scope: NeuronScope) -> Vec { + if scope == NeuronScope::NONE { + return Vec::new(); + } + let acts = self.activations(); acts.into_iter() @@ -197,13 +201,11 @@ impl<'a> Deserialize<'a> for ActivationFn { let reg = ACTIVATION_REGISTRY.read().unwrap(); - let f = reg.fns.get(name.as_str()); - - if f.is_none() { - panic!("Activation function {name} not found"); - } + let f = reg.fns.get(name.as_str()).ok_or_else(|| { + serde::de::Error::custom(format!("Activation function {name} not found")) + })?; - Ok(f.unwrap().clone()) + Ok(f.clone()) } } diff --git a/src/neuralnet.rs b/src/neuralnet.rs index 86ac803..2097309 100644 --- a/src/neuralnet.rs +++ b/src/neuralnet.rs @@ -346,7 +346,7 @@ impl NeuralNetwork { /// doesn't have any outbound connections. /// Returns the connection if it found one before reaching max_retries. pub fn get_random_connection( - &mut self, + &self, max_retries: usize, rng: &mut impl rand::Rng, ) -> Option {