more tests + allowing generators as input

This commit is contained in:
Niko Abeler 2022-11-20 21:44:24 +01:00
parent 2d02427c1f
commit ada2f2e9b1
6 changed files with 103 additions and 42 deletions

3
.gitignore vendored
View File

@ -1,4 +1,5 @@
/target
result/
venv/
venv/
*/__pycache__/

View File

@ -1 +1,2 @@
maturin==0.14.1
pytest==7.2.0

View File

@ -34,3 +34,46 @@ pub fn new_node_vector(size: usize) -> NodeVector {
}
Arc::new(nodes)
}
pub fn add_edge(matrix: &mut EdgeMatrix, i: usize, j: usize) {
let mut edges = matrix.write().unwrap();
edges[i][j].weight = 1.0;
edges[j][i].weight = 1.0;
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_new_edge_matrix() {
let matrix = new_edge_matrix(5);
let edges = matrix.read().unwrap();
assert_eq!(edges.len(), 5);
for row in edges.iter() {
assert_eq!(row.len(), 5);
for edge in row.iter() {
assert_eq!(edge.weight, 0.0);
}
}
}
#[test]
fn test_new_node_vector() {
let nodes = new_node_vector(5);
assert_eq!(nodes.len(), 5);
}
#[test]
fn test_add_edge() {
let mut matrix = new_edge_matrix(5);
add_edge(&mut matrix, 0, 1);
let edges = matrix.read().unwrap();
assert_eq!(edges[0][1].weight, 1.0);
assert_eq!(edges[1][0].weight, 1.0);
assert_eq!(edges[0][0].weight, 0.0);
assert_eq!(edges[1][1].weight, 0.0);
}
}

View File

@ -6,22 +6,45 @@ mod utils;
use std::sync::{Arc, RwLock};
use pyo3::prelude::*;
use pyo3::exceptions;
use pyo3::{prelude::*, types::PyIterator};
/// Formats the sum of two numbers as string.
#[pyfunction(number_of_nodes, edges, "*", iter = 500, threads = 0)]
fn layout_from_edge_list(
number_of_nodes: usize,
edges: Vec<(u32, u32)>,
edges: &PyAny,
iter: usize,
threads: usize,
) -> PyResult<Vec<(f32, f32)>> {
let model = Arc::new(RwLock::new(spring_model::SimpleSpringModel::new(1.0)));
let mut edge_matrix = graph::new_edge_matrix(number_of_nodes);
match edges.extract::<&PyIterator>() {
Ok(iter) => {
for edge in iter {
let edge = edge?;
let edge = edge.extract::<(usize, usize)>()?;
graph::add_edge(&mut edge_matrix, edge.0, edge.1);
}
}
Err(_) => match edges.extract::<Vec<(usize, usize)>>() {
Ok(edge) => {
for edge in edge {
graph::add_edge(&mut edge_matrix, edge.0, edge.1);
}
}
Err(_) => {
return Err(PyErr::new::<exceptions::PyTypeError, _>(
"Edges must be an iterable of (int, int)",
));
}
},
}
let r = runner::Runner::new(iter, threads);
Ok(r.layout(number_of_nodes, edges, model))
Ok(r.layout(number_of_nodes, edge_matrix, model))
}
/// A Python module implemented in Rust.
#[pymodule]
fn graph_force(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(layout_from_edge_list, m)?)?;

View File

@ -1,4 +1,4 @@
use crate::graph::{new_edge_matrix, new_node_vector, EdgeMatrix};
use crate::graph::{new_node_vector, EdgeMatrix};
use crate::model::ForceModel;
use crate::utils;
use std::sync::{Arc, RwLock};
@ -33,11 +33,10 @@ impl Runner {
pub fn layout<T: 'static + ForceModel + Sync + Send>(
self: &Self,
number_of_nodes: usize,
edge_list: Vec<(u32, u32)>,
edges: EdgeMatrix,
model: Arc<RwLock<T>>,
) -> Vec<(f32, f32)> {
// let edges = connection_matrix(size);
let edges = edge_matrix_from_edge_list(number_of_nodes, edge_list);
let mut nodes = new_node_vector(number_of_nodes);
let mut nodes_next = new_node_vector(number_of_nodes);
@ -85,21 +84,9 @@ impl Runner {
}
}
fn edge_matrix_from_edge_list(number_of_nodes: usize, edge_list: Vec<(u32, u32)>) -> EdgeMatrix {
let matrix_ptr = new_edge_matrix(number_of_nodes as usize);
{
let mut matrix = matrix_ptr.write().unwrap();
for (node_a, node_b) in edge_list {
matrix[node_a as usize][node_b as usize].weight = 1.0;
matrix[node_b as usize][node_a as usize].weight = 1.0;
}
}
matrix_ptr
}
#[cfg(test)]
mod test {
use crate::graph::{Node, NodeVector};
use crate::graph::{self, Node, NodeVector};
use super::*;
@ -120,30 +107,12 @@ mod test {
}
}
#[test]
fn test_edge_matrix_from_edge_list() {
let edge_list = vec![(0, 1), (1, 2)];
let matrix = edge_matrix_from_edge_list(3, edge_list);
let matrix = matrix.read().unwrap();
assert_eq!(matrix[0][1].weight, 1.0);
assert_eq!(matrix[1][0].weight, 1.0);
assert_eq!(matrix[1][2].weight, 1.0);
assert_eq!(matrix[2][1].weight, 1.0);
assert_eq!(matrix[0][0].weight, 0.0);
assert_eq!(matrix[1][1].weight, 0.0);
assert_eq!(matrix[2][2].weight, 0.0);
assert_eq!(matrix[0][2].weight, 0.0);
assert_eq!(matrix[2][0].weight, 0.0);
}
#[test]
fn test_layout() {
let edge_list = vec![];
let model = Arc::new(RwLock::new(MockModel { counter: 0 }));
let runner = Runner::new(10, 1);
let result = runner.layout(3, edge_list, model);
let edges = graph::new_edge_matrix(3);
let result = runner.layout(3, edges, model);
assert_eq!(result, vec![(0.0, 10.0), (1.0, 10.0), (2.0, 10.0)]);
}
}

24
tests/test_e2e.py Normal file
View File

@ -0,0 +1,24 @@
import graph_force
def test_list_of_edges():
edges = [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6)]
pos = graph_force.layout_from_edge_list(7, edges)
assert pos is not None
assert len(pos) == 7
def test_iterator_of_edges():
pos = graph_force.layout_from_edge_list(
7,
((0, i + 1) for i in range(6))
)
assert pos is not None
assert len(pos) == 7
def test_tuple_of_edges():
pos = graph_force.layout_from_edge_list(
7,
((0,1), (1,2), (2,3), (3,4), (4,5), (5,6))
)
assert pos is not None
assert len(pos) == 7