diff --git a/src/lib.rs b/src/lib.rs index 9b0140f..40e4fa6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ mod runner; mod spring_model; mod utils; +use graph::NodeVector; use pyo3::exceptions; use pyo3::{prelude::*, types::PyIterator}; @@ -21,18 +22,41 @@ fn pick_model(model: &str) -> Result, P } } -#[pyfunction(file_path, "*", iter = 500, threads = 0, model = "\"spring_model\"")] +fn initial_pos_to_node_vector(initial_pos: Option>) -> Option { + match initial_pos { + Some(pos) => { + let nodes = graph::new_node_vector(pos.len()); + for (i, (x, y)) in pos.iter().enumerate() { + let mut node = nodes[i].write().unwrap(); + node.x = *x; + node.y = *y; + } + Some(nodes) + } + None => None, + } +} + +#[pyfunction( + file_path, + "*", + iter = 500, + threads = 0, + model = "\"spring_model\"", + initial_pos = "None" +)] fn layout_from_edge_file( file_path: &str, iter: usize, threads: usize, model: &str, + initial_pos: Option>, ) -> PyResult> { let (size, matrix) = reader::read_graph(file_path); let model = pick_model(model)?; let r = runner::Runner::new(iter, threads); - Ok(r.layout(size, matrix, model)) + Ok(r.layout(size, matrix, model, initial_pos_to_node_vector(initial_pos))) } #[pyfunction( @@ -41,7 +65,8 @@ fn layout_from_edge_file( "*", iter = 500, threads = 0, - model = "\"spring_model\"" + model = "\"spring_model\"", + initial_pos = "None" )] fn layout_from_edge_list( number_of_nodes: usize, @@ -49,6 +74,7 @@ fn layout_from_edge_list( iter: usize, threads: usize, model: &str, + initial_pos: Option>, ) -> PyResult> { let model: Box = pick_model(model)?; @@ -83,7 +109,12 @@ fn layout_from_edge_list( } let r = runner::Runner::new(iter, threads); - Ok(r.layout(number_of_nodes, edge_matrix, model)) + Ok(r.layout( + number_of_nodes, + edge_matrix, + model, + initial_pos_to_node_vector(initial_pos), + )) } #[pymodule] diff --git a/src/runner.rs b/src/runner.rs index c7652c4..8591c26 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -1,4 +1,4 @@ -use crate::graph::{new_node_vector, EdgeMatrix}; +use crate::graph::{new_node_vector, EdgeMatrix, NodeVector}; use crate::model::ForceModel; use crate::utils; use std::sync::{Arc, RwLock}; @@ -35,9 +35,13 @@ impl Runner { number_of_nodes: usize, edges: EdgeMatrix, model: Box, + initial_pos: Option, ) -> Vec<(f32, f32)> { // let edges = connection_matrix(size); - let mut nodes = new_node_vector(number_of_nodes); + let mut nodes = match initial_pos { + Some(pos) => pos, + None => new_node_vector(number_of_nodes), + }; let mut nodes_next = new_node_vector(number_of_nodes); let model = Arc::new(RwLock::new(model)); @@ -114,7 +118,7 @@ mod test { let model = Box::new(MockModel { counter: 0 }); let runner = Runner::new(10, 1); let edges = graph::new_edge_matrix(3); - let result = runner.layout(3, edges, model); + let result = runner.layout(3, edges, model, None); assert_eq!(result, vec![(0.0, 10.0), (1.0, 10.0), (2.0, 10.0)]); } } diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 84833dd..3746553 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -39,4 +39,12 @@ def test_from_file(): pos = graph_force.layout_from_edge_file('/tmp/edges.bin') assert pos is not None - assert len(pos) == 10 \ No newline at end of file + assert len(pos) == 10 + +def test_initial_pos(): + edges = [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6)] + initial = [(i, i) for i in range(7)] + pos = graph_force.layout_from_edge_list(7, edges, iter=0, initial_pos=initial) + assert pos is not None + assert len(pos) == 7 + assert pos == initial