diff --git a/Cargo.lock b/Cargo.lock index e21fb3c..47df7de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -33,7 +33,7 @@ dependencies = [ [[package]] name = "graph_force" -version = "0.1.2" +version = "0.1.3" dependencies = [ "pyo3", "rand", diff --git a/Cargo.toml b/Cargo.toml index 85e1d0c..a0515e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "graph_force" -version = "0.1.2" +version = "0.1.3" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/lib.rs b/src/lib.rs index 9ae3ad4..397d8e0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,22 +1,37 @@ mod graph; mod model; +mod networkx_model; mod runner; mod spring_model; mod utils; -use std::sync::{Arc, RwLock}; - use pyo3::exceptions; use pyo3::{prelude::*, types::PyIterator}; -#[pyfunction(number_of_nodes, edges, "*", iter = 500, threads = 0)] +#[pyfunction( + number_of_nodes, + edges, + "*", + iter = 500, + threads = 0, + model = "\"spring_model\"" +)] fn layout_from_edge_list( number_of_nodes: usize, edges: &PyAny, iter: usize, threads: usize, + model: &str, ) -> PyResult> { - let model = Arc::new(RwLock::new(spring_model::SimpleSpringModel::new(1.0))); + let model: Box = match model { + "spring_model" => Box::new(spring_model::SimpleSpringModel::new(1.0)), + "networkx_model" => Box::new(networkx_model::NetworkXModel::new()), + _ => { + return Err(PyErr::new::( + "model must be either 'spring_model' or 'networkx_model'", + )) + } + }; let mut edge_matrix = graph::new_edge_matrix(number_of_nodes); match edges.extract::<&PyIterator>() { diff --git a/src/runner.rs b/src/runner.rs index 9f6959f..c7652c4 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -30,16 +30,18 @@ impl Runner { } } - pub fn layout( + pub fn layout( self: &Self, number_of_nodes: usize, edges: EdgeMatrix, - model: Arc>, + model: Box, ) -> Vec<(f32, f32)> { // let edges = connection_matrix(size); let mut nodes = new_node_vector(number_of_nodes); let mut nodes_next = new_node_vector(number_of_nodes); + let model = Arc::new(RwLock::new(model)); + model .write() .unwrap() @@ -109,7 +111,7 @@ mod test { #[test] fn test_layout() { - let model = Arc::new(RwLock::new(MockModel { counter: 0 })); + let model = Box::new(MockModel { counter: 0 }); let runner = Runner::new(10, 1); let edges = graph::new_edge_matrix(3); let result = runner.layout(3, edges, model); diff --git a/tests/test_e2e.py b/tests/test_e2e.py index a9e71b6..e3eba9c 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -22,3 +22,9 @@ def test_tuple_of_edges(): ) assert pos is not None assert len(pos) == 7 + +def test_model_selection(): + edges = [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6)] + pos = graph_force.layout_from_edge_list(7, edges, model='networkx_model') + assert pos is not None + assert len(pos) == 7