integration of networkx model

This commit is contained in:
Niko Abeler 2022-11-22 22:35:51 +01:00
parent a503ba80ed
commit 0a548bbd32
5 changed files with 32 additions and 9 deletions

2
Cargo.lock generated
View File

@ -33,7 +33,7 @@ dependencies = [
[[package]]
name = "graph_force"
version = "0.1.2"
version = "0.1.3"
dependencies = [
"pyo3",
"rand",

View File

@ -1,6 +1,6 @@
[package]
name = "graph_force"
version = "0.1.2"
version = "0.1.3"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View File

@ -1,22 +1,37 @@
mod graph;
mod model;
mod networkx_model;
mod runner;
mod spring_model;
mod utils;
use std::sync::{Arc, RwLock};
use pyo3::exceptions;
use pyo3::{prelude::*, types::PyIterator};
#[pyfunction(number_of_nodes, edges, "*", iter = 500, threads = 0)]
#[pyfunction(
number_of_nodes,
edges,
"*",
iter = 500,
threads = 0,
model = "\"spring_model\""
)]
fn layout_from_edge_list(
number_of_nodes: usize,
edges: &PyAny,
iter: usize,
threads: usize,
model: &str,
) -> PyResult<Vec<(f32, f32)>> {
let model = Arc::new(RwLock::new(spring_model::SimpleSpringModel::new(1.0)));
let model: Box<dyn model::ForceModel + Send + Sync> = match model {
"spring_model" => Box::new(spring_model::SimpleSpringModel::new(1.0)),
"networkx_model" => Box::new(networkx_model::NetworkXModel::new()),
_ => {
return Err(PyErr::new::<exceptions::PyValueError, _>(
"model must be either 'spring_model' or 'networkx_model'",
))
}
};
let mut edge_matrix = graph::new_edge_matrix(number_of_nodes);
match edges.extract::<&PyIterator>() {

View File

@ -30,16 +30,18 @@ impl Runner {
}
}
pub fn layout<T: 'static + ForceModel + Sync + Send>(
pub fn layout(
self: &Self,
number_of_nodes: usize,
edges: EdgeMatrix,
model: Arc<RwLock<T>>,
model: Box<dyn ForceModel + Send + Sync>,
) -> Vec<(f32, f32)> {
// let edges = connection_matrix(size);
let mut nodes = new_node_vector(number_of_nodes);
let mut nodes_next = new_node_vector(number_of_nodes);
let model = Arc::new(RwLock::new(model));
model
.write()
.unwrap()
@ -109,7 +111,7 @@ mod test {
#[test]
fn test_layout() {
let model = Arc::new(RwLock::new(MockModel { counter: 0 }));
let model = Box::new(MockModel { counter: 0 });
let runner = Runner::new(10, 1);
let edges = graph::new_edge_matrix(3);
let result = runner.layout(3, edges, model);

View File

@ -22,3 +22,9 @@ def test_tuple_of_edges():
)
assert pos is not None
assert len(pos) == 7
def test_model_selection():
edges = [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6)]
pos = graph_force.layout_from_edge_list(7, edges, model='networkx_model')
assert pos is not None
assert len(pos) == 7