diff --git a/.gitignore b/.gitignore
index 1b71596..ee3dd5f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,4 @@
/target/
-/.vscode/
\ No newline at end of file
+/.vscode/
+fitness-plot.svg
+network.json
\ No newline at end of file
diff --git a/CONTRIBUTING b/CONTRIBUTING
deleted file mode 100644
index 68c1a8c..0000000
--- a/CONTRIBUTING
+++ /dev/null
@@ -1,6 +0,0 @@
-Thanks for contributing to this project.
-
-To get started, check out the [issues page](https://github.com/inflectrix/neat). You can either find a feature/fix from there or start a new issue, then begin implementing it in your own fork of this repo.
-
-Once you are done making the changes you'd like the make, start a pull request to the [dev](https://github.com/inflectrix/neat/tree/dev) branch. State your changes and request a review. After all branch rules have been satisfied, someone with management permissions on this repository will merge it.
-
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..8ede4a4
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,7 @@
+Thanks for contributing to this project.
+
+To get started, check out the [issues page](https://github.com/hypercodec/neat). You can either find a feature/fix from there or start a new issue, then begin implementing it in your own fork of this repo.
+
+Once you are done making the changes you'd like the make, start a pull request to the [dev](https://github.com/hypercodec/neat/tree/dev) branch. State your changes and request a review. After all branch rules have been satisfied and the pull request has a valid reason, someone with management permissions on this repository will merge it.
+
+You could also make a draft PR while implementing your features if you want feedback or discussion before finalizing your changes.
\ No newline at end of file
diff --git a/Cargo.lock b/Cargo.lock
index be4d7b8..3971e78 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1,33 +1,56 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
-version = 3
+version = 4
[[package]]
-name = "bincode"
-version = "1.3.3"
+name = "anyhow"
+version = "1.0.101"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
-dependencies = [
- "serde",
-]
+checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea"
+
+[[package]]
+name = "atomic_float"
+version = "1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "628d228f918ac3b82fe590352cc719d30664a0c13ca3a60266fe02c7132d480a"
[[package]]
name = "bitflags"
-version = "2.5.0"
+version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1"
+checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af"
[[package]]
name = "cfg-if"
-version = "1.0.0"
+version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
+checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
+
+[[package]]
+name = "chacha20"
+version = "0.10.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601"
+dependencies = [
+ "cfg-if",
+ "cpufeatures",
+ "rand_core",
+]
+
+[[package]]
+name = "cpufeatures"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201"
+dependencies = [
+ "libc",
+]
[[package]]
name = "crossbeam-deque"
-version = "0.8.5"
+version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
+checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
@@ -44,21 +67,67 @@ dependencies = [
[[package]]
name = "crossbeam-utils"
-version = "0.8.19"
+version = "0.8.21"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
+
+[[package]]
+name = "darling"
+version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345"
+checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d"
+dependencies = [
+ "darling_core",
+ "darling_macro",
+]
+
+[[package]]
+name = "darling_core"
+version = "0.23.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0"
+dependencies = [
+ "ident_case",
+ "proc-macro2",
+ "quote",
+ "strsim",
+ "syn",
+]
+
+[[package]]
+name = "darling_macro"
+version = "0.23.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d"
+dependencies = [
+ "darling_core",
+ "quote",
+ "syn",
+]
[[package]]
name = "either"
-version = "1.9.0"
+version = "1.15.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719"
+
+[[package]]
+name = "equivalent"
+version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
+checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
+
+[[package]]
+name = "foldhash"
+version = "0.1.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "genetic-rs"
-version = "0.5.1"
+version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b94601f3db2fb341f71a4470134eb1f71d39f54c2fe264122698eda67cd1c91b"
+checksum = "8d6ebe63406d0b2e849602cb15d0863f1dd82fec4e841371b3650775236b3be2"
dependencies = [
"genetic-rs-common",
"genetic-rs-macros",
@@ -66,21 +135,22 @@ dependencies = [
[[package]]
name = "genetic-rs-common"
-version = "0.5.1"
+version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4f41b0e3f6ccb66a00e7fc9170d4e02b1ae80c85f03c67b76b067b3637fd314a"
+checksum = "47620e2765a3b28fc684f2721bc0aa34873f9fbdb226c4929174c4634938c307"
dependencies = [
+ "itertools",
"rand",
"rayon",
- "replace_with",
]
[[package]]
name = "genetic-rs-macros"
-version = "0.5.1"
+version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2d5ec3b9e69a6836bb0f0c8fa6972e6322e0b49108f7b3ed40769feb452c120a"
+checksum = "736f25c1dcdaa86d60482e9c690ff31f931c661aae0d5a0faf340a7b7c5326d2"
dependencies = [
+ "darling",
"genetic-rs-common",
"proc-macro2",
"quote",
@@ -89,107 +159,180 @@ dependencies = [
[[package]]
name = "getrandom"
-version = "0.2.12"
+version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
+checksum = "139ef39800118c7683f2fd3c98c1b23c09ae076556b435f8e9064ae108aaeeec"
dependencies = [
"cfg-if",
"libc",
- "wasi",
+ "r-efi",
+ "rand_core",
+ "wasip2",
+ "wasip3",
+]
+
+[[package]]
+name = "hashbrown"
+version = "0.15.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
+dependencies = [
+ "foldhash",
+]
+
+[[package]]
+name = "hashbrown"
+version = "0.16.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
+
+[[package]]
+name = "heck"
+version = "0.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
+
+[[package]]
+name = "id-arena"
+version = "2.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954"
+
+[[package]]
+name = "ident_case"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
+
+[[package]]
+name = "indexmap"
+version = "2.13.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017"
+dependencies = [
+ "equivalent",
+ "hashbrown 0.16.1",
+ "serde",
+ "serde_core",
+]
+
+[[package]]
+name = "itertools"
+version = "0.14.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285"
+dependencies = [
+ "either",
]
[[package]]
name = "itoa"
-version = "1.0.10"
+version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
+checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
[[package]]
name = "lazy_static"
-version = "1.4.0"
+version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
+checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
+
+[[package]]
+name = "leb128fmt"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2"
[[package]]
name = "libc"
-version = "0.2.153"
+version = "0.2.180"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc"
+
+[[package]]
+name = "log"
+version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
+checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
+
+[[package]]
+name = "memchr"
+version = "2.7.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "neat"
-version = "0.5.1"
+version = "1.0.0"
dependencies = [
- "bincode",
+ "atomic_float",
"bitflags",
"genetic-rs",
"lazy_static",
- "rand",
"rayon",
+ "replace_with",
"serde",
"serde-big-array",
"serde_json",
+ "serde_path_to_error",
]
[[package]]
-name = "ppv-lite86"
-version = "0.2.17"
+name = "prettyplease"
+version = "0.2.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
+checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
+dependencies = [
+ "proc-macro2",
+ "syn",
+]
[[package]]
name = "proc-macro2"
-version = "1.0.78"
+version = "1.0.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae"
+checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
-version = "1.0.35"
+version = "1.0.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef"
+checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4"
dependencies = [
"proc-macro2",
]
[[package]]
-name = "rand"
-version = "0.8.5"
+name = "r-efi"
+version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
-dependencies = [
- "libc",
- "rand_chacha",
- "rand_core",
-]
+checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
-name = "rand_chacha"
-version = "0.3.1"
+name = "rand"
+version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
+checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8"
dependencies = [
- "ppv-lite86",
+ "chacha20",
+ "getrandom",
"rand_core",
]
[[package]]
name = "rand_core"
-version = "0.6.4"
+version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
-dependencies = [
- "getrandom",
-]
+checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba"
[[package]]
name = "rayon"
-version = "1.8.1"
+version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051"
+checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
dependencies = [
"either",
"rayon-core",
@@ -197,9 +340,9 @@ dependencies = [
[[package]]
name = "rayon-core"
-version = "1.12.1"
+version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
+checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
@@ -207,22 +350,23 @@ dependencies = [
[[package]]
name = "replace_with"
-version = "0.1.7"
+version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e3a8614ee435691de62bcffcf4a66d91b3594bf1428a5722e79103249a095690"
+checksum = "51743d3e274e2b18df81c4dc6caf8a5b8e15dbe799e0dca05c7617380094e884"
[[package]]
-name = "ryu"
-version = "1.0.17"
+name = "semver"
+version = "1.0.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1"
+checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2"
[[package]]
name = "serde"
-version = "1.0.197"
+version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2"
+checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
dependencies = [
+ "serde_core",
"serde_derive",
]
@@ -235,11 +379,20 @@ dependencies = [
"serde",
]
+[[package]]
+name = "serde_core"
+version = "1.0.228"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
+dependencies = [
+ "serde_derive",
+]
+
[[package]]
name = "serde_derive"
-version = "1.0.197"
+version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b"
+checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
dependencies = [
"proc-macro2",
"quote",
@@ -248,20 +401,39 @@ dependencies = [
[[package]]
name = "serde_json"
-version = "1.0.114"
+version = "1.0.149"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86"
+dependencies = [
+ "itoa",
+ "memchr",
+ "serde",
+ "serde_core",
+ "zmij",
+]
+
+[[package]]
+name = "serde_path_to_error"
+version = "0.1.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0"
+checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457"
dependencies = [
"itoa",
- "ryu",
"serde",
+ "serde_core",
]
+[[package]]
+name = "strsim"
+version = "0.11.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
+
[[package]]
name = "syn"
-version = "2.0.51"
+version = "2.0.114"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6ab617d94515e94ae53b8406c628598680aa0c9587474ecbe58188f7b345d66c"
+checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a"
dependencies = [
"proc-macro2",
"quote",
@@ -270,12 +442,164 @@ dependencies = [
[[package]]
name = "unicode-ident"
-version = "1.0.12"
+version = "1.0.22"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
+
+[[package]]
+name = "unicode-xid"
+version = "0.2.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
+
+[[package]]
+name = "wasip2"
+version = "1.0.1+wasi-0.2.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7"
+dependencies = [
+ "wit-bindgen 0.46.0",
+]
+
+[[package]]
+name = "wasip3"
+version = "0.4.0+wasi-0.3.0-rc-2026-01-06"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5"
+dependencies = [
+ "wit-bindgen 0.51.0",
+]
+
+[[package]]
+name = "wasm-encoder"
+version = "0.244.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319"
+dependencies = [
+ "leb128fmt",
+ "wasmparser",
+]
+
+[[package]]
+name = "wasm-metadata"
+version = "0.244.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909"
+dependencies = [
+ "anyhow",
+ "indexmap",
+ "wasm-encoder",
+ "wasmparser",
+]
+
+[[package]]
+name = "wasmparser"
+version = "0.244.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
+checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe"
+dependencies = [
+ "bitflags",
+ "hashbrown 0.15.5",
+ "indexmap",
+ "semver",
+]
+
+[[package]]
+name = "wit-bindgen"
+version = "0.46.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59"
+
+[[package]]
+name = "wit-bindgen"
+version = "0.51.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5"
+dependencies = [
+ "wit-bindgen-rust-macro",
+]
+
+[[package]]
+name = "wit-bindgen-core"
+version = "0.51.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc"
+dependencies = [
+ "anyhow",
+ "heck",
+ "wit-parser",
+]
+
+[[package]]
+name = "wit-bindgen-rust"
+version = "0.51.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21"
+dependencies = [
+ "anyhow",
+ "heck",
+ "indexmap",
+ "prettyplease",
+ "syn",
+ "wasm-metadata",
+ "wit-bindgen-core",
+ "wit-component",
+]
+
+[[package]]
+name = "wit-bindgen-rust-macro"
+version = "0.51.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a"
+dependencies = [
+ "anyhow",
+ "prettyplease",
+ "proc-macro2",
+ "quote",
+ "syn",
+ "wit-bindgen-core",
+ "wit-bindgen-rust",
+]
+
+[[package]]
+name = "wit-component"
+version = "0.244.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2"
+dependencies = [
+ "anyhow",
+ "bitflags",
+ "indexmap",
+ "log",
+ "serde",
+ "serde_derive",
+ "serde_json",
+ "wasm-encoder",
+ "wasm-metadata",
+ "wasmparser",
+ "wit-parser",
+]
+
+[[package]]
+name = "wit-parser"
+version = "0.244.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736"
+dependencies = [
+ "anyhow",
+ "id-arena",
+ "indexmap",
+ "log",
+ "semver",
+ "serde",
+ "serde_derive",
+ "serde_json",
+ "unicode-xid",
+ "wasmparser",
+]
[[package]]
-name = "wasi"
-version = "0.11.0+wasi-snapshot-preview1"
+name = "zmij"
+version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
+checksum = "bd8f3f50b848df28f887acb68e41201b5aea6bc8a8dacc00fb40635ff9a72fea"
diff --git a/Cargo.toml b/Cargo.toml
index 91247ad..065fc57 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,11 +1,11 @@
[package]
name = "neat"
description = "Crate for working with NEAT in rust"
-version = "0.5.1"
+version = "1.0.0"
edition = "2021"
-authors = ["Inflectrix"]
-repository = "https://github.com/inflectrix/neat"
-homepage = "https://github.com/inflectrix/neat"
+authors = ["HyperCodec"]
+repository = "https://github.com/HyperCodec/neat"
+homepage = "https://github.com/HyperCodec/neat"
readme = "README.md"
keywords = ["genetic", "machine-learning", "ai", "algorithm", "evolution"]
categories = ["algorithms", "science", "simulation"]
@@ -17,23 +17,35 @@ rustdoc-args = ["--cfg", "docsrs"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+[[example]]
+name = "basic"
+path = "examples/basic.rs"
+required-features = ["genetic-rs/derive"]
+
+[[example]]
+name = "extra_genes"
+path = "examples/extra_genes.rs"
+required-features = ["genetic-rs/derive"]
+
+[[example]]
+name = "serde"
+path = "examples/serde.rs"
+required-features = ["serde"]
+
[features]
-default = ["max-index"]
-crossover = ["genetic-rs/crossover"]
-rayon = ["genetic-rs/rayon", "dep:rayon"]
-max-index = []
+default = []
serde = ["dep:serde", "dep:serde-big-array"]
-
[dependencies]
-bitflags = "2.5.0"
-genetic-rs = { version = "0.5.1", features = ["derive"] }
-lazy_static = "1.4.0"
-rand = "0.8.5"
-rayon = { version = "1.8.1", optional = true }
-serde = { version = "1.0.197", features = ["derive"], optional = true }
+atomic_float = "1.1.0"
+bitflags = "2.11.0"
+genetic-rs = { version = "1.2.0", features = ["rayon"] }
+lazy_static = "1.5.0"
+rayon = "1.11.0"
+replace_with = "0.1.8"
+serde = { version = "1.0.228", features = ["derive"], optional = true }
serde-big-array = { version = "0.5.1", optional = true }
[dev-dependencies]
-bincode = "1.3.3"
-serde_json = "1.0.114"
+serde_json = "1.0.149"
+serde_path_to_error = "0.1.20"
diff --git a/README.md b/README.md
index ad775e2..d4e3312 100644
--- a/README.md
+++ b/README.md
@@ -1,102 +1,96 @@
# neat
-[
](https://github.com/inflectrix/neat)
+[
](https://github.com/hypercodec/neat)
[
](https://crates.io/crates/neat)
[
](https://docs.rs/neat)
Implementation of the NEAT algorithm using `genetic-rs`.
### Features
-- rayon - Uses parallelization on the `NeuralNetwork` struct and adds the `rayon` feature to the `genetic-rs` re-export.
-- serde - Adds the NNTSerde struct and allows for serialization of `NeuralNetworkTopology`
-- crossover - Implements the `CrossoverReproduction` trait on `NeuralNetworkTopology` and adds the `crossover` feature to the `genetic-rs` re-export.
+- serde - Implements `Serialize` and `Deserialize` on most of the types in this crate.
-*Do you like this repo and want to support it? If so, leave a ⭐*
+*Do you like this crate and want to support it? If so, leave a ⭐*
-### How To Use
-When working with this crate, you'll want to use the `NeuralNetworkTopology` struct in your agent's DNA and
-the use `NeuralNetwork::from` when you finally want to test its performance. The `genetic-rs` crate is also re-exported with the rest of this crate.
-
-Here's an example of how one might use this crate:
+# How To Use
+The `NeuralNetwork` struct is the main type exported by this crate. The `I` is the number of input neurons, and `O` is the number of output neurons. It implements `GenerateRandom`, `RandomlyMutable`, `Mitosis`, and `Crossover`, with a lot of customizability. This means that you can use it standalone as your organism's entire genome:
```rust
use neat::*;
-#[derive(Clone, RandomlyMutable, DivisionReproduction)]
-struct MyAgentDNA {
- network: NeuralNetworkTopology<1, 2>,
+fn fitness(net: &NeuralNetwork<5, 6>) -> f32 {
+ // ideally you'd test multiple times for consistency,
+ // but this is just a simple example.
+ // it's also generally good to normalize your inputs between -1..1,
+ // but NEAT is usually flexible enough to still work anyways
+ let inputs = [1.0, 2.0, 3.0, 4.0, 5.0];
+ let outputs = net.predict(inputs);
+
+ // simple fitness: sum of outputs
+ // you should replace this with a real fitness test
+ outputs.iter().sum()
}
-impl GenerateRandom for MyAgentDNA {
- fn gen_random(rng: &mut impl rand::Rng) -> Self {
- Self {
- network: NeuralNetworkTopology::new(0.01, 3, rng),
- }
- }
-}
-
-struct MyAgent {
- network: NeuralNetwork<1, 2>,
- // ... other state
-}
+fn main() {
+ let mut rng = rand::rng();
+ let mut sim = GeneticSim::new(
+ Vec::gen_random(&mut rng, 100),
+ FitnessEliminator::new_without_observer(fitness),
+ CrossoverRepopulator::new(0.25, ReproductionSettings::default()),
+ );
-impl From<&MyAgentDNA> for MyAgent {
- fn from(value: &MyAgentDNA) -> Self {
- Self {
- network: NeuralNetwork::from(&value.network),
- }
- }
+ sim.perform_generations(100);
}
+```
-fn fitness(dna: &MyAgentDNA) -> f32 {
- // agent will simply try to predict whether a number is greater than 0.5
- let mut agent = MyAgent::from(dna);
- let mut rng = rand::thread_rng();
- let mut fitness = 0;
+Or just a part of a more complex genome:
+```rust,ignore
+use neat::*;
- // use repeated tests to avoid situational bias and some local maximums, overall providing more accurate score
- for _ in 0..10 {
- let n = rng.gen::();
- let above = n > 0.5;
+#[derive(Clone, Debug)]
+struct PhysicalStats {
+ strength: f32,
+ speed: f32,
+ // ...
+}
- let res = agent.network.predict([n]);
- let resi = res.iter().max_index();
+// ... implement `RandomlyMutable`, `GenerateRandom`, `Crossover`, `Default`, etc.
- if resi == 0 ^ above {
- // agent did not guess correctly, punish slightly (too much will hinder exploration)
- fitness -= 0.5;
+#[derive(Clone, Debug, GenerateRandom, RandomlyMutable, Mitosis, Crossover)]
+#[randmut(create_context = MyGenomeCtx)]
+#[crossover(with_context = MyGenomeCtx)]
+struct MyGenome {
+ brain: NeuralNetwork<4, 2>,
+ stats: PhysicalStats,
+}
- continue;
+impl Default for MyGenomeCtx {
+ fn default() -> Self {
+ Self {
+ brain: ReproductionSettings::default(),
+ stats: PhysicalStats::default(),
}
-
- // agent guessed correctly, they become more fit.
- fitness += 3.;
}
+}
- fitness
+fn fitness(genome: &MyGenome) -> f32 {
+ let inputs = [1.0, 2.0, 3.0, 4.0];
+ let outputs = genome.brain.predict(inputs);
+ // fitness uses both brain output and stats
+ outputs.iter().sum::() + genome.stats.strength + genome.stats.speed
}
+// main is the exact same as before
fn main() {
- let mut rng = rand::thread_rng();
-
+ let mut rng = rand::rng();
let mut sim = GeneticSim::new(
Vec::gen_random(&mut rng, 100),
- fitness,
- division_pruning_nextgen,
+ FitnessEliminator::new_without_observer(fitness),
+ CrossoverRepopulator::new(0.25, MyGenomeCtx::default()),
);
- // simulate 100 generations
- for _ in 0..100 {
- sim.next_generation();
- }
-
- // display fitness results
- let fits: Vec<_> = sim.entities
- .iter()
- .map(fitness)
- .collect();
-
- dbg!(&fits, fits.iter().max());
+ sim.perform_generations(100);
}
```
+If you want more in-depth examples, look at the [examples](https://github.com/HyperCodec/neat/tree/main/examples). You can also check out the [genetic-rs docs](https://docs.rs/genetic_rs) to see what other options you have to customize your genetic simulation.
+
### License
This crate falls under the `MIT` license
diff --git a/examples/basic.rs b/examples/basic.rs
index 9ad0419..43dd24e 100644
--- a/examples/basic.rs
+++ b/examples/basic.rs
@@ -1,130 +1,74 @@
-//! A basic example of NEAT with this crate. Enable the `crossover` feature for it to use crossover reproduction
-
use neat::*;
-use rand::prelude::*;
-
-#[derive(PartialEq, Clone, Debug, DivisionReproduction, RandomlyMutable)]
-#[cfg_attr(feature = "crossover", derive(CrossoverReproduction))]
-struct AgentDNA {
- network: NeuralNetworkTopology<2, 4>,
-}
-
-impl Prunable for AgentDNA {}
-
-impl GenerateRandom for AgentDNA {
- fn gen_random(rng: &mut impl rand::Rng) -> Self {
- Self {
- network: NeuralNetworkTopology::new(0.01, 3, rng),
- }
- }
-}
-
-#[derive(Debug)]
-struct Agent {
- network: NeuralNetwork<2, 4>,
-}
-
-impl From<&AgentDNA> for Agent {
- fn from(value: &AgentDNA) -> Self {
- Self {
- network: (&value.network).into(),
- }
- }
-}
-
-fn fitness(dna: &AgentDNA) -> f32 {
- let agent = Agent::from(dna);
-
- let mut fitness = 0.;
- let mut rng = rand::thread_rng();
-
- for _ in 0..10 {
- // 10 games
-
- // set up game
- let mut agent_pos: (i32, i32) = (rng.gen_range(0..10), rng.gen_range(0..10));
- let mut food_pos: (i32, i32) = (rng.gen_range(0..10), rng.gen_range(0..10));
-
- while food_pos == agent_pos {
- food_pos = (rng.gen_range(0..10), rng.gen_range(0..10));
- }
- let mut step = 0;
+// approximate the to_degrees function, which should be pretty
+// hard for a traditional network to learn since it's not really close to -1..1 mapping.
+fn fitness(net: &NeuralNetwork<1, 1>) -> f32 {
+ let mut rng = rand::rng();
+ let mut total_fitness = 0.0;
- loop {
- // perform actions in game
- let action = agent.network.predict([
- (food_pos.0 - agent_pos.0) as f32,
- (food_pos.1 - agent_pos.1) as f32,
- ]);
- let action = action.iter().max_index();
-
- match action {
- 0 => agent_pos.0 += 1,
- 1 => agent_pos.0 -= 1,
- 2 => agent_pos.1 += 1,
- _ => agent_pos.1 -= 1,
- }
-
- step += 1;
-
- if agent_pos == food_pos {
- fitness += 10.;
- break; // new game
- } else {
- // lose fitness for being slow and far away
- fitness -=
- (food_pos.0 - agent_pos.0 + food_pos.1 - agent_pos.1).abs() as f32 * 0.001;
- }
-
- // 50 steps per game
- if step == 50 {
- break;
- }
- }
+ // it's good practice to test on multiple inputs to get a more accurate fitness score
+ for _ in 0..100 {
+ let input = rng.random_range(-10.0..10.0);
+ let output = net.predict([input])[0];
+ let expected_output = input.to_degrees();
+
+ // basically just using negative error as fitness.
+ // percentage error doesn't work as well here since
+ // expected_output can be either very small or very large in magnitude.
+ total_fitness -= (output - expected_output).abs();
}
- fitness
+ total_fitness
}
fn main() {
- #[cfg(not(feature = "rayon"))]
- let mut rng = rand::thread_rng();
+ let mut rng = rand::rng();
let mut sim = GeneticSim::new(
- #[cfg(not(feature = "rayon"))]
- Vec::gen_random(&mut rng, 100),
- #[cfg(feature = "rayon")]
- Vec::gen_random(100),
- fitness,
- #[cfg(not(feature = "crossover"))]
- division_pruning_nextgen,
- #[cfg(feature = "crossover")]
- crossover_pruning_nextgen,
+ Vec::gen_random(&mut rng, 250),
+ FitnessEliminator::new_without_observer(fitness),
+ CrossoverRepopulator::new(0.25, ReproductionSettings::default()),
);
- for _ in 0..100 {
+ for i in 0..=150 {
sim.next_generation();
- }
-
- #[cfg(not(feature = "serde"))]
- let mut fits: Vec<_> = sim.genomes.iter().map(fitness).collect();
-
- #[cfg(feature = "serde")]
- let mut fits: Vec<_> = sim.genomes.iter().map(|e| (e, fitness(e))).collect();
-
- #[cfg(not(feature = "serde"))]
- fits.sort_by(|a, b| a.partial_cmp(&b).unwrap());
-
- #[cfg(feature = "serde")]
- fits.sort_by(|(_, a), (_, b)| a.partial_cmp(&b).unwrap());
- dbg!(&fits);
-
- #[cfg(feature = "serde")]
- {
- let intermediate = NNTSerde::from(&fits[0].0.network);
- let serialized = serde_json::to_string(&intermediate).unwrap();
- println!("{}", serialized);
+ // sample a genome to print its fitness.
+ // this value should approach 0 as the generations go on, since the fitness is negative error.
+ // with the way CrossoverRepopulator (and all builtin repopulators) works internally, the parent genomes
+ // (i.e. prev generation champs) are more likely to be at the start of the genomes vector.
+ let sample = &sim.genomes[0];
+ let fit = fitness(sample);
+ println!("Gen {i} sample fitness: {fit}");
+ }
+ println!("Training complete, now you can test the network!");
+
+ let net = &sim.genomes[0];
+ println!("Network in use: {:#?}", net);
+
+ loop {
+ let mut input_text = String::new();
+ println!("Enter a number to convert to degrees (or 'exit' to quit): ");
+ std::io::stdin().read_line(&mut input_text).unwrap();
+ let input_text = input_text.trim();
+ if input_text.eq_ignore_ascii_case("exit") {
+ break;
+ }
+ let input: f32 = match input_text.parse() {
+ Ok(num) => num,
+ Err(_) => {
+ println!("Invalid input, please enter a valid number.");
+ continue;
+ }
+ };
+
+ let output = net.predict([input])[0];
+ let expected_output = input.to_degrees();
+ println!(
+ "Network output: {}, Expected output: {}, Error: {}",
+ output,
+ expected_output,
+ (output - expected_output).abs()
+ );
}
}
diff --git a/examples/extra_genes.rs b/examples/extra_genes.rs
new file mode 100644
index 0000000..c5d7659
--- /dev/null
+++ b/examples/extra_genes.rs
@@ -0,0 +1,372 @@
+use neat::*;
+use std::f32::consts::PI;
+
+// ==========================================================================
+// SIMULATION CONSTANTS - Adjust these to experiment with different dynamics
+// ==========================================================================
+
+// World/Environment Settings
+const WORLD_WIDTH: f32 = 800.0;
+const WORLD_HEIGHT: f32 = 600.0;
+const INITIAL_FOOD_COUNT: usize = 20;
+const FOOD_RESPAWN_THRESHOLD: usize = 10;
+const FOOD_DETECTION_DISTANCE: f32 = 10.0;
+
+// Energy/Food Settings
+const BASE_FOOD_ENERGY: f32 = 20.0; // Energy from each food item
+const STRENGTH_ENERGY_MULTIPLIER: f32 = 10.0; // Extra energy per strength stat
+const MOVEMENT_ENERGY_COST: f32 = 0.2; // Cost per unit of movement
+const IDLE_ENERGY_COST: f32 = 0.1; // Cost per timestep just existing
+
+// Fitness Settings
+const FITNESS_PER_FOOD: f32 = 100.0; // Points per food eaten
+
+// Physical Stats - Min/Max Bounds
+const SPEED_MIN: f32 = 0.5;
+const SPEED_MAX: f32 = 6.0;
+const STRENGTH_MIN: f32 = 0.2;
+const STRENGTH_MAX: f32 = 4.0;
+const SENSE_RANGE_MIN: f32 = 30.0;
+const SENSE_RANGE_MAX: f32 = 250.0;
+const ENERGY_CAPACITY_MIN: f32 = 50.0;
+const ENERGY_CAPACITY_MAX: f32 = 400.0;
+
+// Physical Stats - Initial Generation Range
+const SPEED_INIT_MIN: f32 = 1.0;
+const SPEED_INIT_MAX: f32 = 5.0;
+const STRENGTH_INIT_MIN: f32 = 0.5;
+const STRENGTH_INIT_MAX: f32 = 3.0;
+const SENSE_RANGE_INIT_MIN: f32 = 50.0;
+const SENSE_RANGE_INIT_MAX: f32 = 200.0;
+const ENERGY_CAPACITY_INIT_MIN: f32 = 100.0;
+const ENERGY_CAPACITY_INIT_MAX: f32 = 300.0;
+
+// Mutation Settings
+const SPEED_MUTATION_PROB: f32 = 0.3;
+const SPEED_MUTATION_RANGE: f32 = 0.5;
+const STRENGTH_MUTATION_PROB: f32 = 0.2;
+const STRENGTH_MUTATION_RANGE: f32 = 0.3;
+const SENSE_MUTATION_PROB: f32 = 0.2;
+const SENSE_MUTATION_RANGE: f32 = 20.0;
+const CAPACITY_MUTATION_PROB: f32 = 0.2;
+const CAPACITY_MUTATION_RANGE: f32 = 30.0;
+
+// Genetic Algorithm Settings
+const POPULATION_SIZE: usize = 150;
+const HIGHEST_GENERATION: usize = 250;
+const SIMULATION_TIMESTEPS: usize = 500;
+const MUTATION_RATE: f32 = 0.3;
+
+/// Mutation settings for physical stats
+#[derive(Clone, Debug)]
+struct PhysicalStatsMutationSettings {
+ speed_prob: f32,
+ speed_range: f32,
+ strength_prob: f32,
+ strength_range: f32,
+ sense_prob: f32,
+ sense_range: f32,
+ capacity_prob: f32,
+ capacity_range: f32,
+}
+
+impl Default for PhysicalStatsMutationSettings {
+ fn default() -> Self {
+ Self {
+ speed_prob: SPEED_MUTATION_PROB,
+ speed_range: SPEED_MUTATION_RANGE,
+ strength_prob: STRENGTH_MUTATION_PROB,
+ strength_range: STRENGTH_MUTATION_RANGE,
+ sense_prob: SENSE_MUTATION_PROB,
+ sense_range: SENSE_MUTATION_RANGE,
+ capacity_prob: CAPACITY_MUTATION_PROB,
+ capacity_range: CAPACITY_MUTATION_RANGE,
+ }
+ }
+}
+
+/// Physical traits/stats for an organism
+#[derive(Clone, Debug, PartialEq)]
+struct PhysicalStats {
+ /// Speed multiplier (faster = longer strides but more energy cost)
+ speed: f32,
+ /// Strength stat (affects energy from food)
+ strength: f32,
+ /// Sense range (how far it can detect food)
+ sense_range: f32,
+ /// Energy capacity (larger = can go longer without food)
+ energy_capacity: f32,
+}
+
+impl PhysicalStats {
+ fn clamp(&mut self) {
+ self.speed = self.speed.clamp(SPEED_MIN, SPEED_MAX);
+ self.strength = self.strength.clamp(STRENGTH_MIN, STRENGTH_MAX);
+ self.sense_range = self.sense_range.clamp(SENSE_RANGE_MIN, SENSE_RANGE_MAX);
+ self.energy_capacity = self
+ .energy_capacity
+ .clamp(ENERGY_CAPACITY_MIN, ENERGY_CAPACITY_MAX);
+ }
+}
+
+impl GenerateRandom for PhysicalStats {
+ fn gen_random(rng: &mut impl rand::Rng) -> Self {
+ let mut stats = PhysicalStats {
+ speed: rng.random_range(SPEED_INIT_MIN..SPEED_INIT_MAX),
+ strength: rng.random_range(STRENGTH_INIT_MIN..STRENGTH_INIT_MAX),
+ sense_range: rng.random_range(SENSE_RANGE_INIT_MIN..SENSE_RANGE_INIT_MAX),
+ energy_capacity: rng.random_range(ENERGY_CAPACITY_INIT_MIN..ENERGY_CAPACITY_INIT_MAX),
+ };
+ stats.clamp();
+ stats
+ }
+}
+
+impl RandomlyMutable for PhysicalStats {
+ type Context = PhysicalStatsMutationSettings;
+
+ fn mutate(&mut self, context: &Self::Context, _severity: f32, rng: &mut impl rand::Rng) {
+ if rng.random::() < context.speed_prob {
+ self.speed += rng.random_range(-context.speed_range..context.speed_range);
+ }
+ if rng.random::() < context.strength_prob {
+ self.strength += rng.random_range(-context.strength_range..context.strength_range);
+ }
+ if rng.random::() < context.sense_prob {
+ self.sense_range += rng.random_range(-context.sense_range..context.sense_range);
+ }
+ if rng.random::() < context.capacity_prob {
+ self.energy_capacity +=
+ rng.random_range(-context.capacity_range..context.capacity_range);
+ }
+ self.clamp();
+ }
+}
+
+impl Crossover for PhysicalStats {
+ type Context = PhysicalStatsMutationSettings;
+
+ fn crossover(
+ &self,
+ other: &Self,
+ context: &Self::Context,
+ _severity: f32,
+ rng: &mut impl rand::Rng,
+ ) -> Self {
+ let mut child = PhysicalStats {
+ speed: (self.speed + other.speed) / 2.0
+ + rng.random_range(-context.speed_range..context.speed_range),
+ strength: (self.strength + other.strength) / 2.0
+ + rng.random_range(-context.strength_range..context.strength_range),
+ sense_range: (self.sense_range + other.sense_range) / 2.0
+ + rng.random_range(-context.sense_range..context.sense_range),
+ energy_capacity: (self.energy_capacity + other.energy_capacity) / 2.0
+ + rng.random_range(-context.capacity_range..context.capacity_range),
+ };
+ child.clamp();
+ child
+ }
+}
+
+/// A complete organism genome containing both neural network and physical traits
+#[derive(Clone, Debug, PartialEq, GenerateRandom, RandomlyMutable, Crossover)]
+#[randmut(create_context = OrganismCtx)]
+#[crossover(with_context = OrganismCtx)]
+struct OrganismGenome {
+ brain: NeuralNetwork<8, 2>,
+ stats: PhysicalStats,
+}
+
+/// Running instance of an organism with current position and energy
+struct OrganismInstance {
+ genome: OrganismGenome,
+ x: f32,
+ y: f32,
+ angle: f32,
+ energy: f32,
+ lifetime: usize,
+ food_eaten: usize,
+}
+
+impl OrganismInstance {
+ fn new(genome: OrganismGenome) -> Self {
+ let energy = genome.stats.energy_capacity;
+ Self {
+ genome,
+ x: rand::random::() * WORLD_WIDTH,
+ y: rand::random::() * WORLD_HEIGHT,
+ angle: rand::random::() * 2.0 * PI,
+ energy,
+ lifetime: 0,
+ food_eaten: 0,
+ }
+ }
+
+ /// Simulate one timestep: sense food, decide movement, consume energy, age
+ fn step(&mut self, food_sources: &[(f32, f32)]) {
+ self.lifetime += 1;
+
+ // find nearest food
+ let mut nearest_food_dist = f32::INFINITY;
+ let mut nearest_food_angle = 0.0;
+ let mut nearest_food_x_diff = 0.0;
+ let mut nearest_food_y_diff = 0.0;
+
+ for &(fx, fy) in food_sources {
+ let dx = fx - self.x;
+ let dy = fy - self.y;
+ let dist = (dx * dx + dy * dy).sqrt();
+
+ if dist < self.genome.stats.sense_range && dist < nearest_food_dist {
+ nearest_food_dist = dist;
+ nearest_food_angle = (dy.atan2(dx) - self.angle).sin();
+ nearest_food_x_diff = (dx / 100.0).clamp(-1.0, 1.0);
+ nearest_food_y_diff = (dy / 100.0).clamp(-1.0, 1.0);
+ }
+ }
+
+ let sense_food = if nearest_food_dist < self.genome.stats.sense_range {
+ 1.0
+ } else {
+ 0.0
+ };
+
+ // Create inputs for neural network:
+ // 0: current energy level (0-1)
+ // 1: food detected (0 or 1)
+ // 2: nearest food angle (normalized)
+ // 3: nearest food x diff
+ // 4: nearest food y diff
+ // 5: speed stat (normalized)
+ // 6: energy capacity (normalized)
+ // 7: age (slow-paced, up to 1 at age 1000)
+ let inputs = [
+ (self.energy / self.genome.stats.energy_capacity).clamp(0.0, 1.0),
+ sense_food,
+ nearest_food_angle,
+ nearest_food_x_diff,
+ nearest_food_y_diff,
+ (self.genome.stats.speed / 5.0).clamp(0.0, 1.0),
+ (self.genome.stats.energy_capacity / 200.0).clamp(0.0, 1.0),
+ (self.lifetime as f32 / 1000.0).clamp(0.0, 1.0),
+ ];
+
+ // get movement outputs from neural network
+ let outputs = self.genome.brain.predict(inputs);
+ let move_forward = (outputs[0] * self.genome.stats.speed).clamp(-5.0, 5.0);
+ let turn = (outputs[1] * PI / 4.0).clamp(-PI / 8.0, PI / 8.0);
+
+ // update position and angle
+ self.angle += turn;
+ self.x += move_forward * self.angle.cos();
+ self.y += move_forward * self.angle.sin();
+
+ // wrap around world
+ if self.x < 0.0 {
+ self.x += WORLD_WIDTH;
+ } else if self.x >= WORLD_WIDTH {
+ self.x -= WORLD_WIDTH;
+ }
+ if self.y < 0.0 {
+ self.y += WORLD_HEIGHT;
+ } else if self.y >= WORLD_HEIGHT {
+ self.y -= WORLD_HEIGHT;
+ }
+
+ // consume energy for movement
+ let movement_cost = (move_forward.abs() / self.genome.stats.speed).max(0.5);
+ self.energy -= movement_cost * MOVEMENT_ENERGY_COST;
+
+ // consume energy for existing
+ self.energy -= IDLE_ENERGY_COST;
+ }
+
+ /// Check if organism lands on food and consume it
+ fn eat(&mut self, food_sources: &mut Vec<(f32, f32)>) {
+ food_sources.retain(|&(fx, fy)| {
+ let dx = fx - self.x;
+ let dy = fy - self.y;
+ let dist = (dx * dx + dy * dy).sqrt();
+ if dist < FOOD_DETECTION_DISTANCE {
+ // ate food
+ self.energy +=
+ BASE_FOOD_ENERGY + (self.genome.stats.strength * STRENGTH_ENERGY_MULTIPLIER);
+ self.energy = self.energy.min(self.genome.stats.energy_capacity);
+ self.food_eaten += 1;
+ false
+ } else {
+ true
+ }
+ });
+ }
+
+ fn is_alive(&self) -> bool {
+ self.energy > 0.0
+ }
+
+ fn fitness(&self) -> f32 {
+ let food_fitness = (self.food_eaten as f32) * FITNESS_PER_FOOD;
+ food_fitness
+ }
+}
+
+/// Evaluate an organism's fitness by running a simulation
+fn evaluate_organism(genome: &OrganismGenome) -> f32 {
+ let mut rng = rand::rng();
+
+ let mut food_sources: Vec<(f32, f32)> = (0..INITIAL_FOOD_COUNT)
+ .map(|_| {
+ (
+ rng.random_range(0.0..WORLD_WIDTH),
+ rng.random_range(0.0..WORLD_HEIGHT),
+ )
+ })
+ .collect();
+
+ let mut instance = OrganismInstance::new(genome.clone());
+
+ for _ in 0..SIMULATION_TIMESTEPS {
+ if instance.is_alive() {
+ instance.step(&food_sources);
+ instance.eat(&mut food_sources);
+ }
+
+ // respawn food
+ if food_sources.len() < FOOD_RESPAWN_THRESHOLD {
+ food_sources.push((
+ rng.random_range(0.0..WORLD_WIDTH),
+ rng.random_range(0.0..WORLD_HEIGHT),
+ ));
+ }
+ }
+
+ instance.fitness()
+}
+
+fn main() {
+ let mut rng = rand::rng();
+
+ println!("Starting genetic NEAT simulation with physical traits");
+ println!("Population: {} organisms", POPULATION_SIZE);
+ println!("Each has: Neural Network Brain + Physical Stats (Speed, Strength, Sense Range, Energy Capacity)\n");
+
+ let mut sim = GeneticSim::new(
+ Vec::gen_random(&mut rng, POPULATION_SIZE),
+ FitnessEliminator::new_without_observer(evaluate_organism),
+ CrossoverRepopulator::new(MUTATION_RATE, OrganismCtx::default()),
+ );
+
+ for generation in 0..=HIGHEST_GENERATION {
+ sim.next_generation();
+
+ let sample = &sim.genomes[0];
+ let fitness = evaluate_organism(sample);
+
+ println!(
+ "Gen {}: Sample fitness: {:.1} | Speed: {:.2}, Strength: {:.2}, Sense: {:.1}, Capacity: {:.1}",
+ generation, fitness, sample.stats.speed, sample.stats.strength, sample.stats.sense_range, sample.stats.energy_capacity
+ );
+ }
+
+ println!("\nSimulation complete!");
+}
diff --git a/examples/serde.rs b/examples/serde.rs
new file mode 100644
index 0000000..90b4e81
--- /dev/null
+++ b/examples/serde.rs
@@ -0,0 +1,36 @@
+use neat::{activation::register_activation, *};
+
+const OUTPUT_PATH: &str = "network.json";
+
+fn magic_activation(x: f32) -> f32 {
+ // just a random activation function to show that it gets serialized and deserialized correctly.
+ (x * 2.0).sin()
+}
+
+fn main() {
+ // custom activation functions must be registered before deserialization, since the network needs to know how to deserialize them.
+ register_activation(activation_fn!(magic_activation));
+
+ let mut rng = rand::rng();
+ let mut net = NeuralNetwork::<10, 10>::new(&mut rng);
+
+ println!("Mutating network...");
+
+ for _ in 0..100 {
+ net.mutate(&MutationSettings::default(), 0.25, &mut rng);
+ }
+
+ let file =
+ std::fs::File::create(OUTPUT_PATH).expect("Failed to create file for network output");
+ serde_json::to_writer_pretty(file, &net).expect("Failed to write network to file");
+
+ println!("Network saved to {OUTPUT_PATH}");
+
+ // reopen because for some reason io hates working properly with both read and write
+ // (even when using OpenOptions)
+ let file = std::fs::File::open(OUTPUT_PATH).expect("Failed to open network file for reading");
+ let net2: NeuralNetwork<10, 10> =
+ serde_json::from_reader(file).expect("Failed to parse network from file");
+ assert_eq!(net, net2);
+ println!("Network successfully loaded from file and matches original!");
+}
diff --git a/src/topology/activation.rs b/src/activation.rs
similarity index 67%
rename from src/topology/activation.rs
rename to src/activation.rs
index a711851..3ea25b6 100644
--- a/src/topology/activation.rs
+++ b/src/activation.rs
@@ -1,7 +1,13 @@
+/// Contains some builtin activation functions ([`sigmoid`], [`relu`], etc.)
+pub mod builtin;
+
+use bitflags::bitflags;
+use builtin::*;
+
+use genetic_rs::prelude::{rand, RngExt};
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
-use bitflags::bitflags;
use lazy_static::lazy_static;
use std::{
collections::HashMap,
@@ -15,11 +21,11 @@ use crate::NeuronLocation;
#[macro_export]
macro_rules! activation_fn {
($F: path) => {
- ActivationFn::new(Arc::new($F), ActivationScope::default(), stringify!($F).into())
+ $crate::activation::ActivationFn::new(std::sync::Arc::new($F), $crate::activation::NeuronScope::default(), stringify!($F).into())
};
($F: path, $S: expr) => {
- ActivationFn::new(Arc::new($F), $S, stringify!($F).into())
+ $crate::activation::ActivationFn::new(std::sync::Arc::new($F), $S, stringify!($F).into())
};
{$($F: path),*} => {
@@ -51,13 +57,13 @@ pub fn batch_register_activation(acts: impl IntoIterator- ) {
/// A registry of the different possible activation functions.
pub struct ActivationRegistry {
/// The currently-registered activation functions.
- pub fns: HashMap,
+ pub fns: HashMap<&'static str, ActivationFn>,
}
impl ActivationRegistry {
/// Registers an activation function.
pub fn register(&mut self, activation: ActivationFn) {
- self.fns.insert(activation.name.clone(), activation);
+ self.fns.insert(activation.name, activation);
}
/// Registers multiple activation functions at once.
@@ -67,19 +73,45 @@ impl ActivationRegistry {
}
}
- /// Gets a Vec of all the activation functions registered. Unless you need an owned value, use [fns][ActivationRegistry::fns].values() instead.
+ /// Gets a Vec of all the activation functions registered. Use [fns][ActivationRegistry::fns] if you only need an iterator.
pub fn activations(&self) -> Vec {
self.fns.values().cloned().collect()
}
/// Gets all activation functions that are valid for a scope.
- pub fn activations_in_scope(&self, scope: ActivationScope) -> Vec {
+ pub fn activations_in_scope(&self, scope: NeuronScope) -> Vec {
let acts = self.activations();
acts.into_iter()
- .filter(|a| a.scope != ActivationScope::NONE && a.scope.contains(scope))
+ .filter(|a| a.scope.contains(scope))
.collect()
}
+
+ /// Clears all existing values in the activation registry.
+ pub fn clear(&mut self) {
+ self.fns.clear();
+ }
+
+ /// Fetches a random activation fn that applies to the provided scope.
+ pub fn random_activation_in_scope(
+ &self,
+ scope: NeuronScope,
+ rng: &mut impl rand::Rng,
+ ) -> ActivationFn {
+ let mut iter = self.fns.values().cycle();
+ let num_iterations = rng.random_range(0..self.fns.len() - 1);
+
+ for _ in 0..num_iterations {
+ iter.next().unwrap();
+ }
+
+ let mut val = iter.next().unwrap();
+ while !val.scope.contains(scope) {
+ val = iter.next().unwrap();
+ }
+
+ val.clone()
+ }
}
impl Default for ActivationRegistry {
@@ -88,51 +120,18 @@ impl Default for ActivationRegistry {
fns: HashMap::new(),
};
+ // TODO add a way to disable this
s.batch_register(activation_fn! {
- sigmoid => ActivationScope::HIDDEN | ActivationScope::OUTPUT,
- relu => ActivationScope::HIDDEN | ActivationScope::OUTPUT,
- linear_activation => ActivationScope::INPUT | ActivationScope::HIDDEN | ActivationScope::OUTPUT,
- f32::tanh => ActivationScope::HIDDEN | ActivationScope::OUTPUT
+ sigmoid => NeuronScope::HIDDEN | NeuronScope::OUTPUT,
+ relu => NeuronScope::HIDDEN | NeuronScope::OUTPUT,
+ linear_activation => NeuronScope::INPUT | NeuronScope::HIDDEN | NeuronScope::OUTPUT,
+ f32::tanh => NeuronScope::HIDDEN | NeuronScope::OUTPUT
});
s
}
}
-bitflags! {
- /// Specifies where an activation function can occur
- #[derive(Copy, Clone, Debug, Eq, PartialEq)]
- pub struct ActivationScope: u8 {
- /// Whether the activation can be applied to the input layer.
- const INPUT = 0b001;
-
- /// Whether the activation can be applied to the hidden layer.
- const HIDDEN = 0b010;
-
- /// Whether the activation can be applied to the output layer.
- const OUTPUT = 0b100;
-
- /// The activation function will not be randomly placed anywhere
- const NONE = 0b000;
- }
-}
-
-impl Default for ActivationScope {
- fn default() -> Self {
- Self::HIDDEN
- }
-}
-
-impl From<&NeuronLocation> for ActivationScope {
- fn from(value: &NeuronLocation) -> Self {
- match value {
- NeuronLocation::Input(_) => Self::INPUT,
- NeuronLocation::Hidden(_) => Self::HIDDEN,
- NeuronLocation::Output(_) => Self::OUTPUT,
- }
- }
-}
-
/// A trait that represents an activation method.
pub trait Activation {
/// The activation function.
@@ -152,16 +151,18 @@ pub struct ActivationFn {
pub func: Arc,
/// The scope defining where the activation function can appear.
- pub scope: ActivationScope,
- pub(crate) name: String,
+ pub scope: NeuronScope,
+
+ /// The name of the activation function, used for debugging and serialization.
+ pub name: &'static str,
}
impl ActivationFn {
/// Creates a new ActivationFn object.
pub fn new(
func: Arc,
- scope: ActivationScope,
- name: String,
+ scope: NeuronScope,
+ name: &'static str,
) -> Self {
Self { func, name, scope }
}
@@ -169,7 +170,7 @@ impl ActivationFn {
impl fmt::Debug for ActivationFn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- writeln!(f, "{}", self.name)
+ write!(f, "{}", self.name)
}
}
@@ -182,7 +183,7 @@ impl PartialEq for ActivationFn {
#[cfg(feature = "serde")]
impl Serialize for ActivationFn {
fn serialize(&self, serializer: S) -> Result {
- serializer.serialize_str(&self.name)
+ serializer.serialize_str(self.name)
}
}
@@ -196,7 +197,7 @@ impl<'a> Deserialize<'a> for ActivationFn {
let reg = ACTIVATION_REGISTRY.read().unwrap();
- let f = reg.fns.get(&name);
+ let f = reg.fns.get(name.as_str());
if f.is_none() {
panic!("Activation function {name} not found");
@@ -206,17 +207,36 @@ impl<'a> Deserialize<'a> for ActivationFn {
}
}
-/// The sigmoid activation function.
-pub fn sigmoid(n: f32) -> f32 {
- 1. / (1. + std::f32::consts::E.powf(-n))
+bitflags! {
+ /// Specifies where an activation function can occur
+ #[derive(Copy, Clone, Debug, Eq, PartialEq)]
+ pub struct NeuronScope: u8 {
+ /// Whether the activation can be applied to the input layer.
+ const INPUT = 0b001;
+
+ /// Whether the activation can be applied to the hidden layer.
+ const HIDDEN = 0b010;
+
+ /// Whether the activation can be applied to the output layer.
+ const OUTPUT = 0b100;
+
+ /// The activation function will not be randomly placed anywhere
+ const NONE = 0b000;
+ }
}
-/// The ReLU activation function.
-pub fn relu(n: f32) -> f32 {
- n.max(0.)
+impl Default for NeuronScope {
+ fn default() -> Self {
+ Self::HIDDEN
+ }
}
-/// Activation function that does nothing.
-pub fn linear_activation(n: f32) -> f32 {
- n
+impl> From for NeuronScope {
+ fn from(value: L) -> Self {
+ match value.as_ref() {
+ NeuronLocation::Input(_) => Self::INPUT,
+ NeuronLocation::Hidden(_) => Self::HIDDEN,
+ NeuronLocation::Output(_) => Self::OUTPUT,
+ }
+ }
}
diff --git a/src/activation/builtin.rs b/src/activation/builtin.rs
new file mode 100644
index 0000000..651f3f3
--- /dev/null
+++ b/src/activation/builtin.rs
@@ -0,0 +1,14 @@
+/// The sigmoid activation function. Scales all values nonlinearly to the range (0, 1).
+pub fn sigmoid(n: f32) -> f32 {
+ 1. / (1. + std::f32::consts::E.powf(-n))
+}
+
+/// The ReLU activation function. Equal to `n.max(0)`
+pub fn relu(n: f32) -> f32 {
+ n.max(0.)
+}
+
+/// Activation function that does nothing.
+pub fn linear_activation(n: f32) -> f32 {
+ n
+}
diff --git a/src/lib.rs b/src/lib.rs
index ee9f769..572c7af 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,25 +1,41 @@
-//! A simple crate that implements the Neuroevolution Augmenting Topologies algorithm using [genetic-rs](https://crates.io/crates/genetic-rs)
-//! ### Feature Roadmap:
-//! - [x] base (single-core) crate
-//! - [x] rayon
-//! - [x] serde
-//! - [x] crossover
-//!
-//! You can get started by looking at [genetic-rs docs](https://docs.rs/genetic-rs) and checking the examples for this crate.
-
+#![doc = include_str!("../README.md")]
#![warn(missing_docs)]
-#![cfg_attr(docsrs, feature(doc_cfg))]
-/// A module containing the [`NeuralNetworkTopology`] struct. This is what you want to use in the DNA of your agent, as it is the thing that goes through nextgens and suppors mutation.
-pub mod topology;
+/// Contains the types surrounding activation functions.
+pub mod activation;
+
+/// Contains the [`NeuralNetwork`] and related types.
+pub mod neuralnet;
+
+pub use neuralnet::*;
+
+pub use genetic_rs::{self, prelude::*};
+
+/// A trait for getting the index of the maximum element.
+pub trait MaxIndex {
+ /// Returns the index of the maximum element.
+ fn max_index(self) -> Option;
+}
+
+impl> MaxIndex for I {
+ fn max_index(self) -> Option {
+ // enumerate now so we don't accidentally
+ // skip the index of the first element
+ let mut iter = self.enumerate();
+
+ let mut max_i = 0;
+ let mut max_v = iter.next()?.1;
-/// A module containing the main [`NeuralNetwork`] struct.
-/// This has state/cache and will run the predictions. Make sure to run [`NeuralNetwork::flush_state`] between uses of [`NeuralNetwork::predict`].
-pub mod runnable;
+ for (i, v) in iter {
+ if v > max_v {
+ max_v = v;
+ max_i = i;
+ }
+ }
-pub use genetic_rs::prelude::*;
-pub use runnable::*;
-pub use topology::*;
+ Some(max_i)
+ }
+}
-#[cfg(feature = "serde")]
-pub use nnt_serde::*;
+#[cfg(test)]
+mod tests;
diff --git a/src/neuralnet.rs b/src/neuralnet.rs
new file mode 100644
index 0000000..9ec184e
--- /dev/null
+++ b/src/neuralnet.rs
@@ -0,0 +1,1317 @@
+use std::{
+ collections::{HashMap, HashSet, VecDeque},
+ ops::{Index, IndexMut},
+ sync::{
+ atomic::{AtomicBool, AtomicUsize, Ordering},
+ Arc,
+ },
+};
+
+use atomic_float::AtomicF32;
+use bitflags::bitflags;
+use genetic_rs::prelude::*;
+use rand::Rng;
+use replace_with::replace_with_or_abort;
+
+use crate::{
+ activation::{builtin::*, *},
+ activation_fn,
+};
+
+use rayon::prelude::*;
+
+#[cfg(feature = "serde")]
+use serde::{Deserialize, Deserializer, Serialize, Serializer};
+
+#[cfg(feature = "serde")]
+use serde_big_array::BigArray;
+
+#[cfg(feature = "serde")]
+mod outputs_serde {
+ use super::*;
+ use std::collections::HashMap;
+
+ pub fn serialize
(
+ map: &HashMap,
+ serializer: S,
+ ) -> Result
+ where
+ S: Serializer,
+ {
+ let vec: Vec<(NeuronLocation, f32)> = map.iter().map(|(k, v)| (*k, *v)).collect();
+ vec.serialize(serializer)
+ }
+
+ pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ let vec: Vec<(NeuronLocation, f32)> = Vec::deserialize(deserializer)?;
+ Ok(vec.into_iter().collect())
+ }
+}
+
+/// An abstract neural network type with `I` input neurons and `O` output neurons.
+/// Hidden neurons are not organized into layers, but rather float and link freely
+/// (or at least in any way that doesn't cause a cyclic dependency).
+///
+/// See [`NeuralNetwork::predict`] for usage.
+#[derive(Debug, Clone, PartialEq)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+pub struct NeuralNetwork {
+ /// The input layer of neurons. Values specified in [`NeuralNetwork::predict`] will start here.
+ #[cfg_attr(feature = "serde", serde(with = "BigArray"))]
+ pub input_layer: [Neuron; I],
+
+ /// The hidden layer(s) of neurons. They are not actually layered, but rather free-floating.
+ pub hidden_layers: Vec,
+
+ /// The output layer of neurons. Their values will be returned from [`NeuralNetwork::predict`].
+ #[cfg_attr(feature = "serde", serde(with = "BigArray"))]
+ pub output_layer: [Neuron; O],
+}
+
+impl NeuralNetwork {
+ // TODO option to set default output layer activations
+ /// Creates a new random neural network with the given settings.
+ pub fn new(rng: &mut impl rand::Rng) -> Self {
+ let mut output_layer = Vec::with_capacity(O);
+
+ for _ in 0..O {
+ output_layer.push(Neuron::new_with_activation(
+ HashMap::new(),
+ activation_fn!(sigmoid),
+ rng,
+ ));
+ }
+
+ let mut input_layer = Vec::with_capacity(I);
+
+ for _ in 0..I {
+ let mut already_chosen = HashSet::new();
+ let num_outputs = rng.random_range(1..=O);
+ let mut outputs = HashMap::new();
+
+ for _ in 0..num_outputs {
+ let mut j = rng.random_range(0..O);
+ while already_chosen.contains(&j) {
+ j = rng.random_range(0..O);
+ }
+
+ output_layer[j].input_count += 1;
+ already_chosen.insert(j);
+
+ outputs.insert(NeuronLocation::Output(j), rng.random());
+ }
+
+ input_layer.push(Neuron::new_with_activation(
+ outputs,
+ activation_fn!(linear_activation),
+ rng,
+ ));
+ }
+
+ let input_layer = input_layer.try_into().unwrap();
+ let output_layer = output_layer.try_into().unwrap();
+
+ Self {
+ input_layer,
+ hidden_layers: vec![],
+ output_layer,
+ }
+ }
+
+ /// Runs the neural network, propagating values from input to output layer.
+ pub fn predict(&self, inputs: [f32; I]) -> [f32; O] {
+ let cache = Arc::new(NeuralNetCache::from(self));
+ cache.prime_inputs(inputs);
+
+ (0..I)
+ .into_par_iter()
+ .for_each(|i| self.eval(NeuronLocation::Input(i), cache.clone()));
+
+ let mut outputs = [0.0; O];
+ for (i, output) in outputs.iter_mut().enumerate().take(O) {
+ let n = &self.output_layer[i];
+ let val = cache.get(NeuronLocation::Output(i));
+ *output = n.activate(val);
+ }
+
+ outputs
+ }
+
+ fn eval(&self, loc: NeuronLocation, cache: Arc>) {
+ if !cache.claim(loc) {
+ // some other thread is already
+ // waiting to do this task, currently doing it, or done.
+ // no need to do it again.
+ return;
+ }
+
+ while !cache.is_ready(loc) {
+ // essentially spinlocks until the dependency tasks are complete,
+ // while letting this thread do some work on random tasks.
+ rayon::yield_now();
+ }
+
+ let n = &self[loc];
+ let val = n.activate(cache.get(loc));
+
+ n.outputs.par_iter().for_each(|(&loc2, weight)| {
+ cache.add(loc2, val * weight);
+ self.eval(loc2, cache.clone());
+ });
+ }
+
+ /// Get a neuron at the specified [`NeuronLocation`].
+ pub fn get_neuron(&self, loc: NeuronLocation) -> Option<&Neuron> {
+ if !self.neuron_exists(loc) {
+ None
+ } else {
+ Some(&self[loc])
+ }
+ }
+
+ /// Returns whether there is a neuron at the location
+ pub fn neuron_exists(&self, loc: NeuronLocation) -> bool {
+ match loc {
+ NeuronLocation::Input(i) => i < I,
+ NeuronLocation::Hidden(i) => i < self.hidden_layers.len(),
+ NeuronLocation::Output(i) => i < O,
+ }
+ }
+
+ /// Get a mutable reference to the neuron at the specified [`NeuronLocation`].
+ pub fn get_neuron_mut(&mut self, loc: NeuronLocation) -> Option<&mut Neuron> {
+ if !self.neuron_exists(loc) {
+ None
+ } else {
+ Some(&mut self[loc])
+ }
+ }
+
+ /// Adds a new neuron to hidden layer. Updates [`input_count`][Neuron::input_count]s automatically.
+ /// Removes any output connections that point to invalid neurons or would result in cyclic linkage.
+ /// Returns whether all output connections were valid.
+ /// Due to the cyclic check, this function has time complexity O(nm), where n is the number of neurons
+ /// and m is the number of output connections.
+ pub fn add_neuron(&mut self, mut n: Neuron) -> bool {
+ let mut valid = true;
+ let new_loc = NeuronLocation::Hidden(self.hidden_layers.len());
+ let outputs = n.outputs.keys().cloned().collect::>();
+ for loc in outputs {
+ if !self.neuron_exists(loc)
+ || !self.is_connection_safe(Connection {
+ from: new_loc,
+ to: loc,
+ })
+ {
+ n.outputs.remove(&loc);
+ valid = false;
+ continue;
+ }
+
+ let n = &mut self[loc];
+ n.input_count += 1;
+ }
+
+ self.hidden_layers.push(n);
+
+ valid
+ }
+
+ /// Split a [`Connection`] into two of the same weight, joined by a new [`Neuron`] in the hidden layer(s).
+ pub fn split_connection(&mut self, connection: Connection, rng: &mut impl Rng) {
+ let new_loc = NeuronLocation::Hidden(self.hidden_layers.len());
+
+ let a = &mut self[connection.from];
+ let w = a
+ .outputs
+ .remove(&connection.to)
+ .expect("invalid connection.to");
+
+ a.outputs.insert(new_loc, w);
+
+ let mut outputs = HashMap::new();
+ outputs.insert(connection.to, w);
+ let mut new_n = Neuron::new(outputs, NeuronScope::HIDDEN, rng);
+ new_n.input_count = 1;
+ self.hidden_layers.push(new_n);
+ }
+
+ /// Adds a connection but does not check for cyclic linkages.
+ pub fn add_connection_unchecked(&mut self, connection: Connection, weight: f32) {
+ let a = &mut self[connection.from];
+ a.outputs.insert(connection.to, weight);
+
+ let b = &mut self[connection.to];
+ b.input_count += 1;
+ }
+
+ /// Returns false if the connection is cyclic or the input/output neurons are otherwise invalid in some other way.
+ /// Can be O(n) over the number of neurons in the network.
+ pub fn is_connection_safe(&self, connection: Connection) -> bool {
+ if connection.from.is_output()
+ || connection.to.is_input()
+ || connection.from == connection.to
+ || (self.neuron_exists(connection.from)
+ && self[connection.from].outputs.contains_key(&connection.to))
+ {
+ return false;
+ }
+ let mut visited = HashSet::from([connection.from]);
+ self.dfs(&mut visited, connection.to)
+ }
+
+ fn dfs(&self, visited: &mut HashSet, current: NeuronLocation) -> bool {
+ if !visited.insert(current) {
+ return false;
+ }
+
+ let n = &self[current];
+ for loc in n.outputs.keys() {
+ if !self.dfs(visited, *loc) {
+ return false;
+ }
+ }
+
+ true
+ }
+
+ /// Safe, checked add connection method. Returns false if it aborted due to cyclic linkage.
+ /// Note that checking for cyclic linkage is O(n) over all neurons in the network, which
+ /// may be expensive for larger networks.
+ pub fn add_connection(&mut self, connection: Connection, weight: f32) -> bool {
+ if !self.is_connection_safe(connection) {
+ return false;
+ }
+
+ self.add_connection_unchecked(connection, weight);
+
+ true
+ }
+
+ /// Attempts to add a random connection, retrying if unsafe.
+ /// Returns the connection if it established one before reaching max_retries.
+ pub fn add_random_connection(
+ &mut self,
+ max_retries: usize,
+ rng: &mut impl rand::Rng,
+ ) -> Option {
+ for _ in 0..max_retries {
+ let a = self.random_location_in_scope(rng, !NeuronScope::OUTPUT);
+ let b = self.random_location_in_scope(rng, !NeuronScope::INPUT);
+
+ let conn = Connection { from: a, to: b };
+ if self.add_connection(conn, rng.random()) {
+ return Some(conn);
+ }
+ }
+
+ None
+ }
+
+ /// Attempts to get a random connection, retrying if the neuron it found
+ /// doesn't have any outbound connections.
+ /// Returns the connection if it found one before reaching max_retries.
+ pub fn get_random_connection(
+ &mut self,
+ max_retries: usize,
+ rng: &mut impl rand::Rng,
+ ) -> Option {
+ for _ in 0..max_retries {
+ let a = self.random_location_in_scope(rng, !NeuronScope::OUTPUT);
+ let an = &self[a];
+ if an.outputs.is_empty() {
+ continue;
+ }
+
+ let mut iter = an
+ .outputs
+ .keys()
+ .skip(rng.random_range(0..an.outputs.len()));
+ let b = iter.next().unwrap();
+
+ let conn = Connection { from: a, to: *b };
+ return Some(conn);
+ }
+
+ None
+ }
+
+ /// Attempts to remove a random connection, retrying if the neuron it found
+ /// doesn't have any outbound connections. Also removes hanging neurons created
+ /// by removing the connection.
+ ///
+ /// Returns the connection if it removed one before reaching max_retries.
+ pub fn remove_random_connection(
+ &mut self,
+ max_retries: usize,
+ rng: &mut impl rand::Rng,
+ ) -> Option {
+ if let Some(conn) = self.get_random_connection(max_retries, rng) {
+ self.remove_connection(conn);
+ Some(conn)
+ } else {
+ None
+ }
+ }
+
+ /// Mutates a connection's weight.
+ pub fn mutate_weight(&mut self, connection: Connection, amount: f32, rng: &mut impl Rng) {
+ let n = &mut self[connection.from];
+ n.mutate_weight(connection.to, amount, rng).unwrap();
+ }
+
+ /// Get a random valid location within the network.
+ pub fn random_location(&self, rng: &mut impl Rng) -> NeuronLocation {
+ if self.hidden_layers.is_empty() {
+ if rng.random_range(0..=1) != 0 {
+ return NeuronLocation::Input(rng.random_range(0..I));
+ }
+ return NeuronLocation::Output(rng.random_range(0..O));
+ }
+
+ match rng.random_range(0..3) {
+ 0 => NeuronLocation::Input(rng.random_range(0..I)),
+ 1 => NeuronLocation::Hidden(rng.random_range(0..self.hidden_layers.len())),
+ 2 => NeuronLocation::Output(rng.random_range(0..O)),
+ _ => unreachable!(),
+ }
+ }
+
+ /// Get a random valid location within a [`NeuronScope`].
+ pub fn random_location_in_scope(
+ &self,
+ rng: &mut impl rand::Rng,
+ scope: NeuronScope,
+ ) -> NeuronLocation {
+ if scope == NeuronScope::NONE {
+ panic!("cannot select from empty scope");
+ }
+
+ let mut layers = Vec::with_capacity(3);
+ if scope.contains(NeuronScope::INPUT) {
+ layers.push((NeuronLocation::Input(0), I));
+ }
+ if scope.contains(NeuronScope::HIDDEN) && !self.hidden_layers.is_empty() {
+ layers.push((NeuronLocation::Hidden(0), self.hidden_layers.len()));
+ }
+ if scope.contains(NeuronScope::OUTPUT) {
+ layers.push((NeuronLocation::Output(0), O));
+ }
+
+ let (mut loc, size) = layers[rng.random_range(0..layers.len())];
+ loc.set_inner(rng.random_range(0..size));
+ loc
+ }
+
+ /// Remove a connection and indicate whether the destination neuron became hanging
+ /// (with the exception of output layer neurons).
+ /// Returns `true` if the destination neuron has input_count == 0 and should be removed.
+ /// Callers must handle the removal of the destination neuron if needed.
+ pub fn remove_connection_raw(&mut self, connection: Connection) -> bool {
+ let a = self
+ .get_neuron_mut(connection.from)
+ .expect("invalid connection.from");
+ if a.outputs.remove(&connection.to).is_none() {
+ panic!("invalid connection.to");
+ }
+
+ let b = &mut self[connection.to];
+
+ // if the invariants held at the beginning of the call,
+ // this should never underflow, but some cases like remove_cycles
+ // may temporarily break invariants.
+ b.input_count = b.input_count.saturating_sub(1);
+
+ // signal removal
+ connection.to.is_hidden() && b.input_count == 0
+ }
+
+ /// Remove a connection from the network.
+ /// This will also deal with hanging neurons iteratively to avoid recursion that
+ /// can invalidate stored indices during nested deletions.
+ /// This method is preferable to [`remove_connection_raw`][NeuralNetwork::remove_connection_raw] for a majority of usecases,
+ /// as it preserves the invariants of the neural network.
+ pub fn remove_connection(&mut self, conn: Connection) -> bool {
+ if self.remove_connection_raw(conn) {
+ self.remove_neuron(conn.to);
+ return true;
+ }
+ false
+ }
+
+ /// Remove a neuron and downshift all connection indices to compensate for it.
+ /// Returns the number of neurons removed that were under the index of the removed neuron (including itself).
+ /// This will also deal with hanging neurons iteratively to avoid recursion that
+ /// can invalidate stored indices during nested deletions.
+ pub fn remove_neuron(&mut self, loc: NeuronLocation) -> usize {
+ if !loc.is_hidden() {
+ panic!("cannot remove neurons in input or output layer");
+ }
+
+ let initial_i = loc.unwrap();
+
+ let mut work = VecDeque::new();
+ work.push_back(loc);
+
+ let mut removed = 0;
+ while let Some(cur_loc) = work.pop_front() {
+ // if the neuron was already removed due to earlier deletions, skip.
+ // i don't think it realistically should ever happen, but just in case.
+ if !self.neuron_exists(cur_loc) {
+ continue;
+ }
+
+ let outputs = {
+ let n = &self[cur_loc];
+ n.outputs.keys().cloned().collect::>()
+ };
+
+ for target in outputs {
+ if self.remove_connection_raw(Connection {
+ from: cur_loc,
+ to: target,
+ }) {
+ // target became hanging; schedule it for removal.
+ work.push_back(target);
+ }
+ }
+
+ // Re-check that the neuron still exists and is hidden before removing.
+ if !self.neuron_exists(cur_loc) || !cur_loc.is_hidden() {
+ continue;
+ }
+
+ let i = cur_loc.unwrap();
+ if i < self.hidden_layers.len() {
+ self.hidden_layers.remove(i);
+ if i <= initial_i {
+ removed += 1;
+ }
+ self.downshift_connections(i, &mut work); // O(n^2) bad, but we can optimize later if it's a problem.
+ }
+ }
+
+ removed
+ }
+
+ fn downshift_connections(&mut self, i: usize, work: &mut VecDeque) {
+ self.input_layer
+ .par_iter_mut()
+ .for_each(|n| n.downshift_outputs(i));
+
+ self.hidden_layers
+ .par_iter_mut()
+ .for_each(|n| n.downshift_outputs(i));
+
+ work.par_iter_mut().for_each(|loc| match loc {
+ NeuronLocation::Hidden(j) if *j > i => *j -= 1,
+ _ => {}
+ });
+ }
+
+ /// Runs the `callback` on the weights of the neural network in parallel, allowing it to modify weight values.
+ pub fn update_weights(&mut self, callback: impl Fn(&NeuronLocation, &mut f32) + Sync) {
+ for n in &mut self.input_layer {
+ n.outputs
+ .par_iter_mut()
+ .for_each(|(loc, w)| callback(loc, w));
+ }
+
+ for n in &mut self.hidden_layers {
+ n.outputs
+ .par_iter_mut()
+ .for_each(|(loc, w)| callback(loc, w));
+ }
+ }
+
+ /// Runs the `callback` on the neurons of the neural network in parallel, allowing it to modify neuron values.
+ pub fn mutate_neurons(&mut self, callback: impl Fn(&mut Neuron) + Sync) {
+ self.input_layer.par_iter_mut().for_each(&callback);
+ self.hidden_layers.par_iter_mut().for_each(&callback);
+ self.output_layer.par_iter_mut().for_each(&callback);
+ }
+
+ /// Mutates the activation functions of the neurons in the neural network.
+ pub fn mutate_activations(&mut self, rate: f32) {
+ let reg = ACTIVATION_REGISTRY.read().unwrap();
+ self.mutate_activations_with_reg(rate, ®);
+ }
+
+ /// Mutates the activation functions of the neurons in the neural network, using a provided registry.
+ pub fn mutate_activations_with_reg(&mut self, rate: f32, reg: &ActivationRegistry) {
+ self.input_layer.par_iter_mut().for_each(|n| {
+ let mut rng = rand::rng();
+ if rng.random_bool(rate as f64) {
+ n.mutate_activation(®.activations_in_scope(NeuronScope::INPUT), &mut rng);
+ }
+ });
+ self.hidden_layers.par_iter_mut().for_each(|n| {
+ let mut rng = rand::rng();
+ if rng.random_bool(rate as f64) {
+ n.mutate_activation(®.activations_in_scope(NeuronScope::HIDDEN), &mut rng);
+ }
+ });
+ self.output_layer.par_iter_mut().for_each(|n| {
+ let mut rng = rand::rng();
+ if rng.random_bool(rate as f64) {
+ n.mutate_activation(®.activations_in_scope(NeuronScope::OUTPUT), &mut rng);
+ }
+ });
+ }
+
+ /// Recounts inputs for all neurons in the network
+ /// and removes any invalid connections.
+ pub fn reset_input_counts(&mut self) {
+ self.clear_input_counts();
+
+ for i in 0..I {
+ self.reset_inputs_for_neuron(NeuronLocation::Input(i));
+ }
+
+ for i in 0..self.hidden_layers.len() {
+ self.reset_inputs_for_neuron(NeuronLocation::Hidden(i));
+ }
+ }
+
+ fn reset_inputs_for_neuron(&mut self, loc: NeuronLocation) {
+ let outputs = self[loc].outputs.keys().cloned().collect::>();
+ let outputs2 = outputs
+ .into_iter()
+ .filter(|&loc| {
+ if !self.neuron_exists(loc) {
+ return false;
+ }
+
+ let target = &mut self[loc];
+ target.input_count += 1;
+ true
+ })
+ .collect::>();
+
+ self[loc].outputs.retain(|loc, _| outputs2.contains(loc));
+ }
+
+ fn clear_input_counts(&mut self) {
+ self.input_layer
+ .par_iter_mut()
+ .for_each(|n| n.input_count = 0);
+ self.hidden_layers
+ .par_iter_mut()
+ .for_each(|n| n.input_count = 0);
+ self.output_layer
+ .par_iter_mut()
+ .for_each(|n| n.input_count = 0);
+ }
+
+ /// Iterates over the network and removes any hanging neurons in the hidden layer(s).
+ pub fn prune_hanging_neurons(&mut self) {
+ let mut i = 0;
+ while i < self.hidden_layers.len() {
+ let mut new_i = i + 1;
+ if self.hidden_layers[i].input_count == 0 {
+ // this saturating_sub is a code smell but it works and avoids some edge cases where indices can get messed up.
+ new_i = new_i.saturating_sub(self.remove_neuron(NeuronLocation::Hidden(i)));
+ }
+ i = new_i;
+ }
+ }
+
+ /// Uses DFS to find and remove all cycles in O(n+e) time.
+ /// Expects [`prune_hanging_neurons`][NeuralNetwork::prune_hanging_neurons] to be called afterwards
+ pub fn remove_cycles(&mut self) {
+ let mut visited = HashMap::new();
+ let mut edges_to_remove: HashSet = HashSet::new();
+
+ for i in 0..I {
+ self.remove_cycles_dfs(
+ &mut visited,
+ &mut edges_to_remove,
+ None,
+ NeuronLocation::Input(i),
+ );
+ }
+
+ // unattached cycles (will cause problems since they
+ // never get deleted by input_count == 0)
+ for i in 0..self.hidden_layers.len() {
+ let loc = NeuronLocation::Hidden(i);
+ if !visited.contains_key(&loc) {
+ self.remove_cycles_dfs(&mut visited, &mut edges_to_remove, None, loc);
+ }
+ }
+
+ for conn in edges_to_remove {
+ // only doing raw here since we recalculate input counts and
+ // prune hanging neurons later.
+ self.remove_connection_raw(conn);
+ }
+ }
+
+ // colored dfs
+ fn remove_cycles_dfs(
+ &mut self,
+ visited: &mut HashMap,
+ edges_to_remove: &mut HashSet,
+ prev: Option,
+ current: NeuronLocation,
+ ) {
+ if let Some(&existing) = visited.get(¤t) {
+ if existing == 0 {
+ // part of current dfs - found a cycle
+ // prev must exist here since visited would be empty on first call.
+ let prev = prev.unwrap();
+ if self[prev].outputs.contains_key(¤t) {
+ edges_to_remove.insert(Connection {
+ from: prev,
+ to: current,
+ });
+ }
+ }
+
+ // already fully visited, no need to check again
+ return;
+ }
+
+ visited.insert(current, 0);
+
+ let outputs = self[current].outputs.keys().cloned().collect::>();
+ for loc in outputs {
+ self.remove_cycles_dfs(visited, edges_to_remove, Some(current), loc);
+ }
+
+ visited.insert(current, 1);
+ }
+
+ /// Performs just the mutations that modify the graph structure of the neural network,
+ /// and not the internal mutations that only modify values such as activation functions, weights, and biases.
+ pub fn perform_graph_mutations(
+ &mut self,
+ settings: &MutationSettings,
+ rate: f32,
+ rng: &mut impl rand::Rng,
+ ) {
+ // TODO maybe allow specifying probability
+ // for each type of mutation
+ if settings
+ .allowed_mutations
+ .contains(GraphMutations::SPLIT_CONNECTION)
+ && rng.random_bool(rate as f64)
+ {
+ // split connection
+ if let Some(conn) = self.get_random_connection(settings.max_split_retries, rng) {
+ self.split_connection(conn, rng);
+ }
+ }
+
+ if settings
+ .allowed_mutations
+ .contains(GraphMutations::ADD_CONNECTION)
+ && rng.random_bool(rate as f64)
+ {
+ // add connection
+ self.add_random_connection(settings.max_add_retries, rng);
+ }
+
+ if settings
+ .allowed_mutations
+ .contains(GraphMutations::REMOVE_CONNECTION)
+ && rng.random_bool(rate as f64)
+ {
+ // remove connection
+ self.remove_random_connection(settings.max_remove_retries, rng);
+ }
+ }
+
+ /// Performs just the mutations that modify internal values such as activation functions, weights, and biases,
+ /// and not the graph mutations that modify the structure of the neural network.
+ pub fn perform_internal_mutations(&mut self, settings: &MutationSettings, rate: f32) {
+ self.mutate_activations(rate);
+ self.mutate_weights(settings.weight_mutation_amount);
+ }
+
+ /// Same as [`mutate`][NeuralNetwork::mutate] but allows specifying a custom activation registry for activation mutations.
+ pub fn mutate_with_reg(
+ &mut self,
+ settings: &MutationSettings,
+ rate: f32,
+ rng: &mut impl rand::Rng,
+ reg: &ActivationRegistry,
+ ) {
+ self.perform_graph_mutations(settings, rate, rng);
+ self.mutate_activations_with_reg(rate, reg);
+ self.mutate_weights(settings.weight_mutation_amount);
+ }
+
+ /// Mutates all weights by a random amount up to `max_amount` in either direction.
+ pub fn mutate_weights(&mut self, max_amount: f32) {
+ self.update_weights(|_, w| {
+ let mut rng = rand::rng();
+ let amount = rng.random_range(-max_amount..max_amount);
+ *w += amount;
+ });
+ }
+}
+
+impl Index for NeuralNetwork {
+ type Output = Neuron;
+
+ fn index(&self, loc: NeuronLocation) -> &Self::Output {
+ match loc {
+ NeuronLocation::Input(i) => &self.input_layer[i],
+ NeuronLocation::Hidden(i) => &self.hidden_layers[i],
+ NeuronLocation::Output(i) => &self.output_layer[i],
+ }
+ }
+}
+
+impl GenerateRandom for NeuralNetwork {
+ fn gen_random(rng: &mut impl rand::Rng) -> Self {
+ Self::new(rng)
+ }
+}
+
+impl IndexMut for NeuralNetwork {
+ fn index_mut(&mut self, loc: NeuronLocation) -> &mut Self::Output {
+ match loc {
+ NeuronLocation::Input(i) => &mut self.input_layer[i],
+ NeuronLocation::Hidden(i) => &mut self.hidden_layers[i],
+ NeuronLocation::Output(i) => &mut self.output_layer[i],
+ }
+ }
+}
+
+/// The mutation settings for [`NeuralNetwork`].
+/// Does not affect [`NeuralNetwork::mutate`], only [`NeuralNetwork::divide`] and [`NeuralNetwork::crossover`].
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+#[derive(Debug, Clone, PartialEq)]
+pub struct MutationSettings {
+ /// The chance of each mutation type to occur.
+ pub mutation_rate: f32,
+
+ /// The maximum amount that the weights will be mutated by in one mutation pass.
+ pub weight_mutation_amount: f32,
+
+ /// The maximum amount that biases will be mutated by in one mutation pass.
+ pub bias_mutation_amount: f32,
+
+ /// The maximum number of retries for adding connections.
+ pub max_add_retries: usize,
+
+ /// The maximum number of retries for removing connections.
+ pub max_remove_retries: usize,
+
+ /// The maximum number of retries for splitting connections.
+ pub max_split_retries: usize,
+
+ /// The types of graph mutations to allow during mutation.
+ /// Graph mutations are mutations that modify the structure of the neural network,
+ /// such as adding/removing connections and adding neurons.
+ pub allowed_mutations: GraphMutations,
+}
+
+impl Default for MutationSettings {
+ fn default() -> Self {
+ Self {
+ mutation_rate: 0.01,
+ weight_mutation_amount: 0.5,
+ bias_mutation_amount: 0.5,
+ max_add_retries: 10,
+ max_remove_retries: 10,
+ max_split_retries: 10,
+ allowed_mutations: GraphMutations::default(),
+ }
+ }
+}
+
+bitflags! {
+ /// The types of graph mutations to allow during mutation.
+ #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+ pub struct GraphMutations: u8 {
+ /// Mutation that splits an existing connection into two via a hidden neuron.
+ const SPLIT_CONNECTION = 0b00000001;
+ /// Mutation that adds a new connection between neurons.
+ const ADD_CONNECTION = 0b00000010;
+ /// Mutation that removes an existing connection.
+ const REMOVE_CONNECTION = 0b00000100;
+ }
+}
+
+impl Default for GraphMutations {
+ fn default() -> Self {
+ Self::all()
+ }
+}
+
+#[cfg(feature = "serde")]
+impl Serialize for GraphMutations {
+ fn serialize(&self, serializer: S) -> Result
+ where
+ S: Serializer,
+ {
+ self.bits().serialize(serializer)
+ }
+}
+
+#[cfg(feature = "serde")]
+impl<'de> Deserialize<'de> for GraphMutations {
+ fn deserialize(deserializer: D) -> Result
+ where
+ D: Deserializer<'de>,
+ {
+ let bits = u8::deserialize(deserializer)?;
+ GraphMutations::from_bits(bits)
+ .ok_or_else(|| serde::de::Error::custom("invalid bit pattern for GraphMutations"))
+ }
+}
+
+impl RandomlyMutable for NeuralNetwork {
+ type Context = MutationSettings;
+
+ fn mutate(&mut self, settings: &MutationSettings, rate: f32, rng: &mut impl Rng) {
+ let reg = ACTIVATION_REGISTRY.read().unwrap();
+ self.mutate_with_reg(settings, rate, rng, ®);
+ }
+}
+
+/// The settings used for [`NeuralNetwork`] reproduction.
+#[derive(Debug, Clone, PartialEq)]
+pub struct ReproductionSettings {
+ /// The mutation settings to use during reproduction.
+ pub mutation: MutationSettings,
+
+ /// The number of times to apply mutation during reproduction.
+ pub mutation_passes: usize,
+}
+
+impl Default for ReproductionSettings {
+ fn default() -> Self {
+ Self {
+ mutation: MutationSettings::default(),
+ mutation_passes: 3,
+ }
+ }
+}
+
+impl Mitosis for NeuralNetwork {
+ type Context = ReproductionSettings;
+
+ fn divide(
+ &self,
+ settings: &ReproductionSettings,
+ rate: f32,
+ rng: &mut impl prelude::Rng,
+ ) -> Self {
+ let mut child = self.clone();
+
+ for _ in 0..settings.mutation_passes {
+ child.mutate(&settings.mutation, rate, rng);
+ }
+
+ child
+ }
+}
+
+impl Crossover for NeuralNetwork {
+ type Context = ReproductionSettings;
+
+ fn crossover(
+ &self,
+ other: &Self,
+ settings: &ReproductionSettings,
+ rate: f32,
+ rng: &mut impl rand::Rng,
+ ) -> Self {
+ // merge (temporarily breaking invariants) and then resolve invariants.
+ let mut child = NeuralNetwork {
+ input_layer: self.input_layer.clone(),
+ hidden_layers: vec![],
+ output_layer: self.output_layer.clone(),
+ };
+
+ for i in 0..I {
+ if rng.random_bool(0.5) {
+ child.input_layer[i] = other.input_layer[i].clone();
+ }
+ }
+
+ for i in 0..O {
+ if rng.random_bool(0.5) {
+ child.output_layer[i] = other.output_layer[i].clone();
+ }
+ }
+
+ let larger;
+ let smaller;
+ if self.hidden_layers.len() >= other.hidden_layers.len() {
+ larger = &self.hidden_layers;
+ smaller = &other.hidden_layers;
+ } else {
+ larger = &other.hidden_layers;
+ smaller = &self.hidden_layers;
+ }
+
+ for i in 0..larger.len() {
+ if i < smaller.len() {
+ if rng.random_bool(0.5) {
+ child.hidden_layers.push(smaller[i].clone());
+ } else {
+ child.hidden_layers.push(larger[i].clone());
+ }
+ continue;
+ }
+
+ // larger is the only one with spare neurons, add them.
+ child.hidden_layers.push(larger[i].clone());
+ }
+
+ // resolve invariants
+ child.remove_cycles();
+ child.reset_input_counts();
+ child.prune_hanging_neurons();
+
+ for _ in 0..settings.mutation_passes {
+ child.mutate(&settings.mutation, rate, rng);
+ }
+
+ child
+ }
+}
+
+fn output_exists(loc: NeuronLocation, hidden_len: usize, output_len: usize) -> bool {
+ match loc {
+ NeuronLocation::Input(_) => false,
+ NeuronLocation::Hidden(i) => i < hidden_len,
+ NeuronLocation::Output(i) => i < output_len,
+ }
+}
+
+/// A helper struct for operations on connections between neurons.
+/// It does not contain information about the weight.
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+pub struct Connection {
+ /// The source of the connection.
+ pub from: NeuronLocation,
+
+ /// The destination of the connection.
+ pub to: NeuronLocation,
+}
+
+/// A stateless neuron. Contains info about bias, activation, and connections.
+#[derive(Debug, Clone, PartialEq)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+pub struct Neuron {
+ /// The input count used in [`NeuralNetCache`]. Not safe to modify.
+ pub input_count: usize,
+
+ /// The connections and weights to other neurons.
+ #[cfg_attr(feature = "serde", serde(with = "outputs_serde"))]
+ pub outputs: HashMap,
+
+ /// The initial value of the neuron.
+ pub bias: f32,
+
+ /// The activation function applied to the value before propagating to [`outputs`][Neuron::outputs].
+ pub activation_fn: ActivationFn,
+}
+
+impl Neuron {
+ /// Creates a new neuron with a specified activation function and outputs.
+ pub fn new_with_activation(
+ outputs: HashMap,
+ activation_fn: ActivationFn,
+ rng: &mut impl Rng,
+ ) -> Self {
+ Self {
+ input_count: 0,
+ outputs,
+ bias: rng.random(),
+ activation_fn,
+ }
+ }
+
+ /// Creates a new neuron with the given output locations.
+ /// Chooses a random activation function within the specified scope.
+ pub fn new(
+ outputs: HashMap,
+ scope: NeuronScope,
+ rng: &mut impl Rng,
+ ) -> Self {
+ let reg = ACTIVATION_REGISTRY.read().unwrap();
+ let act = reg.random_activation_in_scope(scope, rng);
+
+ Self::new_with_activation(outputs, act, rng)
+ }
+
+ /// Creates a new neuron with the given outputs.
+ /// Takes a collection of activation functions and chooses a random one from them to use.
+ pub fn new_with_activations(
+ outputs: HashMap,
+ activations: &[ActivationFn],
+ rng: &mut impl Rng,
+ ) -> Self {
+ // TODO maybe Result instead.
+ if activations.is_empty() {
+ panic!("Empty activations list provided");
+ }
+
+ Self::new_with_activation(
+ outputs,
+ activations[rng.random_range(0..activations.len())].clone(),
+ rng,
+ )
+ }
+
+ /// Runs the [activation function][Neuron::activation_fn] on the given value and returns it.
+ pub fn activate(&self, v: f32) -> f32 {
+ self.activation_fn.func.activate(v)
+ }
+
+ /// Randomly mutates the specified weight with the rate.
+ pub fn mutate_weight(
+ &mut self,
+ output: NeuronLocation,
+ rate: f32,
+ rng: &mut impl Rng,
+ ) -> Option {
+ if let Some(w) = self.outputs.get_mut(&output) {
+ *w += rng.random_range(-rate..=rate);
+ return Some(*w);
+ }
+
+ None
+ }
+
+ /// Get a random output location and weight.
+ pub fn random_output(&self, rng: &mut impl Rng) -> (NeuronLocation, f32) {
+ // will panic if outputs is empty
+ let i = rng.random_range(0..self.outputs.len());
+ let x = self.outputs.iter().nth(i).unwrap();
+ (*x.0, *x.1)
+ }
+
+ pub(crate) fn downshift_outputs(&mut self, i: usize) {
+ replace_with_or_abort(&mut self.outputs, |o| {
+ o.into_par_iter()
+ .map(|(loc, w)| match loc {
+ NeuronLocation::Hidden(j) if j > i => (NeuronLocation::Hidden(j - 1), w),
+ _ => (loc, w),
+ })
+ .collect()
+ });
+ }
+
+ /// Removes any outputs pointing to a nonexistent neuron.
+ pub fn prune_invalid_outputs(&mut self, hidden_len: usize, output_len: usize) {
+ self.outputs
+ .retain(|loc, _| output_exists(*loc, hidden_len, output_len));
+ }
+
+ /// Replaces the activation function with a random one.
+ pub fn mutate_activation(&mut self, activations: &[ActivationFn], rng: &mut impl Rng) {
+ if activations.is_empty() {
+ panic!("Empty activations list provided");
+ }
+
+ self.activation_fn = activations[rng.random_range(0..activations.len())].clone();
+ }
+}
+
+/// A pseudo-pointer of sorts that is used for caching.
+#[derive(Hash, Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
+#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
+pub enum NeuronLocation {
+ /// Points to a neuron in the input layer at contained index.
+ Input(usize),
+
+ /// Points to a neuron in the hidden layer at contained index.
+ Hidden(usize),
+
+ /// Points to a neuron in the output layer at contained index.
+ Output(usize),
+}
+
+impl NeuronLocation {
+ /// Returns `true` if it points to the input layer. Otherwise, returns `false`.
+ pub fn is_input(&self) -> bool {
+ matches!(self, Self::Input(_))
+ }
+
+ /// Returns `true` if it points to the hidden layer. Otherwise, returns `false`.
+ pub fn is_hidden(&self) -> bool {
+ matches!(self, Self::Hidden(_))
+ }
+
+ /// Returns `true` if it points to the output layer. Otherwise, returns `false`.
+ pub fn is_output(&self) -> bool {
+ matches!(self, Self::Output(_))
+ }
+
+ /// Retrieves the index value, regardless of layer. Does not consume.
+ pub fn unwrap(&self) -> usize {
+ match self {
+ Self::Input(i) => *i,
+ Self::Hidden(i) => *i,
+ Self::Output(i) => *i,
+ }
+ }
+
+ /// Sets the inner index value without changing the layer.
+ pub fn set_inner(&mut self, v: usize) {
+ // there's gotta be a cleaner way of doing this
+ match self {
+ Self::Input(i) => *i = v,
+ Self::Hidden(i) => *i = v,
+ Self::Output(i) => *i = v,
+ }
+ }
+}
+
+impl AsRef for NeuronLocation {
+ fn as_ref(&self) -> &NeuronLocation {
+ self
+ }
+}
+
+/// Handles the state of a single neuron for [`NeuralNetCache`].
+#[derive(Debug, Default)]
+pub struct NeuronCache {
+ /// The value of the neuron.
+ pub value: AtomicF32,
+
+ /// The expected input count.
+ pub expected_inputs: usize,
+
+ /// The number of inputs that have finished evaluating.
+ pub finished_inputs: AtomicUsize,
+
+ /// Whether or not a thread has claimed this neuron to work on it.
+ pub claimed: AtomicBool,
+}
+
+impl NeuronCache {
+ /// Creates a new [`NeuronCache`] given relevant info.
+ /// Use [`NeuronCache::from`] instead to create cache for a [`Neuron`].
+ pub fn new(bias: f32, expected_inputs: usize) -> Self {
+ Self {
+ value: AtomicF32::new(bias),
+ expected_inputs,
+ ..Default::default()
+ }
+ }
+}
+
+impl From<&Neuron> for NeuronCache {
+ fn from(value: &Neuron) -> Self {
+ Self {
+ value: AtomicF32::new(value.bias),
+ expected_inputs: value.input_count,
+ finished_inputs: AtomicUsize::new(0),
+ claimed: AtomicBool::new(false),
+ }
+ }
+}
+
+/// A cache type used in [`NeuralNetwork::predict`] to track state.
+#[derive(Debug)]
+pub struct NeuralNetCache {
+ /// The input layer cache.
+ pub input_layer: [NeuronCache; I],
+
+ /// The hidden layer(s) cache.
+ pub hidden_layers: Vec,
+
+ /// The output layer cache.
+ pub output_layer: [NeuronCache; O],
+}
+
+impl NeuralNetCache {
+ /// Gets the value of a neuron at the given location.
+ pub fn get(&self, loc: NeuronLocation) -> f32 {
+ match loc {
+ NeuronLocation::Input(i) => self.input_layer[i].value.load(Ordering::SeqCst),
+ NeuronLocation::Hidden(i) => self.hidden_layers[i].value.load(Ordering::SeqCst),
+ NeuronLocation::Output(i) => self.output_layer[i].value.load(Ordering::SeqCst),
+ }
+ }
+
+ /// Adds a value to the neuron at the specified location and increments [`finished_inputs`][NeuronCache::finished_inputs].
+ pub fn add(&self, loc: impl AsRef, n: f32) -> f32 {
+ match loc.as_ref() {
+ NeuronLocation::Input(i) => self.input_layer[*i].value.fetch_add(n, Ordering::SeqCst),
+ NeuronLocation::Hidden(i) => {
+ let c = &self.hidden_layers[*i];
+ let v = c.value.fetch_add(n, Ordering::SeqCst);
+ c.finished_inputs.fetch_add(1, Ordering::SeqCst);
+ v
+ }
+ NeuronLocation::Output(i) => {
+ let c = &self.output_layer[*i];
+ let v = c.value.fetch_add(n, Ordering::SeqCst);
+ c.finished_inputs.fetch_add(1, Ordering::SeqCst);
+ v
+ }
+ }
+ }
+
+ /// Returns whether [`finished_inputs`][NeuronCache::finished_inputs] matches [`expected_inputs`][NeuronCache::expected_inputs].
+ pub fn is_ready(&self, loc: NeuronLocation) -> bool {
+ match loc {
+ NeuronLocation::Input(_) => true, // input neurons are always ready since they don't wait for any inputs
+ NeuronLocation::Hidden(i) => {
+ let c = &self.hidden_layers[i];
+ c.finished_inputs.load(Ordering::SeqCst) >= c.expected_inputs
+ }
+ NeuronLocation::Output(i) => {
+ let c = &self.output_layer[i];
+ c.finished_inputs.load(Ordering::SeqCst) >= c.expected_inputs
+ }
+ }
+ }
+
+ /// Adds the input values to the input layer of neurons.
+ pub fn prime_inputs(&self, inputs: [f32; I]) {
+ for (i, v) in inputs.into_iter().enumerate() {
+ self.input_layer[i].value.fetch_add(v, Ordering::SeqCst);
+ }
+ }
+
+ /// Attempts to claim a neuron. Returns false if it has already been claimed.
+ pub fn claim(&self, loc: impl AsRef) -> bool {
+ match loc.as_ref() {
+ NeuronLocation::Input(i) => self.input_layer[*i]
+ .claimed
+ .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
+ .is_ok(),
+ NeuronLocation::Hidden(i) => self.hidden_layers[*i]
+ .claimed
+ .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
+ .is_ok(),
+ NeuronLocation::Output(i) => self.output_layer[*i]
+ .claimed
+ .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
+ .is_ok(),
+ }
+ }
+}
+
+impl From<&NeuralNetwork> for NeuralNetCache {
+ fn from(net: &NeuralNetwork) -> Self {
+ let input_layer: Vec<_> = net.input_layer.par_iter().map(|n| n.into()).collect();
+ let input_layer = input_layer.try_into().unwrap();
+
+ let hidden_layers = net.hidden_layers.par_iter().map(|n| n.into()).collect();
+
+ let output_layer: Vec<_> = net.output_layer.par_iter().map(|n| n.into()).collect();
+ let output_layer = output_layer.try_into().unwrap();
+
+ Self {
+ input_layer,
+ hidden_layers,
+ output_layer,
+ }
+ }
+}
diff --git a/src/runnable.rs b/src/runnable.rs
deleted file mode 100644
index 5b28f54..0000000
--- a/src/runnable.rs
+++ /dev/null
@@ -1,300 +0,0 @@
-use crate::topology::*;
-
-#[cfg(not(feature = "rayon"))]
-use std::{cell::RefCell, rc::Rc};
-
-#[cfg(feature = "rayon")]
-use rayon::prelude::*;
-#[cfg(feature = "rayon")]
-use std::sync::{Arc, RwLock};
-
-/// A runnable, stated Neural Network generated from a [NeuralNetworkTopology]. Use [`NeuralNetwork::from`] to go from stateles to runnable.
-/// Because this has state, you need to run [`NeuralNetwork::flush_state`] between [`NeuralNetwork::predict`] calls.
-#[derive(Debug)]
-#[cfg(not(feature = "rayon"))]
-pub struct NeuralNetwork {
- input_layer: [Rc>; I],
- hidden_layers: Vec>>,
- output_layer: [Rc>; O],
-}
-
-/// Parallelized version of the [`NeuralNetwork`] struct.
-#[derive(Debug)]
-#[cfg(feature = "rayon")]
-pub struct NeuralNetwork {
- input_layer: [Arc>; I],
- hidden_layers: Vec>>,
- output_layer: [Arc>; O],
-}
-
-impl NeuralNetwork {
- /// Predicts an output for the given inputs.
- #[cfg(not(feature = "rayon"))]
- pub fn predict(&self, inputs: [f32; I]) -> [f32; O] {
- for (i, v) in inputs.iter().enumerate() {
- let mut nw = self.input_layer[i].borrow_mut();
- nw.state.value = *v;
- nw.state.processed = true;
- }
-
- (0..O)
- .map(NeuronLocation::Output)
- .map(|loc| self.process_neuron(loc))
- .collect::>()
- .try_into()
- .unwrap()
- }
-
- /// Parallelized prediction of outputs from inputs.
- #[cfg(feature = "rayon")]
- pub fn predict(&self, inputs: [f32; I]) -> [f32; O] {
- inputs.par_iter().enumerate().for_each(|(i, v)| {
- let mut nw = self.input_layer[i].write().unwrap();
- nw.state.value = *v;
- nw.state.processed = true;
- });
-
- (0..O)
- .map(NeuronLocation::Output)
- .collect::>()
- .into_par_iter()
- .map(|loc| self.process_neuron(loc))
- .collect::>()
- .try_into()
- .unwrap()
- }
-
- #[cfg(not(feature = "rayon"))]
- fn process_neuron(&self, loc: NeuronLocation) -> f32 {
- let n = self.get_neuron(loc);
-
- {
- let nr = n.borrow();
-
- if nr.state.processed {
- return nr.state.value;
- }
- }
-
- let mut n = n.borrow_mut();
-
- for (l, w) in n.inputs.clone() {
- n.state.value += self.process_neuron(l) * w;
- }
-
- n.activate();
-
- n.state.value
- }
-
- #[cfg(feature = "rayon")]
- fn process_neuron(&self, loc: NeuronLocation) -> f32 {
- let n = self.get_neuron(loc);
-
- {
- let nr = n.read().unwrap();
-
- if nr.state.processed {
- return nr.state.value;
- }
- }
-
- let val: f32 = n
- .read()
- .unwrap()
- .inputs
- .par_iter()
- .map(|&(n2, w)| {
- let processed = self.process_neuron(n2);
- processed * w
- })
- .sum();
-
- let mut nw = n.write().unwrap();
- nw.state.value += val;
- nw.activate();
-
- nw.state.value
- }
-
- #[cfg(not(feature = "rayon"))]
- fn get_neuron(&self, loc: NeuronLocation) -> Rc> {
- match loc {
- NeuronLocation::Input(i) => self.input_layer[i].clone(),
- NeuronLocation::Hidden(i) => self.hidden_layers[i].clone(),
- NeuronLocation::Output(i) => self.output_layer[i].clone(),
- }
- }
-
- #[cfg(feature = "rayon")]
- fn get_neuron(&self, loc: NeuronLocation) -> Arc> {
- match loc {
- NeuronLocation::Input(i) => self.input_layer[i].clone(),
- NeuronLocation::Hidden(i) => self.hidden_layers[i].clone(),
- NeuronLocation::Output(i) => self.output_layer[i].clone(),
- }
- }
-
- /// Flushes the network's state after a [prediction][NeuralNetwork::predict].
- #[cfg(not(feature = "rayon"))]
- pub fn flush_state(&self) {
- for n in &self.input_layer {
- n.borrow_mut().flush_state();
- }
-
- for n in &self.hidden_layers {
- n.borrow_mut().flush_state();
- }
-
- for n in &self.output_layer {
- n.borrow_mut().flush_state();
- }
- }
-
- /// Flushes the neural network's state.
- #[cfg(feature = "rayon")]
- pub fn flush_state(&self) {
- self.input_layer
- .par_iter()
- .for_each(|n| n.write().unwrap().flush_state());
-
- self.hidden_layers
- .par_iter()
- .for_each(|n| n.write().unwrap().flush_state());
-
- self.output_layer
- .par_iter()
- .for_each(|n| n.write().unwrap().flush_state());
- }
-}
-
-impl From<&NeuralNetworkTopology> for NeuralNetwork {
- #[cfg(not(feature = "rayon"))]
- fn from(value: &NeuralNetworkTopology) -> Self {
- let input_layer = value
- .input_layer
- .iter()
- .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone()))))
- .collect::>()
- .try_into()
- .unwrap();
-
- let hidden_layers = value
- .hidden_layers
- .iter()
- .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone()))))
- .collect();
-
- let output_layer = value
- .output_layer
- .iter()
- .map(|n| Rc::new(RefCell::new(Neuron::from(&n.read().unwrap().clone()))))
- .collect::>()
- .try_into()
- .unwrap();
-
- Self {
- input_layer,
- hidden_layers,
- output_layer,
- }
- }
-
- #[cfg(feature = "rayon")]
- fn from(value: &NeuralNetworkTopology) -> Self {
- let input_layer = value
- .input_layer
- .iter()
- .map(|n| Arc::new(RwLock::new(Neuron::from(&n.read().unwrap().clone()))))
- .collect::>()
- .try_into()
- .unwrap();
-
- let hidden_layers = value
- .hidden_layers
- .iter()
- .map(|n| Arc::new(RwLock::new(Neuron::from(&n.read().unwrap().clone()))))
- .collect();
-
- let output_layer = value
- .output_layer
- .iter()
- .map(|n| Arc::new(RwLock::new(Neuron::from(&n.read().unwrap().clone()))))
- .collect::>()
- .try_into()
- .unwrap();
-
- Self {
- input_layer,
- hidden_layers,
- output_layer,
- }
- }
-}
-
-/// A state-filled neuron.
-#[derive(Clone, Debug)]
-pub struct Neuron {
- inputs: Vec<(NeuronLocation, f32)>,
- bias: f32,
-
- /// The current state of the neuron.
- pub state: NeuronState,
-
- /// The neuron's activation function
- pub activation: ActivationFn,
-}
-
-impl Neuron {
- /// Flushes a neuron's state. Called by [`NeuralNetwork::flush_state`]
- pub fn flush_state(&mut self) {
- self.state.value = self.bias;
- }
-
- /// Applies the activation function to the neuron
- pub fn activate(&mut self) {
- self.state.value = self.activation.func.activate(self.state.value);
- }
-}
-
-impl From<&NeuronTopology> for Neuron {
- fn from(value: &NeuronTopology) -> Self {
- Self {
- inputs: value.inputs.clone(),
- bias: value.bias,
- state: NeuronState {
- value: value.bias,
- ..Default::default()
- },
- activation: value.activation.clone(),
- }
- }
-}
-
-/// A state used in [`Neuron`]s for cache.
-#[derive(Clone, Debug, Default)]
-pub struct NeuronState {
- /// The current value of the neuron. Initialized to a neuron's bias when flushed.
- pub value: f32,
-
- /// Whether or not [`value`][NeuronState::value] has finished processing.
- pub processed: bool,
-}
-
-/// A blanket trait for iterators meant to help with interpreting the output of a [`NeuralNetwork`]
-#[cfg(feature = "max-index")]
-pub trait MaxIndex {
- /// Retrieves the index of the max value.
- fn max_index(self) -> usize;
-}
-
-#[cfg(feature = "max-index")]
-impl, T: PartialOrd> MaxIndex for I {
- // slow and lazy implementation but it works (will prob optimize in the future)
- fn max_index(self) -> usize {
- self.enumerate()
- .max_by(|(_, v), (_, v2)| v.partial_cmp(v2).unwrap())
- .unwrap()
- .0
- }
-}
diff --git a/src/tests.rs b/src/tests.rs
new file mode 100644
index 0000000..b1345a2
--- /dev/null
+++ b/src/tests.rs
@@ -0,0 +1,339 @@
+use std::collections::HashMap;
+
+use crate::{activation::builtin::linear_activation, *};
+use genetic_rs::prelude::rand::{rngs::StdRng, SeedableRng};
+use rayon::prelude::*;
+
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+enum GraphCheckState {
+ CurrentCycle,
+ Checked,
+}
+
+fn assert_graph_invariants(net: &NeuralNetwork) {
+ let mut visited = HashMap::new();
+
+ for i in 0..I {
+ dfs(net, NeuronLocation::Input(i), &mut visited);
+ }
+
+ for i in 0..net.hidden_layers.len() {
+ let loc = NeuronLocation::Hidden(i);
+ if !visited.contains_key(&loc) {
+ panic!("hanging neuron: {loc:?}");
+ }
+ }
+}
+
+// simple colored dfs for checking graph invariants.
+fn dfs(
+ net: &NeuralNetwork,
+ loc: NeuronLocation,
+ visited: &mut HashMap,
+) {
+ if let Some(existing) = visited.get(&loc) {
+ match *existing {
+ GraphCheckState::CurrentCycle => panic!("cycle detected on {loc:?}"),
+ GraphCheckState::Checked => return,
+ }
+ }
+
+ visited.insert(loc, GraphCheckState::CurrentCycle);
+
+ for loc2 in net[loc].outputs.keys() {
+ dfs(net, *loc2, visited);
+ }
+
+ visited.insert(loc, GraphCheckState::Checked);
+}
+
+struct InputCountsCache {
+ hidden_layers: Vec,
+ output: [usize; O],
+}
+
+impl InputCountsCache {
+ fn tally(&mut self, loc: NeuronLocation) {
+ match loc {
+ NeuronLocation::Input(_) => panic!("input neurons can't have inputs"),
+ NeuronLocation::Hidden(i) => self.hidden_layers[i] += 1,
+ NeuronLocation::Output(i) => self.output[i] += 1,
+ }
+ }
+}
+
+// asserts that cached/tracked values are correct. mainly only used for
+// input count and such
+fn assert_cache_consistency(net: &NeuralNetwork) {
+ let mut cache = InputCountsCache {
+ hidden_layers: vec![0; net.hidden_layers.len()],
+ output: [0; O],
+ };
+
+ for i in 0..I {
+ let n = &net[NeuronLocation::Input(i)];
+ for loc in n.outputs.keys() {
+ cache.tally(*loc);
+ }
+ }
+
+ for n in &net.hidden_layers {
+ for loc in n.outputs.keys() {
+ cache.tally(*loc);
+ }
+ }
+
+ for (i, x) in cache.hidden_layers.into_iter().enumerate() {
+ if x == 0 {
+ // redundant because of graph invariants, but better safe than sorry
+ panic!("found hanging neuron");
+ }
+
+ assert_eq!(x, net.hidden_layers[i].input_count);
+ }
+
+ for (i, x) in cache.output.into_iter().enumerate() {
+ assert_eq!(x, net.output_layer[i].input_count);
+ }
+}
+
+fn assert_network_invariants(net: &NeuralNetwork) {
+ assert_graph_invariants(net);
+ assert_cache_consistency(net);
+ // TODO other invariants
+}
+
+const TEST_COUNT: u64 = 1000;
+fn rng_test(test: impl Fn(&mut StdRng) + Sync) {
+ (0..TEST_COUNT).into_par_iter().for_each(|seed| {
+ let mut rng = StdRng::seed_from_u64(seed);
+ test(&mut rng);
+ });
+}
+
+#[test]
+fn create_network() {
+ rng_test(|rng| {
+ let net = NeuralNetwork::<10, 10>::new(rng);
+ assert_network_invariants(&net);
+ });
+}
+
+#[test]
+fn split_connection() {
+ // rng doesn't matter here since it's just adding bias in eval
+ let mut rng = StdRng::seed_from_u64(0xabcdef);
+
+ let mut net = NeuralNetwork::<1, 1>::new(&mut rng);
+ assert_network_invariants(&net);
+
+ net.split_connection(
+ Connection {
+ from: NeuronLocation::Input(0),
+ to: NeuronLocation::Output(0),
+ },
+ &mut rng,
+ );
+ assert_network_invariants(&net);
+
+ assert_eq!(
+ *net.input_layer[0].outputs.keys().next().unwrap(),
+ NeuronLocation::Hidden(0)
+ );
+ assert_eq!(
+ *net.hidden_layers[0].outputs.keys().next().unwrap(),
+ NeuronLocation::Output(0)
+ );
+}
+
+#[test]
+fn add_connection() {
+ let mut rng = StdRng::seed_from_u64(0xabcdef);
+ let mut net = NeuralNetwork {
+ input_layer: [Neuron::new_with_activation(
+ HashMap::new(),
+ activation_fn!(linear_activation),
+ &mut rng,
+ )],
+ hidden_layers: vec![],
+ output_layer: [Neuron::new_with_activation(
+ HashMap::new(),
+ activation_fn!(linear_activation),
+ &mut rng,
+ )],
+ };
+ assert_network_invariants(&net);
+
+ let mut conn = Connection {
+ from: NeuronLocation::Input(0),
+ to: NeuronLocation::Output(0),
+ };
+ assert!(net.add_connection(conn, 0.1));
+ assert_network_invariants(&net);
+
+ assert!(!net.add_connection(conn, 0.1));
+ assert_network_invariants(&net);
+
+ let mut outputs = HashMap::new();
+ outputs.insert(NeuronLocation::Output(0), 0.1);
+ let n = Neuron::new_with_activation(outputs, activation_fn!(linear_activation), &mut rng);
+
+ net.add_neuron(n.clone());
+ // temporarily broken invariants bc of hanging neuron
+
+ conn.to = NeuronLocation::Hidden(0);
+ assert!(net.add_connection(conn, 0.1));
+ assert_network_invariants(&net);
+
+ net.add_neuron(n);
+
+ conn.to = NeuronLocation::Hidden(1);
+ assert!(net.add_connection(conn, 0.1));
+ assert_network_invariants(&net);
+
+ conn.from = NeuronLocation::Hidden(0);
+ assert!(net.add_connection(conn, 0.1));
+ assert_network_invariants(&net);
+
+ net.split_connection(conn, &mut rng);
+ assert_network_invariants(&net);
+
+ conn.from = NeuronLocation::Hidden(2);
+ conn.to = NeuronLocation::Hidden(0);
+
+ assert!(!net.add_connection(conn, 0.1));
+ assert_network_invariants(&net);
+
+ // random stress testing
+ rng_test(|rng| {
+ let mut net = NeuralNetwork::<10, 10>::new(rng);
+ assert_network_invariants(&net);
+ for _ in 0..50 {
+ net.add_random_connection(10, rng);
+ assert_network_invariants(&net);
+ }
+ });
+}
+
+#[test]
+fn remove_connection() {
+ let mut rng = StdRng::seed_from_u64(0xabcdef);
+ let mut net = NeuralNetwork {
+ input_layer: [Neuron::new_with_activation(
+ HashMap::from([
+ (NeuronLocation::Output(0), 0.1),
+ (NeuronLocation::Hidden(0), 1.0),
+ ]),
+ activation_fn!(linear_activation),
+ &mut rng,
+ )],
+ hidden_layers: vec![Neuron {
+ input_count: 1,
+ outputs: HashMap::new(), // not sure whether i want neurons with no outputs to break the invariant/be removed
+ bias: 0.0,
+ activation_fn: activation_fn!(linear_activation),
+ }],
+ output_layer: [Neuron {
+ input_count: 1,
+ outputs: HashMap::new(),
+ bias: 0.0,
+ activation_fn: activation_fn!(linear_activation),
+ }],
+ };
+ assert_network_invariants(&net);
+
+ assert!(!net.remove_connection(Connection {
+ from: NeuronLocation::Input(0),
+ to: NeuronLocation::Output(0)
+ }));
+ assert_network_invariants(&net);
+
+ assert!(net.remove_connection(Connection {
+ from: NeuronLocation::Input(0),
+ to: NeuronLocation::Hidden(0)
+ }));
+ assert_network_invariants(&net);
+
+ rng_test(|rng| {
+ let mut net = NeuralNetwork::<10, 10>::new(rng);
+ assert_network_invariants(&net);
+
+ for _ in 0..70 {
+ net.add_random_connection(10, rng);
+ assert_network_invariants(&net);
+
+ if rng.random_bool(0.25) {
+ // rng allows network to form more complex edge cases.
+ net.remove_random_connection(5, rng);
+ // don't need to remove neuron since this
+ // method handles it automatically.
+ assert_network_invariants(&net);
+ }
+ }
+ });
+}
+
+// TODO remove_neuron test
+
+const NUM_MUTATIONS: usize = 50;
+const MUTATION_RATE: f32 = 0.25;
+#[test]
+fn mutate() {
+ rng_test(|rng| {
+ let mut net = NeuralNetwork::<10, 10>::new(rng);
+ assert_network_invariants(&net);
+
+ let settings = MutationSettings::default();
+
+ for _ in 0..NUM_MUTATIONS {
+ net.mutate(&settings, MUTATION_RATE, rng);
+ assert_network_invariants(&net);
+ }
+ });
+}
+
+#[test]
+fn crossover() {
+ rng_test(|rng| {
+ let mut net1 = NeuralNetwork::<10, 10>::new(rng);
+ assert_network_invariants(&net1);
+
+ let mut net2 = NeuralNetwork::<10, 10>::new(rng);
+ assert_network_invariants(&net2);
+
+ let settings = ReproductionSettings::default();
+
+ for _ in 0..NUM_MUTATIONS {
+ let a = net1.crossover(&net2, &settings, MUTATION_RATE, rng);
+ assert_network_invariants(&a);
+
+ let b = net2.crossover(&net1, &settings, MUTATION_RATE, rng);
+ assert_network_invariants(&b);
+
+ net1 = a;
+ net2 = b;
+ }
+ });
+}
+
+#[cfg(feature = "serde")]
+mod serde {
+ use super::rng_test;
+ use crate::*;
+
+ #[test]
+ fn full_serde() {
+ rng_test(|rng| {
+ let net1 = NeuralNetwork::<10, 10>::new(rng);
+
+ let mut buf = Vec::new();
+ let writer = std::io::Cursor::new(&mut buf);
+ let mut serializer = serde_json::Serializer::new(writer);
+
+ serde_path_to_error::serialize(&net1, &mut serializer).unwrap();
+ let serialized = serde_json::to_string(&net1).unwrap();
+ let net2: NeuralNetwork<10, 10> = serde_json::from_str(&serialized).unwrap();
+ assert_eq!(net1, net2);
+ });
+ }
+}
diff --git a/src/topology/mod.rs b/src/topology/mod.rs
deleted file mode 100644
index 02ad296..0000000
--- a/src/topology/mod.rs
+++ /dev/null
@@ -1,626 +0,0 @@
-/// Contains useful structs for serializing/deserializing a [`NeuronTopology`]
-#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
-#[cfg(feature = "serde")]
-pub mod nnt_serde;
-
-/// Contains structs and traits used for activation functions.
-pub mod activation;
-
-pub use activation::*;
-
-use std::{
- collections::HashSet,
- sync::{Arc, RwLock},
-};
-
-use genetic_rs::prelude::*;
-use rand::prelude::*;
-
-#[cfg(feature = "serde")]
-use serde::{Deserialize, Serialize};
-
-use crate::activation_fn;
-
-/// A stateless neural network topology.
-/// This is the struct you want to use in your agent's inheritance.
-/// See [`NeuralNetwork::from`][crate::NeuralNetwork::from] for how to convert this to a runnable neural network.
-#[derive(Debug)]
-pub struct NeuralNetworkTopology {
- /// The input layer of the neural network. Uses a fixed length of `I`.
- pub input_layer: [Arc>; I],
-
- /// The hidden layers of the neural network. Because neurons have a flexible connection system, all of them exist in the same flat vector.
- pub hidden_layers: Vec>>,
-
- /// The output layer of the neural netowrk. Uses a fixed length of `O`.
- pub output_layer: [Arc>; O],
-
- /// The mutation rate used in [`NeuralNetworkTopology::mutate`] after crossover/division.
- pub mutation_rate: f32,
-
- /// The number of mutation passes (and thus, maximum number of possible mutations that can occur for each entity in the generation).
- pub mutation_passes: usize,
-}
-
-impl NeuralNetworkTopology {
- /// Creates a new [`NeuralNetworkTopology`].
- pub fn new(mutation_rate: f32, mutation_passes: usize, rng: &mut impl Rng) -> Self {
- let input_layer: [Arc>; I] = (0..I)
- .map(|_| {
- Arc::new(RwLock::new(NeuronTopology::new_with_activation(
- vec![],
- activation_fn!(linear_activation),
- rng,
- )))
- })
- .collect::>()
- .try_into()
- .unwrap();
-
- let mut output_layer = Vec::with_capacity(O);
-
- for _ in 0..O {
- // random number of connections to random input neurons.
- let input = (0..rng.gen_range(1..=I))
- .map(|_| {
- let mut already_chosen = Vec::new();
- let mut i = rng.gen_range(0..I);
- while already_chosen.contains(&i) {
- i = rng.gen_range(0..I);
- }
-
- already_chosen.push(i);
-
- NeuronLocation::Input(i)
- })
- .collect();
-
- output_layer.push(Arc::new(RwLock::new(NeuronTopology::new_with_activation(
- input,
- activation_fn!(sigmoid),
- rng,
- ))));
- }
-
- let output_layer = output_layer.try_into().unwrap();
-
- Self {
- input_layer,
- hidden_layers: vec![],
- output_layer,
- mutation_rate,
- mutation_passes,
- }
- }
-
- /// Creates a new connection between the neurons.
- /// If the connection is cyclic, it does not add a connection and returns false.
- /// Otherwise, it returns true.
- pub fn add_connection(
- &mut self,
- from: NeuronLocation,
- to: NeuronLocation,
- weight: f32,
- ) -> bool {
- if self.is_connection_cyclic(from, to) {
- return false;
- }
-
- // Add the connection since it is not cyclic
- self.get_neuron(to)
- .write()
- .unwrap()
- .inputs
- .push((from, weight));
-
- true
- }
-
- fn is_connection_cyclic(&self, from: NeuronLocation, to: NeuronLocation) -> bool {
- if to.is_input() || from.is_output() {
- return true;
- }
-
- let mut visited = HashSet::new();
- self.dfs(from, to, &mut visited)
- }
-
- // TODO rayon implementation
- fn dfs(
- &self,
- current: NeuronLocation,
- target: NeuronLocation,
- visited: &mut HashSet,
- ) -> bool {
- if current == target {
- return true;
- }
-
- visited.insert(current);
-
- let n = self.get_neuron(current);
- let nr = n.read().unwrap();
-
- for &(input, _) in &nr.inputs {
- if !visited.contains(&input) && self.dfs(input, target, visited) {
- return true;
- }
- }
-
- visited.remove(¤t);
- false
- }
-
- /// Gets a neuron pointer from a [`NeuronLocation`].
- /// You shouldn't ever need to directly call this unless you are doing complex custom mutations.
- pub fn get_neuron(&self, loc: NeuronLocation) -> Arc> {
- match loc {
- NeuronLocation::Input(i) => self.input_layer[i].clone(),
- NeuronLocation::Hidden(i) => self.hidden_layers[i].clone(),
- NeuronLocation::Output(i) => self.output_layer[i].clone(),
- }
- }
-
- /// Gets a random neuron and its location.
- pub fn rand_neuron(&self, rng: &mut impl Rng) -> (Arc>, NeuronLocation) {
- match rng.gen_range(0..3) {
- 0 => {
- let i = rng.gen_range(0..self.input_layer.len());
- (self.input_layer[i].clone(), NeuronLocation::Input(i))
- }
- 1 if !self.hidden_layers.is_empty() => {
- let i = rng.gen_range(0..self.hidden_layers.len());
- (self.hidden_layers[i].clone(), NeuronLocation::Hidden(i))
- }
- _ => {
- let i = rng.gen_range(0..self.output_layer.len());
- (self.output_layer[i].clone(), NeuronLocation::Output(i))
- }
- }
- }
-
- fn delete_neuron(&mut self, loc: NeuronLocation) -> NeuronTopology {
- if !loc.is_hidden() {
- panic!("Invalid neuron deletion");
- }
-
- let index = loc.unwrap();
- let neuron = Arc::into_inner(self.hidden_layers.remove(index)).unwrap();
-
- for n in &self.hidden_layers {
- let mut nw = n.write().unwrap();
-
- nw.inputs = nw
- .inputs
- .iter()
- .filter_map(|&(input_loc, w)| {
- if !input_loc.is_hidden() {
- return Some((input_loc, w));
- }
-
- if input_loc.unwrap() == index {
- return None;
- }
-
- if input_loc.unwrap() > index {
- return Some((NeuronLocation::Hidden(input_loc.unwrap() - 1), w));
- }
-
- Some((input_loc, w))
- })
- .collect();
- }
-
- for n2 in &self.output_layer {
- let mut nw = n2.write().unwrap();
- nw.inputs = nw
- .inputs
- .iter()
- .filter_map(|&(input_loc, w)| {
- if !input_loc.is_hidden() {
- return Some((input_loc, w));
- }
-
- if input_loc.unwrap() == index {
- return None;
- }
-
- if input_loc.unwrap() > index {
- return Some((NeuronLocation::Hidden(input_loc.unwrap() - 1), w));
- }
-
- Some((input_loc, w))
- })
- .collect();
- }
-
- neuron.into_inner().unwrap()
- }
-}
-
-// need to do all this manually because Arcs are cringe
-impl Clone for NeuralNetworkTopology {
- fn clone(&self) -> Self {
- let input_layer = self
- .input_layer
- .iter()
- .map(|n| Arc::new(RwLock::new(n.read().unwrap().clone())))
- .collect::>()
- .try_into()
- .unwrap();
-
- let hidden_layers = self
- .hidden_layers
- .iter()
- .map(|n| Arc::new(RwLock::new(n.read().unwrap().clone())))
- .collect();
-
- let output_layer = self
- .output_layer
- .iter()
- .map(|n| Arc::new(RwLock::new(n.read().unwrap().clone())))
- .collect::>()
- .try_into()
- .unwrap();
-
- Self {
- input_layer,
- hidden_layers,
- output_layer,
- mutation_rate: self.mutation_rate,
- mutation_passes: self.mutation_passes,
- }
- }
-}
-
-impl RandomlyMutable for NeuralNetworkTopology {
- fn mutate(&mut self, rate: f32, rng: &mut impl rand::Rng) {
- for _ in 0..self.mutation_passes {
- if rng.gen::() <= rate {
- // split preexisting connection
- let (mut n2, _) = self.rand_neuron(rng);
-
- while n2.read().unwrap().inputs.is_empty() {
- (n2, _) = self.rand_neuron(rng);
- }
-
- let mut n2 = n2.write().unwrap();
- let i = rng.gen_range(0..n2.inputs.len());
- let (loc, w) = n2.inputs.remove(i);
-
- let loc3 = NeuronLocation::Hidden(self.hidden_layers.len());
-
- let n3 = NeuronTopology::new(vec![loc], ActivationScope::HIDDEN, rng);
-
- self.hidden_layers.push(Arc::new(RwLock::new(n3)));
-
- n2.inputs.insert(i, (loc3, w));
- }
-
- if rng.gen::() <= rate {
- // add a connection
- let (_, mut loc1) = self.rand_neuron(rng);
- let (_, mut loc2) = self.rand_neuron(rng);
-
- while loc1.is_output() || !self.add_connection(loc1, loc2, rng.gen::()) {
- (_, loc1) = self.rand_neuron(rng);
- (_, loc2) = self.rand_neuron(rng);
- }
- }
-
- if rng.gen::() <= rate && !self.hidden_layers.is_empty() {
- // remove a neuron
- let (_, mut loc) = self.rand_neuron(rng);
-
- while !loc.is_hidden() {
- (_, loc) = self.rand_neuron(rng);
- }
-
- // delete the neuron
- self.delete_neuron(loc);
- }
-
- if rng.gen::() <= rate {
- // mutate a connection
- let (mut n, _) = self.rand_neuron(rng);
-
- while n.read().unwrap().inputs.is_empty() {
- (n, _) = self.rand_neuron(rng);
- }
-
- let mut n = n.write().unwrap();
- let i = rng.gen_range(0..n.inputs.len());
- let (_, w) = &mut n.inputs[i];
- *w += rng.gen_range(-1.0..1.0) * rate;
- }
-
- if rng.gen::() <= rate {
- // mutate bias
- let (n, _) = self.rand_neuron(rng);
- let mut n = n.write().unwrap();
-
- n.bias += rng.gen_range(-1.0..1.0) * rate;
- }
-
- if rng.gen::() <= rate && !self.hidden_layers.is_empty() {
- // mutate activation function
- let reg = ACTIVATION_REGISTRY.read().unwrap();
- let activations = reg.activations_in_scope(ActivationScope::HIDDEN);
-
- let (mut n, mut loc) = self.rand_neuron(rng);
-
- while !loc.is_hidden() {
- (n, loc) = self.rand_neuron(rng);
- }
-
- let mut nw = n.write().unwrap();
-
- // should probably not clone, but its not a huge efficiency issue anyways
- nw.activation = activations[rng.gen_range(0..activations.len())].clone();
- }
- }
- }
-}
-
-impl DivisionReproduction for NeuralNetworkTopology {
- fn divide(&self, rng: &mut impl rand::Rng) -> Self {
- let mut child = self.clone();
- child.mutate(self.mutation_rate, rng);
- child
- }
-}
-
-impl PartialEq for NeuralNetworkTopology {
- fn eq(&self, other: &Self) -> bool {
- if self.mutation_rate != other.mutation_rate
- || self.mutation_passes != other.mutation_passes
- {
- return false;
- }
-
- for i in 0..I {
- if *self.input_layer[i].read().unwrap() != *other.input_layer[i].read().unwrap() {
- return false;
- }
- }
-
- for i in 0..self.hidden_layers.len().min(other.hidden_layers.len()) {
- if *self.hidden_layers[i].read().unwrap() != *other.hidden_layers[i].read().unwrap() {
- return false;
- }
- }
-
- for i in 0..O {
- if *self.output_layer[i].read().unwrap() != *other.output_layer[i].read().unwrap() {
- return false;
- }
- }
-
- true
- }
-}
-
-#[cfg(feature = "serde")]
-impl From>
- for NeuralNetworkTopology
-{
- fn from(value: nnt_serde::NNTSerde) -> Self {
- let input_layer = value
- .input_layer
- .into_iter()
- .map(|n| Arc::new(RwLock::new(n)))
- .collect::>()
- .try_into()
- .unwrap();
-
- let hidden_layers = value
- .hidden_layers
- .into_iter()
- .map(|n| Arc::new(RwLock::new(n)))
- .collect();
-
- let output_layer = value
- .output_layer
- .into_iter()
- .map(|n| Arc::new(RwLock::new(n)))
- .collect::>()
- .try_into()
- .unwrap();
-
- NeuralNetworkTopology {
- input_layer,
- hidden_layers,
- output_layer,
- mutation_rate: value.mutation_rate,
- mutation_passes: value.mutation_passes,
- }
- }
-}
-
-#[cfg(feature = "crossover")]
-impl CrossoverReproduction for NeuralNetworkTopology {
- fn crossover(&self, other: &Self, rng: &mut impl rand::Rng) -> Self {
- let input_layer = self
- .input_layer
- .iter()
- .map(|n| Arc::new(RwLock::new(n.read().unwrap().clone())))
- .collect::>()
- .try_into()
- .unwrap();
-
- let mut hidden_layers =
- Vec::with_capacity(self.hidden_layers.len().max(other.hidden_layers.len()));
-
- for i in 0..hidden_layers.len() {
- if rng.gen::() <= 0.5 {
- if let Some(n) = self.hidden_layers.get(i) {
- let mut n = n.read().unwrap().clone();
-
- n.inputs
- .retain(|(l, _)| input_exists(*l, &input_layer, &hidden_layers));
- hidden_layers[i] = Arc::new(RwLock::new(n));
-
- continue;
- }
- }
-
- let mut n = other.hidden_layers[i].read().unwrap().clone();
-
- n.inputs
- .retain(|(l, _)| input_exists(*l, &input_layer, &hidden_layers));
- hidden_layers[i] = Arc::new(RwLock::new(n));
- }
-
- let mut output_layer: [Arc>; O] = self
- .output_layer
- .iter()
- .map(|n| Arc::new(RwLock::new(n.read().unwrap().clone())))
- .collect::>()
- .try_into()
- .unwrap();
-
- for (i, n) in self.output_layer.iter().enumerate() {
- if rng.gen::() <= 0.5 {
- let mut n = n.read().unwrap().clone();
-
- n.inputs
- .retain(|(l, _)| input_exists(*l, &input_layer, &hidden_layers));
- output_layer[i] = Arc::new(RwLock::new(n));
-
- continue;
- }
-
- let mut n = other.output_layer[i].read().unwrap().clone();
-
- n.inputs
- .retain(|(l, _)| input_exists(*l, &input_layer, &hidden_layers));
- output_layer[i] = Arc::new(RwLock::new(n));
- }
-
- let mut child = Self {
- input_layer,
- hidden_layers,
- output_layer,
- mutation_rate: self.mutation_rate,
- mutation_passes: self.mutation_passes,
- };
-
- child.mutate(self.mutation_rate, rng);
-
- child
- }
-}
-
-#[cfg(feature = "crossover")]
-fn input_exists(
- loc: NeuronLocation,
- input: &[Arc>; I],
- hidden: &[Arc>],
-) -> bool {
- match loc {
- NeuronLocation::Input(i) => i < input.len(),
- NeuronLocation::Hidden(i) => i < hidden.len(),
- NeuronLocation::Output(_) => false,
- }
-}
-
-/// A stateless version of [`Neuron`][crate::Neuron].
-#[derive(PartialEq, Debug, Clone)]
-#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
-pub struct NeuronTopology {
- /// The input locations and weights.
- pub inputs: Vec<(NeuronLocation, f32)>,
-
- /// The neuron's bias.
- pub bias: f32,
-
- /// The neuron's activation function.
- pub activation: ActivationFn,
-}
-
-impl NeuronTopology {
- /// Creates a new neuron with the given input locations.
- pub fn new(
- inputs: Vec,
- current_scope: ActivationScope,
- rng: &mut impl Rng,
- ) -> Self {
- let reg = ACTIVATION_REGISTRY.read().unwrap();
- let activations = reg.activations_in_scope(current_scope);
-
- Self::new_with_activations(inputs, activations, rng)
- }
-
- /// Takes a collection of activation functions and chooses a random one to use.
- pub fn new_with_activations(
- inputs: Vec,
- activations: impl IntoIterator- ,
- rng: &mut impl Rng,
- ) -> Self {
- let mut activations: Vec<_> = activations.into_iter().collect();
-
- Self::new_with_activation(
- inputs,
- activations.remove(rng.gen_range(0..activations.len())),
- rng,
- )
- }
-
- /// Creates a neuron with the activation.
- pub fn new_with_activation(
- inputs: Vec,
- activation: ActivationFn,
- rng: &mut impl Rng,
- ) -> Self {
- let inputs = inputs
- .into_iter()
- .map(|i| (i, rng.gen_range(-1.0..1.0)))
- .collect();
-
- Self {
- inputs,
- bias: rng.gen(),
- activation,
- }
- }
-}
-
-/// A pseudo-pointer of sorts used to make structural conversions very fast and easy to write.
-#[derive(Hash, Clone, Copy, Debug, Eq, PartialEq)]
-#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
-pub enum NeuronLocation {
- /// Points to a neuron in the input layer at contained index.
- Input(usize),
-
- /// Points to a neuron in the hidden layer at contained index.
- Hidden(usize),
-
- /// Points to a neuron in the output layer at contained index.
- Output(usize),
-}
-
-impl NeuronLocation {
- /// Returns `true` if it points to the input layer. Otherwise, returns `false`.
- pub fn is_input(&self) -> bool {
- matches!(self, Self::Input(_))
- }
-
- /// Returns `true` if it points to the hidden layer. Otherwise, returns `false`.
- pub fn is_hidden(&self) -> bool {
- matches!(self, Self::Hidden(_))
- }
-
- /// Returns `true` if it points to the output layer. Otherwise, returns `false`.
- pub fn is_output(&self) -> bool {
- matches!(self, Self::Output(_))
- }
-
- /// Retrieves the index value, regardless of layer. Does not consume.
- pub fn unwrap(&self) -> usize {
- match self {
- Self::Input(i) => *i,
- Self::Hidden(i) => *i,
- Self::Output(i) => *i,
- }
- }
-}
diff --git a/src/topology/nnt_serde.rs b/src/topology/nnt_serde.rs
deleted file mode 100644
index 14f392c..0000000
--- a/src/topology/nnt_serde.rs
+++ /dev/null
@@ -1,71 +0,0 @@
-use super::*;
-use serde::{Deserialize, Serialize};
-use serde_big_array::BigArray;
-
-/// A serializable wrapper for [`NeuronTopology`]. See [`NNTSerde::from`] for conversion.
-#[derive(Serialize, Deserialize)]
-pub struct NNTSerde {
- #[serde(with = "BigArray")]
- pub(crate) input_layer: [NeuronTopology; I],
-
- pub(crate) hidden_layers: Vec,
-
- #[serde(with = "BigArray")]
- pub(crate) output_layer: [NeuronTopology; O],
-
- pub(crate) mutation_rate: f32,
- pub(crate) mutation_passes: usize,
-}
-
-impl From<&NeuralNetworkTopology> for NNTSerde {
- fn from(value: &NeuralNetworkTopology) -> Self {
- let input_layer = value
- .input_layer
- .iter()
- .map(|n| n.read().unwrap().clone())
- .collect::>()
- .try_into()
- .unwrap();
-
- let hidden_layers = value
- .hidden_layers
- .iter()
- .map(|n| n.read().unwrap().clone())
- .collect();
-
- let output_layer = value
- .output_layer
- .iter()
- .map(|n| n.read().unwrap().clone())
- .collect::>()
- .try_into()
- .unwrap();
-
- Self {
- input_layer,
- hidden_layers,
- output_layer,
- mutation_rate: value.mutation_rate,
- mutation_passes: value.mutation_passes,
- }
- }
-}
-
-#[cfg(test)]
-#[test]
-fn serde() {
- let mut rng = rand::thread_rng();
- let nnt = NeuralNetworkTopology::<10, 10>::new(0.1, 3, &mut rng);
- let nnts = NNTSerde::from(&nnt);
-
- let encoded = bincode::serialize(&nnts).unwrap();
-
- if let Some(_) = option_env!("TEST_CREATEFILE") {
- std::fs::write("serde-test.nn", &encoded).unwrap();
- }
-
- let decoded: NNTSerde<10, 10> = bincode::deserialize(&encoded).unwrap();
- let nnt2: NeuralNetworkTopology<10, 10> = decoded.into();
-
- dbg!(nnt, nnt2);
-}