From ada2f2e9b1c4d43a3ff6203e6c9d78f462e3f0ed Mon Sep 17 00:00:00 2001 From: Niko Abeler Date: Sun, 20 Nov 2022 21:44:24 +0100 Subject: [PATCH] more tests + allowing generators as input --- .gitignore | 3 ++- requirements.txt | 1 + src/graph.rs | 43 +++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 33 ++++++++++++++++++++++++++++----- src/runner.rs | 41 +++++------------------------------------ tests/test_e2e.py | 24 ++++++++++++++++++++++++ 6 files changed, 103 insertions(+), 42 deletions(-) create mode 100644 tests/test_e2e.py diff --git a/.gitignore b/.gitignore index 8f41e77..89ce369 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /target result/ -venv/ \ No newline at end of file +venv/ +*/__pycache__/ \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6fdfcae..7d151d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ maturin==0.14.1 +pytest==7.2.0 diff --git a/src/graph.rs b/src/graph.rs index 4649e5c..633e483 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -34,3 +34,46 @@ pub fn new_node_vector(size: usize) -> NodeVector { } Arc::new(nodes) } + +pub fn add_edge(matrix: &mut EdgeMatrix, i: usize, j: usize) { + let mut edges = matrix.write().unwrap(); + edges[i][j].weight = 1.0; + edges[j][i].weight = 1.0; +} + +#[cfg(test)] +mod test { + + use super::*; + + #[test] + fn test_new_edge_matrix() { + let matrix = new_edge_matrix(5); + let edges = matrix.read().unwrap(); + assert_eq!(edges.len(), 5); + for row in edges.iter() { + assert_eq!(row.len(), 5); + for edge in row.iter() { + assert_eq!(edge.weight, 0.0); + } + } + } + + #[test] + fn test_new_node_vector() { + let nodes = new_node_vector(5); + assert_eq!(nodes.len(), 5); + } + + #[test] + fn test_add_edge() { + let mut matrix = new_edge_matrix(5); + add_edge(&mut matrix, 0, 1); + let edges = matrix.read().unwrap(); + assert_eq!(edges[0][1].weight, 1.0); + assert_eq!(edges[1][0].weight, 1.0); + + assert_eq!(edges[0][0].weight, 0.0); + assert_eq!(edges[1][1].weight, 0.0); + } +} diff --git a/src/lib.rs b/src/lib.rs index 70b963a..9ae3ad4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,22 +6,45 @@ mod utils; use std::sync::{Arc, RwLock}; -use pyo3::prelude::*; +use pyo3::exceptions; +use pyo3::{prelude::*, types::PyIterator}; -/// Formats the sum of two numbers as string. #[pyfunction(number_of_nodes, edges, "*", iter = 500, threads = 0)] fn layout_from_edge_list( number_of_nodes: usize, - edges: Vec<(u32, u32)>, + edges: &PyAny, iter: usize, threads: usize, ) -> PyResult> { let model = Arc::new(RwLock::new(spring_model::SimpleSpringModel::new(1.0))); + + let mut edge_matrix = graph::new_edge_matrix(number_of_nodes); + match edges.extract::<&PyIterator>() { + Ok(iter) => { + for edge in iter { + let edge = edge?; + let edge = edge.extract::<(usize, usize)>()?; + graph::add_edge(&mut edge_matrix, edge.0, edge.1); + } + } + Err(_) => match edges.extract::>() { + Ok(edge) => { + for edge in edge { + graph::add_edge(&mut edge_matrix, edge.0, edge.1); + } + } + Err(_) => { + return Err(PyErr::new::( + "Edges must be an iterable of (int, int)", + )); + } + }, + } + let r = runner::Runner::new(iter, threads); - Ok(r.layout(number_of_nodes, edges, model)) + Ok(r.layout(number_of_nodes, edge_matrix, model)) } -/// A Python module implemented in Rust. #[pymodule] fn graph_force(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(layout_from_edge_list, m)?)?; diff --git a/src/runner.rs b/src/runner.rs index e38be93..9f6959f 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -1,4 +1,4 @@ -use crate::graph::{new_edge_matrix, new_node_vector, EdgeMatrix}; +use crate::graph::{new_node_vector, EdgeMatrix}; use crate::model::ForceModel; use crate::utils; use std::sync::{Arc, RwLock}; @@ -33,11 +33,10 @@ impl Runner { pub fn layout( self: &Self, number_of_nodes: usize, - edge_list: Vec<(u32, u32)>, + edges: EdgeMatrix, model: Arc>, ) -> Vec<(f32, f32)> { // let edges = connection_matrix(size); - let edges = edge_matrix_from_edge_list(number_of_nodes, edge_list); let mut nodes = new_node_vector(number_of_nodes); let mut nodes_next = new_node_vector(number_of_nodes); @@ -85,21 +84,9 @@ impl Runner { } } -fn edge_matrix_from_edge_list(number_of_nodes: usize, edge_list: Vec<(u32, u32)>) -> EdgeMatrix { - let matrix_ptr = new_edge_matrix(number_of_nodes as usize); - { - let mut matrix = matrix_ptr.write().unwrap(); - for (node_a, node_b) in edge_list { - matrix[node_a as usize][node_b as usize].weight = 1.0; - matrix[node_b as usize][node_a as usize].weight = 1.0; - } - } - matrix_ptr -} - #[cfg(test)] mod test { - use crate::graph::{Node, NodeVector}; + use crate::graph::{self, Node, NodeVector}; use super::*; @@ -120,30 +107,12 @@ mod test { } } - #[test] - fn test_edge_matrix_from_edge_list() { - let edge_list = vec![(0, 1), (1, 2)]; - let matrix = edge_matrix_from_edge_list(3, edge_list); - let matrix = matrix.read().unwrap(); - assert_eq!(matrix[0][1].weight, 1.0); - assert_eq!(matrix[1][0].weight, 1.0); - assert_eq!(matrix[1][2].weight, 1.0); - assert_eq!(matrix[2][1].weight, 1.0); - - assert_eq!(matrix[0][0].weight, 0.0); - assert_eq!(matrix[1][1].weight, 0.0); - assert_eq!(matrix[2][2].weight, 0.0); - - assert_eq!(matrix[0][2].weight, 0.0); - assert_eq!(matrix[2][0].weight, 0.0); - } - #[test] fn test_layout() { - let edge_list = vec![]; let model = Arc::new(RwLock::new(MockModel { counter: 0 })); let runner = Runner::new(10, 1); - let result = runner.layout(3, edge_list, model); + let edges = graph::new_edge_matrix(3); + let result = runner.layout(3, edges, model); assert_eq!(result, vec![(0.0, 10.0), (1.0, 10.0), (2.0, 10.0)]); } } diff --git a/tests/test_e2e.py b/tests/test_e2e.py new file mode 100644 index 0000000..a9e71b6 --- /dev/null +++ b/tests/test_e2e.py @@ -0,0 +1,24 @@ +import graph_force + +def test_list_of_edges(): + edges = [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6)] + pos = graph_force.layout_from_edge_list(7, edges) + assert pos is not None + assert len(pos) == 7 + + +def test_iterator_of_edges(): + pos = graph_force.layout_from_edge_list( + 7, + ((0, i + 1) for i in range(6)) + ) + assert pos is not None + assert len(pos) == 7 + +def test_tuple_of_edges(): + pos = graph_force.layout_from_edge_list( + 7, + ((0,1), (1,2), (2,3), (3,4), (4,5), (5,6)) + ) + assert pos is not None + assert len(pos) == 7