Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/activation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ impl ActivationRegistry {

/// Gets all activation functions that are valid for a scope.
pub fn activations_in_scope(&self, scope: NeuronScope) -> Vec<ActivationFn> {
if scope == NeuronScope::NONE {
return Vec::new();
}

let acts = self.activations();

acts.into_iter()
Expand Down Expand Up @@ -197,13 +201,11 @@ impl<'a> Deserialize<'a> for ActivationFn {

let reg = ACTIVATION_REGISTRY.read().unwrap();

let f = reg.fns.get(name.as_str());

if f.is_none() {
panic!("Activation function {name} not found");
}
let f = reg.fns.get(name.as_str()).ok_or_else(|| {
serde::de::Error::custom(format!("Activation function {name} not found"))
})?;

Ok(f.unwrap().clone())
Ok(f.clone())
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/neuralnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ impl<const I: usize, const O: usize> NeuralNetwork<I, O> {
/// doesn't have any outbound connections.
/// Returns the connection if it found one before reaching max_retries.
pub fn get_random_connection(
&mut self,
&self,
max_retries: usize,
rng: &mut impl rand::Rng,
) -> Option<Connection> {
Expand Down