diff --git a/crates/burn-import/src/burn/ty.rs b/crates/burn-import/src/burn/ty.rs index 1e9f1590d4..82ee06246b 100644 --- a/crates/burn-import/src/burn/ty.rs +++ b/crates/burn-import/src/burn/ty.rs @@ -1,11 +1,9 @@ -use onnx_ir::ir::ElementType; use proc_macro2::Ident; use proc_macro2::Span; use proc_macro2::TokenStream; use quote::quote; use crate::burn::ToTokens; -use onnx_ir::ir::{ArgType, Argument as OnnxArgument, TensorType as OnnxTensorType}; #[derive(Debug, Clone)] pub struct TensorType { diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index da83b96290..244cd0091c 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -66,7 +66,7 @@ use super::op_configuration::{ use onnx_ir::{ convert_constant_value, ir::{ - self, ArgType, Argument as OnnxArgument, Data, ElementType, Node as OnnxNode, NodeType, + ArgType, Argument as OnnxArgument, Data, ElementType, Node as OnnxNode, NodeType, OnnxGraph, TensorType as OnnxTensorType, }, parse_onnx, @@ -728,11 +728,7 @@ impl ParsedOnnxGraph { } fn sum_conversion(node: OnnxNode) -> SumNode { - let inputs = node - .inputs - .iter() - .map(|input| TensorType::from(input)) - .collect(); + let inputs = node.inputs.iter().map(TensorType::from).collect(); let output = TensorType::from(node.outputs.first().unwrap()); SumNode::new(inputs, output) @@ -784,11 +780,7 @@ impl ParsedOnnxGraph { } fn concat_conversion(node: OnnxNode) -> ConcatNode { - let inputs = node - .inputs - .iter() - .map(|input| TensorType::from(input)) - .collect(); + let inputs = node.inputs.iter().map(TensorType::from).collect(); let output = TensorType::from(node.outputs.first().unwrap()); let dim = concat_config(&node);