Skip to content
This repository was archived by the owner on Jul 15, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
/*
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All@@ -20,6 +20,9 @@
import org.tensorflow.ndarray.buffer.DoubleDataBuffer;
import org.tensorflow.ndarray.index.Index;

import java.util.stream.DoubleStream;
import java.util.stream.StreamSupport;

/**
* An{@link NdArray} of doubles.
*/
Expand DownExpand Up@@ -68,6 +71,18 @@ public interface DoubleNdArray extends NdArray<Double>{
*/
DoubleNdArray setDouble(double value, long... coordinates);

/**
* Retrieve all scalar values of this array as a stream of doubles.
*
* <p>For{@code rank() > 1} arrays, all vectors of the last dimension are collated so that the scalar values are
* returned in sequential order.</p>
*
* @return scalar values as a stream
*/
default DoubleStream streamOfDoubles(){
return StreamSupport.stream(scalars().spliterator(), false).mapToDouble(DoubleNdArray::getDouble);
}

@Override
DoubleNdArray slice(Index... indices);

Expand Down
17 changes: 16 additions & 1 deletion ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
/*
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All@@ -20,6 +20,9 @@
import org.tensorflow.ndarray.buffer.IntDataBuffer;
import org.tensorflow.ndarray.index.Index;

import java.util.stream.IntStream;
import java.util.stream.StreamSupport;

/**
* An{@link NdArray} of integers.
*/
Expand DownExpand Up@@ -68,6 +71,18 @@ public interface IntNdArray extends NdArray<Integer>{
*/
IntNdArray setInt(int value, long... coordinates);

/**
* Retrieve all scalar values of this array as a stream of integers.
*
* <p>For{@code rank() > 1} arrays, all vectors of the last dimension are collated so that the scalar values are
* returned in sequential order.</p>
*
* @return scalar values as a stream
*/
default IntStream streamOfInts(){
return StreamSupport.stream(scalars().spliterator(), false).mapToInt(IntNdArray::getInt);
}

@Override
IntNdArray slice(Index... indices);

Expand Down
17 changes: 16 additions & 1 deletion ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
/*
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All@@ -20,6 +20,9 @@
import org.tensorflow.ndarray.buffer.LongDataBuffer;
import org.tensorflow.ndarray.index.Index;

import java.util.stream.LongStream;
import java.util.stream.StreamSupport;

/**
* An{@link NdArray} of longs.
*/
Expand DownExpand Up@@ -68,6 +71,18 @@ public interface LongNdArray extends NdArray<Long>{
*/
LongNdArray setLong(long value, long... coordinates);

/**
* Retrieve all scalar values of this array as a stream of longs.
*
* <p>For{@code rank() > 1} arrays, all vectors of the last dimension are collated so that the scalar values are
* returned in sequential order.</p>
*
* @return scalar values as a stream
*/
default LongStream streamOfLongs(){
return StreamSupport.stream(scalars().spliterator(), false).mapToLong(LongNdArray::getLong);
}

@Override
LongNdArray slice(Index... indices);

Expand Down
17 changes: 16 additions & 1 deletion ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
/*
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All@@ -18,6 +18,9 @@

import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.index.Index;

Expand DownExpand Up@@ -229,6 +232,18 @@ public interface NdArray<T> extends Shaped{
*/
NdArray<T> setObject(T value, long... coordinates);

/**
* Retrieve all scalar values of this array as a stream of objects.
*
* <p>For{@code rank() > 1} arrays, all vectors of the last dimension are collated so that the scalar values are
* returned in sequential order.</p>
*
* @return scalar values as a stream
*/
default Stream<T> streamOfObjects(){
return StreamSupport.stream(scalars().spliterator(), false).map(NdArray::getObject);
}

/**
* Copy the content of this array to the destination array.
*
Expand Down
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
/*
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All@@ -16,10 +16,11 @@
*/
package org.tensorflow.ndarray;

import static org.junit.jupiter.api.Assertions.assertEquals;

import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;

public abstract class IntNdArrayTestBase extends NdArrayTestBase<Integer>{

@Override
Expand DownExpand Up@@ -52,4 +53,29 @@ public void iteratePrimitiveElements(){
assertEquals(9, matrix3d.getInt(0, 0, 4));
assertEquals(7, matrix3d.getInt(0, 1, 2));
}

@Test
public void streamingInts(){
IntNdArray scalar = allocate(Shape.scalar());
scalar.setInt(1);
var values = scalar.streamOfInts().toArray();
assertArrayEquals(new int[]{1}, values);

IntNdArray vector = allocate(Shape.of(5));
vector.setInt(1, 0);
vector.setInt(2, 1);
vector.setInt(3, 2);
vector.setInt(4, 3);
vector.setInt(5, 4);
values = vector.streamOfInts().toArray();
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, values);

IntNdArray matrix = allocate(Shape.of(2, 2));
matrix.setInt(1, 0, 0);
matrix.setInt(2, 0, 1);
matrix.setInt(3, 1, 0);
matrix.setInt(4, 1, 1);
values = matrix.streamOfInts().toArray();
assertArrayEquals(new int[]{1, 2, 3, 4}, values);
}
}
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
/*
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All@@ -16,6 +16,7 @@
*/
package org.tensorflow.ndarray;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;

import org.junit.jupiter.api.Test;
Expand DownExpand Up@@ -52,4 +53,29 @@ public void iteratePrimitiveElements(){
assertEquals(9, matrix3d.getLong(0, 0, 4));
assertEquals(7, matrix3d.getLong(0, 1, 2));
}

@Test
public void streamingLongs(){
LongNdArray scalar = allocate(Shape.scalar());
scalar.setLong(1L);
var values = scalar.streamOfLongs().toArray();
assertArrayEquals(new long[]{1L}, values);

LongNdArray vector = allocate(Shape.of(5));
vector.setLong(1L, 0);
vector.setLong(2L, 1);
vector.setLong(3L, 2);
vector.setLong(4L, 3);
vector.setLong(5L, 4);
values = vector.streamOfLongs().toArray();
assertArrayEquals(new long[]{1L, 2L, 3L, 4L, 5L}, values);

LongNdArray matrix = allocate(Shape.of(2, 2));
matrix.setLong(1L, 0, 0);
matrix.setLong(2L, 0, 1);
matrix.setLong(3L, 1, 0);
matrix.setLong(4L, 1, 1);
values = matrix.streamOfLongs().toArray();
assertArrayEquals(new long[]{1L, 2L, 3L, 4L}, values);
}
}
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
/*
Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All@@ -16,9 +16,7 @@
*/
package org.tensorflow.ndarray;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assertions.*;
import static org.tensorflow.ndarray.NdArrays.vectorOfObjects;
import static org.tensorflow.ndarray.index.Indices.all;
import static org.tensorflow.ndarray.index.Indices.at;
Expand All@@ -32,6 +30,9 @@

import java.nio.BufferOverflowException;
import java.nio.BufferUnderflowException;
import java.util.List;
import java.util.stream.Collectors;

import org.junit.jupiter.api.Test;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.index.Indices;
Expand DownExpand Up@@ -358,4 +359,29 @@ public void iterateScalarsOnSegmentedElements(){
});
});
}

@Test
public void streamingObjects(){
NdArray<T> scalar = allocate(Shape.scalar());
scalar.setObject(valueOf(1L));
var values = scalar.streamOfObjects().collect(Collectors.toList());
assertIterableEquals(List.of(valueOf(1L)), values);

NdArray<T> vector = allocate(Shape.of(5));
vector.setObject(valueOf(1L), 0);
vector.setObject(valueOf(2L), 1);
vector.setObject(valueOf(3L), 2);
vector.setObject(valueOf(4L), 3);
vector.setObject(valueOf(5L), 4);
values = vector.streamOfObjects().collect(Collectors.toList());
assertIterableEquals(List.of(valueOf(1L), valueOf(2L), valueOf(3L), valueOf(4L), valueOf(5L)), values);

NdArray<T> matrix = allocate(Shape.of(2, 2));
matrix.setObject(valueOf(1L), 0, 0);
matrix.setObject(valueOf(2L), 0, 1);
matrix.setObject(valueOf(3L), 1, 0);
matrix.setObject(valueOf(4L), 1, 1);
values = matrix.streamOfObjects().collect(Collectors.toList());
assertIterableEquals(List.of(valueOf(1L), valueOf(2L), valueOf(3L), valueOf(4L)), values);
}
}