From dcdc12bf01fe086ce14d03a558cfc909557a9c30 Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Wed, 25 Jan 2023 11:20:45 -0500 Subject: [PATCH 1/2] Value streaming for NdArrays --- .../org/tensorflow/ndarray/DoubleNdArray.java | 17 +++++++++- .../org/tensorflow/ndarray/IntNdArray.java | 20 ++++++++++- .../org/tensorflow/ndarray/LongNdArray.java | 17 +++++++++- .../java/org/tensorflow/ndarray/NdArray.java | 18 +++++++++- .../ndarray/IntNdArrayTestBase.java | 32 +++++++++++++++-- .../ndarray/LongNdArrayTestBase.java | 28 ++++++++++++++- .../tensorflow/ndarray/NdArrayTestBase.java | 34 ++++++++++++++++--- 7 files changed, 154 insertions(+), 12 deletions(-) diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java index 80e99b0..b0e6dab 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java @@ -1,5 +1,5 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,9 @@ import org.tensorflow.ndarray.buffer.DoubleDataBuffer; import org.tensorflow.ndarray.index.Index; +import java.util.stream.DoubleStream; +import java.util.stream.StreamSupport; + /** * An {@link NdArray} of doubles. */ @@ -68,6 +71,18 @@ public interface DoubleNdArray extends NdArray { */ DoubleNdArray setDouble(double value, long... coordinates); + /** + * Retrieve all scalar values of this array as a stream of doubles. + * + *

For {@code rank() > 1} arrays, all vectors of the last dimension are collated so that the scalar values are + * returned in sequential order.

+ * + * @return scalar values as a stream + */ + default DoubleStream streamOfDoubles() { + return StreamSupport.stream(scalars().spliterator(), false).mapToDouble(DoubleNdArray::getDouble); + } + @Override DoubleNdArray slice(Index... indices); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java index aa2cc65..b1c44f8 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java @@ -1,5 +1,5 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,12 @@ import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.index.Index; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; +import java.util.stream.LongStream; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + /** * An {@link NdArray} of integers. */ @@ -68,6 +74,18 @@ public interface IntNdArray extends NdArray { */ IntNdArray setInt(int value, long... coordinates); + /** + * Retrieve all scalar values of this array as a stream of integers. + * + *

For {@code rank() > 1} arrays, all vectors of the last dimension are collated so that the scalar values are + * returned in sequential order.

+ * + * @return scalar values as a stream + */ + default IntStream streamOfInts() { + return StreamSupport.stream(scalars().spliterator(), false).mapToInt(IntNdArray::getInt); + } + @Override IntNdArray slice(Index... indices); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java index 3e5be6d..e7bd266 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java @@ -1,5 +1,5 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,6 +20,9 @@ import org.tensorflow.ndarray.buffer.LongDataBuffer; import org.tensorflow.ndarray.index.Index; +import java.util.stream.LongStream; +import java.util.stream.StreamSupport; + /** * An {@link NdArray} of longs. */ @@ -68,6 +71,18 @@ public interface LongNdArray extends NdArray { */ LongNdArray setLong(long value, long... coordinates); + /** + * Retrieve all scalar values of this array as a stream of longs. + * + *

For {@code rank() > 1} arrays, all vectors of the last dimension are collated so that the scalar values are + * returned in sequential order.

+ * + * @return scalar values as a stream + */ + default LongStream streamOfLongs() { + return StreamSupport.stream(scalars().spliterator(), false).mapToLong(LongNdArray::getLong); + } + @Override LongNdArray slice(Index... indices); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java index 6686abd..879bd6c 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java @@ -1,5 +1,5 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,10 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; +import java.util.stream.DoubleStream; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + import org.tensorflow.ndarray.buffer.DataBuffer; import org.tensorflow.ndarray.index.Index; @@ -229,6 +233,18 @@ public interface NdArray extends Shaped { */ NdArray setObject(T value, long... coordinates); + /** + * Retrieve all scalar values of this array as a stream of objects. + * + *

For {@code rank() > 1} arrays, all vectors of the last dimension are collated so that the scalar values are + * returned in sequential order.

+ * + * @return scalar values as a stream + */ + default Stream streamOfObjects() { + return StreamSupport.stream(scalars().spliterator(), false).map(NdArray::getObject); + } + /** * Copy the content of this array to the destination array. * diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/IntNdArrayTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/IntNdArrayTestBase.java index 1a3c7cb..0fba8fb 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/IntNdArrayTestBase.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/IntNdArrayTestBase.java @@ -1,5 +1,5 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,10 +16,11 @@ */ package org.tensorflow.ndarray; -import static org.junit.jupiter.api.Assertions.assertEquals; - import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + public abstract class IntNdArrayTestBase extends NdArrayTestBase { @Override @@ -52,4 +53,29 @@ public void iteratePrimitiveElements() { assertEquals(9, matrix3d.getInt(0, 0, 4)); assertEquals(7, matrix3d.getInt(0, 1, 2)); } + + @Test + public void streamingInts() { + IntNdArray scalar = allocate(Shape.scalar()); + scalar.setInt(1); + var values = scalar.streamOfInts().toArray(); + assertArrayEquals(new int[]{1}, values); + + IntNdArray vector = allocate(Shape.of(5)); + vector.setInt(1, 0); + vector.setInt(2, 1); + vector.setInt(3, 2); + vector.setInt(4, 3); + vector.setInt(5, 4); + values = vector.streamOfInts().toArray(); + assertArrayEquals(new int[]{1, 2, 3, 4, 5}, values); + + IntNdArray matrix = allocate(Shape.of(2, 2)); + matrix.setInt(1, 0, 0); + matrix.setInt(2, 0, 1); + matrix.setInt(3, 1, 0); + matrix.setInt(4, 1, 1); + values = matrix.streamOfInts().toArray(); + assertArrayEquals(new int[]{1, 2, 3, 4}, values); + } } diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/LongNdArrayTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/LongNdArrayTestBase.java index b91c19d..520d9bc 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/LongNdArrayTestBase.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/LongNdArrayTestBase.java @@ -1,5 +1,5 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ */ package org.tensorflow.ndarray; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import org.junit.jupiter.api.Test; @@ -52,4 +53,29 @@ public void iteratePrimitiveElements() { assertEquals(9, matrix3d.getLong(0, 0, 4)); assertEquals(7, matrix3d.getLong(0, 1, 2)); } + + @Test + public void streamingLongs() { + LongNdArray scalar = allocate(Shape.scalar()); + scalar.setLong(1L); + var values = scalar.streamOfLongs().toArray(); + assertArrayEquals(new long[]{1L}, values); + + LongNdArray vector = allocate(Shape.of(5)); + vector.setLong(1L, 0); + vector.setLong(2L, 1); + vector.setLong(3L, 2); + vector.setLong(4L, 3); + vector.setLong(5L, 4); + values = vector.streamOfLongs().toArray(); + assertArrayEquals(new long[]{1L, 2L, 3L, 4L, 5L}, values); + + LongNdArray matrix = allocate(Shape.of(2, 2)); + matrix.setLong(1L, 0, 0); + matrix.setLong(2L, 0, 1); + matrix.setLong(3L, 1, 0); + matrix.setLong(4L, 1, 1); + values = matrix.streamOfLongs().toArray(); + assertArrayEquals(new long[]{1L, 2L, 3L, 4L}, values); + } } diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java index c44db8d..8a09ec7 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java @@ -1,5 +1,5 @@ /* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. + Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -16,9 +16,7 @@ */ package org.tensorflow.ndarray; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.*; import static org.tensorflow.ndarray.NdArrays.vectorOfObjects; import static org.tensorflow.ndarray.index.Indices.all; import static org.tensorflow.ndarray.index.Indices.at; @@ -32,6 +30,9 @@ import java.nio.BufferOverflowException; import java.nio.BufferUnderflowException; +import java.util.List; +import java.util.stream.Collectors; + import org.junit.jupiter.api.Test; import org.tensorflow.ndarray.buffer.DataBuffer; import org.tensorflow.ndarray.index.Indices; @@ -358,4 +359,29 @@ public void iterateScalarsOnSegmentedElements() { }); }); } + + @Test + public void streamingObjects() { + NdArray scalar = allocate(Shape.scalar()); + scalar.setObject(valueOf(1L)); + var values = scalar.streamOfObjects().collect(Collectors.toList()); + assertIterableEquals(List.of(valueOf(1L)), values); + + NdArray vector = allocate(Shape.of(5)); + vector.setObject(valueOf(1L), 0); + vector.setObject(valueOf(2L), 1); + vector.setObject(valueOf(3L), 2); + vector.setObject(valueOf(4L), 3); + vector.setObject(valueOf(5L), 4); + values = vector.streamOfObjects().collect(Collectors.toList()); + assertIterableEquals(List.of(valueOf(1L), valueOf(2L), valueOf(3L), valueOf(4L), valueOf(5L)), values); + + NdArray matrix = allocate(Shape.of(2, 2)); + matrix.setObject(valueOf(1L), 0, 0); + matrix.setObject(valueOf(2L), 0, 1); + matrix.setObject(valueOf(3L), 1, 0); + matrix.setObject(valueOf(4L), 1, 1); + values = matrix.streamOfObjects().collect(Collectors.toList()); + assertIterableEquals(List.of(valueOf(1L), valueOf(2L), valueOf(3L), valueOf(4L)), values); + } } From fa6b9be4b5c2cfd574c730bcd78e71185a9c794e Mon Sep 17 00:00:00 2001 From: Karl Lessard Date: Wed, 25 Jan 2023 13:09:13 -0500 Subject: [PATCH 2/2] Remove unnecessary imports --- ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java | 3 --- ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java | 1 - 2 files changed, 4 deletions(-) diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java index b1c44f8..e6a5cf0 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java @@ -20,10 +20,7 @@ import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.index.Index; -import java.util.stream.DoubleStream; import java.util.stream.IntStream; -import java.util.stream.LongStream; -import java.util.stream.Stream; import java.util.stream.StreamSupport; /** diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java b/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java index 879bd6c..f1e84d4 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java @@ -18,7 +18,6 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; -import java.util.stream.DoubleStream; import java.util.stream.Stream; import java.util.stream.StreamSupport;