diff --git a/pom.xml b/pom.xml index f7f90275778..7fa0865af0a 100644 --- a/pom.xml +++ b/pom.xml @@ -31,6 +31,7 @@ + tensorflow-ndarray tensorflow-core tensorflow-framework diff --git a/tensorflow-core/tensorflow-core-api/pom.xml b/tensorflow-core/tensorflow-core-api/pom.xml index 59e1703d355..a4cd84dcf20 100644 --- a/tensorflow-core/tensorflow-core-api/pom.xml +++ b/tensorflow-core/tensorflow-core-api/pom.xml @@ -15,7 +15,6 @@ Platform-dependent native code and pure-Java code for the TensorFlow machine intelligence library. - 1.0.0 1.1.5 false ${project.build.directory}/tf-text-download/ @@ -24,8 +23,8 @@ org.tensorflow - ndarray - ${ndarray.version} + tensorflow-ndarray + ${project.version} org.tensorflow diff --git a/tensorflow-ndarray/pom.xml b/tensorflow-ndarray/pom.xml new file mode 100644 index 00000000000..8f1df831143 --- /dev/null +++ b/tensorflow-ndarray/pom.xml @@ -0,0 +1,101 @@ + + + 4.0.0 + + + org.tensorflow + tensorflow-java + 1.2.0-SNAPSHOT + + tensorflow-ndarray + jar + + TensorFlow NdArray Library + + Utility library for N-dimensional data I/O operations in Java. + + + + org.tensorflow.ndarray + + + + + org.junit.jupiter + junit-jupiter-api + test + + + org.junit.jupiter + junit-jupiter-engine + test + + + org.openjdk.jmh + jmh-core + test + + + org.openjdk.jmh + jmh-generator-annprocess + test + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + ${java.module.name} + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + 1 + false + -Xmx2G + + + + org.apache.maven.plugins + maven-compiler-plugin + + + default-testCompile + + + --add-modules=java.desktop + + + + + + + + + diff --git a/tensorflow-ndarray/src/main/java/module-info.java b/tensorflow-ndarray/src/main/java/module-info.java new file mode 100644 index 00000000000..e9f351a878c --- /dev/null +++ b/tensorflow-ndarray/src/main/java/module-info.java @@ -0,0 +1,38 @@ +/* +Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +module org.tensorflow.ndarray { + requires jdk.unsupported; // required by raw buffer implementations using Unsafe + + exports org.tensorflow.ndarray; + exports org.tensorflow.ndarray.buffer; + exports org.tensorflow.ndarray.buffer.layout; + exports org.tensorflow.ndarray.index; + + // Expose all implementations of our interfaces, so consumers can write custom + // implementations easily by extending from them + exports org.tensorflow.ndarray.impl.buffer; + exports org.tensorflow.ndarray.impl.buffer.adapter; + exports org.tensorflow.ndarray.impl.buffer.layout; + exports org.tensorflow.ndarray.impl.buffer.misc; + exports org.tensorflow.ndarray.impl.buffer.nio; + exports org.tensorflow.ndarray.impl.buffer.raw; + exports org.tensorflow.ndarray.impl.dense; + exports org.tensorflow.ndarray.impl.dimension; + exports org.tensorflow.ndarray.impl.sequence; + exports org.tensorflow.ndarray.impl.sparse; + exports org.tensorflow.ndarray.impl.sparse.slice; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/BooleanNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/BooleanNdArray.java new file mode 100644 index 00000000000..58f9c71456b --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/BooleanNdArray.java @@ -0,0 +1,115 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.index.Index; + +/** An {@link NdArray} of booleans. */ +public interface BooleanNdArray extends NdArray { + + /** + * Returns the boolean value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * BooleanNdArray matrix = NdArrays.ofBooleans(shape(2, 2));  // matrix rank = 2
+   * matrix.getBoolean(0, 1);  // succeeds, returns false
+   * matrix.getBoolean(0);  // throws IllegalRankException
+   *
+   * BooleanNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.getBoolean();  // succeeds, returns false
+   * }
+ * + * @param coordinates coordinates of the scalar to resolve + * @return value of that scalar + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + boolean getBoolean(long... coordinates); + + /** + * Assigns the boolean value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * BooleanNdArray matrix = NdArrays.ofBooleans(shape(2, 2));  // matrix rank = 2
+   * matrix.setBoolean(true, 0, 1);  // succeeds
+   * matrix.setBoolean(true, 0);  // throws IllegalRankException
+   *
+   * BooleanNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.setBoolean(true);  // succeeds
+   * }
+ * + * @param value the value to assign + * @param coordinates coordinates of the scalar to assign + * @return this array + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + BooleanNdArray setBoolean(boolean value, long... coordinates); + + @Override + BooleanNdArray withShape(Shape shape); + + @Override + BooleanNdArray slice(Index... indices); + + @Override + BooleanNdArray get(long... coordinates); + + @Override + BooleanNdArray set(NdArray src, long... coordinates); + + @Override + default Boolean getObject(long... coordinates) { + return getBoolean(coordinates); + } + + @Override + default BooleanNdArray setObject(Boolean value, long... coordinates) { + return setBoolean(value, coordinates); + } + + @Override + NdArraySequence elements(int dimensionIdx); + + @Override + NdArraySequence scalars(); + + @Override + BooleanNdArray copyTo(NdArray dst); + + @Override + BooleanNdArray copyTo(DataBuffer dst); + + BooleanNdArray copyTo(BooleanDataBuffer dst); + + @Override + BooleanNdArray copyFrom(DataBuffer src); + + BooleanNdArray copyFrom(BooleanDataBuffer src); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/ByteNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/ByteNdArray.java new file mode 100644 index 00000000000..9354b474f27 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/ByteNdArray.java @@ -0,0 +1,115 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.index.Index; + +/** An {@link NdArray} of bytes. */ +public interface ByteNdArray extends NdArray { + + /** + * Returns the byte value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * ByteNdArray matrix = NdArrays.ofBytes(shape(2, 2));  // matrix rank = 2
+   * matrix.getByte(0, 1);  // succeeds, returns 0
+   * matrix.getByte(0);  // throws IllegalRankException
+   *
+   * ByteNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.getByte();  // succeeds, returns 0
+   * }
+ * + * @param coordinates coordinates of the scalar to resolve + * @return value of that scalar + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + byte getByte(long... coordinates); + + /** + * Assigns the byte value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * ByteNdArray matrix = NdArrays.ofBytes(shape(2, 2));  // matrix rank = 2
+   * matrix.setByte(10, 0, 1);  // succeeds
+   * matrix.setByte(10, 0);  // throws IllegalRankException
+   *
+   * ByteNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.setByte(10);  // succeeds
+   * }
+ * + * @param value the value to assign + * @param coordinates coordinates of the scalar to assign + * @return this array + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + ByteNdArray setByte(byte value, long... coordinates); + + @Override + ByteNdArray withShape(Shape shape); + + @Override + ByteNdArray slice(Index... indices); + + @Override + ByteNdArray get(long... coordinates); + + @Override + ByteNdArray set(NdArray src, long... coordinates); + + @Override + default Byte getObject(long... coordinates) { + return getByte(coordinates); + } + + @Override + default ByteNdArray setObject(Byte value, long... coordinates) { + return setByte(value, coordinates); + } + + @Override + NdArraySequence elements(int dimensionIdx); + + @Override + NdArraySequence scalars(); + + @Override + ByteNdArray copyTo(NdArray dst); + + @Override + ByteNdArray copyTo(DataBuffer dst); + + ByteNdArray copyTo(ByteDataBuffer dst); + + @Override + ByteNdArray copyFrom(DataBuffer src); + + ByteNdArray copyFrom(ByteDataBuffer src); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java new file mode 100644 index 00000000000..f631a1c7522 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/DoubleNdArray.java @@ -0,0 +1,130 @@ +/* +Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import java.util.stream.DoubleStream; +import java.util.stream.StreamSupport; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.index.Index; + +/** An {@link NdArray} of doubles. */ +public interface DoubleNdArray extends NdArray { + + /** + * Returns the double value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * DoubleNdArray matrix = NdArrays.ofDoubles(shape(2, 2));  // matrix rank = 2
+   * matrix.getDouble(0, 1);  // succeeds, returns 0.0
+   * matrix.getDouble(0);  // throws IllegalRankException
+   *
+   * DoubleNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.getDouble();  // succeeds, returns 0.0
+   * }
+ * + * @param coordinates coordinates of the scalar to resolve + * @return value of that scalar + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + double getDouble(long... coordinates); + + /** + * Assigns the double value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * DoubleNdArray matrix = NdArrays.ofDoubles(shape(2, 2));  // matrix rank = 2
+   * matrix.setDouble(10.0, 0, 1);  // succeeds
+   * matrix.setDouble(10.0, 0);  // throws IllegalRankException
+   *
+   * DoubleNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.setDouble(10.0);  // succeeds
+   * }
+ * + * @param value value to assign + * @param coordinates coordinates of the scalar to assign + * @return this array + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + DoubleNdArray setDouble(double value, long... coordinates); + + /** + * Retrieve all scalar values of this array as a stream of doubles. + * + *

For {@code rank() > 1} arrays, all vectors of the last dimension are collated so that the + * scalar values are returned in sequential order. + * + * @return scalar values as a stream + */ + default DoubleStream streamOfDoubles() { + return StreamSupport.stream(scalars().spliterator(), false) + .mapToDouble(DoubleNdArray::getDouble); + } + + @Override + DoubleNdArray withShape(Shape shape); + + @Override + DoubleNdArray slice(Index... indices); + + @Override + DoubleNdArray get(long... coordinates); + + @Override + DoubleNdArray set(NdArray src, long... coordinates); + + @Override + default Double getObject(long... coordinates) { + return getDouble(coordinates); + } + + @Override + default DoubleNdArray setObject(Double value, long... coordinates) { + return setDouble(value, coordinates); + } + + @Override + NdArraySequence elements(int dimensionIdx); + + @Override + NdArraySequence scalars(); + + @Override + DoubleNdArray copyTo(NdArray dst); + + @Override + DoubleNdArray copyTo(DataBuffer dst); + + DoubleNdArray copyTo(DoubleDataBuffer dst); + + @Override + DoubleNdArray copyFrom(DataBuffer src); + + DoubleNdArray copyFrom(DoubleDataBuffer src); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/FloatNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/FloatNdArray.java new file mode 100644 index 00000000000..1d564370b51 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/FloatNdArray.java @@ -0,0 +1,115 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.index.Index; + +/** An {@link NdArray} of floats. */ +public interface FloatNdArray extends NdArray { + + /** + * Returns the float value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * FloatNdArray matrix = NdArrays.ofFloats(shape(2, 2));  // matrix rank = 2
+   * matrix.getFloat(0, 1);  // succeeds, returns 0.0f
+   * matrix.getFloat(0);  // throws IllegalRankException
+   *
+   * FloatNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.getFloat();  // succeeds, returns 0.0f
+   * }
+ * + * @param coordinates coordinates of the scalar to resolve + * @return value of that scalar + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + float getFloat(long... coordinates); + + /** + * Assigns the float value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * FloatNdArray matrix = NdArrays.ofFloats(shape(2, 2));  // matrix rank = 2
+   * matrix.setFloat(10.0f, 0, 1);  // succeeds
+   * matrix.setFloat(10.0f, 0);  // throws IllegalRankException
+   *
+   * FloatNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.setFloat(10.0f);  // succeeds
+   * }
+ * + * @param value value to assign + * @param coordinates coordinates of the scalar to assign + * @return this array + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + FloatNdArray setFloat(float value, long... coordinates); + + @Override + FloatNdArray withShape(Shape shape); + + @Override + FloatNdArray slice(Index... coordinates); + + @Override + FloatNdArray get(long... coordinates); + + @Override + FloatNdArray set(NdArray src, long... coordinates); + + @Override + default Float getObject(long... coordinates) { + return getFloat(coordinates); + } + + @Override + default FloatNdArray setObject(Float value, long... coordinates) { + return setFloat(value, coordinates); + } + + @Override + NdArraySequence elements(int dimensionIdx); + + @Override + NdArraySequence scalars(); + + @Override + FloatNdArray copyTo(NdArray dst); + + @Override + FloatNdArray copyTo(DataBuffer dst); + + FloatNdArray copyTo(FloatDataBuffer dst); + + @Override + FloatNdArray copyFrom(DataBuffer src); + + FloatNdArray copyFrom(FloatDataBuffer src); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/IllegalRankException.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/IllegalRankException.java new file mode 100644 index 00000000000..4f4efd281b2 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/IllegalRankException.java @@ -0,0 +1,27 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +/** + * Exception thrown when an operation cannot be completed because of the rank of the targeted array. + */ +public class IllegalRankException extends IllegalArgumentException { + + public IllegalRankException(String message) { + super(message); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java new file mode 100644 index 00000000000..74bd7429213 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/IntNdArray.java @@ -0,0 +1,129 @@ +/* +Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import java.util.stream.IntStream; +import java.util.stream.StreamSupport; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.index.Index; + +/** An {@link NdArray} of integers. */ +public interface IntNdArray extends NdArray { + + /** + * Returns the integer value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * IntNdArray matrix = NdArrays.ofInts(shape(2, 2));  // matrix rank = 2
+   * matrix.getInt(0, 1);  // succeeds, returns 0
+   * matrix.getInt(0);  // throws IllegalRankException
+   *
+   * IntNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.getInt();  // succeeds, returns 0
+   * }
+ * + * @param coordinates coordinates of the scalar to resolve + * @return value of that scalar + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + int getInt(long... coordinates); + + /** + * Assigns the integer value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * IntNdArray matrix = NdArrays.ofInts(shape(2, 2));  // matrix rank = 2
+   * matrix.setInt(10, 0, 1);  // succeeds
+   * matrix.setInt(10, 0);  // throws IllegalRankException
+   *
+   * IntNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.setInt(10);  // succeeds
+   * }
+ * + * @param value value to assign + * @param coordinates coordinates of the scalar to assign + * @return this array + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + IntNdArray setInt(int value, long... coordinates); + + /** + * Retrieve all scalar values of this array as a stream of integers. + * + *

For {@code rank() > 1} arrays, all vectors of the last dimension are collated so that the + * scalar values are returned in sequential order. + * + * @return scalar values as a stream + */ + default IntStream streamOfInts() { + return StreamSupport.stream(scalars().spliterator(), false).mapToInt(IntNdArray::getInt); + } + + @Override + IntNdArray withShape(Shape shape); + + @Override + IntNdArray slice(Index... indices); + + @Override + IntNdArray get(long... coordinates); + + @Override + IntNdArray set(NdArray src, long... coordinates); + + @Override + default Integer getObject(long... coordinates) { + return getInt(coordinates); + } + + @Override + default IntNdArray setObject(Integer value, long... coordinates) { + return setInt(value, coordinates); + } + + @Override + NdArraySequence elements(int dimensionIdx); + + @Override + NdArraySequence scalars(); + + @Override + IntNdArray copyTo(NdArray dst); + + @Override + IntNdArray copyTo(DataBuffer dst); + + IntNdArray copyTo(IntDataBuffer dst); + + @Override + IntNdArray copyFrom(DataBuffer src); + + IntNdArray copyFrom(IntDataBuffer src); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java new file mode 100644 index 00000000000..dc781be6957 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/LongNdArray.java @@ -0,0 +1,129 @@ +/* +Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import java.util.stream.LongStream; +import java.util.stream.StreamSupport; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.index.Index; + +/** An {@link NdArray} of longs. */ +public interface LongNdArray extends NdArray { + + /** + * Returns the long value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * LongNdArray matrix = NdArrays.ofLongs(shape(2, 2));  // matrix rank = 2
+   * matrix.getLong(0, 1);  // succeeds, returns 0L
+   * matrix.getLong(0);  // throws IllegalRankException
+   *
+   * LongNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.getLong();  // succeeds, returns 0L
+   * }
+ * + * @param coordinates coordinates of the scalar to resolve + * @return value of that scalar + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + long getLong(long... coordinates); + + /** + * Assigns the long value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * LongNdArray matrix = NdArrays.ofLongs(shape(2, 2));  // matrix rank = 2
+   * matrix.setLong(10L, 0, 1);  // succeeds
+   * matrix.setLong(10L, 0);  // throws IllegalRankException
+   *
+   * LongNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.setLong(10L);  // succeeds
+   * }
+ * + * @param value value to assign + * @param coordinates coordinates of the scalar to assign + * @return this array + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + LongNdArray setLong(long value, long... coordinates); + + /** + * Retrieve all scalar values of this array as a stream of longs. + * + *

For {@code rank() > 1} arrays, all vectors of the last dimension are collated so that the + * scalar values are returned in sequential order. + * + * @return scalar values as a stream + */ + default LongStream streamOfLongs() { + return StreamSupport.stream(scalars().spliterator(), false).mapToLong(LongNdArray::getLong); + } + + @Override + LongNdArray withShape(Shape shape); + + @Override + LongNdArray slice(Index... indices); + + @Override + LongNdArray get(long... coordinates); + + @Override + LongNdArray set(NdArray src, long... coordinates); + + @Override + default Long getObject(long... coordinates) { + return getLong(coordinates); + } + + @Override + default LongNdArray setObject(Long value, long... coordinates) { + return setLong(value, coordinates); + } + + @Override + NdArraySequence elements(int dimensionIdx); + + @Override + NdArraySequence scalars(); + + @Override + LongNdArray copyTo(NdArray dst); + + @Override + LongNdArray copyTo(DataBuffer dst); + + LongNdArray copyTo(LongDataBuffer dst); + + @Override + LongNdArray copyFrom(DataBuffer src); + + LongNdArray copyFrom(LongDataBuffer src); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java new file mode 100644 index 00000000000..4f7e4fbf5d6 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/NdArray.java @@ -0,0 +1,361 @@ +/* +Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.index.Index; + +/** + * A data structure of N-dimensions. + * + *

The `NdArray` interface creates an abstraction between the physical storage of a data record, + * which can be linear or segmented, and its logical representation. In general, they achieve better + * performances than standard multi-dimensional arrays in Java by mapping directly linear data + * segments in memory. + * + *

Like {@link DataBuffer}, {@code NdArray} instances support 64-bits indexing so they can be + * used to map very large data records. They also support special coordinates that allow traversing + * their values in any direction or to select only a subset of them. + * + *

Example of usage: + * + *

{@code
+ * // Creates a 2x3x2 matrix (of rank 3)
+ * FloatNdArray matrix3d = NdArrays.ofFloats(shape(2, 3, 2));
+ *
+ * // Initialize sub-matrices data with vectors
+ * matrix.set(NdArrays.vectorOf(1.0f, 2.0f), 0, 0)
+ *       .set(NdArrays.vectorOf(3.0f, 4.0f), 0, 1)
+ *       .set(NdArrays.vectorOf(5.0f, 6.0f), 0, 2)
+ *       .set(NdArrays.vectorOf(7.0f, 8.0f), 1, 0)
+ *       .set(NdArrays.vectorOf(9.0f, 10.0f), 1, 1)
+ *       .set(NdArrays.vectorOf(11.0f, 12.0f), 1, 2);
+ *
+ * // Access the second 3x2 matrix (of rank 2)
+ * FloatNdArray matrix = matrix3d.get(1);
+ *
+ * // Access directly the float value at (1, 0) from the second matrix
+ * assertEquals(9.0f, matrix.getFloat(1, 0));
+ * }
+ * + * @param the type of values to be mapped + */ +public interface NdArray extends Shaped { + + /** + * Returns a sequence of all elements at a given dimension. + * + *

Logically, the N-dimensional array can be flatten in a single vector, where the scalars of + * the {@code (n - 1)}th element precedes those of the {@code (n)}th element, for a total of + * {@link #size()} values. + * + *

For example, given a {@code n x m} matrix on the {@code [x, y]} axes, elements are iterated + * in the following order: + * + *

x0y0, x0y1, ..., x0ym-1, + * x1y0, x1y1, ..., xn-1ym-1 + * + *

The returned sequence can then be iterated to visit each elements, either by calling {@link + * NdArraySequence#forEach(Consumer)} or {@link NdArraySequence#forEachIndexed(BiConsumer)}. + * + *

{@code
+   * // Iterate matrix for initializing each of its vectors
+   * matrixOfFloats.elements(0).forEach(v -> {
+   *   v.set(vector(1.0f, 2.0f, 3.0f));
+   * });
+   *
+   * // Iterate a vector for reading each of its scalar
+   * vectorOfFloats.scalars().forEachIdx((coords, s) -> {
+   *   System.out.println("Value " + s.getFloat() + " found at " + coords);
+   * });
+   * }
+ * + * @param dimensionIdx index of the dimension + * @return an {@code NdArray} sequence + * @throws IllegalArgumentException if {@code dimensionIdx} is greater or equal to the total + * number of dimensions of this array + */ + NdArraySequence> elements(int dimensionIdx); + + /** + * Returns a sequence of all scalars in this array. + * + *

This is equivalent to call {@code elements(shape().numDimensions() - 1)} + * + * @return an {@code NdArray} sequence + */ + NdArraySequence> scalars(); + + /** + * Returns a new N-dimensional view of this array with the given {@code shape}. + * + *

The provided {@code shape} must comply to the following characteristics: + * + *

    + *
  • new shape is known (i.e. has no unknown dimension) + *
  • new shape size is equal to the size of the current shape (i.e. same number of elements) + *
+ * + * For example, + * + *
{@code
+   * NdArrays.ofInts(Shape.scalar()).withShape(Shape.of(1, 1));  // ok
+   * NdArrays.ofInts(Shape.of(2, 3).withShape(Shape.of(3, 2));   // ok
+   * NdArrays.ofInts(Shape.scalar()).withShape(Shape.of(1, 2));  // not ok, sizes are different (1 != 2)
+   * NdArrays.ofInts(Shape.of(2, 3)).withShape(Shape.unknown()); // not ok, new shape unknown
+   * }
+ * + *

Any changes applied to the returned view affect the data of this array as well, as there is + * no copy involved. + * + * @param shape the new shape to apply + * @return a new array viewing the data according to the new shape, or this array if shapes are + * the same + * @throws IllegalArgumentException if the provided {@code shape} is not compliant + * @throws UnsupportedOperationException if this array does not support this operation + */ + NdArray withShape(Shape shape); + + /** + * Creates a multi-dimensional view (or slice) of this array by mapping one or more dimensions to + * the given index selectors. + * + *

Slices allow to traverse an N-dimensional array in any of its axis and/or to filter only + * elements of interest. For example, for a given matrix on the {@code [x, y]} axes, it is + * possible to iterate elements at {@code y=0} for all {@code x}. + * + *

Any changes applied to the returned slice affect the data of this array as well, as there is + * no copy involved. + * + *

Example of usage: + * + *

{@code
+   * FloatNdArray matrix3d = NdArrays.ofFloats(shape(3, 2, 4));  // with [x, y, z] axes
+   *
+   * // Iterates elements on the x axis by preserving only the 3rd value on the z axis,
+   * // (i.e. [x, y, 2])
+   * matrix3d.slice(all(), all(), at(2)).elements(0).forEach(m -> {
+   *   assertEquals(shape(2), m); // y=2, z=0 (scalar)
+   * });
+   *
+   * // Creates a slice that contains only the last element of the y axis and elements with an
+   * // odd `z` coordinate.
+   * FloatNdArray slice = matrix3d.slice(all(), at(1), odd());
+   * assertEquals(shape(3, 2), slice.shape());  // x=3, y=0 (scalar), z=2 (odd coordinates)
+   *
+   * // Iterates backward the elements on the x axis
+   * matrix3d.slice(flip()).elements(0).forEach(m -> {
+   *   assertEquals(shape(2, 4), m);  // y=2, z=4
+   * });
+   * }
+ * + * @param indices index selectors per dimensions, starting from dimension 0 of this array. + * @return the element resulting of the index selection + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + */ + NdArray slice(Index... indices); + + /** + * Returns the N-dimensional element of this array at the given coordinates. + * + *

Elements of any of the dimensions of this array can be retrieved. For example, if the number + * of coordinates is equal to the number of dimensions of this array, then a rank-0 (scalar) array + * is returned, which value can then be obtained by calling `array.getObject()`. + * + *

Any changes applied to the returned elements affect the data of this array as well, as there + * is no copy involved. + * + *

Note that invoking this method is an equivalent and more efficient way to slice this array + * on single scalar, i.e. {@code array.get(x, y, z)} is equal to {@code array.slice(at(x), at(y), + * at(z))} + * + * @param coordinates coordinates of the element to access, none will return this array + * @return the element at this index + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + */ + NdArray get(long... coordinates); + + /** + * Assigns the value of the N-dimensional element found at the given coordinates. + * + *

The number of coordinates provided can be anywhere between 0 and rank - 1. For example: + * + *

{@code
+   * FloatNdArray matrix = NdArrays.ofFloats(shape(2, 2));  // matrix rank = 2
+   * matrix.set(vector(10.0f, 20.0f), 0);  // success
+   * matrix.set(scalar(10.0f), 1, 0); // success
+   * }
+ * + * @param src an array of the values to assign + * @param coordinates coordinates of the element to assign + * @return this array + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + */ + NdArray set(NdArray src, long... coordinates); + + /** + * Returns the value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * FloatNdArray matrix = NdArrays.ofFloats(shape(2, 2));  // matrix rank = 2
+   * matrix.getObject(0, 1);  // succeeds, returns 0.0f
+   * matrix.getObject(0);  // throws IllegalRankException
+   *
+   * FloatNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.getObject();  // succeeds, returns 0.0f
+   * }
+ * + * Note: if this array stores values of a primitive type, prefer the usage of the specialized + * method in the subclass for that type. For example, {@code floatArray.getFloat(0); }. + * + * @param coordinates coordinates of the scalar to resolve + * @return value of that scalar + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + T getObject(long... coordinates); + + /** + * Assigns the value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * FloatNdArray matrix = NdArrays.ofFloats(shape(2, 2));  // matrix rank = 2
+   * matrix.setObject(10.0f, 0, 1);  // succeeds
+   * matrix.setObject(10.0f, 0);  // throws IllegalRankException
+   *
+   * FloatNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.setObject(10.0f);  // succeeds
+   * }
+ * + * Note: if this array stores values of a primitive type, prefer the usage of the specialized + * method in the subclass for that type. For example, {@code floatArray.setFloat(10.0f, 0); } + * + * @param value the value to assign + * @param coordinates coordinates of the scalar to assign + * @return this array + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + NdArray setObject(T value, long... coordinates); + + /** + * Retrieve all scalar values of this array as a stream of objects. + * + *

For {@code rank() > 1} arrays, all vectors of the last dimension are collated so that the + * scalar values are returned in sequential order. + * + * @return scalar values as a stream + */ + default Stream streamOfObjects() { + return StreamSupport.stream(scalars().spliterator(), false).map(NdArray::getObject); + } + + /** + * Copy the content of this array to the destination array. + * + *

The {@link #shape()} of the destination array must be equal to the shape of this array, or + * an exception is thrown. After the copy, the content of both arrays can be altered + * independently, without affecting each other. + * + * @param dst array to receive a copy of the content of this array + * @return this array + * @throws IllegalArgumentException if the shape of {@code dst} is not equal to the shape of this + * array + */ + NdArray copyTo(NdArray dst); + + /** + * Copy the content of this N-dimensional array into the destination buffer. + * + *

The size of the buffer must be equal or greater to the {@link #size()} of this array, or an + * exception is thrown. After the copy, content of the buffer and of the array can be altered + * independently, without affecting each other. + * + *

Note: in version 0.4.0 and earlier, this method was named {@code read(DataBuffer)}. It + * has been renamed to explicitly indicate the direction of the data flow to avoid confusion. + * + * @param dst the destination buffer + * @return this array + * @throws java.nio.BufferOverflowException if the buffer cannot hold the content of this array + * @see DataBuffer#size() + */ + NdArray copyTo(DataBuffer dst); + + /** + * Copy the content of the source buffer into this N-dimensional array. + * + *

The size of the buffer must be equal or greater to the {@link #size()} of this array, or an + * exception is thrown. After the copy, content of the buffer and of the array can be altered + * independently, without affecting each other. + * + *

Note: in version 0.4.0 and earlier, this method was named {@code write(DataBuffer)}. + * It has been renamed to explicitly indicate the direction of the data flow to avoid + * confusion. + * + * @param src the source buffer + * @return this array + * @throws java.nio.BufferUnderflowException if the buffer has not enough remaining data to write + * into this array + * @see DataBuffer#size() + */ + NdArray copyFrom(DataBuffer src); + + /** + * Checks equality between n-dimensional arrays. + * + *

An array is equal to another object if this object is another {@link NdArray} of the same + * shape, type and the elements are equal and in the same order. For example: + * + *

{@code
+   * IntNdArray array = NdArrays.ofInts(Shape.of(2, 2))
+   *    .set(NdArrays.vectorOf(1, 2), 0)
+   *    .set(NdArrays.vectorOf(3, 4), 1);
+   *
+   * assertEquals(array, StdArrays.ndCopyOf(new int[][] {{1, 2}, {3, 4}}));  // true
+   * assertEquals(array, StdArrays.ndCopyOf(new Integer[][] {{1, 2}, {3, 4}}));  // true, as Integers are equal to ints
+   * assertNotEquals(array, NdArrays.vectorOf(1, 2, 3, 4));  // false, different shapes
+   * assertNotEquals(array, StdArrays.ndCopyOf(new int[][] {{3, 4}, {1, 2}}));  // false, different order
+   * assertNotEquals(array, StdArrays.ndCopyOf(new long[][] {{1L, 2L}, {3L, 4L}}));  // false, different types
+   * }
+ * + *

Note that the computation required to verify equality between two arrays can be expensive in + * some cases and therefore, it is recommended to not use this method in a critical path where + * performances matter. + * + * @param obj object to compare this array with + * @return true if this array is equal to the provided object + */ + @Override + boolean equals(Object obj); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/NdArraySequence.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/NdArraySequence.java new file mode 100644 index 00000000000..bda82cec383 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/NdArraySequence.java @@ -0,0 +1,71 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray; + +import java.util.function.BiConsumer; +import org.tensorflow.ndarray.buffer.DataBufferWindow; + +/** + * A sequence of elements of an N-dimensional array. + * + *

An {@code NdArraySequence} is used to traverse an {@code NdArray} in a given dimension and + * visit each of its elements. For example, given a {@code n x m} matrix on the {@code [x, y]} axes, + * elements are iterated in the following order: + * + *

x0y0, x0y1, ..., x0ym-1, + * x1y0, x1y1, ..., xn-1ym-1 + * + * @param data type of the array being iterated + */ +public interface NdArraySequence> extends Iterable { + + /** + * Visit each elements of this iteration and their respective coordinates. + * + *

Important: the consumer method should not keep a reference to the coordinates as they + * might be mutable and reused during the iteration to improve performance. + * + * @param consumer method to invoke for each elements + */ + void forEachIndexed(BiConsumer consumer); + + /** + * Returns each element as a new slice. + * + *

Unlike conventional Java collections, elements of a {@code NdArraySequence} are transient, + * i.e. new {@code NdArray} instances are allocated for each iteration. To improve performance, + * the same instance can be recycled to view all elements of this sequence, using a {@link + * DataBufferWindow}. + * + *

In some cases though, it might be preferable to disable such optimizations to ensure that + * each element returned is a new slice of the original array. For example, if one or more + * elements visited must live beyond the scope of the sequence iteration, {@code asSlices()} makes + * sure that all elements returned by the sequence are unique instances. + * + *

{@code
+   * final List vectors = new ArrayList<>();
+   * IntNdArray matrix = NdArrays.ofInts(Shape.of(6, 6));
+   * ndArray.elements(0).forEach(e -> vectors::add);  // Not safe, as `e` might always be the same recycled instance
+   * ndArray.elements(0).asSlices().forEach(e -> vectors::add);  // Safe, each `e` is a distinct NdArray instance
+   * }
+ * + * @return a sequence that returns each elements iterated as a new slice + * @see DataBufferWindow + */ + NdArraySequence asSlices(); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java new file mode 100644 index 00000000000..6caceb7d24b --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/NdArrays.java @@ -0,0 +1,819 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.impl.dense.BooleanDenseNdArray; +import org.tensorflow.ndarray.impl.dense.ByteDenseNdArray; +import org.tensorflow.ndarray.impl.dense.DenseNdArray; +import org.tensorflow.ndarray.impl.dense.DoubleDenseNdArray; +import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray; +import org.tensorflow.ndarray.impl.dense.IntDenseNdArray; +import org.tensorflow.ndarray.impl.dense.LongDenseNdArray; +import org.tensorflow.ndarray.impl.dense.ShortDenseNdArray; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.BooleanSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.ByteSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.DoubleSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.FloatSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.IntSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.LongSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.ShortSparseNdArray; + +/** Utility class for instantiating {@link NdArray} objects. */ +public final class NdArrays { + + // BYTE ARRAYS + + /** + * Creates byte scalar (rank 0) initialized with the given value. + * + * @param value scalar value + * @return new byte scalar + */ + public static ByteNdArray scalarOf(byte value) { + return ofBytes(Shape.scalar()).setByte(value); + } + + /** + * Creates a byte vector (rank 1) initialized with the given values. + * + *

Modifying the data of the returned vector will also impact the values in the array passed in + * parameter. + * + * @param values vector values + * @return new byte vector + * @throws IllegalArgumentException if values is null + */ + public static ByteNdArray vectorOf(byte... values) { + if (values == null) { + throw new IllegalArgumentException("Values cannot be null"); + } + return wrap(Shape.of(values.length), DataBuffers.of(values, false, false)); + } + + /** + * Creates an N-dimensional array of bytes of the given shape. + * + *

All values are initialized to zeros. + * + * @param shape shape of the array + * @return new byte N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static ByteNdArray ofBytes(Shape shape) { + if (shape == null) { + throw new IllegalArgumentException("Shape cannot be null"); + } + return wrap(shape, DataBuffers.ofBytes(shape.size())); + } + + /** + * Wraps a buffer in a byte N-dimensional array of a given shape. + * + * @param shape shape of the array + * @param buffer buffer to wrap + * @return new byte N-dimensional array + * @throws IllegalArgumentException if shape is null, has unknown dimensions or has size bigger in + * the buffer size + */ + public static ByteNdArray wrap(Shape shape, ByteDataBuffer buffer) { + return ByteDenseNdArray.create(buffer, shape); + } + + /** + * Creates a Sparse array of byte values with a default value of zero + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D ByteNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter + * {@code values=[18, 3]} specifies that element {@code [1,3,1]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4,0]} of the NdArray has a value of {@code 3}. + * @param shape the shape of the dense array represented by this sparse array. + * @return the byte sparse array. + */ + public static ByteSparseNdArray sparseOf(LongNdArray indices, ByteNdArray values, Shape shape) { + return ByteSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + } + + /** + * Creates a Sparse array of byte values + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non default values. + * @param values A 1-D ByteNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter + * {@code values=[18, 3]} specifies that element {@code [1,3,1]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4,0]} of the NdArray has a value of {@code 3}. + * @param defaultValue Scalar value to set for indices not specified in 'indices' + * @param shape the shape of the dense array represented by this sparse array. + * @return the byte sparse array. + */ + public static ByteSparseNdArray sparseOf( + LongNdArray indices, ByteNdArray values, byte defaultValue, Shape shape) { + return ByteSparseNdArray.create(indices, values, defaultValue, DimensionalSpace.create(shape)); + } + + // LONG ARRAYS + + /** + * Creates long scalar (rank 0) initialized with the given value. + * + * @param value scalar value + * @return new long scalar + */ + public static LongNdArray scalarOf(long value) { + return ofLongs(Shape.scalar()).setLong(value); + } + + /** + * Creates a long vector (rank 1) initialized with the given values. + * + *

Modifying the data of the returned vector will also impact the values in the array passed in + * parameter. + * + * @param values vector values + * @return new long vector + * @throws IllegalArgumentException if values is null + */ + public static LongNdArray vectorOf(long... values) { + if (values == null) { + throw new IllegalArgumentException(); + } + return wrap(Shape.of(values.length), DataBuffers.of(values, false, false)); + } + + /** + * Creates an N-dimensional array of longs of the given shape. + * + *

All values are initialized to zeros. + * + * @param shape shape of the array + * @return new long N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static LongNdArray ofLongs(Shape shape) { + return wrap(shape, DataBuffers.ofLongs(shape.size())); + } + + /** + * Wraps a buffer in a long N-dimensional array of a given shape. + * + * @param shape shape of the array + * @param buffer buffer to wrap + * @return new long N-dimensional array + * @throws IllegalArgumentException if shape is null, has unknown dimensions or has size bigger in + * the buffer size + */ + public static LongNdArray wrap(Shape shape, LongDataBuffer buffer) { + return LongDenseNdArray.create(buffer, shape); + } + + /** + * Creates a Sparse array of long values with a default value of zero + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D LongNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter + * {@code values=[18L, 3L]} specifies that element {@code [1,3,1]} of the sparse NdArray has a + * value of {@code 18L}, and element {@code [2,4,0]} of the NdArray has a value of {@code 3L}. + * @param shape the shape of the dense array represented by this sparse array. + * @return the long sparse array. + */ + public static LongSparseNdArray sparseOf(LongNdArray indices, LongNdArray values, Shape shape) { + return LongSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + } + + /** + * Creates a Sparse array of long values with a default value of zero + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D LongNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter + * {@code values=[18L, 3L]} specifies that element {@code [1,3,1]} of the sparse NdArray has a + * value of {@code 18L}, and element {@code [2,4,0]} of the NdArray has a value of {@code 3L}. + * @param defaultValue Scalar value to set for indices not specified in 'indices' + * @param shape the shape of the dense array represented by this sparse array. + * @return the long sparse array. + */ + public static LongSparseNdArray sparseOf( + LongNdArray indices, LongNdArray values, long defaultValue, Shape shape) { + return LongSparseNdArray.create(indices, values, defaultValue, DimensionalSpace.create(shape)); + } + + // INT ARRAYS + + /** + * Creates long scalar (rank 0) initialized with the given value. + * + * @param value scalar value + * @return new long scalar + */ + public static IntNdArray scalarOf(int value) { + return ofInts(Shape.scalar()).setInt(value); + } + + /** + * Creates a int vector (rank 1) initialized with the given values. + * + *

Modifying the data of the returned vector will also impact the values in the array passed in + * parameter. + * + * @param values vector values + * @return new int vector + * @throws IllegalArgumentException if values is null + */ + public static IntNdArray vectorOf(int... values) { + if (values == null) { + throw new IllegalArgumentException(); + } + return wrap(Shape.of(values.length), DataBuffers.of(values, false, false)); + } + + /** + * Creates an N-dimensional array of ints of the given shape. + * + *

All values are initialized to zeros. + * + * @param shape shape of the array + * @return new int N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static IntNdArray ofInts(Shape shape) { + return wrap(shape, DataBuffers.ofInts(shape.size())); + } + + /** + * Wraps a buffer in an int N-dimensional array of a given shape. + * + * @param shape shape of the array + * @param buffer buffer to wrap + * @return new int N-dimensional array + * @throws IllegalArgumentException if shape is null, has unknown dimensions or has size bigger in + * the buffer size + */ + public static IntNdArray wrap(Shape shape, IntDataBuffer buffer) { + return IntDenseNdArray.create(buffer, shape); + } + + /** + * Creates a Sparse array of int values with a default value of zero. + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D IntNdArray of shape {@code [N]}, which supplies the values for each element + * in indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter {@code + * values=[18, 3]} specifies that element {@code [1,3,1]} of the sparse NdArray has a value of + * {@code 18}, and element {@code [2,4,0]} of the NdArray has a value of {@code 3}. + * @param shape the shape of the dense array represented by this sparse array. + * @return the int sparse array. + */ + public static IntSparseNdArray sparseOf(LongNdArray indices, IntNdArray values, Shape shape) { + return IntSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + } + + /** + * Creates a Sparse array of int values + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D IntNdArray of shape {@code [N]}, which supplies the values for each element + * in indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter {@code + * values=[18, 3]} specifies that element {@code [1,3,1]} of the sparse NdArray has a value of + * {@code 18}, and element {@code [2,4,0]} of the NdArray has a value of {@code 3}. + * @param defaultValue Scalar value to set for indices not specified in 'indices' + * @param shape the shape of the dense array represented by this sparse array. + * @return the int sparse array. + */ + public static IntSparseNdArray sparseOf( + LongNdArray indices, IntNdArray values, int defaultValue, Shape shape) { + return IntSparseNdArray.create(indices, values, defaultValue, DimensionalSpace.create(shape)); + } + + // SHORT ARRAYS + + /** + * Creates short scalar (rank 0) initialized with the given value. + * + * @param value scalar value + * @return new short scalar + */ + public static ShortNdArray scalarOf(short value) { + return ofShorts(Shape.scalar()).setShort(value); + } + + /** + * Creates a short vector (rank 1) initialized with the given values. + * + *

Modifying the data of the returned vector will also impact the values in the array passed in + * parameter. + * + * @param values vector values + * @return new short vector + * @throws IllegalArgumentException if values is null + */ + public static ShortNdArray vectorOf(short... values) { + if (values == null) { + throw new IllegalArgumentException(); + } + return wrap(Shape.of(values.length), DataBuffers.of(values, false, false)); + } + + /** + * Creates an N-dimensional array of shorts of the given shape. + * + *

All values are initialized to zeros. + * + * @param shape shape of the array + * @return new short N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static ShortNdArray ofShorts(Shape shape) { + return wrap(shape, DataBuffers.ofShorts(shape.size())); + } + + /** + * Wraps a buffer in a short N-dimensional array of a given shape. + * + * @param shape shape of the array + * @param buffer buffer to wrap + * @return new short N-dimensional array + * @throws IllegalArgumentException if shape is null, has unknown dimensions or has size bigger in + * the buffer size + */ + public static ShortNdArray wrap(Shape shape, ShortDataBuffer buffer) { + return ShortDenseNdArray.create(buffer, shape); + } + + /** + * Creates a Sparse array of short values with a default value of zero + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D ShortNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter + * {@code values=[18, 3]} specifies that element {@code [1,3,1]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4,0]} of the NdArray has a value of {@code 3}. + * @param shape the shape of the dense array represented by this sparse array. + * @return the short sparse array. + */ + public static ShortSparseNdArray sparseOf(LongNdArray indices, ShortNdArray values, Shape shape) { + return ShortSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + } + + /** + * Creates a Sparse array of short values + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D ShortNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter + * {@code values=[18, 3]} specifies that element {@code [1,3,1]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4,0]} of the NdArray has a value of {@code 3}. + * @param defaultValue Scalar value to set for indices not specified in 'indices' + * @param shape the shape of the dense array represented by this sparse array. + * @return the short sparse array. + */ + public static ShortSparseNdArray sparseOf( + LongNdArray indices, ShortNdArray values, short defaultValue, Shape shape) { + return ShortSparseNdArray.create(indices, values, defaultValue, DimensionalSpace.create(shape)); + } + + // FLOAT ARRAYS + + /** + * Creates float scalar (rank 0) initialized with the given value. + * + * @param value scalar value + * @return new float scalar + */ + public static FloatNdArray scalarOf(float value) { + return ofFloats(Shape.scalar()).setFloat(value); + } + + /** + * Creates a float vector (rank 1) initialized with the given values. + * + *

Modifying the data of the returned vector will also impact the values in the array passed in + * parameter. + * + * @param values vector values + * @return new float vector + * @throws IllegalArgumentException if values is null + */ + public static FloatNdArray vectorOf(float... values) { + if (values == null) { + throw new IllegalArgumentException(); + } + return wrap(Shape.of(values.length), DataBuffers.of(values, false, false)); + } + + /** + * Creates an N-dimensional array of floats of the given shape. + * + *

All values are initialized to zeros. + * + * @param shape shape of the array + * @return new float N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static FloatNdArray ofFloats(Shape shape) { + return wrap(shape, DataBuffers.ofFloats(shape.size())); + } + + /** + * Wraps a buffer in a float N-dimensional array of a given shape. + * + * @param shape shape of the array + * @param buffer buffer to wrap + * @return new float N-dimensional array + * @throws IllegalArgumentException if shape is null, has unknown dimensions or has size bigger in + * the buffer size + */ + public static FloatNdArray wrap(Shape shape, FloatDataBuffer buffer) { + return FloatDenseNdArray.create(buffer, shape); + } + + /** + * Creates a Sparse array of float values with a default value of zero + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D FloatNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter + * {@code values=[18f, 3.8f]} specifies that element {@code [1,3,1]} of the sparse NdArray has + * a value of {@code 18f}, and element {@code [2,4,0]} of the NdArray has a value of {@code + * 3.8f}. + * @param shape the shape of the dense array represented by this sparse array. + * @return the float sparse array. + */ + public static FloatSparseNdArray sparseOf(LongNdArray indices, FloatNdArray values, Shape shape) { + return FloatSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + } + + /** + * Creates a Sparse array of float values + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D FloatNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter + * {@code values=[18f, 3.8f]} specifies that element {@code [1,3,1]} of the sparse NdArray has + * a value of {@code 18f}, and element {@code [2,4,0]} of the NdArray has a value of {@code + * 3.8f}. + * @param defaultValue Scalar value to set for indices not specified in 'indices' + * @param shape the shape of the dense array represented by this sparse array. + * @return the float sparse array. + */ + public static FloatSparseNdArray sparseOf( + LongNdArray indices, FloatNdArray values, float defaultValue, Shape shape) { + return FloatSparseNdArray.create(indices, values, defaultValue, DimensionalSpace.create(shape)); + } + + // DOUBLE ARRAYS + + /** + * Creates double scalar (rank 0) initialized with the given value. + * + * @param value scalar value + * @return new double scalar + */ + public static DoubleNdArray scalarOf(double value) { + return ofDoubles(Shape.scalar()).setDouble(value); + } + + /** + * Creates a double vector (rank 1) initialized with the given values. + * + *

Modifying the data of the returned vector will also impact the values in the array passed in + * parameter. + * + * @param values vector values + * @return new double vector + * @throws IllegalArgumentException if values is null + */ + public static DoubleNdArray vectorOf(double... values) { + if (values == null) { + throw new IllegalArgumentException(); + } + return wrap(Shape.of(values.length), DataBuffers.of(values, false, false)); + } + + /** + * Creates an N-dimensional array of doubles of the given shape. + * + *

All values are initialized to zeros. + * + * @param shape shape of the array + * @return new double N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static DoubleNdArray ofDoubles(Shape shape) { + return wrap(shape, DataBuffers.ofDoubles(shape.size())); + } + + /** + * Wraps a buffer in a double N-dimensional array of a given shape. + * + * @param shape shape of the array + * @param buffer buffer to wrap + * @return new double N-dimensional array + * @throws IllegalArgumentException if shape is null, has unknown dimensions or has size bigger in + * the buffer size + */ + public static DoubleNdArray wrap(Shape shape, DoubleDataBuffer buffer) { + return DoubleDenseNdArray.create(buffer, shape); + } + + /** + * Creates a Sparse array of double values with a default value of zero + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D DoubleNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter + * {@code values=[18, 3.8]} specifies that element {@code [1,3,1]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4,0]} of the NdArray has a value of {@code 3.8}. + * @param shape the shape of the dense array represented by this sparse array. + * @return the float sparse array. + */ + public static DoubleSparseNdArray sparseOf( + LongNdArray indices, DoubleNdArray values, Shape shape) { + return DoubleSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + } + + /** + * Creates a Sparse array of double values + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D DoubleNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter + * {@code values=[18, 3.8]} specifies that element {@code [1,3,1]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4,0]} of the NdArray has a value of {@code 3.8}. + * @param defaultValue Scalar value to set for indices not specified in 'indices' + * @param shape the shape of the dense array represented by this sparse array. + * @return the float sparse array. + */ + public static DoubleSparseNdArray sparseOf( + LongNdArray indices, DoubleNdArray values, double defaultValue, Shape shape) { + return DoubleSparseNdArray.create( + indices, values, defaultValue, DimensionalSpace.create(shape)); + } + + // BOOLEAN ARRAYS + + /** + * Creates boolean scalar (rank 0) initialized with the given value. + * + * @param value scalar value + * @return new boolean scalar + */ + public static BooleanNdArray scalarOf(boolean value) { + return ofBooleans(Shape.scalar()).setBoolean(value); + } + + /** + * Creates a boolean vector (rank 1) initialized with the given values. + * + *

Modifying the data of the returned vector will also impact the values in the array passed in + * parameter. + * + * @param values vector values + * @return new boolean vector + * @throws IllegalArgumentException if values is null + */ + public static BooleanNdArray vectorOf(boolean... values) { + if (values == null) { + throw new IllegalArgumentException(); + } + return wrap(Shape.of(values.length), DataBuffers.of(values, false, false)); + } + + /** + * Creates an N-dimensional array of booleans of the given shape. + * + *

All values are initialized to zeros. + * + * @param shape shape of the array + * @return new boolean N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static BooleanNdArray ofBooleans(Shape shape) { + return wrap(shape, DataBuffers.ofBooleans(shape.size())); + } + + /** + * Wraps a buffer in a boolean N-dimensional array of a given shape. + * + * @param shape shape of the array + * @param buffer buffer to wrap + * @return new boolean N-dimensional array + * @throws IllegalArgumentException if shape is null, has unknown dimensions or has size bigger in + * the buffer size + */ + public static BooleanNdArray wrap(Shape shape, BooleanDataBuffer buffer) { + return BooleanDenseNdArray.create(buffer, shape); + } + + /** + * Creates a Sparse array of boolean values with a default value of 'false' + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D BooleanNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter + * {@code values=[true, true]} specifies that element {@code [1,3,1]} of the sparse NdArray + * has a value of true, and element {@code [2,4,0]} of the NdArray has a value of true. All + * other values are false. + * @param shape the shape of the dense array represented by this sparse array. + * @return the float sparse array. + */ + public static BooleanSparseNdArray sparseOf( + LongNdArray indices, BooleanNdArray values, Shape shape) { + return BooleanSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + } + + /** + * Creates a Sparse array of boolean values + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D BooleanNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter + * {@code values=[true, true]} specifies that element {@code [1,3,1]} of the sparse NdArray + * has a value of true, and element {@code [2,4,0]} of the NdArray has a value of true. All + * other values are false. + * @param defaultValue Scalar value to set for indices not specified in 'indices' + * @param shape the shape of the dense array represented by this sparse array. + * @return the float sparse array. + */ + public static BooleanSparseNdArray sparseOf( + LongNdArray indices, BooleanNdArray values, boolean defaultValue, Shape shape) { + return BooleanSparseNdArray.create( + indices, values, defaultValue, DimensionalSpace.create(shape)); + } + + // OBJECT ARRAYS + + /** + * Creates scalar (rank 0) initialized with the given value. + * + * @param value scalar value + * @param the data type + * @return new scalar + */ + @SuppressWarnings("unchecked") + public static NdArray scalarOfObject(T value) { + if (value == null) { + throw new IllegalArgumentException(); + } + return ofObjects((Class) value.getClass(), Shape.scalar()).setObject(value); + } + + /** + * Creates a vector (rank 1) initialized with the given values. + * + *

Modifying the data of the returned vector will also impact the values in the array passed in + * parameter. + * + * @param values vector values + * @param the data type + * @return new vector + * @throws IllegalArgumentException if values is null + */ + @SafeVarargs + public static NdArray vectorOfObjects(T... values) { + if (values == null || values.length == 0) { + throw new IllegalArgumentException("Null or zero length input supplied to vectorOfObjects."); + } + return wrap(Shape.of(values.length), DataBuffers.of(values, false, false)); + } + + /** + * Creates an N-dimensional array of the given shape. + * + *

All values are initialized to zeros. + * + * @param clazz class of the data to be stored in this array + * @param shape shape of the array + * @param the data type + * @return new N-dimensional array + * @throws IllegalArgumentException if shape is null or has unknown dimensions + */ + public static NdArray ofObjects(Class clazz, Shape shape) { + return wrap(shape, DataBuffers.ofObjects(clazz, shape.size())); + } + + /** + * Wraps a buffer in an N-dimensional array of a given shape. + * + * @param shape shape of the array + * @param buffer buffer to wrap + * @param the data type + * @return new N-dimensional array + * @throws IllegalArgumentException if shape is null, has unknown dimensions or has size bigger in + * the buffer size + */ + public static NdArray wrap(Shape shape, DataBuffer buffer) { + return DenseNdArray.wrap(buffer, shape); + } + + /** + * Creates a Sparse array of values with a null default value + * + * @param type the class type represented by this sparse array. + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D NdArray of shape {@code [N]}, which supplies the values for each element in + * indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter {@code + * values=["one", "two"]} specifies that element {@code [1,3,1]} of the sparse NdArray has a + * value of "one", and element {@code [2,4,0]} of the NdArray has a value of "two"". All other + * values are null. + * @param shape the shape of the dense array represented by this sparse array. + * @return the float sparse array. + */ + public static NdArray sparseOfObjects( + Class type, LongNdArray indices, NdArray values, Shape shape) { + return org.tensorflow.ndarray.impl.sparse.SparseNdArray.create( + type, indices, values, DimensionalSpace.create(shape)); + } + + /** + * Creates a Sparse array of values + * + * @param type the class type represented by this sparse array. + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3,1], [2,4,0]]} specifies that the elements with indexes of + * {@code [1,3,1]} and {@code [2,4,0]} have non-default values. + * @param values A 1-D NdArray of shape {@code [N]}, which supplies the values for each element in + * indices. For example, given {@code indices=[[1,3,1], [2,4,0]]}, the parameter {@code + * values=["one", "two"]} specifies that element {@code [1,3,1]} of the sparse NdArray has a + * value of "one", and element {@code [2,4,0]} of the NdArray has a value of "two"". All other + * values are null. + * @param defaultValue Scalar value to set for indices not specified in 'indices' + * @param shape the shape of the dense array represented by this sparse array. + * @return the float sparse array. + */ + public static NdArray sparseOfObjects( + Class type, LongNdArray indices, NdArray values, T defaultValue, Shape shape) { + return org.tensorflow.ndarray.impl.sparse.SparseNdArray.create( + type, indices, values, defaultValue, DimensionalSpace.create(shape)); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/Shape.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/Shape.java new file mode 100644 index 00000000000..9dbbdc44e74 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/Shape.java @@ -0,0 +1,498 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ + +package org.tensorflow.ndarray; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * The shape of a Tensor or {@link NdArray}. + * + *

A {@code Shape} defines sizes along its axes. It may contain an unknown size for one of the + * axes or may be totally unknown, in which case not even the number of axes is known. If the size + * of an axis is unknown, {@link Shape#UNKNOWN_SIZE} should be used as its size. + */ +public final class Shape { + + /** The size of an unknown axis or the total unknown size for an unknown Shape. */ + public static long UNKNOWN_SIZE = -1L; + + /** + * Creates a Shape representing an unknown number of dimensions. + * + * @return A Shape for which {@link Shape#isUnknown()} is true, never null. + */ + public static Shape unknown() { + return new Shape(null); + } + + /** + * Creates a Shape representing a scalar value. + * + * @return A Shape without dimensions for which {@link Shape#isScalar()} is true, never null. + */ + public static Shape scalar() { + return new Shape(new long[0]); + } + + /** + * Create a Shape representing a scalar or an N-dimensional value. + * + *

Creates a Shape representing a scalar or an N-dimensional value (N being at least 1), with + * the provided size for each dimension. A -1 indicates that the size of the corresponding + * dimension is unknown. If no sizes are provided, a Shape representing a scalar is created. For + * example: + * + *

{@code
+   * // A 2-element vector.
+   * Shape vector = Shape.of(2);
+   *
+   * // A 2x3 matrix.
+   * Shape matrix = Shape.of(2, 3);
+   *
+   * // A matrix with 4 columns but an unknown number of rows.
+   * // This is typically used to indicate the shape of tensors that represent
+   * // a variable-sized batch of values. The Shape below might represent a
+   * // variable-sized batch of 4-element vectors.
+   * Shape batch = Shape.of(-1, 4);
+   *
+   * // A scalar. For readability, you should prefer calling Shape.scalar()
+   * Shape scalar = Shape.of()
+   * }
+ * + * @param dimensionSizes number of elements in each dimension of this shape, if any, or {@link + * Shape#UNKNOWN_SIZE} if unknown. + * @return a new shape + */ + public static Shape of(long... dimensionSizes) { + if (dimensionSizes == null || dimensionSizes.length == 0) { + return scalar(); + } + return new Shape(dimensionSizes); + } + + /** + * Returns the total number of elements a Tensor with this Shape would have. + * + *

If {@link Shape#isUnknown()} is true or {@link Shape#hasUnknownDimension()} is true, {@link + * Shape#UNKNOWN_SIZE} is returned. + * + * @return The total number of elements a Tensor with this shape would have if it can be + * calculated, else {@link Shape#UNKNOWN_SIZE}. + */ + public long size() { + if (size == null) { + size = computeSize(dimensionSizes); + } + return size; + } + + /** + * The size of the dimension with the given index. + * + *

If {@link Shape#isUnknown()} is true or the size of the dimension with the given index has + * an unknown size, {@link Shape#UNKNOWN_SIZE} is returned. + * + * @param i the index of the dimension to get the size for. If this Shape has a known number of + * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in + * which case the position is counted from the end of the shape. E.g.: {@code size(-1)} + * returns the size of the last dimension, {@code size(-2)} the size of the second to last + * dimension etc. + * @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE} + * otherwise. + * @deprecated Renamed to {@link #get(int)}. + */ + @Deprecated + public long size(int i) { + return get(i); + } + + /** + * The size of the dimension with the given index. + * + *

If {@link Shape#isUnknown()} is true or the size of the dimension with the given index has + * an unknown size, {@link Shape#UNKNOWN_SIZE} is returned. + * + * @param i the index of the dimension to get the size for. If this Shape has a known number of + * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in + * which case the position is counted from the end of the shape. E.g.: {@code size(-1)} + * returns the size of the last dimension, {@code size(-2)} the size of the second to last + * dimension etc. + * @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE} + * otherwise. + */ + public long get(int i) { + if (dimensionSizes == null) { + return UNKNOWN_SIZE; + } else if (i >= 0) { + return dimensionSizes[i]; + } else { + return dimensionSizes[dimensionSizes.length + i]; + } + } + + /** + * Returns the number of dimensions of this Shape. -1 if unknown, 0 for a scalar, 1 for a vector, + * 2 for a matrix etc. + */ + public int numDimensions() { + return dimensionSizes != null ? dimensionSizes.length : -1; + } + + /** Returns whether one or more dimensions of this Shape have an unknown size. */ + public boolean hasUnknownDimension() { + if (dimensionSizes == null) { + return true; + } + for (long dimSize : dimensionSizes) { + if (dimSize == UNKNOWN_SIZE) { + return true; + } + } + return false; + } + + /** Returns whether this Shape represents a scalar. */ + public boolean isScalar() { + return dimensionSizes != null && dimensionSizes.length == 0; + } + + /** Returns whether this Shape is the shape of a vector. */ + public boolean isVector() { + return dimensionSizes != null && dimensionSizes.length == 1; + } + + /** Returns whether this Shape is the shape of a matrix */ + public boolean isMatrix() { + return dimensionSizes != null && dimensionSizes.length == 2; + } + + /** Returns whether the number of dimensions of this Shape is unknown. */ + public boolean isUnknown() { + return dimensionSizes == null; + } + + /** + * Returns a defensive copy of the this Shape's axes. Changes to the returned array to not change + * this Shape's state. Returns null if {@link Shape#isUnknown()} is true. + */ + public long[] asArray() { + if (this.dimensionSizes == null) { + return null; + } else { + return Arrays.copyOf(dimensionSizes, dimensionSizes.length); + } + } + + /** + * Returns a defensive copy of the this Shape's axes. Changes to the returned list do not change + * this Shape's state. Returns null if {@link Shape#isUnknown()} is true. + */ + public List toListOrNull() { + long[] array = asArray(); + if (array == null) { + return null; + } + + List list = new ArrayList<>(array.length); + for (long l : array) { + list.add(l); + } + + return list; + } + + @Override + public int hashCode() { + return dimensionSizes != null ? Arrays.hashCode(dimensionSizes) : super.hashCode(); + } + + /** + * Equals implementation for Shapes. Two Shapes are considered equal iff: + * + *

+ * + *

    + *
  • the number of dimensions is defined and equal for both + *
  • the size of each dimension is defined and equal for both + *
+ * + *

If either Shape has unknown dimensions (even if they are the same in both) or if either + * shape has an unknown number of dimensions (even if both return {@code true} for {@link + * Shape#isUnknown()}), they are not considered equal! However, a shape will always equal itself, + * even if it is unknown or contains unknown dimensions. + */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + // Shapes are equivalent if all of their dimensions are equals + if (obj instanceof Shape) { + Shape otherShape = (Shape) obj; + if (otherShape.hasUnknownDimension()) { + return false; + } + return Arrays.equals(dimensionSizes, otherShape.dimensionSizes); + } + return false; + } + + /** Succinct description of the Shape meant for debugging. */ + @Override + public String toString() { + return Arrays.toString(dimensionSizes); + } + + private Shape(long[] dimensionSizes) { + this.dimensionSizes = dimensionSizes; + } + + private final long[] dimensionSizes; + private Long size; + + /** + * Returns a 1-dimensional Shape with first dimension matching the first dimension of this Shape. + */ + public Shape head() { + return take(1); + } + + /** + * Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this + * shape + * + * @param n the number of leading dimensions to get, must be <= than {@link + * Shape#numDimensions()} + * @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of + * this Shape + */ + public Shape take(int n) { + if (n > numDimensions()) { + throw new ArrayIndexOutOfBoundsException( + "Cannot take " + n + " dimensions, shape has only " + numDimensions() + "."); + } + long[] newDimensions = new long[n]; + System.arraycopy(dimensionSizes, 0, newDimensions, 0, n); + return Shape.of(newDimensions); + } + + /** Returns a new Shape, with this Shape's first dimension removed. */ + public Shape tail() { + if (dimensionSizes.length < 2) { + return Shape.of(); + } + return Shape.of(Arrays.copyOfRange(dimensionSizes, 1, dimensionSizes.length)); + } + + /** + * Returns an n-dimensional Shape with the dimensions matching the last n dimensions of this + * Shape. + * + * @param n the number of trailing dimensions to get, must be <= than {@link + * Shape#numDimensions()} + * @return an n-dimensional shape with the dimensions matching the last n dimensions of this + * Shape, never null + */ + public Shape takeLast(int n) { + if (n > numDimensions()) { + throw new ArrayIndexOutOfBoundsException( + "Cannot take last " + n + " dimensions, shape has only " + numDimensions() + "."); + } + long[] newDimensions = new long[n]; + System.arraycopy(dimensionSizes, numDimensions() - n, newDimensions, 0, n); + return Shape.of(newDimensions); + } + + /** + * Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code + * begin} to {@code end}. + * + * @param begin Where to start the sub-shape. + * @param end Where to end the sub-shape, exclusive. + * @return the sub-shape bounded by begin and end. + */ + public Shape subShape(int begin, int end) { + if (end > numDimensions()) { + throw new ArrayIndexOutOfBoundsException( + "End index " + + end + + " out of bounds: shape only has " + + numDimensions() + + " dimensions."); + } + if (begin < 0) { + throw new ArrayIndexOutOfBoundsException( + "Begin index " + begin + " out of bounds: cannot be less than 0."); + } + + long[] newDimensions = new long[end - begin]; + System.arraycopy(dimensionSizes, begin, newDimensions, 0, end - begin); + return Shape.of(newDimensions); + } + + /** + * Returns a new Shape, with a new first dimension added. In order for this call to succeed, + * {@link Shape#isUnknown()} must be {@code false}. + * + * @param firstDimension the dimension to prepend + * @return a new shape with the given dimension first, followed by this Shape's dimensions, never + * null + */ + public Shape prepend(long firstDimension) { + long[] newDimensions = new long[dimensionSizes.length + 1]; + newDimensions[0] = firstDimension; + System.arraycopy(dimensionSizes, 0, newDimensions, 1, dimensionSizes.length); + + return Shape.of(newDimensions); + } + + /** + * Returns a new Shape, with a new last dimension added. In order for this call to succeed, {@link + * Shape#isUnknown()} must be {@code false}. + * + * @param lastDimension the dimension to append + * @return a new Shape with this Shape's dimensions followed by the given dimension, never null + */ + public Shape append(long lastDimension) { + long[] newDimensions = new long[dimensionSizes.length + 1]; + newDimensions[newDimensions.length - 1] = lastDimension; + System.arraycopy(dimensionSizes, 0, newDimensions, 0, dimensionSizes.length); + + return Shape.of(newDimensions); + } + + /** + * Returns a new Shape, with another Shape's dimensions prepended. For both this Shape and the + * other Shape, {@link Shape#isUnknown()} must return false. E.g. {@code + * Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) } + * + * @param other another Shape, must not be {@code null}, must not be unknown + * @return A new Shape consisting of the given Shape's dimensions followed by this Shape's + * dimensions, never null + */ + public Shape prepend(Shape other) { + long[] newDimensions = new long[other.dimensionSizes.length + dimensionSizes.length]; + System.arraycopy(other.dimensionSizes, 0, newDimensions, 0, other.dimensionSizes.length); + System.arraycopy( + dimensionSizes, 0, newDimensions, other.dimensionSizes.length, dimensionSizes.length); + return Shape.of(newDimensions); + } + + /** + * Returns a new Shape, with another Shapes' dimensions appended. For both this Shape and the + * other Shape, {@link Shape#isUnknown()} must return false. E.g. @code + * Shape.of(3,4).append(Shape.of(1,2)) => Shape.of(3,4,1,2) } + * + * @param other another Shape, must not be {@code null}, must not be unknown + * @return A new Shape consisting of this Shape's dimensions followed by the given Shape's + * dimensions + */ + public Shape append(Shape other) { + long[] newDimensions = new long[dimensionSizes.length + other.dimensionSizes.length]; + System.arraycopy(dimensionSizes, 0, newDimensions, 0, dimensionSizes.length); + System.arraycopy( + other.dimensionSizes, 0, newDimensions, dimensionSizes.length, other.dimensionSizes.length); + return Shape.of(newDimensions); + } + + private static long computeSize(long[] dimensionSizes) { + if (dimensionSizes == null) { + return UNKNOWN_SIZE; + } + long computedSize = 1L; + for (long dimensionSize : dimensionSizes) { + if (dimensionSize == UNKNOWN_SIZE) { + return UNKNOWN_SIZE; + } + computedSize *= dimensionSize; + } + return computedSize; + } + + /** + * Determines whether another shape is compatible with this one. + * + *

+ * + *

Two possibly-partially-defined shapes are compatible if there exists a fully-defined shape + * that both shapes can represent. Thus, compatibility allows the shape inference code to reason + * about partially-defined shapes. For example: + * + *

    + *
  • Shape.unknown() is compatible with all shapes. + *
  • Shape(UNKNOWN_SIZE, UNKNOWN_SIZE) is compatible with all two-dimensional + * shapes, such as Shape(32, 784), and also Shape.unknown(). It is + * not compatible with, for example, Shape(UNKNOWN_SIZE) or + * Shape(UNKNOWN_SIZE, UNKNOWN_SIZE, UNKNOWN_SIZE). + *
  • Shape(32, UNKNOWN_SIZE) is compatible with all two-dimensional shapes with + * size 32 in the 0th dimension, and also Shape(UNKNOWN_SIZE, UNKNOWN_SIZE) and + * Shape.unknown(). It is not compatible with, for example, Shape(32) + * , Shape(32, UNKNOWN_SIZE, 1) or Shape(64, UNKNOWN_SIZE). + *
  • Shape(32, 784) is compatible with itself, and also + * Shape(32, UNKNOWN_SIZE), Shape(UNKNOWN_SIZE, 784), + * Shape(UNKNOWN_SIZE, UNKNOWN_SIZE) and Shape.unknown(). It is not + * compatible with, for example, Shape(32, 1, 784) or Shape(UNKNOWN_SIZE) + * . + *
+ * + *

The compatibility relation is reflexive and symmetric, but not transitive. For example, + * Shape(32, 784) is compatible with Shape.unknown(), and + * Shape.unknown() is compatible with Shape(4, 4), but Shape(32, 784) + * is not compatible with Shape(4, 4). + * + *

Compatibility is not the same as broadcasting. Compatible shapes must have the same number + * of dimensions and for each dimension pair, one dimension has to equal the other dimensions or + * at least one of the dimensions in the pair has to be UNKNOWN_SIZE. + * + *

Broadcasting allows different dimensions, but paired dimensions have to either be equal, or + * one dimension must be 1. If one shape has less dimensions than another shape, the smaller shape + * is "stretched" with dimensions of 1. + * + * @param shape The other shape + * @return true, if the two shapes are compatible. + */ + public boolean isCompatibleWith(Shape shape) { + if (!this.isUnknown() && !shape.isUnknown()) { + if (numDimensions() != shape.numDimensions()) { + return false; + } + for (int i = 0; i < numDimensions(); i++) { + if (!isCompatible(get(i), shape.get(i))) { + return false; + } + } + } + return true; + } + + /** + * Test to see if two shape dimensions are compatible. + * + *

The dimensions are compatible if either dimension is Shape.UNKNOWN_SIZE or both + * dimensions are equal + * + * @param dim the first dimension + * @param otherDim the second dimension + * @return true, if both dimensions are compatible + */ + public static boolean isCompatible(long dim, long otherDim) { + return dim == Shape.UNKNOWN_SIZE || otherDim == Shape.UNKNOWN_SIZE || dim == otherDim; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/Shaped.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/Shaped.java new file mode 100644 index 00000000000..244550bb4a7 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/Shaped.java @@ -0,0 +1,44 @@ +/* +Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +/** Any data container with a given {@link Shape}. */ +public interface Shaped { + + /** + * @return the shape of this container + */ + Shape shape(); + + /** + * @return the rank of this container + */ + default int rank() { + return shape().numDimensions(); + } + + /** + * Computes and returns the total size of this container, in number of values. + * + *

For example, given a 3x3x2 matrix, the return value will be 18. + * + * @return number of values in this element + */ + default long size() { + return shape().size(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/ShortNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/ShortNdArray.java new file mode 100644 index 00000000000..1cf837cd15e --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/ShortNdArray.java @@ -0,0 +1,115 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.index.Index; + +/** An {@link NdArray} of shorts. */ +public interface ShortNdArray extends NdArray { + + /** + * Returns the short value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * ShortNdArray matrix = NdArrays.ofShorts(shape(2, 2));  // matrix rank = 2
+   * matrix.getShort(0, 1);  // succeeds, returns 0.0f
+   * matrix.getShort(0);  // throws IllegalRankException
+   *
+   * ShortNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.getShort();  // succeeds, returns 0.0f
+   * }
+ * + * @param coordinates coordinates of the scalar to resolve + * @return value of that scalar + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + short getShort(long... coordinates); + + /** + * Assigns the short value of the scalar found at the given coordinates. + * + *

To access the scalar element, the number of coordinates provided must be equal to the number + * of dimensions of this array (i.e. its rank). For example: + * + *

{@code
+   * ShortNdArray matrix = NdArrays.ofShorts(shape(2, 2));  // matrix rank = 2
+   * matrix.setShort(10.0f, 0, 1);  // succeeds
+   * matrix.setShort(10.0f, 0);  // throws IllegalRankException
+   *
+   * ShortNdArray scalar = matrix.get(0, 1);  // scalar rank = 0
+   * scalar.setShort(10.0f);  // succeeds
+   * }
+ * + * @param value value to assign + * @param coordinates coordinates of the scalar to assign + * @return this array + * @throws IndexOutOfBoundsException if some coordinates are outside the limits of their + * respective dimension + * @throws IllegalRankException if number of coordinates is not sufficient to access a scalar + * element + */ + ShortNdArray setShort(short value, long... coordinates); + + @Override + ShortNdArray withShape(Shape shape); + + @Override + ShortNdArray slice(Index... coordinates); + + @Override + ShortNdArray get(long... coordinates); + + @Override + ShortNdArray set(NdArray src, long... coordinates); + + @Override + default Short getObject(long... coordinates) { + return getShort(coordinates); + } + + @Override + default ShortNdArray setObject(Short value, long... coordinates) { + return setShort(value, coordinates); + } + + @Override + NdArraySequence elements(int dimensionIdx); + + @Override + NdArraySequence scalars(); + + @Override + ShortNdArray copyTo(NdArray dst); + + @Override + ShortNdArray copyTo(DataBuffer dst); + + ShortNdArray copyTo(ShortDataBuffer dst); + + @Override + ShortNdArray copyFrom(DataBuffer src); + + ShortNdArray copyFrom(ShortDataBuffer src); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/SparseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/SparseNdArray.java new file mode 100644 index 00000000000..ab91d1c1448 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/SparseNdArray.java @@ -0,0 +1,50 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray; + +/** + * Interface for Sparse Arrays + * + * @param the type that the array contains + * @param the type of dense NdArray + */ +public interface SparseNdArray> extends NdArray { + /** + * Gets the Indices + * + *

Indices are a A 2-D long array of shape {@code [N, ndims]}, that specifies the indices of + * the elements in the sparse array that contain nonzero values (elements are zero-indexed). + * + *

For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * coordinates {@code [1,3]} and {@code [2,4]} have nonzero values. + * + * @return the Indices + */ + LongNdArray getIndices(); + + /** + * Gets the values. + * + *

Values are a 1-D array of any type and shape {@code [N]}, that supplies the values for each + * element in indices. + * + *

For example, given {@code indices=[[1,3], [2,4]]}, and {@code values=[18, 3.6]} specifies + * that element {@code [1,3]} of the sparse array has a value of {@code 18}, and element {@code + * [2,4]} of the sparse array has a value of {@code 3.6}. + * + * @return the values + */ + U getValues(); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/StdArrays.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/StdArrays.java new file mode 100644 index 00000000000..3ec5ec77df8 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/StdArrays.java @@ -0,0 +1,3898 @@ +package org.tensorflow.ndarray; + +import java.lang.reflect.Array; +import org.tensorflow.ndarray.buffer.DataBuffers; + +/** Utility class for working with {@link NdArray} instances mixed with standard Java arrays. */ +public final class StdArrays { + + /** + * Copy an array of ints in a new {@link IntNdArray} + * + * @param array source array + * @return the {@code IntNdArray} copy + */ + public static IntNdArray ndCopyOf(int[] array) { + IntNdArray ndArray = NdArrays.ofInts(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 2-dimensional array of ints in a new {@link IntNdArray} + * + * @param array source array + * @return the {@code IntNdArray} copy + */ + public static IntNdArray ndCopyOf(int[][] array) { + IntNdArray ndArray = NdArrays.ofInts(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 3-dimensional array of ints in a new {@link IntNdArray} + * + * @param array source array + * @return the {@code IntNdArray} copy + */ + public static IntNdArray ndCopyOf(int[][][] array) { + IntNdArray ndArray = NdArrays.ofInts(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 4-dimensional array of ints in a new {@link IntNdArray} + * + * @param array source array + * @return the {@code IntNdArray} copy + */ + public static IntNdArray ndCopyOf(int[][][][] array) { + IntNdArray ndArray = NdArrays.ofInts(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 5-dimensional array of ints in a new {@link IntNdArray} + * + * @param array source array + * @return the {@code IntNdArray} copy + */ + public static IntNdArray ndCopyOf(int[][][][][] array) { + IntNdArray ndArray = NdArrays.ofInts(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 6-dimensional array of ints in a new {@link IntNdArray} + * + * @param array source array + * @return the {@code IntNdArray} copy + */ + public static IntNdArray ndCopyOf(int[][][][][][] array) { + IntNdArray ndArray = NdArrays.ofInts(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy an array of longs in a new {@link LongNdArray} + * + * @param array source array + * @return the {@code LongNdArray} copy + */ + public static LongNdArray ndCopyOf(long[] array) { + LongNdArray ndArray = NdArrays.ofLongs(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 2-dimensional array of longs in a new {@link LongNdArray} + * + * @param array source array + * @return the {@code LongNdArray} copy + */ + public static LongNdArray ndCopyOf(long[][] array) { + LongNdArray ndArray = NdArrays.ofLongs(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 3-dimensional array of longs in a new {@link LongNdArray} + * + * @param array source array + * @return the {@code LongNdArray} copy + */ + public static LongNdArray ndCopyOf(long[][][] array) { + LongNdArray ndArray = NdArrays.ofLongs(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 4-dimensional array of longs in a new {@link LongNdArray} + * + * @param array source array + * @return the {@code LongNdArray} copy + */ + public static LongNdArray ndCopyOf(long[][][][] array) { + LongNdArray ndArray = NdArrays.ofLongs(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 5-dimensional array of longs in a new {@link LongNdArray} + * + * @param array source array + * @return the {@code LongNdArray} copy + */ + public static LongNdArray ndCopyOf(long[][][][][] array) { + LongNdArray ndArray = NdArrays.ofLongs(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 6-dimensional array of longs in a new {@link LongNdArray} + * + * @param array source array + * @return the {@code LongNdArray} copy + */ + public static LongNdArray ndCopyOf(long[][][][][][] array) { + LongNdArray ndArray = NdArrays.ofLongs(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy an array of floats in a new {@link FloatNdArray} + * + * @param array source array + * @return the {@code FloatNdArray} copy + */ + public static FloatNdArray ndCopyOf(float[] array) { + FloatNdArray ndArray = NdArrays.ofFloats(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 2-dimensional array of floats in a new {@link FloatNdArray} + * + * @param array source array + * @return the {@code FloatNdArray} copy + */ + public static FloatNdArray ndCopyOf(float[][] array) { + FloatNdArray ndArray = NdArrays.ofFloats(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 3-dimensional array of floats in a new {@link FloatNdArray} + * + * @param array source array + * @return the {@code FloatNdArray} copy + */ + public static FloatNdArray ndCopyOf(float[][][] array) { + FloatNdArray ndArray = NdArrays.ofFloats(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 4-dimensional array of floats in a new {@link FloatNdArray} + * + * @param array source array + * @return the {@code FloatNdArray} copy + */ + public static FloatNdArray ndCopyOf(float[][][][] array) { + FloatNdArray ndArray = NdArrays.ofFloats(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 5-dimensional array of floats in a new {@link FloatNdArray} + * + * @param array source array + * @return the {@code FloatNdArray} copy + */ + public static FloatNdArray ndCopyOf(float[][][][][] array) { + FloatNdArray ndArray = NdArrays.ofFloats(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 6-dimensional array of floats in a new {@link FloatNdArray} + * + * @param array source array + * @return the {@code FloatNdArray} copy + */ + public static FloatNdArray ndCopyOf(float[][][][][][] array) { + FloatNdArray ndArray = NdArrays.ofFloats(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy an array of doubles in a new {@link DoubleNdArray} + * + * @param array source array + * @return the {@code DoubleNdArray} copy + */ + public static DoubleNdArray ndCopyOf(double[] array) { + DoubleNdArray ndArray = NdArrays.ofDoubles(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 2-dimensional array of doubles in a new {@link DoubleNdArray} + * + * @param array source array + * @return the {@code DoubleNdArray} copy + */ + public static DoubleNdArray ndCopyOf(double[][] array) { + DoubleNdArray ndArray = NdArrays.ofDoubles(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 3-dimensional array of doubles in a new {@link DoubleNdArray} + * + * @param array source array + * @return the {@code DoubleNdArray} copy + */ + public static DoubleNdArray ndCopyOf(double[][][] array) { + DoubleNdArray ndArray = NdArrays.ofDoubles(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 4-dimensional array of doubles in a new {@link DoubleNdArray} + * + * @param array source array + * @return the {@code DoubleNdArray} copy + */ + public static DoubleNdArray ndCopyOf(double[][][][] array) { + DoubleNdArray ndArray = NdArrays.ofDoubles(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 5-dimensional array of doubles in a new {@link DoubleNdArray} + * + * @param array source array + * @return the {@code DoubleNdArray} copy + */ + public static DoubleNdArray ndCopyOf(double[][][][][] array) { + DoubleNdArray ndArray = NdArrays.ofDoubles(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 6-dimensional array of doubles in a new {@link DoubleNdArray} + * + * @param array source array + * @return the {@code DoubleNdArray} copy + */ + public static DoubleNdArray ndCopyOf(double[][][][][][] array) { + DoubleNdArray ndArray = NdArrays.ofDoubles(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy an array of bytes in a new {@link ByteNdArray} + * + * @param array source array + * @return the {@code ByteNdArray} copy + */ + public static ByteNdArray ndCopyOf(byte[] array) { + ByteNdArray ndArray = NdArrays.ofBytes(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 2-dimensional array of bytes in a new {@link ByteNdArray} + * + * @param array source array + * @return the {@code ByteNdArray} copy + */ + public static ByteNdArray ndCopyOf(byte[][] array) { + ByteNdArray ndArray = NdArrays.ofBytes(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 3-dimensional array of bytes in a new {@link ByteNdArray} + * + * @param array source array + * @return the {@code ByteNdArray} copy + */ + public static ByteNdArray ndCopyOf(byte[][][] array) { + ByteNdArray ndArray = NdArrays.ofBytes(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 4-dimensional array of bytes in a new {@link ByteNdArray} + * + * @param array source array + * @return the {@code ByteNdArray} copy + */ + public static ByteNdArray ndCopyOf(byte[][][][] array) { + ByteNdArray ndArray = NdArrays.ofBytes(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 5-dimensional array of bytes in a new {@link ByteNdArray} + * + * @param array source array + * @return the {@code ByteNdArray} copy + */ + public static ByteNdArray ndCopyOf(byte[][][][][] array) { + ByteNdArray ndArray = NdArrays.ofBytes(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 6-dimensional array of bytes in a new {@link ByteNdArray} + * + * @param array source array + * @return the {@code ByteNdArray} copy + */ + public static ByteNdArray ndCopyOf(byte[][][][][][] array) { + ByteNdArray ndArray = NdArrays.ofBytes(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy an array of shorts in a new {@link ShortNdArray} + * + * @param array source array + * @return the {@code ShortNdArray} copy + */ + public static ShortNdArray ndCopyOf(short[] array) { + ShortNdArray ndArray = NdArrays.ofShorts(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 2-dimensional array of shorts in a new {@link ShortNdArray} + * + * @param array source array + * @return the {@code ShortNdArray} copy + */ + public static ShortNdArray ndCopyOf(short[][] array) { + ShortNdArray ndArray = NdArrays.ofShorts(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 3-dimensional array of shorts in a new {@link ShortNdArray} + * + * @param array source array + * @return the {@code ShortNdArray} copy + */ + public static ShortNdArray ndCopyOf(short[][][] array) { + ShortNdArray ndArray = NdArrays.ofShorts(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 4-dimensional array of shorts in a new {@link ShortNdArray} + * + * @param array source array + * @return the {@code ShortNdArray} copy + */ + public static ShortNdArray ndCopyOf(short[][][][] array) { + ShortNdArray ndArray = NdArrays.ofShorts(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 5-dimensional array of shorts in a new {@link ShortNdArray} + * + * @param array source array + * @return the {@code ShortNdArray} copy + */ + public static ShortNdArray ndCopyOf(short[][][][][] array) { + ShortNdArray ndArray = NdArrays.ofShorts(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 6-dimensional array of shorts in a new {@link ShortNdArray} + * + * @param array source array + * @return the {@code ShortNdArray} copy + */ + public static ShortNdArray ndCopyOf(short[][][][][][] array) { + ShortNdArray ndArray = NdArrays.ofShorts(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy an array of booleans in a new {@link BooleanNdArray} + * + * @param array source array + * @return the {@code BooleanNdArray} copy + */ + public static BooleanNdArray ndCopyOf(boolean[] array) { + BooleanNdArray ndArray = NdArrays.ofBooleans(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 2-dimensional array of booleans in a new {@link BooleanNdArray} + * + * @param array source array + * @return the {@code BooleanNdArray} copy + */ + public static BooleanNdArray ndCopyOf(boolean[][] array) { + BooleanNdArray ndArray = NdArrays.ofBooleans(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 3-dimensional array of booleans in a new {@link BooleanNdArray} + * + * @param array source array + * @return the {@code BooleanNdArray} copy + */ + public static BooleanNdArray ndCopyOf(boolean[][][] array) { + BooleanNdArray ndArray = NdArrays.ofBooleans(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 4-dimensional array of booleans in a new {@link BooleanNdArray} + * + * @param array source array + * @return the {@code BooleanNdArray} copy + */ + public static BooleanNdArray ndCopyOf(boolean[][][][] array) { + BooleanNdArray ndArray = NdArrays.ofBooleans(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 5-dimensional array of booleans in a new {@link BooleanNdArray} + * + * @param array source array + * @return the {@code BooleanNdArray} copy + */ + public static BooleanNdArray ndCopyOf(boolean[][][][][] array) { + BooleanNdArray ndArray = NdArrays.ofBooleans(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 6-dimensional array of booleans in a new {@link BooleanNdArray} + * + * @param array source array + * @return the {@code BooleanNdArray} copy + */ + public static BooleanNdArray ndCopyOf(boolean[][][][][][] array) { + BooleanNdArray ndArray = NdArrays.ofBooleans(shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy an array of objects in a new {@link NdArray} + * + * @param array source array + * @param data type + * @return the {@code NdArray} copy + */ + public static NdArray ndCopyOf(T[] array) { + @SuppressWarnings("unchecked") + NdArray ndArray = NdArrays.ofObjects(componentTypeOf(array), shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 2-dimensional array of objects in a new {@link NdArray} + * + * @param array source array + * @param data type + * @return the {@code NdArray} copy + */ + public static NdArray ndCopyOf(T[][] array) { + @SuppressWarnings("unchecked") + NdArray ndArray = NdArrays.ofObjects(componentTypeOf(array), shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 3-dimensional array of objects in a new {@link NdArray} + * + * @param array source array + * @param data type + * @return the {@code NdArray} copy + */ + public static NdArray ndCopyOf(T[][][] array) { + @SuppressWarnings("unchecked") + NdArray ndArray = NdArrays.ofObjects(componentTypeOf(array), shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 4-dimensional array of objects in a new {@link NdArray} + * + * @param array source array + * @param data type + * @return the {@code NdArray} copy + */ + public static NdArray ndCopyOf(T[][][][] array) { + @SuppressWarnings("unchecked") + NdArray ndArray = NdArrays.ofObjects(componentTypeOf(array), shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 5-dimensional array of objects in a new {@link NdArray} + * + * @param array source array + * @param data type + * @return the {@code NdArray} copy + */ + public static NdArray ndCopyOf(T[][][][][] array) { + @SuppressWarnings("unchecked") + NdArray ndArray = NdArrays.ofObjects(componentTypeOf(array), shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a 6-dimensional array of objects in a new {@link NdArray} + * + * @param array source array + * @param data type + * @return the {@code NdArray} copy + */ + public static NdArray ndCopyOf(T[][][][][][] array) { + @SuppressWarnings("unchecked") + NdArray ndArray = NdArrays.ofObjects(componentTypeOf(array), shapeOf(array)); + copyTo(array, ndArray); + return ndArray; + } + + /** + * Copy a {@link IntNdArray} in a new 1-dimension standard array of ints + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-1 or has a shape that + * exceeds standard arrays limits + */ + public static int[] array1dCopyOf(IntNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 1); + int[] array = new int[dims[0]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link IntNdArray} in a new 2-dimension standard array of ints + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-2 or has a shape that + * exceeds standard arrays limits + */ + public static int[][] array2dCopyOf(IntNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 2); + int[][] array = new int[dims[0]][dims[1]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link IntNdArray} in a new 3-dimension standard array of ints + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-3 or has a shape that + * exceeds standard arrays limits + */ + public static int[][][] array3dCopyOf(IntNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 3); + int[][][] array = new int[dims[0]][dims[1]][dims[2]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link IntNdArray} in a new 4-dimension standard array of ints + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-4 or has a shape that + * exceeds standard arrays limits + */ + public static int[][][][] array4dCopyOf(IntNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 4); + int[][][][] array = new int[dims[0]][dims[1]][dims[2]][dims[3]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link IntNdArray} in a new 5-dimension standard array of ints + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-5 or has a shape that + * exceeds standard arrays limits + */ + public static int[][][][][] array5dCopyOf(IntNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 5); + int[][][][][] array = new int[dims[0]][dims[1]][dims[2]][dims[3]][dims[4]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link IntNdArray} in a new 6-dimension standard array of ints + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-6 or has a shape that + * exceeds standard arrays limits + */ + public static int[][][][][][] array6dCopyOf(IntNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 6); + int[][][][][][] array = new int[dims[0]][dims[1]][dims[2]][dims[3]][dims[4]][dims[5]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link LongNdArray} in a new 1-dimension standard array of longs + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-1 or has a shape that + * exceeds standard arrays limits + */ + public static long[] array1dCopyOf(LongNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 1); + long[] array = new long[dims[0]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link LongNdArray} in a new 2-dimension standard array of longs + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-2 or has a shape that + * exceeds standard arrays limits + */ + public static long[][] array2dCopyOf(LongNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 2); + long[][] array = new long[dims[0]][dims[1]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link LongNdArray} in a new 3-dimension standard array of longs + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-3 or has a shape that + * exceeds standard arrays limits + */ + public static long[][][] array3dCopyOf(LongNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 3); + long[][][] array = new long[dims[0]][dims[1]][dims[2]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link LongNdArray} in a new 4-dimension standard array of longs + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-4 or has a shape that + * exceeds standard arrays limits + */ + public static long[][][][] array4dCopyOf(LongNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 4); + long[][][][] array = new long[dims[0]][dims[1]][dims[2]][dims[3]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link LongNdArray} in a new 5-dimension standard array of longs + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-5 or has a shape that + * exceeds standard arrays limits + */ + public static long[][][][][] array5dCopyOf(LongNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 5); + long[][][][][] array = new long[dims[0]][dims[1]][dims[2]][dims[3]][dims[4]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link LongNdArray} in a new 6-dimension standard array of longs + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-6 or has a shape that + * exceeds standard arrays limits + */ + public static long[][][][][][] array6dCopyOf(LongNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 6); + long[][][][][][] array = new long[dims[0]][dims[1]][dims[2]][dims[3]][dims[4]][dims[5]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link FloatNdArray} in a new 1-dimension standard array of floats + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-1 or has a shape that + * exceeds standard arrays limits + */ + public static float[] array1dCopyOf(FloatNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 1); + float[] array = new float[dims[0]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link FloatNdArray} in a new 2-dimension standard array of floats + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-2 or has a shape that + * exceeds standard arrays limits + */ + public static float[][] array2dCopyOf(FloatNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 2); + float[][] array = new float[dims[0]][dims[1]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link FloatNdArray} in a new 3-dimension standard array of floats + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-3 or has a shape that + * exceeds standard arrays limits + */ + public static float[][][] array3dCopyOf(FloatNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 3); + float[][][] array = new float[dims[0]][dims[1]][dims[2]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link FloatNdArray} in a new 4-dimension standard array of floats + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-4 or has a shape that + * exceeds standard arrays limits + */ + public static float[][][][] array4dCopyOf(FloatNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 4); + float[][][][] array = new float[dims[0]][dims[1]][dims[2]][dims[3]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link FloatNdArray} in a new 5-dimension standard array of floats + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-5 or has a shape that + * exceeds standard arrays limits + */ + public static float[][][][][] array5dCopyOf(FloatNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 5); + float[][][][][] array = new float[dims[0]][dims[1]][dims[2]][dims[3]][dims[4]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link FloatNdArray} in a new 6-dimension standard array of floats + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-6 or has a shape that + * exceeds standard arrays limits + */ + public static float[][][][][][] array6dCopyOf(FloatNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 6); + float[][][][][][] array = new float[dims[0]][dims[1]][dims[2]][dims[3]][dims[4]][dims[5]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link DoubleNdArray} in a new 1-dimension standard array of doubles + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-1 or has a shape that + * exceeds standard arrays limits + */ + public static double[] array1dCopyOf(DoubleNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 1); + double[] array = new double[dims[0]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link DoubleNdArray} in a new 2-dimension standard array of doubles + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-2 or has a shape that + * exceeds standard arrays limits + */ + public static double[][] array2dCopyOf(DoubleNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 2); + double[][] array = new double[dims[0]][dims[1]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link DoubleNdArray} in a new 3-dimension standard array of doubles + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-3 or has a shape that + * exceeds standard arrays limits + */ + public static double[][][] array3dCopyOf(DoubleNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 3); + double[][][] array = new double[dims[0]][dims[1]][dims[2]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link DoubleNdArray} in a new 4-dimension standard array of doubles + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-4 or has a shape that + * exceeds standard arrays limits + */ + public static double[][][][] array4dCopyOf(DoubleNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 4); + double[][][][] array = new double[dims[0]][dims[1]][dims[2]][dims[3]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link DoubleNdArray} in a new 5-dimension standard array of doubles + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-5 or has a shape that + * exceeds standard arrays limits + */ + public static double[][][][][] array5dCopyOf(DoubleNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 5); + double[][][][][] array = new double[dims[0]][dims[1]][dims[2]][dims[3]][dims[4]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link DoubleNdArray} in a new 6-dimension standard array of doubles + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-6 or has a shape that + * exceeds standard arrays limits + */ + public static double[][][][][][] array6dCopyOf(DoubleNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 6); + double[][][][][][] array = new double[dims[0]][dims[1]][dims[2]][dims[3]][dims[4]][dims[5]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link ByteNdArray} in a new 1-dimension standard array of bytes + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-1 or has a shape that + * exceeds standard arrays limits + */ + public static byte[] array1dCopyOf(ByteNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 1); + byte[] array = new byte[dims[0]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link ByteNdArray} in a new 2-dimension standard array of bytes + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-2 or has a shape that + * exceeds standard arrays limits + */ + public static byte[][] array2dCopyOf(ByteNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 2); + byte[][] array = new byte[dims[0]][dims[1]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link ByteNdArray} in a new 3-dimension standard array of bytes + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-3 or has a shape that + * exceeds standard arrays limits + */ + public static byte[][][] array3dCopyOf(ByteNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 3); + byte[][][] array = new byte[dims[0]][dims[1]][dims[2]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link ByteNdArray} in a new 4-dimension standard array of bytes + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-4 or has a shape that + * exceeds standard arrays limits + */ + public static byte[][][][] array4dCopyOf(ByteNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 4); + byte[][][][] array = new byte[dims[0]][dims[1]][dims[2]][dims[3]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link ByteNdArray} in a new 5-dimension standard array of bytes + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-5 or has a shape that + * exceeds standard arrays limits + */ + public static byte[][][][][] array5dCopyOf(ByteNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 5); + byte[][][][][] array = new byte[dims[0]][dims[1]][dims[2]][dims[3]][dims[4]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link ByteNdArray} in a new 6-dimension standard array of bytes + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-6 or has a shape that + * exceeds standard arrays limits + */ + public static byte[][][][][][] array6dCopyOf(ByteNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 6); + byte[][][][][][] array = new byte[dims[0]][dims[1]][dims[2]][dims[3]][dims[4]][dims[5]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link ShortNdArray} in a new 1-dimension standard array of shorts + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-1 or has a shape that + * exceeds standard arrays limits + */ + public static short[] array1dCopyOf(ShortNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 1); + short[] array = new short[dims[0]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link ShortNdArray} in a new 2-dimension standard array of shorts + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-2 or has a shape that + * exceeds standard arrays limits + */ + public static short[][] array2dCopyOf(ShortNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 2); + short[][] array = new short[dims[0]][dims[1]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link ShortNdArray} in a new 3-dimension standard array of shorts + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-3 or has a shape that + * exceeds standard arrays limits + */ + public static short[][][] array3dCopyOf(ShortNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 3); + short[][][] array = new short[dims[0]][dims[1]][dims[2]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link ShortNdArray} in a new 4-dimension standard array of shorts + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-4 or has a shape that + * exceeds standard arrays limits + */ + public static short[][][][] array4dCopyOf(ShortNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 4); + short[][][][] array = new short[dims[0]][dims[1]][dims[2]][dims[3]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link ShortNdArray} in a new 5-dimension standard array of shorts + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-5 or has a shape that + * exceeds standard arrays limits + */ + public static short[][][][][] array5dCopyOf(ShortNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 5); + short[][][][][] array = new short[dims[0]][dims[1]][dims[2]][dims[3]][dims[4]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link ShortNdArray} in a new 6-dimension standard array of shorts + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-6 or has a shape that + * exceeds standard arrays limits + */ + public static short[][][][][][] array6dCopyOf(ShortNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 6); + short[][][][][][] array = new short[dims[0]][dims[1]][dims[2]][dims[3]][dims[4]][dims[5]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link BooleanNdArray} in a new 1-dimension standard array of booleans + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-1 or has a shape that + * exceeds standard arrays limits + */ + public static boolean[] array1dCopyOf(BooleanNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 1); + boolean[] array = new boolean[dims[0]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link BooleanNdArray} in a new 2-dimension standard array of booleans + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-2 or has a shape that + * exceeds standard arrays limits + */ + public static boolean[][] array2dCopyOf(BooleanNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 2); + boolean[][] array = new boolean[dims[0]][dims[1]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link BooleanNdArray} in a new 3-dimension standard array of booleans + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-3 or has a shape that + * exceeds standard arrays limits + */ + public static boolean[][][] array3dCopyOf(BooleanNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 3); + boolean[][][] array = new boolean[dims[0]][dims[1]][dims[2]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link BooleanNdArray} in a new 4-dimension standard array of booleans + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-4 or has a shape that + * exceeds standard arrays limits + */ + public static boolean[][][][] array4dCopyOf(BooleanNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 4); + boolean[][][][] array = new boolean[dims[0]][dims[1]][dims[2]][dims[3]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link BooleanNdArray} in a new 5-dimension standard array of booleans + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-5 or has a shape that + * exceeds standard arrays limits + */ + public static boolean[][][][][] array5dCopyOf(BooleanNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 5); + boolean[][][][][] array = new boolean[dims[0]][dims[1]][dims[2]][dims[3]][dims[4]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link BooleanNdArray} in a new 6-dimension standard array of booleans + * + * @param ndArray source array + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-6 or has a shape that + * exceeds standard arrays limits + */ + public static boolean[][][][][][] array6dCopyOf(BooleanNdArray ndArray) { + int[] dims = computeArrayDims(ndArray, 6); + boolean[][][][][][] array = new boolean[dims[0]][dims[1]][dims[2]][dims[3]][dims[4]][dims[5]]; + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link NdArray NdArray<T>} in a new 1-dimension standard array of objects + * + * @param ndArray source array + * @param objectType type of object + * @param data type + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-1 or has a shape that + * exceeds standard arrays limits + */ + public static T[] array1dCopyOf(NdArray ndArray, Class objectType) { + int[] dims = computeArrayDims(ndArray, 1); + T[] array = (T[]) Array.newInstance(objectType, dims[0]); + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link NdArray NdArray<T>} in a new 2-dimension standard array of objects + * + * @param ndArray source array + * @param objectType type of object + * @param data type + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-2 or has a shape that + * exceeds standard arrays limits + */ + public static T[][] array2dCopyOf(NdArray ndArray, Class objectType) { + int[] dims = computeArrayDims(ndArray, 2); + T[][] array = (T[][]) Array.newInstance(objectType, dims[0], dims[1]); + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link NdArray NdArray<T>} in a new 3-dimension standard array of objects + * + * @param ndArray source array + * @param objectType type of object + * @param data type + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-3 or has a shape that + * exceeds standard arrays limits + */ + public static T[][][] array3dCopyOf(NdArray ndArray, Class objectType) { + int[] dims = computeArrayDims(ndArray, 3); + T[][][] array = (T[][][]) Array.newInstance(objectType, dims[0], dims[1], dims[2]); + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link NdArray NdArray<T>} in a new 4-dimension standard array of objects + * + * @param ndArray source array + * @param objectType type of object + * @param data type + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-4 or has a shape that + * exceeds standard arrays limits + */ + public static T[][][][] array4dCopyOf(NdArray ndArray, Class objectType) { + int[] dims = computeArrayDims(ndArray, 4); + T[][][][] array = (T[][][][]) Array.newInstance(objectType, dims[0], dims[1], dims[2], dims[3]); + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link NdArray NdArray<T>} in a new 5-dimension standard array of objects + * + * @param ndArray source array + * @param objectType type of object + * @param data type + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-5 or has a shape that + * exceeds standard arrays limits + */ + public static T[][][][][] array5dCopyOf(NdArray ndArray, Class objectType) { + int[] dims = computeArrayDims(ndArray, 5); + T[][][][][] array = + (T[][][][][]) Array.newInstance(objectType, dims[0], dims[1], dims[2], dims[3], dims[4]); + copyFrom(ndArray, array); + return array; + } + + /** + * Copy a {@link NdArray NdArray<T>} in a new 6-dimension standard array of objects + * + * @param ndArray source array + * @param objectType type of object + * @param data type + * @return the array copy + * @throws IllegalArgumentException if {@code ndArray} is not of rank-6 or has a shape that + * exceeds standard arrays limits + */ + public static T[][][][][][] array6dCopyOf(NdArray ndArray, Class objectType) { + int[] dims = computeArrayDims(ndArray, 6); + T[][][][][][] array = + (T[][][][][][]) + Array.newInstance(objectType, dims[0], dims[1], dims[2], dims[3], dims[4], dims[5]); + copyFrom(ndArray, array); + return array; + } + + /** + * Copy an array of ints into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-1 array + * @throws IllegalArgumentException if {@code dst} is not of rank-1 or has an incompatible shape + * with the source array + */ + public static void copyTo(int[] src, IntNdArray dst) { + NdArrays.vectorOf(src).copyTo(dst); + } + + /** + * Copy a 2-dimensional array of ints into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-2 array + * @throws IllegalArgumentException if {@code dst} is not of rank-2 or has an incompatible shape + * with the source array + */ + public static void copyTo(int[][] src, IntNdArray dst) { + dst.elements(0).forEachIndexed((idx, e) -> NdArrays.vectorOf(src[(int) idx[0]]).copyTo(e)); + } + + /** + * Copy a 3-dimensional array of ints into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-3 array + * @throws IllegalArgumentException if {@code dst} is not of rank-3 or has an incompatible shape + * with the source array + */ + public static void copyTo(int[][][] src, IntNdArray dst) { + dst.elements(1) + .forEachIndexed((idx, e) -> NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]]).copyTo(e)); + } + + /** + * Copy a 4-dimensional array of ints into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-4 array + * @throws IllegalArgumentException if {@code dst} is not of rank-4 or has an incompatible shape + * with the source array + */ + public static void copyTo(int[][][][] src, IntNdArray dst) { + dst.elements(2) + .forEachIndexed( + (idx, e) -> NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]][(int) idx[2]]).copyTo(e)); + } + + /** + * Copy a 5-dimensional array of ints into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-5 array + * @throws IllegalArgumentException if {@code dst} is not of rank-5 or has an incompatible shape + * with the source array + */ + public static void copyTo(int[][][][][] src, IntNdArray dst) { + dst.elements(3) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]]) + .copyTo(e)); + } + + /** + * Copy a 6-dimensional array of ints into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-6 array + * @throws IllegalArgumentException if {@code dst} is not of rank-6 or has an incompatible shape + * with the source array + */ + public static void copyTo(int[][][][][][] src, IntNdArray dst) { + dst.elements(4) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOf( + src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]]) + .copyTo(e)); + } + + /** + * Copy an array of longs into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-1 array + * @throws IllegalArgumentException if {@code dst} is not of rank-1 or has an incompatible shape + * with the source array + */ + public static void copyTo(long[] src, LongNdArray dst) { + NdArrays.vectorOf(src).copyTo(dst); + } + + /** + * Copy a 2-dimensional array of longs into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-2 array + * @throws IllegalArgumentException if {@code dst} is not of rank-2 or has an incompatible shape + * with the source array + */ + public static void copyTo(long[][] src, LongNdArray dst) { + dst.elements(0).forEachIndexed((idx, e) -> NdArrays.vectorOf(src[(int) idx[0]]).copyTo(e)); + } + + /** + * Copy a 3-dimensional array of longs into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-3 array + * @throws IllegalArgumentException if {@code dst} is not of rank-3 or has an incompatible shape + * with the source array + */ + public static void copyTo(long[][][] src, LongNdArray dst) { + dst.elements(1) + .forEachIndexed((idx, e) -> NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]]).copyTo(e)); + } + + /** + * Copy a 4-dimensional array of longs into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-4 array + * @throws IllegalArgumentException if {@code dst} is not of rank-4 or has an incompatible shape + * with the source array + */ + public static void copyTo(long[][][][] src, LongNdArray dst) { + dst.elements(2) + .forEachIndexed( + (idx, e) -> NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]][(int) idx[2]]).copyTo(e)); + } + + /** + * Copy a 5-dimensional array of longs into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-5 array + * @throws IllegalArgumentException if {@code dst} is not of rank-5 or has an incompatible shape + * with the source array + */ + public static void copyTo(long[][][][][] src, LongNdArray dst) { + dst.elements(3) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]]) + .copyTo(e)); + } + + /** + * Copy a 6-dimensional array of longs into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-6 array + * @throws IllegalArgumentException if {@code dst} is not of rank-6 or has an incompatible shape + * with the source array + */ + public static void copyTo(long[][][][][][] src, LongNdArray dst) { + dst.elements(4) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOf( + src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]]) + .copyTo(e)); + } + + /** + * Copy an array of floats into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-1 array + * @throws IllegalArgumentException if {@code dst} is not of rank-1 or has an incompatible shape + * with the source array + */ + public static void copyTo(float[] src, FloatNdArray dst) { + NdArrays.vectorOf(src).copyTo(dst); + } + + /** + * Copy a 2-dimensional array of floats into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-2 array + * @throws IllegalArgumentException if {@code dst} is not of rank-2 or has an incompatible shape + * with the source array + */ + public static void copyTo(float[][] src, FloatNdArray dst) { + dst.elements(0).forEachIndexed((idx, e) -> NdArrays.vectorOf(src[(int) idx[0]]).copyTo(e)); + } + + /** + * Copy a 3-dimensional array of floats into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-3 array + * @throws IllegalArgumentException if {@code dst} is not of rank-3 or has an incompatible shape + * with the source array + */ + public static void copyTo(float[][][] src, FloatNdArray dst) { + dst.elements(1) + .forEachIndexed((idx, e) -> NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]]).copyTo(e)); + } + + /** + * Copy a 4-dimensional array of floats into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-4 array + * @throws IllegalArgumentException if {@code dst} is not of rank-4 or has an incompatible shape + * with the source array + */ + public static void copyTo(float[][][][] src, FloatNdArray dst) { + dst.elements(2) + .forEachIndexed( + (idx, e) -> NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]][(int) idx[2]]).copyTo(e)); + } + + /** + * Copy a 5-dimensional array of floats into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-5 array + * @throws IllegalArgumentException if {@code dst} is not of rank-5 or has an incompatible shape + * with the source array + */ + public static void copyTo(float[][][][][] src, FloatNdArray dst) { + dst.elements(3) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]]) + .copyTo(e)); + } + + /** + * Copy a 6-dimensional array of floats into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-6 array + * @throws IllegalArgumentException if {@code dst} is not of rank-6 or has an incompatible shape + * with the source array + */ + public static void copyTo(float[][][][][][] src, FloatNdArray dst) { + dst.elements(4) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOf( + src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]]) + .copyTo(e)); + } + + /** + * Copy an array of doubles into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-1 array + * @throws IllegalArgumentException if {@code dst} is not of rank-1 or has an incompatible shape + * with the source array + */ + public static void copyTo(double[] src, DoubleNdArray dst) { + NdArrays.vectorOf(src).copyTo(dst); + } + + /** + * Copy a 2-dimensional array of doubles into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-2 array + * @throws IllegalArgumentException if {@code dst} is not of rank-2 or has an incompatible shape + * with the source array + */ + public static void copyTo(double[][] src, DoubleNdArray dst) { + dst.elements(0).forEachIndexed((idx, e) -> NdArrays.vectorOf(src[(int) idx[0]]).copyTo(e)); + } + + /** + * Copy a 3-dimensional array of doubles into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-3 array + * @throws IllegalArgumentException if {@code dst} is not of rank-3 or has an incompatible shape + * with the source array + */ + public static void copyTo(double[][][] src, DoubleNdArray dst) { + dst.elements(1) + .forEachIndexed((idx, e) -> NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]]).copyTo(e)); + } + + /** + * Copy a 4-dimensional array of doubles into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-4 array + * @throws IllegalArgumentException if {@code dst} is not of rank-4 or has an incompatible shape + * with the source array + */ + public static void copyTo(double[][][][] src, DoubleNdArray dst) { + dst.elements(2) + .forEachIndexed( + (idx, e) -> NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]][(int) idx[2]]).copyTo(e)); + } + + /** + * Copy a 5-dimensional array of doubles into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-5 array + * @throws IllegalArgumentException if {@code dst} is not of rank-5 or has an incompatible shape + * with the source array + */ + public static void copyTo(double[][][][][] src, DoubleNdArray dst) { + dst.elements(3) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]]) + .copyTo(e)); + } + + /** + * Copy a 6-dimensional array of doubles into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-6 array + * @throws IllegalArgumentException if {@code dst} is not of rank-6 or has an incompatible shape + * with the source array + */ + public static void copyTo(double[][][][][][] src, DoubleNdArray dst) { + dst.elements(4) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOf( + src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]]) + .copyTo(e)); + } + + /** + * Copy an array of bytes into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-1 array + * @throws IllegalArgumentException if {@code dst} is not of rank-1 or has an incompatible shape + * with the source array + */ + public static void copyTo(byte[] src, ByteNdArray dst) { + NdArrays.vectorOf(src).copyTo(dst); + } + + /** + * Copy a 2-dimensional array of bytes into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-2 array + * @throws IllegalArgumentException if {@code dst} is not of rank-2 or has an incompatible shape + * with the source array + */ + public static void copyTo(byte[][] src, ByteNdArray dst) { + dst.elements(0).forEachIndexed((idx, e) -> NdArrays.vectorOf(src[(int) idx[0]]).copyTo(e)); + } + + /** + * Copy a 3-dimensional array of bytes into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-3 array + * @throws IllegalArgumentException if {@code dst} is not of rank-3 or has an incompatible shape + * with the source array + */ + public static void copyTo(byte[][][] src, ByteNdArray dst) { + dst.elements(1) + .forEachIndexed((idx, e) -> NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]]).copyTo(e)); + } + + /** + * Copy a 4-dimensional array of bytes into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-4 array + * @throws IllegalArgumentException if {@code dst} is not of rank-4 or has an incompatible shape + * with the source array + */ + public static void copyTo(byte[][][][] src, ByteNdArray dst) { + dst.elements(2) + .forEachIndexed( + (idx, e) -> NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]][(int) idx[2]]).copyTo(e)); + } + + /** + * Copy a 5-dimensional array of bytes into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-5 array + * @throws IllegalArgumentException if {@code dst} is not of rank-5 or has an incompatible shape + * with the source array + */ + public static void copyTo(byte[][][][][] src, ByteNdArray dst) { + dst.elements(3) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]]) + .copyTo(e)); + } + + /** + * Copy a 6-dimensional array of bytes into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-6 array + * @throws IllegalArgumentException if {@code dst} is not of rank-6 or has an incompatible shape + * with the source array + */ + public static void copyTo(byte[][][][][][] src, ByteNdArray dst) { + dst.elements(4) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOf( + src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]]) + .copyTo(e)); + } + + /** + * Copy an array of shorts into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-1 array + * @throws IllegalArgumentException if {@code dst} is not of rank-1 or has an incompatible shape + * with the source array + */ + public static void copyTo(short[] src, ShortNdArray dst) { + NdArrays.vectorOf(src).copyTo(dst); + } + + /** + * Copy a 2-dimensional array of shorts into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-2 array + * @throws IllegalArgumentException if {@code dst} is not of rank-2 or has an incompatible shape + * with the source array + */ + public static void copyTo(short[][] src, ShortNdArray dst) { + dst.elements(0).forEachIndexed((idx, e) -> NdArrays.vectorOf(src[(int) idx[0]]).copyTo(e)); + } + + /** + * Copy a 3-dimensional array of shorts into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-3 array + * @throws IllegalArgumentException if {@code dst} is not of rank-3 or has an incompatible shape + * with the source array + */ + public static void copyTo(short[][][] src, ShortNdArray dst) { + dst.elements(1) + .forEachIndexed((idx, e) -> NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]]).copyTo(e)); + } + + /** + * Copy a 4-dimensional array of shorts into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-4 array + * @throws IllegalArgumentException if {@code dst} is not of rank-4 or has an incompatible shape + * with the source array + */ + public static void copyTo(short[][][][] src, ShortNdArray dst) { + dst.elements(2) + .forEachIndexed( + (idx, e) -> NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]][(int) idx[2]]).copyTo(e)); + } + + /** + * Copy a 5-dimensional array of shorts into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-5 array + * @throws IllegalArgumentException if {@code dst} is not of rank-5 or has an incompatible shape + * with the source array + */ + public static void copyTo(short[][][][][] src, ShortNdArray dst) { + dst.elements(3) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]]) + .copyTo(e)); + } + + /** + * Copy a 6-dimensional array of shorts into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-6 array + * @throws IllegalArgumentException if {@code dst} is not of rank-6 or has an incompatible shape + * with the source array + */ + public static void copyTo(short[][][][][][] src, ShortNdArray dst) { + dst.elements(4) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOf( + src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]]) + .copyTo(e)); + } + + /** + * Copy an array of booleans into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-1 array + * @throws IllegalArgumentException if {@code dst} is not of rank-1 or has an incompatible shape + * with the source array + */ + public static void copyTo(boolean[] src, BooleanNdArray dst) { + NdArrays.vectorOf(src).copyTo(dst); + } + + /** + * Copy a 2-dimensional array of booleans into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-2 array + * @throws IllegalArgumentException if {@code dst} is not of rank-2 or has an incompatible shape + * with the source array + */ + public static void copyTo(boolean[][] src, BooleanNdArray dst) { + dst.elements(0).forEachIndexed((idx, e) -> NdArrays.vectorOf(src[(int) idx[0]]).copyTo(e)); + } + + /** + * Copy a 3-dimensional array of booleans into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-3 array + * @throws IllegalArgumentException if {@code dst} is not of rank-3 or has an incompatible shape + * with the source array + */ + public static void copyTo(boolean[][][] src, BooleanNdArray dst) { + dst.elements(1) + .forEachIndexed((idx, e) -> NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]]).copyTo(e)); + } + + /** + * Copy a 4-dimensional array of booleans into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-4 array + * @throws IllegalArgumentException if {@code dst} is not of rank-4 or has an incompatible shape + * with the source array + */ + public static void copyTo(boolean[][][][] src, BooleanNdArray dst) { + dst.elements(2) + .forEachIndexed( + (idx, e) -> NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]][(int) idx[2]]).copyTo(e)); + } + + /** + * Copy a 5-dimensional array of booleans into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-5 array + * @throws IllegalArgumentException if {@code dst} is not of rank-5 or has an incompatible shape + * with the source array + */ + public static void copyTo(boolean[][][][][] src, BooleanNdArray dst) { + dst.elements(3) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOf(src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]]) + .copyTo(e)); + } + + /** + * Copy a 6-dimensional array of booleans into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-6 array + * @throws IllegalArgumentException if {@code dst} is not of rank-6 or has an incompatible shape + * with the source array + */ + public static void copyTo(boolean[][][][][][] src, BooleanNdArray dst) { + dst.elements(4) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOf( + src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]]) + .copyTo(e)); + } + + /** + * Copy an array of objects into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-1 array + * @param data type + * @throws IllegalArgumentException if {@code dst} is not of rank-1 or has an incompatible shape + * with the source array + */ + public static void copyTo(T[] src, NdArray dst) { + NdArrays.vectorOfObjects(src).copyTo(dst); + } + + /** + * Copy a 2-dimensional array of objects into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-2 array + * @param data type + * @throws IllegalArgumentException if {@code dst} is not of rank-2 or has an incompatible shape + * with the source array + */ + public static void copyTo(T[][] src, NdArray dst) { + dst.elements(0) + .forEachIndexed((idx, e) -> NdArrays.vectorOfObjects(src[(int) idx[0]]).copyTo(e)); + } + + /** + * Copy a 3-dimensional array of objects into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-3 array + * @param data type + * @throws IllegalArgumentException if {@code dst} is not of rank-3 or has an incompatible shape + * with the source array + */ + public static void copyTo(T[][][] src, NdArray dst) { + dst.elements(1) + .forEachIndexed( + (idx, e) -> NdArrays.vectorOfObjects(src[(int) idx[0]][(int) idx[1]]).copyTo(e)); + } + + /** + * Copy a 4-dimensional array of objects into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-4 array + * @param data type + * @throws IllegalArgumentException if {@code dst} is not of rank-4 or has an incompatible shape + * with the source array + */ + public static void copyTo(T[][][][] src, NdArray dst) { + dst.elements(2) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOfObjects(src[(int) idx[0]][(int) idx[1]][(int) idx[2]]).copyTo(e)); + } + + /** + * Copy a 5-dimensional array of objects into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-5 array + * @param data type + * @throws IllegalArgumentException if {@code dst} is not of rank-5 or has an incompatible shape + * with the source array + */ + public static void copyTo(T[][][][][] src, NdArray dst) { + dst.elements(3) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOfObjects( + src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]]) + .copyTo(e)); + } + + /** + * Copy a 6-dimensional array of objects into the {@code dst} {@link NdArray} + * + * @param src source array + * @param dst destination rank-6 array + * @param data type + * @throws IllegalArgumentException if {@code dst} is not of rank-6 or has an incompatible shape + * with the source array + */ + public static void copyTo(T[][][][][][] src, NdArray dst) { + dst.elements(4) + .forEachIndexed( + (idx, e) -> + NdArrays.vectorOfObjects( + src[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]]) + .copyTo(e)); + } + + /** + * Copy a {@link NdArray} to an array of ints + * + * @param src source rank-1 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-1 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(IntNdArray src, int[] dst) { + if (src.rank() != 1) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + if (src.size() > dst.length) { + throw new ArrayIndexOutOfBoundsException(String.valueOf(src.size()) + " > " + dst.length); + } + src.copyTo(DataBuffers.of(dst, false, false)); + } + + /** + * Copy a {@link NdArray} to a 2-dimensional array of ints + * + * @param src source rank-2 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-2 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(IntNdArray src, int[][] dst) { + if (src.rank() != 2) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(0).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]])); + } + + /** + * Copy a {@link NdArray} to a 3-dimensional array of ints + * + * @param src source rank-3 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-3 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(IntNdArray src, int[][][] dst) { + if (src.rank() != 3) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(1).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]])); + } + + /** + * Copy a {@link NdArray} to a 4-dimensional array of ints + * + * @param src source rank-4 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-4 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(IntNdArray src, int[][][][] dst) { + if (src.rank() != 4) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(2) + .forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]])); + } + + /** + * Copy a {@link NdArray} to a 5-dimensional array of ints + * + * @param src source rank-5 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-5 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(IntNdArray src, int[][][][][] dst) { + if (src.rank() != 5) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(3) + .forEachIndexed( + (idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]])); + } + + /** + * Copy a {@link NdArray} to a 6-dimensional array of ints + * + * @param src source rank-6 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-6 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(IntNdArray src, int[][][][][][] dst) { + if (src.rank() != 6) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(4) + .forEachIndexed( + (idx, e) -> + copyFrom( + e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]])); + } + + /** + * Copy a {@link NdArray} to an array of longs + * + * @param src source rank-1 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-1 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(LongNdArray src, long[] dst) { + if (src.rank() != 1) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + if (src.size() > dst.length) { + throw new ArrayIndexOutOfBoundsException(String.valueOf(src.size()) + " > " + dst.length); + } + src.copyTo(DataBuffers.of(dst, false, false)); + } + + /** + * Copy a {@link NdArray} to a 2-dimensional array of longs + * + * @param src source rank-2 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-2 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(LongNdArray src, long[][] dst) { + if (src.rank() != 2) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(0).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]])); + } + + /** + * Copy a {@link NdArray} to a 3-dimensional array of longs + * + * @param src source rank-3 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-3 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(LongNdArray src, long[][][] dst) { + if (src.rank() != 3) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(1).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]])); + } + + /** + * Copy a {@link NdArray} to a 4-dimensional array of longs + * + * @param src source rank-4 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-4 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(LongNdArray src, long[][][][] dst) { + if (src.rank() != 4) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(2) + .forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]])); + } + + /** + * Copy a {@link NdArray} to a 5-dimensional array of longs + * + * @param src source rank-5 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-5 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(LongNdArray src, long[][][][][] dst) { + if (src.rank() != 5) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(3) + .forEachIndexed( + (idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]])); + } + + /** + * Copy a {@link NdArray} to a 6-dimensional array of longs + * + * @param src source rank-6 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-6 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(LongNdArray src, long[][][][][][] dst) { + if (src.rank() != 6) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(4) + .forEachIndexed( + (idx, e) -> + copyFrom( + e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]])); + } + + /** + * Copy a {@link NdArray} to an array of floats + * + * @param src source rank-1 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-1 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(FloatNdArray src, float[] dst) { + if (src.rank() != 1) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + if (src.size() > dst.length) { + throw new ArrayIndexOutOfBoundsException(String.valueOf(src.size()) + " > " + dst.length); + } + src.copyTo(DataBuffers.of(dst, false, false)); + } + + /** + * Copy a {@link NdArray} to a 2-dimensional array of floats + * + * @param src source rank-2 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-2 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(FloatNdArray src, float[][] dst) { + if (src.rank() != 2) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(0).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]])); + } + + /** + * Copy a {@link NdArray} to a 3-dimensional array of floats + * + * @param src source rank-3 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-3 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(FloatNdArray src, float[][][] dst) { + if (src.rank() != 3) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(1).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]])); + } + + /** + * Copy a {@link NdArray} to a 4-dimensional array of floats + * + * @param src source rank-4 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-4 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(FloatNdArray src, float[][][][] dst) { + if (src.rank() != 4) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(2) + .forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]])); + } + + /** + * Copy a {@link NdArray} to a 5-dimensional array of floats + * + * @param src source rank-5 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-5 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(FloatNdArray src, float[][][][][] dst) { + if (src.rank() != 5) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(3) + .forEachIndexed( + (idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]])); + } + + /** + * Copy a {@link NdArray} to a 6-dimensional array of floats + * + * @param src source rank-6 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-6 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(FloatNdArray src, float[][][][][][] dst) { + if (src.rank() != 6) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(4) + .forEachIndexed( + (idx, e) -> + copyFrom( + e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]])); + } + + /** + * Copy a {@link NdArray} to an array of doubles + * + * @param src source rank-1 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-1 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(DoubleNdArray src, double[] dst) { + if (src.rank() != 1) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + if (src.size() > dst.length) { + throw new ArrayIndexOutOfBoundsException(String.valueOf(src.size()) + " > " + dst.length); + } + src.copyTo(DataBuffers.of(dst, false, false)); + } + + /** + * Copy a {@link NdArray} to a 2-dimensional array of doubles + * + * @param src source rank-2 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-2 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(DoubleNdArray src, double[][] dst) { + if (src.rank() != 2) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(0).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]])); + } + + /** + * Copy a {@link NdArray} to a 3-dimensional array of doubles + * + * @param src source rank-3 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-3 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(DoubleNdArray src, double[][][] dst) { + if (src.rank() != 3) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(1).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]])); + } + + /** + * Copy a {@link NdArray} to a 4-dimensional array of doubles + * + * @param src source rank-4 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-4 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(DoubleNdArray src, double[][][][] dst) { + if (src.rank() != 4) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(2) + .forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]])); + } + + /** + * Copy a {@link NdArray} to a 5-dimensional array of doubles + * + * @param src source rank-5 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-5 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(DoubleNdArray src, double[][][][][] dst) { + if (src.rank() != 5) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(3) + .forEachIndexed( + (idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]])); + } + + /** + * Copy a {@link NdArray} to a 6-dimensional array of doubles + * + * @param src source rank-6 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-6 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(DoubleNdArray src, double[][][][][][] dst) { + if (src.rank() != 6) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(4) + .forEachIndexed( + (idx, e) -> + copyFrom( + e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]])); + } + + /** + * Copy a {@link NdArray} to an array of bytes + * + * @param src source rank-1 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-1 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(ByteNdArray src, byte[] dst) { + if (src.rank() != 1) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + if (src.size() > dst.length) { + throw new ArrayIndexOutOfBoundsException(String.valueOf(src.size()) + " > " + dst.length); + } + src.copyTo(DataBuffers.of(dst, false, false)); + } + + /** + * Copy a {@link NdArray} to a 2-dimensional array of bytes + * + * @param src source rank-2 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-2 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(ByteNdArray src, byte[][] dst) { + if (src.rank() != 2) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(0).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]])); + } + + /** + * Copy a {@link NdArray} to a 3-dimensional array of bytes + * + * @param src source rank-3 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-3 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(ByteNdArray src, byte[][][] dst) { + if (src.rank() != 3) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(1).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]])); + } + + /** + * Copy a {@link NdArray} to a 4-dimensional array of bytes + * + * @param src source rank-4 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-4 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(ByteNdArray src, byte[][][][] dst) { + if (src.rank() != 4) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(2) + .forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]])); + } + + /** + * Copy a {@link NdArray} to a 5-dimensional array of bytes + * + * @param src source rank-5 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-5 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(ByteNdArray src, byte[][][][][] dst) { + if (src.rank() != 5) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(3) + .forEachIndexed( + (idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]])); + } + + /** + * Copy a {@link NdArray} to a 6-dimensional array of bytes + * + * @param src source rank-6 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-6 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(ByteNdArray src, byte[][][][][][] dst) { + if (src.rank() != 6) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(4) + .forEachIndexed( + (idx, e) -> + copyFrom( + e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]])); + } + + /** + * Copy a {@link NdArray} to an array of shorts + * + * @param src source rank-1 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-1 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(ShortNdArray src, short[] dst) { + if (src.rank() != 1) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + if (src.size() > dst.length) { + throw new ArrayIndexOutOfBoundsException(String.valueOf(src.size()) + " > " + dst.length); + } + src.copyTo(DataBuffers.of(dst, false, false)); + } + + /** + * Copy a {@link NdArray} to a 2-dimensional array of shorts + * + * @param src source rank-2 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-2 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(ShortNdArray src, short[][] dst) { + if (src.rank() != 2) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(0).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]])); + } + + /** + * Copy a {@link NdArray} to a 3-dimensional array of shorts + * + * @param src source rank-3 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-3 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(ShortNdArray src, short[][][] dst) { + if (src.rank() != 3) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(1).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]])); + } + + /** + * Copy a {@link NdArray} to a 4-dimensional array of shorts + * + * @param src source rank-4 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-4 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(ShortNdArray src, short[][][][] dst) { + if (src.rank() != 4) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(2) + .forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]])); + } + + /** + * Copy a {@link NdArray} to a 5-dimensional array of shorts + * + * @param src source rank-5 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-5 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(ShortNdArray src, short[][][][][] dst) { + if (src.rank() != 5) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(3) + .forEachIndexed( + (idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]])); + } + + /** + * Copy a {@link NdArray} to a 6-dimensional array of shorts + * + * @param src source rank-6 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-6 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(ShortNdArray src, short[][][][][][] dst) { + if (src.rank() != 6) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(4) + .forEachIndexed( + (idx, e) -> + copyFrom( + e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]])); + } + + /** + * Copy a {@link NdArray} to an array of booleans. + * + * @param src source rank-1 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-1 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(BooleanNdArray src, boolean[] dst) { + if (src.rank() != 1) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + if (src.size() > dst.length) { + throw new ArrayIndexOutOfBoundsException(String.valueOf(src.size()) + " > " + dst.length); + } + src.copyTo(DataBuffers.of(dst, false, false)); + } + + /** + * Copy a {@link NdArray} to a 2-dimensional array of booleans + * + * @param src source rank-2 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-2 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(BooleanNdArray src, boolean[][] dst) { + if (src.rank() != 2) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(0).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]])); + } + + /** + * Copy a {@link NdArray} to a 3-dimensional array of booleans + * + * @param src source rank-3 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-3 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(BooleanNdArray src, boolean[][][] dst) { + if (src.rank() != 3) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(1).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]])); + } + + /** + * Copy a {@link NdArray} to a 4-dimensional array of booleans + * + * @param src source rank-4 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-4 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(BooleanNdArray src, boolean[][][][] dst) { + if (src.rank() != 4) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(2) + .forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]])); + } + + /** + * Copy a {@link NdArray} to a 5-dimensional array of booleans + * + * @param src source rank-5 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-5 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(BooleanNdArray src, boolean[][][][][] dst) { + if (src.rank() != 5) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(3) + .forEachIndexed( + (idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]])); + } + + /** + * Copy a {@link NdArray} to a 6-dimensional array of booleans + * + * @param src source rank-6 array + * @param dst destination array + * @throws IllegalArgumentException if {@code src} is not of rank-6 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(BooleanNdArray src, boolean[][][][][][] dst) { + if (src.rank() != 6) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(4) + .forEachIndexed( + (idx, e) -> + copyFrom( + e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]])); + } + + /** + * Copy a {@link NdArray} to an array of objects + * + * @param src source rank-1 array + * @param dst destination array + * @param data type + * @throws IllegalArgumentException if {@code src} is not of rank-1 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(NdArray src, T[] dst) { + if (src.rank() != 1) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + if (src.size() > dst.length) { + throw new ArrayIndexOutOfBoundsException(String.valueOf(src.size()) + " > " + dst.length); + } + src.copyTo(DataBuffers.of(dst, false, false)); + } + + /** + * Copy a {@link NdArray} to a 2-dimensional array of objects + * + * @param src source rank-2 array + * @param dst destination array + * @param data type + * @throws IllegalArgumentException if {@code src} is not of rank-2 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(NdArray src, T[][] dst) { + if (src.rank() != 2) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(0).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]])); + } + + /** + * Copy a {@link NdArray} to a 3-dimensional array of objects + * + * @param src source rank-3 array + * @param dst destination array + * @param data type + * @throws IllegalArgumentException if {@code src} is not of rank-3 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(NdArray src, T[][][] dst) { + if (src.rank() != 3) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(1).forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]])); + } + + /** + * Copy a {@link NdArray} to a 4-dimensional array of objects + * + * @param src source rank-4 array + * @param dst destination array + * @param data type + * @throws IllegalArgumentException if {@code src} is not of rank-4 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(NdArray src, T[][][][] dst) { + if (src.rank() != 4) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(2) + .forEachIndexed((idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]])); + } + + /** + * Copy a {@link NdArray} to a 5-dimensional array of objects + * + * @param src source rank-5 array + * @param dst destination array + * @param data type + * @throws IllegalArgumentException if {@code src} is not of rank-5 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(NdArray src, T[][][][][] dst) { + if (src.rank() != 5) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(3) + .forEachIndexed( + (idx, e) -> copyFrom(e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]])); + } + + /** + * Copy a {@link NdArray} to a 6-dimensional array of objects + * + * @param src source rank-6 array + * @param dst destination array + * @param data type + * @throws IllegalArgumentException if {@code src} is not of rank-6 + * @throws ArrayIndexOutOfBoundsException if not all elements of {@code src} can fit it the + * destination array + */ + public static void copyFrom(NdArray src, T[][][][][][] dst) { + if (src.rank() != 6) { + throw new IllegalArgumentException( + "Array cannot be copied from NdArray of rank " + src.rank()); + } + src.elements(4) + .forEachIndexed( + (idx, e) -> + copyFrom( + e, dst[(int) idx[0]][(int) idx[1]][(int) idx[2]][(int) idx[3]][(int) idx[4]])); + } + + /** + * Compute the shape of an int array. + * + * @param array 1D array + * @return shape of the array + */ + public static Shape shapeOf(int[] array) { + return Shape.of(array.length); + } + + /** + * Compute the shape of a 2-dimensional int array. + * + * @param array 2D array + * @return shape of the array + */ + public static Shape shapeOf(int[][] array) { + return Shape.of(computeShape(array, new long[2])); + } + + /** + * Compute the shape of a 3-dimensional int array. + * + * @param array 3D array + * @return shape of the array + */ + public static Shape shapeOf(int[][][] array) { + return Shape.of(computeShape(array, new long[3])); + } + + /** + * Compute the shape of a 4-dimensional int array. + * + * @param array 4D array + * @return shape of the array + */ + public static Shape shapeOf(int[][][][] array) { + return Shape.of(computeShape(array, new long[4])); + } + + /** + * Compute the shape of a 5-dimensional int array. + * + * @param array 5D array + * @return shape of the array + */ + public static Shape shapeOf(int[][][][][] array) { + return Shape.of(computeShape(array, new long[5])); + } + + /** + * Compute the shape of a 6-dimensional int array. + * + * @param array 6D array + * @return shape of the array + */ + public static Shape shapeOf(int[][][][][][] array) { + return Shape.of(computeShape(array, new long[6])); + } + + /** + * Compute the shape of a long array. + * + * @param array 1D array + * @return shape of the array + */ + public static Shape shapeOf(long[] array) { + return Shape.of(array.length); + } + + /** + * Compute the shape of a 2-dimensional long array. + * + * @param array 2D array + * @return shape of the array + */ + public static Shape shapeOf(long[][] array) { + return Shape.of(computeShape(array, new long[2])); + } + + /** + * Compute the shape of a 3-dimensional long array. + * + * @param array 3D array + * @return shape of the array + */ + public static Shape shapeOf(long[][][] array) { + return Shape.of(computeShape(array, new long[3])); + } + + /** + * Compute the shape of a 4-dimensional long array. + * + * @param array 4D array + * @return shape of the array + */ + public static Shape shapeOf(long[][][][] array) { + return Shape.of(computeShape(array, new long[4])); + } + + /** + * Compute the shape of a 5-dimensional long array. + * + * @param array 5D array + * @return shape of the array + */ + public static Shape shapeOf(long[][][][][] array) { + return Shape.of(computeShape(array, new long[5])); + } + + /** + * Compute the shape of a 6-dimensional long array. + * + * @param array 6D array + * @return shape of the array + */ + public static Shape shapeOf(long[][][][][][] array) { + return Shape.of(computeShape(array, new long[6])); + } + + /** + * Compute the shape of a float array. + * + * @param array 1D array + * @return shape of the array + */ + public static Shape shapeOf(float[] array) { + return Shape.of(array.length); + } + + /** + * Compute the shape of a 2-dimensional float array. + * + * @param array 2D array + * @return shape of the array + */ + public static Shape shapeOf(float[][] array) { + return Shape.of(computeShape(array, new long[2])); + } + + /** + * Compute the shape of a 3-dimensional float array. + * + * @param array 3D array + * @return shape of the array + */ + public static Shape shapeOf(float[][][] array) { + return Shape.of(computeShape(array, new long[3])); + } + + /** + * Compute the shape of a 4-dimensional float array. + * + * @param array 4D array + * @return shape of the array + */ + public static Shape shapeOf(float[][][][] array) { + return Shape.of(computeShape(array, new long[4])); + } + + /** + * Compute the shape of a 5-dimensional float array. + * + * @param array 5D array + * @return shape of the array + */ + public static Shape shapeOf(float[][][][][] array) { + return Shape.of(computeShape(array, new long[5])); + } + + /** + * Compute the shape of a 6-dimensional float array. + * + * @param array 6D array + * @return shape of the array + */ + public static Shape shapeOf(float[][][][][][] array) { + return Shape.of(computeShape(array, new long[6])); + } + + /** + * Compute the shape of a double array. + * + * @param array 1D array + * @return shape of the array + */ + public static Shape shapeOf(double[] array) { + return Shape.of(array.length); + } + + /** + * Compute the shape of a 2-dimensional double array. + * + * @param array 2D array + * @return shape of the array + */ + public static Shape shapeOf(double[][] array) { + return Shape.of(computeShape(array, new long[2])); + } + + /** + * Compute the shape of a 3-dimensional double array. + * + * @param array 3D array + * @return shape of the array + */ + public static Shape shapeOf(double[][][] array) { + return Shape.of(computeShape(array, new long[3])); + } + + /** + * Compute the shape of a 4-dimensional double array. + * + * @param array 4D array + * @return shape of the array + */ + public static Shape shapeOf(double[][][][] array) { + return Shape.of(computeShape(array, new long[4])); + } + + /** + * Compute the shape of a 5-dimensional double array. + * + * @param array 5D array + * @return shape of the array + */ + public static Shape shapeOf(double[][][][][] array) { + return Shape.of(computeShape(array, new long[5])); + } + + /** + * Compute the shape of a 6-dimensional double array. + * + * @param array 6D array + * @return shape of the array + */ + public static Shape shapeOf(double[][][][][][] array) { + return Shape.of(computeShape(array, new long[6])); + } + + /** + * Compute the shape of a byte array. + * + * @param array 1D array + * @return shape of the array + */ + public static Shape shapeOf(byte[] array) { + return Shape.of(array.length); + } + + /** + * Compute the shape of a 2-dimensional byte array. + * + * @param array 2D array + * @return shape of the array + */ + public static Shape shapeOf(byte[][] array) { + return Shape.of(computeShape(array, new long[2])); + } + + /** + * Compute the shape of a 3-dimensional byte array. + * + * @param array 3D array + * @return shape of the array + */ + public static Shape shapeOf(byte[][][] array) { + return Shape.of(computeShape(array, new long[3])); + } + + /** + * Compute the shape of a 4-dimensional byte array. + * + * @param array 4D array + * @return shape of the array + */ + public static Shape shapeOf(byte[][][][] array) { + return Shape.of(computeShape(array, new long[4])); + } + + /** + * Compute the shape of a 5-dimensional byte array. + * + * @param array 5D array + * @return shape of the array + */ + public static Shape shapeOf(byte[][][][][] array) { + return Shape.of(computeShape(array, new long[5])); + } + + /** + * Compute the shape of a 6-dimensional byte array. + * + * @param array 6D array + * @return shape of the array + */ + public static Shape shapeOf(byte[][][][][][] array) { + return Shape.of(computeShape(array, new long[6])); + } + + /** + * Compute the shape of a short array. + * + * @param array 1D array + * @return shape of the array + */ + public static Shape shapeOf(short[] array) { + return Shape.of(array.length); + } + + /** + * Compute the shape of a 2-dimensional short array. + * + * @param array 2D array + * @return shape of the array + */ + public static Shape shapeOf(short[][] array) { + return Shape.of(computeShape(array, new long[2])); + } + + /** + * Compute the shape of a 3-dimensional short array. + * + * @param array 3D array + * @return shape of the array + */ + public static Shape shapeOf(short[][][] array) { + return Shape.of(computeShape(array, new long[3])); + } + + /** + * Compute the shape of a 4-dimensional short array. + * + * @param array 4D array + * @return shape of the array + */ + public static Shape shapeOf(short[][][][] array) { + return Shape.of(computeShape(array, new long[4])); + } + + /** + * Compute the shape of a 5-dimensional short array. + * + * @param array 5D array + * @return shape of the array + */ + public static Shape shapeOf(short[][][][][] array) { + return Shape.of(computeShape(array, new long[5])); + } + + /** + * Compute the shape of a 6-dimensional short array. + * + * @param array 6D array + * @return shape of the array + */ + public static Shape shapeOf(short[][][][][][] array) { + return Shape.of(computeShape(array, new long[6])); + } + + /** + * Compute the shape of a boolean array. + * + * @param array 1D array + * @return shape of the array + */ + public static Shape shapeOf(boolean[] array) { + return Shape.of(array.length); + } + + /** + * Compute the shape of a 2-dimensional boolean array. + * + * @param array 2D array + * @return shape of the array + */ + public static Shape shapeOf(boolean[][] array) { + return Shape.of(computeShape(array, new long[2])); + } + + /** + * Compute the shape of a 3-dimensional boolean array. + * + * @param array 3D array + * @return shape of the array + */ + public static Shape shapeOf(boolean[][][] array) { + return Shape.of(computeShape(array, new long[3])); + } + + /** + * Compute the shape of a 4-dimensional boolean array. + * + * @param array 4D array + * @return shape of the array + */ + public static Shape shapeOf(boolean[][][][] array) { + return Shape.of(computeShape(array, new long[4])); + } + + /** + * Compute the shape of a 5-dimensional boolean array. + * + * @param array 5D array + * @return shape of the array + */ + public static Shape shapeOf(boolean[][][][][] array) { + return Shape.of(computeShape(array, new long[5])); + } + + /** + * Compute the shape of a 6-dimensional boolean array. + * + * @param array 6D array + * @return shape of the array + */ + public static Shape shapeOf(boolean[][][][][][] array) { + return Shape.of(computeShape(array, new long[6])); + } + + /** + * Compute the shape of an object array. + * + * @param array 1D array + * @param data type + * @return shape of the array + */ + public static Shape shapeOf(T[] array) { + return Shape.of(array.length); + } + + /** + * Compute the shape of a 2-dimensional object array. + * + * @param array 2D array + * @param data type + * @return shape of the array + */ + public static Shape shapeOf(T[][] array) { + return Shape.of(computeShape(array, new long[2])); + } + + /** + * Compute the shape of a 3-dimensional object array. + * + * @param array 3D array + * @param data type + * @return shape of the array + */ + public static Shape shapeOf(T[][][] array) { + return Shape.of(computeShape(array, new long[3])); + } + + /** + * Compute the shape of a 4-dimensional object array. + * + * @param array 4D array + * @param data type + * @return shape of the array + */ + public static Shape shapeOf(T[][][][] array) { + return Shape.of(computeShape(array, new long[4])); + } + + /** + * Compute the shape of a 5-dimensional object array. + * + * @param array 5D array + * @param data type + * @return shape of the array + */ + public static Shape shapeOf(T[][][][][] array) { + return Shape.of(computeShape(array, new long[5])); + } + + /** + * Compute the shape of a 6-dimensional object array. + * + * @param array 6D array + * @param data type + * @return shape of the array + */ + public static Shape shapeOf(T[][][][][][] array) { + return Shape.of(computeShape(array, new long[6])); + } + + private static void dimSize(int arrayLength, long[] shape, int dimIdx) { + if (shape[dimIdx] == 0) { + shape[dimIdx] = arrayLength; + } else if (shape[dimIdx] != arrayLength) { + shape[dimIdx] = Shape.UNKNOWN_SIZE; + } + } + + private static long[] computeShape(int[][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 2); + for (int i = 0; i < array.length; ++i) { + if (array[i] == null) { + throw new IllegalStateException("One of the subarray is null"); + } + dimSize(array[i].length, shape, shape.length - 1); + } + return shape; + } + + private static long[] computeShape(int[][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 3); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(int[][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 4); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(int[][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 5); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(int[][][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 6); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(long[][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 2); + for (int i = 0; i < array.length; ++i) { + if (array[i] == null) { + throw new IllegalStateException("One of the subarray is null"); + } + dimSize(array[i].length, shape, shape.length - 1); + } + return shape; + } + + private static long[] computeShape(long[][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 3); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(long[][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 4); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(long[][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 5); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(long[][][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 6); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(float[][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 2); + for (int i = 0; i < array.length; ++i) { + if (array[i] == null) { + throw new IllegalStateException("One of the subarray is null"); + } + dimSize(array[i].length, shape, shape.length - 1); + } + return shape; + } + + private static long[] computeShape(float[][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 3); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(float[][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 4); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(float[][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 5); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(float[][][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 6); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(double[][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 2); + for (int i = 0; i < array.length; ++i) { + if (array[i] == null) { + throw new IllegalStateException("One of the subarray is null"); + } + dimSize(array[i].length, shape, shape.length - 1); + } + return shape; + } + + private static long[] computeShape(double[][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 3); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(double[][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 4); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(double[][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 5); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(double[][][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 6); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(byte[][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 2); + for (int i = 0; i < array.length; ++i) { + if (array[i] == null) { + throw new IllegalStateException("One of the subarray is null"); + } + dimSize(array[i].length, shape, shape.length - 1); + } + return shape; + } + + private static long[] computeShape(byte[][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 3); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(byte[][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 4); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(byte[][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 5); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(byte[][][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 6); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(short[][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 2); + for (int i = 0; i < array.length; ++i) { + if (array[i] == null) { + throw new IllegalStateException("One of the subarray is null"); + } + dimSize(array[i].length, shape, shape.length - 1); + } + return shape; + } + + private static long[] computeShape(short[][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 3); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(short[][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 4); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(short[][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 5); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(short[][][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 6); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(boolean[][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 2); + for (int i = 0; i < array.length; ++i) { + if (array[i] == null) { + throw new IllegalStateException("One of the subarray is null"); + } + dimSize(array[i].length, shape, shape.length - 1); + } + return shape; + } + + private static long[] computeShape(boolean[][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 3); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(boolean[][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 4); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(boolean[][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 5); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(boolean[][][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 6); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(T[][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 2); + for (int i = 0; i < array.length; ++i) { + if (array[i] == null) { + throw new IllegalStateException("One of the subarray is null"); + } + dimSize(array[i].length, shape, shape.length - 1); + } + return shape; + } + + private static long[] computeShape(T[][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 3); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(T[][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 4); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(T[][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 5); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static long[] computeShape(T[][][][][][] array, long[] shape) { + if (array == null) { + throw new IllegalStateException("The array or one of its subarray is null"); + } + dimSize(array.length, shape, shape.length - 6); + for (int i = 0; i < array.length; ++i) { + computeShape(array[i], shape); + } + return shape; + } + + private static Class componentTypeOf(Object array) { + Class componentType = array.getClass().getComponentType(); + while (componentType.isArray()) { + componentType = componentType.getComponentType(); + } + return (Class) componentType; + } + + private static int[] computeArrayDims(NdArray ndArray, int expectedRank) { + Shape shape = ndArray.shape(); + if (shape.numDimensions() != expectedRank) { + throw new IllegalArgumentException("NdArray must be of rank " + expectedRank); + } + int[] arrayShape = new int[expectedRank]; + for (int i = 0; i < expectedRank; ++i) { + long dimSize = shape.get(i); + if (dimSize > Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "Dimension " + i + " is too large to fit in a standard array (" + shape.get(i) + ")"); + } + arrayShape[i] = (int) dimSize; + } + return arrayShape; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/BooleanDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/BooleanDataBuffer.java new file mode 100644 index 00000000000..f1a1dc3cacf --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/BooleanDataBuffer.java @@ -0,0 +1,159 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; +import java.nio.ReadOnlyBufferException; + +/** A {@link DataBuffer} of booleans. */ +public interface BooleanDataBuffer extends DataBuffer { + + /** + * Reads the boolean at the given index. + * + * @param index the index from which the float will be read + * @return the boolean at the given index + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + */ + boolean getBoolean(long index); + + /** + * Writes the given boolean into this buffer at the given index. + * + * @param value the boolean to be written + * @param index the index at which the value will be written + * @return this buffer + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + * @throws ReadOnlyBufferException if this buffer is read-only + */ + BooleanDataBuffer setBoolean(boolean value, long index); + + /** + * Bulk get method, using boolean arrays. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code + * dst.length > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = dst.length} values from this buffer into the given + * array. + * + * @param dst the array into which values are to be written + * @return this buffer + * @throws BufferUnderflowException if there are not enough values to copy from this buffer + */ + default BooleanDataBuffer read(boolean[] dst) { + return read(dst, 0, dst.length); + } + + /** + * Bulk get method, using boolean arrays. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code length + * > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from this buffer into the given + * array starting at the given offset. + * + * @param dst the array into which values are to be written + * @param offset the offset within the array of the first value to be written; must be + * non-negative and no larger than {@code dst.length} + * @param length the maximum number of values to be written to the given array; must be + * non-negative and no larger than {@code dst.length - offset} + * @return this buffer + * @throws BufferUnderflowException if there are fewer than length values remaining in this buffer + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + */ + BooleanDataBuffer read(boolean[] dst, int offset, int length); + + /** + * Bulk put method, using boolean arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code src.length > size()}, + * then no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = src.length} values from the given array. + * + * @param src the source array from which values are to be read + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws ReadOnlyBufferException if this buffer is read-only + */ + default BooleanDataBuffer write(boolean[] src) { + return write(src, 0, src.length); + } + + /** + * Bulk put method, using boolean arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code length > size()}, then + * no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from the given array into this + * buffer, starting at the given offset. + * + * @param src the source array from which values are to be read + * @param offset the offset within the array of the first value to be read; must be non-negative + * and no larger than {@code src.length} + * @param length the number of values to be read from the given array; must be non-negative and no + * larger than {@code src.length - offset} + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + * @throws ReadOnlyBufferException if this buffer is read-only + */ + BooleanDataBuffer write(boolean[] src, int offset, int length); + + @Override + default Boolean getObject(long index) { + return getBoolean(index); + } + + @Override + default BooleanDataBuffer setObject(Boolean value, long index) { + return setBoolean(value, index); + } + + @Override + BooleanDataBuffer copyTo(DataBuffer dst, long size); + + @Override + default BooleanDataBuffer offset(long index) { + return slice(index, size() - index); + } + + @Override + default BooleanDataBuffer narrow(long size) { + return slice(0, size); + } + + @Override + BooleanDataBuffer slice(long index, long size); + + @Override + default DataBufferWindow window(long size) { + throw new UnsupportedOperationException(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/ByteDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/ByteDataBuffer.java new file mode 100644 index 00000000000..72610170fb7 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/ByteDataBuffer.java @@ -0,0 +1,225 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; +import java.nio.ReadOnlyBufferException; + +/** A {@link DataBuffer} of bytes. */ +public interface ByteDataBuffer extends DataBuffer { + + /** + * Reads the byte at the given index. + * + * @param index the index from which the float will be read + * @return the byte at the given index + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + */ + byte getByte(long index); + + /** + * Writes the given byte into this buffer at the given index. + * + * @param value the byte to be written + * @param index the index at which the value will be written + * @return this buffer + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + * @throws ReadOnlyBufferException if this buffer is read-only + */ + ByteDataBuffer setByte(byte value, long index); + + /** + * Bulk get method, using byte arrays. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code + * dst.length > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = dst.length} values from this buffer into the given + * array. + * + * @param dst the array into which values are to be written + * @return this buffer + * @throws BufferUnderflowException if there are not enough values to copy from this buffer + */ + default ByteDataBuffer read(byte[] dst) { + return read(dst, 0, dst.length); + } + + /** + * Bulk get method, using byte arrays. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code length + * > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from this buffer into the given + * array starting at the given offset. + * + * @param dst the array into which values are to be written + * @param offset the offset within the array of the first value to be written; must be + * non-negative and no larger than {@code dst.length} + * @param length the maximum number of values to be written to the given array; must be + * non-negative and no larger than {@code dst.length - offset} + * @return this buffer + * @throws BufferUnderflowException if there are fewer than length values remaining in this buffer + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + */ + ByteDataBuffer read(byte[] dst, int offset, int length); + + /** + * Bulk put method, using byte arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code src.length > size()}, + * then no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = src.length} values from the given array. + * + * @param src the source array from which values are to be read + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws ReadOnlyBufferException if this buffer is read-only + */ + default ByteDataBuffer write(byte[] src) { + return write(src, 0, src.length); + } + + /** + * Bulk put method, using byte arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code length > size()}, then + * no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from the given array into this + * buffer, starting at the given offset. + * + * @param src the source array from which values are to be read + * @param offset the offset within the array of the first value to be read; must be non-negative + * and no larger than {@code src.length} + * @param length the number of values to be read from the given array; must be non-negative and no + * larger than {@code src.length - offset} + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + * @throws ReadOnlyBufferException if this buffer is read-only + */ + ByteDataBuffer write(byte[] src, int offset, int length); + + /** + * Return this byte buffer as a buffer of ints. + * + *

The returned buffer provides a different view on the same memory as the original byte + * buffer, meaning that changing a value in one will affect the other. + * + * @return this buffer as a {@link IntDataBuffer} + * @throws IllegalStateException if this buffer cannot be converted + */ + IntDataBuffer asInts(); + + /** + * Return this byte buffer as a buffer of shorts. + * + *

The returned buffer provides a different view on the same memory as the original byte + * buffer, meaning that changing a value in one will affect the other. + * + * @return this buffer as a {@link ShortDataBuffer} + * @throws IllegalStateException if this buffer cannot be converted + */ + ShortDataBuffer asShorts(); + + /** + * Return this byte buffer as a buffer of longs. + * + *

The returned buffer provides a different view on the same memory as the original byte + * buffer, meaning that changing a value in one will affect the other. + * + * @return this buffer as a {@link LongDataBuffer} + * @throws IllegalStateException if this buffer cannot be converted + */ + LongDataBuffer asLongs(); + + /** + * Return this byte buffer as a buffer of floats. + * + *

The returned buffer provides a different view on the same memory as the original byte + * buffer, meaning that changing a value in one will affect the other. + * + * @return this buffer as a {@link FloatDataBuffer} + * @throws IllegalStateException if this buffer cannot be converted + */ + FloatDataBuffer asFloats(); + + /** + * Return this byte buffer as a buffer of doubles. + * + *

The returned buffer provides a different view on the same memory as the original byte + * buffer, meaning that changing a value in one will affect the other. + * + * @return this buffer as a {@link DoubleDataBuffer} + * @throws IllegalStateException if this buffer cannot be converted + */ + DoubleDataBuffer asDoubles(); + + /** + * Return this byte buffer as a buffer of booleans. + * + *

The returned buffer provides a different view on the same memory as the original byte + * buffer, meaning that changing a value in one will affect the other. + * + * @return this buffer as a {@link BooleanDataBuffer} + * @throws IllegalStateException if this buffer cannot be converted + */ + BooleanDataBuffer asBooleans(); + + @Override + default Byte getObject(long index) { + return getByte(index); + } + + @Override + default ByteDataBuffer setObject(Byte value, long index) { + return setByte(value, index); + } + + @Override + ByteDataBuffer copyTo(DataBuffer dst, long size); + + @Override + default ByteDataBuffer offset(long index) { + return slice(index, size() - index); + } + + @Override + default ByteDataBuffer narrow(long size) { + return slice(0, size); + } + + @Override + ByteDataBuffer slice(long index, long size); + + @Override + default DataBufferWindow window(long size) { + throw new UnsupportedOperationException(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DataBuffer.java new file mode 100644 index 00000000000..e21c45d345c --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DataBuffer.java @@ -0,0 +1,322 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ + +package org.tensorflow.ndarray.buffer; + +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; +import java.nio.ReadOnlyBufferException; + +/** + * A container of data of a specific type. + * + *

Instances of {@code DataBuffer} map native or heap memory segments to a linear view that + * supports: + * + *

    + *
  • 64-bits indexing, allowing to work with buffer larger than 231 bytes + *
  • Storage of object of any types and not only primitives + *
  • Generic types allows to work directly with boxed types as well, which does not require + * explicit buffer types as with the standard JDK buffers. + *
+ * + * It is important to note that there is no guarantee the memory managed by a {@code DataBuffer} is + * linear, specially when dealing with non-primitive types or large buffers. + * + * @param type of data stored in this buffer + */ +public interface DataBuffer { + + /** + * Size of the buffer, in elements. + * + *

For exemple, in case of a byte buffer, this value is equal to the number of bytes this + * buffer can hold. For an integer buffer, it is equal to the number of integers, therefore the + * size in bytes of this buffer is {@code size() * Integer.BYTES}. + * + * @return the buffer size + */ + long size(); + + /** + * Tells whether or not this buffer is backed by an accessible array. + * + * @return true if, and only if, this buffer is read-only + */ + boolean isReadOnly(); + + /** + * Reads the value at the given index. + * + *

Important: Usage of this method should be limited to buffers of non-primitive types + * or when the data type is not deterministically known by the caller. In any other case, prefer + * the usage of its primitive variant which will significantly improve performances (e.g. {@code + * IntDataBuffer.getInt(idx)} + * + * @param index the index from which the float will be read + * @return the value at the given index + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + */ + T getObject(long index); + + /** + * Writes the given value into this buffer at the given index. + * + *

Important: Usage of this method should be limited to buffers of non-primitive types + * or when the data type is not deterministically known by the caller. In any other case, prefer + * the usage of its primitive variant which will significantly improve performances (e.g. {@code + * IntDataBuffer.setInt(idx)} + * + * @param value the value to be written + * @param index the index at which the value will be written + * @return this buffer + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + * @throws ReadOnlyBufferException if this buffer is read-only + */ + DataBuffer setObject(T value, long index); + + /** + * Read the references of the objects in this buffer into the destination array. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code + * dst.length > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = dst.length} values from this buffer into the given + * array. + * + * @param dst the array into which values are to be written + * @return this buffer + * @throws BufferUnderflowException if there are not enough values to copy from this buffer + */ + default DataBuffer read(T[] dst) { + return read(dst, 0, dst.length); + } + + /** + * Read the references of the objects in this buffer into the destination array. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code length + * > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from this buffer into the given + * array starting at the given offset. + * + * @param dst the array into which values are to be written + * @param offset the offset within the array of the first value to be written; must be + * non-negative and no larger than {@code dst.length} + * @param length the maximum number of values to be written to the given array; must be + * non-negative and no larger than {@code dst.length - offset} + * @return this buffer + * @throws BufferUnderflowException if there are fewer than length values remaining in this buffer + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + */ + DataBuffer read(T[] dst, int offset, int length); + + /** + * Write the references of the objects in the source array into this buffer. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code src.length > size()}, + * then no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = src.length} values from the given array. + * + * @param src the source array from which values are to be read + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws ReadOnlyBufferException if this buffer is read-only + */ + default DataBuffer write(T[] src) { + return write(src, 0, src.length); + } + + /** + * Bulk put method, using int arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code length > size()}, then + * no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from the given array into this + * buffer, starting at the given offset. + * + * @param src the source array from which values are to be read + * @param offset the offset within the array of the first value to be read; must be non-negative + * and no larger than {@code src.length} + * @param length the number of values to be read from the given array; must be non-negative and no + * larger than {@code src.length - offset} + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + * @throws ReadOnlyBufferException if this buffer is read-only + */ + DataBuffer write(T[] src, int offset, int length); + + /** + * Write the references of the objects in the source array into this buffer. + * + *

If there are more values to copy than the destination buffer size, i.e. {@code size > + * dst.size()}, then no values are transferred and a BufferOverflowException is thrown. On the + * other hand, if there are more values to copy that the source buffer size, i.e. {@code > + * src.size()}, then a BufferUnderfloatException is thrown. + * + *

Otherwise, this method copies {@code n = size} values from this buffer into the destination + * buffer. + * + * @param dst the destination buffer into which values are copied; must not be this buffer + * @param size number of values to copy to the destination buffer + * @return this buffer + * @throws IllegalArgumentException if the destination buffer is this buffer + * @throws ReadOnlyBufferException if the destination buffer is read-only + * @throws java.nio.BufferOverflowException if there is not enough space in destination buffer + * @throws java.nio.BufferUnderflowException if there are not enough values in the source buffer + */ + DataBuffer copyTo(DataBuffer dst, long size); + + /** + * Creates a new buffer whose content is a shared subsequence of this buffer's content, starting + * at the given index. + * + *

The index must not be greater than this buffer size. Changes to this buffer's content will + * be visible in the new buffer and vice versa. The new buffer will be read-only if, and only if, + * this buffer is read-only. + * + *

This call is equivalent to {@link #slice(long, long) slice(index, size() - index)} + * + * @param index index of the first value of the new buffer created, must not be greater than + * {@code size()} + * @return the new buffer + * @throws IllegalArgumentException if index do not pass validation checks + */ + default DataBuffer offset(long index) { + return slice(index, size() - index); + } + + /** + * Creates a new buffer whose content is a shared subsequence of this buffer's content, whose size + * is set to the given value. + * + *

The new size must not be greater than this buffer size. Changes to this buffer's content + * will be visible in the new buffer and vice versa. The new buffer will be read-only if, and only + * if, this buffer is read-only. + * + *

This call is equivalent to {@link #slice(long, long) slice(0, size)} + * + * @param size size of this new buffer + * @return the new buffer + * @throws IllegalArgumentException if index and/or size values do not pass validation checks + */ + default DataBuffer narrow(long size) { + return slice(0, size); + } + + /** + * Creates a new buffer whose content is a shared subsequence of this buffer's content, starting + * at the given index and of the given size. + * + *

The index plus the new size must not be greater than this buffer size. Changes to this + * buffer's content will be visible in the new buffer and vice versa. The new buffer will be + * read-only if, and only if, this buffer is read-only. + * + * @param index index of the first value of the new buffer created + * @param size size of this new buffer, must not be greater than {@code size()} + * @return the new buffer + * @throws IllegalArgumentException if size value do not pass validation checks + */ + DataBuffer slice(long index, long size); + + /** + * Creates a {@link DataBufferWindow} that provides a partial view of this buffer. + * + *

The created window has a fixed size and can {@link DataBufferWindow#slide(long) "slide"} + * along this buffer to provide different views of the data without allocating a new buffer + * instance, like {@link #offset(long)} does. This improves overall performance when this + * operation is repeated frequently. For example: + * + *

{@code
+   * IntDataBuffer bufferA = DataBuffers.ofInts(1024);
+   * // ... init buffer data
+   * IntDataBuffer bufferB = DataBuffers.ofInts(1, 2, 3, 4);
+   *
+   * // Return the index of the first occurrence of bufferB in bufferA using a sliding window
+   * DataBufferWindow windowA = bufferA.window(4);
+   * for (int i = 0; i < bufferA.size() - bufferB.size(); ++i) {
+   *     if (windowA.slideTo(i).buffer().equals(bufferB)) {
+   *         return i;
+   *     }
+   * }
+   * }
+ * + *

The returned object is stateful and is not thread-safe. + * + * @param size size of the window + * @return a new window that starts at the index 0 of this buffer + * @throws UnsupportedOperationException if this type of buffer does not support buffer windows + */ + default DataBufferWindow> window(long size) { + throw new UnsupportedOperationException(); + } + + /** + * Visits the backing storage of this buffer. + * + *

The buffer implementation is responsible of passing back a reference to the actual data + * storage to the provided visitor. The visitor does not have to handle all possible types of data + * storage and can override only methods for storage it is actually interested in. For any other + * type of storage, this call will fallback to {@link DataStorageVisitor#fallback()} so the + * visitor can execute some generic routine if needed. + * + * @param visitor visits the data storage of this buffer + * @param type of value returned by the visitor + * @return the same value returned by the visitor + */ + default R accept(DataStorageVisitor visitor) { + return visitor.fallback(); + } + + /** + * Checks equality between data buffers. + * + *

A data buffer is equal to another object if this object is another {@link DataBuffer} of the + * same size, type and the elements are equal and in the same order. For example: + * + *

{@code
+   * IntDataBuffer buffer = DataBuffers.of(1, 2, 3);
+   *
+   * assertEquals(buffer, DataBuffers.of(1, 2, 3));  // true
+   * assertEquals(buffer, DataBuffers.ofObjects(1, 2, 3));  // true, as Integers are equal to ints
+   * assertNotEquals(buffer, DataBuffers.of(1, 2, 3, 0));  // false, different sizes
+   * assertNotEquals(buffer, DataBuffers.of(1, 3, 2));  // false, different order
+   * assertNotEquals(buffer, DataBuffers.of(1L, 2L, 3L));  // false, different types
+   * }
+ * + *

Note that the computation required to verify equality between two buffers can be expensive + * in some cases and therefore, it is recommended to not use this method in a critical path where + * performances matter. + * + * @param obj object to compare this buffer with + * @return true if this buffer is equal to the provided object + */ + @Override + boolean equals(Object obj); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DataBufferWindow.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DataBufferWindow.java new file mode 100644 index 00000000000..c153200fd5f --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DataBufferWindow.java @@ -0,0 +1,88 @@ +package org.tensorflow.ndarray.buffer; + +/** + * A mutable container for viewing part of a {@link DataBuffer}. + * + *

Data buffer windows have a fixed size and can {@link DataBufferWindow#slide(long) "slide"} + * along a buffer to provide different views of the data without allocating a new buffer instance, + * like {@link DataBuffer#offset(long)} does. This improves overall performance when this operation + * is repeated frequently. For example: + * + *

{@code
+ * IntDataBuffer bufferA = DataBuffers.ofInts(1024);
+ * // ... init buffer data
+ * IntDataBuffer bufferB = DataBuffers.ofInts(1, 2, 3, 4);
+ *
+ * // Return the index of the first occurrence of bufferB in bufferA using a sliding window
+ * DataBufferWindow windowA = bufferA.window(4);
+ * for (int i = 0; i < bufferA.size() - bufferB.size(); ++i) {
+ *     if (windowA.slideTo(i).buffer().equals(bufferB)) {
+ *         return i;
+ *     }
+ * }
+ * }
+ * + *

{@code DataBufferWindow} instances are stateful and not thread-safe. + * + * @param the type of buffer being viewed + */ +public interface DataBufferWindow> { + + /** Returns the current offset of this window in the original buffer. */ + long offset(); + + /** Returns the size of this buffer window. */ + long size(); + + /** + * Moves the window at the given position in the original buffer. + * + *

The size of the window remains the same and its offset is set to {@code index}, so that + * accessing the value of {@link #buffer()} at index {@code x} will return the value at {@code + * index + x} in the original buffer. + * + * @param index new offset for this window + * @return this instance + * @throws IndexOutOfBoundsException if the window cannot be slid because it goes beyond the + * original buffer limits + */ + DataBufferWindow slideTo(long index); + + /** + * Moves the window of {@code step} elements in the original buffer. + * + *

The size of the window remains the same and its offset is set to {@code offset() + step}. If + * {@code step} is positive, then the window will slide forward. If it is negative, it will slide + * backward. + * + * @param step value to add to the current offset of this window + * @return this instance + * @throws IndexOutOfBoundsException if the window cannot be slid because it goes beyond the + * original buffer limits + */ + DataBufferWindow slide(long step); + + /** + * Returns the buffer backing this window. + * + *

Each window instance has it's own buffer providing a view onto the original {@link + * DataBuffer}. The buffers are mutated when the window slides to different offsets. For example: + * + *

{@code
+   * IntDataBuffer buffer = DataBuffers.of(0, 1, 2, 3);
+   * DataBufferWindow window = buffer.window(0, 2);
+   *
+   * IntDataBuffer windowBuffer = window.buffer();
+   * assertEquals(0, windowBuffer.getInt(0));
+   * assertEquals(1, windowBuffer.getInt(1));
+   *
+   * window.slideTo(2);
+   * assertEquals(2, windowBuffer.getInt(0));
+   * assertEquals(3, windowBuffer.getInt(1));
+   * assertSame(windowBuffer, window.buffer());
+   * }
+ * + * @return this window's buffer + */ + B buffer(); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DataBuffers.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DataBuffers.java new file mode 100644 index 00000000000..cc60aa8a68a --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DataBuffers.java @@ -0,0 +1,463 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import java.lang.reflect.Array; +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.nio.ShortBuffer; +import java.util.Arrays; +import java.util.BitSet; +import org.tensorflow.ndarray.impl.buffer.Validator; +import org.tensorflow.ndarray.impl.buffer.misc.MiscDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.nio.NioDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; + +/** Helper class for creating {@code DataBuffer} instances. */ +public final class DataBuffers { + + /** + * Creates a buffer of bytes that can store up to {@code size} values + * + * @param size size of the buffer to allocate + * @return a new buffer + */ + public static ByteDataBuffer ofBytes(long size) { + Validator.createArgs(size, MAX_32BITS); + if (RawDataBufferFactory.canBeUsed()) { + return RawDataBufferFactory.create(new byte[(int) size], false); + } + return NioDataBufferFactory.create(ByteBuffer.allocate((int) size)); + } + + /** + * Creates a buffer of longs that can store up to {@code size} values + * + * @param size size of the buffer to allocate + * @return a new buffer + */ + public static LongDataBuffer ofLongs(long size) { + Validator.createArgs(size, MAX_32BITS); + if (RawDataBufferFactory.canBeUsed()) { + return RawDataBufferFactory.create(new long[(int) size], false); + } + return NioDataBufferFactory.create(LongBuffer.allocate((int) size)); + } + + /** + * Creates a buffer of integers that can store up to {@code size} values + * + * @param size size of the buffer to allocate + * @return a new buffer + */ + public static IntDataBuffer ofInts(long size) { + Validator.createArgs(size, MAX_32BITS); + if (RawDataBufferFactory.canBeUsed()) { + return RawDataBufferFactory.create(new int[(int) size], false); + } + return NioDataBufferFactory.create(IntBuffer.allocate((int) size)); + } + + /** + * Creates a buffer of shorts that can store up to {@code size} values + * + * @param size size of the buffer to allocate + * @return a new buffer + */ + public static ShortDataBuffer ofShorts(long size) { + Validator.createArgs(size, MAX_32BITS); + if (RawDataBufferFactory.canBeUsed()) { + return RawDataBufferFactory.create(new short[(int) size], false); + } + return NioDataBufferFactory.create(ShortBuffer.allocate((int) size)); + } + + /** + * Creates a buffer of doubles that can store up to {@code size} values + * + * @param size size of the buffer to allocate + * @return a new buffer + */ + public static DoubleDataBuffer ofDoubles(long size) { + Validator.createArgs(size, MAX_32BITS); + if (RawDataBufferFactory.canBeUsed()) { + return RawDataBufferFactory.create(new double[(int) size], false); + } + return NioDataBufferFactory.create(DoubleBuffer.allocate((int) size)); + } + + /** + * Creates a buffer of floats that can store up to {@code size} values + * + * @param size size of the buffer to allocate + * @return a new buffer + */ + public static FloatDataBuffer ofFloats(long size) { + Validator.createArgs(size, MAX_32BITS); + if (RawDataBufferFactory.canBeUsed()) { + return RawDataBufferFactory.create(new float[(int) size], false); + } + return NioDataBufferFactory.create(FloatBuffer.allocate((int) size)); + } + + /** + * Creates a buffer of booleans that can store up to {@code size} values + * + * @param size size of the buffer to allocate + * @return a new buffer + */ + public static BooleanDataBuffer ofBooleans(long size) { + Validator.createArgs(size, MAX_32BITS); + if (RawDataBufferFactory.canBeUsed()) { + return RawDataBufferFactory.create(new boolean[(int) size], false); + } + return MiscDataBufferFactory.create(new BitSet((int) size), size, false); + } + + /** + * Creates a buffer of references to objects of type {@code clazz` that can store up to `size} + * values. + * + * @param type the type of object stored in this buffer + * @param size size of the buffer to allocate + * @param data type + * @return a new buffer + */ + public static DataBuffer ofObjects(Class type, long size) { + Validator.createArgs(size, MAX_32BITS); + @SuppressWarnings("unchecked") + T[] array = (T[]) Array.newInstance(type, (int) size); + return MiscDataBufferFactory.create(array, false); + } + + /** + * Create a buffer from an array of floats into a data buffer. + * + *

The returned buffer allows read and write operations and share the memory of the source + * array, which is equivalent to call {@link #of(float[], boolean, boolean) of(values, false, + * false}} + * + * @param values float values + * @return a new buffer + */ + public static FloatDataBuffer of(float... values) { + return of(values, false, false); + } + + /** + * Create a buffer from an array of bytes into a data buffer. + * + *

The returned buffer allows read and write operations and share the memory of the source + * array, which is equivalent to call {@link #of(byte[], boolean, boolean) of(values, false, + * false}} + * + * @param values byte values + * @return a new buffer + */ + public static ByteDataBuffer of(byte... values) { + return of(values, false, false); + } + + /** + * Create a buffer from an array of longs into a data buffer. + * + *

The returned buffer allows read and write operations and share the memory of the source + * array, which is equivalent to call {@link #of(long[], boolean, boolean) of(values, false, + * false}} + * + * @param values long values + * @return a new buffer + */ + public static LongDataBuffer of(long... values) { + return of(values, false, false); + } + + /** + * Create a buffer from an array of ints into a data buffer. + * + *

The returned buffer allows read and write operations and share the memory of the source + * array, which is equivalent to call {@link #of(int[], boolean, boolean) of(values, false, + * false}} + * + * @param values int values + * @return a new buffer + */ + public static IntDataBuffer of(int... values) { + return of(values, false, false); + } + + /** + * Create a buffer from an array of shorts into a data buffer. + * + *

The returned buffer allows read and write operations and share the memory of the source + * array, which is equivalent to call {@link #of(short[], boolean, boolean) of(values, false, + * false}} + * + * @param values short values + * @return a new buffer + */ + public static ShortDataBuffer of(short... values) { + return of(values, false, false); + } + + /** + * Create a buffer from an array of doubles into a data buffer. + * + *

The returned buffer allows read and write operations and share the memory of the source + * array, which is equivalent to call {@link #of(double[], boolean, boolean) of(array, false, + * false}} + * + * @param values double values + * @return a new buffer + */ + public static DoubleDataBuffer of(double... values) { + return of(values, false, false); + } + + /** + * Create a buffer from an array of booleans into a data buffer. + * + *

The returned buffer allows read and write operations and share the memory of the source + * array, which is equivalent to call {@link #of(boolean[], boolean, boolean) of(values, false, + * false}} + * + * @param values booleans values + * @return a new buffer + */ + public static BooleanDataBuffer of(boolean... values) { + return of(values, false, false); + } + + /** + * Create a buffer from an array of objects into a data buffer. + * + *

The returned buffer allows read and write operations and share the memory of the source + * array, which is equivalent to call {@link #of(Object[], boolean, boolean) of(values, false, + * false}} + * + * @param values objects values + * @param data type + * @return a new buffer + */ + @SafeVarargs + public static DataBuffer ofObjects(T... values) { + return of(values, false, false); + } + + /** + * Create a buffer from an array of floats into a data buffer. + * + * @param array array of floats + * @param readOnly true if the buffer created must be read-only + * @param makeCopy true if the array must be copied, false will wrap the provided array + * @return a new buffer + */ + public static FloatDataBuffer of(float[] array, boolean readOnly, boolean makeCopy) { + float[] bufferArray = makeCopy ? Arrays.copyOf(array, array.length) : array; + if (RawDataBufferFactory.canBeUsed()) { + return RawDataBufferFactory.create(bufferArray, readOnly); + } + FloatBuffer buf = FloatBuffer.wrap(bufferArray); + return NioDataBufferFactory.create(readOnly ? buf.asReadOnlyBuffer() : buf); + } + + /** + * Create a buffer from an array of bytes into a data buffer. + * + * @param array array of bytes + * @param readOnly true if the buffer created must be read-only + * @param makeCopy true if the array must be copied, false will wrap the provided array + * @return a new buffer + */ + public static ByteDataBuffer of(byte[] array, boolean readOnly, boolean makeCopy) { + byte[] bufferArray = makeCopy ? Arrays.copyOf(array, array.length) : array; + if (RawDataBufferFactory.canBeUsed()) { + return RawDataBufferFactory.create(bufferArray, readOnly); + } + ByteBuffer buf = ByteBuffer.wrap(bufferArray); + return NioDataBufferFactory.create(readOnly ? buf.asReadOnlyBuffer() : buf); + } + + /** + * Create a buffer from an array of longs into a data buffer. + * + * @param array array of longs + * @param readOnly true if the buffer created must be read-only + * @param makeCopy true if the array must be copied, false will wrap the provided array + * @return a new buffer + */ + public static LongDataBuffer of(long[] array, boolean readOnly, boolean makeCopy) { + long[] bufferArray = makeCopy ? Arrays.copyOf(array, array.length) : array; + if (RawDataBufferFactory.canBeUsed()) { + return RawDataBufferFactory.create(bufferArray, readOnly); + } + LongBuffer buf = LongBuffer.wrap(bufferArray); + return NioDataBufferFactory.create(readOnly ? buf.asReadOnlyBuffer() : buf); + } + + /** + * Create a buffer from an array of ints into a data buffer. + * + * @param array array of ints + * @param readOnly true if the buffer created must be read-only + * @param makeCopy true if the array must be copied, false will wrap the provided array + * @return a new buffer + */ + public static IntDataBuffer of(int[] array, boolean readOnly, boolean makeCopy) { + int[] bufferArray = makeCopy ? Arrays.copyOf(array, array.length) : array; + if (RawDataBufferFactory.canBeUsed()) { + return RawDataBufferFactory.create(bufferArray, readOnly); + } + IntBuffer buf = IntBuffer.wrap(bufferArray); + return NioDataBufferFactory.create(readOnly ? buf.asReadOnlyBuffer() : buf); + } + + /** + * Create a buffer from an array of shorts into a data buffer. + * + * @param array array of shorts + * @param readOnly true if the buffer created must be read-only + * @param makeCopy true if the array must be copied, false will wrap the provided array + * @return a new buffer + */ + public static ShortDataBuffer of(short[] array, boolean readOnly, boolean makeCopy) { + short[] bufferArray = makeCopy ? Arrays.copyOf(array, array.length) : array; + if (RawDataBufferFactory.canBeUsed()) { + return RawDataBufferFactory.create(bufferArray, readOnly); + } + ShortBuffer buf = ShortBuffer.wrap(bufferArray); + return NioDataBufferFactory.create(readOnly ? buf.asReadOnlyBuffer() : buf); + } + + /** + * Create a buffer from an array of doubles into a data buffer. + * + * @param array array of doubles + * @param readOnly true if the buffer created must be read-only + * @param makeCopy true if the array must be copied, false will wrap the provided array + * @return a new buffer + */ + public static DoubleDataBuffer of(double[] array, boolean readOnly, boolean makeCopy) { + double[] bufferArray = makeCopy ? Arrays.copyOf(array, array.length) : array; + if (RawDataBufferFactory.canBeUsed()) { + return RawDataBufferFactory.create(bufferArray, readOnly); + } + DoubleBuffer buf = DoubleBuffer.wrap(bufferArray); + return NioDataBufferFactory.create(readOnly ? buf.asReadOnlyBuffer() : buf); + } + + /** + * Create a buffer from an array of booleans into a data buffer. + * + * @param array array of booleans + * @param readOnly true if the buffer created must be read-only + * @param makeCopy true if the array must be copied, false will wrap the provided array + * @return a new buffer + */ + public static BooleanDataBuffer of(boolean[] array, boolean readOnly, boolean makeCopy) { + boolean[] bufferArray = makeCopy ? Arrays.copyOf(array, array.length) : array; + if (RawDataBufferFactory.canBeUsed()) { + return RawDataBufferFactory.create(bufferArray, readOnly); + } + return MiscDataBufferFactory.create(bufferArray, readOnly); + } + + /** + * Create a buffer from an array of objects into a data buffer. + * + * @param array array of objects + * @param readOnly true if the buffer created must be read-only + * @param makeCopy true if the array must be copied, false will wrap the provided array + * @param data type + * @return a new buffer + */ + public static DataBuffer of(T[] array, boolean readOnly, boolean makeCopy) { + T[] bufferArray = makeCopy ? Arrays.copyOf(array, array.length) : array; + return MiscDataBufferFactory.create(bufferArray, readOnly); + } + + /** + * Wraps a JDK NIO {@link ByteBuffer} into a data buffer. + * + * @param buf buffer to wrap + * @return a new buffer + */ + public static ByteDataBuffer of(ByteBuffer buf) { + return NioDataBufferFactory.create(buf.duplicate()); + } + + /** + * Wraps a JDK NIO {@link IntBuffer} into a data buffer. + * + * @param buf buffer to wrap + * @return a new buffer + */ + public static IntDataBuffer of(IntBuffer buf) { + return NioDataBufferFactory.create(buf.duplicate()); + } + + /** + * Wraps a JDK NIO {@link ShortBuffer} into a data buffer. + * + * @param buf buffer to wrap + * @return a new buffer + */ + public static ShortDataBuffer of(ShortBuffer buf) { + return NioDataBufferFactory.create(buf.duplicate()); + } + + /** + * Wraps a JDK NIO {@link LongBuffer} into a data buffer. + * + * @param buf buffer to wrap + * @return a new buffer + */ + public static LongDataBuffer of(LongBuffer buf) { + return NioDataBufferFactory.create(buf.duplicate()); + } + + /** + * Wraps a JDK NIO {@link FloatBuffer} into a data buffer. + * + * @param buf buffer to wrap + * @return a new buffer + */ + public static FloatDataBuffer of(FloatBuffer buf) { + return NioDataBufferFactory.create(buf.duplicate()); + } + + /** + * Wraps a JDK NIO {@link DoubleBuffer} into a data buffer. + * + * @param buf buffer to wrap + * @return a new buffer + */ + public static DoubleDataBuffer of(DoubleBuffer buf) { + return NioDataBufferFactory.create(buf.duplicate()); + } + + /* + * The maximum size for a buffer of this type, i.e. the maximum number of bytes it can store. + *

+ * As the maximum size may vary depending on the JVM implementation and on the platform, this + * property returns a value that is safe for most of them. + */ + static long MAX_32BITS = Integer.MAX_VALUE - 10; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DataStorageVisitor.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DataStorageVisitor.java new file mode 100644 index 00000000000..fa6ef03e570 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DataStorageVisitor.java @@ -0,0 +1,147 @@ +package org.tensorflow.ndarray.buffer; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.nio.ShortBuffer; +import java.util.BitSet; + +/** + * Visit the backing storage of {@link DataBuffer} instances. + * + * @param value type returned by the visitor + */ +public interface DataStorageVisitor { + + /** + * Visit the {@link ByteBuffer} backing a given instance of a {@link DataBuffer} + * + * @param buffer underlying buffer + * @return any value + * @see DataBuffer#accept(DataStorageVisitor) + */ + default R visit(ByteBuffer buffer) { + return fallback(); + } + + /** + * Visit the {@link ShortBuffer} backing a given instance of a {@link DataBuffer} + * + * @param buffer underlying buffer + * @return any value + * @see DataBuffer#accept(DataStorageVisitor) + */ + default R visit(ShortBuffer buffer) { + return fallback(); + } + + /** + * Visit the {@link IntBuffer} backing a given instance of a {@link DataBuffer} + * + * @param buffer underlying buffer + * @return any value + * @see DataBuffer#accept(DataStorageVisitor) + */ + default R visit(IntBuffer buffer) { + return fallback(); + } + + /** + * Visit the {@link LongBuffer} backing a given instance of a {@link DataBuffer} + * + * @param buffer underlying buffer + * @return any value + * @see DataBuffer#accept(DataStorageVisitor) + */ + default R visit(LongBuffer buffer) { + return fallback(); + } + + /** + * Visit the {@link FloatBuffer} backing a given instance of a {@link DataBuffer} + * + * @param buffer underlying buffer + * @return any value + * @see DataBuffer#accept(DataStorageVisitor) + */ + default R visit(FloatBuffer buffer) { + return fallback(); + } + + /** + * Visit the {@link DoubleBuffer} backing a given instance of a {@link DataBuffer} + * + * @param buffer underlying buffer + * @return any value + * @see DataBuffer#accept(DataStorageVisitor) + */ + default R visit(DoubleBuffer buffer) { + return fallback(); + } + + /** + * Visit the boolean array backing a given instance of a {@link DataBuffer} + * + * @param array underlying array + * @param offset offset of the buffer within the array + * @param length length of the buffer within the array + * @return any value + * @see DataBuffer#accept(DataStorageVisitor) + */ + default R visit(boolean[] array, int offset, int length) { + return fallback(); + } + + /** + * Visit the bit set backing a given instance of a {@link DataBuffer} + * + * @param bitSet underlying bit set + * @param offset offset of the buffer within the bit set + * @param numBits number of bits used to represent the buffer within the bit set + * @return any value + * @see DataBuffer#accept(DataStorageVisitor) + */ + default R visit(BitSet bitSet, int offset, long numBits) { + return fallback(); + } + + /** + * Visit the object array backing a given instance of a {@link DataBuffer} + * + * @param array underlying array + * @param offset offset of the buffer within the array + * @param length length of the buffer within the array + * @return any value + * @see DataBuffer#accept(DataStorageVisitor) + */ + default R visit(Object[] array, int offset, int length) { + return fallback(); + } + + /** + * Visit the raw memory segment of a given instance of a {@link DataBuffer} + * + * @param address native address of the buffer + * @param length length of the buffer + * @param scale number of bytes required to store a single value of this buffer + * @return any value + * @see DataBuffer#accept(DataStorageVisitor) + */ + default R visit(long address, long length, long scale) { + return fallback(); + } + + /** + * Fallback method called if the visitor implementation does not support the type of backing + * storage for a given {@link DataBuffer} + * + *

The implementor of this interface must override the {@code visit} methods for type of + * storage it supports. If {@link DataBuffer#accept(DataStorageVisitor)} is called on a buffer + * using a different type of storage, the invocation will fallback to this method. + * + * @return any value + */ + R fallback(); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DoubleDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DoubleDataBuffer.java new file mode 100644 index 00000000000..50367c38a06 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/DoubleDataBuffer.java @@ -0,0 +1,159 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; +import java.nio.ReadOnlyBufferException; + +/** A {@link DataBuffer} of doubles. */ +public interface DoubleDataBuffer extends DataBuffer { + + /** + * Reads the double at the given index. + * + * @param index the index from which the float will be read + * @return the double at the given index + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + */ + double getDouble(long index); + + /** + * Writes the given double into this buffer at the given index. + * + * @param value the double to be written + * @param index the index at which the value will be written + * @return this buffer + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + * @throws ReadOnlyBufferException if this buffer is read-only + */ + DoubleDataBuffer setDouble(double value, long index); + + /** + * Bulk get method, using double arrays. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code + * dst.length > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = dst.length} values from this buffer into the given + * array. + * + * @param dst the array into which values are to be written + * @return this buffer + * @throws BufferUnderflowException if there are not enough values to copy from this buffer + */ + default DoubleDataBuffer read(double[] dst) { + return read(dst, 0, dst.length); + } + + /** + * Bulk get method, using double arrays. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code length + * > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from this buffer into the given + * array starting at the given offset. + * + * @param dst the array into which values are to be written + * @param offset the offset within the array of the first value to be written; must be + * non-negative and no larger than {@code dst.length} + * @param length the maximum number of values to be written to the given array; must be + * non-negative and no larger than {@code dst.length - offset} + * @return this buffer + * @throws BufferUnderflowException if there are fewer than length values remaining in this buffer + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + */ + DoubleDataBuffer read(double[] dst, int offset, int length); + + /** + * Bulk put method, using double arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code src.length > size()}, + * then no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = src.length} values from the given array. + * + * @param src the source array from which values are to be read + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws ReadOnlyBufferException if this buffer is read-only + */ + default DoubleDataBuffer write(double[] src) { + return write(src, 0, src.length); + } + + /** + * Bulk put method, using double arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code length > size()}, then + * no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from the given array into this + * buffer, starting at the given offset. + * + * @param src the source array from which values are to be read + * @param offset the offset within the array of the first value to be read; must be non-negative + * and no larger than {@code src.length} + * @param length the number of values to be read from the given array; must be non-negative and no + * larger than {@code src.length - offset} + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + * @throws ReadOnlyBufferException if this buffer is read-only + */ + DoubleDataBuffer write(double[] src, int offset, int length); + + @Override + default Double getObject(long index) { + return getDouble(index); + } + + @Override + default DoubleDataBuffer setObject(Double value, long index) { + return setDouble(value, index); + } + + @Override + DoubleDataBuffer copyTo(DataBuffer dst, long size); + + @Override + default DoubleDataBuffer offset(long index) { + return slice(index, size() - index); + } + + @Override + default DoubleDataBuffer narrow(long size) { + return slice(0, size); + } + + @Override + DoubleDataBuffer slice(long index, long size); + + @Override + default DataBufferWindow window(long size) { + throw new UnsupportedOperationException(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/FloatDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/FloatDataBuffer.java new file mode 100644 index 00000000000..45e389a559e --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/FloatDataBuffer.java @@ -0,0 +1,159 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; +import java.nio.ReadOnlyBufferException; + +/** A {@link DataBuffer} of floats. */ +public interface FloatDataBuffer extends DataBuffer { + + /** + * Reads the float at the given index. + * + * @param index the index from which the float will be read + * @return the float at the given index + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + */ + float getFloat(long index); + + /** + * Writes the given float into this buffer at the given index. + * + * @param value the float to be written + * @param index the index at which the value will be written + * @return this buffer + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + * @throws ReadOnlyBufferException if this buffer is read-only + */ + FloatDataBuffer setFloat(float value, long index); + + /** + * Bulk get method, using float arrays. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code + * dst.length > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = dst.length} values from this buffer into the given + * array. + * + * @param dst the array into which values are to be written + * @return this buffer + * @throws BufferUnderflowException if there are not enough values to copy from this buffer + */ + default FloatDataBuffer read(float[] dst) { + return read(dst, 0, dst.length); + } + + /** + * Bulk get method, using float arrays. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code length + * > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from this buffer into the given + * array starting at the given offset. + * + * @param dst the array into which values are to be written + * @param offset the offset within the array of the first value to be written; must be + * non-negative and no larger than {@code dst.length} + * @param length the maximum number of values to be written to the given array; must be + * non-negative and no larger than {@code dst.length - offset} + * @return this buffer + * @throws BufferUnderflowException if there are fewer than length values remaining in this buffer + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + */ + FloatDataBuffer read(float[] dst, int offset, int length); + + /** + * Bulk put method, using float arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code src.length > size()}, + * then no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = src.length} values from the given array. + * + * @param src the source array from which values are to be read + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws ReadOnlyBufferException if this buffer is read-only + */ + default FloatDataBuffer write(float[] src) { + return write(src, 0, src.length); + } + + /** + * Bulk put method, using float arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code length > size()}, then + * no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from the given array into this + * buffer, starting at the given offset. + * + * @param src the source array from which values are to be read + * @param offset the offset within the array of the first value to be read; must be non-negative + * and no larger than {@code src.length} + * @param length the number of values to be read from the given array; must be non-negative and no + * larger than {@code src.length - offset} + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + * @throws ReadOnlyBufferException if this buffer is read-only + */ + FloatDataBuffer write(float[] src, int offset, int length); + + @Override + default Float getObject(long index) { + return getFloat(index); + } + + @Override + default FloatDataBuffer setObject(Float value, long index) { + return setFloat(value, index); + } + + @Override + FloatDataBuffer copyTo(DataBuffer dst, long size); + + @Override + default FloatDataBuffer offset(long index) { + return slice(index, size() - index); + } + + @Override + default FloatDataBuffer narrow(long size) { + return slice(0, size); + } + + @Override + FloatDataBuffer slice(long index, long size); + + @Override + default DataBufferWindow window(long size) { + throw new UnsupportedOperationException(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/IntDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/IntDataBuffer.java new file mode 100644 index 00000000000..52e3428f02d --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/IntDataBuffer.java @@ -0,0 +1,159 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; +import java.nio.ReadOnlyBufferException; + +/** A {@link DataBuffer} of ints. */ +public interface IntDataBuffer extends DataBuffer { + + /** + * Reads the int at the given index. + * + * @param index the index from which the float will be read + * @return the int at the given index + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + */ + int getInt(long index); + + /** + * Writes the given int into this buffer at the given index. + * + * @param value the int to be written + * @param index the index at which the value will be written + * @return this buffer + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + * @throws ReadOnlyBufferException if this buffer is read-only + */ + IntDataBuffer setInt(int value, long index); + + /** + * Bulk get method, using int arrays. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code + * dst.length > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = dst.length} values from this buffer into the given + * array. + * + * @param dst the array into which values are to be written + * @return this buffer + * @throws BufferUnderflowException if there are not enough values to copy from this buffer + */ + default IntDataBuffer read(int[] dst) { + return read(dst, 0, dst.length); + } + + /** + * Bulk get method, using int arrays. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code length + * > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from this buffer into the given + * array starting at the given offset. + * + * @param dst the array into which values are to be written + * @param offset the offset within the array of the first value to be written; must be + * non-negative and no larger than {@code dst.length} + * @param length the maximum number of values to be written to the given array; must be + * non-negative and no larger than {@code dst.length - offset} + * @return this buffer + * @throws BufferUnderflowException if there are fewer than length values remaining in this buffer + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + */ + IntDataBuffer read(int[] dst, int offset, int length); + + /** + * Bulk put method, using int arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code src.length > size()}, + * then no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = src.length} values from the given array. + * + * @param src the source array from which values are to be read + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws ReadOnlyBufferException if this buffer is read-only + */ + default IntDataBuffer write(int[] src) { + return write(src, 0, src.length); + } + + /** + * Bulk put method, using int arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code length > size()}, then + * no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from the given array into this + * buffer, starting at the given offset. + * + * @param src the source array from which values are to be read + * @param offset the offset within the array of the first value to be read; must be non-negative + * and no larger than {@code src.length} + * @param length the number of values to be read from the given array; must be non-negative and no + * larger than {@code src.length - offset} + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + * @throws ReadOnlyBufferException if this buffer is read-only + */ + IntDataBuffer write(int[] src, int offset, int length); + + @Override + default Integer getObject(long index) { + return getInt(index); + } + + @Override + default IntDataBuffer setObject(Integer value, long index) { + return setInt(value, index); + } + + @Override + IntDataBuffer copyTo(DataBuffer dst, long size); + + @Override + default IntDataBuffer offset(long index) { + return slice(index, size() - index); + } + + @Override + default IntDataBuffer narrow(long size) { + return slice(0, size); + } + + @Override + IntDataBuffer slice(long index, long size); + + @Override + default DataBufferWindow window(long size) { + throw new UnsupportedOperationException(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/LongDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/LongDataBuffer.java new file mode 100644 index 00000000000..89ae7ae3aed --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/LongDataBuffer.java @@ -0,0 +1,159 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; +import java.nio.ReadOnlyBufferException; + +/** A {@link DataBuffer} of longs. */ +public interface LongDataBuffer extends DataBuffer { + + /** + * Reads the long at the given index. + * + * @param index the index from which the float will be read + * @return the long at the given index + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + */ + long getLong(long index); + + /** + * Writes the given long into this buffer at the given index. + * + * @param value the long to be written + * @param index the index at which the value will be written + * @return this buffer + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + * @throws ReadOnlyBufferException if this buffer is read-only + */ + LongDataBuffer setLong(long value, long index); + + /** + * Bulk get method, using long arrays. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code + * dst.length > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = dst.length} values from this buffer into the given + * array. + * + * @param dst the array into which values are to be written + * @return this buffer + * @throws BufferUnderflowException if there are not enough values to copy from this buffer + */ + default LongDataBuffer read(long[] dst) { + return read(dst, 0, dst.length); + } + + /** + * Bulk get method, using long arrays. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code length + * > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from this buffer into the given + * array starting at the given offset. + * + * @param dst the array into which values are to be written + * @param offset the offset within the array of the first value to be written; must be + * non-negative and no larger than {@code dst.length} + * @param length the maximum number of values to be written to the given array; must be + * non-negative and no larger than {@code dst.length - offset} + * @return this buffer + * @throws BufferUnderflowException if there are fewer than length values remaining in this buffer + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + */ + LongDataBuffer read(long[] dst, int offset, int length); + + /** + * Bulk put method, using long arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code src.length > size()}, + * then no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = src.length} values from the given array. + * + * @param src the source array from which values are to be read + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws ReadOnlyBufferException if this buffer is read-only + */ + default LongDataBuffer write(long[] src) { + return write(src, 0, src.length); + } + + /** + * Bulk put method, using long arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code length > size()}, then + * no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from the given array into this + * buffer, starting at the given offset. + * + * @param src the source array from which values are to be read + * @param offset the offset within the array of the first value to be read; must be non-negative + * and no larger than {@code src.length} + * @param length the number of values to be read from the given array; must be non-negative and no + * larger than {@code src.length - offset} + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + * @throws ReadOnlyBufferException if this buffer is read-only + */ + LongDataBuffer write(long[] src, int offset, int length); + + @Override + default Long getObject(long index) { + return getLong(index); + } + + @Override + default LongDataBuffer setObject(Long value, long index) { + return setLong(value, index); + } + + @Override + LongDataBuffer copyTo(DataBuffer dst, long size); + + @Override + default LongDataBuffer offset(long index) { + return slice(index, size() - index); + } + + @Override + default LongDataBuffer narrow(long size) { + return slice(0, size); + } + + @Override + LongDataBuffer slice(long index, long size); + + @Override + default DataBufferWindow window(long size) { + throw new UnsupportedOperationException(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/ShortDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/ShortDataBuffer.java new file mode 100644 index 00000000000..1ae128d4e69 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/ShortDataBuffer.java @@ -0,0 +1,159 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; +import java.nio.ReadOnlyBufferException; + +/** A {@link DataBuffer} of shorts. */ +public interface ShortDataBuffer extends DataBuffer { + + /** + * Reads the short at the given index. + * + * @param index the index from which the float will be read + * @return the short at the given index + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + */ + short getShort(long index); + + /** + * Writes the given short into this buffer at the given index. + * + * @param value the short to be written + * @param index the index at which the value will be written + * @return this buffer + * @throws IndexOutOfBoundsException if index is negative or not smaller than the buffer size + * @throws ReadOnlyBufferException if this buffer is read-only + */ + ShortDataBuffer setShort(short value, long index); + + /** + * Bulk get method, using short arrays. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code + * dst.length > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = dst.length} values from this buffer into the given + * array. + * + * @param dst the array into which values are to be written + * @return this buffer + * @throws BufferUnderflowException if there are not enough values to copy from this buffer + */ + default ShortDataBuffer read(short[] dst) { + return read(dst, 0, dst.length); + } + + /** + * Bulk get method, using short arrays. + * + *

This method transfers values from this buffer into the given destination array. If there are + * fewer values in the buffer than are required to satisfy the request, that is, if {@code length + * > size()}, then no values are transferred and a BufferUnderflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from this buffer into the given + * array starting at the given offset. + * + * @param dst the array into which values are to be written + * @param offset the offset within the array of the first value to be written; must be + * non-negative and no larger than {@code dst.length} + * @param length the maximum number of values to be written to the given array; must be + * non-negative and no larger than {@code dst.length - offset} + * @return this buffer + * @throws BufferUnderflowException if there are fewer than length values remaining in this buffer + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + */ + ShortDataBuffer read(short[] dst, int offset, int length); + + /** + * Bulk put method, using short arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code src.length > size()}, + * then no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = src.length} values from the given array. + * + * @param src the source array from which values are to be read + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws ReadOnlyBufferException if this buffer is read-only + */ + default ShortDataBuffer write(short[] src) { + return write(src, 0, src.length); + } + + /** + * Bulk put method, using short arrays. + * + *

This method transfers the values in the given source array into this buffer. If there are + * more values in the source array than in this buffer, that is, if {@code length > size()}, then + * no values are transferred and a BufferOverflowException is thrown. + * + *

Otherwise, this method copies {@code n = length} values from the given array into this + * buffer, starting at the given offset. + * + * @param src the source array from which values are to be read + * @param offset the offset within the array of the first value to be read; must be non-negative + * and no larger than {@code src.length} + * @param length the number of values to be read from the given array; must be non-negative and no + * larger than {@code src.length - offset} + * @return this buffer + * @throws BufferOverflowException if there is insufficient space in this buffer for the values in + * the source array + * @throws IndexOutOfBoundsException if the preconditions on the offset and length parameters do + * not hold + * @throws ReadOnlyBufferException if this buffer is read-only + */ + ShortDataBuffer write(short[] src, int offset, int length); + + @Override + default Short getObject(long index) { + return getShort(index); + } + + @Override + default ShortDataBuffer setObject(Short value, long index) { + return setShort(value, index); + } + + @Override + ShortDataBuffer copyTo(DataBuffer dst, long size); + + @Override + default ShortDataBuffer offset(long index) { + return slice(index, size() - index); + } + + @Override + default ShortDataBuffer narrow(long size) { + return slice(0, size); + } + + @Override + ShortDataBuffer slice(long index, long size); + + @Override + default DataBufferWindow window(long size) { + throw new UnsupportedOperationException(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/BooleanDataLayout.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/BooleanDataLayout.java new file mode 100644 index 00000000000..fd69a957e69 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/BooleanDataLayout.java @@ -0,0 +1,66 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.ndarray.buffer.layout; + +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.impl.buffer.adapter.DataBufferAdapterFactory; + +/** + * A {@link DataLayout} that converts data stored in a buffer to booleans. + * + * @param type of buffer this layout can be applied to + * @see DataLayout + */ +public interface BooleanDataLayout> extends DataLayout { + + @Override + default BooleanDataBuffer applyTo(S buffer) { + return DataBufferAdapterFactory.create(buffer, this); + } + + /** + * Writes a boolean into the buffer at the given index after converting it to the buffer type. + * + * @param buffer the buffer to write to + * @param value the boolean to convert and write + * @param index index in the buffer where the converted value should be written + * @see #writeObject(DataBuffer, Boolean, long) + */ + void writeBoolean(S buffer, boolean value, long index); + + /** + * Reads {@code n = scale()} values from the buffer at the given index and return them as a + * boolean. + * + * @param buffer the buffer to read from + * @param index position of the buffer to read in the buffer + * @return the boolean value + * @see #readObject(DataBuffer, long) + */ + boolean readBoolean(S buffer, long index); + + @Override + default void writeObject(S buffer, Boolean value, long index) { + writeBoolean(buffer, value, index); + } + + @Override + default Boolean readObject(S buffer, long index) { + return readBoolean(buffer, index); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/ByteDataLayout.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/ByteDataLayout.java new file mode 100644 index 00000000000..f9d868dabdd --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/ByteDataLayout.java @@ -0,0 +1,65 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer.layout; + +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.impl.buffer.adapter.DataBufferAdapterFactory; + +/** + * A {@link DataLayout} that converts data stored in a buffer to bytes. + * + * @param type of buffer this layout can be applied to + * @see DataLayout + */ +public interface ByteDataLayout> extends DataLayout { + + @Override + default ByteDataBuffer applyTo(S buffer) { + return DataBufferAdapterFactory.create(buffer, this); + } + + /** + * Writes a byte into the buffer at the given index after converting it to the buffer type. + * + * @param buffer the buffer to write to + * @param value the byte to convert and write + * @param index index in the buffer where the converted value should be written + * @see #writeObject(DataBuffer, Byte, long) + */ + void writeByte(S buffer, byte value, long index); + + /** + * Reads {@code n = scale()} values from the buffer at the given index and return them as a byte. + * + * @param buffer the buffer to read from + * @param index position of the buffer to read in the buffer + * @return the byte value + * @see #readObject(DataBuffer, long) + */ + byte readByte(S buffer, long index); + + @Override + default void writeObject(S buffer, Byte value, long index) { + writeByte(buffer, value, index); + } + + @Override + default Byte readObject(S buffer, long index) { + return readByte(buffer, index); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/DataLayout.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/DataLayout.java new file mode 100644 index 00000000000..97c26530ddd --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/DataLayout.java @@ -0,0 +1,134 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.ndarray.buffer.layout; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.impl.buffer.adapter.DataBufferAdapterFactory; + +/** + * Converts data stored in a buffer to a given type. + * + *

{@code DataLayout} instances are used to define a custom format for storing and reading data + * of a {@link DataBuffer}. They provide a segregation layer between the type of data stored in the + * buffer (the buffer type) and the type of data manipulated by the end user (the user + * type). + * + *

Since the conversion methods are invoked for every value that is written or read, working with + * data layouts may have a negative impact on the performances so using primitive types directly + * should be preferred whenever possible. + * + *

It is also recommended to implement immutable data layouts so they can be reapplied to + * multiple buffers without reallocating a new instance for each of them. For example: + * + *

+ * class BigIntegerBufferAllocator {
+ *
+ *     public DataBuffer<BigInteger> allocate(long size) {
+ *         return LAYOUT.applyTo(DataBuffers.ofLongs(size * LAYOUT.scale()));  // scale is 1 by default
+ *     }
+ *
+ *     private static final DataLayout<LongDataBuffer, BigInteger> LAYOUT = new DataLayout<LongDataBuffer, BigInteger>() {
+ *
+ *         @Override
+ *         public void writeObject(LongDataBuffer buffer, BigInteger value, long index) {
+ *             buffer.setLong(value.longValue(), index);
+ *         }
+ *
+ *         @Override
+ *         public BigInteger readObject(LongDataBuffer buffer, long index) {
+ *             return BigInteger.valueOf(buffer.getLong(index));
+ *         }
+ *     };
+ * }
+ * 
+ * + * @param type of buffer this layout can be applied to + * @param user data type of this layout + */ +public interface DataLayout, T> { + + /** + * Apply this layout to the provided buffer. + * + *

The returned {@link DataBuffer} instance is simply a wrapper to the original buffer and does + * not have a backing storage of his own. + * + * @param buffer the target buffer to apply this layout to + * @return a buffer with this layout + */ + default DataBuffer applyTo(S buffer) { + return DataBufferAdapterFactory.create(buffer, this); + } + + /** + * Writes a user value into the buffer at the given index after converting it to the buffer type. + * + *

It is the responsibility of the implementors of this interface to write the converted value + * to the given buffer before this call returns, using the most appropriate method. For example, + * for a layout converting a {@code BigInteger} to a single {@code long}, + * + *

+   * @Override
+   * public void writeObject(LongDataBuffer buffer, BigInteger value, long index) {
+   *   buffer.setLong(value.longValue(), index);
+   * }
+   * 
+ * + * If a single user value scales over more than one buffer values, {@code index} indicates the + * starting position of the sequence to be written to the buffer. + * + * @param buffer the buffer to write to + * @param value the value in the user type to convert and write + * @param index index in the buffer where the converted value should be written + */ + void writeObject(S buffer, T value, long index); + + /** + * Reads {@code n = scale()} values from the buffer at the given index and return them as a single + * value in the user type. + * + *

It is the responsibility of the implementors of this interface to read the value to be + * converted from the given buffer, using the most appropriate method. For example, for a layout + * that converting a single {@code long} to a {@code BigInteger}, + * + *

+   * @Override
+   * public BigInteger readObject(LongDataBuffer buffer, long index) {
+   *   return BigInteger.valueOf(buffer.getLong(index));
+   * }
+   * 
+ * + * If a single user value scales over more than one buffer values, {@code index} indicates the + * starting position of the sequence to be read from the buffer. + * + * @param buffer the buffer to read from + * @param index position of the buffer to read in the buffer + * @return the converted value + */ + T readObject(S buffer, long index); + + /** + * Indicates the number of buffer values are required to represent a single user value, default is + * 1. + * + *

Scale must be positive and must be an integer, meaning that a single buffer value in a + * buffer cannot be used to represent more than one user value. + */ + default int scale() { + return 1; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/DataLayouts.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/DataLayouts.java new file mode 100644 index 00000000000..e58ca550636 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/DataLayouts.java @@ -0,0 +1,99 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.buffer.layout; + +import java.nio.charset.Charset; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.impl.buffer.layout.Bfloat16Layout; +import org.tensorflow.ndarray.impl.buffer.layout.BoolLayout; +import org.tensorflow.ndarray.impl.buffer.layout.Float16Layout; +import org.tensorflow.ndarray.impl.buffer.layout.StringLayout; + +/** + * Exposes {@link DataLayout} instances of data formats frequently used in linear algebra + * computation. + * + *

Example of usage: + * + *

{@code
+ * // Storing boolean values in a ByteDataBuffer
+ * BooleanDataBuffer boolBuffer = DataLayouts.BOOL.applyTo(byteDataBuffer);
+ *
+ * // Allocating a new buffer of 256 half floats
+ * FloatDataBuffer halfBuffer = DataLayouts.FLOAT16.applyTo(DataBuffers.ofShorts(256 * DataLayouts.FLOAT16.scale());
+ * }
+ */ +public final class DataLayouts { + + /** + * Data layout for converting 16-bit bfloats to/from short values. + * + *

This format used to be specific to TensorFlow but has now been adopted more broadly in the + * machine learning field. It is optimized for fast conversion with single-precision 32-bit + * floating points by simply shifting their value and truncating the mantissa to only 7 bits. + * + *

Therefore, this is a lost of precision in the fraction part compared to the IEEE-754 + * half-precision floating point specification (see {@link #FLOAT16} but it has a larger range of + * possible values in the whole part as it preserves the 8-bit exponent and uses the same bias, + * (i.e. an absolute range above 0 of approximately [10-40, 3.39 × + * 1038] + * + *

Some CPUs support the bfloat16 format natively for better performances. + */ + public static final FloatDataLayout BFLOAT16 = new Bfloat16Layout(); + + /** + * Data layout for converting 16-bit half floats to/from short values. + * + *

Half floats are stored in memory accordingly to the IEEE-754 half-precision floating point + * specification, and are converted to/from 32-bit floats in the user space. + * + *

There is a potential loss of precision when converting a single float (32-bit) to a half + * float (16-bit). Absolute range of values above 0 for a half float is approximately [5.96 × + * 10-8, 6.55 × 104] and their decimal part is rounded up to a 10 bits + * mantissa. + * + *

In general, half float computation perform better on GPUs since, in general, CPUs do not + * support this format natively. + */ + public static final FloatDataLayout FLOAT16 = new Float16Layout(); + + /** + * Data layout for converting booleans to/from byte values. + * + *

Since there is no Java NIO boolean buffer, this layout is particularly useful for mapping + * booleans values to standard byte buffers. The conversion between a boolean and a byte requires + * explicit type casting. + */ + public static final BooleanDataLayout BOOL = new BoolLayout(); + + /** + * Creates a data layout for converting strings to/from byte sequences. + * + *

This layout requires a {@code charset} in parameter to specify how the strings must be + * encoded/decoded as byte sequences. So a new layout instance is always returned. + * + * @param charset charset to use + * @return a new string layout + */ + public static DataLayout, String> ofStrings(Charset charset) { + return StringLayout.of(charset); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/DoubleDataLayout.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/DoubleDataLayout.java new file mode 100644 index 00000000000..efd1e461802 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/DoubleDataLayout.java @@ -0,0 +1,65 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.ndarray.buffer.layout; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.impl.buffer.adapter.DataBufferAdapterFactory; + +/** + * A {@link DataLayout} that converts data stored in a buffer to doubles. + * + * @param type of buffer this layout can be applied to + * @see DataLayout + */ +public interface DoubleDataLayout> extends DataLayout { + + @Override + default DoubleDataBuffer applyTo(S buffer) { + return DataBufferAdapterFactory.create(buffer, this); + } + + /** + * Writes a double into the buffer at the given index after converting it to the buffer type. + * + * @param buffer the buffer to write to + * @param value the double to convert and write + * @param index index in the buffer where the converted value should be written + * @see #writeObject(DataBuffer, Double, long) + */ + void writeDouble(S buffer, double value, long index); + + /** + * Reads {@code n = scale()} buffer values at the given index and return them as a double. + * + * @param buffer the buffer to read from + * @param index position of the buffer to read in the buffer + * @return the double value + * @see #readObject(DataBuffer, long) + */ + double readDouble(S buffer, long index); + + @Override + default void writeObject(S buffer, Double value, long index) { + writeDouble(buffer, value, index); + } + + @Override + default Double readObject(S buffer, long index) { + return readDouble(buffer, index); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/FloatDataLayout.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/FloatDataLayout.java new file mode 100644 index 00000000000..a57f525d69f --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/FloatDataLayout.java @@ -0,0 +1,65 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.ndarray.buffer.layout; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.impl.buffer.adapter.DataBufferAdapterFactory; + +/** + * A {@link DataLayout} that converts data stored in a buffer to floats. + * + * @param type of buffer this layout can be applied to + * @see DataLayout + */ +public interface FloatDataLayout> extends DataLayout { + + @Override + default FloatDataBuffer applyTo(S buffer) { + return DataBufferAdapterFactory.create(buffer, this); + } + + /** + * Writes a float into the buffer at the given index after converting it to the buffer type. + * + * @param buffer the buffer to write to + * @param value the float to convert and write + * @param index index in the buffer where the converted value should be written + * @see #writeObject(DataBuffer, Float, long) + */ + void writeFloat(S buffer, float value, long index); + + /** + * Reads {@code n = scale()} values from the buffer at the given index and return them as a float. + * + * @param buffer the buffer to read from + * @param index position of the buffer to read in the buffer + * @return the float value + * @see #readObject(DataBuffer, long) + */ + float readFloat(S buffer, long index); + + @Override + default void writeObject(S buffer, Float value, long index) { + writeFloat(buffer, value, index); + } + + @Override + default Float readObject(S buffer, long index) { + return readFloat(buffer, index); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/IntDataLayout.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/IntDataLayout.java new file mode 100644 index 00000000000..718deac9b9f --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/IntDataLayout.java @@ -0,0 +1,66 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.buffer.layout; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.impl.buffer.adapter.DataBufferAdapterFactory; + +/** + * A {@link DataLayout} that converts data stored in a buffer to ints. + * + * @param type of buffer this layout can be applied to + * @see DataLayout + */ +public interface IntDataLayout> extends DataLayout { + + @Override + default IntDataBuffer applyTo(S buffer) { + return DataBufferAdapterFactory.create(buffer, this); + } + + /** + * Writes a int into the buffer at the given index after converting it to the buffer type. + * + * @param buffer the buffer to write to + * @param value the int to convert and write + * @param index index in the buffer where the converted value should be written + * @see #writeObject(DataBuffer, Integer, long) + */ + void writeInt(S buffer, int value, long index); + + /** + * Reads {@code n = scale()} values from the buffer at the given index and return them as an int. + * + * @param buffer the buffer to read from + * @param index position of the buffer to read in the buffer + * @return the int value + * @see #readObject(DataBuffer, long) + */ + int readInt(S buffer, long index); + + @Override + default void writeObject(S buffer, Integer value, long index) { + writeInt(buffer, value, index); + } + + @Override + default Integer readObject(S buffer, long index) { + return readInt(buffer, index); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/LongDataLayout.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/LongDataLayout.java new file mode 100644 index 00000000000..de8fddc8407 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/LongDataLayout.java @@ -0,0 +1,65 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.ndarray.buffer.layout; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.impl.buffer.adapter.DataBufferAdapterFactory; + +/** + * A {@link DataLayout} that converts data stored in a buffer to longs. + * + * @param type of buffer this layout can be applied to + * @see DataLayout + */ +public interface LongDataLayout> extends DataLayout { + + @Override + default LongDataBuffer applyTo(S buffer) { + return DataBufferAdapterFactory.create(buffer, this); + } + + /** + * Writes a long into the buffer at the given index after converting it to the buffer type. + * + * @param buffer the buffer to write to + * @param value the long to convert and write + * @param index index in the buffer where the converted value should be written + * @see #writeObject(DataBuffer, Long, long) + */ + void writeLong(S buffer, long value, long index); + + /** + * Reads {@code n = scale()} values from the buffer at the given index and return them as a long. + * + * @param buffer the buffer to read from + * @param index position of the buffer to read in the buffer + * @return the long value + * @see #readObject(DataBuffer, long) + */ + long readLong(S buffer, long index); + + @Override + default void writeObject(S buffer, Long value, long index) { + writeLong(buffer, value, index); + } + + @Override + default Long readObject(S buffer, long index) { + return readLong(buffer, index); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/ShortDataLayout.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/ShortDataLayout.java new file mode 100644 index 00000000000..89c1fd0dec4 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/ShortDataLayout.java @@ -0,0 +1,65 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ +package org.tensorflow.ndarray.buffer.layout; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.impl.buffer.adapter.DataBufferAdapterFactory; + +/** + * A {@link DataLayout} that converts data stored in a buffer to shorts. + * + * @param type of buffer this layout can be applied to + * @see DataLayout + */ +public interface ShortDataLayout> extends DataLayout { + + @Override + default ShortDataBuffer applyTo(S buffer) { + return DataBufferAdapterFactory.create(buffer, this); + } + + /** + * Writes a short into the buffer at the given index after converting it to the buffer type. + * + * @param buffer the buffer to write to + * @param value the short to convert and write + * @param index index in the buffer where the converted value should be written + * @see #writeObject(DataBuffer, Short, long) + */ + void writeShort(S buffer, short value, long index); + + /** + * Reads {@code n = scale()} buffer values at the given index and return them as a short. + * + * @param buffer the buffer to read from + * @param index position of the value to read in the buffer + * @return the short value + * @see #readObject(DataBuffer, long) + */ + short readShort(S buffer, long index); + + @Override + default void writeObject(S buffer, Short value, long index) { + writeShort(buffer, value, index); + } + + @Override + default Short readObject(S buffer, long index) { + return readShort(buffer, index); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/AbstractNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/AbstractNdArray.java new file mode 100644 index 00000000000..41f2cb5977f --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/AbstractNdArray.java @@ -0,0 +1,98 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl; + +import java.util.Iterator; +import java.util.Objects; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArraySequence; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +@SuppressWarnings("unchecked") +public abstract class AbstractNdArray> implements NdArray { + + protected final DimensionalSpace dimensions; + + protected AbstractNdArray(DimensionalSpace dimensions) { + this.dimensions = dimensions; + } + + public abstract U slice(long position, DimensionalSpace dimensions); + + public DimensionalSpace dimensions() { + return dimensions; + } + + @Override + public Shape shape() { + return dimensions.shape(); + } + + @Override + public NdArraySequence scalars() { + // negative if this array is a scalar, should be handled in `elements(dimIdx)` + return (NdArraySequence) elements(shape().numDimensions() - 1); + } + + @Override + public int hashCode() { + return slowHashCode(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof NdArray)) { + return false; + } + return slowEquals((NdArray) obj); + } + + protected void slowCopyTo(NdArray array) { + scalars().forEachIndexed((coords, e) -> array.setObject(e.getObject(), coords)); + } + + protected int slowHashCode() { + final int prime = 31; + int result = 1; + for (NdArray scalar : scalars()) { + result = prime * result + scalar.getObject().hashCode(); + } + result = prime * result + shape().hashCode(); + return result; + } + + protected boolean slowEquals(NdArray array) { + if (!shape() + .equals( + array.shape())) { // this guarantees also that we have the same number of scalar values + return false; + } + for (Iterator> thisIter = scalars().iterator(), + otherIter = array.scalars().iterator(); + thisIter.hasNext(); ) { + // Use Object.equals to handle nulls. + if (!Objects.equals(thisIter.next().getObject(), otherIter.next().getObject())) { + return false; + } + } + return true; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/Validator.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/Validator.java new file mode 100644 index 00000000000..da7ca2354e0 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/Validator.java @@ -0,0 +1,59 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl; + +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.buffer.DataBuffer; + +public class Validator { + + public static void copyToNdArrayArgs(NdArray ndArray, NdArray otherNdArray) { + if (!ndArray.shape().equals(otherNdArray.shape())) { + throw new IllegalArgumentException( + "Can only copy to arrays of the same shape (" + + ndArray.shape() + + " != " + + otherNdArray.shape() + + ")"); + } + } + + public static void copyToBufferArgs(NdArray ndArray, DataBuffer dst) { + if (dst.size() < ndArray.size()) { + throw new BufferOverflowException(); + } + } + + public static void copyFromBufferArgs(NdArray ndArray, DataBuffer src) { + if (src.size() < ndArray.size()) { + throw new BufferUnderflowException(); + } + } + + private static void copyArrayArgs(int arrayLength, int arrayOffset) { + if (arrayOffset < 0) { + throw new IndexOutOfBoundsException("Offset must be non-negative"); + } + if (arrayOffset > arrayLength) { + throw new IndexOutOfBoundsException("Offset must be no larger than array length"); + } + } + + protected Validator() {} +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/AbstractDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/AbstractDataBuffer.java new file mode 100644 index 00000000000..5de34b23f7e --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/AbstractDataBuffer.java @@ -0,0 +1,206 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import org.tensorflow.ndarray.buffer.DataBuffer; + +public abstract class AbstractDataBuffer implements DataBuffer { + + @Override + public DataBuffer read(T[] dst, int offset, int length) { + Validator.readArgs(this, dst.length, offset, length); + for (int i = 0; i < length; ++i) { + dst[i + offset] = getObject(i); + } + return this; + } + + @Override + public DataBuffer write(T[] src, int offset, int length) { + Validator.writeArgs(this, src.length, offset, length); + for (int i = 0; i < length; ++i) { + setObject(src[i + offset], i); + } + return this; + } + + @Override + public DataBuffer copyTo(DataBuffer dst, long size) { + return slowCopyTo(dst, size); + } + + @Override + public int hashCode() { + // This hash code computation is generic to all types of data buffers and accurate but not + // optimized + // for performances, it needs to be improved if there is a present use case for such hash codes. + return slowHashCode(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof DataBuffer)) { + return false; + } + return slowEquals((DataBuffer) obj); + } + + @SuppressWarnings("unchecked") + protected > U slowCopyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + for (long idx = 0L; idx < size; ++idx) { + dst.setObject(getObject(idx), idx); + } + return (U) this; + } + + protected int slowHashCode() { + final int prime = 31; + int result = 1; + + // First check from the first non-null element if we are dealing with a buffer of arrays + long idx = 0L; + for (; idx < size(); ++idx) { + T o = getObject(idx); + if (o != null) { + if (o.getClass().isArray()) { + result = + prime * result + + arrayHashCode(idx, o.getClass()); // compute hash codes based on array elements + return result; + } + result = prime * result + o.hashCode(); + break; // continue hash code computation without array type check + } + result = prime * result; + } + while (++idx < size()) { + result = prime * result + Objects.hashCode(getObject(idx)); + } + return result; + } + + protected boolean slowEquals(DataBuffer other) { + if (other.size() != size()) { + return false; + } + long idx = 0L; + for (; idx < size(); ++idx) { + Object thisObject = getObject(idx); + if (thisObject != null) { + if (thisObject.getClass().isArray()) { + return arrayEquals(idx, thisObject.getClass(), other); + } + if (!Objects.equals(other.getObject(idx), thisObject)) { + return false; + } + break; // continue equality comparison without array type check + } + if (other.getObject(idx) != null) { + return false; + } + } + while (++idx < size()) { + if (!Objects.equals(other.getObject(idx), getObject(idx))) { + return false; + } + } + return true; + } + + private int arrayHashCode(long startIdx, Class arrayClass) { + ArrayHashCoder hashCoder = ARRAY_HASH_CODERS.getOrDefault(arrayClass, DEFAULT_ARRAY_HASH_CODER); + final int prime = 31; + int result = 1; + for (long idx = startIdx; idx < size(); ++idx) { + result = prime * result + hashCoder.hashCode(this, idx); + } + return result; + } + + private boolean arrayEquals(long startIdx, Class arrayClass, DataBuffer other) { + ArrayComparator comparator = + ARRAY_COMPARATORS.getOrDefault(arrayClass, DEFAULT_ARRAY_COMPARATOR); + for (long idx = startIdx; idx < size(); ++idx) { + if (!comparator.equals(this, other, idx)) { + return false; + } + } + return true; + } + + @FunctionalInterface + private static interface ArrayHashCoder { + int hashCode(DataBuffer buffer, long index); + } + + private static final Map, ArrayHashCoder> ARRAY_HASH_CODERS = new HashMap<>(); + private static final ArrayHashCoder DEFAULT_ARRAY_HASH_CODER; + + @FunctionalInterface + private static interface ArrayComparator { + boolean equals(DataBuffer buffer, DataBuffer otherBuffer, long index); + } + + private static final Map, ArrayComparator> ARRAY_COMPARATORS = new HashMap<>(); + private static final ArrayComparator DEFAULT_ARRAY_COMPARATOR; + + static { + ARRAY_HASH_CODERS.put(byte[].class, (b, idx) -> Arrays.hashCode((byte[]) b.getObject(idx))); + ARRAY_HASH_CODERS.put(int[].class, (b, idx) -> Arrays.hashCode((int[]) b.getObject(idx))); + ARRAY_HASH_CODERS.put(short[].class, (b, idx) -> Arrays.hashCode((short[]) b.getObject(idx))); + ARRAY_HASH_CODERS.put(long[].class, (b, idx) -> Arrays.hashCode((long[]) b.getObject(idx))); + ARRAY_HASH_CODERS.put(float[].class, (b, idx) -> Arrays.hashCode((float[]) b.getObject(idx))); + ARRAY_HASH_CODERS.put(double[].class, (b, idx) -> Arrays.hashCode((double[]) b.getObject(idx))); + ARRAY_HASH_CODERS.put( + boolean[].class, (b, idx) -> Arrays.hashCode((boolean[]) b.getObject(idx))); + DEFAULT_ARRAY_HASH_CODER = (b, idx) -> Arrays.deepHashCode((Object[]) b.getObject(idx)); + + ARRAY_COMPARATORS.put( + byte[].class, + (b1, b2, idx) -> Arrays.equals((byte[]) b1.getObject(idx), (byte[]) b2.getObject(idx))); + ARRAY_COMPARATORS.put( + int[].class, + (b1, b2, idx) -> Arrays.equals((int[]) b1.getObject(idx), (int[]) b2.getObject(idx))); + ARRAY_COMPARATORS.put( + short[].class, + (b1, b2, idx) -> Arrays.equals((short[]) b1.getObject(idx), (short[]) b2.getObject(idx))); + ARRAY_COMPARATORS.put( + long[].class, + (b1, b2, idx) -> Arrays.equals((long[]) b1.getObject(idx), (long[]) b2.getObject(idx))); + ARRAY_COMPARATORS.put( + float[].class, + (b1, b2, idx) -> Arrays.equals((float[]) b1.getObject(idx), (float[]) b2.getObject(idx))); + ARRAY_COMPARATORS.put( + double[].class, + (b1, b2, idx) -> Arrays.equals((double[]) b1.getObject(idx), (double[]) b2.getObject(idx))); + ARRAY_COMPARATORS.put( + boolean[].class, + (b1, b2, idx) -> + Arrays.equals((boolean[]) b1.getObject(idx), (boolean[]) b2.getObject(idx))); + DEFAULT_ARRAY_COMPARATOR = + (b1, b2, idx) -> + Arrays.deepEquals((Object[]) b1.getObject(idx), (Object[]) b2.getObject(idx)); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/AbstractDataBufferWindow.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/AbstractDataBufferWindow.java new file mode 100644 index 00000000000..cf28df86ca1 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/AbstractDataBufferWindow.java @@ -0,0 +1,49 @@ +package org.tensorflow.ndarray.impl.buffer; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBufferWindow; + +public abstract class AbstractDataBufferWindow> + implements DataBufferWindow { + + @Override + public final long offset() { + return offset; + } + + @Override + public final long size() { + return windowBuffer.size(); + } + + @Override + public final DataBufferWindow slideTo(long index) { + if (index < 0 || index > maxOffset) { + throw new IndexOutOfBoundsException(); + } + offset(index); + offset = index; + return this; + } + + @Override + public final DataBufferWindow slide(long step) { + return slideTo(offset + step); + } + + @Override + public final B buffer() { + return windowBuffer; + } + + protected abstract void offset(long offset); + + protected AbstractDataBufferWindow(B windowBuffer, long bufferLimit) { + this.windowBuffer = windowBuffer; + maxOffset = bufferLimit - windowBuffer.size(); + } + + private final B windowBuffer; + private final long maxOffset; + private long offset = 0; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/Validator.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/Validator.java new file mode 100644 index 00000000000..d85a6ded17f --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/Validator.java @@ -0,0 +1,135 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer; + +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; +import java.nio.ReadOnlyBufferException; +import org.tensorflow.ndarray.buffer.DataBuffer; + +public class Validator { + + public static void createArgs(long size, long maxSize) { + if (size < 0) { + throw new IllegalArgumentException("Size must be non-negative"); + } + if (size > maxSize) { + throw new IllegalArgumentException( + "Buffer size must be no greater than maximum size allowed (" + maxSize + ")"); + } + } + + public static void getArgs(DataBuffer buffer, long index) { + if (index < 0) { + throw new IndexOutOfBoundsException("Index must be non-negative"); + } + if (index >= buffer.size()) { + throw new IndexOutOfBoundsException("Index must be smaller than the buffer size"); + } + } + + public static void setArgs(DataBuffer buffer, long index) { + if (index < 0) { + throw new IndexOutOfBoundsException("Index must be non-negative"); + } + if (index >= buffer.size()) { + throw new IndexOutOfBoundsException("Index must be smaller than the buffer size"); + } + if (buffer.isReadOnly()) { + throw new ReadOnlyBufferException(); + } + } + + public static void copyToArgs(DataBuffer src, DataBuffer dst, long size) { + if (dst == src) { + throw new IllegalArgumentException("Source cannot be the same buffer as destination"); + } + if (size > dst.size()) { + throw new BufferOverflowException(); + } + if (size > src.size()) { + throw new BufferUnderflowException(); + } + if (dst.isReadOnly()) { + throw new ReadOnlyBufferException(); + } + } + + public static void readArgs(DataBuffer buffer, int arrayLength, int offset, int length) { + if (length > buffer.size()) { + throw new BufferUnderflowException(); + } + arrayArgs(arrayLength, offset, length); + } + + public static void writeArgs(DataBuffer buffer, int arrayLength, int offset, int length) { + if (length > buffer.size()) { + throw new BufferOverflowException(); + } + if (buffer.isReadOnly()) { + throw new ReadOnlyBufferException(); + } + arrayArgs(arrayLength, offset, length); + } + + public static void offsetArgs(DataBuffer buffer, long index) { + if (index < 0) { + throw new IllegalArgumentException("Index must be non-negative"); + } + if (index > buffer.size()) { + throw new IllegalArgumentException("Index must not exceed buffer size"); + } + } + + public static void narrowArgs(DataBuffer buffer, long size) { + if (size < 0) { + throw new IllegalArgumentException("Size must be non-negative"); + } + if (size > buffer.size()) { + throw new IllegalArgumentException( + "Cannot narrow a buffer of size " + buffer.size() + " to " + size); + } + } + + public static void sliceArgs(DataBuffer buffer, long index, long size) { + if (index < 0) { + throw new IllegalArgumentException("Index must be non-negative"); + } + if (size < 0) { + throw new IllegalArgumentException("Size must be non-negative"); + } + if (index + size > buffer.size()) { + throw new IllegalArgumentException("Buffer view must not exceed original buffer limits"); + } + } + + private static void arrayArgs(int arrayLength, int offset, int length) { + if (offset < 0) { + throw new IndexOutOfBoundsException("Offset must be non-negative"); + } + if (offset > arrayLength) { + throw new IndexOutOfBoundsException("Offset must be no larger than array length"); + } + if (length < 0) { + throw new IndexOutOfBoundsException("Length must be non-negative"); + } + if (length > arrayLength - offset) { + throw new IndexOutOfBoundsException( + "Length must be no larger than array length minus the offset"); + } + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/AbstractDataBufferAdapter.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/AbstractDataBufferAdapter.java new file mode 100644 index 00000000000..901a7ed8905 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/AbstractDataBufferAdapter.java @@ -0,0 +1,69 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.layout.DataLayout; +import org.tensorflow.ndarray.impl.buffer.AbstractDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +@SuppressWarnings("unchecked") +abstract class AbstractDataBufferAdapter, T, U extends DataBuffer> + extends AbstractDataBuffer { + + @Override + public long size() { + return size; + } + + @Override + public boolean isReadOnly() { + return buffer.isReadOnly(); + } + + @Override + public T getObject(long index) { + Validator.getArgs(this, index); + return layout.readObject(buffer, index * layout.scale()); + } + + @Override + public U setObject(T value, long index) { + Validator.setArgs(this, index); + layout.writeObject(buffer, value, index * layout.scale()); + return (U) this; + } + + AbstractDataBufferAdapter(S buffer, DataLayout layout) { + this.buffer = buffer; + this.layout = layout; + size = buffer.size() / layout.scale(); + } + + DataLayout layout() { + return layout; + } + + S buffer() { + return buffer; + } + + private final S buffer; + private final DataLayout layout; + private final long size; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/BooleanDataBufferAdapter.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/BooleanDataBufferAdapter.java new file mode 100644 index 00000000000..f32d2292423 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/BooleanDataBufferAdapter.java @@ -0,0 +1,117 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.layout.BooleanDataLayout; +import org.tensorflow.ndarray.impl.buffer.Validator; + +class BooleanDataBufferAdapter> + extends AbstractDataBufferAdapter implements BooleanDataBuffer { + + @Override + public boolean getBoolean(long index) { + Validator.getArgs(this, index); + return layout.readBoolean(buffer(), index * layout.scale()); + } + + @Override + public BooleanDataBuffer setBoolean(boolean value, long index) { + Validator.setArgs(this, index); + layout.writeBoolean(buffer(), value, index * layout.scale()); + return this; + } + + @Override + public BooleanDataBuffer read(boolean[] dst, int offset, int length) { + Validator.readArgs(this, dst.length, offset, length); + for (int i = 0, j = offset; i < length; ++i, ++j) { + dst[j] = layout.readBoolean(buffer(), i * layout.scale()); + } + return this; + } + + @Override + public BooleanDataBuffer write(boolean[] src, int offset, int length) { + Validator.writeArgs(this, src.length, offset, length); + for (int i = 0, j = offset; i < length; ++i, ++j) { + layout.writeBoolean(buffer(), src[j], i * layout.scale()); + } + return this; + } + + @Override + public BooleanDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + if (dst instanceof BooleanDataBuffer) { + BooleanDataBuffer booleanDst = (BooleanDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + booleanDst.setBoolean(getBoolean(idx), idx); + } + return this; + } + return slowCopyTo(dst, size); + } + + @Override + @SuppressWarnings("unchecked") + public BooleanDataBuffer offset(long index) { + return new BooleanDataBufferAdapter<>((S) buffer().offset(index * layout.scale()), layout); + } + + @Override + @SuppressWarnings("unchecked") + public BooleanDataBuffer narrow(long size) { + return new BooleanDataBufferAdapter<>((S) buffer().narrow(size * layout.scale()), layout); + } + + @Override + @SuppressWarnings("unchecked") + public BooleanDataBuffer slice(long index, long size) { + return new BooleanDataBufferAdapter<>( + (S) buffer().slice(index * layout.scale(), size * layout.scale()), layout); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof BooleanDataBuffer)) { + return super.equals(obj); + } + BooleanDataBuffer other = (BooleanDataBuffer) obj; + if (other.size() != size()) { + return false; + } + for (long idx = 0L; idx < size(); ++idx) { + if (other.getBoolean(idx) != getBoolean(idx)) { + return false; + } + } + return true; + } + + BooleanDataBufferAdapter(S buffer, BooleanDataLayout layout) { + super(buffer, layout); + this.layout = layout; + } + + private BooleanDataLayout layout; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/ByteDataBufferAdapter.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/ByteDataBufferAdapter.java new file mode 100644 index 00000000000..e93ce3054b0 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/ByteDataBufferAdapter.java @@ -0,0 +1,136 @@ +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.buffer.layout.ByteDataLayout; +import org.tensorflow.ndarray.impl.buffer.Validator; + +class ByteDataBufferAdapter> + extends AbstractDataBufferAdapter implements ByteDataBuffer { + + @Override + public byte getByte(long index) { + Validator.getArgs(this, index); + return layout.readByte(buffer(), index * layout.scale()); + } + + @Override + public ByteDataBuffer setByte(byte value, long index) { + Validator.setArgs(this, index); + layout.writeByte(buffer(), value, index * layout.scale()); + return this; + } + + @Override + public ByteDataBuffer read(byte[] dst, int offset, int length) { + Validator.readArgs(this, dst.length, offset, length); + for (int i = 0, j = offset; i < length; ++i, ++j) { + dst[j] = layout.readByte(buffer(), i * layout.scale()); + } + return this; + } + + @Override + public ByteDataBuffer write(byte[] src, int offset, int length) { + Validator.writeArgs(this, src.length, offset, length); + for (int i = 0, j = offset; i < length; ++i, ++j) { + layout.writeByte(buffer(), src[j], i * layout.scale()); + } + return this; + } + + @Override + public ByteDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + if (dst instanceof ByteDataBuffer) { + ByteDataBuffer byteDst = (ByteDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + byteDst.setByte(getByte(idx), idx); + } + return this; + } + return slowCopyTo(dst, size); + } + + @Override + public IntDataBuffer asInts() { + throw new IllegalStateException("Byte buffers with layout cannot be converted"); + } + + @Override + public ShortDataBuffer asShorts() { + throw new IllegalStateException("Byte buffers with layout cannot be converted"); + } + + @Override + public LongDataBuffer asLongs() { + throw new IllegalStateException("Byte buffers with layout cannot be converted"); + } + + @Override + public FloatDataBuffer asFloats() { + throw new IllegalStateException("Byte buffers with layout cannot be converted"); + } + + @Override + public DoubleDataBuffer asDoubles() { + throw new IllegalStateException("Byte buffers with layout cannot be converted"); + } + + @Override + public BooleanDataBuffer asBooleans() { + throw new IllegalStateException("Byte buffers with layout cannot be converted"); + } + + @Override + @SuppressWarnings("unchecked") + public ByteDataBuffer offset(long index) { + return new ByteDataBufferAdapter<>((S) buffer().offset(index * layout.scale()), layout); + } + + @Override + @SuppressWarnings("unchecked") + public ByteDataBuffer narrow(long size) { + return new ByteDataBufferAdapter<>((S) buffer().narrow(size * layout.scale()), layout); + } + + @Override + @SuppressWarnings("unchecked") + public ByteDataBuffer slice(long index, long size) { + return new ByteDataBufferAdapter<>( + (S) buffer().slice(index * layout.scale(), size * layout.scale()), layout); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof ByteDataBuffer)) { + return super.equals(obj); + } + ByteDataBuffer other = (ByteDataBuffer) obj; + if (other.size() != size()) { + return false; + } + for (long idx = 0L; idx < size(); ++idx) { + if (other.getByte(idx) != getByte(idx)) { + return false; + } + } + return true; + } + + ByteDataBufferAdapter(S buffer, ByteDataLayout layout) { + super(buffer, layout); + this.layout = layout; + } + + private ByteDataLayout layout; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/DataBufferAdapter.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/DataBufferAdapter.java new file mode 100644 index 00000000000..3d5e98111d2 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/DataBufferAdapter.java @@ -0,0 +1,49 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.layout.DataLayout; + +@SuppressWarnings("unchecked") +class DataBufferAdapter, T> + extends AbstractDataBufferAdapter> { + + @Override + @SuppressWarnings("unchecked") + public DataBuffer offset(long index) { + return new DataBufferAdapter<>((S) buffer().offset(index * layout().scale()), layout()); + } + + @Override + @SuppressWarnings("unchecked") + public DataBuffer narrow(long size) { + return new DataBufferAdapter<>((S) buffer().narrow(size * layout().scale()), layout()); + } + + @Override + @SuppressWarnings("unchecked") + public DataBuffer slice(long index, long size) { + return new DataBufferAdapter<>( + (S) buffer().slice(index * layout().scale(), size * layout().scale()), layout()); + } + + DataBufferAdapter(S buffer, DataLayout layout) { + super(buffer, layout); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/DataBufferAdapterFactory.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/DataBufferAdapterFactory.java new file mode 100644 index 00000000000..8b82282864e --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/DataBufferAdapterFactory.java @@ -0,0 +1,149 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.buffer.layout.BooleanDataLayout; +import org.tensorflow.ndarray.buffer.layout.ByteDataLayout; +import org.tensorflow.ndarray.buffer.layout.DataLayout; +import org.tensorflow.ndarray.buffer.layout.DoubleDataLayout; +import org.tensorflow.ndarray.buffer.layout.FloatDataLayout; +import org.tensorflow.ndarray.buffer.layout.IntDataLayout; +import org.tensorflow.ndarray.buffer.layout.LongDataLayout; +import org.tensorflow.ndarray.buffer.layout.ShortDataLayout; + +/** + * Factory of data buffer adapters. + * + *

Data buffer adapters are used to apply a {@link DataLayout} to a buffer. Conceptually, they + * act as a proxy that intercept each I/O call and perform the required type conversions + * after/before delegating the task to the underlying buffer. + */ +public class DataBufferAdapterFactory { + + /** + * Creates an adapter that applies a byte data layout to the given buffer. + * + * @param buffer the delegate buffer + * @param layout layout to apply + * @param the type of the buffer + * @return buffer adapter + */ + public static > ByteDataBuffer create( + S buffer, ByteDataLayout layout) { + return new ByteDataBufferAdapter<>(buffer, layout); + } + + /** + * Creates an adapter that applies a boolean data layout to the given buffer. + * + * @param buffer the delegate buffer + * @param layout layout to apply + * @param the type of the buffer + * @return buffer adapter + */ + public static > BooleanDataBuffer create( + S buffer, BooleanDataLayout layout) { + return new BooleanDataBufferAdapter<>(buffer, layout); + } + + /** + * Creates an adapter that applies a double data layout to the given buffer. + * + * @param buffer the delegate buffer + * @param layout layout to apply + * @param the type of the buffer + * @return buffer adapter + */ + public static > DoubleDataBuffer create( + S buffer, DoubleDataLayout layout) { + return new DoubleDataBufferAdapter<>(buffer, layout); + } + + /** + * Creates an adapter that applies a float data layout to the given buffer. + * + * @param buffer the delegate buffer + * @param layout layout to apply + * @param the type of the buffer + * @return buffer adapter + */ + public static > FloatDataBuffer create( + S buffer, FloatDataLayout layout) { + return new FloatDataBufferAdapter<>(buffer, layout); + } + + /** + * Creates an adapter that applies a integer data layout to the given buffer. + * + * @param buffer the delegate buffer + * @param layout layout to apply + * @param the type of the buffer + * @return buffer adapter + */ + public static > IntDataBuffer create(S buffer, IntDataLayout layout) { + return new IntDataBufferAdapter<>(buffer, layout); + } + + /** + * Creates an adapter that applies a long data layout to the given buffer. + * + * @param buffer the delegate buffer + * @param layout layout to apply + * @param the type of the buffer + * @return buffer adapter + */ + public static > LongDataBuffer create( + S buffer, LongDataLayout layout) { + return new LongDataBufferAdapter<>(buffer, layout); + } + + /** + * Creates an adapter that applies a short data layout to the given buffer. + * + * @param buffer the delegate buffer + * @param layout layout to apply + * @param the type of the buffer + * @return buffer adapter + */ + public static > ShortDataBuffer create( + S buffer, ShortDataLayout layout) { + return new ShortDataBufferAdapter<>(buffer, layout); + } + + /** + * Creates an adapter that applies a data layout to the given buffer. + * + * @param buffer the delegate buffer + * @param layout layout to apply + * @param the type of the buffer + * @param the type of data returned by the layout + * @return buffer adapter + */ + public static , T> DataBuffer create( + S buffer, DataLayout layout) { + return new DataBufferAdapter<>(buffer, layout); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/DoubleDataBufferAdapter.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/DoubleDataBufferAdapter.java new file mode 100644 index 00000000000..542b4dfe4dd --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/DoubleDataBufferAdapter.java @@ -0,0 +1,117 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.layout.DoubleDataLayout; +import org.tensorflow.ndarray.impl.buffer.Validator; + +class DoubleDataBufferAdapter> + extends AbstractDataBufferAdapter implements DoubleDataBuffer { + + @Override + public double getDouble(long index) { + Validator.getArgs(this, index); + return layout.readDouble(buffer(), index * layout.scale()); + } + + @Override + public DoubleDataBuffer setDouble(double value, long index) { + Validator.setArgs(this, index); + layout.writeDouble(buffer(), value, index * layout.scale()); + return this; + } + + @Override + public DoubleDataBuffer read(double[] dst, int offset, int length) { + Validator.readArgs(this, dst.length, offset, length); + for (int i = 0, j = offset; i < length; ++i, ++j) { + dst[j] = layout.readDouble(buffer(), i * layout.scale()); + } + return this; + } + + @Override + public DoubleDataBuffer write(double[] src, int offset, int length) { + Validator.writeArgs(this, src.length, offset, length); + for (int i = 0, j = offset; i < length; ++i, ++j) { + layout.writeDouble(buffer(), src[j], i * layout.scale()); + } + return this; + } + + @Override + public DoubleDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + if (dst instanceof DoubleDataBuffer) { + DoubleDataBuffer doubleDst = (DoubleDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + doubleDst.setDouble(getDouble(idx), idx); + } + return this; + } + return slowCopyTo(dst, size); + } + + @Override + @SuppressWarnings("unchecked") + public DoubleDataBuffer offset(long index) { + return new DoubleDataBufferAdapter<>((S) buffer().offset(index * layout.scale()), layout); + } + + @Override + @SuppressWarnings("unchecked") + public DoubleDataBuffer narrow(long size) { + return new DoubleDataBufferAdapter<>((S) buffer().narrow(size * layout.scale()), layout); + } + + @Override + @SuppressWarnings("unchecked") + public DoubleDataBuffer slice(long index, long size) { + return new DoubleDataBufferAdapter<>( + (S) buffer().slice(index * layout.scale(), size * layout.scale()), layout); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof DoubleDataBuffer)) { + return super.equals(obj); + } + DoubleDataBuffer other = (DoubleDataBuffer) obj; + if (other.size() != size()) { + return false; + } + for (long idx = 0L; idx < size(); ++idx) { + if (other.getDouble(idx) != getDouble(idx)) { + return false; + } + } + return true; + } + + DoubleDataBufferAdapter(S buffer, DoubleDataLayout layout) { + super(buffer, layout); + this.layout = layout; + } + + private DoubleDataLayout layout; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/FloatDataBufferAdapter.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/FloatDataBufferAdapter.java new file mode 100644 index 00000000000..2c581f4b1e0 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/FloatDataBufferAdapter.java @@ -0,0 +1,117 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.layout.FloatDataLayout; +import org.tensorflow.ndarray.impl.buffer.Validator; + +class FloatDataBufferAdapter> + extends AbstractDataBufferAdapter implements FloatDataBuffer { + + @Override + public float getFloat(long index) { + Validator.getArgs(this, index); + return layout.readFloat(buffer(), index * layout.scale()); + } + + @Override + public FloatDataBuffer setFloat(float value, long index) { + Validator.setArgs(this, index); + layout.writeFloat(buffer(), value, index * layout.scale()); + return this; + } + + @Override + public FloatDataBuffer read(float[] dst, int offset, int length) { + Validator.readArgs(this, dst.length, offset, length); + for (int i = 0, j = offset; i < length; ++i, ++j) { + dst[j] = layout.readFloat(buffer(), i * layout.scale()); + } + return this; + } + + @Override + public FloatDataBuffer write(float[] src, int offset, int length) { + Validator.writeArgs(this, src.length, offset, length); + for (int i = 0, j = offset; i < length; ++i, ++j) { + layout.writeFloat(buffer(), src[j], i * layout.scale()); + } + return this; + } + + @Override + public FloatDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + if (dst instanceof FloatDataBuffer) { + FloatDataBuffer floatDst = (FloatDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + floatDst.setFloat(getFloat(idx), idx); + } + return this; + } + return slowCopyTo(dst, size); + } + + @Override + @SuppressWarnings("unchecked") + public FloatDataBuffer offset(long index) { + return new FloatDataBufferAdapter<>((S) buffer().offset(index * layout.scale()), layout); + } + + @Override + @SuppressWarnings("unchecked") + public FloatDataBuffer narrow(long size) { + return new FloatDataBufferAdapter<>((S) buffer().narrow(size * layout.scale()), layout); + } + + @Override + @SuppressWarnings("unchecked") + public FloatDataBuffer slice(long index, long size) { + return new FloatDataBufferAdapter<>( + (S) buffer().slice(index * layout.scale(), size * layout.scale()), layout); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof FloatDataBuffer)) { + return super.equals(obj); + } + FloatDataBuffer other = (FloatDataBuffer) obj; + if (other.size() != size()) { + return false; + } + for (long idx = 0L; idx < size(); ++idx) { + if (other.getFloat(idx) != getFloat(idx)) { + return false; + } + } + return true; + } + + FloatDataBufferAdapter(S buffer, FloatDataLayout layout) { + super(buffer, layout); + this.layout = layout; + } + + private FloatDataLayout layout; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/IntDataBufferAdapter.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/IntDataBufferAdapter.java new file mode 100644 index 00000000000..7a93d52e6a7 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/IntDataBufferAdapter.java @@ -0,0 +1,117 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.layout.IntDataLayout; +import org.tensorflow.ndarray.impl.buffer.Validator; + +class IntDataBufferAdapter> + extends AbstractDataBufferAdapter implements IntDataBuffer { + + @Override + public int getInt(long index) { + Validator.getArgs(this, index); + return layout.readInt(buffer(), index * layout.scale()); + } + + @Override + public IntDataBuffer setInt(int value, long index) { + Validator.setArgs(this, index); + layout.writeInt(buffer(), value, index * layout.scale()); + return this; + } + + @Override + public IntDataBuffer read(int[] dst, int offset, int length) { + Validator.readArgs(this, dst.length, offset, length); + for (int i = 0, j = offset; i < length; ++i, ++j) { + dst[j] = layout.readInt(buffer(), i * layout.scale()); + } + return this; + } + + @Override + public IntDataBuffer write(int[] src, int offset, int length) { + Validator.writeArgs(this, src.length, offset, length); + for (int i = 0, j = offset; i < length; ++i, ++j) { + layout.writeInt(buffer(), src[j], i * layout.scale()); + } + return this; + } + + @Override + public IntDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + if (dst instanceof IntDataBuffer) { + IntDataBuffer intDst = (IntDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + intDst.setInt(getInt(idx), idx); + } + return this; + } + return slowCopyTo(dst, size); + } + + @Override + @SuppressWarnings("unchecked") + public IntDataBuffer offset(long index) { + return new IntDataBufferAdapter<>((S) buffer().offset(index * layout.scale()), layout); + } + + @Override + @SuppressWarnings("unchecked") + public IntDataBuffer narrow(long size) { + return new IntDataBufferAdapter<>((S) buffer().narrow(size * layout.scale()), layout); + } + + @Override + @SuppressWarnings("unchecked") + public IntDataBuffer slice(long index, long size) { + return new IntDataBufferAdapter<>( + (S) buffer().slice(index * layout.scale(), size * layout.scale()), layout); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof IntDataBuffer)) { + return super.equals(obj); + } + IntDataBuffer other = (IntDataBuffer) obj; + if (other.size() != size()) { + return false; + } + for (long idx = 0L; idx < size(); ++idx) { + if (other.getInt(idx) != getInt(idx)) { + return false; + } + } + return true; + } + + IntDataBufferAdapter(S buffer, IntDataLayout layout) { + super(buffer, layout); + this.layout = layout; + } + + private IntDataLayout layout; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/LongDataBufferAdapter.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/LongDataBufferAdapter.java new file mode 100644 index 00000000000..db31050b02f --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/LongDataBufferAdapter.java @@ -0,0 +1,117 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.buffer.layout.LongDataLayout; +import org.tensorflow.ndarray.impl.buffer.Validator; + +class LongDataBufferAdapter> + extends AbstractDataBufferAdapter implements LongDataBuffer { + + @Override + public long getLong(long index) { + Validator.getArgs(this, index); + return layout.readLong(buffer(), index * layout.scale()); + } + + @Override + public LongDataBuffer setLong(long value, long index) { + Validator.setArgs(this, index); + layout.writeLong(buffer(), value, index * layout.scale()); + return this; + } + + @Override + public LongDataBuffer read(long[] dst, int offset, int length) { + Validator.readArgs(this, dst.length, offset, length); + for (int i = 0, j = offset; i < length; ++i, ++j) { + dst[j] = layout.readLong(buffer(), i * layout.scale()); + } + return this; + } + + @Override + public LongDataBuffer write(long[] src, int offset, int length) { + Validator.writeArgs(this, src.length, offset, length); + for (int i = 0, j = offset; i < length; ++i, ++j) { + layout.writeLong(buffer(), src[j], i * layout.scale()); + } + return this; + } + + @Override + public LongDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + if (dst instanceof LongDataBuffer) { + LongDataBuffer longDst = (LongDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + longDst.setLong(getLong(idx), idx); + } + return this; + } + return slowCopyTo(dst, size); + } + + @Override + @SuppressWarnings("unchecked") + public LongDataBuffer offset(long index) { + return new LongDataBufferAdapter<>((S) buffer().offset(index * layout.scale()), layout); + } + + @Override + @SuppressWarnings("unchecked") + public LongDataBuffer narrow(long size) { + return new LongDataBufferAdapter<>((S) buffer().narrow(size * layout.scale()), layout); + } + + @Override + @SuppressWarnings("unchecked") + public LongDataBuffer slice(long index, long size) { + return new LongDataBufferAdapter<>( + (S) buffer().slice(index * layout.scale(), size * layout.scale()), layout); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof LongDataBuffer)) { + return super.equals(obj); + } + LongDataBuffer other = (LongDataBuffer) obj; + if (other.size() != size()) { + return false; + } + for (long idx = 0L; idx < size(); ++idx) { + if (other.getLong(idx) != getLong(idx)) { + return false; + } + } + return true; + } + + LongDataBufferAdapter(S buffer, LongDataLayout layout) { + super(buffer, layout); + this.layout = layout; + } + + private LongDataLayout layout; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/ShortDataBufferAdapter.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/ShortDataBufferAdapter.java new file mode 100644 index 00000000000..001b31ef176 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/adapter/ShortDataBufferAdapter.java @@ -0,0 +1,117 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.buffer.layout.ShortDataLayout; +import org.tensorflow.ndarray.impl.buffer.Validator; + +class ShortDataBufferAdapter> + extends AbstractDataBufferAdapter implements ShortDataBuffer { + + @Override + public short getShort(long index) { + Validator.getArgs(this, index); + return layout.readShort(buffer(), index * layout.scale()); + } + + @Override + public ShortDataBuffer setShort(short value, long index) { + Validator.setArgs(this, index); + layout.writeShort(buffer(), value, index * layout.scale()); + return this; + } + + @Override + public ShortDataBuffer read(short[] dst, int offset, int length) { + Validator.readArgs(this, dst.length, offset, length); + for (int i = 0, j = offset; i < length; ++i, ++j) { + dst[j] = layout.readShort(buffer(), i * layout.scale()); + } + return this; + } + + @Override + public ShortDataBuffer write(short[] src, int offset, int length) { + Validator.writeArgs(this, src.length, offset, length); + for (int i = 0, j = offset; i < length; ++i, ++j) { + layout.writeShort(buffer(), src[j], i * layout.scale()); + } + return this; + } + + @Override + public ShortDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + if (dst instanceof ShortDataBuffer) { + ShortDataBuffer shortDst = (ShortDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + shortDst.setShort(getShort(idx), idx); + } + return this; + } + return slowCopyTo(dst, size); + } + + @Override + @SuppressWarnings("unchecked") + public ShortDataBuffer offset(long index) { + return new ShortDataBufferAdapter<>((S) buffer().offset(index * layout.scale()), layout); + } + + @Override + @SuppressWarnings("unchecked") + public ShortDataBuffer narrow(long size) { + return new ShortDataBufferAdapter<>((S) buffer().narrow(size * layout.scale()), layout); + } + + @Override + @SuppressWarnings("unchecked") + public ShortDataBuffer slice(long index, long size) { + return new ShortDataBufferAdapter<>( + (S) buffer().slice(index * layout.scale(), size * layout.scale()), layout); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof ShortDataBuffer)) { + return super.equals(obj); + } + ShortDataBuffer other = (ShortDataBuffer) obj; + if (other.size() != size()) { + return false; + } + for (long idx = 0L; idx < size(); ++idx) { + if (other.getShort(idx) != getShort(idx)) { + return false; + } + } + return true; + } + + ShortDataBufferAdapter(S buffer, ShortDataLayout layout) { + super(buffer, layout); + this.layout = layout; + } + + private ShortDataLayout layout; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/layout/Bfloat16Layout.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/layout/Bfloat16Layout.java new file mode 100644 index 00000000000..444e7f8f674 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/layout/Bfloat16Layout.java @@ -0,0 +1,54 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.layout; + +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.buffer.layout.FloatDataLayout; + +/** + * Data layout that converts 32-bit floats from/to 16-bit, truncating their mantissa to 7 bits but + * preserving the 8-bit exponent with the same bias. + */ +public final class Bfloat16Layout implements FloatDataLayout { + + @Override + public void writeFloat(ShortDataBuffer buffer, float value, long index) { + buffer.setShort(float32to16(value), index); + } + + @Override + public float readFloat(ShortDataBuffer buffer, long index) { + return float16to32(buffer.getShort(index)); + } + + // + // FLOAT 32-bit to/from BFLOAT 16-bit conversions + // + // We simply shift the value from 32-bit to 16-bit and vice-versa. NaN special case is ignored. + // + + // VisibleForTesting + static short float32to16(float f32) { + return (short) (Float.floatToIntBits(f32) >>> 16); + } + + // Visible for testing + static float float16to32(short i16) { + return Float.intBitsToFloat((int) i16 << 16); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/layout/BoolLayout.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/layout/BoolLayout.java new file mode 100644 index 00000000000..ea76aafb823 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/layout/BoolLayout.java @@ -0,0 +1,45 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.layout; + +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.layout.BooleanDataLayout; + +/** Data layout that converts booleans from/to bytes. */ +public final class BoolLayout implements BooleanDataLayout { + + @Override + public void writeBoolean(ByteDataBuffer buffer, boolean value, long index) { + buffer.setByte(booleanToByte(value), index); + } + + @Override + public boolean readBoolean(ByteDataBuffer buffer, long index) { + return byteToBoolean(buffer.getByte(index)); + } + + // Visible for testing + static byte booleanToByte(boolean b) { + return (byte) (b ? 0x1 : 0x0); + } + + // Visible for testing + static boolean byteToBoolean(byte b) { + return b != 0x0; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/layout/Float16Layout.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/layout/Float16Layout.java new file mode 100644 index 00000000000..7d7a32ad203 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/layout/Float16Layout.java @@ -0,0 +1,122 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.layout; + +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.buffer.layout.FloatDataLayout; + +/** + * Data layout that converts 32-bit floats from/to 16-bit, accordingly to the IEEE-754 + * half-precision floating point specification. + */ +public final class Float16Layout implements FloatDataLayout { + + @Override + public void writeFloat(ShortDataBuffer buffer, float value, long index) { + buffer.setShort(float32to16(value), index); + } + + @Override + public float readFloat(ShortDataBuffer buffer, long index) { + return float16to32(buffer.getShort(index)); + } + + // + // FLOAT 32-bit to/from 16-bit conversions + // + // The following conversion algorithms are issued from the C++ implementation found in the + // Eigen library used by TensorFlow native library. + // See https://eigen.tuxfamily.org/dox-devel/Half_8h_source.html for more details. + // + + // VisibleForTesting + static short float32to16(float f32) { + int i16; + int i32 = Float.floatToIntBits(f32); + short sign16 = (short) ((i32 >>> 16) & 0x8000); + i32 &= 0x7FFFFFFF; // remove sign + + if (i32 >= (E32BIAS + E16MAX + 1) << E32SHIFT) { + // float32 value is higher than float16 max value (max16 -> 2^15 * 2 -> 2^16) + // - if float32 value is higher than infinite (i.e. s32 > 0), then it is NaN and should also + // be NaN in float16 (0x7e00) + // - else, float16 value is forced to infinite (0x7c00) + i16 = i32 > E32MASK ? 0x7E00 : 0x7C00; + + } else if (i32 < (E32BIAS + E16MIN) << E32SHIFT) { + // float32 abs value is smaller than float16 min abs value (min16 = 2^-14), could also be 0 + // - apply magic number to align significand 10 bits at the bottom on the float and subtract + // bias + i16 = Float.floatToIntBits(Float.intBitsToFloat(i32) + MAGIC_32_16_FLOAT) - MAGIC_32_16; + + } else { + // float32 value can be rounded up to a normalized float16 value (i.e. exp32 = [113(-14), + // 142(15)]) + // - rebase exponent to float16 + // - round up significand to the 13nd bit if s16 is even, on the 12nd bit if it is odd + int round = 0xFFF + ((i32 >>> 13) & 0x1); + i16 = (i32 + ((E16BIAS - E32BIAS) << E32SHIFT) + round) >>> 13; + } + return (short) (i16 | sign16); + } + + // Visible for testing + static float float16to32(short i16) { + int i32 = (i16 & 0x7FFF) << (S32BITS - S16BITS); // remove sign and align in float32 + i32 += (E32BIAS - E16BIAS) << E32SHIFT; // rebase exponent to float32 + + // Handle float16 exponent special cases + switch (i16 & E16MASK) { + case E16MASK: + // float16 value is infinite or NaN + // - adjust float32 exponent one more time + i32 += (E32BIAS - E16BIAS) << E32SHIFT; + break; + case 0x0: + // float16 value is zero or subnormal + // - adjust float32 exponent + // - renormalize using magic number + i32 = Float.floatToIntBits(Float.intBitsToFloat(i32 + (1 << E32SHIFT)) - MAGIC_16_32_FLOAT); + break; + default: + break; + } + return Float.intBitsToFloat(i32 | ((i16 & 0x8000) << 16)); // reapply sign + } + + // float32 format + private static final int E32SHIFT = 23; // position of the exponent in float32 + private static final int E32MASK = 0xFF << E32SHIFT; // mask for float32 exponent (== Infinity) + private static final int E32BIAS = 127; // exponent bias for float32 + private static final int S32BITS = 23; // number of bits in float32 significand + + // float16 format + private static final int E16SHIFT = 10; // position of the exponent in float16 + private static final int E16MASK = 0x1F << E16SHIFT; // mask for float16 exponent (== Infinity) + private static final int E16BIAS = 15; // exponent bias for float16 + private static final int E16MAX = 15; // max value for float16 exponent + private static final int E16MIN = -14; // min value for float16 exponent + private static final int S16BITS = 10; // number of bits in float16 significand + + // magic numbers used when converting denormalized values + private static final int MAGIC_32_16 = + ((E32BIAS - E16BIAS) + (S32BITS - S16BITS) + 1) << E32SHIFT; + private static final float MAGIC_32_16_FLOAT = Float.intBitsToFloat(MAGIC_32_16); + private static final int MAGIC_16_32 = (E32BIAS - E16BIAS + 1) << E32SHIFT; + private static final float MAGIC_16_32_FLOAT = Float.intBitsToFloat(MAGIC_16_32); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/layout/StringLayout.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/layout/StringLayout.java new file mode 100644 index 00000000000..e77427bd4e8 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/layout/StringLayout.java @@ -0,0 +1,46 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.layout; + +import java.nio.charset.Charset; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.layout.DataLayout; + +/** Data layout that converts a String to/from a sequence of bytes applying a given charset. */ +public final class StringLayout implements DataLayout, String> { + + public static StringLayout of(Charset charset) { + return new StringLayout(charset); + } + + @Override + public void writeObject(DataBuffer buffer, String value, long index) { + buffer.setObject(value.getBytes(charset), index); + } + + @Override + public String readObject(DataBuffer buffer, long index) { + return new String(buffer.getObject(index), charset); + } + + private StringLayout(Charset charset) { + this.charset = charset; + } + + private final Charset charset; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/misc/ArrayDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/misc/ArrayDataBuffer.java new file mode 100644 index 00000000000..1a26843713f --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/misc/ArrayDataBuffer.java @@ -0,0 +1,131 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.misc; + +import java.util.Arrays; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.impl.buffer.AbstractDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +class ArrayDataBuffer extends AbstractDataBuffer { + + @Override + public long size() { + return length; + } + + @Override + public boolean isReadOnly() { + return readOnly; + } + + @Override + public T getObject(long index) { + Validator.getArgs(this, index); + return values[(int) index + offset]; + } + + @Override + public DataBuffer setObject(T value, long index) { + Validator.setArgs(this, index); + values[(int) index + offset] = value; + return this; + } + + @Override + public DataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor>() { + + @Override + public DataBuffer visit(Object[] array, int arrayOffset, int arrayLength) { + System.arraycopy(values, offset, array, arrayOffset, (int) size); + return ArrayDataBuffer.this; + } + + @Override + public DataBuffer fallback() { + for (int idx = 0; idx < size; ++idx) { + dst.setObject(values[idx + offset], idx); + } + return ArrayDataBuffer.this; + } + }); + } + + @Override + public DataBuffer slice(long index, long size) { + Validator.sliceArgs(this, index, size); + return new ArrayDataBuffer<>(values, readOnly, offset + (int) index, (int) size); + } + + @Override + public R accept(DataStorageVisitor visitor) { + return visitor.visit(values, offset, length); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof DataBuffer)) { + return false; + } + DataBuffer other = (DataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(Object[] array, int arrayOffset, int arrayLength) { + if (offset == 0 + && values.length == length + && arrayOffset == 0 + && array.length == arrayLength) { + return Arrays.deepEquals(array, values); + } + return slowEquals(other); + } + + @Override + public Boolean fallback() { + return slowEquals(other); + } + }); + } + + ArrayDataBuffer(T[] values, boolean readOnly) { + this(values, readOnly, 0, values.length); + } + + private ArrayDataBuffer(T[] values, boolean readOnly, int offset, int length) { + this.values = values; + this.readOnly = readOnly; + this.offset = offset; + this.length = length; + } + + private final T[] values; + private final boolean readOnly; + private final int offset; + private final int length; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/misc/BitSetDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/misc/BitSetDataBuffer.java new file mode 100644 index 00000000000..62b658f7cf4 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/misc/BitSetDataBuffer.java @@ -0,0 +1,185 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.misc; + +import java.util.BitSet; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.impl.buffer.AbstractDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +class BitSetDataBuffer extends AbstractDataBuffer implements BooleanDataBuffer { + + @Override + public long size() { + return numBits; + } + + @Override + public boolean isReadOnly() { + return readOnly; + } + + @Override + public boolean getBoolean(long index) { + Validator.getArgs(this, index); + return bitSet.get((int) index + offset); + } + + @Override + public BooleanDataBuffer setBoolean(boolean value, long index) { + Validator.setArgs(this, index); + bitSet.set((int) index + offset, value); + return this; + } + + @Override + public BooleanDataBuffer read(boolean[] dst, int offset, int length) { + Validator.readArgs(this, dst.length, offset, length); + for (int i = this.offset, j = offset; i < this.offset + length; ++i, ++j) { + dst[j] = bitSet.get(i); + } + return this; + } + + @Override + public BooleanDataBuffer write(boolean[] src, int offset, int length) { + Validator.readArgs(this, src.length, offset, length); + for (int i = this.offset, j = offset; i < this.offset + length; ++i, ++j) { + bitSet.set(i, src[j]); + } + return this; + } + + @Override + public BooleanDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public BooleanDataBuffer visit(boolean[] array, int arrayOffset, int arrayLength) { + for (int idx = 0; idx < size; ++idx) { + array[idx + arrayOffset] = bitSet.get(idx + offset); + } + return BitSetDataBuffer.this; + } + + @Override + public BooleanDataBuffer visit(BitSet dstBitSet, int dstOffset, long numBits) { + for (int idx = 0; idx < size; ++idx) { + dstBitSet.set(idx + dstOffset, bitSet.get(idx + offset)); + } + return BitSetDataBuffer.this; + } + + @Override + public BooleanDataBuffer fallback() { + if (dst instanceof BooleanDataBuffer) { + BooleanDataBuffer booleanDst = (BooleanDataBuffer) dst; + for (int idx = 0; idx < size; ++idx) { + booleanDst.setBoolean(bitSet.get(idx + offset), idx); + } + } else { + for (int idx = 0; idx < size; ++idx) { + dst.setObject(bitSet.get(idx + offset), idx); + } + } + return BitSetDataBuffer.this; + } + }); + } + + @Override + public BooleanDataBuffer slice(long index, long size) { + Validator.sliceArgs(this, index, size); + return new BitSetDataBuffer(bitSet, size, readOnly, offset + (int) index); + } + + @Override + public R accept(DataStorageVisitor visitor) { + return visitor.visit(bitSet, offset, numBits); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof BooleanDataBuffer)) { + return super.equals(obj); + } + BooleanDataBuffer other = (BooleanDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(boolean[] array, int arrayOffset, int length) { + for (int idx = 0; idx < size(); ++idx) { + if (array[idx + arrayOffset] != bitSet.get(idx + offset)) { + return false; + } + } + return true; + } + + @Override + public Boolean visit(BitSet otherBitSet, int otherOffset, long otherNumBits) { + if (offset == 0 && otherOffset == 0 && numBits == otherNumBits) { + return bitSet.equals(otherBitSet); + } + for (int idx = 0; idx < size(); ++idx) { + if (otherBitSet.get(idx + otherOffset) != bitSet.get(idx + offset)) { + return false; + } + } + return true; + } + + @Override + public Boolean fallback() { + for (int idx = 0; idx < size(); ++idx) { + if (other.getBoolean(idx) != bitSet.get(idx + offset)) { + return false; + } + } + return true; + } + }); + } + + BitSetDataBuffer(BitSet bitSet, long numBits, boolean readOnly) { + this(bitSet, numBits, readOnly, 0); + } + + private BitSetDataBuffer(BitSet bitSet, long numBits, boolean readOnly, int offset) { + this.bitSet = bitSet; + this.numBits = numBits; + this.readOnly = readOnly; + this.offset = offset; + } + + private final BitSet bitSet; + private final long numBits; + private final boolean readOnly; + private final int offset; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/misc/BooleanArrayDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/misc/BooleanArrayDataBuffer.java new file mode 100644 index 00000000000..ac1715a7d27 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/misc/BooleanArrayDataBuffer.java @@ -0,0 +1,180 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.misc; + +import java.util.Arrays; +import java.util.BitSet; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.impl.buffer.AbstractDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +class BooleanArrayDataBuffer extends AbstractDataBuffer implements BooleanDataBuffer { + + @Override + public long size() { + return length; + } + + @Override + public boolean isReadOnly() { + return readOnly; + } + + @Override + public boolean getBoolean(long index) { + Validator.getArgs(this, index); + return values[(int) index + offset]; + } + + @Override + public BooleanDataBuffer setBoolean(boolean value, long index) { + Validator.setArgs(this, index); + values[(int) index + offset] = value; + return this; + } + + @Override + public BooleanDataBuffer read(boolean[] dst, int offset, int length) { + System.arraycopy(values, this.offset, dst, offset, length); + return this; + } + + @Override + public BooleanDataBuffer write(boolean[] src, int offset, int length) { + System.arraycopy(src, offset, values, this.offset, length); + return null; + } + + @Override + public BooleanDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public BooleanDataBuffer visit(boolean[] array, int arrayOffset, int arrayLength) { + System.arraycopy(values, offset, array, arrayOffset, (int) size); + return BooleanArrayDataBuffer.this; + } + + @Override + public BooleanDataBuffer visit(BitSet bitSet, int bitSetOffset, long numBits) { + for (int idx = 0; idx < size; ++idx) { + bitSet.set(idx + bitSetOffset, values[idx + offset]); + } + return BooleanArrayDataBuffer.this; + } + + @Override + public BooleanDataBuffer fallback() { + if (dst instanceof BooleanDataBuffer) { + BooleanDataBuffer booleanDst = (BooleanDataBuffer) dst; + for (int idx = 0; idx < size; ++idx) { + booleanDst.setBoolean(values[idx + offset], idx); + } + } else { + for (int idx = 0; idx < size; ++idx) { + dst.setObject(values[idx + offset], idx); + } + } + return BooleanArrayDataBuffer.this; + } + }); + } + + @Override + public BooleanDataBuffer slice(long index, long size) { + Validator.sliceArgs(this, index, size); + return new BooleanArrayDataBuffer(values, readOnly, offset + (int) index, (int) size); + } + + @Override + public R accept(DataStorageVisitor visitor) { + return visitor.visit(values, offset, length); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof BooleanDataBuffer)) { + return super.equals(obj); + } + BooleanDataBuffer other = (BooleanDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(boolean[] array, int arrayOffset, int arrayLength) { + if (offset == 0 + && values.length == length + && arrayOffset == 0 + && array.length == arrayLength) { + return Arrays.equals(array, values); + } + for (int idx = 0; idx < size(); ++idx) { + if (array[idx + arrayOffset] != values[idx + offset]) { + return false; + } + } + return true; + } + + @Override + public Boolean visit(BitSet bitSet, int bitSetOffset, long numBits) { + for (int idx = 0; idx < size(); ++idx) { + if (bitSet.get(idx + bitSetOffset) != values[idx + offset]) { + return false; + } + } + return true; + } + + @Override + public Boolean fallback() { + for (int idx = 0; idx < size(); ++idx) { + if (other.getBoolean(idx) != values[idx + offset]) { + return false; + } + } + return true; + } + }); + } + + BooleanArrayDataBuffer(boolean[] values, boolean readOnly) { + this(values, readOnly, 0, values.length); + } + + private BooleanArrayDataBuffer(boolean[] values, boolean readOnly, int offset, int length) { + this.values = values; + this.readOnly = readOnly; + this.offset = offset; + this.length = length; + } + + private final boolean[] values; + private final boolean readOnly; + private final int offset; + private final int length; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/misc/MiscDataBufferFactory.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/misc/MiscDataBufferFactory.java new file mode 100644 index 00000000000..73bbaa2d3d3 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/misc/MiscDataBufferFactory.java @@ -0,0 +1,38 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.misc; + +import java.util.BitSet; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; + +/** Factory of miscellaneous data buffers */ +public class MiscDataBufferFactory { + + public static BooleanDataBuffer create(BitSet bitSet, long numBits, boolean readOnly) { + return new BitSetDataBuffer(bitSet, numBits, readOnly); + } + + public static BooleanDataBuffer create(boolean[] array, boolean readOnly) { + return new BooleanArrayDataBuffer(array, readOnly); + } + + public static DataBuffer create(T[] array, boolean readOnly) { + return new ArrayDataBuffer<>(array, readOnly); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/AbstractNioDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/AbstractNioDataBuffer.java new file mode 100644 index 00000000000..3709b97008c --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/AbstractNioDataBuffer.java @@ -0,0 +1,41 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ + +package org.tensorflow.ndarray.impl.buffer.nio; + +import java.nio.Buffer; +import org.tensorflow.ndarray.impl.buffer.AbstractDataBuffer; + +/** + * Base class for all JDK-based data buffers. + * + * @param type of elements (or values) stored in this buffer + */ +abstract class AbstractNioDataBuffer extends AbstractDataBuffer { + + @Override + public long size() { + return buf().capacity(); + } + + @Override + public boolean isReadOnly() { + return buf().isReadOnly(); + } + + abstract Buffer buf(); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/ByteNioDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/ByteNioDataBuffer.java new file mode 100644 index 00000000000..9263286ddb4 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/ByteNioDataBuffer.java @@ -0,0 +1,184 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ + +package org.tensorflow.ndarray.impl.buffer.nio; + +import java.nio.ByteBuffer; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.buffer.layout.DataLayouts; +import org.tensorflow.ndarray.impl.buffer.Validator; +import org.tensorflow.ndarray.impl.buffer.adapter.DataBufferAdapterFactory; + +/** A buffer of bytes using a JDK {@link ByteBuffer} for storage. */ +final class ByteNioDataBuffer extends AbstractNioDataBuffer implements ByteDataBuffer { + + @Override + public byte getByte(long index) { + return buf.get((int) index); + } + + @Override + public ByteDataBuffer setByte(byte value, long index) { + buf.put((int) index, value); + return this; + } + + @Override + public ByteDataBuffer read(byte[] dst, int offset, int length) { + buf.duplicate().get(dst, offset, length); + return this; + } + + @Override + public ByteDataBuffer write(byte[] src, int offset, int length) { + buf.duplicate().put(src, offset, length); + return this; + } + + @Override + public ByteDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public ByteDataBuffer visit(ByteBuffer buffer) { + buffer.duplicate().put((ByteBuffer) buf.duplicate().limit((int) size)); + return ByteNioDataBuffer.this; + } + + @Override + public ByteDataBuffer fallback() { + if (dst instanceof ByteDataBuffer) { + ByteDataBuffer byteDst = (ByteDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + byteDst.setByte(getByte(idx), idx); + } + return ByteNioDataBuffer.this; + } + return slowCopyTo(dst, size); + } + }); + } + + @Override + public IntDataBuffer asInts() { + return new IntNioDataBuffer(buf.asIntBuffer()); + } + + @Override + public ShortDataBuffer asShorts() { + return new ShortNioDataBuffer(buf.asShortBuffer()); + } + + @Override + public LongDataBuffer asLongs() { + return new LongNioDataBuffer(buf.asLongBuffer()); + } + + @Override + public FloatDataBuffer asFloats() { + return new FloatNioDataBuffer(buf.asFloatBuffer()); + } + + @Override + public DoubleDataBuffer asDoubles() { + return new DoubleNioDataBuffer(buf.asDoubleBuffer()); + } + + @Override + public BooleanDataBuffer asBooleans() { + return DataBufferAdapterFactory.create(this, DataLayouts.BOOL); + } + + @Override + public ByteDataBuffer offset(long index) { + Validator.offsetArgs(this, index); + return new ByteNioDataBuffer(((ByteBuffer) buf.duplicate().position((int) index)).slice()); + } + + @Override + public ByteDataBuffer narrow(long size) { + Validator.narrowArgs(this, size); + return new ByteNioDataBuffer(((ByteBuffer) buf.duplicate().limit((int) size)).slice()); + } + + @Override + public ByteDataBuffer slice(long index, long size) { + Validator.sliceArgs(this, index, size); + ByteBuffer sliceBuf = buf.duplicate(); + sliceBuf.position((int) index); + sliceBuf.limit((int) index + (int) size); + return new ByteNioDataBuffer(sliceBuf.slice()); + } + + @Override + public R accept(DataStorageVisitor visitor) { + return visitor.visit(buf); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof ByteDataBuffer)) { + return super.equals(obj); + } + ByteDataBuffer other = (ByteDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(ByteBuffer buffer) { + return buf.equals(buffer); + } + + @Override + public Boolean fallback() { + for (int idx = 0; idx < size(); ++idx) { + if (other.getByte(idx) != getByte(idx)) { + return false; + } + } + return true; + } + }); + } + + @Override + ByteBuffer buf() { + return buf; + } + + ByteNioDataBuffer(ByteBuffer buf) { + this.buf = buf; + } + + private ByteBuffer buf; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/DoubleNioDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/DoubleNioDataBuffer.java new file mode 100644 index 00000000000..87d16da292d --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/DoubleNioDataBuffer.java @@ -0,0 +1,146 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ + +package org.tensorflow.ndarray.impl.buffer.nio; + +import java.nio.DoubleBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +/** A buffer of bytes using a JDK {@link DoubleBuffer} for storage. */ +final class DoubleNioDataBuffer extends AbstractNioDataBuffer implements DoubleDataBuffer { + + @Override + public double getDouble(long index) { + return buf.get((int) index); + } + + @Override + public DoubleDataBuffer setDouble(double value, long index) { + buf.put((int) index, value); + return this; + } + + @Override + public DoubleDataBuffer read(double[] dst, int offset, int length) { + buf.duplicate().get(dst, offset, length); + return this; + } + + @Override + public DoubleDataBuffer write(double[] src, int offset, int length) { + buf.duplicate().put(src, offset, length); + return this; + } + + @Override + public DoubleDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public DoubleDataBuffer visit(DoubleBuffer buffer) { + buffer.duplicate().put((DoubleBuffer) buf.duplicate().limit((int) size)); + return DoubleNioDataBuffer.this; + } + + @Override + public DoubleDataBuffer fallback() { + if (dst instanceof DoubleDataBuffer) { + DoubleDataBuffer doubleDst = (DoubleDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + doubleDst.setDouble(getDouble(idx), idx); + } + return DoubleNioDataBuffer.this; + } + return slowCopyTo(dst, size); + } + }); + } + + @Override + public DoubleDataBuffer offset(long index) { + Validator.offsetArgs(this, index); + return new DoubleNioDataBuffer(((DoubleBuffer) buf.duplicate().position((int) index)).slice()); + } + + @Override + public DoubleDataBuffer narrow(long size) { + Validator.narrowArgs(this, size); + return new DoubleNioDataBuffer(((DoubleBuffer) buf.duplicate().limit((int) size)).slice()); + } + + @Override + public DoubleDataBuffer slice(long index, long size) { + Validator.sliceArgs(this, index, size); + DoubleBuffer sliceBuf = buf.duplicate(); + sliceBuf.position((int) index); + sliceBuf.limit((int) index + (int) size); + return new DoubleNioDataBuffer(sliceBuf.slice()); + } + + @Override + public R accept(DataStorageVisitor visitor) { + return visitor.visit(buf); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof DoubleDataBuffer)) { + return super.equals(obj); + } + DoubleDataBuffer other = (DoubleDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(DoubleBuffer buffer) { + return buf.equals(buffer); + } + + @Override + public Boolean fallback() { + for (int idx = 0; idx < size(); ++idx) { + if (other.getDouble(idx) != getDouble(idx)) { + return false; + } + } + return true; + } + }); + } + + @Override + DoubleBuffer buf() { + return buf; + } + + DoubleNioDataBuffer(DoubleBuffer buf) { + this.buf = buf; + } + + private DoubleBuffer buf; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/FloatNioDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/FloatNioDataBuffer.java new file mode 100644 index 00000000000..8fc3d2681e6 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/FloatNioDataBuffer.java @@ -0,0 +1,146 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ + +package org.tensorflow.ndarray.impl.buffer.nio; + +import java.nio.FloatBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +/** A buffer of bytes using a JDK {@link FloatBuffer} for storage. */ +final class FloatNioDataBuffer extends AbstractNioDataBuffer implements FloatDataBuffer { + + @Override + public float getFloat(long index) { + return buf.get((int) index); + } + + @Override + public FloatDataBuffer setFloat(float value, long index) { + buf.put((int) index, value); + return this; + } + + @Override + public FloatDataBuffer read(float[] dst, int offset, int length) { + buf.duplicate().get(dst, offset, length); + return this; + } + + @Override + public FloatDataBuffer write(float[] src, int offset, int length) { + buf.duplicate().put(src, offset, length); + return this; + } + + @Override + public FloatDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public FloatDataBuffer visit(FloatBuffer buffer) { + buffer.duplicate().put((FloatBuffer) buf.duplicate().limit((int) size)); + return FloatNioDataBuffer.this; + } + + @Override + public FloatDataBuffer fallback() { + if (dst instanceof FloatDataBuffer) { + FloatDataBuffer floatDst = (FloatDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + floatDst.setFloat(getFloat(idx), idx); + } + return FloatNioDataBuffer.this; + } + return slowCopyTo(dst, size); + } + }); + } + + @Override + public FloatDataBuffer offset(long index) { + Validator.offsetArgs(this, index); + return new FloatNioDataBuffer(((FloatBuffer) buf.duplicate().position((int) index)).slice()); + } + + @Override + public FloatDataBuffer narrow(long size) { + Validator.narrowArgs(this, size); + return new FloatNioDataBuffer(((FloatBuffer) buf.duplicate().limit((int) size)).slice()); + } + + @Override + public FloatDataBuffer slice(long index, long size) { + Validator.sliceArgs(this, index, size); + FloatBuffer sliceBuf = buf.duplicate(); + sliceBuf.position((int) index); + sliceBuf.limit((int) index + (int) size); + return new FloatNioDataBuffer(sliceBuf.slice()); + } + + @Override + public R accept(DataStorageVisitor visitor) { + return visitor.visit(buf); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof FloatDataBuffer)) { + return super.equals(obj); + } + FloatDataBuffer other = (FloatDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(FloatBuffer buffer) { + return buf.equals(buffer); + } + + @Override + public Boolean fallback() { + for (int idx = 0; idx < size(); ++idx) { + if (other.getFloat(idx) != getFloat(idx)) { + return false; + } + } + return true; + } + }); + } + + @Override + FloatBuffer buf() { + return buf; + } + + FloatNioDataBuffer(FloatBuffer buf) { + this.buf = buf; + } + + private FloatBuffer buf; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/IntNioDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/IntNioDataBuffer.java new file mode 100644 index 00000000000..eda8c8f61b5 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/IntNioDataBuffer.java @@ -0,0 +1,146 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ + +package org.tensorflow.ndarray.impl.buffer.nio; + +import java.nio.IntBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +/** A buffer of bytes using a JDK {@link IntBuffer} for storage. */ +final class IntNioDataBuffer extends AbstractNioDataBuffer implements IntDataBuffer { + + @Override + public int getInt(long index) { + return buf.get((int) index); + } + + @Override + public IntDataBuffer setInt(int value, long index) { + buf.put((int) index, value); + return this; + } + + @Override + public IntDataBuffer read(int[] dst, int offset, int length) { + buf.duplicate().get(dst, offset, length); + return this; + } + + @Override + public IntDataBuffer write(int[] src, int offset, int length) { + buf.duplicate().put(src, offset, length); + return this; + } + + @Override + public IntDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public IntDataBuffer visit(IntBuffer buffer) { + buffer.duplicate().put((IntBuffer) buf.duplicate().limit((int) size)); + return IntNioDataBuffer.this; + } + + @Override + public IntDataBuffer fallback() { + if (dst instanceof IntDataBuffer) { + IntDataBuffer intDst = (IntDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + intDst.setInt(getInt(idx), idx); + } + return IntNioDataBuffer.this; + } + return slowCopyTo(dst, size); + } + }); + } + + @Override + public IntDataBuffer offset(long index) { + Validator.offsetArgs(this, index); + return new IntNioDataBuffer(((IntBuffer) buf.duplicate().position((int) index)).slice()); + } + + @Override + public IntDataBuffer narrow(long size) { + Validator.narrowArgs(this, size); + return new IntNioDataBuffer(((IntBuffer) buf.duplicate().limit((int) size)).slice()); + } + + @Override + public IntDataBuffer slice(long index, long size) { + Validator.sliceArgs(this, index, size); + IntBuffer sliceBuf = buf.duplicate(); + sliceBuf.position((int) index); + sliceBuf.limit((int) index + (int) size); + return new IntNioDataBuffer(sliceBuf.slice()); + } + + @Override + public R accept(DataStorageVisitor visitor) { + return visitor.visit(buf); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof IntDataBuffer)) { + return super.equals(obj); + } + IntDataBuffer other = (IntDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(IntBuffer buffer) { + return buf.equals(buffer); + } + + @Override + public Boolean fallback() { + for (int idx = 0; idx < size(); ++idx) { + if (other.getInt(idx) != getInt(idx)) { + return false; + } + } + return true; + } + }); + } + + @Override + IntBuffer buf() { + return buf; + } + + IntNioDataBuffer(IntBuffer buf) { + this.buf = buf; + } + + private IntBuffer buf; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/LongNioDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/LongNioDataBuffer.java new file mode 100644 index 00000000000..ceffbeb216e --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/LongNioDataBuffer.java @@ -0,0 +1,146 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ + +package org.tensorflow.ndarray.impl.buffer.nio; + +import java.nio.LongBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +/** A buffer of bytes using a JDK {@link LongBuffer} for storage. */ +final class LongNioDataBuffer extends AbstractNioDataBuffer implements LongDataBuffer { + + @Override + public long getLong(long index) { + return buf.get((int) index); + } + + @Override + public LongDataBuffer setLong(long value, long index) { + buf.put((int) index, value); + return this; + } + + @Override + public LongDataBuffer read(long[] dst, int offset, int length) { + buf.duplicate().get(dst, offset, length); + return this; + } + + @Override + public LongDataBuffer write(long[] src, int offset, int length) { + buf.duplicate().put(src, offset, length); + return this; + } + + @Override + public LongDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public LongDataBuffer visit(LongBuffer buffer) { + buffer.duplicate().put((LongBuffer) buf.duplicate().limit((int) size)); + return LongNioDataBuffer.this; + } + + @Override + public LongDataBuffer fallback() { + if (dst instanceof LongDataBuffer) { + LongDataBuffer longDst = (LongDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + longDst.setLong(getLong(idx), idx); + } + return LongNioDataBuffer.this; + } + return slowCopyTo(dst, size); + } + }); + } + + @Override + public LongDataBuffer offset(long index) { + Validator.offsetArgs(this, index); + return new LongNioDataBuffer(((LongBuffer) buf.duplicate().position((int) index)).slice()); + } + + @Override + public LongDataBuffer narrow(long size) { + Validator.narrowArgs(this, size); + return new LongNioDataBuffer(((LongBuffer) buf.duplicate().limit((int) size)).slice()); + } + + @Override + public LongDataBuffer slice(long index, long size) { + Validator.sliceArgs(this, index, size); + LongBuffer sliceBuf = buf.duplicate(); + sliceBuf.position((int) index); + sliceBuf.limit((int) index + (int) size); + return new LongNioDataBuffer(sliceBuf.slice()); + } + + @Override + public R accept(DataStorageVisitor visitor) { + return visitor.visit(buf); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof LongDataBuffer)) { + return super.equals(obj); + } + LongDataBuffer other = (LongDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(LongBuffer buffer) { + return buf.equals(buffer); + } + + @Override + public Boolean fallback() { + for (int idx = 0; idx < size(); ++idx) { + if (other.getLong(idx) != getLong(idx)) { + return false; + } + } + return true; + } + }); + } + + @Override + LongBuffer buf() { + return buf; + } + + LongNioDataBuffer(LongBuffer buf) { + this.buf = buf; + } + + private LongBuffer buf; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/NioDataBufferFactory.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/NioDataBufferFactory.java new file mode 100644 index 00000000000..e26b2a702ff --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/NioDataBufferFactory.java @@ -0,0 +1,59 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.nio; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.nio.ShortBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; + +/** Factory of JDK NIO-based data buffers */ +public class NioDataBufferFactory { + + public static ByteDataBuffer create(ByteBuffer buffer) { + return new ByteNioDataBuffer(buffer); + } + + public static DoubleDataBuffer create(DoubleBuffer buffer) { + return new DoubleNioDataBuffer(buffer); + } + + public static FloatDataBuffer create(FloatBuffer buffer) { + return new FloatNioDataBuffer(buffer); + } + + public static IntDataBuffer create(IntBuffer buffer) { + return new IntNioDataBuffer(buffer); + } + + public static LongDataBuffer create(LongBuffer buffer) { + return new LongNioDataBuffer(buffer); + } + + public static ShortDataBuffer create(ShortBuffer buffer) { + return new ShortNioDataBuffer(buffer); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/ShortNioDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/ShortNioDataBuffer.java new file mode 100644 index 00000000000..e6535275f07 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/nio/ShortNioDataBuffer.java @@ -0,0 +1,146 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ + +package org.tensorflow.ndarray.impl.buffer.nio; + +import java.nio.ShortBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +/** A buffer of bytes using a JDK {@link ShortBuffer} for storage. */ +final class ShortNioDataBuffer extends AbstractNioDataBuffer implements ShortDataBuffer { + + @Override + public short getShort(long index) { + return buf.get((int) index); + } + + @Override + public ShortDataBuffer setShort(short value, long index) { + buf.put((int) index, value); + return this; + } + + @Override + public ShortDataBuffer read(short[] dst, int offset, int length) { + buf.duplicate().get(dst, offset, length); + return this; + } + + @Override + public ShortDataBuffer write(short[] src, int offset, int length) { + buf.duplicate().put(src, offset, length); + return this; + } + + @Override + public ShortDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public ShortDataBuffer visit(ShortBuffer buffer) { + buffer.duplicate().put((ShortBuffer) buf.duplicate().limit((int) size)); + return ShortNioDataBuffer.this; + } + + @Override + public ShortDataBuffer fallback() { + if (dst instanceof ShortDataBuffer) { + ShortDataBuffer shortDst = (ShortDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + shortDst.setShort(getShort(idx), idx); + } + return ShortNioDataBuffer.this; + } + return slowCopyTo(dst, size); + } + }); + } + + @Override + public ShortDataBuffer offset(long index) { + Validator.offsetArgs(this, index); + return new ShortNioDataBuffer(((ShortBuffer) buf.duplicate().position((int) index)).slice()); + } + + @Override + public ShortDataBuffer narrow(long size) { + Validator.narrowArgs(this, size); + return new ShortNioDataBuffer(((ShortBuffer) buf.duplicate().limit((int) size)).slice()); + } + + @Override + public ShortDataBuffer slice(long index, long size) { + Validator.sliceArgs(this, index, size); + ShortBuffer sliceBuf = buf.duplicate(); + sliceBuf.position((int) index); + sliceBuf.limit((int) index + (int) size); + return new ShortNioDataBuffer(sliceBuf.slice()); + } + + @Override + public R accept(DataStorageVisitor visitor) { + return visitor.visit(buf); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof ShortDataBuffer)) { + return super.equals(obj); + } + ShortDataBuffer other = (ShortDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(ShortBuffer buffer) { + return buf.equals(buffer); + } + + @Override + public Boolean fallback() { + for (int idx = 0; idx < size(); ++idx) { + if (other.getShort(idx) != getShort(idx)) { + return false; + } + } + return true; + } + }); + } + + @Override + ShortBuffer buf() { + return buf; + } + + ShortNioDataBuffer(ShortBuffer buf) { + this.buf = buf; + } + + private ShortBuffer buf; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/AbstractRawDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/AbstractRawDataBuffer.java new file mode 100644 index 00000000000..0ce4da0f602 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/AbstractRawDataBuffer.java @@ -0,0 +1,94 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.raw; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBufferWindow; +import org.tensorflow.ndarray.impl.buffer.AbstractDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +@SuppressWarnings("unchecked") +abstract class AbstractRawDataBuffer> extends AbstractDataBuffer { + + public long size() { + return memory.size(); + } + + @Override + public boolean isReadOnly() { + return readOnly; + } + + public B read(Object dst, int dstLength) { + Validator.readArgs(this, dstLength, 0, dstLength); + memory.copyTo(UnsafeMemoryHandle.fromArray(dst, dstLength), dstLength); + return (B) this; + } + + public B read(Object dst, int dstLength, int offset, int length) { + Validator.readArgs(this, dstLength, offset, length); + memory.copyTo(UnsafeMemoryHandle.fromArray(dst, dstLength).offset(offset), length); + return (B) this; + } + + public B write(Object src, int srcLength) { + Validator.writeArgs(this, srcLength, 0, srcLength); + UnsafeMemoryHandle.fromArray(src, srcLength).copyTo(memory, srcLength); + return (B) this; + } + + public B write(Object src, int srcLength, int offset, int length) { + Validator.writeArgs(this, srcLength, offset, length); + UnsafeMemoryHandle.fromArray(src, srcLength).offset(offset).copyTo(memory, length); + return (B) this; + } + + @Override + public B copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + if (dst instanceof AbstractRawDataBuffer) { + AbstractRawDataBuffer unsafeDst = (AbstractRawDataBuffer) dst; + memory.copyTo(unsafeDst.memory, size); + } else { + super.copyTo(dst, size); + } + return (B) this; + } + + @Override + public B slice(long index, long size) { + Validator.sliceArgs(this, index, size); + return instantiate(memory.slice(index, size)); + } + + @Override + public DataBufferWindow window(long size) { + B windowBuffer = instantiate(memory.slice(0, size)); + return new RawDataBufferWindow<>((AbstractRawDataBuffer) windowBuffer, size()); + } + + protected final UnsafeMemoryHandle memory; + protected final boolean readOnly; + + protected abstract B instantiate(UnsafeMemoryHandle region); + + AbstractRawDataBuffer(UnsafeMemoryHandle memory, boolean readOnly) { + this.memory = memory; + this.readOnly = readOnly; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/BooleanRawDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/BooleanRawDataBuffer.java new file mode 100644 index 00000000000..ecde38f72d8 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/BooleanRawDataBuffer.java @@ -0,0 +1,149 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.raw; + +import java.util.Arrays; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.impl.buffer.Validator; + +final class BooleanRawDataBuffer extends AbstractRawDataBuffer + implements BooleanDataBuffer { + + @Override + public boolean getBoolean(long index) { + Validator.getArgs(this, index); + return memory.getBoolean(index); + } + + @Override + public BooleanDataBuffer setBoolean(boolean value, long index) { + Validator.setArgs(this, index); + memory.setBoolean(value, index); + return this; + } + + @Override + public BooleanDataBuffer read(boolean[] dst) { + return read(dst, dst.length); + } + + @Override + public BooleanDataBuffer read(boolean[] dst, int offset, int length) { + return read(dst, dst.length, offset, length); + } + + @Override + public BooleanDataBuffer write(boolean[] src) { + return write(src, src.length); + } + + @Override + public BooleanDataBuffer write(boolean[] src, int offset, int length) { + return write(src, src.length, offset, length); + } + + @Override + public BooleanDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public BooleanDataBuffer visit(boolean[] array, int offset, int length) { + memory.copyTo(UnsafeMemoryHandle.fromArray(array, offset, length), size); + return BooleanRawDataBuffer.this; + } + + @Override + public BooleanDataBuffer visit(long address, long length, long scale) { + memory.copyTo(UnsafeMemoryHandle.fromAddress(address, length, scale), size); + return BooleanRawDataBuffer.this; + } + + @Override + public BooleanDataBuffer fallback() { + if (dst instanceof BooleanDataBuffer) { + BooleanDataBuffer booleanDst = (BooleanDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + booleanDst.setBoolean(getBoolean(idx), idx); + } + return BooleanRawDataBuffer.this; + } + return slowCopyTo(dst, size); + } + }); + } + + @Override + public R accept(DataStorageVisitor visitor) { + if (memory.isArray()) { + return visitor.visit( + (boolean[]) memory.object, memory.arrayOffset(boolean[].class), (int) memory.size()); + } + return visitor.visit(memory.byteOffset, memory.byteSize, memory.scale); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof BooleanDataBuffer)) { + return super.equals(obj); + } + BooleanDataBuffer other = (BooleanDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(boolean[] array, int offset, int length) { + if (memory.isArray() && memory.arrayOffset(boolean[].class) == 0 && offset == 0) { + boolean[] thisArray = memory.array(); + if (thisArray.length == array.length) { + return Arrays.equals(thisArray, array); + } + } + return fallback(); + } + + @Override + public Boolean fallback() { + for (long idx = 0L; idx < size(); ++idx) { + if (other.getBoolean(idx) != getBoolean(idx)) { + return false; + } + } + return true; + } + }); + } + + @Override + protected BooleanDataBuffer instantiate(UnsafeMemoryHandle memory) { + return new BooleanRawDataBuffer(memory, readOnly); + } + + BooleanRawDataBuffer(UnsafeMemoryHandle memory, boolean readOnly) { + super(memory, readOnly); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/ByteRawDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/ByteRawDataBuffer.java new file mode 100644 index 00000000000..5916064a516 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/ByteRawDataBuffer.java @@ -0,0 +1,189 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.raw; + +import java.nio.ByteBuffer; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +final class ByteRawDataBuffer extends AbstractRawDataBuffer + implements ByteDataBuffer { + + @Override + public byte getByte(long index) { + Validator.getArgs(this, index); + return memory.getByte(index); + } + + @Override + public ByteDataBuffer setByte(byte value, long index) { + Validator.setArgs(this, index); + memory.setByte(value, index); + return this; + } + + @Override + public ByteDataBuffer read(byte[] dst) { + return read(dst, dst.length); + } + + @Override + public ByteDataBuffer read(byte[] dst, int offset, int length) { + return read(dst, dst.length, offset, length); + } + + @Override + public ByteDataBuffer write(byte[] src) { + return write(src, src.length); + } + + @Override + public ByteDataBuffer write(byte[] src, int offset, int length) { + return write(src, src.length, offset, length); + } + + @Override + public ByteDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public ByteDataBuffer visit(ByteBuffer buffer) { + if (buffer.hasArray()) { + memory.copyTo( + UnsafeMemoryHandle.fromArray(buffer.array(), buffer.position(), buffer.limit()), + size); + } else if (memory.isArray()) { + buffer.put(memory.toArrayByteBuffer()); + } else { + slowCopyTo(dst, size); + } + return ByteRawDataBuffer.this; + } + + @Override + public ByteDataBuffer visit(long address, long length, long scale) { + memory.copyTo(UnsafeMemoryHandle.fromAddress(address, length, scale), size); + return ByteRawDataBuffer.this; + } + + @Override + public ByteDataBuffer fallback() { + if (dst instanceof ByteDataBuffer) { + ByteDataBuffer byteDst = (ByteDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + byteDst.setByte(getByte(idx), idx); + } + return ByteRawDataBuffer.this; + } + return slowCopyTo(dst, size); + } + }); + } + + @Override + public IntDataBuffer asInts() { + return new IntRawDataBuffer(memory.rescale(Integer.BYTES), readOnly); + } + + @Override + public ShortDataBuffer asShorts() { + return new ShortRawDataBuffer(memory.rescale(Short.BYTES), readOnly); + } + + @Override + public LongDataBuffer asLongs() { + return new LongRawDataBuffer(memory.rescale(Long.BYTES), readOnly); + } + + @Override + public FloatDataBuffer asFloats() { + return new FloatRawDataBuffer(memory.rescale(Float.BYTES), readOnly); + } + + @Override + public DoubleDataBuffer asDoubles() { + return new DoubleRawDataBuffer(memory.rescale(Double.BYTES), readOnly); + } + + @Override + public BooleanDataBuffer asBooleans() { + return new BooleanRawDataBuffer(memory.rescale(Byte.BYTES), readOnly); + } + + @Override + public R accept(DataStorageVisitor visitor) { + if (memory.isArray()) { + return visitor.visit(memory.toArrayByteBuffer()); + } + return visitor.visit(memory.byteOffset, memory.byteSize, memory.scale); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof ByteDataBuffer)) { + return super.equals(obj); + } + ByteDataBuffer other = (ByteDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(ByteBuffer buffer) { + if (memory.isArray()) { + return buffer.equals(memory.toArrayByteBuffer()); + } + return fallback(); + } + + @Override + public Boolean fallback() { + for (long idx = 0L; idx < size(); ++idx) { + if (other.getByte(idx) != getByte(idx)) { + return false; + } + } + return true; + } + }); + } + + @Override + protected ByteDataBuffer instantiate(UnsafeMemoryHandle memory) { + return new ByteRawDataBuffer(memory, readOnly); + } + + ByteRawDataBuffer(UnsafeMemoryHandle memory, boolean readOnly) { + super(memory, readOnly); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/DoubleRawDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/DoubleRawDataBuffer.java new file mode 100644 index 00000000000..c6259f7aa61 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/DoubleRawDataBuffer.java @@ -0,0 +1,153 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.raw; + +import java.nio.DoubleBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +final class DoubleRawDataBuffer extends AbstractRawDataBuffer + implements DoubleDataBuffer { + + @Override + public double getDouble(long index) { + Validator.getArgs(this, index); + return memory.getDouble(index); + } + + @Override + public DoubleDataBuffer setDouble(double value, long index) { + Validator.setArgs(this, index); + memory.setDouble(value, index); + return this; + } + + @Override + public DoubleDataBuffer read(double[] dst) { + return read(dst, dst.length); + } + + @Override + public DoubleDataBuffer read(double[] dst, int offset, int length) { + return read(dst, dst.length, offset, length); + } + + @Override + public DoubleDataBuffer write(double[] src) { + return write(src, src.length); + } + + @Override + public DoubleDataBuffer write(double[] src, int offset, int length) { + return write(src, src.length, offset, length); + } + + @Override + public DoubleDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public DoubleDataBuffer visit(DoubleBuffer buffer) { + if (buffer.hasArray()) { + memory.copyTo( + UnsafeMemoryHandle.fromArray(buffer.array(), buffer.position(), buffer.limit()), + size); + } else if (memory.isArray()) { + buffer.put(memory.toArrayDoubleBuffer()); + } else { + slowCopyTo(dst, size); + } + return DoubleRawDataBuffer.this; + } + + @Override + public DoubleDataBuffer visit(long address, long length, long scale) { + memory.copyTo(UnsafeMemoryHandle.fromAddress(address, length, scale), size); + return DoubleRawDataBuffer.this; + } + + @Override + public DoubleDataBuffer fallback() { + if (dst instanceof DoubleDataBuffer) { + DoubleDataBuffer doubleDst = (DoubleDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + doubleDst.setDouble(getDouble(idx), idx); + } + return DoubleRawDataBuffer.this; + } + return slowCopyTo(dst, size); + } + }); + } + + @Override + public R accept(DataStorageVisitor visitor) { + if (memory.isArray()) { + return visitor.visit(memory.toArrayDoubleBuffer()); + } + return visitor.visit(memory.byteOffset, memory.byteSize, memory.scale); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof DoubleDataBuffer)) { + return super.equals(obj); + } + DoubleDataBuffer other = (DoubleDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(DoubleBuffer buffer) { + if (memory.isArray()) { + return buffer.equals(memory.toArrayDoubleBuffer()); + } + return fallback(); + } + + @Override + public Boolean fallback() { + for (long idx = 0L; idx < size(); ++idx) { + if (other.getDouble(idx) != getDouble(idx)) { + return false; + } + } + return true; + } + }); + } + + @Override + protected DoubleDataBuffer instantiate(UnsafeMemoryHandle memory) { + return new DoubleRawDataBuffer(memory, readOnly); + } + + DoubleRawDataBuffer(UnsafeMemoryHandle memory, boolean readOnly) { + super(memory, readOnly); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/FloatRawDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/FloatRawDataBuffer.java new file mode 100644 index 00000000000..2a9d12f47b3 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/FloatRawDataBuffer.java @@ -0,0 +1,153 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.raw; + +import java.nio.FloatBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +final class FloatRawDataBuffer extends AbstractRawDataBuffer + implements FloatDataBuffer { + + @Override + public float getFloat(long index) { + Validator.getArgs(this, index); + return memory.getFloat(index); + } + + @Override + public FloatDataBuffer setFloat(float value, long index) { + Validator.setArgs(this, index); + memory.setFloat(value, index); + return this; + } + + @Override + public FloatDataBuffer read(float[] dst) { + return read(dst, dst.length); + } + + @Override + public FloatDataBuffer read(float[] dst, int offset, int length) { + return read(dst, dst.length, offset, length); + } + + @Override + public FloatDataBuffer write(float[] src) { + return write(src, src.length); + } + + @Override + public FloatDataBuffer write(float[] src, int offset, int length) { + return write(src, src.length, offset, length); + } + + @Override + public FloatDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public FloatDataBuffer visit(FloatBuffer buffer) { + if (buffer.hasArray()) { + memory.copyTo( + UnsafeMemoryHandle.fromArray(buffer.array(), buffer.position(), buffer.limit()), + size); + } else if (memory.isArray()) { + buffer.put(memory.toArrayFloatBuffer()); + } else { + slowCopyTo(dst, size); + } + return FloatRawDataBuffer.this; + } + + @Override + public FloatDataBuffer visit(long address, long length, long scale) { + memory.copyTo(UnsafeMemoryHandle.fromAddress(address, length, scale), size); + return FloatRawDataBuffer.this; + } + + @Override + public FloatDataBuffer fallback() { + if (dst instanceof FloatDataBuffer) { + FloatDataBuffer floatDst = (FloatDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + floatDst.setFloat(getFloat(idx), idx); + } + return FloatRawDataBuffer.this; + } + return slowCopyTo(dst, size); + } + }); + } + + @Override + public R accept(DataStorageVisitor visitor) { + if (memory.isArray()) { + return visitor.visit(memory.toArrayFloatBuffer()); + } + return visitor.visit(memory.byteOffset, memory.byteSize, memory.scale); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof FloatDataBuffer)) { + return super.equals(obj); + } + FloatDataBuffer other = (FloatDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(FloatBuffer buffer) { + if (memory.isArray()) { + return buffer.equals(memory.toArrayFloatBuffer()); + } + return fallback(); + } + + @Override + public Boolean fallback() { + for (long idx = 0L; idx < size(); ++idx) { + if (other.getFloat(idx) != getFloat(idx)) { + return false; + } + } + return true; + } + }); + } + + @Override + protected FloatDataBuffer instantiate(UnsafeMemoryHandle memory) { + return new FloatRawDataBuffer(memory, readOnly); + } + + FloatRawDataBuffer(UnsafeMemoryHandle memory, boolean readOnly) { + super(memory, readOnly); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/IntRawDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/IntRawDataBuffer.java new file mode 100644 index 00000000000..891c4c6f3aa --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/IntRawDataBuffer.java @@ -0,0 +1,153 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.raw; + +import java.nio.IntBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +final class IntRawDataBuffer extends AbstractRawDataBuffer + implements IntDataBuffer { + + @Override + public int getInt(long index) { + Validator.getArgs(this, index); + return memory.getInt(index); + } + + @Override + public IntDataBuffer setInt(int value, long index) { + Validator.setArgs(this, index); + memory.setInt(value, index); + return this; + } + + @Override + public IntDataBuffer read(int[] dst) { + return read(dst, dst.length); + } + + @Override + public IntDataBuffer read(int[] dst, int offset, int length) { + return read(dst, dst.length, offset, length); + } + + @Override + public IntDataBuffer write(int[] src) { + return write(src, src.length); + } + + @Override + public IntDataBuffer write(int[] src, int offset, int length) { + return write(src, src.length, offset, length); + } + + @Override + public IntDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public IntDataBuffer visit(IntBuffer buffer) { + if (buffer.hasArray()) { + memory.copyTo( + UnsafeMemoryHandle.fromArray(buffer.array(), buffer.position(), buffer.limit()), + size); + } else if (memory.isArray()) { + buffer.put(memory.toArrayIntBuffer()); + } else { + slowCopyTo(dst, size); + } + return IntRawDataBuffer.this; + } + + @Override + public IntDataBuffer visit(long address, long length, long scale) { + memory.copyTo(UnsafeMemoryHandle.fromAddress(address, length, scale), size); + return IntRawDataBuffer.this; + } + + @Override + public IntDataBuffer fallback() { + if (dst instanceof IntDataBuffer) { + IntDataBuffer intDst = (IntDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + intDst.setInt(getInt(idx), idx); + } + return IntRawDataBuffer.this; + } + return slowCopyTo(dst, size); + } + }); + } + + @Override + public R accept(DataStorageVisitor visitor) { + if (memory.isArray()) { + return visitor.visit(memory.toArrayIntBuffer()); + } + return visitor.visit(memory.byteOffset, memory.byteSize, memory.scale); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof IntDataBuffer)) { + return super.equals(obj); + } + IntDataBuffer other = (IntDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(IntBuffer buffer) { + if (memory.isArray()) { + return buffer.equals(memory.toArrayIntBuffer()); + } + return fallback(); + } + + @Override + public Boolean fallback() { + for (long idx = 0L; idx < size(); ++idx) { + if (other.getInt(idx) != getInt(idx)) { + return false; + } + } + return true; + } + }); + } + + @Override + protected IntDataBuffer instantiate(UnsafeMemoryHandle memory) { + return new IntRawDataBuffer(memory, readOnly); + } + + IntRawDataBuffer(UnsafeMemoryHandle memory, boolean readOnly) { + super(memory, readOnly); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/LongRawDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/LongRawDataBuffer.java new file mode 100644 index 00000000000..6fb918ae093 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/LongRawDataBuffer.java @@ -0,0 +1,153 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.raw; + +import java.nio.LongBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +final class LongRawDataBuffer extends AbstractRawDataBuffer + implements LongDataBuffer { + + @Override + public long getLong(long index) { + Validator.getArgs(this, index); + return memory.getLong(index); + } + + @Override + public LongDataBuffer setLong(long value, long index) { + Validator.setArgs(this, index); + memory.setLong(value, index); + return this; + } + + @Override + public LongDataBuffer read(long[] dst) { + return read(dst, dst.length); + } + + @Override + public LongDataBuffer read(long[] dst, int offset, int length) { + return read(dst, dst.length, offset, length); + } + + @Override + public LongDataBuffer write(long[] src) { + return write(src, src.length); + } + + @Override + public LongDataBuffer write(long[] src, int offset, int length) { + return write(src, src.length, offset, length); + } + + @Override + public LongDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public LongDataBuffer visit(LongBuffer buffer) { + if (buffer.hasArray()) { + memory.copyTo( + UnsafeMemoryHandle.fromArray(buffer.array(), buffer.position(), buffer.limit()), + size); + } else if (memory.isArray()) { + buffer.put(memory.toArrayLongBuffer()); + } else { + slowCopyTo(dst, size); + } + return LongRawDataBuffer.this; + } + + @Override + public LongDataBuffer visit(long address, long length, long scale) { + memory.copyTo(UnsafeMemoryHandle.fromAddress(address, length, scale), size); + return LongRawDataBuffer.this; + } + + @Override + public LongDataBuffer fallback() { + if (dst instanceof LongDataBuffer) { + LongDataBuffer longDst = (LongDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + longDst.setLong(getLong(idx), idx); + } + return LongRawDataBuffer.this; + } + return slowCopyTo(dst, size); + } + }); + } + + @Override + public R accept(DataStorageVisitor visitor) { + if (memory.isArray()) { + return visitor.visit(memory.toArrayLongBuffer()); + } + return visitor.visit(memory.byteOffset, memory.byteSize, memory.scale); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof LongDataBuffer)) { + return super.equals(obj); + } + LongDataBuffer other = (LongDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(LongBuffer buffer) { + if (memory.isArray()) { + return buffer.equals(memory.toArrayLongBuffer()); + } + return fallback(); + } + + @Override + public Boolean fallback() { + for (long idx = 0L; idx < size(); ++idx) { + if (other.getLong(idx) != getLong(idx)) { + return false; + } + } + return true; + } + }); + } + + @Override + protected LongDataBuffer instantiate(UnsafeMemoryHandle memory) { + return new LongRawDataBuffer(memory, readOnly); + } + + LongRawDataBuffer(UnsafeMemoryHandle memory, boolean readOnly) { + super(memory, readOnly); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/RawDataBufferFactory.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/RawDataBufferFactory.java new file mode 100644 index 00000000000..b185eefa6b5 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/RawDataBufferFactory.java @@ -0,0 +1,153 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.raw; + +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +/** Factory of raw data buffers */ +public class RawDataBufferFactory { + + public static boolean canBeUsed() { + return UnsafeReference.isAvailable(); + } + + public static BooleanDataBuffer create(boolean[] array, boolean readOnly) { + return new BooleanRawDataBuffer(UnsafeMemoryHandle.fromArray(array, array.length), readOnly); + } + + public static ByteDataBuffer create(byte[] array, boolean readOnly) { + if (!canBeUsed()) { + throw new IllegalStateException("Raw data buffers are not available"); + } + return new ByteRawDataBuffer(UnsafeMemoryHandle.fromArray(array, array.length), readOnly); + } + + public static DoubleDataBuffer create(double[] array, boolean readOnly) { + if (!canBeUsed()) { + throw new IllegalStateException("Raw data buffers are not available"); + } + return new DoubleRawDataBuffer(UnsafeMemoryHandle.fromArray(array, array.length), readOnly); + } + + public static FloatDataBuffer create(float[] array, boolean readOnly) { + if (!canBeUsed()) { + throw new IllegalStateException("Raw data buffers are not available"); + } + return new FloatRawDataBuffer(UnsafeMemoryHandle.fromArray(array, array.length), readOnly); + } + + public static IntDataBuffer create(int[] array, boolean readOnly) { + if (!canBeUsed()) { + throw new IllegalStateException("Raw data buffers are not available"); + } + return new IntRawDataBuffer(UnsafeMemoryHandle.fromArray(array, array.length), readOnly); + } + + public static LongDataBuffer create(long[] array, boolean readOnly) { + if (!canBeUsed()) { + throw new IllegalStateException("Raw data buffers are not available"); + } + return new LongRawDataBuffer(UnsafeMemoryHandle.fromArray(array, array.length), readOnly); + } + + public static ShortDataBuffer create(short[] array, boolean readOnly) { + if (!canBeUsed()) { + throw new IllegalStateException("Raw data buffers are not available"); + } + return new ShortRawDataBuffer(UnsafeMemoryHandle.fromArray(array, array.length), readOnly); + } + + protected static BooleanDataBuffer mapNativeBooleans(long address, long size, boolean readOnly) { + if (!canBeUsed()) { + throw new IllegalStateException("Raw data buffers are not available"); + } + Validator.createArgs(size, MAX_64BITS); + return new BooleanRawDataBuffer( + UnsafeMemoryHandle.fromAddress(address, size, Byte.BYTES), readOnly); + } + + protected static ByteDataBuffer mapNativeBytes(long address, long size, boolean readOnly) { + if (!canBeUsed()) { + throw new IllegalStateException("Raw data buffers are not available"); + } + Validator.createArgs(size, MAX_64BITS); + return new ByteRawDataBuffer( + UnsafeMemoryHandle.fromAddress(address, size, Byte.BYTES), readOnly); + } + + protected static DoubleDataBuffer mapNativeDoubles(long address, long size, boolean readOnly) { + if (!canBeUsed()) { + throw new IllegalStateException("Raw data buffers are not available"); + } + Validator.createArgs(size, MAX_64BITS); + return new DoubleRawDataBuffer( + UnsafeMemoryHandle.fromAddress(address, size, Double.BYTES), readOnly); + } + + protected static FloatDataBuffer mapNativeFloats(long address, long size, boolean readOnly) { + if (!canBeUsed()) { + throw new IllegalStateException("Raw data buffers are not available"); + } + Validator.createArgs(size, MAX_64BITS); + return new FloatRawDataBuffer( + UnsafeMemoryHandle.fromAddress(address, size, Float.BYTES), readOnly); + } + + protected static IntDataBuffer mapNativeInts(long address, long size, boolean readOnly) { + if (!canBeUsed()) { + throw new IllegalStateException("Raw data buffers are not available"); + } + Validator.createArgs(size, MAX_64BITS); + return new IntRawDataBuffer( + UnsafeMemoryHandle.fromAddress(address, size, Integer.BYTES), readOnly); + } + + protected static LongDataBuffer mapNativeLongs(long address, long size, boolean readOnly) { + if (!canBeUsed()) { + throw new IllegalStateException("Raw data buffers are not available"); + } + Validator.createArgs(size, MAX_64BITS); + return new LongRawDataBuffer( + UnsafeMemoryHandle.fromAddress(address, size, Long.BYTES), readOnly); + } + + protected static ShortDataBuffer mapNativeShorts(long address, long size, boolean readOnly) { + if (!canBeUsed()) { + throw new IllegalStateException("Raw data buffers are not available"); + } + Validator.createArgs(size, MAX_64BITS); + return new ShortRawDataBuffer( + UnsafeMemoryHandle.fromAddress(address, size, Short.BYTES), readOnly); + } + + /* + * The maximum size for a buffer of this type, i.e. the maximum number of bytes it can store. + *

+ * As the maximum size may vary depending on the JVM implementation and on the platform, this + * property returns a value that is safe for most of them. + */ + static long MAX_32BITS = Integer.MAX_VALUE - 10; + static long MAX_64BITS = Long.MAX_VALUE - 10; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/RawDataBufferWindow.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/RawDataBufferWindow.java new file mode 100644 index 00000000000..0b5aa464ea3 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/RawDataBufferWindow.java @@ -0,0 +1,19 @@ +package org.tensorflow.ndarray.impl.buffer.raw; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.impl.buffer.AbstractDataBufferWindow; + +final class RawDataBufferWindow> extends AbstractDataBufferWindow { + + @Override + public void offset(long offset) { + windowMemory.rebase(offset); + } + + > RawDataBufferWindow(R windowBuffer, long bufferLimit) { + super((B) windowBuffer, bufferLimit); + this.windowMemory = windowBuffer.memory; + } + + private final UnsafeMemoryHandle windowMemory; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/ShortRawDataBuffer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/ShortRawDataBuffer.java new file mode 100644 index 00000000000..df3320ff4bd --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/ShortRawDataBuffer.java @@ -0,0 +1,153 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.raw; + +import java.nio.ShortBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataStorageVisitor; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.impl.buffer.Validator; + +final class ShortRawDataBuffer extends AbstractRawDataBuffer + implements ShortDataBuffer { + + @Override + public short getShort(long index) { + Validator.getArgs(this, index); + return memory.getShort(index); + } + + @Override + public ShortDataBuffer setShort(short value, long index) { + Validator.setArgs(this, index); + memory.setShort(value, index); + return this; + } + + @Override + public ShortDataBuffer read(short[] dst) { + return read(dst, dst.length); + } + + @Override + public ShortDataBuffer read(short[] dst, int offset, int length) { + return read(dst, dst.length, offset, length); + } + + @Override + public ShortDataBuffer write(short[] src) { + return write(src, src.length); + } + + @Override + public ShortDataBuffer write(short[] src, int offset, int length) { + return write(src, src.length, offset, length); + } + + @Override + public ShortDataBuffer copyTo(DataBuffer dst, long size) { + Validator.copyToArgs(this, dst, size); + return dst.accept( + new DataStorageVisitor() { + + @Override + public ShortDataBuffer visit(ShortBuffer buffer) { + if (buffer.hasArray()) { + memory.copyTo( + UnsafeMemoryHandle.fromArray(buffer.array(), buffer.position(), buffer.limit()), + size); + } else if (memory.isArray()) { + buffer.put(memory.toArrayShortBuffer()); + } else { + slowCopyTo(dst, size); + } + return ShortRawDataBuffer.this; + } + + @Override + public ShortDataBuffer visit(long address, long length, long scale) { + memory.copyTo(UnsafeMemoryHandle.fromAddress(address, length, scale), size); + return ShortRawDataBuffer.this; + } + + @Override + public ShortDataBuffer fallback() { + if (dst instanceof ShortDataBuffer) { + ShortDataBuffer shortDst = (ShortDataBuffer) dst; + for (long idx = 0L; idx < size; ++idx) { + shortDst.setShort(getShort(idx), idx); + } + return ShortRawDataBuffer.this; + } + return slowCopyTo(dst, size); + } + }); + } + + @Override + public R accept(DataStorageVisitor visitor) { + if (memory.isArray()) { + return visitor.visit(memory.toArrayShortBuffer()); + } + return visitor.visit(memory.byteOffset, memory.byteSize, memory.scale); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof ShortDataBuffer)) { + return super.equals(obj); + } + ShortDataBuffer other = (ShortDataBuffer) obj; + if (size() != other.size()) { + return false; + } + return other.accept( + new DataStorageVisitor() { + + @Override + public Boolean visit(ShortBuffer buffer) { + if (memory.isArray()) { + return buffer.equals(memory.toArrayShortBuffer()); + } + return fallback(); + } + + @Override + public Boolean fallback() { + for (long idx = 0L; idx < size(); ++idx) { + if (other.getShort(idx) != getShort(idx)) { + return false; + } + } + return true; + } + }); + } + + @Override + protected ShortDataBuffer instantiate(UnsafeMemoryHandle memory) { + return new ShortRawDataBuffer(memory, readOnly); + } + + ShortRawDataBuffer(UnsafeMemoryHandle memory, boolean readOnly) { + super(memory, readOnly); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/UnsafeMemoryHandle.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/UnsafeMemoryHandle.java new file mode 100644 index 00000000000..e2022cb9dc7 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/UnsafeMemoryHandle.java @@ -0,0 +1,214 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.raw; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.nio.ShortBuffer; + +final class UnsafeMemoryHandle { + + static UnsafeMemoryHandle fromArray(Object array, int length) { + return fromArray(array, 0, length); + } + + static UnsafeMemoryHandle fromArray(Object array, int arrayOffset, int length) { + long scale = UnsafeReference.UNSAFE.arrayIndexScale(array.getClass()); + int baseOffset = UnsafeReference.UNSAFE.arrayBaseOffset(array.getClass()); + return new UnsafeMemoryHandle(array, baseOffset + (arrayOffset * scale), length * scale, scale); + } + + static UnsafeMemoryHandle fromAddress(long address, long byteSize, long scale) { + return new UnsafeMemoryHandle(address, byteSize, scale); + } + + long size() { + return size; + } + + byte getByte(long index) { + return UnsafeReference.UNSAFE.getByte(object, align(index)); + } + + void setByte(byte value, long index) { + UnsafeReference.UNSAFE.putByte(object, align(index), value); + } + + boolean getBoolean(long index) { + return UnsafeReference.UNSAFE.getBoolean(object, align(index)); + } + + void setBoolean(boolean value, long index) { + UnsafeReference.UNSAFE.putBoolean(object, align(index), value); + } + + short getShort(long index) { + return UnsafeReference.UNSAFE.getShort(object, align(index)); + } + + void setShort(short value, long index) { + UnsafeReference.UNSAFE.putShort(object, align(index), value); + } + + int getInt(long index) { + return UnsafeReference.UNSAFE.getInt(object, align(index)); + } + + void setInt(int value, long index) { + UnsafeReference.UNSAFE.putInt(object, align(index), value); + } + + float getFloat(long index) { + return UnsafeReference.UNSAFE.getFloat(object, align(index)); + } + + void setFloat(float value, long index) { + UnsafeReference.UNSAFE.putFloat(object, align(index), value); + } + + double getDouble(long index) { + return UnsafeReference.UNSAFE.getDouble(object, align(index)); + } + + void setDouble(double value, long index) { + UnsafeReference.UNSAFE.putDouble(object, align(index), value); + } + + long getLong(long index) { + return UnsafeReference.UNSAFE.getLong(object, align(index)); + } + + void setLong(long value, long index) { + UnsafeReference.UNSAFE.putLong(object, align(index), value); + } + + void copyTo(UnsafeMemoryHandle memory, long length) { + UnsafeReference.UNSAFE.copyMemory( + object, byteOffset, memory.object, memory.byteOffset, length * scale); + } + + UnsafeMemoryHandle offset(long index) { + long offset = scale(index); + return new UnsafeMemoryHandle(object, this.byteOffset + offset, byteSize - offset, scale); + } + + UnsafeMemoryHandle narrow(long size) { + return new UnsafeMemoryHandle(object, byteOffset, scale(size), scale); + } + + UnsafeMemoryHandle slice(long index, long size) { + return new UnsafeMemoryHandle(object, this.byteOffset + scale(index), scale(size), scale); + } + + UnsafeMemoryHandle rescale(long scale) { + if (object != null) { + throw new IllegalStateException("Raw heap memory cannot be rescaled"); + } + return new UnsafeMemoryHandle(null, byteOffset, byteSize, scale); + } + + void rebase(long index) { + byteOffset = baseOffset + scale(index); + } + + boolean isArray() { + return object != null; + } + + @SuppressWarnings("unchecked") + A array() { + return (A) object; + } + + int arrayOffset(Class arrayClass) { + return (int) ((byteOffset - UnsafeReference.UNSAFE.arrayBaseOffset(arrayClass)) / scale); + } + + ByteBuffer toArrayByteBuffer() { + return ByteBuffer.wrap( + (byte[]) object, + (int) byteOffset - UnsafeReference.UNSAFE.arrayBaseOffset(byte[].class), + (int) size); + } + + ShortBuffer toArrayShortBuffer() { + return ShortBuffer.wrap( + (short[]) object, + (int) ((byteOffset - UnsafeReference.UNSAFE.arrayBaseOffset(short[].class)) / scale), + (int) size); + } + + IntBuffer toArrayIntBuffer() { + return IntBuffer.wrap( + (int[]) object, + (int) ((byteOffset - UnsafeReference.UNSAFE.arrayBaseOffset(int[].class)) / scale), + (int) size); + } + + LongBuffer toArrayLongBuffer() { + return LongBuffer.wrap( + (long[]) object, + (int) ((byteOffset - UnsafeReference.UNSAFE.arrayBaseOffset(long[].class)) / scale), + (int) size); + } + + FloatBuffer toArrayFloatBuffer() { + return FloatBuffer.wrap( + (float[]) object, + (int) ((byteOffset - UnsafeReference.UNSAFE.arrayBaseOffset(float[].class)) / scale), + (int) size); + } + + DoubleBuffer toArrayDoubleBuffer() { + return DoubleBuffer.wrap( + (double[]) object, + (int) ((byteOffset - UnsafeReference.UNSAFE.arrayBaseOffset(double[].class)) / scale), + (int) size); + } + + final Object object; + final long baseOffset; + long byteOffset; + final long byteSize; + final long scale; + final long size; + + private UnsafeMemoryHandle(Object object, long baseOffset, long byteSize, long scale) { + this.object = object; + this.baseOffset = baseOffset; + byteOffset = baseOffset; + this.byteSize = byteSize; + this.scale = scale; + size = byteSize / scale; + } + + private UnsafeMemoryHandle(long address, long byteSize, long scale) { + this(null, address, byteSize, scale); + } + + private long align(long index) { + return byteOffset + index * scale; + } + + private long scale(long value) { + return value * scale; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/UnsafeReference.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/UnsafeReference.java new file mode 100644 index 00000000000..d0a4e1a3e89 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/buffer/raw/UnsafeReference.java @@ -0,0 +1,83 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.raw; + +import java.lang.reflect.Field; +import sun.misc.Unsafe; + +final class UnsafeReference { + + static boolean isAvailable() { + return UNSAFE != null; + } + + static final Unsafe UNSAFE; + + static { + Unsafe unsafe = null; + try { + Class clazz = Class.forName("sun.misc.Unsafe"); + Field theUnsafe = clazz.getDeclaredField("theUnsafe"); + theUnsafe.setAccessible(true); + Object instance = theUnsafe.get(null); + if (instance.getClass() == clazz) { + checkMethod(clazz, "getByte", Object.class, long.class); + checkMethod(clazz, "putByte", Object.class, long.class, byte.class); + checkMethod(clazz, "getShort", Object.class, long.class); + checkMethod(clazz, "putShort", Object.class, long.class, short.class); + checkMethod(clazz, "getInt", Object.class, long.class); + checkMethod(clazz, "putInt", Object.class, long.class, int.class); + checkMethod(clazz, "getLong", Object.class, long.class); + checkMethod(clazz, "putLong", Object.class, long.class, long.class); + checkMethod(clazz, "getFloat", Object.class, long.class); + checkMethod(clazz, "putFloat", Object.class, long.class, float.class); + checkMethod(clazz, "getDouble", Object.class, long.class); + checkMethod(clazz, "putDouble", Object.class, long.class, double.class); + checkMethod(clazz, "getBoolean", Object.class, long.class); + checkMethod(clazz, "putBoolean", Object.class, long.class, boolean.class); + checkMethod( + clazz, "copyMemory", Object.class, long.class, Object.class, long.class, long.class); + checkMethod(clazz, "arrayBaseOffset", Class.class); + checkMethod(clazz, "arrayIndexScale", Class.class); + + unsafe = (Unsafe) instance; + } + } catch (ClassNotFoundException + | NoSuchMethodException + | NoSuchFieldException + | SecurityException + | IllegalAccessException + | ClassCastException ex) { + // Do nothing, keep unsafe as null + } + UNSAFE = unsafe; + } + + /** + * Validate that this Unsafe instance exposes this method + * + *

ErrorProne does not like that we do nothing with the returned method... but there is nothing + * to do with it, so disable the check + */ + @SuppressWarnings("ReturnValueIgnored") + private static void checkMethod( + Class unsafeClass, String methodName, Class... parameterTypes) + throws NoSuchMethodException { + unsafeClass.getDeclaredMethod(methodName, parameterTypes); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java new file mode 100644 index 00000000000..399e45d2934 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/AbstractDenseNdArray.java @@ -0,0 +1,190 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArraySequence; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBufferWindow; +import org.tensorflow.ndarray.impl.AbstractNdArray; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace; +import org.tensorflow.ndarray.impl.sequence.FastElementSequence; +import org.tensorflow.ndarray.impl.sequence.SingleElementSequence; +import org.tensorflow.ndarray.impl.sequence.SlicingElementSequence; +import org.tensorflow.ndarray.index.Index; + +@SuppressWarnings("unchecked") +public abstract class AbstractDenseNdArray> extends AbstractNdArray { + + @Override + public NdArraySequence elements(int dimensionIdx) { + if (dimensionIdx >= shape().numDimensions()) { + throw new IllegalArgumentException( + "Cannot iterate elements in dimension '" + + dimensionIdx + + "' of array with shape " + + shape()); + } + if (rank() == 0 && dimensionIdx < 0) { + return new SingleElementSequence<>(this); + } + DimensionalSpace elemDims = dimensions().from(dimensionIdx + 1); + try { + DataBufferWindow> elemWindow = + buffer().window(elemDims.physicalSize()); + U element = instantiateView(elemWindow.buffer(), elemDims); + return new FastElementSequence(this, dimensionIdx, element, elemWindow); + } catch (UnsupportedOperationException e) { + // If buffer windows are not supported, fallback to slicing (and slower) sequence + return new SlicingElementSequence<>(this, dimensionIdx, elemDims); + } + } + + @Override + public U withShape(Shape shape) { + if (shape == null || shape.isUnknown() || shape.size() != this.shape().size()) { + throw new IllegalArgumentException( + "Shape " + shape + " cannot be used to reshape ndarray of shape " + this.shape()); + } + if (shape.equals(this.shape())) { + return (U) this; + } + return instantiateView(buffer(), DimensionalSpace.create(shape)); + } + + @Override + public U slice(long position, DimensionalSpace sliceDimensions) { + DataBuffer sliceBuffer = buffer().slice(position, sliceDimensions.physicalSize()); + return instantiateView(sliceBuffer, sliceDimensions); + } + + @Override + public U slice(Index... indices) { + if (indices == null) { + throw new IllegalArgumentException("Slicing requires at least one index"); + } + RelativeDimensionalSpace sliceDimensions = dimensions().mapTo(indices); + return slice(sliceDimensions.position(), sliceDimensions); + } + + @Override + public U get(long... coords) { + return slice(positionOf(coords, false), dimensions().from(coords.length)); + } + + @Override + public T getObject(long... coords) { + return buffer().getObject(positionOf(coords, true)); + } + + @Override + public U set(NdArray src, long... coordinates) { + src.copyTo((coordinates == null || coordinates.length == 0) ? this : get(coordinates)); + return (U) this; + } + + @Override + public U setObject(T value, long... coords) { + buffer().setObject(value, positionOf(coords, true)); + return (U) this; + } + + @Override + public U copyTo(DataBuffer dst) { + Validator.copyToBufferArgs(this, dst); + DataTransfer.execute(buffer(), dimensions(), dst, DataTransfer::ofValue); + return (U) this; + } + + @Override + public U copyFrom(DataBuffer src) { + Validator.copyFromBufferArgs(this, src); + DataTransfer.execute(src, buffer(), dimensions(), DataTransfer::ofValue); + return (U) this; + } + + @Override + public int hashCode() { + if (dimensions().isSegmented()) { + return slowHashCode(); + } + final int prime = 31; + int result = 1; + result = prime * result + buffer().hashCode(); + result = prime * result + shape().hashCode(); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof AbstractDenseNdArray)) { + return super.equals(obj); + } + AbstractDenseNdArray other = (AbstractDenseNdArray) obj; + if (dimensions().isSegmented() || other.dimensions().isSegmented()) { + return slowEquals(other); + } + if (!shape().equals(other.shape())) { + return false; + } + return buffer().equals(other.buffer()); + } + + /** + * A String showing the type and shape of this dense ndarray. + * + * @return A string containing the type and shape. + */ + @Override + public String toString() { + return this.getClass().getSimpleName() + "(shape=" + this.shape() + ")"; + } + + protected AbstractDenseNdArray(DimensionalSpace dimensions) { + super(dimensions); + } + + protected abstract DataBuffer buffer(); + + abstract U instantiateView(DataBuffer buffer, DimensionalSpace dimensions); + + long positionOf(long[] coords, boolean isValue) { + if (coords == null || coords.length == 0) { + return 0; + } + Validator.coordinates(dimensions, coords, isValue); + return dimensions.positionOf(coords); + } + + @Override + protected void slowCopyTo(NdArray array) { + if (array instanceof AbstractDenseNdArray) { + AbstractDenseNdArray dst = (AbstractDenseNdArray) array; + long offset = 0L; + for (NdArray s : scalars()) { + dst.buffer().setObject(s.getObject(), offset++); + } + } else { + super.slowCopyTo(array); + } + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java new file mode 100644 index 00000000000..ea428b02ca2 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArray.java @@ -0,0 +1,96 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.BooleanNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +public class BooleanDenseNdArray extends AbstractDenseNdArray + implements BooleanNdArray { + + public static BooleanNdArray create(BooleanDataBuffer buffer, Shape shape) { + Validator.denseShape(buffer, shape); + return new BooleanDenseNdArray(buffer, shape); + } + + @Override + public boolean getBoolean(long... indices) { + return buffer.getBoolean(positionOf(indices, true)); + } + + @Override + public BooleanNdArray setBoolean(boolean value, long... indices) { + buffer.setBoolean(value, positionOf(indices, true)); + return this; + } + + @Override + public BooleanNdArray copyTo(NdArray dst) { + Validator.copyToNdArrayArgs(this, dst); + if (dst instanceof BooleanDenseNdArray) { + BooleanDenseNdArray booleanDst = (BooleanDenseNdArray) dst; + DataTransfer.execute( + buffer, + dimensions(), + booleanDst.buffer, + booleanDst.dimensions(), + DataTransfer::ofBoolean); + } else { + slowCopyTo(dst); + } + return this; + } + + @Override + public BooleanNdArray copyTo(BooleanDataBuffer dst) { + Validator.copyToBufferArgs(this, dst); + DataTransfer.execute(buffer, dimensions(), dst, DataTransfer::ofBoolean); + return this; + } + + @Override + public BooleanNdArray copyFrom(BooleanDataBuffer src) { + Validator.copyFromBufferArgs(this, src); + DataTransfer.execute(src, buffer, dimensions(), DataTransfer::ofBoolean); + return this; + } + + protected BooleanDenseNdArray(BooleanDataBuffer buffer, Shape shape) { + this(buffer, DimensionalSpace.create(shape)); + } + + @Override + BooleanDenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { + return new BooleanDenseNdArray((BooleanDataBuffer) buffer, dimensions); + } + + @Override + protected BooleanDataBuffer buffer() { + return buffer; + } + + private final BooleanDataBuffer buffer; + + private BooleanDenseNdArray(BooleanDataBuffer buffer, DimensionalSpace dimensions) { + super(dimensions); + this.buffer = buffer; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java new file mode 100644 index 00000000000..a8aff33063a --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArray.java @@ -0,0 +1,92 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.ByteNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +public class ByteDenseNdArray extends AbstractDenseNdArray + implements ByteNdArray { + + public static ByteNdArray create(ByteDataBuffer buffer, Shape shape) { + Validator.denseShape(buffer, shape); + return new ByteDenseNdArray(buffer, shape); + } + + @Override + public byte getByte(long... indices) { + return buffer.getByte(positionOf(indices, true)); + } + + @Override + public ByteNdArray setByte(byte value, long... indices) { + buffer.setByte(value, positionOf(indices, true)); + return this; + } + + @Override + public ByteNdArray copyTo(NdArray dst) { + Validator.copyToNdArrayArgs(this, dst); + if (dst instanceof ByteDenseNdArray) { + ByteDenseNdArray byteDst = (ByteDenseNdArray) dst; + DataTransfer.execute( + buffer, dimensions(), byteDst.buffer, byteDst.dimensions(), DataTransfer::ofByte); + } else { + slowCopyTo(dst); + } + return this; + } + + @Override + public ByteNdArray copyTo(ByteDataBuffer dst) { + Validator.copyToBufferArgs(this, dst); + DataTransfer.execute(buffer, dimensions(), dst, DataTransfer::ofByte); + return this; + } + + @Override + public ByteNdArray copyFrom(ByteDataBuffer src) { + Validator.copyFromBufferArgs(this, src); + DataTransfer.execute(src, buffer, dimensions(), DataTransfer::ofByte); + return this; + } + + protected ByteDenseNdArray(ByteDataBuffer buffer, Shape shape) { + this(buffer, DimensionalSpace.create(shape)); + } + + @Override + ByteDenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { + return new ByteDenseNdArray((ByteDataBuffer) buffer, dimensions); + } + + @Override + protected ByteDataBuffer buffer() { + return buffer; + } + + private final ByteDataBuffer buffer; + + private ByteDenseNdArray(ByteDataBuffer buffer, DimensionalSpace dimensions) { + super(dimensions); + this.buffer = buffer; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DataTransfer.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DataTransfer.java new file mode 100644 index 00000000000..aa3c874e021 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DataTransfer.java @@ -0,0 +1,143 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sequence.PositionIterator; + +final class DataTransfer { + + @FunctionalInterface + interface OfValue> { + void copy(B srcBuffer, long srcIndex, B dstBuffer, long dstIndex); + } + + static > void ofValue(B srcBuf, long srcIdx, B dstBuf, long dstIdx) { + dstBuf.setObject(srcBuf.getObject(srcIdx), dstIdx); + } + + static void ofByte(ByteDataBuffer srcBuf, long srcIdx, ByteDataBuffer dstBuf, long dstIdx) { + dstBuf.setByte(srcBuf.getByte(srcIdx), dstIdx); + } + + static void ofInt(IntDataBuffer srcBuf, long srcIdx, IntDataBuffer dstBuf, long dstIdx) { + dstBuf.setInt(srcBuf.getInt(srcIdx), dstIdx); + } + + static void ofLong(LongDataBuffer srcBuf, long srcIdx, LongDataBuffer dstBuf, long dstIdx) { + dstBuf.setLong(srcBuf.getLong(srcIdx), dstIdx); + } + + static void ofDouble(DoubleDataBuffer srcBuf, long srcIdx, DoubleDataBuffer dstBuf, long dstIdx) { + dstBuf.setDouble(srcBuf.getDouble(srcIdx), dstIdx); + } + + static void ofFloat(FloatDataBuffer srcBuf, long srcIdx, FloatDataBuffer dstBuf, long dstIdx) { + dstBuf.setFloat(srcBuf.getFloat(srcIdx), dstIdx); + } + + static void ofShort(ShortDataBuffer srcBuf, long srcIdx, ShortDataBuffer dstBuf, long dstIdx) { + dstBuf.setShort(srcBuf.getShort(srcIdx), dstIdx); + } + + static void ofBoolean( + BooleanDataBuffer srcBuf, long srcIdx, BooleanDataBuffer dstBuf, long dstIdx) { + dstBuf.setBoolean(srcBuf.getBoolean(srcIdx), dstIdx); + } + + static > void execute( + B srcBuffer, + DimensionalSpace srcDimensions, + B dstBuffer, + DimensionalSpace dstDimensions, + OfValue valueTransfer) { + if (srcDimensions.isSegmented() || dstDimensions.isSegmented()) { + int segmentationIdx = + Math.max(srcDimensions.segmentationIdx(), dstDimensions.segmentationIdx()); + copyByElement( + srcBuffer, + PositionIterator.create(srcDimensions, segmentationIdx), + dstBuffer, + PositionIterator.create(dstDimensions, segmentationIdx), + srcDimensions.get(segmentationIdx).elementSize(), + valueTransfer); + } else { + srcBuffer.copyTo(dstBuffer, srcDimensions.physicalSize()); + } + } + + static > void execute( + B srcBuffer, B dstBuffer, DimensionalSpace dstDimensions, OfValue valueTransfer) { + if (dstDimensions.isSegmented()) { + long elementSize = dstDimensions.get(dstDimensions.segmentationIdx()).elementSize(); + copyByElement( + srcBuffer, + PositionIterator.sequence(elementSize, srcBuffer.size()), + dstBuffer, + PositionIterator.create(dstDimensions, dstDimensions.segmentationIdx()), + elementSize, + valueTransfer); + } else { + srcBuffer.copyTo(dstBuffer, dstDimensions.physicalSize()); + } + } + + static > void execute( + B srcBuffer, DimensionalSpace srcDimensions, B dstBuffer, OfValue valueTransfer) { + if (srcDimensions.isSegmented()) { + long elementSize = srcDimensions.get(srcDimensions.segmentationIdx()).elementSize(); + copyByElement( + srcBuffer, + PositionIterator.create(srcDimensions, srcDimensions.segmentationIdx()), + dstBuffer, + PositionIterator.sequence(elementSize, dstBuffer.size()), + elementSize, + valueTransfer); + } else { + srcBuffer.copyTo(dstBuffer, srcDimensions.physicalSize()); + } + } + + private static > void copyByElement( + B srcBuffer, + PositionIterator srcIterator, + B dstBuffer, + PositionIterator dstIterator, + long elementSize, + OfValue valueTransfer) { + if (elementSize == 1) { + while (srcIterator.hasNext()) { + valueTransfer.copy(srcBuffer, srcIterator.nextLong(), dstBuffer, dstIterator.nextLong()); + } + } else { + while (srcIterator.hasNext()) { + srcBuffer + .offset(srcIterator.nextLong()) + .copyTo(dstBuffer.offset(dstIterator.nextLong()), elementSize); + } + } + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java new file mode 100644 index 00000000000..1006b5c05c5 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DenseNdArray.java @@ -0,0 +1,64 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +public class DenseNdArray extends AbstractDenseNdArray> { + + public static NdArray wrap(DataBuffer buffer, Shape shape) { + Validator.denseShape(buffer, shape); + return new DenseNdArray<>(buffer, shape); + } + + @Override + public NdArray copyTo(NdArray dst) { + Validator.copyToNdArrayArgs(this, dst); + if (dst instanceof DenseNdArray) { + DenseNdArray denseDst = (DenseNdArray) dst; + DataTransfer.execute( + buffer, dimensions(), denseDst.buffer, denseDst.dimensions(), DataTransfer::ofValue); + } else { + slowCopyTo(dst); + } + return this; + } + + protected DenseNdArray(DataBuffer buffer, Shape shape) { + this(buffer, DimensionalSpace.create(shape)); + } + + @Override + DenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { + return new DenseNdArray<>(buffer, dimensions); + } + + @Override + protected DataBuffer buffer() { + return buffer; + } + + private final DataBuffer buffer; + + private DenseNdArray(DataBuffer buffer, DimensionalSpace dimensions) { + super(dimensions); + this.buffer = buffer; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java new file mode 100644 index 00000000000..4e9883c9c80 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArray.java @@ -0,0 +1,92 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +public class DoubleDenseNdArray extends AbstractDenseNdArray + implements DoubleNdArray { + + public static DoubleNdArray create(DoubleDataBuffer buffer, Shape shape) { + Validator.denseShape(buffer, shape); + return new DoubleDenseNdArray(buffer, shape); + } + + @Override + public double getDouble(long... indices) { + return buffer.getDouble(positionOf(indices, true)); + } + + @Override + public DoubleNdArray setDouble(double value, long... indices) { + buffer.setDouble(value, positionOf(indices, true)); + return this; + } + + @Override + public DoubleNdArray copyTo(NdArray dst) { + Validator.copyToNdArrayArgs(this, dst); + if (dst instanceof DoubleDenseNdArray) { + DoubleDenseNdArray doubleDst = (DoubleDenseNdArray) dst; + DataTransfer.execute( + buffer, dimensions(), doubleDst.buffer, doubleDst.dimensions(), DataTransfer::ofDouble); + } else { + slowCopyTo(dst); + } + return this; + } + + @Override + public DoubleNdArray copyTo(DoubleDataBuffer dst) { + Validator.copyToBufferArgs(this, dst); + DataTransfer.execute(buffer, dimensions(), dst, DataTransfer::ofDouble); + return this; + } + + @Override + public DoubleNdArray copyFrom(DoubleDataBuffer src) { + Validator.copyFromBufferArgs(this, src); + DataTransfer.execute(src, buffer, dimensions(), DataTransfer::ofDouble); + return this; + } + + protected DoubleDenseNdArray(DoubleDataBuffer buffer, Shape shape) { + this(buffer, DimensionalSpace.create(shape)); + } + + @Override + DoubleDenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { + return new DoubleDenseNdArray((DoubleDataBuffer) buffer, dimensions); + } + + @Override + protected DoubleDataBuffer buffer() { + return buffer; + } + + private final DoubleDataBuffer buffer; + + private DoubleDenseNdArray(DoubleDataBuffer buffer, DimensionalSpace dimensions) { + super(dimensions); + this.buffer = buffer; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java new file mode 100644 index 00000000000..74369bcf1bc --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArray.java @@ -0,0 +1,92 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +public class FloatDenseNdArray extends AbstractDenseNdArray + implements FloatNdArray { + + public static FloatNdArray create(FloatDataBuffer buffer, Shape shape) { + Validator.denseShape(buffer, shape); + return new FloatDenseNdArray(buffer, shape); + } + + @Override + public float getFloat(long... indices) { + return buffer.getFloat(positionOf(indices, true)); + } + + @Override + public FloatNdArray setFloat(float value, long... indices) { + buffer.setFloat(value, positionOf(indices, true)); + return this; + } + + @Override + public FloatNdArray copyTo(NdArray dst) { + Validator.copyToNdArrayArgs(this, dst); + if (dst instanceof FloatDenseNdArray) { + FloatDenseNdArray floatDst = (FloatDenseNdArray) dst; + DataTransfer.execute( + buffer, dimensions(), floatDst.buffer, floatDst.dimensions(), DataTransfer::ofFloat); + } else { + slowCopyTo(dst); + } + return this; + } + + @Override + public FloatNdArray copyTo(FloatDataBuffer dst) { + Validator.copyToBufferArgs(this, dst); + DataTransfer.execute(buffer, dimensions(), dst, DataTransfer::ofFloat); + return this; + } + + @Override + public FloatNdArray copyFrom(FloatDataBuffer src) { + Validator.copyFromBufferArgs(this, src); + DataTransfer.execute(src, buffer, dimensions(), DataTransfer::ofFloat); + return this; + } + + protected FloatDenseNdArray(FloatDataBuffer buffer, Shape shape) { + this(buffer, DimensionalSpace.create(shape)); + } + + @Override + FloatDenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { + return new FloatDenseNdArray((FloatDataBuffer) buffer, dimensions); + } + + @Override + public FloatDataBuffer buffer() { + return buffer; + } + + private final FloatDataBuffer buffer; + + private FloatDenseNdArray(FloatDataBuffer buffer, DimensionalSpace dimensions) { + super(dimensions); + this.buffer = buffer; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java new file mode 100644 index 00000000000..e3210b18a7f --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArray.java @@ -0,0 +1,92 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.IntNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +public class IntDenseNdArray extends AbstractDenseNdArray + implements IntNdArray { + + public static IntNdArray create(IntDataBuffer buffer, Shape shape) { + Validator.denseShape(buffer, shape); + return new IntDenseNdArray(buffer, shape); + } + + @Override + public int getInt(long... indices) { + return buffer.getInt(positionOf(indices, true)); + } + + @Override + public IntNdArray setInt(int value, long... indices) { + buffer.setInt(value, positionOf(indices, true)); + return this; + } + + @Override + public IntNdArray copyTo(NdArray dst) { + Validator.copyToNdArrayArgs(this, dst); + if (dst instanceof IntDenseNdArray) { + IntDenseNdArray intDst = (IntDenseNdArray) dst; + DataTransfer.execute( + buffer, dimensions(), intDst.buffer, intDst.dimensions(), DataTransfer::ofInt); + } else { + slowCopyTo(dst); + } + return this; + } + + @Override + public IntNdArray copyTo(IntDataBuffer dst) { + Validator.copyToBufferArgs(this, dst); + DataTransfer.execute(buffer, dimensions(), dst, DataTransfer::ofInt); + return this; + } + + @Override + public IntNdArray copyFrom(IntDataBuffer src) { + Validator.copyFromBufferArgs(this, src); + DataTransfer.execute(src, buffer, dimensions(), DataTransfer::ofInt); + return this; + } + + protected IntDenseNdArray(IntDataBuffer buffer, Shape shape) { + this(buffer, DimensionalSpace.create(shape)); + } + + @Override + IntDenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { + return new IntDenseNdArray((IntDataBuffer) buffer, dimensions); + } + + @Override + protected IntDataBuffer buffer() { + return buffer; + } + + private final IntDataBuffer buffer; + + private IntDenseNdArray(IntDataBuffer buffer, DimensionalSpace dimensions) { + super(dimensions); + this.buffer = buffer; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java new file mode 100644 index 00000000000..7018f756c4f --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArray.java @@ -0,0 +1,92 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +public class LongDenseNdArray extends AbstractDenseNdArray + implements LongNdArray { + + public static LongNdArray create(LongDataBuffer buffer, Shape shape) { + Validator.denseShape(buffer, shape); + return new LongDenseNdArray(buffer, shape); + } + + @Override + public long getLong(long... indices) { + return buffer.getLong(positionOf(indices, true)); + } + + @Override + public LongNdArray setLong(long value, long... indices) { + buffer.setLong(value, positionOf(indices, true)); + return this; + } + + @Override + public LongNdArray copyTo(NdArray dst) { + Validator.copyToNdArrayArgs(this, dst); + if (dst instanceof LongDenseNdArray) { + LongDenseNdArray longDst = (LongDenseNdArray) dst; + DataTransfer.execute( + buffer, dimensions(), longDst.buffer, longDst.dimensions(), DataTransfer::ofLong); + } else { + slowCopyTo(dst); + } + return this; + } + + @Override + public LongNdArray copyTo(LongDataBuffer dst) { + Validator.copyToBufferArgs(this, dst); + DataTransfer.execute(buffer, dimensions(), dst, DataTransfer::ofLong); + return this; + } + + @Override + public LongNdArray copyFrom(LongDataBuffer src) { + Validator.copyFromBufferArgs(this, src); + DataTransfer.execute(src, buffer, dimensions(), DataTransfer::ofLong); + return this; + } + + protected LongDenseNdArray(LongDataBuffer buffer, Shape shape) { + this(buffer, DimensionalSpace.create(shape)); + } + + @Override + LongDenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { + return new LongDenseNdArray((LongDataBuffer) buffer, dimensions); + } + + @Override + protected LongDataBuffer buffer() { + return buffer; + } + + private final LongDataBuffer buffer; + + private LongDenseNdArray(LongDataBuffer buffer, DimensionalSpace dimensions) { + super(dimensions); + this.buffer = buffer; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java new file mode 100644 index 00000000000..3aa2880adae --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArray.java @@ -0,0 +1,92 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.ShortNdArray; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +public class ShortDenseNdArray extends AbstractDenseNdArray + implements ShortNdArray { + + public static ShortNdArray create(ShortDataBuffer buffer, Shape shape) { + Validator.denseShape(buffer, shape); + return new ShortDenseNdArray(buffer, shape); + } + + @Override + public short getShort(long... indices) { + return buffer.getShort(positionOf(indices, true)); + } + + @Override + public ShortNdArray setShort(short value, long... indices) { + buffer.setShort(value, positionOf(indices, true)); + return this; + } + + @Override + public ShortNdArray copyTo(NdArray dst) { + Validator.copyToNdArrayArgs(this, dst); + if (dst instanceof ShortDenseNdArray) { + ShortDenseNdArray shortDst = (ShortDenseNdArray) dst; + DataTransfer.execute( + buffer, dimensions(), shortDst.buffer, shortDst.dimensions(), DataTransfer::ofShort); + } else { + slowCopyTo(dst); + } + return this; + } + + @Override + public ShortNdArray copyTo(ShortDataBuffer dst) { + Validator.copyToBufferArgs(this, dst); + DataTransfer.execute(buffer, dimensions(), dst, DataTransfer::ofShort); + return this; + } + + @Override + public ShortNdArray copyFrom(ShortDataBuffer src) { + Validator.copyFromBufferArgs(this, src); + DataTransfer.execute(src, buffer, dimensions(), DataTransfer::ofShort); + return this; + } + + protected ShortDenseNdArray(ShortDataBuffer buffer, Shape shape) { + this(buffer, DimensionalSpace.create(shape)); + } + + @Override + ShortDenseNdArray instantiateView(DataBuffer buffer, DimensionalSpace dimensions) { + return new ShortDenseNdArray((ShortDataBuffer) buffer, dimensions); + } + + @Override + protected ShortDataBuffer buffer() { + return buffer; + } + + private final ShortDataBuffer buffer; + + private ShortDenseNdArray(ShortDataBuffer buffer, DimensionalSpace dimensions) { + super(dimensions); + this.buffer = buffer; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/Validator.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/Validator.java new file mode 100644 index 00000000000..3d2e9c5ed9b --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dense/Validator.java @@ -0,0 +1,49 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.IllegalRankException; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +final class Validator extends org.tensorflow.ndarray.impl.Validator { + + static void coordinates(DimensionalSpace dimensions, long[] coords, boolean isValue) { + if (coords.length > dimensions.numDimensions()) { + throw new IndexOutOfBoundsException(); + } + if (isValue && coords.length != dimensions.numDimensions()) { + throw new IllegalRankException("Not a scalar value"); + } + } + + static void denseShape(DataBuffer buffer, Shape shape) { + if (shape == null) { + throw new IllegalArgumentException("Shape cannot be null"); + } + if (shape.hasUnknownDimension()) { + throw new IllegalArgumentException("Dense arrays cannot have unknown dimension(s)"); + } + if (buffer.size() < shape.size()) { + throw new IllegalArgumentException("Buffer size is smaller than the shape size"); + } + ; + } + + private Validator() {} +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/AbstractDimension.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/AbstractDimension.java new file mode 100644 index 00000000000..4c038ef581e --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/AbstractDimension.java @@ -0,0 +1,41 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dimension; + +abstract class AbstractDimension implements Dimension { + + /** Dimensions are known to be equal if they have the same number of elements */ + @Override + public int hashCode() { + final int prime = 17; + long numElements = numElements(); + return 31 * prime + (int) (numElements ^ (numElements >>> 32)); + } + + /** Dimensions are known to be equal if they have the same number of elements */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj instanceof Dimension) { + Dimension otherDimension = (Dimension) obj; + return numElements() == otherDimension.numElements(); + } + return false; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/Axis.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/Axis.java new file mode 100644 index 00000000000..e031150efc3 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/Axis.java @@ -0,0 +1,61 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dimension; + +final class Axis extends AbstractDimension { + + @Override + public long numElements() { + return numElements; + } + + @Override + public long positionOf(long coord) { + if (coord >= numElements) { + throw new IndexOutOfBoundsException(); + } + return elementSize * coord; + } + + @Override + public boolean isSegmented() { + return false; // all axis are continuous + } + + @Override + public long elementSize() { + return elementSize; + } + + @Override + public long physicalSize() { + return elementSize * numElements; + } + + @Override + public String toString() { + return String.valueOf(numElements); + } + + Axis(long numElements, long elementSize) { + this.numElements = numElements; + this.elementSize = elementSize; + } + + private final long numElements; + private final long elementSize; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/Dimension.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/Dimension.java new file mode 100644 index 00000000000..c24cd825403 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/Dimension.java @@ -0,0 +1,36 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dimension; + +import org.tensorflow.ndarray.index.Index; + +public interface Dimension { + + default Dimension withIndex(Index index) { + return new IndexedDimension(index, this); + } + + long numElements(); + + long elementSize(); + + long physicalSize(); + + long positionOf(long coord); + + boolean isSegmented(); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java new file mode 100644 index 00000000000..598000d23e0 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java @@ -0,0 +1,229 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ + +package org.tensorflow.ndarray.impl.dimension; + +import java.util.Arrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Index; + +public class DimensionalSpace { + + public static DimensionalSpace create(Shape shape) { + Dimension[] dimensions = new Dimension[shape.numDimensions()]; + + // Start from the last dimension, where all elements are continuous + for (int i = dimensions.length - 1, elementSize = 1; i >= 0; --i) { + dimensions[i] = new Axis(shape.get(i), elementSize); + elementSize *= dimensions[i].numElements(); + } + return new DimensionalSpace(dimensions, shape); + } + + public RelativeDimensionalSpace mapTo(Index[] indices) { + if (dimensions == null) { + throw new ArrayIndexOutOfBoundsException(); + } + int dimIdx = 0; + int indexIdx = 0; + int newDimIdx = 0; + int segmentationIdx = -1; + long initialOffset = 0; + + int newAxes = 0; + boolean seenEllipsis = false; + for (Index idx : indices) { + if (idx.isNewAxis()) { + newAxes += 1; + } + if (idx.isEllipsis()) { + if (seenEllipsis) { + throw new IllegalArgumentException("Only one ellipsis allowed"); + } else { + seenEllipsis = true; + } + } + } + int newLength = dimensions.length + newAxes; + + Dimension[] newDimensions = new Dimension[newLength]; + while (indexIdx < indices.length) { + + if (indices[indexIdx].isPoint()) { + // When an index targets a single point in a given dimension, calculate the offset of this + // point and cumulate the offset of any subsequent point as well + long offset = 0; + do { + offset += indices[indexIdx].mapCoordinate(0, dimensions[dimIdx]); + dimIdx++; + } while (++indexIdx < indices.length && indices[indexIdx].isPoint()); + + // If this is the first index, then the offset is the position of the whole dimension + // space within the original one. If not, then we apply the offset to the last vectorial + // dimension + if (newDimIdx == 0) { + initialOffset = offset; + } else { + long reducedSize = dimensions[dimIdx - 1].elementSize(); + newDimensions[newDimIdx - 1] = + new ReducedDimension(newDimensions[newDimIdx - 1], offset, reducedSize); + segmentationIdx = newDimIdx - 1; + } + + } else if (indices[indexIdx].isNewAxis()) { + long newSize; + if (dimIdx == 0) { + // includes everything. Should really include future reduction (at()) but that doesn't + // seem to cause issues + // elsewhere + newSize = dimensions[0].numElements() * dimensions[0].elementSize(); + } else { + newSize = dimensions[dimIdx - 1].elementSize(); + } + + newDimensions[newDimIdx] = new Axis(1, newSize); + segmentationIdx = newDimIdx; // is this correct? + ++newDimIdx; + ++indexIdx; + } else if (indices[indexIdx].isEllipsis()) { + int remainingDimensions = dimensions.length - dimIdx; + int requiredDimensions = 0; + for (int i = indexIdx + 1; i < indices.length; i++) { + if (!indices[i].isNewAxis()) { + requiredDimensions++; + } + } + // while the number of dimensions left < the number of indices that consume axes + while (remainingDimensions > requiredDimensions) { + Dimension dim = dimensions[dimIdx++]; + if (dim.isSegmented()) { + segmentationIdx = newDimIdx; + } + newDimensions[newDimIdx++] = dim; + remainingDimensions--; + } + indexIdx++; + } else { + // Map any other index to the appropriate dimension of this space + Dimension newDimension = indices[indexIdx].apply(dimensions[dimIdx++]); + newDimensions[newDimIdx] = newDimension; + if (newDimension.isSegmented()) { + segmentationIdx = newDimIdx; + } + ++newDimIdx; + ++indexIdx; + } + } + + // When the number of indices provided is smaller than the number of dimensions in this space, + // we copy the remaining dimensions directly to the new space as well. + for (; dimIdx < dimensions.length; ++dimIdx, ++newDimIdx) { + Dimension dim = dimensions[dimIdx]; + newDimensions[newDimIdx] = dim; + if (dim.isSegmented()) { + segmentationIdx = newDimIdx; + } + } + return new RelativeDimensionalSpace( + Arrays.copyOf(newDimensions, newDimIdx), segmentationIdx, initialOffset); + } + + public DimensionalSpace from(int dimensionStart) { + if (dimensionStart > dimensions.length) { + throw new IndexOutOfBoundsException(); + } + Dimension[] newDimensions = Arrays.copyOfRange(dimensions, dimensionStart, dimensions.length); + if (segmentationIdx >= dimensionStart) { + return new DimensionalSpace(newDimensions, segmentationIdx - dimensionStart); + } + return new DimensionalSpace(newDimensions); + } + + public Shape shape() { + if (shape == null) { + shape = toShape(dimensions); + } + return shape; + } + + public int numDimensions() { + return dimensions.length; + } + + public long numElements(int i) { + return dimensions[i].numElements(); + } + + public long physicalSize() { + return dimensions.length > 0 + ? dimensions[0].physicalSize() + : 1; // dimensions.length == 0 for scalars + } + + public Dimension get(int i) { + return dimensions[i]; + } + + public boolean isSegmented() { + return segmentationIdx >= 0; + } + + public int segmentationIdx() { + return segmentationIdx; + } + + public long positionOf(long[] coords) { + long position = 0L; + for (int i = 0; i < coords.length; ++i) { + position += dimensions[i].positionOf(coords[i]); + } + return position; + } + + /** Succinct description of the shape meant for debugging. */ + @Override + public String toString() { + return Arrays.toString(dimensions); + } + + DimensionalSpace(Dimension[] dimensions, int segmentationIdx) { + this.dimensions = dimensions; + this.segmentationIdx = segmentationIdx; + } + + private DimensionalSpace(Dimension[] dimensions) { + this(dimensions, -1); + } + + private DimensionalSpace(Dimension[] dimensions, Shape shape) { + this(dimensions); + this.shape = shape; + } + + private final Dimension[] dimensions; + private final int segmentationIdx; + private Shape shape; + + private static Shape toShape(Dimension[] dimensions) { + long[] shapeDimSizes = new long[dimensions.length]; + int i = 0; + for (Dimension dimension : dimensions) { + shapeDimSizes[i++] = dimension.numElements(); + } + return Shape.of(shapeDimSizes); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/IndexedDimension.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/IndexedDimension.java new file mode 100644 index 00000000000..6129ff55e71 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/IndexedDimension.java @@ -0,0 +1,69 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dimension; + +import org.tensorflow.ndarray.index.Index; + +final class IndexedDimension extends AbstractDimension { + + @Override + public long numElements() { + return numElements; + } + + @Override + public long positionOf(long coord) { + if (coord >= numElements()) { + throw new IndexOutOfBoundsException(); + } + return originalDimension.positionOf(index.mapCoordinate(coord, originalDimension)); + } + + @Override + public boolean isSegmented() { + // TODO (karllessard) for now we consider all indexed dimensions as segmented but might depend + // on the actual index + return true; + } + + @Override + public long elementSize() { + return originalDimension.elementSize(); // indices do not change the size of an inner element + } + + @Override + public long physicalSize() { + // TODO (karllessard) we consider this dimension takes the same amount of memory that the + // original one but might depend on the actual index + return originalDimension.physicalSize(); + } + + @Override + public String toString() { + return String.valueOf(numElements()); + } + + IndexedDimension(Index index, Dimension originalDimension) { + this.index = index; + this.originalDimension = originalDimension; + this.numElements = index.numElements(originalDimension); + } + + private final Index index; + private final Dimension originalDimension; + private final long numElements; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/ReducedDimension.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/ReducedDimension.java new file mode 100644 index 00000000000..a432b0754dd --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/ReducedDimension.java @@ -0,0 +1,62 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dimension; + +final class ReducedDimension extends AbstractDimension { + + @Override + public long numElements() { + return originalDimension.numElements(); + } + + @Override + public long positionOf(long coord) { + return originalDimension.positionOf(coord) + offset; + } + + @Override + public boolean isSegmented() { + return true; + } + + @Override + public long elementSize() { + return elementSize; + } + + @Override + public long physicalSize() { + // We simplify the computation by assuming that a reduced dimension takes the same amount of + // memory than the original one + return originalDimension.physicalSize(); + } + + @Override + public String toString() { + return String.valueOf(numElements()); + } + + ReducedDimension(Dimension originalDimension, long offset, long elementSize) { + this.originalDimension = originalDimension; + this.offset = offset; + this.elementSize = elementSize; + } + + private final Dimension originalDimension; + private final long offset; + private final long elementSize; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/RelativeDimensionalSpace.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/RelativeDimensionalSpace.java new file mode 100644 index 00000000000..b2d3cdd91a4 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/RelativeDimensionalSpace.java @@ -0,0 +1,32 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ + +package org.tensorflow.ndarray.impl.dimension; + +public class RelativeDimensionalSpace extends DimensionalSpace { + + public long position() { + return position; + } + + RelativeDimensionalSpace(Dimension[] dimensions, int segmentationIdx, long position) { + super(dimensions, segmentationIdx); + this.position = position; + } + + private long position; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java new file mode 100644 index 00000000000..8c9c9f86f4c --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/CoordinatesIncrementor.java @@ -0,0 +1,38 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.sequence; + +final class CoordinatesIncrementor { + + boolean increment() { + for (int i = coords.length - 1; i >= 0; --i) { + if ((coords[i] = (coords[i] + 1) % shape[i]) > 0) { + return true; + } + } + return false; + } + + CoordinatesIncrementor(long[] shape, int dimensionIdx) { + this.shape = shape; + this.coords = new long[dimensionIdx + 1]; + } + + final long[] shape; + final long[] coords; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java new file mode 100644 index 00000000000..eec12671911 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/FastElementSequence.java @@ -0,0 +1,87 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.sequence; + +import java.util.Iterator; +import java.util.function.BiConsumer; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArraySequence; +import org.tensorflow.ndarray.buffer.DataBufferWindow; +import org.tensorflow.ndarray.impl.AbstractNdArray; + +/** + * A sequence recycling the same {@code NdArray} instance when iterating its elements + * + * @param Type of the elements + * @param Type of the {@code NdArray} with this sequence + */ +public final class FastElementSequence> implements NdArraySequence { + + public FastElementSequence( + AbstractNdArray ndArray, + int dimensionIdx, + U element, + DataBufferWindow elementWindow) { + this.ndArray = ndArray; + this.dimensionIdx = dimensionIdx; + this.element = element; + this.elementWindow = elementWindow; + } + + @Override + public Iterator iterator() { + return new SequenceIterator(); + } + + @Override + public void forEachIndexed(BiConsumer consumer) { + PositionIterator.createIndexed(ndArray.dimensions(), dimensionIdx) + .forEachIndexed( + (long[] coords, long position) -> { + elementWindow.slideTo(position); + consumer.accept(coords, element); + }); + } + + @Override + public NdArraySequence asSlices() { + return new SlicingElementSequence(ndArray, dimensionIdx); + } + + private class SequenceIterator implements Iterator { + + @Override + public boolean hasNext() { + return positionIterator.hasNext(); + } + + @Override + public U next() { + elementWindow.slideTo(positionIterator.nextLong()); + return element; + } + + private final PositionIterator positionIterator = + PositionIterator.create(ndArray.dimensions(), dimensionIdx); + } + + private final AbstractNdArray ndArray; + private final int dimensionIdx; + private final U element; + private final DataBufferWindow elementWindow; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedPositionIterator.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedPositionIterator.java new file mode 100644 index 00000000000..30ece1599b6 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedPositionIterator.java @@ -0,0 +1,28 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.sequence; + +public interface IndexedPositionIterator extends PositionIterator { + + @FunctionalInterface + interface CoordsLongConsumer { + void consume(long[] coords, long position); + } + + void forEachIndexed(CoordsLongConsumer consumer); +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java new file mode 100644 index 00000000000..9ba90130ae1 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/IndexedSequentialPositionIterator.java @@ -0,0 +1,53 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.sequence; + +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +class IndexedSequentialPositionIterator extends SequentialPositionIterator + implements IndexedPositionIterator { + + @Override + public void forEachIndexed(CoordsLongConsumer consumer) { + while (hasNext()) { + consumer.consume(coords, nextLong()); + incrementCoords(); + } + } + + private void incrementCoords() { + for (int i = coords.length - 1; i >= 0; --i) { + if (coords[i] < shape[i] - 1) { + coords[i] += 1L; + return; + } + coords[i] = 0L; + } + } + + IndexedSequentialPositionIterator(DimensionalSpace dimensions, int dimensionIdx) { + super(dimensions, dimensionIdx); + this.shape = dimensions.shape().asArray(); + this.coords = new long[dimensionIdx + 1]; + // this.coordsIncrementor = new CoordinatesIncrementor(dimensions.shape().asArray(), + // dimensionIdx); + } + + private final long[] shape; + private final long[] coords; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java new file mode 100644 index 00000000000..789474c58ae --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/NdPositionIterator.java @@ -0,0 +1,70 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.sequence; + +import java.util.NoSuchElementException; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +class NdPositionIterator implements IndexedPositionIterator { + + @Override + public boolean hasNext() { + return coords != null; + } + + @Override + public long nextLong() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + long position = dimensions.positionOf(coords); + increment(); + return position; + } + + @Override + public void forEachIndexed(CoordsLongConsumer consumer) { + while (hasNext()) { + consumer.consume(coords, dimensions.positionOf(coords)); + increment(); + } + } + + private void increment() { + if (!increment(coords, dimensions)) { + coords = null; + } + } + + static boolean increment(long[] coords, DimensionalSpace dimensions) { + for (int i = coords.length - 1; i >= 0; --i) { + if ((coords[i] = (coords[i] + 1) % dimensions.get(i).numElements()) > 0) { + return true; + } + } + return false; + } + + NdPositionIterator(DimensionalSpace dimensions, int dimensionIdx) { + this.dimensions = dimensions; + this.coords = new long[dimensionIdx + 1]; + } + + private final DimensionalSpace dimensions; + private long[] coords; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java new file mode 100644 index 00000000000..83ed940563c --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/PositionIterator.java @@ -0,0 +1,42 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.sequence; + +import java.util.PrimitiveIterator; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +public interface PositionIterator extends PrimitiveIterator.OfLong { + + static PositionIterator create(DimensionalSpace dimensions, int dimensionIdx) { + if (dimensions.isSegmented()) { + return new NdPositionIterator(dimensions, dimensionIdx); + } + return new SequentialPositionIterator(dimensions, dimensionIdx); + } + + static IndexedPositionIterator createIndexed(DimensionalSpace dimensions, int dimensionIdx) { + if (dimensions.isSegmented()) { + return new NdPositionIterator(dimensions, dimensionIdx); + } + return new IndexedSequentialPositionIterator(dimensions, dimensionIdx); + } + + static PositionIterator sequence(long stride, long end) { + return new SequentialPositionIterator(stride, end); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java new file mode 100644 index 00000000000..65c6fc966cc --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SequentialPositionIterator.java @@ -0,0 +1,55 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.sequence; + +import java.util.NoSuchElementException; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +class SequentialPositionIterator implements PositionIterator { + + @Override + public boolean hasNext() { + return index < end; + } + + @Override + public long nextLong() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return stride * index++; + } + + SequentialPositionIterator(DimensionalSpace dimensions, int dimensionIdx) { + long size = 1; + for (int i = 0; i <= dimensionIdx; ++i) { + size *= dimensions.get(i).numElements(); + } + this.stride = dimensions.get(dimensionIdx).elementSize(); + this.end = size; + } + + SequentialPositionIterator(long stride, long end) { + this.stride = stride; + this.end = end; + } + + private final long stride; + private final long end; + private long index; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SingleElementSequence.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SingleElementSequence.java new file mode 100644 index 00000000000..98f7b1919ca --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SingleElementSequence.java @@ -0,0 +1,72 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.sequence; + +import java.util.Iterator; +import java.util.function.BiConsumer; +import org.tensorflow.ndarray.IllegalRankException; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArraySequence; +import org.tensorflow.ndarray.impl.AbstractNdArray; + +/** + * A sequence of one single element + * + * @param Type of the element + * @param Type of the {@code NdArray} with this sequence + */ +public final class SingleElementSequence> implements NdArraySequence { + + public SingleElementSequence(AbstractNdArray ndArray) { + this.ndArray = ndArray; + } + + @Override + public Iterator iterator() { + return new Iterator() { + + @Override + public boolean hasNext() { + return element != null; + } + + @Override + public U next() { + U ret = element; + element = null; + return ret; + } + + @SuppressWarnings("unchecked") + private U element = (U) ndArray; + }; + } + + @Override + public NdArraySequence asSlices() { + return this; // no need to slice, as there are only one element + } + + @Override + public void forEachIndexed(BiConsumer consumer) { + throw new IllegalRankException( + "Single element has no coordinates to iterate on, use forEach()"); + } + + private final AbstractNdArray ndArray; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java new file mode 100644 index 00000000000..9d550d387d6 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sequence/SlicingElementSequence.java @@ -0,0 +1,79 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.sequence; + +import java.util.Iterator; +import java.util.function.BiConsumer; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArraySequence; +import org.tensorflow.ndarray.impl.AbstractNdArray; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +/** + * A sequence creating a new {@code NdArray} instance (slice) for each element of an iteration + * + * @param Type of the element + * @param Type of the {@code NdArray} with this sequence + */ +public final class SlicingElementSequence> implements NdArraySequence { + + public SlicingElementSequence(AbstractNdArray ndArray, int dimensionIdx) { + this(ndArray, dimensionIdx, ndArray.dimensions().from(dimensionIdx + 1)); + } + + public SlicingElementSequence( + AbstractNdArray ndArray, int dimensionIdx, DimensionalSpace elementDimensions) { + this.ndArray = ndArray; + this.dimensionIdx = dimensionIdx; + this.elementDimensions = elementDimensions; + } + + @Override + public Iterator iterator() { + PositionIterator positionIterator = PositionIterator.create(ndArray.dimensions(), dimensionIdx); + return new Iterator() { + + @Override + public boolean hasNext() { + return positionIterator.hasNext(); + } + + @Override + public U next() { + return ndArray.slice(positionIterator.next(), elementDimensions); + } + }; + } + + @Override + public void forEachIndexed(BiConsumer consumer) { + PositionIterator.createIndexed(ndArray.dimensions(), dimensionIdx) + .forEachIndexed( + (long[] coords, long position) -> + consumer.accept(coords, ndArray.slice(position, elementDimensions))); + } + + @Override + public NdArraySequence asSlices() { + return this; + } + + private final AbstractNdArray ndArray; + private final int dimensionIdx; + private final DimensionalSpace elementDimensions; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java new file mode 100644 index 00000000000..2a471aca19f --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/AbstractSparseNdArray.java @@ -0,0 +1,557 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse; + +import java.nio.ReadOnlyBufferException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.LongStream; +import org.tensorflow.ndarray.IllegalRankException; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArraySequence; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.SparseNdArray; +import org.tensorflow.ndarray.impl.AbstractNdArray; +import org.tensorflow.ndarray.impl.dense.AbstractDenseNdArray; +import org.tensorflow.ndarray.impl.dimension.Dimension; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace; +import org.tensorflow.ndarray.impl.sequence.SingleElementSequence; +import org.tensorflow.ndarray.impl.sequence.SlicingElementSequence; +import org.tensorflow.ndarray.index.Index; + +/** + * Abstract base class for sparse array. + * + *

A sparse array as two separate dense arrays: indices, values, and a shape that represents the + * dense shape. + * + *

NOTE: all Sparse Arrays are readonly for the {@link #set(NdArray, long...)} and + * {@link #setObject(Object, long...)} methods + * + *

{@code
+ * FloatSparseNdArray st = new FloatSparseNdArray(
+ *      StdArrays.of(new long[][] {{0, 0}, {1, 2}}),
+ *      NdArrays.vectorOf(1f, 2f),
+ *      Shape.of(3, 4));
+ *
+ * }
+ * + *

represents the dense array: + * + *

{@code
+ * [[1, 0, 0, 0]
+ *  [0, 0, 2, 0]
+ *  [0, 0, 0, 0]]
+ *
+ * }
+ * + * @param the type that the array contains + * @param the type of dense NdArray + */ +public abstract class AbstractSparseNdArray> extends AbstractNdArray + implements SparseNdArray { + /** + * A 2-D long array of shape {@code [N, ndims]}, that specifies the indices of the elements in the + * sparse array that contain non-default values (elements are zero-indexed). + * + *

For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * coordinates {@code [1,3]} and {@code [2,4]} have non-default values. + */ + private LongNdArray indices; + + /** + * A 1-D array of any type and shape {@code [N]}, that supplies the values for each element in + * indices. + * + *

For example, given {@code indices=[[1,3], [2,4]]}, and {@code values=[18, 3.6]} specifies + * that element {@code [1,3]} of the sparse array has a value of {@code 18}, and element {@code + * [2,4]} of the sparse array has a value of {@code 3.6}. + */ + private U values; + + /** + * Scalar value to set for indices not specified in {@link #getIndices()} This will default to + * zero, false, or the empty string depending on the data type of the values. + */ + private T defaultValue; + + /** + * Scalar NdArray to use for indices not specified in {@link #getIndices()} This will default to + * zero, false, or the empty string depending on the data type of the values, otherwise it will + * contain the defaultValue. + */ + private U defaultArray; + + /** + * Creates an abstract SparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #indices} + * @param dimensions the dimensional space for the dense object represented by this sparse array. + */ + protected AbstractSparseNdArray( + LongNdArray indices, U values, T defaultValue, DimensionalSpace dimensions) { + super(dimensions); + this.indices = indices; + this.values = values; + setDefaultValue(defaultValue); + + // sanity checks on shapes, indices (shape = {@code [N, ndims]}, where N is the number of values + // (shape = {@code [N]}}. + if (this.indices.shape().get(0) != this.values.shape().get(0)) { + throw new IllegalArgumentException( + String.format( + "The number of rows in indices (%d) does not match the number of elements in values(%d).", + this.indices.shape().get(0), this.values.shape().get(0))); + } + + // sanity checks on shapes, indices (shape = {@code [N, ndims]}, where ndims = the number of + // dimensions in the dense shape. + if (this.indices.shape().get(1) != shape().numDimensions()) { + throw new IllegalArgumentException( + String.format( + "The number of columns in indices (%d) does not match the number of dimensions in shape (%d).", + this.indices.shape().get(1), shape().get(0))); + } + } + + /** + * Creates an abstract SparseNdArray + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + protected AbstractSparseNdArray(T defaultValue, DimensionalSpace dimensions) { + super(dimensions); + setDefaultValue(defaultValue); + } + + /** {@inheritDoc} */ + @Override + public NdArraySequence elements(int dimensionIdx) { + if (dimensionIdx >= shape().numDimensions()) { + throw new IllegalArgumentException( + "Cannot iterate elements in dimension '" + + dimensionIdx + + "' of array with shape " + + shape()); + } + if (rank() == 0 && dimensionIdx < 0) { + return new SingleElementSequence<>(this); + } + DimensionalSpace elemDims = dimensions().from(dimensionIdx + 1); + + return new SlicingElementSequence<>(this, dimensionIdx, elemDims); + } + + /** + * Computes the coordinates based on a relative position to the beginning of the dimension space. + * + * @param dimensions the dimension space + * @param position relative position to the beginning of the dimension space. + * @return the coordinates + */ + // TODO should have automatical access to the coordinates from which this position is coming from. + // But that will require some refactoring even at the dense level. + protected long[] toCoordinates(DimensionalSpace dimensions, long position) { + long[] result = new long[dimensions.numDimensions()]; + long p = position; + + for (int dim = 0; dim < dimensions.numDimensions(); dim++) { + Dimension dimension = dimensions.get(dim); + result[dim] = p / dimension.elementSize(); + p = p % dimension.elementSize(); + } + return result; + } + + /** + * Converts the given set of indices coordinates to a long array of coordinates. + * + *

The shape of the NdArray is {@code [ndims]} + * + * @param l the LongNdArray containing the coordinates + * @return the long array containing the coordinates. + */ + protected long[] getIndicesCoordinates(LongNdArray l) { + long[] results = new long[(int) l.size()]; + for (int i = 0; i < l.size(); i++) { + results[i] = l.getLong(i); + } + return results; + } + + /** + * Converts this sparse array to a dense array. + * + * @return the dense array. + */ + public abstract U toDense(); + + @Override + public U withShape(Shape shape) { + throw new UnsupportedOperationException( + "Sparse NdArrays cannot be viewed with a different shape"); + } + + /** {@inheritDoc} */ + @Override + public NdArray slice(Index... indices) { + if (indices == null) { + throw new IllegalArgumentException("Slicing requires at least one index"); + } + RelativeDimensionalSpace sliceDimensions = dimensions().mapTo(indices); + return slice(sliceDimensions.position(), sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public NdArray get(long... coordinates) { + return slice(positionOf(coordinates, false), dimensions().from(coordinates.length)); + } + + /** {@inheritDoc} */ + @Override + public T getObject(long... coordinates) { + if (coordinates.length != shape().numDimensions()) { + throw new IllegalRankException( + String.format( + "Length of coordinates (%s)%s does not match the rank %d", + coordinates.length, Arrays.toString(coordinates), shape().numDimensions())); + } + long index = locateIndex(coordinates); + if (index >= 0) { + return getValues().getObject(index); + } else { + return defaultValue; + } + } + + /** {@inheritDoc} */ + @Override + public NdArray setObject(T value, long... coords) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public NdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** + * Creates a dense array of the type that this sparse array represents. + * + * @param shape the shape of the dense array. + * @return the dense of the type that this sparse array represents. + */ + public abstract U createValues(Shape shape); + + /** {@inheritDoc} */ + @Override + public NdArray copyTo(NdArray dst) { + if (dst instanceof AbstractSparseNdArray) { + AbstractSparseNdArray sparse = (AbstractSparseNdArray) dst; + LongNdArray indicesCopy = NdArrays.ofLongs(indices.shape()); + this.indices.copyTo(indicesCopy); + U valuesCopy = createValues(values.shape()); + this.values.copyTo(valuesCopy); + sparse.setIndices(indicesCopy); + sparse.setValues(valuesCopy); + } else { + U dense = toDense(); + dense.copyTo(dst); + } + return this; + } + + /** + * Computes the position within the dense array given by the coordinates + * + * @param coords the coordinates within the dense array + * @param isValue indicator whether the coordinates represents a value or higher level dimension. + * @return the position within the array + */ + protected long positionOf(long[] coords, boolean isValue) { + if (coords == null || coords.length == 0) { + return 0; + } + Validator.coordinates(dimensions, coords, isValue); + return dimensions.positionOf(coords); + } + + /** {@inheritDoc} */ + @Override + protected void slowCopyTo(NdArray array) { + if (array instanceof AbstractDenseNdArray) { + AbstractDenseNdArray dst = (AbstractDenseNdArray) array; + long offset = 0L; + for (NdArray s : scalars()) { + dst.setObject(s.getObject(), offset++); + } + } else if (array instanceof AbstractSparseNdArray) { + AbstractSparseNdArray dst = (AbstractSparseNdArray) array; + indices.copyTo(dst.getIndices()); + values.copyTo(dst.values); + } else { + super.slowCopyTo(array); + } + } + + /** + * Gets the Indices + * + * @return the Indices + */ + public LongNdArray getIndices() { + return indices; + } + + /** + * Sets the Indices + * + * @param indices the Indices + */ + public void setIndices(LongNdArray indices) { + this.indices = indices; + } + + /** + * Gets the values + * + * @return the values + */ + public U getValues() { + return values; + } + + /** + * Sets the values + * + * @param values the values + */ + public void setValues(U values) { + this.values = values; + } + + /** + * Gets the values index by coordinates + * + * @param coordinates the coordinates to locate + * @return index of the coordinates, if the coordinates are contained in the {@code indices} + * array; otherwise, {@code (-(insertion point) - 1)}. The insertion point is defined as the + * point at which the {@code coordinates} would be inserted into the {@code indices} array: + * the index of the first element greater than the key, or {@code indices.shape().get(0)}; if + * all elements in the array are less than the specified key. Note that this guarantees that + * the return value will be {@code >= 0}, only if the coordinates are found. + */ + protected long locateIndex(long[] coordinates) { + long size = indices.shape().get(0); + LongNdArray coordArray = NdArrays.vectorOf(coordinates); + return binarySearch(size, coordArray); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + if (dimensions().isSegmented()) { + return slowHashCode(); + } + final int prime = 31; + int result = 1; + result = prime * result + indices.hashCode(); + result = prime * result + values.hashCode(); + result = prime * result + shape().hashCode(); + return result; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof AbstractSparseNdArray)) { + return super.equals(obj); + } + AbstractSparseNdArray other = (AbstractSparseNdArray) obj; + if (!shape().equals(other.shape())) { + return false; + } + if (!indices.equals(other.indices)) { + return false; + } + return values.equals(other.values); + } + + /** + * A String showing the type, default value, number of elements and the dense shape of this sparse + * ndarray. + * + * @return A string containing the type, default value, number of elements and shape. + */ + @Override + public String toString() { + long numElements = values == null ? 0 : values.size(); + String strDefault; + if (defaultValue == null) { + strDefault = ""; + } else if (defaultValue instanceof Number) { + strDefault = defaultValue.toString(); + } else { + strDefault = "'" + defaultValue + "'"; + } + return this.getClass().getSimpleName() + + "(defaultValue=" + + strDefault + + ", numElements=" + + numElements + + ", shape=" + + this.shape() + + ")"; + } + + /** + * Performs a binary search on the indices array to locate the index of the specified coordinates. + * The indices array must be sorted by coordinates, row major. + * + * @param toIndex the index of the last element (exclusive) to be searched + * @param coordinates the coordinates to locate + * @return index of the coordinates, if the coordinates are contained in the {@code indices} + * array; otherwise, {@code (-(insertion point) - 1)}. The insertion point is defined as the + * point at which the {@code coordinates} would be inserted into the {@code indices} array: + * the index of the first element greater than the key, or {@code indices.shape().get(0)}; if + * all elements in the array are less than the specified key. Note that this guarantees that + * the return value will be @{code >= 0}, only if the coordinates are found. + */ + private long binarySearch(long toIndex, LongNdArray coordinates) { + + long low = 0; + long high = toIndex - 1; + + while (low <= high) { + long mid = (low + high) >>> 1; + LongNdArray comparable = indices.get(mid); + int rc = compareCoordinates(comparable, coordinates); + if (rc < 0) { // less than + low = mid + 1; + } else if (rc > 0) { // higher than + high = mid - 1; + } else { // match + return mid; + } + } + return -(low + 1); // no match + } + + /** + * Sorts the indices and values in ascending row-major coordinates. + * + * @return this instance + */ + @SuppressWarnings("UnusedReturnValue") + public AbstractSparseNdArray sortIndicesAndValues() { + + // indices will contain the indexes into the indices and values ndArrays, resorted. + List indexes = new ArrayList<>(); + // create a range for the length of values + LongStream.range(0, values.size()).forEach(indexes::add); + + // then sort this range based on ascending row-wise coordinates. + indexes.sort((a, b) -> compareCoordinates(indices.get(a), indices.get(b))); + + LongNdArray newIndices = NdArrays.ofLongs(indices.shape()); + U newValues = createValues(values.shape()); + // used the sorted indexes to set up the sorted Indices and Values + for (long i = 0; i < indexes.size(); i++) { + long moveIndex = indexes.get((int) i); + newIndices.set(indices.get(moveIndex), i); + newValues.setObject(values.getObject(moveIndex), i); + } + indices = newIndices; + values = newValues; + return this; + } + + /** + * Compares its two arguments for row major coordinate order. + * + * @return a negative integer, zero, or a positive integer as the first argument is less than, + * equal to, or greater than the second. + */ + private int compareCoordinates(LongNdArray a, LongNdArray b) { + int rc = (int) (a.size() - b.size()); + if (rc != 0) { + return rc; + } + + for (long i = 0; i < a.size(); i++) { + long l = a.getLong(i); + rc = (int) (l - b.getLong(i)); + if (rc != 0) { + return rc; + } + } + return 0; + } + + /** + * Scalar value to set for indices not specified in {@link #indices}, defaults to zero, false, or + * the empty String depending on the data type. + */ + public T getDefaultValue() { + return defaultValue; + } + + /** + * Sets the defaultValue + * + * @param defaultValue the default value + */ + public void setDefaultValue(T defaultValue) { + this.defaultValue = defaultValue; + defaultArray = null; + } + + /** + * Creates the NdArray with the default value as a scalar + * + * @return the default NdArray of the default value as a scalar + */ + public abstract U createDefaultArray(); + + /** + * Scalar NdArray to use for indices not specified in {@link #getIndices()} This will default to + * zero, false, or the empty string depending on the data type of the values, otherwise it will + * contain the {@link #defaultValue}. + */ + public U getDefaultArray() { + if (defaultArray == null) { + defaultArray = createDefaultArray(); + } + return defaultArray; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/BooleanSparseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/BooleanSparseNdArray.java new file mode 100644 index 00000000000..d000eddaed9 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/BooleanSparseNdArray.java @@ -0,0 +1,428 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse; + +import java.nio.ReadOnlyBufferException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.tensorflow.ndarray.BooleanNdArray; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.slice.BooleanSparseSlice; +import org.tensorflow.ndarray.index.Index; + +/** + * sparse array for the boolean data type + * + *

A sparse array as two separate dense arrays: indices, values, and a shape that represents the + * dense shape. + * + *

NOTE: all Sparse Arrays are readonly for the {@link #set(NdArray, long...)} and + * {@link #setObject(Boolean, long...)} methods + * + *

{@code
+ * FloatSparseNdArray st = new BooleanSparseNdArray(
+ *      StdArrays.of(new long[][] {{0, 0}, {1, 2}}),
+ *      NdArrays.vectorOf(true, true),
+ *      Shape.of(3, 4));
+ *
+ * }
+ * + *

represents the dense array: + * + *

{@code
+ * [[true, false, false, false]
+ *  [false, false, true, false]
+ *  [false, false, false, false]]
+ *
+ * }
+ */ +public class BooleanSparseNdArray extends AbstractSparseNdArray + implements BooleanNdArray { + + /** + * Creates a BooleanSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of Boolean type and shape {@code [N]}, which supplies the values + * for each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the + * parameter {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse + * NdArray has a value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of + * {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + protected BooleanSparseNdArray( + LongNdArray indices, + BooleanNdArray values, + boolean defaultValue, + DimensionalSpace dimensions) { + super(indices, values, defaultValue, dimensions); + } + + /** + * Creates a BooleanSparseNdArray with a default value of false. + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of Boolean type and shape {@code [N]}, which supplies the values + * for each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the + * parameter {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse + * NdArray has a value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of + * {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + BooleanSparseNdArray(LongNdArray indices, BooleanNdArray values, DimensionalSpace dimensions) { + this(indices, values, false, dimensions); + } + + /** + * Creates a BooleanSparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + BooleanSparseNdArray(BooleanDataBuffer dataBuffer, DimensionalSpace dimensions) { + this(dataBuffer, false, dimensions); + } + + /** + * Creates a BooleanSparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + BooleanSparseNdArray( + BooleanDataBuffer dataBuffer, boolean defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + // use write to set up the indices and values + copyFrom(dataBuffer); + } + + /** + * Creates a zero-filled BooleanSparseNdArray + * + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + BooleanSparseNdArray(DimensionalSpace dimensions) { + this(false, dimensions); + } + + /** + * Creates a zero-filled BooleanSparseNdArray + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + BooleanSparseNdArray(boolean defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + } + + /** + * Creates a new BooleanSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static BooleanSparseNdArray create( + LongNdArray indices, BooleanNdArray values, DimensionalSpace dimensions) { + return new BooleanSparseNdArray(indices, values, dimensions); + } + + /** + * Creates a new BooleanSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static BooleanSparseNdArray create( + LongNdArray indices, + BooleanNdArray values, + boolean defaultValue, + DimensionalSpace dimensions) { + return new BooleanSparseNdArray(indices, values, defaultValue, dimensions); + } + + /** + * Creates a new BooleanSparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static BooleanSparseNdArray create( + BooleanDataBuffer dataBuffer, DimensionalSpace dimensions) { + return new BooleanSparseNdArray(dataBuffer, dimensions); + } + + /** + * Creates a new BooleanSparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static BooleanSparseNdArray create( + BooleanDataBuffer dataBuffer, boolean defaultValue, DimensionalSpace dimensions) { + return new BooleanSparseNdArray(dataBuffer, defaultValue, dimensions); + } + + /** + * Creates a new empty BooleanSparseNdArray from a data buffer + * + * @param dimensions the dimensions array + * @return the new Sparse Array + */ + public static BooleanSparseNdArray create(DimensionalSpace dimensions) { + return new BooleanSparseNdArray(dimensions); + } + + /** + * Creates a new empty BooleanSparseNdArray from a data buffer + * + * @param dimensions the dimensions array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @return the new Sparse Array + */ + public static BooleanSparseNdArray create(boolean defaultValue, DimensionalSpace dimensions) { + return new BooleanSparseNdArray(defaultValue, dimensions); + } + + /** + * Creates a new empty BooleanSparseNdArray from a float data buffer + * + * @param buffer the data buffer + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static BooleanSparseNdArray create(BooleanDataBuffer buffer, Shape shape) { + return new BooleanSparseNdArray(buffer, DimensionalSpace.create(shape)); + } + + /** + * Creates a new empty BooleanSparseNdArray from a float data buffer + * + * @param buffer the data buffer + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static BooleanSparseNdArray create( + BooleanDataBuffer buffer, boolean defaultValue, Shape shape) { + return new BooleanSparseNdArray(buffer, defaultValue, DimensionalSpace.create(shape)); + } + + /** + * Creates a new BooleanSparseNdArray from a BooleanNdArray + * + * @param src the BooleanNdArray + * @return the new Sparse Array + */ + public static BooleanSparseNdArray create(BooleanNdArray src) { + BooleanDataBuffer buffer = DataBuffers.ofBooleans(src.size()); + src.copyTo(buffer); + return new BooleanSparseNdArray(buffer, DimensionalSpace.create(src.shape())); + } + + /** + * Creates a new BooleanSparseNdArray from a BooleanNdArray + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param src the BooleanNdArray + * @return the new Sparse Array + */ + public static BooleanSparseNdArray create(BooleanNdArray src, boolean defaultValue) { + BooleanDataBuffer buffer = DataBuffers.ofBooleans(src.size()); + src.copyTo(buffer); + return new BooleanSparseNdArray(buffer, defaultValue, DimensionalSpace.create(src.shape())); + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray createDefaultArray() { + return NdArrays.scalarOf(getDefaultValue()); + } + + /** + * Creates a BooleanNdArray of the specified shape + * + * @param shape the shape of the dense array. + * @return a BooleanNdArray of the specified shape + */ + public BooleanNdArray createValues(Shape shape) { + return NdArrays.ofBooleans(shape); + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray slice(long position, DimensionalSpace sliceDimensions) { + return new BooleanSparseSlice(this, position, sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public boolean getBoolean(long... coordinates) { + return getObject(coordinates); + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray setBoolean(boolean value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray copyTo(DataBuffer dst) { + return copyTo((BooleanDataBuffer) dst); + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray copyTo(BooleanDataBuffer dst) { + // set the values in buf to the default, then overwrite with indices/values + Boolean[] defaults = new Boolean[(int) shape().size()]; + Arrays.fill(defaults, getDefaultValue()); + dst.write(defaults); + + AtomicInteger i = new AtomicInteger(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + boolean value = getValues().getBoolean(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray copyFrom(BooleanDataBuffer src) { + List indices = new ArrayList<>(); + List values = new ArrayList<>(); + + for (long i = 0; i < src.size(); i++) { + if (!src.getObject(i).equals(getDefaultValue())) { + indices.add(toCoordinates(dimensions, i)); + values.add(src.getObject(i)); + } + } + long[][] indicesArray = new long[indices.size()][]; + boolean[] valuesArray = new boolean[values.size()]; + for (int i = 0; i < indices.size(); i++) { + indicesArray[i] = indices.get(i); + valuesArray[i] = values.get(i); + } + + setIndices(StdArrays.ndCopyOf(indicesArray)); + setValues(NdArrays.vectorOf(valuesArray)); + return this; + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray copyFrom(DataBuffer src) { + return copyFrom((BooleanDataBuffer) src); + } + + /** + * Converts the sparse array to a dense array + * + * @return the dense array + */ + public BooleanNdArray toDense() { + BooleanDataBuffer dataBuffer = DataBuffers.ofBooleans(shape().size()); + copyTo(dataBuffer); + return NdArrays.wrap(shape(), dataBuffer); + } + + /** + * Populates this sparse array from a dense array + * + * @param src the dense array + * @return this sparse array + */ + public BooleanNdArray fromDense(BooleanNdArray src) { + BooleanDataBuffer buffer = DataBuffers.ofBooleans(src.size()); + src.copyTo(buffer); + copyFrom(buffer); + return this; + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray slice(Index... indices) { + return (BooleanNdArray) super.slice(indices); + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray get(long... coordinates) { + return (BooleanNdArray) super.get(coordinates); + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray setObject(Boolean value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray copyTo(NdArray dst) { + return (BooleanNdArray) super.copyTo(dst); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/ByteSparseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/ByteSparseNdArray.java new file mode 100644 index 00000000000..5614c233fe0 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/ByteSparseNdArray.java @@ -0,0 +1,417 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse; + +import java.nio.ReadOnlyBufferException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.tensorflow.ndarray.ByteNdArray; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.slice.ByteSparseSlice; +import org.tensorflow.ndarray.index.Index; + +/** + * sparse array for the byte data type + * + *

A sparse array as two separate dense arrays: indices, values, and a shape that represents the + * dense shape. + * + *

NOTE: all Sparse Arrays are readonly for the {@link #set(NdArray, long...)} and + * {@link #setObject(Byte, long...)} methods + * + *

{@code
+ * ByteSparseNdArray st = new ByteSparseNdArray(
+ *      StdArrays.of(new long[][] {{0, 0}, {1, 2}}),
+ *      NdArrays.vectorOf((byte)1, (byte)255),
+ *      Shape.of(3, 4));
+ *
+ * }
+ * + *

represents the dense array: + * + *

{@code
+ * [[(byte)1, (byte)0, (byte)0, (byte)0]
+ *  [(byte)0, (byte)0, (byte)1, (byte)0]
+ *  [(byte)0, (byte)0, (byte)0, (byte)0]]
+ *
+ * }
+ */ +public class ByteSparseNdArray extends AbstractSparseNdArray + implements ByteNdArray { + + /** + * Creates a ByteSparseNdArray with a default value of zero. + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of Byte type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + protected ByteSparseNdArray( + LongNdArray indices, ByteNdArray values, byte defaultValue, DimensionalSpace dimensions) { + super(indices, values, defaultValue, dimensions); + } + + /** + * Creates a ByteSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of Byte type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + ByteSparseNdArray(LongNdArray indices, ByteNdArray values, DimensionalSpace dimensions) { + this(indices, values, (byte) 0, dimensions); + } + + /** + * Creates a ByteSparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + ByteSparseNdArray(ByteDataBuffer dataBuffer, DimensionalSpace dimensions) { + this(dataBuffer, (byte) 0, dimensions); + } + + /** + * Creates a ByteSparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + ByteSparseNdArray(ByteDataBuffer dataBuffer, byte defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + // use write to set up the indices and values + copyFrom(dataBuffer); + } + + /** + * Creates a zero-filled ByteSparseNdArray + * + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + ByteSparseNdArray(DimensionalSpace dimensions) { + this((byte) 0, dimensions); + } + + /** + * Creates a zero-filled ByteSparseNdArray + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + ByteSparseNdArray(byte defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + } + + /** + * Creates a new ByteSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static ByteSparseNdArray create( + LongNdArray indices, ByteNdArray values, DimensionalSpace dimensions) { + return new ByteSparseNdArray(indices, values, dimensions); + } + + /** + * Creates a new ByteSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static ByteSparseNdArray create( + LongNdArray indices, ByteNdArray values, byte defaultValue, DimensionalSpace dimensions) { + return new ByteSparseNdArray(indices, values, defaultValue, dimensions); + } + + /** + * Creates a new ByteSparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static ByteSparseNdArray create(ByteDataBuffer dataBuffer, DimensionalSpace dimensions) { + return new ByteSparseNdArray(dataBuffer, dimensions); + } + + /** + * Creates a new ByteSparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static ByteSparseNdArray create( + ByteDataBuffer dataBuffer, byte defaultValue, DimensionalSpace dimensions) { + return new ByteSparseNdArray(dataBuffer, defaultValue, dimensions); + } + + /** + * Creates a new empty ByteSparseNdArray from a data buffer + * + * @param dimensions the dimensions array + * @return the new Sparse Array + */ + public static ByteSparseNdArray create(DimensionalSpace dimensions) { + return new ByteSparseNdArray(dimensions); + } + + /** + * Creates a new empty ByteSparseNdArray from a data buffer + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensions array + * @return the new Sparse Array + */ + public static ByteSparseNdArray create(byte defaultValue, DimensionalSpace dimensions) { + return new ByteSparseNdArray(defaultValue, dimensions); + } + + /** + * Creates a new empty ByteSparseNdArray from a float data buffer + * + * @param buffer the data buffer + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static ByteSparseNdArray create(ByteDataBuffer buffer, Shape shape) { + return new ByteSparseNdArray(buffer, DimensionalSpace.create(shape)); + } + + /** + * Creates a new empty ByteSparseNdArray from a float data buffer + * + * @param buffer the data buffer + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static ByteSparseNdArray create(ByteDataBuffer buffer, byte defaultValue, Shape shape) { + return new ByteSparseNdArray(buffer, defaultValue, DimensionalSpace.create(shape)); + } + + /** + * Creates a new ByteSparseNdArray from a ByteNdArray + * + * @param src the ByteNdArray + * @return the new Sparse Array + */ + public static ByteSparseNdArray create(ByteNdArray src) { + ByteDataBuffer buffer = DataBuffers.ofBytes(src.size()); + src.copyTo(buffer); + return new ByteSparseNdArray(buffer, DimensionalSpace.create(src.shape())); + } + + /** + * Creates a new ByteSparseNdArray from a ByteNdArray + * + * @param src the ByteNdArray + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @return the new Sparse Array + */ + public static ByteSparseNdArray create(ByteNdArray src, byte defaultValue) { + ByteDataBuffer buffer = DataBuffers.ofBytes(src.size()); + src.copyTo(buffer); + return new ByteSparseNdArray(buffer, defaultValue, DimensionalSpace.create(src.shape())); + } + + /** + * Creates a ByteNdArray of the specified shape + * + * @param shape the shape of the dense array. + * @return a ByteNdArray of the specified shape + */ + public ByteNdArray createValues(Shape shape) { + return NdArrays.ofBytes(shape); + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray slice(long position, DimensionalSpace sliceDimensions) { + return new ByteSparseSlice(this, position, sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public byte getByte(long... coordinates) { + return getObject(coordinates); + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray setByte(byte value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray copyTo(DataBuffer dst) { + return copyTo((ByteDataBuffer) dst); + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray copyTo(ByteDataBuffer dst) { + // set the values in buf to the default, then overwrite with indices/values + Byte[] defaults = new Byte[(int) shape().size()]; + Arrays.fill(defaults, getDefaultValue()); + dst.write(defaults); + + AtomicInteger i = new AtomicInteger(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + byte value = getValues().getByte(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray copyFrom(ByteDataBuffer src) { + List indices = new ArrayList<>(); + List values = new ArrayList<>(); + + for (long i = 0; i < src.size(); i++) { + if (!src.getObject(i).equals(getDefaultValue())) { + indices.add(toCoordinates(dimensions, i)); + values.add(src.getObject(i)); + } + } + long[][] indicesArray = new long[indices.size()][]; + byte[] valuesArray = new byte[values.size()]; + for (int i = 0; i < indices.size(); i++) { + indicesArray[i] = indices.get(i); + valuesArray[i] = values.get(i); + } + + setIndices(StdArrays.ndCopyOf(indicesArray)); + setValues(NdArrays.vectorOf(valuesArray)); + return this; + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray copyFrom(DataBuffer src) { + return copyFrom((ByteDataBuffer) src); + } + + /** + * Converts the sparse array to a dense array + * + * @return the dense array + */ + public ByteNdArray toDense() { + ByteDataBuffer dataBuffer = DataBuffers.ofBytes(shape().size()); + copyTo(dataBuffer); + return NdArrays.wrap(shape(), dataBuffer); + } + + /** + * Populates this sparse array from a dense array + * + * @param src the dense array + * @return this sparse array + */ + public ByteNdArray fromDense(ByteNdArray src) { + ByteDataBuffer buffer = DataBuffers.ofBytes(src.size()); + src.copyTo(buffer); + copyFrom(buffer); + return this; + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray slice(Index... indices) { + return (ByteNdArray) super.slice(indices); + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray get(long... coordinates) { + return (ByteNdArray) super.get(coordinates); + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray setObject(Byte value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray copyTo(NdArray dst) { + return (ByteNdArray) super.copyTo(dst); + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray createDefaultArray() { + return NdArrays.scalarOf(getDefaultValue()); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/DoubleSparseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/DoubleSparseNdArray.java new file mode 100644 index 00000000000..2a1611725f4 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/DoubleSparseNdArray.java @@ -0,0 +1,420 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse; + +import java.nio.ReadOnlyBufferException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.slice.DoubleSparseSlice; +import org.tensorflow.ndarray.index.Index; + +/** + * A sparse array for the double data type + * + *

A sparse array as two separate dense arrays: indices, values, and a shape that represents the + * dense shape. + * + *

NOTE: all Sparse Arrays are readonly for the {@link #set(NdArray, long...)} and + * {@link #setObject(Double, long...)} methods + * + *

{@code
+ * DoubleSparseNdArray st = new DoubleSparseNdArray(
+ *      StdArrays.of(new long[][] {{0, 0}, {1, 2}}),
+ *      NdArrays.vectorsOf(new double[] {1, 2}),
+ *      Shape.of(3, 4));
+ *
+ * }
+ * + *

represents the dense array: + * + *

{@code
+ * [[1, 0, 0, 0]
+ *  [0, 0, 2, 0]
+ *  [0, 0, 0, 0]]
+ *
+ * }
+ */ +public class DoubleSparseNdArray extends AbstractSparseNdArray + implements DoubleNdArray { + + /** + * Creates a DoubleSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D DoubleNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter {@code + * values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a value of + * {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + protected DoubleSparseNdArray( + LongNdArray indices, DoubleNdArray values, double defaultValue, DimensionalSpace dimensions) { + super(indices, values, defaultValue, dimensions); + } + + /** + * Creates a DoubleSparseNdArray with a default value of zero. + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D DoubleNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter {@code + * values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a value of + * {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + DoubleSparseNdArray(LongNdArray indices, DoubleNdArray values, DimensionalSpace dimensions) { + this(indices, values, 0d, dimensions); + } + + /** + * Creates a DoubleSparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + DoubleSparseNdArray(DoubleDataBuffer dataBuffer, DimensionalSpace dimensions) { + this(dataBuffer, 0d, dimensions); + } + + /** + * Creates a DoubleSparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + DoubleSparseNdArray( + DoubleDataBuffer dataBuffer, double defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + // use write to set up the indices and values + copyFrom(dataBuffer); + } + + /** + * Creates a zero-filled DoubleSparseNdArray + * + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + DoubleSparseNdArray(DimensionalSpace dimensions) { + this(0d, dimensions); + } + + /** + * Creates a zero-filled DoubleSparseNdArray + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + DoubleSparseNdArray(double defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + } + + /** + * Creates a new DoubleSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static DoubleSparseNdArray create( + LongNdArray indices, DoubleNdArray values, DimensionalSpace dimensions) { + return new DoubleSparseNdArray(indices, values, dimensions); + } + + /** + * Creates a new DoubleSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static DoubleSparseNdArray create( + LongNdArray indices, DoubleNdArray values, double defaultValue, DimensionalSpace dimensions) { + return new DoubleSparseNdArray(indices, values, defaultValue, dimensions); + } + + /** + * Creates a new DoubleSparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static DoubleSparseNdArray create( + DoubleDataBuffer dataBuffer, DimensionalSpace dimensions) { + return new DoubleSparseNdArray(dataBuffer, dimensions); + } + + /** + * Creates a new DoubleSparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static DoubleSparseNdArray create( + DoubleDataBuffer dataBuffer, double defaultValue, DimensionalSpace dimensions) { + return new DoubleSparseNdArray(dataBuffer, defaultValue, dimensions); + } + + /** + * Creates a new empty DoubleSparseNdArray from a data buffer + * + * @param dimensions the dimensions array + * @return the new Sparse Array + */ + public static DoubleSparseNdArray create(DimensionalSpace dimensions) { + return new DoubleSparseNdArray(dimensions); + } + + /** + * Creates a new empty DoubleSparseNdArray from a data buffer + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensions array + * @return the new Sparse Array + */ + public static DoubleSparseNdArray create(double defaultValue, DimensionalSpace dimensions) { + return new DoubleSparseNdArray(defaultValue, dimensions); + } + + /** + * Creates a new empty DoubleSparseNdArray from a double data buffer + * + * @param buffer the data buffer + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static DoubleSparseNdArray create(DoubleDataBuffer buffer, Shape shape) { + return new DoubleSparseNdArray(buffer, DimensionalSpace.create(shape)); + } + + /** + * Creates a new empty DoubleSparseNdArray from a double data buffer + * + * @param buffer the data buffer + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static DoubleSparseNdArray create( + DoubleDataBuffer buffer, double defaultValue, Shape shape) { + return new DoubleSparseNdArray(buffer, defaultValue, DimensionalSpace.create(shape)); + } + + /** + * Creates a new DoubleSparseNdArray from a DoubleNdArray + * + * @param src the DoubleNdArray + * @return the new Sparse Array + */ + public static DoubleSparseNdArray create(DoubleNdArray src) { + DoubleDataBuffer buffer = DataBuffers.ofDoubles(src.size()); + src.copyTo(buffer); + return new DoubleSparseNdArray(buffer, DimensionalSpace.create(src.shape())); + } + + /** + * Creates a new DoubleSparseNdArray from a DoubleNdArray + * + * @param src the DoubleNdArray + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @return the new Sparse Array + */ + public static DoubleSparseNdArray create(DoubleNdArray src, double defaultValue) { + DoubleDataBuffer buffer = DataBuffers.ofDoubles(src.size()); + src.copyTo(buffer); + return new DoubleSparseNdArray(buffer, defaultValue, DimensionalSpace.create(src.shape())); + } + + /** + * Creates a DoubleNdArray of the specified shape + * + * @param shape the shape of the dense array. + * @return a DoubleNdArray of the specified shape + */ + public DoubleNdArray createValues(Shape shape) { + return NdArrays.ofDoubles(shape); + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray slice(long position, DimensionalSpace sliceDimensions) { + return new DoubleSparseSlice(this, position, sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public double getDouble(long... coordinates) { + return getObject(coordinates); + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray setDouble(double value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray copyTo(DataBuffer dst) { + return copyTo((DoubleDataBuffer) dst); + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray copyTo(DoubleDataBuffer dst) { + // set buf to the default values, then overwrite with the indices/values. + Double[] defaults = new Double[(int) shape().size()]; + Arrays.fill(defaults, getDefaultValue()); + dst.write(defaults); + + AtomicInteger i = new AtomicInteger(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + double value = getValues().getDouble(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray copyFrom(DoubleDataBuffer src) { + List indices = new ArrayList<>(); + List values = new ArrayList<>(); + + for (long i = 0; i < src.size(); i++) { + if (!src.getObject(i).equals(getDefaultValue())) { + indices.add(toCoordinates(dimensions, i)); + values.add(src.getObject(i)); + } + } + long[][] indicesArray = new long[indices.size()][]; + double[] valuesArray = new double[values.size()]; + for (int i = 0; i < indices.size(); i++) { + indicesArray[i] = indices.get(i); + valuesArray[i] = values.get(i); + } + + setIndices(StdArrays.ndCopyOf(indicesArray)); + setValues(NdArrays.vectorOf(valuesArray)); + return this; + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray copyFrom(DataBuffer src) { + return copyFrom((DoubleDataBuffer) src); + } + + /** + * Converts the sparse array to a dense array + * + * @return the dense array + */ + public DoubleNdArray toDense() { + DoubleDataBuffer dataBuffer = DataBuffers.ofDoubles(shape().size()); + copyTo(dataBuffer); + return NdArrays.wrap(shape(), dataBuffer); + } + + /** + * Populates this sparse array from a dense array + * + * @param src the dense array + * @return this sparse array + */ + public DoubleNdArray fromDense(DoubleNdArray src) { + DoubleDataBuffer buffer = DataBuffers.ofDoubles(src.size()); + src.copyTo(buffer); + copyFrom(buffer); + return this; + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray slice(Index... indices) { + return (DoubleNdArray) super.slice(indices); + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray get(long... coordinates) { + return (DoubleNdArray) super.get(coordinates); + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray setObject(Double value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray copyTo(NdArray dst) { + return (DoubleNdArray) super.copyTo(dst); + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray createDefaultArray() { + return NdArrays.scalarOf(getDefaultValue()); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/FloatSparseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/FloatSparseNdArray.java new file mode 100644 index 00000000000..accb92f385d --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/FloatSparseNdArray.java @@ -0,0 +1,417 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse; + +import java.nio.ReadOnlyBufferException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.slice.FloatSparseSlice; +import org.tensorflow.ndarray.index.Index; + +/** + * sparse array for the float data type + * + *

A sparse array as two separate dense arrays: indices, values, and a shape that represents the + * dense shape. + * + *

NOTE: all Sparse Arrays are readonly for the {@link #set(NdArray, long...)} and + * {@link #setObject(Float, long...)} methods + * + *

{@code
+ * FloatSparseNdArray st = new FloatSparseNdArray(
+ *      StdArrays.of(new long[][] {{0, 0}, {1, 2}}),
+ *      NdArrays.vectorOf(1f, 3.14f}),
+ *      Shape.of(3, 4));
+ *
+ * }
+ * + *

represents the dense array: + * + *

{@code
+ * [[1, 0, 0, 0]
+ *  [0, 0, 3.14, 0]
+ *  [0, 0, 0, 0]]
+ *
+ * }
+ */ +public class FloatSparseNdArray extends AbstractSparseNdArray + implements FloatNdArray { + + /** + * Creates a FloatSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D FloatNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter {@code + * values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a value of + * {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + protected FloatSparseNdArray( + LongNdArray indices, FloatNdArray values, float defaultValue, DimensionalSpace dimensions) { + super(indices, values, defaultValue, dimensions); + } + + /** + * Creates a FloatSparseNdArray with a default value of zero. + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D FloatNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter {@code + * values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a value of + * {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + FloatSparseNdArray(LongNdArray indices, FloatNdArray values, DimensionalSpace dimensions) { + this(indices, values, 0f, dimensions); + } + + /** + * Creates a FloatSparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + FloatSparseNdArray(FloatDataBuffer dataBuffer, DimensionalSpace dimensions) { + this(dataBuffer, 0f, dimensions); + } + + /** + * Creates a FloatSparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + FloatSparseNdArray(FloatDataBuffer dataBuffer, float defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + // use write to set up the indices and values + copyFrom(dataBuffer); + } + + /** + * Creates a zero-filled FloatSparseNdArray + * + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + FloatSparseNdArray(DimensionalSpace dimensions) { + this(0f, dimensions); + } + + /** + * Creates a zero-filled FloatSparseNdArray + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + FloatSparseNdArray(float defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + } + + /** + * Creates a new FloatSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static FloatSparseNdArray create( + LongNdArray indices, FloatNdArray values, DimensionalSpace dimensions) { + return new FloatSparseNdArray(indices, values, dimensions); + } + + /** + * Creates a new FloatSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static FloatSparseNdArray create( + LongNdArray indices, FloatNdArray values, float defaultValue, DimensionalSpace dimensions) { + return new FloatSparseNdArray(indices, values, defaultValue, dimensions); + } + + /** + * Creates a new FloatSparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static FloatSparseNdArray create(FloatDataBuffer dataBuffer, DimensionalSpace dimensions) { + return new FloatSparseNdArray(dataBuffer, dimensions); + } + + /** + * Creates a new FloatSparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static FloatSparseNdArray create( + FloatDataBuffer dataBuffer, float defaultValue, DimensionalSpace dimensions) { + return new FloatSparseNdArray(dataBuffer, defaultValue, dimensions); + } + + /** + * Creates a new empty FloatSparseNdArray from a data buffer + * + * @param dimensions the dimensions array + * @return the new Sparse Array + */ + public static FloatSparseNdArray create(DimensionalSpace dimensions) { + return new FloatSparseNdArray(dimensions); + } + + /** + * Creates a new empty FloatSparseNdArray from a data buffer + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensions array + * @return the new Sparse Array + */ + public static FloatSparseNdArray create(float defaultValue, DimensionalSpace dimensions) { + return new FloatSparseNdArray(defaultValue, dimensions); + } + + /** + * Creates a new empty FloatSparseNdArray from a float data buffer + * + * @param buffer the data buffer + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static FloatSparseNdArray create(FloatDataBuffer buffer, Shape shape) { + return new FloatSparseNdArray(buffer, DimensionalSpace.create(shape)); + } + + /** + * Creates a new empty FloatSparseNdArray from a float data buffer + * + * @param buffer the data buffer + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static FloatSparseNdArray create(FloatDataBuffer buffer, float defaultValue, Shape shape) { + return new FloatSparseNdArray(buffer, defaultValue, DimensionalSpace.create(shape)); + } + + /** + * Creates a new FloatSparseNdArray from a FloatNdArray + * + * @param src the FloatNdArray + * @return the new Sparse Array + */ + public static FloatSparseNdArray create(FloatNdArray src) { + FloatDataBuffer buffer = DataBuffers.ofFloats(src.size()); + src.copyTo(buffer); + return new FloatSparseNdArray(buffer, DimensionalSpace.create(src.shape())); + } + + /** + * Creates a new FloatSparseNdArray from a FloatNdArray + * + * @param src the FloatNdArray + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @return the new Sparse Array + */ + public static FloatSparseNdArray create(FloatNdArray src, float defaultValue) { + FloatDataBuffer buffer = DataBuffers.ofFloats(src.size()); + src.copyTo(buffer); + return new FloatSparseNdArray(buffer, defaultValue, DimensionalSpace.create(src.shape())); + } + + /** + * Creates a FloatNdArray of the specified shape + * + * @param shape the shape of the dense array. + * @return a FloatNdArray of the specified shape + */ + public FloatNdArray createValues(Shape shape) { + return NdArrays.ofFloats(shape); + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray slice(long position, DimensionalSpace sliceDimensions) { + return new FloatSparseSlice(this, position, sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public float getFloat(long... coordinates) { + return getObject(coordinates); + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray setFloat(float value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray copyTo(DataBuffer dst) { + return copyTo((FloatDataBuffer) dst); + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray copyTo(FloatDataBuffer dst) { + // set the values in buf to the default, then overwrite with indices/values + Float[] defaults = new Float[(int) shape().size()]; + Arrays.fill(defaults, getDefaultValue()); + dst.write(defaults); + + AtomicInteger i = new AtomicInteger(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + float value = getValues().getFloat(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray copyFrom(FloatDataBuffer src) { + List indices = new ArrayList<>(); + List values = new ArrayList<>(); + + for (long i = 0; i < src.size(); i++) { + if (!src.getObject(i).equals(getDefaultValue())) { + indices.add(toCoordinates(dimensions, i)); + values.add(src.getObject(i)); + } + } + long[][] indicesArray = new long[indices.size()][]; + float[] valuesArray = new float[values.size()]; + for (int i = 0; i < indices.size(); i++) { + indicesArray[i] = indices.get(i); + valuesArray[i] = values.get(i); + } + + setIndices(StdArrays.ndCopyOf(indicesArray)); + setValues(NdArrays.vectorOf(valuesArray)); + return this; + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray copyFrom(DataBuffer src) { + return copyFrom((FloatDataBuffer) src); + } + + /** + * Converts the sparse array to a dense array + * + * @return the dense array + */ + public FloatNdArray toDense() { + FloatDataBuffer dataBuffer = DataBuffers.ofFloats(shape().size()); + copyTo(dataBuffer); + return NdArrays.wrap(shape(), dataBuffer); + } + + /** + * Populates this sparse array from a dense array + * + * @param src the dense array + * @return this sparse array + */ + public FloatNdArray fromDense(FloatNdArray src) { + FloatDataBuffer buffer = DataBuffers.ofFloats(src.size()); + src.copyTo(buffer); + copyFrom(buffer); + return this; + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray slice(Index... indices) { + return (FloatNdArray) super.slice(indices); + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray get(long... coordinates) { + return (FloatNdArray) super.get(coordinates); + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray setObject(Float value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray copyTo(NdArray dst) { + return (FloatNdArray) super.copyTo(dst); + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray createDefaultArray() { + return NdArrays.scalarOf(getDefaultValue()); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/IntSparseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/IntSparseNdArray.java new file mode 100644 index 00000000000..46be8f624cd --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/IntSparseNdArray.java @@ -0,0 +1,433 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse; + +import java.nio.ReadOnlyBufferException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.tensorflow.ndarray.IntNdArray; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.slice.IntSparseSlice; +import org.tensorflow.ndarray.index.Index; + +/** + * sparse array for the int data type + * + *

A sparse array as two separate dense arrays: indices, values, and a shape that represents the + * dense shape. + * + *

NOTE: all Sparse Arrays are readonly for the {@link #set(NdArray, long...)} and + * {@link #setObject(Integer, long...)} methods + * + *

{@code
+ * IntSparseNdArray st = new IntSparseNdArray(
+ *      StdArrays.of(new long[][] {{0, 0}, {1, 2}}),
+ *      NdArrays.vectorOf(1, 256),
+ *      Shape.of(3, 4));
+ *
+ * }
+ * + *

represents the dense array: + * + *

{@code
+ * [[1, 0, 0, 0]
+ *  [0, 0, 256, 0]
+ *  [0, 0, 0, 0]]
+ *
+ * }
+ */ +public class IntSparseNdArray extends AbstractSparseNdArray + implements IntNdArray { + + /** + * Creates a IntSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D IntNdArray of shape {@code [N]}, which supplies the values for each element + * in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter {@code + * values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a value of + * {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + protected IntSparseNdArray( + LongNdArray indices, IntNdArray values, int defaultValue, DimensionalSpace dimensions) { + super(indices, values, defaultValue, dimensions); + } + + /** + * Creates a IntSparseNdArray with a default value of zero. + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D IntNdArray of shape {@code [N]}, which supplies the values for each element + * in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter {@code + * values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a value of + * {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + IntSparseNdArray(LongNdArray indices, IntNdArray values, DimensionalSpace dimensions) { + this(indices, values, 0, dimensions); + } + + /** + * Creates a IntSparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + IntSparseNdArray(IntDataBuffer dataBuffer, DimensionalSpace dimensions) { + this(dataBuffer, 0, dimensions); + } + + /** + * Creates a IntSparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + IntSparseNdArray(IntDataBuffer dataBuffer, int defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + // use write to set up the indices and values + copyFrom(dataBuffer); + } + + /** + * Creates a zero-filled IntSparseNdArray + * + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + IntSparseNdArray(DimensionalSpace dimensions) { + this(0, dimensions); + } + + /** + * Creates a zero-filled IntSparseNdArray + * + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + IntSparseNdArray(int defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + } + + /** + * Creates a new IntSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static IntSparseNdArray create( + LongNdArray indices, IntNdArray values, DimensionalSpace dimensions) { + return new IntSparseNdArray(indices, values, dimensions); + } + + /** + * Creates a new IntSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static IntSparseNdArray create( + LongNdArray indices, IntNdArray values, int defaultValue, DimensionalSpace dimensions) { + return new IntSparseNdArray(indices, values, defaultValue, dimensions); + } + + /** + * Creates a new IntSparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static IntSparseNdArray create(IntDataBuffer dataBuffer, DimensionalSpace dimensions) { + return new IntSparseNdArray(dataBuffer, dimensions); + } + + /** + * Creates a new IntSparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static IntSparseNdArray create( + IntDataBuffer dataBuffer, int defaultValue, DimensionalSpace dimensions) { + return new IntSparseNdArray(dataBuffer, defaultValue, dimensions); + } + + /** + * Creates a new empty IntSparseNdArray from a data buffer + * + * @param dimensions the dimensions array + * @return the new Sparse Array + */ + public static IntSparseNdArray create(DimensionalSpace dimensions) { + return new IntSparseNdArray(dimensions); + } + + /** + * Creates a new empty IntSparseNdArray from a data buffer + * + * @param dimensions the dimensions array + * @return the new Sparse Array + */ + public static IntSparseNdArray create(int defaultValue, DimensionalSpace dimensions) { + return new IntSparseNdArray(defaultValue, dimensions); + } + + /** + * Creates a new empty IntSparseNdArray from a data buffer + * + * @param shape the shape of the debse array that this sparse array represents + * @return the new Sparse Array + */ + public static IntSparseNdArray create(Shape shape) { + return new IntSparseNdArray(DimensionalSpace.create(shape)); + } + + /** + * Creates a new empty IntSparseNdArray from a data buffer + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param shape the shape of the debse array that this sparse array represents + * @return the new Sparse Array + */ + public static IntSparseNdArray create(int defaultValue, Shape shape) { + return new IntSparseNdArray(defaultValue, DimensionalSpace.create(shape)); + } + + /** + * Creates a new empty IntSparseNdArray from a int data buffer + * + * @param buffer the data buffer + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static IntSparseNdArray create(IntDataBuffer buffer, Shape shape) { + return new IntSparseNdArray(buffer, DimensionalSpace.create(shape)); + } + + /** + * Creates a new empty IntSparseNdArray from a int data buffer + * + * @param buffer the data buffer + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static IntSparseNdArray create(IntDataBuffer buffer, int defaultValue, Shape shape) { + return new IntSparseNdArray(buffer, defaultValue, DimensionalSpace.create(shape)); + } + + /** + * Creates a new IntSparseNdArray from a IntNdArray + * + * @param src the IntNdArray + * @return the new Sparse Array + */ + public static IntSparseNdArray create(IntNdArray src) { + IntDataBuffer buffer = DataBuffers.ofInts(src.size()); + src.copyTo(buffer); + return new IntSparseNdArray(buffer, DimensionalSpace.create(src.shape())); + } + + /** + * Creates a new IntSparseNdArray from a IntNdArray + * + * @param src the IntNdArray + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @return the new Sparse Array + */ + public static IntSparseNdArray create(IntNdArray src, int defaultValue) { + IntDataBuffer buffer = DataBuffers.ofInts(src.size()); + src.copyTo(buffer); + return new IntSparseNdArray(buffer, defaultValue, DimensionalSpace.create(src.shape())); + } + + /** + * Creates a IntNdArray of the specified shape + * + * @param shape the shape of the dense array. + * @return a IntNdArray of the specified shape + */ + public IntNdArray createValues(Shape shape) { + return NdArrays.ofInts(shape); + } + + /** {@inheritDoc} */ + @Override + public IntNdArray slice(long position, DimensionalSpace sliceDimensions) { + return new IntSparseSlice(this, position, sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public int getInt(long... coordinates) { + return getObject(coordinates); + } + + /** {@inheritDoc} */ + @Override + public IntNdArray setInt(int value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public IntNdArray copyTo(DataBuffer dst) { + return copyTo((IntDataBuffer) dst); + } + + /** {@inheritDoc} */ + @Override + public IntNdArray copyTo(IntDataBuffer dst) { + // set the values in buf to the default, then overwrite with indices/values + Integer[] defaults = new Integer[(int) shape().size()]; + Arrays.fill(defaults, getDefaultValue()); + dst.write(defaults); + + AtomicInteger i = new AtomicInteger(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + int value = getValues().getInt(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + /** {@inheritDoc} */ + @Override + public IntNdArray copyFrom(IntDataBuffer src) { + List indices = new ArrayList<>(); + List values = new ArrayList<>(); + + for (long i = 0; i < src.size(); i++) { + if (!src.getObject(i).equals(getDefaultValue())) { + indices.add(toCoordinates(dimensions, i)); + values.add(src.getObject(i)); + } + } + long[][] indicesArray = new long[indices.size()][]; + int[] valuesArray = new int[values.size()]; + for (int i = 0; i < indices.size(); i++) { + indicesArray[i] = indices.get(i); + valuesArray[i] = values.get(i); + } + + setIndices(StdArrays.ndCopyOf(indicesArray)); + setValues(NdArrays.vectorOf(valuesArray)); + return this; + } + + /** {@inheritDoc} */ + @Override + public IntNdArray copyFrom(DataBuffer src) { + return copyFrom((IntDataBuffer) src); + } + + /** + * Converts the sparse array to a dense array + * + * @return the dense array + */ + public IntNdArray toDense() { + IntDataBuffer dataBuffer = DataBuffers.ofInts(shape().size()); + copyTo(dataBuffer); + return NdArrays.wrap(shape(), dataBuffer); + } + + /** + * Populates this sparse array from a dense array + * + * @param src the dense array + * @return this sparse array + */ + public IntNdArray fromDense(IntNdArray src) { + IntDataBuffer buffer = DataBuffers.ofInts(src.size()); + src.copyTo(buffer); + copyFrom(buffer); + return this; + } + + /** {@inheritDoc} */ + @Override + public IntNdArray slice(Index... indices) { + return (IntNdArray) super.slice(indices); + } + + /** {@inheritDoc} */ + @Override + public IntNdArray get(long... coordinates) { + return (IntNdArray) super.get(coordinates); + } + + /** {@inheritDoc} */ + @Override + public IntNdArray setObject(Integer value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public IntNdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public IntNdArray copyTo(NdArray dst) { + return (IntNdArray) super.copyTo(dst); + } + + /** {@inheritDoc} */ + @Override + public IntNdArray createDefaultArray() { + return NdArrays.scalarOf(getDefaultValue()); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/LongSparseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/LongSparseNdArray.java new file mode 100644 index 00000000000..098482e4cc0 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/LongSparseNdArray.java @@ -0,0 +1,416 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse; + +import java.nio.ReadOnlyBufferException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.slice.LongSparseSlice; +import org.tensorflow.ndarray.index.Index; + +/** + * sparse array for the long data type + * + *

A sparse array as two separate dense arrays: indices, values, and a shape that represents the + * dense shape. + * + *

NOTE: all Sparse Arrays are readonly for the {@link #set(NdArray, long...)} and + * {@link #setObject(Long, long...)} methods + * + *

{@code
+ * LongSparseNdArray st = new LongSparseNdArray(
+ *      StdArrays.of(new long[][] {{0, 0}, {1, 2}}),
+ *      NdArrays.vectorOf(1L, 256L),
+ *      Shape.of(3, 4));
+ *
+ * }
+ * + *

represents the dense array: + * + *

{@code
+ * [[1, 0, 0, 0]
+ *  [0, 0, 256, 0]
+ *  [0, 0, 0, 0]]
+ *
+ * }
+ */ +public class LongSparseNdArray extends AbstractSparseNdArray + implements LongNdArray { + + /** + * Creates a LongSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D LongNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter {@code + * values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a value of + * {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + protected LongSparseNdArray( + LongNdArray indices, LongNdArray values, long defaultValue, DimensionalSpace dimensions) { + super(indices, values, defaultValue, dimensions); + } + + /** + * Creates a LongSparseNdArray with a default value of zero. + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D LongNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter {@code + * values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a value of + * {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + LongSparseNdArray(LongNdArray indices, LongNdArray values, DimensionalSpace dimensions) { + this(indices, values, 0L, dimensions); + } + + /** + * Creates a LongSparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + LongSparseNdArray(LongDataBuffer dataBuffer, DimensionalSpace dimensions) { + this(dataBuffer, 0L, dimensions); + } + + /** + * Creates a LongSparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + LongSparseNdArray(LongDataBuffer dataBuffer, long defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + // use write to set up the indices and values + copyFrom(dataBuffer); + } + + /** + * Creates a zero-filled LongSparseNdArray + * + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + LongSparseNdArray(DimensionalSpace dimensions) { + this(0L, dimensions); + } + + /** + * Creates a zero-filled LongSparseNdArray + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + LongSparseNdArray(long defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + } + + /** + * Creates a new LongSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static LongSparseNdArray create( + LongNdArray indices, LongNdArray values, DimensionalSpace dimensions) { + return new LongSparseNdArray(indices, values, dimensions); + } + + /** + * Creates a new LongSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static LongSparseNdArray create( + LongNdArray indices, LongNdArray values, long defaultValue, DimensionalSpace dimensions) { + return new LongSparseNdArray(indices, values, defaultValue, dimensions); + } + + /** + * Creates a new LongSparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static LongSparseNdArray create(LongDataBuffer dataBuffer, DimensionalSpace dimensions) { + return new LongSparseNdArray(dataBuffer, dimensions); + } + + /** + * Creates a new LongSparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static LongSparseNdArray create( + LongDataBuffer dataBuffer, long defaultValue, DimensionalSpace dimensions) { + return new LongSparseNdArray(dataBuffer, defaultValue, dimensions); + } + + /** + * Creates a new empty LongSparseNdArray from a data buffer + * + * @param dimensions the dimensions array + * @return the new Sparse Array + */ + public static LongSparseNdArray create(DimensionalSpace dimensions) { + return new LongSparseNdArray(dimensions); + } + + /** + * Creates a new empty LongSparseNdArray from a data buffer + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensions array + * @return the new Sparse Array + */ + public static LongSparseNdArray create(long defaultValue, DimensionalSpace dimensions) { + return new LongSparseNdArray(defaultValue, dimensions); + } + + /** + * Creates a new empty LongSparseNdArray from a long data buffer + * + * @param buffer the data buffer + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static LongSparseNdArray create(LongDataBuffer buffer, Shape shape) { + return new LongSparseNdArray(buffer, DimensionalSpace.create(shape)); + } + + /** + * Creates a new empty LongSparseNdArray from a long data buffer + * + * @param buffer the data buffer + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static LongSparseNdArray create(LongDataBuffer buffer, long defaultValue, Shape shape) { + return new LongSparseNdArray(buffer, defaultValue, DimensionalSpace.create(shape)); + } + + /** + * Creates a new LongSparseNdArray from a LongNdArray + * + * @param src the LongNdArray + * @return the new Sparse Array + */ + public static LongSparseNdArray create(LongNdArray src) { + LongDataBuffer buffer = DataBuffers.ofLongs(src.size()); + src.copyTo(buffer); + return new LongSparseNdArray(buffer, DimensionalSpace.create(src.shape())); + } + + /** + * Creates a new LongSparseNdArray from a LongNdArray + * + * @param src the LongNdArray + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @return the new Sparse Array + */ + public static LongSparseNdArray create(LongNdArray src, long defaultValue) { + LongDataBuffer buffer = DataBuffers.ofLongs(src.size()); + src.copyTo(buffer); + return new LongSparseNdArray(buffer, defaultValue, DimensionalSpace.create(src.shape())); + } + + /** + * Creates a LongNdArray of the specified shape + * + * @param shape the shape of the dense array. + * @return a LongNdArray of the specified shape + */ + public LongNdArray createValues(Shape shape) { + return NdArrays.ofLongs(shape); + } + + /** {@inheritDoc} */ + @Override + public LongNdArray slice(long position, DimensionalSpace sliceDimensions) { + return new LongSparseSlice(this, position, sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public long getLong(long... coordinates) { + return getObject(coordinates); + } + + /** {@inheritDoc} */ + @Override + public LongNdArray setLong(long value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public LongNdArray copyTo(DataBuffer dst) { + return copyTo((LongDataBuffer) dst); + } + + /** {@inheritDoc} */ + @Override + public LongNdArray copyTo(LongDataBuffer dst) { + // set the values in buf to the default, then overwrite with indices/values + Long[] defaults = new Long[(int) shape().size()]; + Arrays.fill(defaults, getDefaultValue()); + dst.write(defaults); + + AtomicLong i = new AtomicLong(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + long value = getValues().getLong(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + /** {@inheritDoc} */ + @Override + public LongNdArray copyFrom(LongDataBuffer src) { + List indices = new ArrayList<>(); + List values = new ArrayList<>(); + + for (long i = 0; i < src.size(); i++) { + if (!src.getObject(i).equals(getDefaultValue())) { + indices.add(toCoordinates(dimensions, i)); + values.add(src.getObject(i)); + } + } + long[][] indicesArray = new long[indices.size()][]; + long[] valuesArray = new long[values.size()]; + for (int i = 0; i < indices.size(); i++) { + indicesArray[i] = indices.get(i); + valuesArray[i] = values.get(i); + } + + setIndices(StdArrays.ndCopyOf(indicesArray)); + setValues(NdArrays.vectorOf(valuesArray)); + return this; + } + + /** {@inheritDoc} */ + @Override + public LongNdArray copyFrom(DataBuffer src) { + return copyFrom((LongDataBuffer) src); + } + + /** + * Converts the sparse array to a dense array + * + * @return the dense array + */ + public LongNdArray toDense() { + LongDataBuffer dataBuffer = DataBuffers.ofLongs(shape().size()); + copyTo(dataBuffer); + return NdArrays.wrap(shape(), dataBuffer); + } + + /** + * Populates this sparse array from a dense array + * + * @param src the dense array + * @return this sparse array + */ + public LongNdArray fromDense(LongNdArray src) { + LongDataBuffer buffer = DataBuffers.ofLongs(src.size()); + src.copyTo(buffer); + copyFrom(buffer); + return this; + } + + /** {@inheritDoc} */ + @Override + public LongNdArray slice(Index... indices) { + return (LongNdArray) super.slice(indices); + } + + /** {@inheritDoc} */ + @Override + public LongNdArray get(long... coordinates) { + return (LongNdArray) super.get(coordinates); + } + + /** {@inheritDoc} */ + @Override + public LongNdArray setObject(Long value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public LongNdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public LongNdArray copyTo(NdArray dst) { + return (LongNdArray) super.copyTo(dst); + } + + /** {@inheritDoc} */ + @Override + public LongNdArray createDefaultArray() { + return NdArrays.scalarOf(getDefaultValue()); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/ShortSparseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/ShortSparseNdArray.java new file mode 100644 index 00000000000..f9c9c1bf1c9 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/ShortSparseNdArray.java @@ -0,0 +1,417 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse; + +import java.nio.ReadOnlyBufferException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.ShortNdArray; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.slice.ShortSparseSlice; +import org.tensorflow.ndarray.index.Index; + +/** + * sparse array for the short data type + * + *

A sparse array as two separate dense arrays: indices, values, and a shape that represents the + * dense shape. + * + *

NOTE: all Sparse Arrays are readonly for the {@link #set(NdArray, long...)} and + * {@link #setObject(Short, long...)} methods + * + *

{@code
+ * ShortSparseNdArray st = new ShortSparseNdArray(
+ *      StdArrays.of(new long[][] {{0, 0}, {1, 2}}),
+ *      NdArrays.vectorOf((short)1, (short)256}),
+ *      Shape.of(3, 4));
+ *
+ * }
+ * + *

represents the dense array: + * + *

{@code
+ * [[1, 0, 0, 0]
+ *  [0, 0, 256, 0]
+ *  [0, 0, 0, 0]]
+ *
+ * }
+ */ +public class ShortSparseNdArray extends AbstractSparseNdArray + implements ShortNdArray { + + /** + * Creates a ShortSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D ShortNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter {@code + * values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a value of + * {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + protected ShortSparseNdArray( + LongNdArray indices, ShortNdArray values, short defaultValue, DimensionalSpace dimensions) { + super(indices, values, defaultValue, dimensions); + } + + /** + * Creates a ShortSparseNdArray with a default value of zero. + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D ShortNdArray of shape {@code [N]}, which supplies the values for each + * element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter {@code + * values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a value of + * {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + ShortSparseNdArray(LongNdArray indices, ShortNdArray values, DimensionalSpace dimensions) { + this(indices, values, (short) 0, dimensions); + } + + /** + * Creates a ShortSparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + ShortSparseNdArray(ShortDataBuffer dataBuffer, DimensionalSpace dimensions) { + this(dataBuffer, (short) 0, dimensions); + } + + /** + * Creates a ShortSparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + ShortSparseNdArray(ShortDataBuffer dataBuffer, short defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + // use write to set up the indices and values + copyFrom(dataBuffer); + } + + /** + * Creates a zero-filled ShortSparseNdArray + * + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + ShortSparseNdArray(DimensionalSpace dimensions) { + this((short) 0, dimensions); + } + + /** + * Creates a zero-filled ShortSparseNdArray + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + ShortSparseNdArray(short defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + } + + /** + * Creates a new ShortSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static ShortSparseNdArray create( + LongNdArray indices, ShortNdArray values, DimensionalSpace dimensions) { + return new ShortSparseNdArray(indices, values, dimensions); + } + + /** + * Creates a new ShortSparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static ShortSparseNdArray create( + LongNdArray indices, ShortNdArray values, short defaultValue, DimensionalSpace dimensions) { + return new ShortSparseNdArray(indices, values, defaultValue, dimensions); + } + + /** + * Creates a new ShortSparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static ShortSparseNdArray create(ShortDataBuffer dataBuffer, DimensionalSpace dimensions) { + return new ShortSparseNdArray(dataBuffer, dimensions); + } + + /** + * Creates a new ShortSparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static ShortSparseNdArray create( + ShortDataBuffer dataBuffer, short defaultValue, DimensionalSpace dimensions) { + return new ShortSparseNdArray(dataBuffer, defaultValue, dimensions); + } + + /** + * Creates a new empty ShortSparseNdArray from a data buffer + * + * @param dimensions the dimensions array + * @return the new Sparse Array + */ + public static ShortSparseNdArray create(DimensionalSpace dimensions) { + return new ShortSparseNdArray(dimensions); + } + + /** + * Creates a new empty ShortSparseNdArray from a data buffer + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensions array + * @return the new Sparse Array + */ + public static ShortSparseNdArray create(short defaultValue, DimensionalSpace dimensions) { + return new ShortSparseNdArray(defaultValue, dimensions); + } + + /** + * Creates a new empty ShortSparseNdArray from a short data buffer + * + * @param buffer the data buffer + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static ShortSparseNdArray create(ShortDataBuffer buffer, Shape shape) { + return new ShortSparseNdArray(buffer, DimensionalSpace.create(shape)); + } + + /** + * Creates a new empty ShortSparseNdArray from a short data buffer + * + * @param buffer the data buffer + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static ShortSparseNdArray create(ShortDataBuffer buffer, short defaultValue, Shape shape) { + return new ShortSparseNdArray(buffer, defaultValue, DimensionalSpace.create(shape)); + } + + /** + * Creates a new ShortSparseNdArray from a ShortNdArray + * + * @param src the ShortNdArray + * @return the new Sparse Array + */ + public static ShortSparseNdArray create(ShortNdArray src) { + ShortDataBuffer buffer = DataBuffers.ofShorts(src.size()); + src.copyTo(buffer); + return new ShortSparseNdArray(buffer, DimensionalSpace.create(src.shape())); + } + + /** + * Creates a new ShortSparseNdArray from a ShortNdArray + * + * @param src the ShortNdArray + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @return the new Sparse Array + */ + public static ShortSparseNdArray create(ShortNdArray src, short defaultValue) { + ShortDataBuffer buffer = DataBuffers.ofShorts(src.size()); + src.copyTo(buffer); + return new ShortSparseNdArray(buffer, defaultValue, DimensionalSpace.create(src.shape())); + } + + /** + * Creates a ShortNdArray of the specified shape + * + * @param shape the shape of the dense array. + * @return a ShortNdArray of the specified shape + */ + public ShortNdArray createValues(Shape shape) { + return NdArrays.ofShorts(shape); + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray slice(long position, DimensionalSpace sliceDimensions) { + return new ShortSparseSlice(this, position, sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public short getShort(long... coordinates) { + return getObject(coordinates); + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray setShort(short value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray copyTo(DataBuffer dst) { + return copyTo((ShortDataBuffer) dst); + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray copyTo(ShortDataBuffer dst) { + // set the values in buf to the default, then overwrite with indices/values + Short[] defaults = new Short[(int) shape().size()]; + Arrays.fill(defaults, getDefaultValue()); + dst.write(defaults); + + AtomicInteger i = new AtomicInteger(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + short value = getValues().getShort(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray copyFrom(ShortDataBuffer src) { + List indices = new ArrayList<>(); + List values = new ArrayList<>(); + + for (short i = 0; i < src.size(); i++) { + if (!src.getObject(i).equals(getDefaultValue())) { + indices.add(toCoordinates(dimensions, i)); + values.add(src.getObject(i)); + } + } + long[][] indicesArray = new long[indices.size()][]; + short[] valuesArray = new short[values.size()]; + for (int i = 0; i < indices.size(); i++) { + indicesArray[i] = indices.get(i); + valuesArray[i] = values.get(i); + } + + setIndices(StdArrays.ndCopyOf(indicesArray)); + setValues(NdArrays.vectorOf(valuesArray)); + return this; + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray copyFrom(DataBuffer src) { + return copyFrom((ShortDataBuffer) src); + } + + /** + * Converts the sparse array to a dense array + * + * @return the dense array + */ + public ShortNdArray toDense() { + ShortDataBuffer dataBuffer = DataBuffers.ofShorts(shape().size()); + copyTo(dataBuffer); + return NdArrays.wrap(shape(), dataBuffer); + } + + /** + * Populates this sparse array from a dense array + * + * @param src the dense array + * @return this sparse array + */ + public ShortNdArray fromDense(ShortNdArray src) { + ShortDataBuffer buffer = DataBuffers.ofShorts(src.size()); + src.copyTo(buffer); + copyFrom(buffer); + return this; + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray slice(Index... indices) { + return (ShortNdArray) super.slice(indices); + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray get(long... coordinates) { + return (ShortNdArray) super.get(coordinates); + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray setObject(Short value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray copyTo(NdArray dst) { + return (ShortNdArray) super.copyTo(dst); + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray createDefaultArray() { + return NdArrays.scalarOf(getDefaultValue()); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/SparseNdArray.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/SparseNdArray.java new file mode 100644 index 00000000000..10a854ff47d --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/SparseNdArray.java @@ -0,0 +1,428 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse; + +import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.slice.ObjectSparseSlice; + +/** + * sparse array for the any data type + * + *

A sparse array has two separate dense arrays: indices, values, and a shape that represents the + * dense shape. + * + *

NOTE: all Sparse Arrays are readonly for the {@link #set(NdArray, long...)} and + * {@link #setObject(Object, long...)} methods + * + *

{@code
+ * SparseNdArray st = new SparseNdArray<>(
+ *      StdArrays.of(new long[][] {{0, 0}, {1, 2}}),
+ *      NdArrays.vectorOf("first", "second"),
+ *      Shape.of(3, 4));
+ *
+ * }
+ * + *

represents the dense array: + * + *

{@code
+ * [[true, false, false, false]
+ *  [false, false, true, false]
+ *  [false, false, false, false]]
+ *
+ * }
+ */ +public class SparseNdArray> extends AbstractSparseNdArray + implements org.tensorflow.ndarray.SparseNdArray { + + private final Class type; + + /** + * Creates a SparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of Boolean type and shape {@code [N]}, which supplies the values + * for each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the + * parameter {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse + * NdArray has a value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of + * {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + protected SparseNdArray( + Class type, LongNdArray indices, U values, T defaultValue, DimensionalSpace dimensions) { + super(indices, values, defaultValue, dimensions); + this.type = type; + } + + /** + * Creates a SparseNdArray with a default value of null. + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of Boolean type and shape {@code [N]}, which supplies the values + * for each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the + * parameter {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse + * NdArray has a value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of + * {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + SparseNdArray(Class type, LongNdArray indices, U values, DimensionalSpace dimensions) { + this(type, indices, values, null, dimensions); + } + + /** + * Creates a SparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + SparseNdArray(Class type, DataBuffer dataBuffer, DimensionalSpace dimensions) { + this(type, dataBuffer, null, dimensions); + } + + /** + * Creates a SparseNdArray + * + * @param dataBuffer a dense dataBuffer used to create the spars array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + SparseNdArray( + Class type, DataBuffer dataBuffer, T defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + this.type = type; + // use write to set up the indices and values + copyFrom(dataBuffer); + } + + /** + * Creates a zero-filled SparseNdArray + * + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + SparseNdArray(Class type, DimensionalSpace dimensions) { + this(type, (T) null, dimensions); + } + + /** + * Creates a zero-filled SparseNdArray + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array, + */ + SparseNdArray(Class type, T defaultValue, DimensionalSpace dimensions) { + super(defaultValue, dimensions); + this.type = type; + } + + /** + * Creates a new SparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static > SparseNdArray create( + Class type, LongNdArray indices, U values, DimensionalSpace dimensions) { + return new SparseNdArray<>(type, indices, values, dimensions); + } + + /** + * Creates a new SparseNdArray + * + * @param indices A 2-D LongNdArray of shape {@code [N, ndims]}, that specifies the indices of the + * elements in the sparse array that contain non-default values (elements are zero-indexed). + * For example, {@code indices=[[1,3], [2,4]]} specifies that the elements with indexes of + * {@code [1,3]} and {@code [2,4]} have non-default values. + * @param values A 1-D NdArray of any type and shape {@code [N]}, which supplies the values for + * each element in indices. For example, given {@code indices=[[1,3], [2,4]]}, the parameter + * {@code values=[18, 3.6]} specifies that element {@code [1,3]} of the sparse NdArray has a + * value of {@code 18}, and element {@code [2,4]} of the NdArray has a value of {@code 3.6}. + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the dense object represented by this sparse array. + * @return the new Sparse Array + */ + public static > SparseNdArray create( + Class type, LongNdArray indices, U values, T defaultValue, DimensionalSpace dimensions) { + return new SparseNdArray<>(type, indices, values, defaultValue, dimensions); + } + + /** + * Creates a new SparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static > SparseNdArray create( + Class type, DataBuffer dataBuffer, DimensionalSpace dimensions) { + return new SparseNdArray<>(type, dataBuffer, dimensions); + } + + /** + * Creates a new SparseNdArray from a data buffer + * + * @param dataBuffer the databuffer containing the dense array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param dimensions the dimensional space for the sparse array + * @return the new Sparse Array + */ + public static > SparseNdArray create( + Class type, DataBuffer dataBuffer, T defaultValue, DimensionalSpace dimensions) { + return new SparseNdArray<>(type, dataBuffer, defaultValue, dimensions); + } + + /** + * Creates a new empty SparseNdArray from a data buffer + * + * @param dimensions the dimensions array + * @return the new Sparse Array + */ + public static > SparseNdArray create( + Class type, DimensionalSpace dimensions) { + return new SparseNdArray<>(type, dimensions); + } + + /** + * Creates a new empty SparseNdArray from a data buffer + * + * @param dimensions the dimensions array + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @return the new Sparse Array + */ + public static > SparseNdArray create( + Class type, T defaultValue, DimensionalSpace dimensions) { + return new SparseNdArray<>(type, defaultValue, dimensions); + } + + /** + * Creates a new empty SparseNdArray from a float data buffer + * + * @param buffer the data buffer + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static > SparseNdArray create( + Class type, DataBuffer buffer, Shape shape) { + return new SparseNdArray<>(type, buffer, DimensionalSpace.create(shape)); + } + + /** + * Creates a new empty SparseNdArray from a float data buffer + * + * @param buffer the data buffer + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param shape the shape of the sparse array. + * @return the new Sparse Array + */ + public static > SparseNdArray create( + Class type, DataBuffer buffer, T defaultValue, Shape shape) { + return new SparseNdArray<>(type, buffer, defaultValue, DimensionalSpace.create(shape)); + } + + /** + * Creates a new SparseNdArray from a NdArray + * + * @param src the NdArray + * @return the new Sparse Array + */ + public static > SparseNdArray create(Class type, U src) { + DataBuffer buffer = DataBuffers.ofObjects(type, src.size()); + src.copyTo(buffer); + return new SparseNdArray<>(type, buffer, DimensionalSpace.create(src.shape())); + } + + /** + * Creates a new SparseNdArray from a NdArray + * + * @param defaultValue Scalar value to set for indices not specified in {@link #getIndices()} + * @param src the NdArray + * @return the new Sparse Array + */ + public static > SparseNdArray create( + Class type, U src, T defaultValue) { + DataBuffer buffer = DataBuffers.ofObjects(type, src.size()); + src.copyTo(buffer); + return new SparseNdArray<>(type, buffer, defaultValue, DimensionalSpace.create(src.shape())); + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public U createDefaultArray() { + return getDefaultValue() == null + ? (U) NdArrays.ofObjects(type, Shape.scalar()) + : (U) NdArrays.scalarOfObject(getDefaultValue()); + } + + /** + * Creates a NdArray of the specified shape + * + * @param shape the shape of the dense array. + * @return a NdArray of the specified shape + */ + @SuppressWarnings("unchecked") + public U createValues(Shape shape) { + return (U) NdArrays.ofObjects(type, shape); + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public U slice(long position, DimensionalSpace sliceDimensions) { + return (U) new ObjectSparseSlice<>(this, position, sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public NdArray copyTo(DataBuffer dst) { + // set the values in buf to the default, then overwrite with indices/values + @SuppressWarnings("unchecked") + T[] defaults = (T[]) Array.newInstance(type, (int) dst.size()); + Arrays.fill(defaults, getDefaultValue()); + dst.write(defaults); + + AtomicInteger i = new AtomicInteger(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + T value = getValues().getObject(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings({ + "unchecked", + }) + public NdArray copyFrom(DataBuffer src) { + List indices = new ArrayList<>(); + List values = new ArrayList<>(); + + for (long i = 0; i < src.size(); i++) { + if (!Objects.equals(src.getObject(i), getDefaultValue())) { + indices.add(toCoordinates(dimensions, i)); + values.add(src.getObject(i)); + } + } + long[][] indicesArray = new long[indices.size()][]; + // unchecked cast, suppressed. + T[] valuesArray = (T[]) Array.newInstance(type, values.size()); + for (int i = 0; i < indices.size(); i++) { + indicesArray[i] = indices.get(i); + valuesArray[i] = values.get(i); + } + + setIndices(StdArrays.ndCopyOf(indicesArray)); + + // unchecked cast, suppressed. + setValues((U) NdArrays.vectorOfObjects(valuesArray)); + return this; + } + + /** + * Converts the sparse array to a dense array + * + * @return the dense array + */ + @SuppressWarnings("unchecked") + public U toDense() { + DataBuffer dataBuffer = DataBuffers.ofObjects(type, shape().size()); + copyTo(dataBuffer); + // unchecked cast, suppressed. + return (U) NdArrays.wrap(shape(), dataBuffer); + } + + /** + * Populates this sparse array from a dense array + * + * @param src the dense array + * @return this sparse array + */ + public NdArray fromDense(NdArray src) { + DataBuffer buffer = DataBuffers.ofObjects(type, src.size()); + src.copyTo(buffer); + copyFrom(buffer); + return this; + } + + /** + * Gets the class type for this sparse array + * + * @return the class type for this sparse array. + */ + public Class getType() { + return type; + } + + /** + * A String showing the type, default value, number of elements and the dense shape of this sparse + * ndarray. + * + * @return A string containing the type, default value, number of elements and shape. + */ + @Override + public String toString() { + long numElements = getValues() == null ? 0 : getValues().size(); + String strDefault; + T defaultVal = getDefaultValue(); + if (defaultVal == null) { + strDefault = ""; + } else if (defaultVal instanceof Number) { + strDefault = defaultVal.toString(); + } else { + strDefault = "'" + defaultVal + "'"; + } + return this.getClass().getSimpleName() + + "(type=" + + type.getSimpleName() + + ", defaultValue=" + + strDefault + + ", numElements=" + + numElements + + ", shape=" + + this.shape() + + ")"; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/Validator.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/Validator.java new file mode 100644 index 00000000000..2fa77366c9d --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/Validator.java @@ -0,0 +1,46 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse; + +import org.tensorflow.ndarray.IllegalRankException; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; + +final class Validator extends org.tensorflow.ndarray.impl.Validator { + + private Validator() {} + + static void coordinates(DimensionalSpace dimensions, long[] coords, boolean isValue) { + if (coords.length > dimensions.numDimensions()) { + throw new IndexOutOfBoundsException(); + } + if (isValue && coords.length != dimensions.numDimensions()) { + throw new IllegalRankException("Not a scalar value"); + } + } + + static void denseShape(DataBuffer buffer, Shape shape) { + if (shape == null) { + throw new IllegalArgumentException("Shape cannot be null"); + } + if (shape.hasUnknownDimension()) { + throw new IllegalArgumentException("Sparse arrays cannot have unknown dimension(s)"); + } + if (buffer.size() < shape.size()) { + throw new IllegalArgumentException("Buffer size is smaller than the shape size"); + } + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/BooleanSparseSlice.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/BooleanSparseSlice.java new file mode 100644 index 00000000000..0f31c7181a4 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/BooleanSparseSlice.java @@ -0,0 +1,138 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse.slice; + +import java.nio.ReadOnlyBufferException; +import java.util.concurrent.atomic.AtomicInteger; +import org.tensorflow.ndarray.BooleanNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; +import org.tensorflow.ndarray.index.Index; + +public class BooleanSparseSlice extends SparseSlice + implements BooleanNdArray { + + /** + * Creates a BooleanSparseSlice + * + * @param source the source Sparse Array that this object slices. + * @param sourcePosition the relative source position into the source + * @param dimensions the dimensional space for the window + */ + public BooleanSparseSlice( + AbstractSparseNdArray source, + long sourcePosition, + DimensionalSpace dimensions) { + super(source, sourcePosition, dimensions); + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray toDense() { + BooleanDataBuffer dataBuffer = DataBuffers.ofBooleans(shape().size()); + copyTo(dataBuffer); + return NdArrays.wrap(shape(), dataBuffer); + } + + @Override + public boolean getBoolean(long... coordinates) { + return getObject(coordinates); + } + + @Override + public BooleanNdArray setBoolean(boolean value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public BooleanNdArray setObject(Boolean value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public BooleanNdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray copyTo(DataBuffer dst) { + // zero out buf. + Boolean[] defaults = new Boolean[(int) shape().size()]; + dst.write(defaults); + + AtomicInteger i = new AtomicInteger(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + boolean value = getValues().getBoolean(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + @Override + public BooleanNdArray copyTo(BooleanDataBuffer dst) { + return copyTo((DataBuffer) dst); + } + + @Override + public BooleanNdArray copyFrom(DataBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public BooleanNdArray copyFrom(BooleanDataBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public BooleanNdArray slice(Index... indices) { + if (indices == null) { + throw new IllegalArgumentException("Slicing requires at least one index"); + } + RelativeDimensionalSpace sliceDimensions = dimensions().mapTo(indices); + return slice(sliceDimensions.position(), sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public BooleanNdArray slice(long position, DimensionalSpace sliceDimensions) { + return new BooleanSparseSlice(this.source, position + sourcePosition, sliceDimensions); + } + + @Override + public BooleanNdArray get(long... coordinates) { + return (BooleanNdArray) super.get(coordinates); + } + + @Override + public BooleanNdArray copyTo(NdArray dst) { + return (BooleanNdArray) super.copyTo(dst); + } + + @Override + public BooleanNdArray createDefaultArray() { + return source.getDefaultArray(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/ByteSparseSlice.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/ByteSparseSlice.java new file mode 100644 index 00000000000..1da93a133de --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/ByteSparseSlice.java @@ -0,0 +1,137 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse.slice; + +import java.nio.ReadOnlyBufferException; +import java.util.concurrent.atomic.AtomicInteger; +import org.tensorflow.ndarray.ByteNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; +import org.tensorflow.ndarray.index.Index; + +public class ByteSparseSlice extends SparseSlice implements ByteNdArray { + + /** + * Creates a ByteSparseSlice + * + * @param source the source Sparse Array that this object slices. + * @param sourcePosition the relative source position into the source + * @param dimensions the dimensional space for the window + */ + public ByteSparseSlice( + AbstractSparseNdArray source, + long sourcePosition, + DimensionalSpace dimensions) { + super(source, sourcePosition, dimensions); + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray toDense() { + ByteDataBuffer dataBuffer = DataBuffers.ofBytes(shape().size()); + copyTo(dataBuffer); + return NdArrays.wrap(shape(), dataBuffer); + } + + @Override + public byte getByte(long... coordinates) { + return getObject(coordinates); + } + + @Override + public ByteNdArray setByte(byte value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteNdArray setObject(Byte value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteNdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray copyTo(DataBuffer dst) { + // zero out buf. + Byte[] defaults = new Byte[(int) shape().size()]; + dst.write(defaults); + + AtomicInteger i = new AtomicInteger(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + byte value = getValues().getByte(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + @Override + public ByteNdArray copyTo(ByteDataBuffer dst) { + return this.copyTo((DataBuffer) dst); + } + + @Override + public ByteNdArray copyFrom(DataBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteNdArray copyFrom(ByteDataBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteNdArray slice(Index... indices) { + if (indices == null) { + throw new IllegalArgumentException("Slicing requires at least one index"); + } + RelativeDimensionalSpace sliceDimensions = dimensions().mapTo(indices); + return slice(sliceDimensions.position(), sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public ByteNdArray slice(long position, DimensionalSpace sliceDimensions) { + return new ByteSparseSlice(this.source, position + sourcePosition, sliceDimensions); + } + + @Override + public ByteNdArray get(long... coordinates) { + return (ByteNdArray) super.get(coordinates); + } + + @Override + public ByteNdArray copyTo(NdArray dst) { + return (ByteNdArray) super.copyTo(dst); + } + + @Override + public ByteNdArray createDefaultArray() { + return source.getDefaultArray(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/DoubleSparseSlice.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/DoubleSparseSlice.java new file mode 100644 index 00000000000..0e99aa4750f --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/DoubleSparseSlice.java @@ -0,0 +1,139 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse.slice; + +import java.nio.ReadOnlyBufferException; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; +import org.tensorflow.ndarray.index.Index; + +public class DoubleSparseSlice extends SparseSlice implements DoubleNdArray { + + /** + * Creates a DoubleSparseSlice + * + * @param source the source Sparse Array that this object slices. + * @param sourcePosition the relative source position into the source + * @param dimensions the dimensional space for the window + */ + public DoubleSparseSlice( + AbstractSparseNdArray source, + long sourcePosition, + DimensionalSpace dimensions) { + super(source, sourcePosition, dimensions); + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray toDense() { + DoubleDataBuffer dataBuffer = DataBuffers.ofDoubles(shape().size()); + copyTo(dataBuffer); + return NdArrays.wrap(shape(), dataBuffer); + } + + @Override + public double getDouble(long... coordinates) { + return getObject(coordinates); + } + + @Override + public DoubleNdArray setDouble(double value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public DoubleNdArray setObject(Double value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public DoubleNdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray copyTo(DataBuffer dst) { + // set the values in buf to the default, then overwrite with indices/values + Double[] defaults = new Double[(int) shape().size()]; + Arrays.fill(defaults, getDefaultValue()); + dst.write(defaults); + + AtomicInteger i = new AtomicInteger(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + double value = getValues().getDouble(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + @Override + public DoubleNdArray copyTo(DoubleDataBuffer dst) { + return this.copyTo((DataBuffer) dst); + } + + @Override + public DoubleNdArray copyFrom(DataBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public DoubleNdArray copyFrom(DoubleDataBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public DoubleNdArray slice(Index... indices) { + if (indices == null) { + throw new IllegalArgumentException("Slicing requires at least one index"); + } + RelativeDimensionalSpace sliceDimensions = dimensions().mapTo(indices); + return slice(sliceDimensions.position(), sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public DoubleNdArray slice(long position, DimensionalSpace sliceDimensions) { + return new DoubleSparseSlice(this.source, position + sourcePosition, sliceDimensions); + } + + @Override + public DoubleNdArray get(long... coordinates) { + return (DoubleNdArray) super.get(coordinates); + } + + @Override + public DoubleNdArray copyTo(NdArray dst) { + return (DoubleNdArray) super.copyTo(dst); + } + + @Override + public DoubleNdArray createDefaultArray() { + return source.getDefaultArray(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/FloatSparseSlice.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/FloatSparseSlice.java new file mode 100644 index 00000000000..75abfe46f54 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/FloatSparseSlice.java @@ -0,0 +1,139 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse.slice; + +import java.nio.ReadOnlyBufferException; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; +import org.tensorflow.ndarray.index.Index; + +public class FloatSparseSlice extends SparseSlice implements FloatNdArray { + + /** + * Creates a FloatSparseSlice + * + * @param source the source Sparse Array that this object slices. + * @param sourcePosition the relative source position into the source + * @param dimensions the dimensional space for the window + */ + public FloatSparseSlice( + AbstractSparseNdArray source, + long sourcePosition, + DimensionalSpace dimensions) { + super(source, sourcePosition, dimensions); + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray toDense() { + FloatDataBuffer dataBuffer = DataBuffers.ofFloats(shape().size()); + copyTo(dataBuffer); + return NdArrays.wrap(shape(), dataBuffer); + } + + @Override + public float getFloat(long... coordinates) { + return getObject(coordinates); + } + + @Override + public FloatNdArray setFloat(float value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public FloatNdArray setObject(Float value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public FloatNdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray copyTo(DataBuffer dst) { + // set the values in buf to the default, then overwrite with indices/values + Float[] defaults = new Float[(int) shape().size()]; + Arrays.fill(defaults, getDefaultValue()); + dst.write(defaults); + + AtomicInteger i = new AtomicInteger(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + float value = getValues().getFloat(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + @Override + public FloatNdArray copyTo(FloatDataBuffer dst) { + return this.copyTo((DataBuffer) dst); + } + + @Override + public FloatNdArray copyFrom(DataBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public FloatNdArray copyFrom(FloatDataBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public FloatNdArray slice(Index... indices) { + if (indices == null) { + throw new IllegalArgumentException("Slicing requires at least one index"); + } + RelativeDimensionalSpace sliceDimensions = dimensions().mapTo(indices); + return slice(sliceDimensions.position(), sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public FloatNdArray slice(long position, DimensionalSpace sliceDimensions) { + return new FloatSparseSlice(this.source, position + sourcePosition, sliceDimensions); + } + + @Override + public FloatNdArray get(long... coordinates) { + return (FloatNdArray) super.get(coordinates); + } + + @Override + public FloatNdArray copyTo(NdArray dst) { + return (FloatNdArray) super.copyTo(dst); + } + + @Override + public FloatNdArray createDefaultArray() { + return source.getDefaultArray(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/IntSparseSlice.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/IntSparseSlice.java new file mode 100644 index 00000000000..831d6727d4f --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/IntSparseSlice.java @@ -0,0 +1,139 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse.slice; + +import java.nio.ReadOnlyBufferException; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; +import org.tensorflow.ndarray.IntNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; +import org.tensorflow.ndarray.index.Index; + +public class IntSparseSlice extends SparseSlice implements IntNdArray { + + /** + * Creates a IntSparseSlice + * + * @param source the source Sparse Array that this object slices. + * @param sourcePosition the relative source position into the source + * @param dimensions the dimensional space for the window + */ + public IntSparseSlice( + AbstractSparseNdArray source, + long sourcePosition, + DimensionalSpace dimensions) { + super(source, sourcePosition, dimensions); + } + + /** {@inheritDoc} */ + @Override + public IntNdArray toDense() { + IntDataBuffer dataBuffer = DataBuffers.ofInts(shape().size()); + copyTo(dataBuffer); + return NdArrays.wrap(shape(), dataBuffer); + } + + @Override + public int getInt(long... coordinates) { + return getObject(coordinates); + } + + @Override + public IntNdArray setInt(int value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public IntNdArray setObject(Integer value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public IntNdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public IntNdArray copyTo(DataBuffer dst) { + // set the values in buf to the default, then overwrite with indices/values + Integer[] defaults = new Integer[(int) shape().size()]; + Arrays.fill(defaults, getDefaultValue()); + dst.write(defaults); + + AtomicInteger i = new AtomicInteger(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + int value = getValues().getInt(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + @Override + public IntNdArray copyTo(IntDataBuffer dst) { + return this.copyTo((DataBuffer) dst); + } + + @Override + public IntNdArray copyFrom(DataBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public IntNdArray copyFrom(IntDataBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public IntNdArray slice(Index... indices) { + if (indices == null) { + throw new IllegalArgumentException("Slicing requires at least one index"); + } + RelativeDimensionalSpace sliceDimensions = dimensions().mapTo(indices); + return slice(sliceDimensions.position(), sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public IntNdArray slice(long position, DimensionalSpace sliceDimensions) { + return new IntSparseSlice(this.source, position + sourcePosition, sliceDimensions); + } + + @Override + public IntNdArray get(long... coordinates) { + return (IntNdArray) super.get(coordinates); + } + + @Override + public IntNdArray copyTo(NdArray dst) { + return (IntNdArray) super.copyTo(dst); + } + + @Override + public IntNdArray createDefaultArray() { + return source.getDefaultArray(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/LongSparseSlice.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/LongSparseSlice.java new file mode 100644 index 00000000000..e882eb08a34 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/LongSparseSlice.java @@ -0,0 +1,139 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse.slice; + +import java.nio.ReadOnlyBufferException; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicLong; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; +import org.tensorflow.ndarray.index.Index; + +public class LongSparseSlice extends SparseSlice implements LongNdArray { + + /** + * Creates a LongSparseSlice + * + * @param source the source Sparse Array that this object slices. + * @param sourcePosition the relative source position into the source + * @param dimensions the dimensional space for the window + */ + public LongSparseSlice( + AbstractSparseNdArray source, + long sourcePosition, + DimensionalSpace dimensions) { + super(source, sourcePosition, dimensions); + } + + /** {@inheritDoc} */ + @Override + public LongNdArray toDense() { + LongDataBuffer dataBuffer = DataBuffers.ofLongs(shape().size()); + copyTo(dataBuffer); + return NdArrays.wrap(shape(), dataBuffer); + } + + @Override + public long getLong(long... coordinates) { + return getObject(coordinates); + } + + @Override + public LongNdArray setLong(long value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public LongNdArray setObject(Long value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public LongNdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public LongNdArray copyTo(DataBuffer dst) { + // set the values in buf to the default, then overwrite with indices/values + Long[] defaults = new Long[(int) shape().size()]; + Arrays.fill(defaults, getDefaultValue()); + dst.write(defaults); + + AtomicLong i = new AtomicLong(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + long value = getValues().getLong(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + @Override + public LongNdArray copyTo(LongDataBuffer dst) { + return copyTo((DataBuffer) dst); + } + + @Override + public LongNdArray copyFrom(DataBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public LongNdArray copyFrom(LongDataBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public LongNdArray slice(Index... indices) { + if (indices == null) { + throw new IllegalArgumentException("Slicing requires at least one index"); + } + RelativeDimensionalSpace sliceDimensions = dimensions().mapTo(indices); + return slice(sliceDimensions.position(), sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public LongNdArray slice(long position, DimensionalSpace sliceDimensions) { + return new LongSparseSlice(this.source, position + sourcePosition, sliceDimensions); + } + + @Override + public LongNdArray get(long... coordinates) { + return (LongNdArray) super.get(coordinates); + } + + @Override + public LongNdArray copyTo(NdArray dst) { + return (LongNdArray) super.copyTo(dst); + } + + @Override + public LongNdArray createDefaultArray() { + return source.getDefaultArray(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/ObjectSparseSlice.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/ObjectSparseSlice.java new file mode 100644 index 00000000000..9d547670803 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/ObjectSparseSlice.java @@ -0,0 +1,114 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse.slice; + +import java.lang.reflect.Array; +import java.nio.ReadOnlyBufferException; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.SparseNdArray; +import org.tensorflow.ndarray.index.Index; + +public class ObjectSparseSlice> extends SparseSlice + implements NdArray { + + /** + * Creates a BooleanSparseSlice + * + * @param source the source Sparse Array that this object slices. + * @param sourcePosition the relative source position into the source + * @param dimensions the dimensional space for the window + */ + public ObjectSparseSlice( + SparseNdArray source, long sourcePosition, DimensionalSpace dimensions) { + super(source, sourcePosition, dimensions); + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public U toDense() { + DataBuffer dataBuffer = DataBuffers.ofObjects(getType(), shape().size()); + copyTo(dataBuffer); + // unchecked NdArray to U + return (U) NdArrays.wrap(shape(), dataBuffer); + } + + @Override + public U setObject(T value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public U set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public U copyTo(DataBuffer dst) { + // unchecked Object to T[] + T[] defaults = (T[]) Array.newInstance(getType(), (int) dst.size()); + Arrays.fill(defaults, getDefaultValue()); + dst.write(defaults); + + AtomicInteger i = new AtomicInteger(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + T value = getValues().getObject(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + // Unchecked cast ObjectSparseSlice to U + return (U) this; + } + + @Override + public U slice(Index... indices) { + if (indices == null) { + throw new IllegalArgumentException("Slicing requires at least one index"); + } + RelativeDimensionalSpace sliceDimensions = dimensions().mapTo(indices); + return slice(sliceDimensions.position(), sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public U slice(long position, DimensionalSpace sliceDimensions) { + // unchecked ObjectSparseSlice to U + return (U) + new ObjectSparseSlice<>( + (SparseNdArray) this.source, position + sourcePosition, sliceDimensions); + } + + @Override + public U createDefaultArray() { + return source.getDefaultArray(); + } + + public Class getType() { + return ((SparseNdArray) source).getType(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/ShortSparseSlice.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/ShortSparseSlice.java new file mode 100644 index 00000000000..43424d25fce --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/ShortSparseSlice.java @@ -0,0 +1,139 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse.slice; + +import java.nio.ReadOnlyBufferException; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicLong; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.ShortNdArray; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; +import org.tensorflow.ndarray.index.Index; + +public class ShortSparseSlice extends SparseSlice implements ShortNdArray { + + /** + * Creates a LongSparseSlice + * + * @param source the source Sparse Array that this object slices. + * @param sourcePosition the relative source position into the source + * @param dimensions the dimensional space for the window + */ + public ShortSparseSlice( + AbstractSparseNdArray source, + long sourcePosition, + DimensionalSpace dimensions) { + super(source, sourcePosition, dimensions); + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray toDense() { + ShortDataBuffer dataBuffer = DataBuffers.ofShorts(shape().size()); + copyTo(dataBuffer); + return NdArrays.wrap(shape(), dataBuffer); + } + + @Override + public short getShort(long... coordinates) { + return getObject(coordinates); + } + + @Override + public ShortNdArray setShort(short value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public ShortNdArray setObject(Short value, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + @Override + public ShortNdArray set(NdArray src, long... coordinates) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray copyTo(DataBuffer dst) { + // set the values in buf to the default, then overwrite with indices/values + Short[] defaults = new Short[(int) shape().size()]; + Arrays.fill(defaults, getDefaultValue()); + dst.write(defaults); + + AtomicLong i = new AtomicLong(); + getIndices() + .elements(0) + .forEachIndexed( + (idx, l) -> { + long[] coordinates = getIndicesCoordinates(l); + short value = getValues().getShort(i.getAndIncrement()); + dst.setObject(value, dimensions.positionOf(coordinates)); + }); + return this; + } + + @Override + public ShortNdArray copyTo(ShortDataBuffer dst) { + return this.copyTo((DataBuffer) dst); + } + + @Override + public ShortNdArray copyFrom(DataBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public ShortNdArray copyFrom(ShortDataBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public ShortNdArray slice(Index... indices) { + if (indices == null) { + throw new IllegalArgumentException("Slicing requires at least one index"); + } + RelativeDimensionalSpace sliceDimensions = dimensions().mapTo(indices); + return slice(sliceDimensions.position(), sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public ShortNdArray slice(long position, DimensionalSpace sliceDimensions) { + return new ShortSparseSlice(this.source, position + sourcePosition, sliceDimensions); + } + + @Override + public ShortNdArray get(long... coordinates) { + return (ShortNdArray) super.get(coordinates); + } + + @Override + public ShortNdArray copyTo(NdArray dst) { + return (ShortNdArray) super.copyTo(dst); + } + + @Override + public ShortNdArray createDefaultArray() { + return source.getDefaultArray(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/SparseSlice.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/SparseSlice.java new file mode 100644 index 00000000000..3f09456ec8b --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/impl/sparse/slice/SparseSlice.java @@ -0,0 +1,144 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse.slice; + +import java.nio.ReadOnlyBufferException; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArraySequence; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.impl.dimension.RelativeDimensionalSpace; +import org.tensorflow.ndarray.impl.sequence.SingleElementSequence; +import org.tensorflow.ndarray.impl.sequence.SlicingElementSequence; +import org.tensorflow.ndarray.impl.sparse.AbstractSparseNdArray; +import org.tensorflow.ndarray.index.Index; + +/** + * A sparse window is a view into an AbstractSparseNdArray. It is used internally by the slice + * methods. + * + * @param the type that the array contains + * @param the type of dense NdArray + */ +public abstract class SparseSlice> extends AbstractSparseNdArray { + protected final AbstractSparseNdArray source; + protected final long sourcePosition; + + /** + * Creates a SparseSlice + * + * @param source the source Sparse Array that this object slices. + * @param sourcePosition the relative position into the source array + * @param dimensions the dimensional space for the window + */ + public SparseSlice( + AbstractSparseNdArray source, long sourcePosition, DimensionalSpace dimensions) { + super(source.getDefaultValue(), dimensions); + this.source = source; + this.sourcePosition = sourcePosition; + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + source.hashCode(); + result = prime * result + (int) sourcePosition; + return result; + } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof SparseSlice)) { + return super.equals(obj); + } + SparseSlice other = (SparseSlice) obj; + if (!source.equals(other.source)) { + return false; + } + if (!shape().equals(other.shape())) { + return false; + } + return sourcePosition == other.sourcePosition; + } + + /** {@inheritDoc} */ + @Override + public T getObject(long... coordinates) { + long position = dimensions().positionOf(coordinates); + long[] sourceCoordinates = toCoordinates(source.dimensions(), sourcePosition + position); + return source.getObject(sourceCoordinates); + } + + /** {@inheritDoc} */ + @Override + public NdArray get(long... coordinates) { + long position = dimensions().positionOf(coordinates); + long[] sourceCoordinates = toCoordinates(source.dimensions(), sourcePosition + position); + return source.get(sourceCoordinates); + } + + /** {@inheritDoc} */ + @Override + public NdArray slice(Index... indices) { + if (indices == null) { + throw new IllegalArgumentException("Slicing requires at least one index"); + } + RelativeDimensionalSpace sliceDimensions = dimensions().mapTo(indices); + return slice(sliceDimensions.position(), sliceDimensions); + } + + /** {@inheritDoc} */ + @Override + public NdArraySequence elements(int dimensionIdx) { + if (dimensionIdx >= shape().numDimensions()) { + throw new IllegalArgumentException( + "Cannot iterate elements in dimension '" + + dimensionIdx + + "' of array with shape " + + shape()); + } + if (rank() == 0 && dimensionIdx < 0) { + return new SingleElementSequence<>(this); + } + DimensionalSpace elemDims = dimensions().from(dimensionIdx + 1); + return new SlicingElementSequence<>(this, dimensionIdx, elemDims); + } + + /** + * Converts the sparse window to a dense NdArray + * + * @return the NdArray + */ + public abstract U toDense(); + + /** {@inheritDoc} */ + @Override + public NdArray copyFrom(DataBuffer src) { + throw new ReadOnlyBufferException(); + } + + /** {@inheritDoc} */ + @Override + public U createValues(Shape shape) { + return source.createValues(shape); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/All.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/All.java new file mode 100644 index 00000000000..e21b9030315 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/All.java @@ -0,0 +1,56 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.index; + +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class All implements Index { + + static final All INSTANCE = new All(); + + @Override + public long numElements(Dimension dim) { + return dim.numElements(); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return coordinate; + } + + @Override + public Dimension apply(Dimension dim) { + return dim; + } + + private All() {} + + @Override + public boolean beginMask() { + return true; + } + + @Override + public boolean endMask() { + return true; + } + + @Override + public String toString() { + return All.class.getSimpleName() + "()"; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/At.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/At.java new file mode 100644 index 00000000000..cbe142a84b1 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/At.java @@ -0,0 +1,74 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class At implements Index { + + @Override + public long numElements(Dimension dim) { + return 1; + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + long coord = this.coord >= 0 ? this.coord : dim.numElements() + this.coord; + return dim.positionOf(coord); + } + + @Override + public Dimension apply(Dimension dim) { + if (!keepDim) { + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); + } + + return dim.withIndex(this); + } + + @Override + public boolean isPoint() { + return !keepDim; + } + + At(long coord, boolean keepDim) { + this.coord = coord; + this.keepDim = keepDim; + } + + private final long coord; + private final boolean keepDim; + + @Override + public long begin() { + return coord; + } + + @Override + public long end() { + return coord + 1; + } + + @Override + public String toString() { + return new StringJoiner(", ", At.class.getSimpleName() + "(", ")") + .add("coord=" + coord) + .add("keepDim=" + keepDim) + .toString(); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java new file mode 100644 index 00000000000..244ea333bd4 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java @@ -0,0 +1,46 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ +package org.tensorflow.ndarray.index; + +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class Ellipsis implements Index { + + static final Ellipsis INSTANCE = new Ellipsis(); + + private Ellipsis() {} + + @Override + public long numElements(Dimension dim) { + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); + } + + @Override + public boolean isEllipsis() { + return true; + } + + @Override + public String toString() { + return Ellipsis.class.getSimpleName() + "()"; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java new file mode 100644 index 00000000000..8d01b3d21d6 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java @@ -0,0 +1,88 @@ +/* + * Copyright 2020 Matteo Di Giovinazzo. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +/** + * A hyperslab is a rectangular pattern defined by four arrays. + * + *

The {@code start} defines the origin of the hyperslab in the original coordinates. The {@code + * stride} is the number of elements to increment between selected elements. A stride of '1' is + * every element, a stride of '2' is every second element, etc. The default stride is 1. The {@code + * count} is the number of elements in the hyperslab selection. When the stride is 1, the selection + * is a hyper rectangle with a corner at {@code start} and size {@code count[0]} by {@code count[1]} + * by ... When stride is greater than one, the hyperslab bounded by start and the corners defined by + * {@code stride[n] * count[n]}. The {@code block} is a count on the number of repetitions of the + * hyperslab. The default block size is '1', which is one hyperslab. A block of 2 would be two + * hyperslabs in that dimension, with the second starting at {@code start[n]+ (count[n] * stride[n]) + * + 1}. + * + * @see https://portal.hdfgroup.org/display/HDF5/Reading+From+or+Writing+To+a+Subset+of+a+Dataset + * @see https://portal.hdfgroup.org/display/HDF5/H5S_SELECT_HYPERSLAB + * @see https://support.hdfgroup.org/HDF5/doc1.6/UG/12_Dataspaces.html + * @author Matteo Di Giovinazzo + */ +final class Hyperslab implements Index { + + @Override + public long numElements(Dimension dimension) { + return count * block; + } + + @Override + public long mapCoordinate(long coordinate, Dimension dimension) { + return start + stride * (coordinate / block) + (coordinate % block); + } + + @Override + public Dimension apply(Dimension dim) { + return dim.withIndex(this); + } + + @Override + public boolean isPoint() { + return false; + } + + Hyperslab(long start, long stride, long count, long block) { + this.start = start; + this.stride = stride; + this.count = count; + this.block = block; + } + + private final long start; + private final long stride; + private final long count; + private final long block; + + @Override + public String toString() { + return new StringJoiner(", ", Hyperslab.class.getSimpleName() + "Hyperslab(", ")") + .add("start=" + start) + .add("stride=" + stride) + .add("count=" + count) + .add("block=" + block) + .toString(); + } + + @Override + public boolean isStridedSlicingCompliant() { + return false; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java new file mode 100644 index 00000000000..b98bb0dc988 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java @@ -0,0 +1,126 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.index; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +/** + * An index used for slicing a view out of an N-dimensional array. + * + *

A slice, i.e. a reduced view, of an N-dimensional array is obtain by calling {@link + * NdArray#slice(Index...)}, given a list of indices that select which elements on a given dimension + * should be included/excluded from that view. + */ +public interface Index { + + /** + * Returns the number of elements that can be retrieved using this index on the given dimension. + * + *

An index that maps one-by-one all elements of the dimensions will return a value equal to + * {@code dim.numElements()}, while an index that only maps a subset of these will return a + * smaller value. + * + * @param dim the indexed dimension + * @return number of elements accessible + */ + long numElements(Dimension dim); + + /** + * Transforms an element coordinate to a new coordinate by applying this index to the given + * dimension. + * + *

For example, if the coordinate is 0 and this index flips the {@code n} elements on this + * dimension, then the returned value will be {@code n-1}. + * + * @param coordinate coordinate to transform + * @param dim dimension the indexed dimension + * @return transformed coordinate + */ + long mapCoordinate(long coordinate, Dimension dim); + + /** + * Applies this index to the given dimension. + * + *

When accessing the elements from the returned dimension, this index will automatically apply + * and may transform the original position. + * + * @param dim dimension to apply this index to + * @return an indexed dimension + */ + default Dimension apply(Dimension dim) { + return dim.withIndex(this); + } + + /** Returns true if this index is a single point, reducing the number of dimensions by one */ + default boolean isPoint() { + return false; + } + + /** Returns true if this index is a new axis, adding a dimension of size 1 */ + default boolean isNewAxis() { + return false; + } + + /** + * Returns true if this index is an ellipsis, expanding to take as many dimensions as possible + * (and applying all() to them) + */ + default boolean isEllipsis() { + return false; + } + + /** + * Get whether the Index supports strided slice style indexing (using start, end, stride, and + * flags, i.e. TensorFlow's). + */ + default boolean isStridedSlicingCompliant() { + return true; + } + + /** Get the start of the index, for strided slice style indexing. */ + default long begin() { + return 0; + } + + /** Get the end of the index, strided slice style indexing. */ + default long end() { + return 0; + } + + /** Get the stride of the index, for strided slice style indexing. */ + default long stride() { + return 1; + } + + /** + * Get whether the Index should start at the beginning of the dimension, for strided slice style + * indexing. + */ + default boolean beginMask() { + return false; + } + + /** + * Get whether the Index should end at the beginning of the dimension, for strided slice style + * indexing. + */ + default boolean endMask() { + return false; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java new file mode 100644 index 00000000000..39f37b90205 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java @@ -0,0 +1,365 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.index; + +import org.tensorflow.ndarray.IllegalRankException; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffers; + +/** Helper class for instantiating {@link Index} objects. */ +public final class Indices { + + /** + * A coordinate that selects a specific element on a given dimension. + * + *

When this index is applied to a given dimension, the dimension is resolved as a single + * element and therefore is excluded from the computation of the rank. + * + *

For example, given a 3D matrix on the axis [x, y, z], if {@code matrix.slice(all(), at(0), + * at(0)}, then the rank of the returned slice is 1 and its number of elements is {@code + * x.numElements()} + * + * @param coord coordinate of the element on the indexed axis + * @return index + */ + public static Index at(long coord) { + return new At(coord, false); + } + + /** + * A coordinate that selects a specific element on a given dimension. + * + *

This is equivalent to call {@link #at(long)} but where the value of the coordinate is + * provided by an N-dimensional array. + * + * @param coord scalar indicating the coordinate of the element on the indexed axis + * @return index + * @throws IllegalRankException if {@code coord} is not a scalar (rank 0) + */ + public static Index at(NdArray coord) { + if (coord.rank() > 0) { + throw new IllegalRankException("Only scalars are accepted as a value index"); + } + return new At(coord.getObject().longValue(), false); + } + + /** + * A coordinate that selects a specific element on a given dimension. + * + *

When this index is applied to a given dimension, the dimension is resolved as a single + * element and therefore, if {@code keepDim} is false, is excluded from the computation of the + * rank. If {@code} keepDim is true, the dimension is collapsed down to one element. + * + *

For example, given a 3D matrix on the axis [x, y, z], if {@code matrix.slice(all(), at(0), + * at(0)}, then the rank of the returned slice is 1 and its number of elements is {@code + * x.numElements()} + * + * @param coord coordinate of the element on the indexed axis + * @param keepDim whether to remove the dimension. + * @return index + */ + public static Index at(long coord, boolean keepDim) { + return new At(coord, keepDim); + } + + /** + * A coordinate that selects a specific element on a given dimension. + * + *

This is equivalent to call {@link #at(long, boolean)} but where the value of the coordinate + * is provided by an N-dimensional array. + * + *

If {@code} keepDim is true, the dimension is collapsed down to one element instead of being + * removed. + * + * @param coord scalar indicating the coordinate of the element on the indexed axis + * @param keepDim whether to remove the dimension. + * @return index + * @throws IllegalRankException if {@code coord} is not a scalar (rank 0) + */ + public static Index at(NdArray coord, boolean keepDim) { + if (coord.rank() > 0) { + throw new IllegalRankException("Only scalars are accepted as a value index"); + } + return new At(coord.getObject().longValue(), keepDim); + } + + /** + * An index that returns all elements of a dimension in the original order. + * + *

Applying this index to a given dimension will return the original dimension directly. + * + *

For example, given a vector with {@code n} elements, {@code all()} returns x0, + * x1, ..., xn-1 + * + * @return index + */ + public static Index all() { + return All.INSTANCE; + } + + /** + * An index that returns only specific elements on a given dimension. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > + * 10}, {@code seq(8, 0, 3)} returns x8, x0, x3 + * + * @param coords coordinates of the elements in the sequence + * @return index + */ + public static Index seq(long... coords) { + if (coords == null) { + throw new IllegalArgumentException(); + } + return new Sequence( + NdArrays.wrap(Shape.of(coords.length), DataBuffers.of(coords, true, false))); + } + + /** + * An index that returns only specific elements on a given dimension. + * + *

This is equivalent to {@link #seq(long...)} but where the coordinates of the elements in the + * sequence are provided by an N-dimensional array. + * + * @param coords vector of coordinates of the elements in the sequence + * @return index + * @throws IllegalRankException if {@code coords} is not a vector (rank 1) + */ + public static Index seq(NdArray coords) { + if (coords.rank() != 1) { + throw new IllegalRankException("Only vectors are accepted as an element index"); + } + return new Sequence(coords); + } + + /** + * An index that returns only elements found at an even position in the original dimension. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and n is even, + * {@code even()} returns x0, x2, ..., xn-2 + * + * @return index + */ + public static Index even() { + return step(2); + } + + /** + * An index that returns only elements found at an odd position in the original dimension. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and n is even, + * {@code odd()} returns x1, x3, ..., xn-1 + * + * @return index + */ + public static Index odd() { + return sliceFrom(1, 2); + } + + /** + * An index that skips a fixed amount of coordinates between each values returned. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, {@code step(k)} + * returns x0, xk, xk*2, ... + * + * @param stride the number of elements between each steps + * @return index + */ + public static Index step(long stride) { + return new Step(stride); + } + + /** + * An index that returns only elements on a given dimension starting at a specific coordinate. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > + * k}, {@code from(k)} returns xk, xk+1, ..., xn-1 + * + * @param start coordinate of the first element of the sequence + * @return index + */ + public static Index sliceFrom(long start) { + return sliceFrom(start, 1); + } + + /** + * An index that returns only elements on a given dimension starting at a specific coordinate, + * using the given stride. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > + * k}, {@code from(k)} returns xk, xk+1, ..., xn-1 + * + * @param start coordinate of the first element of the sequence + * @param stride the stride to use + * @return index + * @see #slice(long, long, long) + */ + public static Index sliceFrom(long start, long stride) { + return new SliceFrom(start, stride); + } + + /** + * An index that returns only elements on a given dimension up to a specific coordinate. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > + * k}, {@code to(k)} returns x0, x1, ..., xk + * + * @param end coordinate of the last element of the sequence (exclusive) + * @return index + */ + public static Index sliceTo(long end) { + return sliceTo(end, 1); + } + + /** + * An index that returns only elements on a given dimension up to a specific coordinate, using the + * given stride. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > + * k}, {@code to(k)} returns x0, x1, ..., xk + * + * @param end coordinate of the last element of the sequence (exclusive) + * @param stride the stride to use + * @return index + * @see #slice(long, long, long) + */ + public static Index sliceTo(long end, long stride) { + return new SliceTo(end, stride); + } + + /** + * An index that returns only elements on a given dimension between two coordinates. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k + * > j}, {@code range(j, k)} returns xj, xj+1, ..., xk + * + * @param start coordinate of the first element of the sequence + * @param end coordinate of the last element of the sequence (exclusive) + * @return index + */ + public static Index range(long start, long end) { + return slice(start, end); + } + + /** + * An index that returns only elements on a given dimension between two coordinates. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k + * > j}, {@code range(j, k)} returns xj, xj+1, ..., xk + * + * @return index + */ + public static Index flip() { + return slice(null, null, -1); + } + + /** + * An index that returns elements according to an hyperslab defined by {@code start}, {@code + * stride}, {@code count}, {@code block}. See {@link Hyperslab}. + * + * @param start Starting location for the hyperslab. + * @param stride The number of elements to separate each element or block to be selected. + * @param count The number of elements or blocks to select along the dimension. + * @param block The size of the block selected from the dimension. + * @return index + */ + public static Index hyperslab(long start, long stride, long count, long block) { + return new Hyperslab(start, stride, count, block); + } + + /** + * An index that inserts a new dimension of size 1 into the resulting array. + * + * @return index + */ + public static Index newAxis() { + return NewAxis.INSTANCE; + } + + /** + * An index that expands to fill all available source dimensions. Works the same as Python's + * {@code ...}. + * + * @return index + */ + public static Index ellipsis() { + return Ellipsis.INSTANCE; + } + + /** + * An index that returns elements between {@code start} and {@code end}. If {@code start} or + * {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + * + *

Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(long start, long end) { + return slice(start, end, 1); + } + + /** + * An index that returns every {@code stride}-th element between {@code start} and {@code end}. If + * {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, + * respectively. + * + *

Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(long start, long end, long stride) { + return new Slice(start, end, stride); + } + + /** + * An index that returns elements between {@code start} and {@code end}. If {@code start} or + * {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + * + *

Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(Long start, Long end) { + return slice(start, end, 1); + } + + /** + * An index that returns every {@code stride}-th element between {@code start} and {@code end}. If + * {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, + * respectively. + * + *

Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static Index slice(Long start, Long end, long stride) { + if (start == null && end == null) { + if (stride == 1) { + return Indices.all(); + } else { + return Indices.step(stride); + } + } else if (start == null) { + return Indices.sliceTo(end, stride); + } else if (end == null) { + return Indices.sliceFrom(start, stride); + } + + return slice(start.longValue(), end.longValue(), stride); + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java new file mode 100644 index 00000000000..47f31bdf9b1 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java @@ -0,0 +1,51 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ +package org.tensorflow.ndarray.index; + +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class NewAxis implements Index { + + static final NewAxis INSTANCE = new NewAxis(); + + private NewAxis() {} + + @Override + public long numElements(Dimension dim) { + return 1; + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return coordinate; + } + + @Override + public Dimension apply(Dimension dim) { + throw new IllegalStateException(); + } + + @Override + public boolean isNewAxis() { + return true; + } + + @Override + public String toString() { + return NewAxis.class.getSimpleName() + "()"; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java new file mode 100644 index 00000000000..beda853abb3 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java @@ -0,0 +1,52 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class Sequence implements Index { + + @Override + public long numElements(Dimension dim) { + return coords.size(); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return coords.getObject(coordinate).longValue(); + } + + Sequence(NdArray coords) { + this.coords = coords; + } + + private final NdArray coords; + + @Override + public String toString() { + return new StringJoiner(", ", Sequence.class.getSimpleName() + "(", ")") + .add("coords=" + coords) + .toString(); + } + + @Override + public boolean isStridedSlicingCompliant() { + return false; + } +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java new file mode 100644 index 00000000000..74743c68fa2 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java @@ -0,0 +1,89 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class Slice implements Index { + + Slice(long start, long end, long stride) { + this.start = start; + this.end = end; + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + + @Override + public long numElements(Dimension dim) { + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return start(dim) + stride * coordinate; + } + + @Override + public long begin() { + return start; + } + + @Override + public long end() { + return end; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", Slice.class.getSimpleName() + "(", ")") + .add("start=" + start) + .add("end=" + end) + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (start < 0) { + return dim.numElements() + start; + } + + return start; + } + + private long end(Dimension dim) { + if (end < 0) { + return dim.numElements() + end; + } else { + return end; + } + } + + private final long start; + private final long end; + private final long stride; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java new file mode 100644 index 00000000000..10ae6d0f09a --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java @@ -0,0 +1,86 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class SliceFrom implements Index { + + SliceFrom(long start, long stride) { + this.start = start; + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + + @Override + public long numElements(Dimension dim) { + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return start(dim) + stride * coordinate; + } + + @Override + public long begin() { + return start; + } + + @Override + public boolean endMask() { + return true; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", SliceFrom.class.getSimpleName() + "(", ")") + .add("start=" + start) + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (start < 0) { + return dim.numElements() + start; + } + + return start; + } + + private long end(Dimension dim) { + if (stride > 0) { + return dim.numElements(); + } else { + return -1; // it's exclusive + } + } + + private final long start; + private final long stride; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java new file mode 100644 index 00000000000..18f72585530 --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java @@ -0,0 +1,86 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class SliceTo implements Index { + + SliceTo(long end, long stride) { + this.end = end; + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + + @Override + public long numElements(Dimension dim) { + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return start(dim) + stride * coordinate; + } + + @Override + public long end() { + return end; + } + + @Override + public boolean beginMask() { + return true; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", SliceTo.class.getSimpleName() + "(", ")") + .add("end=" + end) + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (stride > 0) { + return 0; + } + + return dim.numElements() - 1; // it's inclusive + } + + private long end(Dimension dim) { + if (end < 0) { + return dim.numElements() + end; + } else { + return end; + } + } + + private final long end; + private final long stride; +} diff --git a/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java new file mode 100644 index 00000000000..fc407bbe55b --- /dev/null +++ b/tensorflow-ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java @@ -0,0 +1,83 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class Step implements Index { + + Step(long stride) { + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + + @Override + public long numElements(Dimension dim) { + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return start(dim) + stride * coordinate; + } + + @Override + public boolean beginMask() { + return true; + } + + @Override + public boolean endMask() { + return true; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", Step.class.getSimpleName() + "(", ")") + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (stride > 0) { + return 0; + } + + return dim.numElements() - 1; // it's inclusive + } + + private long end(Dimension dim) { + if (stride > 0) { + return dim.numElements(); + } else { + return -1; // it's exclusive + } + } + + private final long stride; +} diff --git a/tensorflow-ndarray/src/test/java/module-info.test b/tensorflow-ndarray/src/test/java/module-info.test new file mode 100644 index 00000000000..310e500ee47 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/module-info.test @@ -0,0 +1,22 @@ +/* + Copyright 2022 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ======================================================================= + */ +module org.tensorflow.ndarray { + requires java.desktop; // required for java.awt.* + + requires transitive org.junit.jupiter.engine; + requires transitive org.junit.jupiter.api; +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/BooleanNdArrayTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/BooleanNdArrayTestBase.java new file mode 100644 index 00000000000..f11a1193a35 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/BooleanNdArrayTestBase.java @@ -0,0 +1,53 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.tensorflow.ndarray.NdArrays.vectorOf; + +import org.junit.jupiter.api.Test; + +public abstract class BooleanNdArrayTestBase extends NdArrayTestBase { + + @Override + protected abstract BooleanNdArray allocate(Shape shape); + + @Override + protected Boolean valueOf(Long val) { + return val > 0; + } + + @Test + public void iteratePrimitiveElements() { + BooleanNdArray matrix3d = allocate(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> scalar.setBoolean(coords[2] > 0)); + + assertFalse(matrix3d.getBoolean(0, 0, 0)); + assertTrue(matrix3d.getBoolean(0, 0, 1)); + assertTrue(matrix3d.getBoolean(0, 0, 4)); + assertTrue(matrix3d.getBoolean(0, 1, 2)); + + matrix3d.elements(1).forEach(vector -> vector.set(vectorOf(true, false, true, false, true))); + + assertTrue(matrix3d.getBoolean(0, 0, 0)); + assertFalse(matrix3d.getBoolean(0, 0, 1)); + assertTrue(matrix3d.getBoolean(0, 0, 4)); + assertTrue(matrix3d.getBoolean(0, 1, 2)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/ByteNdArrayTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/ByteNdArrayTestBase.java new file mode 100644 index 00000000000..be8c99a6b1e --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/ByteNdArrayTestBase.java @@ -0,0 +1,55 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public abstract class ByteNdArrayTestBase extends NdArrayTestBase { + + @Override + protected abstract ByteNdArray allocate(Shape shape); + + @Override + protected Byte valueOf(Long val) { + return val.byteValue(); + } + + @Test + public void iteratePrimitiveElements() { + ByteNdArray matrix3d = allocate(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> scalar.setByte((byte) coords[2])); + + assertEquals(0, matrix3d.getByte(0, 0, 0)); + assertEquals(1, matrix3d.getByte(0, 0, 1)); + assertEquals(4, matrix3d.getByte(0, 0, 4)); + assertEquals(2, matrix3d.getByte(0, 1, 2)); + + matrix3d + .elements(1) + .forEach( + vector -> + vector.set(NdArrays.vectorOf((byte) 5, (byte) 6, (byte) 7, (byte) 8, (byte) 9))); + + assertEquals(5, matrix3d.getByte(0, 0, 0)); + assertEquals(6, matrix3d.getByte(0, 0, 1)); + assertEquals(9, matrix3d.getByte(0, 0, 4)); + assertEquals(7, matrix3d.getByte(0, 1, 2)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/DoubleNdArrayTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/DoubleNdArrayTestBase.java new file mode 100644 index 00000000000..1bcca203ff7 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/DoubleNdArrayTestBase.java @@ -0,0 +1,51 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public abstract class DoubleNdArrayTestBase extends NdArrayTestBase { + + @Override + protected abstract DoubleNdArray allocate(Shape shape); + + @Override + protected Double valueOf(Long val) { + return val.doubleValue(); + } + + @Test + public void iteratePrimitiveElements() { + DoubleNdArray matrix3d = allocate(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> scalar.setDouble((double) coords[2])); + + assertEquals(0.0, matrix3d.getDouble(0, 0, 0), 0.0); + assertEquals(1.0, matrix3d.getDouble(0, 0, 1), 0.0); + assertEquals(4.0, matrix3d.getDouble(0, 0, 4), 0.0); + assertEquals(2.0, matrix3d.getDouble(0, 1, 2), 0.0); + + matrix3d.elements(1).forEach(vector -> vector.set(NdArrays.vectorOf(5.0, 6.0, 7.0, 8.0, 9.0))); + + assertEquals(5, matrix3d.getDouble(0, 0, 0), 0.0); + assertEquals(6, matrix3d.getDouble(0, 0, 1), 0.0); + assertEquals(9, matrix3d.getDouble(0, 0, 4), 0.0); + assertEquals(7, matrix3d.getDouble(0, 1, 2), 0.0); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/FloatNdArrayTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/FloatNdArrayTestBase.java new file mode 100644 index 00000000000..6d11346df76 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/FloatNdArrayTestBase.java @@ -0,0 +1,53 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public abstract class FloatNdArrayTestBase extends NdArrayTestBase { + + @Override + protected abstract FloatNdArray allocate(Shape shape); + + @Override + protected Float valueOf(Long val) { + return val.floatValue(); + } + + @Test + public void iteratePrimitiveElements() { + FloatNdArray matrix3d = allocate(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> scalar.setFloat((float) coords[2])); + + assertEquals(0.0f, matrix3d.getFloat(0, 0, 0), 0.0f); + assertEquals(1.0f, matrix3d.getFloat(0, 0, 1), 0.0f); + assertEquals(4.0f, matrix3d.getFloat(0, 0, 4), 0.0f); + assertEquals(2.0f, matrix3d.getFloat(0, 1, 2), 0.0f); + + matrix3d + .elements(1) + .forEach(vector -> vector.set(NdArrays.vectorOf(5.0f, 6.0f, 7.0f, 8.0f, 9.0f))); + + assertEquals(5, matrix3d.getFloat(0, 0, 0), 0.0f); + assertEquals(6, matrix3d.getFloat(0, 0, 1), 0.0f); + assertEquals(9, matrix3d.getFloat(0, 0, 4), 0.0f); + assertEquals(7, matrix3d.getFloat(0, 1, 2), 0.0f); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java new file mode 100644 index 00000000000..94897f63129 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java @@ -0,0 +1,555 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ +package org.tensorflow.ndarray; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.index.Indices; + +public class IndexTest { + @Test + public void testNullConversions() { + assertTrue( + Indices.slice(null, 0L).beginMask(), + "Passed null for slice start but didn't set begin mask"); + + assertTrue( + Indices.slice(null, 0L).beginMask(), + "Passed null for slice start but didn't set begin mask"); + + assertTrue( + Indices.slice(null, null).beginMask(), + "Passed null for slice start but didn't set begin mask"); + + assertTrue( + Indices.slice(0L, null).endMask(), "Passed null for slice end but didn't set end mask"); + + assertTrue( + Indices.slice(0L, null).endMask(), "Passed null for slice end but didn't set end mask"); + + assertTrue( + Indices.slice(null, null).endMask(), "Passed null for slice end but didn't set end mask"); + } + + @Test + public void testIndices() { + + String[][] indexData = new String[5][4]; + for (int i = 0; i < 5; i++) { + for (int j = 0; j < 4; j++) indexData[i][j] = "(" + j + ", " + i + ")"; + } + + NdArray matrix2d = StdArrays.ndCopyOf(indexData); + assertEquals(2, matrix2d.rank()); + + /* + |(0, 0), (1, 0), (2, 0), (3, 0)| + |(0, 1), (1, 1), (2, 1), (3, 1)| + |(0, 2), (1, 2), (2, 2), (3, 2)| + |(0, 3), (1, 3), (2, 3), (3, 3)| + |(0, 4), (1, 4), (2, 4), (3, 4)| + */ + assertArrayEquals(new String[] {"(0, 0)", "(1, 0)", "(2, 0)", "(3, 0)"}, indexData[0]); + + NdArray same1 = matrix2d.slice(Indices.all()); + String[][] same1j = StdArrays.array2dCopyOf(same1, String.class); + assertEquals(2, same1.rank()); + assertEquals(same1, matrix2d); + assertEquals(matrix2d, StdArrays.ndCopyOf(same1j)); + + NdArray same2 = matrix2d.slice(Indices.ellipsis()); + String[][] same2j = StdArrays.array2dCopyOf(same2, String.class); + assertEquals(2, same2.rank()); + assertEquals(matrix2d, same2); + assertEquals(matrix2d, StdArrays.ndCopyOf(same2j)); + + // All rows, column 1 + NdArray same3 = matrix2d.slice(Indices.all(), Indices.at(1)); + assertEquals(1, same3.rank()); + String[] same3j = StdArrays.array1dCopyOf(same3, String.class); + assertArrayEquals(new String[] {"(1, 0)", "(1, 1)", "(1, 2)", "(1, 3)", "(1, 4)"}, same3j); + + // row 2, all columns + NdArray same4 = matrix2d.slice(Indices.at(2), Indices.all()); + assertEquals(1, same4.rank()); + String[] same4j = StdArrays.array1dCopyOf(same4, String.class); + assertArrayEquals(new String[] {"(0, 2)", "(1, 2)", "(2, 2)", "(3, 2)"}, same4j); + assertEquals(NdArrays.vectorOfObjects("(0, 2)", "(1, 2)", "(2, 2)", "(3, 2)"), same4); + + // row 2, column 1 + NdArray same5 = matrix2d.slice(Indices.at(2), Indices.at(1)); + assertEquals(0, same5.rank()); + assertTrue(same5.shape().isScalar()); + // Don't use an index + String same5j = same5.getObject(); + assertEquals("(1, 2)", same5j); + + // rows 1 to 2, all columns + NdArray same6 = matrix2d.slice(Indices.slice(1, 3)); + assertEquals(2, same6.rank()); + String[][] same6j = StdArrays.array2dCopyOf(same6, String.class); + assertArrayEquals( + new String[][] { + {"(0, 1)", "(1, 1)", "(2, 1)", "(3, 1)"}, + {"(0, 2)", "(1, 2)", "(2, 2)", "(3, 2)"} + }, + same6j); + + // Exception in thread "main" java.nio.BufferOverflowException + // all rows, columns 1 to 2 + NdArray same7 = matrix2d.slice(Indices.all(), Indices.slice(1, 3)); + assertEquals(2, same7.rank()); + assertEquals(Shape.of(5, 2), same7.shape()); + assertEquals(10, same7.size()); + NdArray r7_0 = same7.get(0); + NdArray r7_1 = same7.get(1); + NdArray r7_2 = same7.get(2); + NdArray r7_3 = same7.get(3); + NdArray r7_4 = same7.get(4); + assertEquals(1, r7_0.rank()); + assertEquals(Shape.of(2), r7_0.shape()); + assertEquals(2, r7_0.size()); + // TODO: I get a (0,0) which is not what I expected + // System.out.println(r7_0.getObject()); + // assertEquals("(1,0)", r7_0.getObject()); + assertEquals("(1, 0)", r7_0.getObject(0)); + assertEquals("(2, 0)", r7_0.getObject(1)); + assertEquals("(1, 1)", r7_1.getObject(0)); + assertEquals("(2, 1)", r7_1.getObject(1)); + assertEquals("(1, 2)", r7_2.getObject(0)); + assertEquals("(2, 2)", r7_2.getObject(1)); + assertEquals("(1, 3)", r7_3.getObject(0)); + assertEquals("(2, 3)", r7_3.getObject(1)); + assertEquals("(1, 4)", r7_4.getObject(0)); + assertEquals("(2, 4)", r7_4.getObject(1)); + String[][] expectedr7 = + new String[][] { + {"(1, 0)", "(2, 0)"}, + {"(1, 1)", "(2, 1)"}, + {"(1, 2)", "(2, 2)"}, + {"(1, 3)", "(2, 3)"}, + {"(1, 4)", "(2, 4)"} + }; + String[][] lArray = new String[5][2]; + StdArrays.copyFrom(same7, lArray); + assertArrayEquals(expectedr7, lArray); + String[][] same7j = StdArrays.array2dCopyOf(same7, String.class); + assertArrayEquals(expectedr7, same7j); + + // rows 1 to 2, columns 1 to 2 + NdArray same8 = matrix2d.slice(Indices.slice(1, 3), Indices.slice(1, 3)); + assertEquals(2, same8.rank()); + assertEquals(Shape.of(2, 2), same8.shape()); + assertEquals(4, same8.size()); + String[][] same8j = StdArrays.array2dCopyOf(same8, String.class); + // print2D(same8j) + String[][] expected_r8 = + new String[][] { + {"(1, 1)", "(2, 1)"}, + {"(1, 2)", "(2, 2)"} + }; + assertArrayEquals(expected_r8, same8j); + NdArray r8_0 = same8.get(0); + NdArray r8_1 = same8.get(1); + assertEquals(1, r8_0.rank()); + assertEquals(Shape.of(2), r8_0.shape()); + assertEquals(2, r8_0.size()); + assertEquals("(1, 1)", r8_0.getObject(0)); + assertEquals("(2, 1)", r8_0.getObject(1)); + assertEquals("(1, 2)", r8_1.getObject(0)); + assertEquals("(2, 2)", r8_1.getObject(1)); + + // rows 1 to 2, columns 1 to 2 + NdArray same9 = matrix2d.slice(Indices.range(1, 3), Indices.range(1, 3)); + assertEquals(2, same9.rank()); + assertEquals(Shape.of(2, 2), same9.shape()); + assertEquals(4, same9.size()); + String[][] same9j = StdArrays.array2dCopyOf(same9, String.class); + String[][] expected_r9 = + new String[][] { + {"(1, 1)", "(2, 1)"}, + {"(1, 2)", "(2, 2)"} + }; + assertArrayEquals(expected_r9, same9j); + NdArray r9_0 = same9.get(0); + NdArray r9_1 = same9.get(1); + assertEquals(1, r9_0.rank()); + assertEquals(Shape.of(2), r9_0.shape()); + assertEquals(2, r9_0.size()); + assertEquals("(1, 1)", r9_0.getObject(0)); + assertEquals("(2, 1)", r9_0.getObject(1)); + assertEquals("(1, 2)", r9_1.getObject(0)); + assertEquals("(2, 2)", r9_1.getObject(1)); + + // rows 1, 3 and 4, columns 0 to 2 + NdArray same10 = matrix2d.slice(Indices.odd(), Indices.even()); + String[][] same10j = StdArrays.array2dCopyOf(same10, String.class); + assertEquals(2, same10.rank()); + assertEquals(Shape.of(2, 2), same10.shape()); + assertEquals(4, same10.size()); + String[][] expected_r10 = + new String[][] { + {"(0, 1)", "(2, 1)"}, + {"(0, 3)", "(2, 3)"} + }; + assertArrayEquals(expected_r10, same10j); + NdArray r10_0 = same10.get(0); + NdArray r10_1 = same10.get(1); + assertEquals(1, r10_0.rank()); + assertEquals(Shape.of(2), r10_0.shape()); + assertEquals(2, r10_0.size()); + assertEquals("(0, 1)", r10_0.getObject(0)); + assertEquals("(2, 1)", r10_0.getObject(1)); + assertEquals("(0, 3)", r10_1.getObject(0)); + assertEquals("(2, 3)", r10_1.getObject(1)); + + // rows 3 and 4, columns 0 and 1. Second value is stride + NdArray same11 = matrix2d.slice(Indices.sliceFrom(3, 1), Indices.sliceFrom(2, 1)); + String[][] same11j = StdArrays.array2dCopyOf(same11, String.class); + assertEquals(2, same11.rank()); + assertEquals(Shape.of(2, 2), same11.shape()); + assertEquals(4, same11.size()); + String[][] expected_r11 = + new String[][] { + {"(2, 3)", "(3, 3)"}, + {"(2, 4)", "(3, 4)"} + }; + assertArrayEquals(expected_r11, same11j); + NdArray r11_0 = same11.get(0); + NdArray r11_1 = same11.get(1); + assertEquals(1, r11_0.rank()); + assertEquals(Shape.of(2), r11_0.shape()); + assertEquals(2, r11_0.size()); + assertEquals("(2, 3)", r11_0.getObject(0)); + assertEquals("(3, 3)", r11_0.getObject(1)); + assertEquals("(2, 4)", r11_1.getObject(0)); + assertEquals("(3, 4)", r11_1.getObject(1)); + + // rows 0 and 2, columns 0 and 1. Second value is stride. Index non inclusive + NdArray same12 = matrix2d.slice(Indices.sliceTo(3, 2), Indices.sliceTo(2, 1)); + String[][] same12j = StdArrays.array2dCopyOf(same12, String.class); + assertEquals(2, same12.rank()); + assertEquals(Shape.of(2, 2), same12.shape()); + assertEquals(4, same12.size()); + String[][] expected_r12 = + new String[][] { + {"(0, 0)", "(1, 0)"}, + {"(0, 2)", "(1, 2)"} + }; + assertArrayEquals(expected_r12, same12j); + NdArray r12_0 = same12.get(0); + NdArray r12_1 = same12.get(1); + assertEquals(1, r12_0.rank()); + assertEquals(Shape.of(2), r12_0.shape()); + assertEquals(2, r12_0.size()); + assertEquals("(0, 0)", r12_0.getObject(0)); + assertEquals("(1, 0)", r12_0.getObject(1)); + assertEquals("(0, 2)", r12_1.getObject(0)); + assertEquals("(1, 2)", r12_1.getObject(1)); + + // rows 0 and 2, columns 0 and 1. Second value is stride. Index non inclusive + NdArray same13 = matrix2d.slice(Indices.step(2), Indices.step(2)); + String[][] same13j = StdArrays.array2dCopyOf(same13, String.class); + assertEquals(2, same13.rank()); + assertEquals(Shape.of(3, 2), same13.shape()); + assertEquals(6, same13.size()); + String[][] expected_r13 = + new String[][] { + {"(0, 0)", "(2, 0)"}, + {"(0, 2)", "(2, 2)"}, + {"(0, 4)", "(2, 4)"} + }; + assertArrayEquals(expected_r13, same13j); + NdArray r13_0 = same13.get(0); + NdArray r13_1 = same13.get(1); + NdArray r13_2 = same13.get(2); + assertEquals(1, r13_0.rank()); + assertEquals(Shape.of(2), r13_0.shape()); + assertEquals(2, r13_0.size()); + assertEquals("(0, 0)", r13_0.getObject(0)); + assertEquals("(2, 0)", r13_0.getObject(1)); + assertEquals("(0, 2)", r13_1.getObject(0)); + assertEquals("(2, 2)", r13_1.getObject(1)); + assertEquals("(0, 4)", r13_2.getObject(0)); + assertEquals("(2, 4)", r13_2.getObject(1)); + + NdArray same14 = same13.slice(Indices.flip(), Indices.flip()); + String[][] same14j = StdArrays.array2dCopyOf(same14, String.class); + assertEquals(2, same14.rank()); + assertEquals(Shape.of(3, 2), same14.shape()); + assertEquals(6, same14.size()); + String[][] expected_r14 = + new String[][] { + {"(2, 4)", "(0, 4)"}, + {"(2, 2)", "(0, 2)"}, + {"(2, 0)", "(0, 0)"} + }; + assertArrayEquals(same14j, expected_r14); + NdArray r14_0 = same14.get(0); + NdArray r14_1 = same14.get(1); + NdArray r14_2 = same14.get(2); + assertEquals(1, r14_0.rank()); + assertEquals(Shape.of(2), r14_0.shape()); + assertEquals(2, r14_0.size()); + assertEquals("(0, 0)", r14_2.getObject(1)); + assertEquals("(2, 0)", r14_2.getObject(0)); + assertEquals("(0, 2)", r14_1.getObject(1)); + assertEquals("(2, 2)", r14_1.getObject(0)); + assertEquals("(0, 4)", r14_0.getObject(1)); + assertEquals("(2, 4)", r14_0.getObject(0)); + + NdArray same15 = matrix2d.slice(Indices.slice(4, 0, -2), Indices.slice(3L, null, -2)); + String[][] same15j = StdArrays.array2dCopyOf(same15, String.class); + assertEquals(2, same15.rank()); + assertEquals(Shape.of(2, 2), same15.shape()); + assertEquals(4, same15.size()); + String[][] expected_r15 = + new String[][] { + {"(3, 4)", "(1, 4)"}, + {"(3, 2)", "(1, 2)"}, + }; + assertArrayEquals(expected_r15, same15j); + NdArray r15_0 = same15.get(0); + NdArray r15_1 = same15.get(1); + assertEquals(1, r15_0.rank()); + assertEquals(Shape.of(2), r15_0.shape()); + assertEquals(2, r15_0.size()); + assertEquals("(3, 4)", r15_0.getObject(0)); + assertEquals("(1, 4)", r15_0.getObject(1)); + assertEquals("(3, 2)", r15_1.getObject(0)); + assertEquals("(1, 2)", r15_1.getObject(1)); + + NdArray same16 = matrix2d.slice(Indices.seq(4, 2), Indices.seq(3, 1)); + String[][] same16j = StdArrays.array2dCopyOf(same16, String.class); + assertEquals(2, same16.rank()); + assertEquals(Shape.of(2, 2), same16.shape()); + assertEquals(4, same16.size()); + String[][] expected_r16 = + new String[][] { + {"(3, 4)", "(1, 4)"}, + {"(3, 2)", "(1, 2)"} + }; + assertArrayEquals(expected_r16, same16j); + NdArray r16_0 = same16.get(0); + NdArray r16_1 = same16.get(1); + assertEquals(1, r16_0.rank()); + assertEquals(Shape.of(2), r16_0.shape()); + assertEquals(2, r16_0.size()); + assertEquals("(3, 4)", r16_0.getObject(0)); + assertEquals("(1, 4)", r16_0.getObject(1)); + assertEquals("(3, 2)", r16_1.getObject(0)); + assertEquals("(1, 2)", r16_1.getObject(1)); + + // New axis always has size 1 + NdArray same17 = matrix2d.slice(Indices.all(), Indices.all(), Indices.newAxis()); + String[][][] same17j = StdArrays.array3dCopyOf(same17, String.class); + assertEquals(3, same17.rank()); + assertEquals(Shape.of(5, 4, 1), same17.shape()); + assertEquals(20, same17.size()); + String[][][] expected_r17 = + new String[][][] { + {{"(0, 0)"}, {"(1, 0)"}, {"(2, 0)"}, {"(3, 0)"}}, + {{"(0, 1)"}, {"(1, 1)"}, {"(2, 1)"}, {"(3, 1)"}}, + {{"(0, 2)"}, {"(1, 2)"}, {"(2, 2)"}, {"(3, 2)"}}, + {{"(0, 3)"}, {"(1, 3)"}, {"(2, 3)"}, {"(3, 3)"}}, + {{"(0, 4)"}, {"(1, 4)"}, {"(2, 4)"}, {"(3, 4)"}} + }; + assertArrayEquals(expected_r17, same17j); + NdArray r17_0 = same17.get(0); + NdArray r17_1 = same17.get(1); + NdArray r17_2 = same17.get(2); + NdArray r17_3 = same17.get(3); + NdArray r17_4 = same17.get(4); + assertEquals(2, r17_0.rank()); + assertEquals(Shape.of(4, 1), r17_0.shape()); + assertEquals(4, r17_0.size()); + // row 0 + // What use case can we have for a new index of size 1? + // row 1 + assertEquals("(0, 1)", r17_1.getObject(0, 0)); + assertEquals("(1, 1)", r17_1.getObject(1, 0)); + assertEquals("(2, 1)", r17_1.getObject(2, 0)); + assertEquals("(3, 1)", r17_1.getObject(3, 0)); + // row 2 + assertEquals("(0, 2)", r17_2.getObject(0, 0)); + assertEquals("(1, 2)", r17_2.getObject(1, 0)); + assertEquals("(2, 2)", r17_2.getObject(2, 0)); + assertEquals("(3, 2)", r17_2.getObject(3, 0)); + // row 3 + assertEquals("(0, 3)", r17_3.getObject(0, 0)); + assertEquals("(1, 3)", r17_3.getObject(1, 0)); + assertEquals("(2, 3)", r17_3.getObject(2, 0)); + assertEquals("(3, 3)", r17_3.getObject(3, 0)); + // row 4 + assertEquals("(0, 4)", r17_4.getObject(0, 0)); + assertEquals("(1, 4)", r17_4.getObject(1, 0)); + assertEquals("(2, 4)", r17_4.getObject(2, 0)); + assertEquals("(3, 4)", r17_4.getObject(3, 0)); + } + + @Test + public void testNewaxis() { + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> scalar.setInt((int) coords[2])); + + IntNdArray slice1 = + matrix3d.slice(Indices.all(), Indices.all(), Indices.all(), Indices.newAxis()); + + assertEquals(Shape.of(5, 4, 5, 1), slice1.shape()); + assertEquals(0, slice1.getInt(0, 0, 0, 0)); + assertEquals(1, slice1.getInt(0, 0, 1, 0)); + assertEquals(4, slice1.getInt(0, 0, 4, 0)); + assertEquals(2, slice1.getInt(0, 1, 2, 0)); + + IntNdArray slice2 = + matrix3d.slice(Indices.all(), Indices.all(), Indices.newAxis(), Indices.all()); + + assertEquals(Shape.of(5, 4, 1, 5), slice2.shape()); + assertEquals(0, slice2.getInt(0, 0, 0, 0)); + assertEquals(1, slice2.getInt(0, 0, 0, 1)); + assertEquals(4, slice2.getInt(0, 0, 0, 4)); + assertEquals(2, slice2.getInt(0, 1, 0, 2)); + + IntNdArray slice3 = + matrix3d.slice(Indices.all(), Indices.newAxis(), Indices.all(), Indices.all()); + + assertEquals(Shape.of(5, 1, 4, 5), slice3.shape()); + assertEquals(0, slice3.getInt(0, 0, 0, 0)); + assertEquals(1, slice3.getInt(0, 0, 0, 1)); + assertEquals(4, slice3.getInt(0, 0, 0, 4)); + assertEquals(2, slice3.getInt(0, 0, 1, 2)); + + IntNdArray slice4 = + matrix3d.slice(Indices.newAxis(), Indices.all(), Indices.all(), Indices.all()); + + assertEquals(Shape.of(1, 5, 4, 5), slice4.shape()); + assertEquals(0, slice4.getInt(0, 0, 0, 0)); + assertEquals(1, slice4.getInt(0, 0, 0, 1)); + assertEquals(4, slice4.getInt(0, 0, 0, 4)); + assertEquals(2, slice4.getInt(0, 0, 1, 2)); + } + + @Test + public void testEllipsis() { + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> scalar.setInt((int) coords[2])); + + assertEquals( + matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0)), + matrix3d.slice(Indices.ellipsis(), Indices.at(0))); + + assertEquals( + matrix3d.slice(Indices.at(0), Indices.all(), Indices.all()), + matrix3d.slice(Indices.at(0), Indices.ellipsis())); + + assertEquals( + matrix3d.slice(Indices.at(0), Indices.all(), Indices.at(0)), + matrix3d.slice(Indices.at(0), Indices.ellipsis(), Indices.at(0))); + + // newaxis interacts specially with ellipsis (since it doesn't consume a dimension), test this + + assertEquals( + matrix3d.slice(Indices.all(), Indices.all(), Indices.newAxis(), Indices.at(0)), + matrix3d.slice(Indices.ellipsis(), Indices.newAxis(), Indices.at(0))); + + assertEquals( + matrix3d.slice(Indices.newAxis(), Indices.all(), Indices.all(), Indices.at(0)), + matrix3d.slice(Indices.newAxis(), Indices.ellipsis(), Indices.at(0))); + + assertEquals( + matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0), Indices.newAxis()), + matrix3d.slice(Indices.ellipsis(), Indices.at(0), Indices.newAxis())); + } + + @Test + public void testSlice() { + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> scalar.setInt((int) coords[2])); + + IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.sliceTo(3), Indices.all()); + + assertEquals(Shape.of(5, 3, 5), slice1.shape()); + assertEquals(0, slice1.getInt(0, 0, 0)); + assertEquals(1, slice1.getInt(0, 0, 1)); + assertEquals(2, slice1.getInt(0, 1, 2)); + + IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, 4)); + + assertEquals(Shape.of(5, 4, 3), slice2.shape()); + assertEquals(1, slice2.getInt(0, 0, 0)); + assertEquals(3, slice2.getInt(0, 0, 2)); + assertEquals(2, slice2.getInt(0, 1, 1)); + + assertEquals(slice2, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, -1))); + + assertEquals(slice2, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-4, -1))); + + assertEquals( + Shape.of(5, 4, 0), + matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, 4, -2)).shape()); + + IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(4, 1, -2)); + + assertEquals(Shape.of(5, 4, 2), slice3.shape()); + assertEquals(4, slice3.getInt(0, 0, 0)); + assertEquals(2, slice3.getInt(0, 1, 1)); + + assertEquals(slice3, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-1, 1, -2))); + + assertEquals(slice3, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-1, -4, -2))); + + IntNdArray slice4 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(null, null, -1)); + + assertEquals(Shape.of(5, 4, 5), slice4.shape()); + assertEquals(4, slice4.getInt(0, 0, 0)); + assertEquals(3, slice4.getInt(0, 0, 1)); + assertEquals(2, slice4.getInt(0, 1, 2)); + } + + @Test + public void testAt() { + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> scalar.setInt((int) coords[2])); + + IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0)); + + assertEquals(Shape.of(5, 4), slice1.shape()); + assertEquals(0, slice1.getInt(0, 0)); + + IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(3)); + + assertEquals(Shape.of(5, 4), slice2.shape()); + assertEquals(3, slice2.getInt(0, 0)); + + IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(-3)); + + assertEquals(Shape.of(5, 4), slice3.shape()); + assertEquals(2, slice3.getInt(0, 0)); + + IntNdArray slice4 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(-3, true)); + + assertEquals(Shape.of(5, 4, 1), slice4.shape()); + assertEquals(2, slice4.getInt(0, 0, 0)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/IntNdArrayTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/IntNdArrayTestBase.java new file mode 100644 index 00000000000..f3278196901 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/IntNdArrayTestBase.java @@ -0,0 +1,77 @@ +/* +Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public abstract class IntNdArrayTestBase extends NdArrayTestBase { + + @Override + protected abstract IntNdArray allocate(Shape shape); + + @Override + protected Integer valueOf(Long val) { + return val.intValue(); + } + + @Test + public void iteratePrimitiveElements() { + IntNdArray matrix3d = allocate(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> scalar.setInt((int) coords[2])); + + assertEquals(0, matrix3d.getInt(0, 0, 0)); + assertEquals(1, matrix3d.getInt(0, 0, 1)); + assertEquals(4, matrix3d.getInt(0, 0, 4)); + assertEquals(2, matrix3d.getInt(0, 1, 2)); + + matrix3d.elements(1).forEach(vector -> vector.set(NdArrays.vectorOf(5, 6, 7, 8, 9))); + + assertEquals(5, matrix3d.getInt(0, 0, 0)); + assertEquals(6, matrix3d.getInt(0, 0, 1)); + assertEquals(9, matrix3d.getInt(0, 0, 4)); + assertEquals(7, matrix3d.getInt(0, 1, 2)); + } + + @Test + public void streamingInts() { + IntNdArray scalar = allocate(Shape.scalar()); + scalar.setInt(1); + var values = scalar.streamOfInts().toArray(); + assertArrayEquals(new int[] {1}, values); + + IntNdArray vector = allocate(Shape.of(5)); + vector.setInt(1, 0); + vector.setInt(2, 1); + vector.setInt(3, 2); + vector.setInt(4, 3); + vector.setInt(5, 4); + values = vector.streamOfInts().toArray(); + assertArrayEquals(new int[] {1, 2, 3, 4, 5}, values); + + IntNdArray matrix = allocate(Shape.of(2, 2)); + matrix.setInt(1, 0, 0); + matrix.setInt(2, 0, 1); + matrix.setInt(3, 1, 0); + matrix.setInt(4, 1, 1); + values = matrix.streamOfInts().toArray(); + assertArrayEquals(new int[] {1, 2, 3, 4}, values); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/LongNdArrayTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/LongNdArrayTestBase.java new file mode 100644 index 00000000000..ad8023284f1 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/LongNdArrayTestBase.java @@ -0,0 +1,77 @@ +/* +Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public abstract class LongNdArrayTestBase extends NdArrayTestBase { + + @Override + protected abstract LongNdArray allocate(Shape shape); + + @Override + protected Long valueOf(Long val) { + return val; + } + + @Test + public void iteratePrimitiveElements() { + LongNdArray matrix3d = allocate(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> scalar.setLong(coords[2])); + + assertEquals(0, matrix3d.getLong(0, 0, 0)); + assertEquals(1, matrix3d.getLong(0, 0, 1)); + assertEquals(4, matrix3d.getLong(0, 0, 4)); + assertEquals(2, matrix3d.getLong(0, 1, 2)); + + matrix3d.elements(1).forEach(vector -> vector.set(NdArrays.vectorOf(5L, 6L, 7L, 8L, 9L))); + + assertEquals(5, matrix3d.getLong(0, 0, 0)); + assertEquals(6, matrix3d.getLong(0, 0, 1)); + assertEquals(9, matrix3d.getLong(0, 0, 4)); + assertEquals(7, matrix3d.getLong(0, 1, 2)); + } + + @Test + public void streamingLongs() { + LongNdArray scalar = allocate(Shape.scalar()); + scalar.setLong(1L); + var values = scalar.streamOfLongs().toArray(); + assertArrayEquals(new long[] {1L}, values); + + LongNdArray vector = allocate(Shape.of(5)); + vector.setLong(1L, 0); + vector.setLong(2L, 1); + vector.setLong(3L, 2); + vector.setLong(4L, 3); + vector.setLong(5L, 4); + values = vector.streamOfLongs().toArray(); + assertArrayEquals(new long[] {1L, 2L, 3L, 4L, 5L}, values); + + LongNdArray matrix = allocate(Shape.of(2, 2)); + matrix.setLong(1L, 0, 0); + matrix.setLong(2L, 0, 1); + matrix.setLong(3L, 1, 0); + matrix.setLong(4L, 1, 1); + values = matrix.streamOfLongs().toArray(); + assertArrayEquals(new long[] {1L, 2L, 3L, 4L}, values); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java new file mode 100644 index 00000000000..ce6d990dd90 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java @@ -0,0 +1,428 @@ +/* +Copyright 2019-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import static org.junit.jupiter.api.Assertions.*; +import static org.tensorflow.ndarray.NdArrays.vectorOfObjects; +import static org.tensorflow.ndarray.index.Indices.all; +import static org.tensorflow.ndarray.index.Indices.at; +import static org.tensorflow.ndarray.index.Indices.even; +import static org.tensorflow.ndarray.index.Indices.flip; +import static org.tensorflow.ndarray.index.Indices.odd; +import static org.tensorflow.ndarray.index.Indices.range; +import static org.tensorflow.ndarray.index.Indices.seq; +import static org.tensorflow.ndarray.index.Indices.sliceFrom; +import static org.tensorflow.ndarray.index.Indices.sliceTo; + +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.index.Indices; + +public abstract class NdArrayTestBase { + + protected abstract NdArray allocate(Shape shape); + + protected abstract DataBuffer allocateBuffer(long size); + + protected abstract T valueOf(Long val); + + protected T zeroOrNull() { + return valueOf(0L); + } + + @Test + public void shapeAndSizes() { + Shape scalarShape = Shape.scalar(); + NdArray scalar = allocate(scalarShape); + assertEquals(scalarShape, scalar.shape()); + assertEquals(0, scalar.rank()); + assertEquals(scalarShape, Shape.of()); + + Shape vectorShape = Shape.of(10); + NdArray vector = allocate(vectorShape); + assertEquals(vectorShape, vector.shape()); + assertEquals(1, vector.rank()); + } + + @Test + public void setAndGetValues() { + NdArray matrix = allocate(Shape.of(5, 4)); + assertEquals(zeroOrNull(), matrix.getObject(3, 3)); + + matrix.setObject(valueOf(10L), 3, 3); + assertEquals(valueOf(10L), matrix.getObject(3, 3)); + try { + matrix.setObject(valueOf(10L), 3, 4); + fail(); + } catch (IndexOutOfBoundsException e) { + // as expected + } + try { + matrix.setObject(valueOf(10L), -1, 3); + fail(); + } catch (IndexOutOfBoundsException e) { + // as expected + } + try { + matrix.getObject(3); + fail(); + } catch (IllegalRankException e) { + // as expected + } + try { + matrix.setObject(valueOf(10L), 3); + fail(); + } catch (IllegalRankException e) { + // as expected + } + + NdArray matrix2 = + allocate(Shape.of(3, 2)) + .set(vectorOfObjects(valueOf(1L), valueOf(2L)), 0) + .set(vectorOfObjects(valueOf(3L), valueOf(4L)), 1) + .setObject(valueOf(5L), 2, 0) + .setObject(valueOf(6L), 2, 1); + + assertEquals(valueOf(1L), matrix2.getObject(0, 0)); + assertEquals(valueOf(2L), matrix2.getObject(0, 1)); + assertEquals(valueOf(3L), matrix2.getObject(1, 0)); + assertEquals(valueOf(4L), matrix2.getObject(1, 1)); + assertEquals(valueOf(5L), matrix2.getObject(2, 0)); + assertEquals(valueOf(6L), matrix2.getObject(2, 1)); + } + + @Test + public void iterateElements() { + NdArray matrix3d = allocate(Shape.of(5, 4, 5)); + + matrix3d + .scalars() + .forEachIndexed( + (coords, scalar) -> { + scalar.setObject(valueOf(coords[2])); + }); + + assertEquals(valueOf(0L), matrix3d.getObject(0, 0, 0)); + assertEquals(valueOf(1L), matrix3d.getObject(0, 0, 1)); + assertEquals(valueOf(4L), matrix3d.getObject(0, 0, 4)); + assertEquals(valueOf(2L), matrix3d.getObject(0, 1, 2)); + + matrix3d + .elements(1) + .forEach( + vector -> { + vector.set( + vectorOfObjects(valueOf(5L), valueOf(6L), valueOf(7L), valueOf(8L), valueOf(9L))); + }); + + assertEquals(valueOf(5L), matrix3d.getObject(0, 0, 0)); + assertEquals(valueOf(6L), matrix3d.getObject(0, 0, 1)); + assertEquals(valueOf(9L), matrix3d.getObject(0, 0, 4)); + assertEquals(valueOf(7L), matrix3d.getObject(0, 1, 2)); + + long value = 0L; + for (NdArray matrix : matrix3d.elements(0)) { + assertEquals(2L, matrix.shape().numDimensions()); + assertEquals(4L, matrix.shape().get(0)); + assertEquals(5L, matrix.shape().get(1)); + + for (NdArray vector : matrix.elements(0)) { + assertEquals(1L, vector.shape().numDimensions()); + assertEquals(5L, vector.shape().get(0)); + + for (NdArray scalar : vector.scalars()) { + assertEquals(0L, scalar.shape().numDimensions()); + scalar.setObject(valueOf(value++)); + try { + scalar.elements(0); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + } + } + } + assertEquals(valueOf(0L), matrix3d.getObject(0, 0, 0)); + assertEquals(valueOf(5L), matrix3d.getObject(0, 1, 0)); + assertEquals(valueOf(9L), matrix3d.getObject(0, 1, 4)); + assertEquals(valueOf(20L), matrix3d.getObject(1, 0, 0)); + assertEquals(valueOf(25L), matrix3d.getObject(1, 1, 0)); + assertEquals(valueOf(99L), matrix3d.getObject(4, 3, 4)); + } + + @Test + public void slices() { + NdArray matrix3d = allocate(Shape.of(5, 4, 5)); + + T val100 = valueOf(100L); + matrix3d.setObject(val100, 1, 0, 0); + T val101 = valueOf(101L); + matrix3d.setObject(val101, 1, 0, 1); + + // Vector (1,0,*) + NdArray vector10X = matrix3d.get(1, 0); + assertEquals(Shape.of(5), vector10X.shape()); + assertEquals(val100, vector10X.getObject(0)); + assertEquals(val101, vector10X.getObject(1)); + + T val102 = valueOf(102L); + vector10X.setObject(val102, 2); + assertEquals(val102, vector10X.getObject(2)); + assertEquals(val102, matrix3d.getObject(1, 0, 2)); + + // Vector (*,0,0) + NdArray vectorX00 = matrix3d.slice(all(), at(0), at(0)); + assertEquals(Shape.of(5), vectorX00.shape()); + assertEquals(val100, vectorX00.getObject(1)); + T val200 = valueOf(200L); + vectorX00.setObject(val200, 2); + assertEquals(val200, vectorX00.getObject(2)); + assertEquals(val200, matrix3d.getObject(2, 0, 0)); + + // Vector (1,0,[2,0]) + NdArray vector10_20 = matrix3d.slice(at(1), at(0), seq(2, 0)); + assertEquals(vector10_20.shape(), Shape.of(2)); + assertEquals(val102, vector10_20.getObject(0)); + assertEquals(val100, vector10_20.getObject(1)); + + // Vector (1,0,[even]) + NdArray vector10_even = matrix3d.slice(at(1), at(0), even()); + assertEquals(vector10_even.shape(), Shape.of(3)); + assertEquals(val100, vector10_even.getObject(0)); + assertEquals(val102, vector10_even.getObject(1)); + + // Vector ([odd]) from vector (1,0,[even]) + NdArray vector10_even_odd = vector10_even.slice(odd()); + assertEquals(vector10_even_odd.shape(), Shape.of(1)); + assertEquals(val102, vector10_even_odd.getObject(0)); + + // Vector (1,0,[flip]) + NdArray vector10_flip = matrix3d.slice(at(1), at(0), flip()); + assertEquals(vector10_flip.shape(), Shape.of(5)); + assertEquals(val100, vector10_flip.getObject(4)); + assertEquals(val101, vector10_flip.getObject(3)); + + // Vector (1,0,[from 1]) from vector (1,0,*) + NdArray vector10_1toX = vector10X.slice(sliceFrom(1)); + assertEquals(vector10_1toX.shape(), Shape.of(4)); + assertEquals(val101, vector10_1toX.getObject(0)); + assertEquals(val102, vector10_1toX.getObject(1)); + + // Vector (1,0,[to 1]) from vector (1,0,*) + NdArray vector10_Xto1 = vector10X.slice(sliceTo(2)); + assertEquals(vector10_Xto1.shape(), Shape.of(2)); + assertEquals(val100, vector10_Xto1.getObject(0)); + assertEquals(val101, vector10_Xto1.getObject(1)); + + // Vector (1,0,[1 to 3]) + NdArray vector10_1to3 = matrix3d.slice(at(1), at(0), range(1, 3)); + assertEquals(vector10_1to3.shape(), Shape.of(2)); + assertEquals(val101, vector10_1to3.getObject(0)); + assertEquals(val102, vector10_1to3.getObject(1)); + + // Scalar (1,0,0) from vector (1,0,*) + NdArray scalar100 = vector10X.get(0); + assertEquals(Shape.of(), scalar100.shape()); + assertEquals(val100, scalar100.getObject()); + + // Slice scalar (1,0,z) + LongNdArray z = NdArrays.scalarOf(2L); + NdArray scalar102 = matrix3d.slice(at(1), at(0), at(z)); + assertEquals(scalar102.shape(), Shape.of()); + assertEquals(val102, scalar102.getObject()); + + // Slicing the 3D matrix so we only keep the first element of the second dimension + NdArray matrix_X0Z = matrix3d.slice(all(), at(0)); + assertEquals(2, matrix_X0Z.rank()); + assertEquals(Shape.of(5, 5), matrix_X0Z.shape()); + assertEquals(val100, matrix_X0Z.getObject(1, 0)); + assertEquals(val101, matrix_X0Z.getObject(1, 1)); + assertEquals(val200, matrix_X0Z.getObject(2, 0)); + } + + @Test + public void writeAndReadWithBuffers() { + DataBuffer buffer = allocateBuffer(15L); + for (long val = 0L; val < buffer.size(); ++val) { + buffer.setObject(valueOf(val), val); + } + NdArray matrix = allocate(Shape.of(3, 5)); + matrix.copyFrom(buffer); + assertEquals(valueOf(0L), matrix.getObject(0, 0)); + assertEquals(valueOf(4L), matrix.getObject(0, 4)); + assertEquals(valueOf(5L), matrix.getObject(1, 0)); + assertEquals(valueOf(10L), matrix.getObject(2, 0)); + assertEquals(valueOf(14L), matrix.getObject(2, 4)); + + matrix.setObject(valueOf(100L), 1, 0); + matrix.copyTo(buffer); + assertEquals(valueOf(0L), buffer.getObject(0)); + assertEquals(valueOf(4L), buffer.getObject(4)); + assertEquals(valueOf(100L), buffer.getObject(5)); + assertEquals(valueOf(10L), buffer.getObject(10)); + assertEquals(valueOf(14L), buffer.getObject(14)); + + try { + matrix.copyFrom(buffer.narrow(10)); + fail(); + } catch (BufferUnderflowException e) { + // as expected + } + try { + matrix.copyTo(buffer.narrow(10)); + fail(); + } catch (BufferOverflowException e) { + // as expected + } + } + + @Test + public void ndArrayCopies() { + NdArray matrixA = allocate(Shape.of(3, 5)); + + long value = 0L; + for (NdArray s : matrixA.scalars()) { + s.setObject(valueOf(value++)); + } + NdArray matrixB = allocate(Shape.of(3, 5)).setObject(valueOf(100L), 1, 0); + matrixA.copyTo(matrixB); + assertEquals(valueOf(0L), matrixB.getObject(0, 0)); + assertEquals(valueOf(4L), matrixB.getObject(0, 4)); + assertEquals(valueOf(5L), matrixB.getObject(1, 0)); + assertEquals(valueOf(10L), matrixB.getObject(2, 0)); + assertEquals(valueOf(14L), matrixB.getObject(2, 4)); + + NdArray matrixC = allocate(Shape.of(3, 4)); + try { + matrixA.copyTo(matrixC); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + } + + @Test + public void equalsAndHashCode() { + NdArray array1 = allocate(Shape.of(2, 2)); + NdArray array2 = allocate(Shape.of(2, 2)); + NdArray array3 = allocate(Shape.of(2, 2)); + NdArray array4 = allocate(Shape.of(1, 2, 2)); + + @SuppressWarnings("unchecked") + T[][][] values = + (T[][][]) (new Object[][][] {{{valueOf(0L), valueOf(1L)}, {valueOf(2L), valueOf(0L)}}}); + + StdArrays.copyTo(values[0], array1); + StdArrays.copyTo(values[0], array2); + StdArrays.copyTo(values[0], array3); + array3.setObject(valueOf(0L), 0, 1); + StdArrays.copyTo(values, array4); + + assertEquals(array1, array2); + assertEquals(array1.hashCode(), array2.hashCode()); + assertNotEquals(array1, array3); + assertNotEquals(array1.hashCode(), array3.hashCode()); + assertNotEquals(array1, array4); + assertNotEquals(array1.hashCode(), array4.hashCode()); + } + + @Test + public void iterateScalarsOnSegmentedElements() { + NdArray originalTensor = allocate(Shape.of(2, 3)); + + originalTensor + .setObject(valueOf(0L), 0, 0) + .setObject(valueOf(1L), 0, 1) + .setObject(valueOf(2L), 0, 2) + .setObject(valueOf(3L), 1, 0) + .setObject(valueOf(4L), 1, 1) + .setObject(valueOf(5L), 1, 2); + + NdArray slice = originalTensor.slice(Indices.all(), Indices.sliceFrom(1)); + assertEquals(Shape.of(2, 2), slice.shape()); + + slice + .elements(0) + .forEachIndexed( + (eCoord, e) -> { + e.scalars() + .forEachIndexed( + (sCoord, s) -> { + assertEquals( + valueOf((eCoord[0] * originalTensor.shape().get(1)) + sCoord[0] + 1), + s.getObject()); + }); + }); + } + + @Test + public void streamingObjects() { + NdArray scalar = allocate(Shape.scalar()); + scalar.setObject(valueOf(1L)); + var values = scalar.streamOfObjects().collect(Collectors.toList()); + assertIterableEquals(List.of(valueOf(1L)), values); + + NdArray vector = allocate(Shape.of(5)); + vector.setObject(valueOf(1L), 0); + vector.setObject(valueOf(2L), 1); + vector.setObject(valueOf(3L), 2); + vector.setObject(valueOf(4L), 3); + vector.setObject(valueOf(5L), 4); + values = vector.streamOfObjects().collect(Collectors.toList()); + assertIterableEquals( + List.of(valueOf(1L), valueOf(2L), valueOf(3L), valueOf(4L), valueOf(5L)), values); + + NdArray matrix = allocate(Shape.of(2, 2)); + matrix.setObject(valueOf(1L), 0, 0); + matrix.setObject(valueOf(2L), 0, 1); + matrix.setObject(valueOf(3L), 1, 0); + matrix.setObject(valueOf(4L), 1, 1); + values = matrix.streamOfObjects().collect(Collectors.toList()); + assertIterableEquals(List.of(valueOf(1L), valueOf(2L), valueOf(3L), valueOf(4L)), values); + } + + @Test + public void withShape() { + Shape originalShape = Shape.scalar(); + Shape newShape = originalShape.prepend(1).prepend(1); // [1, 1] + + NdArray originalArray = allocate(originalShape); + originalArray.setObject(valueOf(10L)); + assertEquals(valueOf(10L), originalArray.getObject()); + + NdArray newArray = originalArray.withShape(newShape); + assertNotNull(newArray); + assertEquals(newShape, newArray.shape()); + assertEquals(originalShape, originalArray.shape()); + assertEquals(valueOf(10L), newArray.getObject(0, 0)); + + NdArray sameArray = originalArray.withShape(Shape.scalar()); + assertSame(originalArray, sameArray); + + assertThrows(IllegalArgumentException.class, () -> originalArray.withShape(Shape.of(2))); + assertThrows(IllegalArgumentException.class, () -> originalArray.withShape(Shape.unknown())); + + NdArray originalMatrix = allocate(Shape.of(2, 3)); + assertThrows(IllegalArgumentException.class, () -> originalMatrix.withShape(Shape.scalar())); + NdArray newMatrix = originalMatrix.withShape(Shape.of(3, 2)); + assertEquals(Shape.of(3, 2), newMatrix.shape()); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/ShapeTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/ShapeTest.java new file mode 100644 index 00000000000..f6bec66cb25 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/ShapeTest.java @@ -0,0 +1,176 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.jupiter.api.Test; + +public class ShapeTest { + + @Test + public void allKnownDimensions() { + Shape shape = Shape.of(5, 4, 5); + assertEquals(3, shape.numDimensions()); + assertEquals(5, shape.get(0)); + assertEquals(4, shape.get(1)); + assertEquals(5, shape.get(2)); + assertEquals(100, shape.size()); + assertArrayEquals(new long[] {5, 4, 5}, shape.asArray()); + try { + shape.get(3); + fail(); + } catch (IndexOutOfBoundsException e) { + // as expected + } + assertEquals(5, shape.get(-1)); + assertEquals(4, shape.get(-2)); + assertEquals(5, shape.get(-3)); + try { + shape.get(-4); + fail(); + } catch (IndexOutOfBoundsException e) { + // as expected + } + assertFalse(shape.isUnknown()); + assertFalse(shape.hasUnknownDimension()); + assertFalse(shape.isScalar()); + } + + @Test + public void hashCodeEquals() { + Shape shape1 = Shape.of(5, 4, 5); + Shape shape2 = Shape.of(5, 4, 5); + Shape shape3 = Shape.of(5, 4, 5, 6); + Shape shape4 = Shape.of(5, 4, 1); + + assertEquals(shape1, shape2); + assertEquals(shape1.hashCode(), shape2.hashCode()); + assertNotEquals(shape1, shape3); + assertNotEquals(shape1.hashCode(), shape3.hashCode()); + assertNotEquals(shape1, shape4); + assertNotEquals(shape1.hashCode(), shape4.hashCode()); + + Shape scalar1 = Shape.of(); + Shape scalar2 = Shape.of(); + assertEquals(scalar1, scalar2); + assertNotEquals(scalar1, shape1); + + Shape unknown1 = Shape.of(-1, 4, 5); + Shape unknown2 = Shape.of(-1, 4, 5); + assertNotEquals(unknown1, unknown2); + assertNotEquals(unknown1, shape1); + assertEquals(unknown1, unknown1); + + Shape sizeUnknown1 = Shape.unknown(); + Shape sizeUnknown2 = Shape.unknown(); + assertNotEquals(sizeUnknown1, sizeUnknown2); + assertEquals(sizeUnknown1, sizeUnknown1); + } + + @Test + public void testShapeModification() { + Shape one = Shape.of(2, 4, 6, 8); + assertEquals(one.head(), Shape.of(2)); + assertEquals(one.tail(), Shape.of(4, 6, 8)); + + Shape two = Shape.of(5); + assertEquals(two.head(), two); + assertEquals(two.tail(), Shape.of()); + + try { + Shape.of().head(); + fail(); + } catch (IndexOutOfBoundsException e) { + // as expected + } + + assertEquals(Shape.of().tail(), Shape.of()); + + Shape three = Shape.of(2, 4, 6); + assertEquals(three.prepend(5), Shape.of(5, 2, 4, 6)); + + assertEquals(Shape.of(5, 2, 4, 6), two.append(three)); + assertEquals(Shape.of(2, 4, 6, 5), two.prepend(three)); + assertEquals(Shape.of(1, 2, 3, 4), Shape.of(1, 2).append(Shape.of(3, 4))); + assertEquals(Shape.of(1, 2, 3, 4), Shape.of(1, 2, 3).append(4)); + assertEquals(Shape.of(1, 2, 3, 4), Shape.of(1, 2, 3, 4).append(Shape.scalar())); + assertEquals(Shape.of(3, 4, 1, 2), Shape.of(1, 2).prepend(Shape.of(3, 4))); + assertEquals(Shape.of(4, 6), three.takeLast(2)); + assertEquals(Shape.scalar(), three.takeLast(0)); + assertEquals(Shape.of(2, 4), three.take(2)); + assertEquals(Shape.scalar(), three.take(0)); + + try { + Shape.unknown().append(Shape.of(1, 2)); + fail(); + } catch (NullPointerException e) { + // as expected + } + + try { + Shape.unknown().prepend(Shape.of(1, 2)); + fail(); + } catch (NullPointerException e) { + // as expected + } + + // changing the values of the array returned by asArray should not mutate the shape + long[] internalShape = one.asArray(); + assertNotNull(internalShape); + internalShape[0] = 42L; + assertEquals(2L, one.get(0)); + } + + @Test + public void testShapeCompatible() { + Shape a = Shape.unknown(); + Shape b = Shape.of(2, 2); + assertTrue(a.isCompatibleWith(b)); + assertTrue(b.isCompatibleWith(a)); + + a = Shape.of(2, 2); + assertTrue(a.isCompatibleWith(b)); + assertTrue(b.isCompatibleWith(a)); + + a = Shape.of(2, -1); + assertTrue(a.isCompatibleWith(b)); + assertTrue(b.isCompatibleWith(a)); + + a = Shape.of(-1, 2); + assertTrue(a.isCompatibleWith(b)); + assertTrue(b.isCompatibleWith(a)); + + a = Shape.of(-1, -1); + assertTrue(a.isCompatibleWith(b)); + assertTrue(b.isCompatibleWith(a)); + + a = Shape.of(1, 2); + assertFalse(a.isCompatibleWith(b)); + assertFalse(b.isCompatibleWith(a)); + + a = Shape.of(1, 2, 3); + assertFalse(a.isCompatibleWith(b)); + assertFalse(b.isCompatibleWith(a)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/ShortNdArrayTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/ShortNdArrayTestBase.java new file mode 100644 index 00000000000..347ac7a7b6a --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/ShortNdArrayTestBase.java @@ -0,0 +1,56 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public abstract class ShortNdArrayTestBase extends NdArrayTestBase { + + @Override + protected abstract ShortNdArray allocate(Shape shape); + + @Override + protected Short valueOf(Long val) { + return val.shortValue(); + } + + @Test + public void iteratePrimitiveElements() { + ShortNdArray matrix3d = allocate(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> scalar.setShort((short) coords[2])); + + assertEquals(0, matrix3d.getShort(0, 0, 0)); + assertEquals(1, matrix3d.getShort(0, 0, 1)); + assertEquals(4, matrix3d.getShort(0, 0, 4)); + assertEquals(2, matrix3d.getShort(0, 1, 2)); + + matrix3d + .elements(1) + .forEach( + vector -> + vector.set( + NdArrays.vectorOf((short) 5, (short) 6, (short) 7, (short) 8, (short) 9))); + + assertEquals(5, matrix3d.getShort(0, 0, 0)); + assertEquals(6, matrix3d.getShort(0, 0, 1)); + assertEquals(9, matrix3d.getShort(0, 0, 4)); + assertEquals(7, matrix3d.getShort(0, 1, 2)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/SparseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/SparseNdArrayTest.java new file mode 100644 index 00000000000..9c001dbaf80 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/SparseNdArrayTest.java @@ -0,0 +1,196 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.impl.sparse.BooleanSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.ByteSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.DoubleSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.FloatSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.IntSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.LongSparseNdArray; +import org.tensorflow.ndarray.impl.sparse.ShortSparseNdArray; + +public class SparseNdArrayTest { + long[][] indicesArray = {{0, 0}, {1, 2}, {2, 3}}; + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + Shape shape = Shape.of(3, 4); + double epsilon = 0.001; + + @Test + public void testBoolean() { + BooleanSparseNdArray instance = + NdArrays.sparseOf(indices, NdArrays.vectorOf(true, true, true), shape); + assertEquals(6, instance.getIndices().size()); + assertEquals(3, instance.getValues().size()); + assertTrue(instance.getBoolean(0, 0)); + assertFalse(instance.getBoolean(0, 1)); + assertFalse(instance.getBoolean(0, 2)); + assertFalse(instance.getBoolean(0, 3)); + + assertFalse(instance.getBoolean(1, 0)); + assertFalse(instance.getBoolean(1, 1)); + assertTrue(instance.getBoolean(1, 2)); + assertFalse(instance.getBoolean(1, 3)); + + assertFalse(instance.getBoolean(2, 0)); + assertFalse(instance.getBoolean(2, 1)); + assertFalse(instance.getBoolean(2, 2)); + assertTrue(instance.getBoolean(2, 3)); + } + + @Test + public void testByte() { + ByteSparseNdArray instance = + NdArrays.sparseOf(indices, NdArrays.vectorOf((byte) 1, (byte) 18, (byte) 0xff), shape); + assertEquals(6, instance.getIndices().size()); + assertEquals(3, instance.getValues().size()); + assertEquals((byte) 1, instance.getByte(0, 0)); + assertEquals((byte) 0, instance.getByte(0, 1)); + assertEquals((byte) 0, instance.getByte(0, 2)); + assertEquals((byte) 0, instance.getByte(0, 3)); + + assertEquals((byte) 0, instance.getByte(1, 0)); + assertEquals((byte) 0, instance.getByte(1, 1)); + assertEquals((byte) 18, instance.getByte(1, 2)); + assertEquals((byte) 0, instance.getByte(1, 3)); + + assertEquals((byte) 0, instance.getByte(2, 0)); + assertEquals((byte) 0, instance.getByte(2, 1)); + assertEquals((byte) 0, instance.getByte(2, 2)); + assertEquals((byte) 0xff, instance.getByte(2, 3)); + } + + @Test + public void testDouble() { + DoubleSparseNdArray instance = + NdArrays.sparseOf(indices, NdArrays.vectorOf(1., 1.8, 3.14), shape); + assertEquals(6, instance.getIndices().size()); + assertEquals(3, instance.getValues().size()); + assertEquals(1., instance.getDouble(0, 0), epsilon); + assertEquals(0, instance.getDouble(0, 1), epsilon); + assertEquals(0, instance.getDouble(0, 2), epsilon); + assertEquals(0, instance.getDouble(0, 3), epsilon); + + assertEquals(0, instance.getDouble(1, 0), epsilon); + assertEquals(0, instance.getDouble(1, 1), epsilon); + assertEquals(1.8, instance.getDouble(1, 2), epsilon); + assertEquals(0, instance.getDouble(1, 3), epsilon); + + assertEquals(0, instance.getDouble(2, 0), epsilon); + assertEquals(0, instance.getDouble(2, 1), epsilon); + assertEquals(0, instance.getDouble(2, 2), epsilon); + assertEquals(3.14, instance.getDouble(2, 3), epsilon); + } + + @Test + public void testFloat() { + FloatSparseNdArray instance = + NdArrays.sparseOf(indices, NdArrays.vectorOf(1.f, 1.8f, 3.14f), shape); + assertEquals(6, instance.getIndices().size()); + assertEquals(3, instance.getValues().size()); + assertEquals(1.f, instance.getFloat(0, 0), epsilon); + assertEquals(0f, instance.getFloat(0, 1), epsilon); + assertEquals(0f, instance.getFloat(0, 2), epsilon); + assertEquals(0f, instance.getFloat(0, 3), epsilon); + + assertEquals(0f, instance.getFloat(1, 0), epsilon); + assertEquals(0f, instance.getFloat(1, 1), epsilon); + assertEquals(1.8f, instance.getFloat(1, 2), epsilon); + assertEquals(0f, instance.getFloat(1, 3), epsilon); + + assertEquals(0f, instance.getFloat(2, 0), epsilon); + assertEquals(0f, instance.getFloat(2, 1), epsilon); + assertEquals(0f, instance.getFloat(2, 2), epsilon); + assertEquals(3.14f, instance.getFloat(2, 3), epsilon); + } + + @Test + public void testInt() { + IntSparseNdArray instance = NdArrays.sparseOf(indices, NdArrays.vectorOf(1, 18, 256), shape); + assertEquals(6, instance.getIndices().size()); + assertEquals(3, instance.getValues().size()); + assertEquals(1, instance.getInt(0, 0)); + assertEquals(0, instance.getInt(0, 1)); + assertEquals(0, instance.getInt(0, 2)); + assertEquals(0, instance.getInt(0, 3)); + + assertEquals(0, instance.getInt(1, 0)); + assertEquals(0, instance.getInt(1, 1)); + assertEquals(18, instance.getInt(1, 2)); + assertEquals(0, instance.getInt(1, 3)); + + assertEquals(0, instance.getInt(2, 0)); + assertEquals(0, instance.getInt(2, 1)); + assertEquals(0, instance.getInt(2, 2)); + assertEquals(256, instance.getInt(2, 3)); + } + + @Test + public void testLong() { + LongSparseNdArray instance = + NdArrays.sparseOf(indices, NdArrays.vectorOf(1L, 18L, 256L), shape); + assertEquals(6, instance.getIndices().size()); + assertEquals(3, instance.getValues().size()); + assertEquals(1L, instance.getLong(0, 0)); + assertEquals(0L, instance.getLong(0, 1)); + assertEquals(0L, instance.getLong(0, 2)); + assertEquals(0L, instance.getLong(0, 3)); + + assertEquals(0L, instance.getLong(1, 0)); + assertEquals(0L, instance.getLong(1, 1)); + assertEquals(18, instance.getLong(1, 2)); + assertEquals(0L, instance.getLong(1, 3)); + + assertEquals(0L, instance.getLong(2, 0)); + assertEquals(0L, instance.getLong(2, 1)); + assertEquals(0L, instance.getLong(2, 2)); + assertEquals(256L, instance.getLong(2, 3)); + } + + @Test + public void testShort() { + ShortSparseNdArray instance = + NdArrays.sparseOf(indices, NdArrays.vectorOf((short) 1, (short) 18, (short) 0xff00), shape); + assertEquals(6, instance.getIndices().size()); + assertEquals(3, instance.getValues().size()); + assertEquals((short) 1, instance.getShort(0, 0)); + assertEquals((short) 0, instance.getShort(0, 1)); + assertEquals((short) 0, instance.getShort(0, 2)); + assertEquals((short) 0, instance.getShort(0, 3)); + + assertEquals((short) 0, instance.getShort(1, 0)); + assertEquals((short) 0, instance.getShort(1, 1)); + assertEquals((short) 18, instance.getShort(1, 2)); + assertEquals((short) 0, instance.getShort(1, 3)); + + assertEquals((short) 0, instance.getShort(2, 0)); + assertEquals((short) 0, instance.getShort(2, 1)); + assertEquals((short) 0, instance.getShort(2, 2)); + assertEquals((short) 0xff00, instance.getShort(2, 3)); + } + + @Test + public void withShape() { + NdArray sparseArray = NdArrays.sparseOf(indices, NdArrays.vectorOf(1, 2, 3), shape); + assertThrows( + UnsupportedOperationException.class, () -> sparseArray.withShape(shape.prepend(1))); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/StdArraysTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/StdArraysTest.java new file mode 100644 index 00000000000..7b1c9663a39 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/StdArraysTest.java @@ -0,0 +1,211 @@ +package org.tensorflow.ndarray; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.jupiter.api.Test; + +public class StdArraysTest { + + @Test + public void vectors() { + IntNdArray vector = NdArrays.ofInts(Shape.of(2)); + + StdArrays.copyTo(new int[] {1, 2}, vector); + assertEquals(1, vector.getInt(0)); + assertEquals(2, vector.getInt(1)); + + try { + StdArrays.copyTo(new int[] {1, 2, 3}, vector); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + try { + StdArrays.copyTo(new int[] {1, 2}, NdArrays.ofInts(Shape.of(4))); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + try { + StdArrays.copyTo(new int[] {1, 2}, NdArrays.ofInts(Shape.of(2, 2))); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + + int[] array = StdArrays.array1dCopyOf(vector); + assertEquals(1, array[0]); + assertEquals(2, array[1]); + + array = new int[3]; + StdArrays.copyFrom(vector, array); + assertEquals(1, array[0]); + assertEquals(2, array[1]); + assertEquals(0, array[2]); + + try { + StdArrays.copyFrom(vector, new int[1]); + fail(); + } catch (ArrayIndexOutOfBoundsException e) { + // as expected + } + try { + StdArrays.copyFrom(vector, new int[1][2]); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + try { + StdArrays.copyFrom(vector, new int[2][2][2]); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + } + + @Test + public void matrices() { + IntNdArray matrix = NdArrays.ofInts(Shape.of(2, 2)); + + StdArrays.copyTo( + new int[][] { + {1, 2}, + {3, 4} + }, + matrix); + assertEquals(1, matrix.getInt(0, 0)); + assertEquals(2, matrix.getInt(0, 1)); + assertEquals(3, matrix.getInt(1, 0)); + assertEquals(4, matrix.getInt(1, 1)); + try { + StdArrays.copyTo(new int[][] {{1, 2, 3}, {4, 5, 6}}, matrix); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + try { + StdArrays.copyTo(new int[][] {{1, 2}, {3, 4}}, NdArrays.ofInts(Shape.of(3, 3))); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + try { + StdArrays.copyTo(new int[][] {{1, 2}, {3, 4}}, NdArrays.ofInts(Shape.of(2, 2, 1))); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + + int[][] array = StdArrays.array2dCopyOf(matrix); + assertEquals(1, array[0][0]); + assertEquals(2, array[0][1]); + assertEquals(3, array[1][0]); + assertEquals(4, array[1][1]); + + array = new int[3][3]; + StdArrays.copyFrom(matrix, array); + assertArrayEquals(new int[] {1, 2, 0}, array[0]); + assertArrayEquals(new int[] {3, 4, 0}, array[1]); + assertArrayEquals(new int[] {0, 0, 0}, array[2]); + + try { + StdArrays.copyFrom(matrix, new int[1][2]); + fail(); + } catch (ArrayIndexOutOfBoundsException e) { + // as expected + } + try { + StdArrays.copyFrom(matrix, new int[2][1]); + fail(); + } catch (ArrayIndexOutOfBoundsException e) { + // as expected + } + try { + StdArrays.copyFrom(matrix, new int[2]); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + try { + StdArrays.copyFrom(matrix, new int[1][2][2]); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + try { + StdArrays.copyFrom(matrix, new int[2][2][2]); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + } + + @Test + public void objectMatrix() { + NdArray matrix = StdArrays.ndCopyOf(new String[][] {{"ab", "bc"}, {"cd", "de"}}); + assertEquals(NdArrays.vectorOfObjects("ab", "bc"), matrix.get(0)); + assertEquals(NdArrays.vectorOfObjects("cd", "de"), matrix.get(1)); + + String[][] array = StdArrays.array2dCopyOf(matrix, String.class); + assertEquals("ab", array[0][0]); + assertEquals("bc", array[0][1]); + assertEquals("cd", array[1][0]); + assertEquals("de", array[1][1]); + + array = new String[2][3]; + StdArrays.copyFrom(matrix, array); + assertEquals("ab", array[0][0]); + assertEquals("bc", array[0][1]); + assertNull(array[0][2]); + assertEquals("cd", array[1][0]); + assertEquals("de", array[1][1]); + assertNull(array[1][2]); + } + + @Test + public void cannotInitDenseMatrixWithRaggedArray() { + IntNdArray matrix = NdArrays.ofInts(Shape.of(2, 2)); + try { + StdArrays.copyTo( + new int[][] { + {1, 2}, + {3} + }, + matrix); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + } + + @Test + public void computeShapeDense3DMatrix() { + Shape shape = + StdArrays.shapeOf( + new int[][][] { + {{1, 2, 3}, {4, 5, 6}}, + {{1, 2, 3}, {4, 5, 6}} + }); + assertArrayEquals(new long[] {2, 2, 3}, shape.asArray()); + } + + @Test + public void shapeOfRagged3DMatrix() { + Shape shape = + StdArrays.shapeOf( + new int[][][] { + {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, + {{1, 2, 3}, {4, 5, 6}} + }); + assertArrayEquals(new long[] {2, Shape.UNKNOWN_SIZE, 3}, shape.asArray()); + } + + @Test + public void shapeOfEmptyArray() { + Shape shape = StdArrays.shapeOf(new int[2][2][3]); + assertArrayEquals(new long[] {2, 2, 3}, shape.asArray()); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java new file mode 100644 index 00000000000..5dbb5b034eb --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java @@ -0,0 +1,171 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.benchmark; + +import static org.tensorflow.ndarray.index.Indices.all; +import static org.tensorflow.ndarray.index.Indices.at; + +import java.awt.image.BufferedImage; +import java.awt.image.Raster; +import java.io.IOException; +import javax.imageio.ImageIO; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.RunnerException; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; + +@Fork( + value = 1, + jvmArgs = {"-Xms4G", "-Xmx4G"}) +@BenchmarkMode(Mode.AverageTime) +@Warmup(iterations = 3) +@Measurement(iterations = 5) +@State(Scope.Benchmark) +public class NdArrayBenchmark { + + public static void main(String[] args) throws IOException, RunnerException { + org.openjdk.jmh.Main.main(args); + } + + @Setup + public void setUp() throws IOException { + BufferedImage image = ImageIO.read(getClass().getClassLoader().getResourceAsStream(TEST_IMAGE)); + + int numPixels = image.getWidth() * image.getHeight(); + pixels = NdArrays.ofFloats(Shape.of(numPixels, 3)); + channels = NdArrays.ofFloats(Shape.of(3, numPixels)); + + Raster imageData = image.getData(); + float[] pixel = new float[3]; + for (int y = 0, pixelIdx = 0; y < image.getHeight(); ++y) { + for (int x = 0; x < image.getWidth(); ++x, ++pixelIdx) { + imageData.getPixel(x, y, pixel); + StdArrays.copyTo(pixel, pixels.get(pixelIdx)); + StdArrays.copyTo(pixel, channels.slice(all(), at(pixelIdx))); + } + } + batches = NdArrays.ofFloats(Shape.of(BATCH_SIZE, 3, numPixels)); + firstBatch = batches.get(0); + } + + @Benchmark + @Measurement(batchSize = 2049 * 1537) + public void getElementAtIndex() { + pixels.get(0); + } + + @Benchmark + @Measurement(batchSize = 2049 * 1537) + public void slicing() { + batches.slice(at(0), all(), at(0)); + } + + @Benchmark + public void readingAllPixelsChannelsBySequence() { + pixels.scalars().forEach(pixel -> pixel.getFloat()); + } + + @Benchmark + public void readingAllPixelsChannelsBySequenceSlices() { + pixels.scalars().asSlices().forEach(pixel -> pixel.getFloat()); + } + + @Benchmark + @Measurement(batchSize = 100) + public void readingAllPixelsChannelsByIndex() { + long[] shape = pixels.shape().asArray(); + for (int i = 0; i < shape[0]; ++i) { + for (int j = 0; j < shape[1]; ++j) { + pixels.getFloat(i, j); + } + } + } + + @Benchmark + @Measurement(batchSize = BATCH_SIZE) + public void writeFirstBatchChannels() { + firstBatch.set(channels); + } + + @Benchmark + public void writeAllBatchChannels() { + batches.elements(0).forEach(batch -> batch.set(channels)); + } + + @Benchmark + @Measurement(batchSize = 2049 * 1537) + public void writeOnePixelBySlicing() { + batches.slice(at(0), all(), at(0)).set(pixels.get(0)); + } + + @Benchmark + public void writeAllPixelsBySlicing() { + batches + .elements(0) + .forEach( + batch -> + pixels + .elements(0) + .forEachIndexed( + (coords, pixel) -> batch.slice(all(), at(coords[0])).set(pixel))); + } + + @Benchmark + @Measurement(batchSize = 2049 * 1537) + public void writeOnePixelsByIndex() { + batches + .setFloat(pixels.getFloat(0, 0), 0, 0, 0) + .setFloat(pixels.getFloat(0, 1), 0, 1, 0) + .setFloat(pixels.getFloat(0, 2), 0, 2, 0); + } + + @Benchmark + public void writeAllPixelsByIndex() { + batches + .elements(0) + .forEach( + batch -> + pixels + .elements(0) + .forEachIndexed( + (coords, pixel) -> { + long pixelIndex = coords[0]; + batch + .setFloat(pixel.getFloat(0), 0, pixelIndex) + .setFloat(pixel.getFloat(1), 1, pixelIndex) + .setFloat(pixel.getFloat(2), 2, pixelIndex); + })); + } + + private static final String TEST_IMAGE = "castle.jpg"; + private static final int BATCH_SIZE = 60; + + private FloatNdArray pixels; + private FloatNdArray channels; + private FloatNdArray batches; + private FloatNdArray firstBatch; +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/BooleanDataBufferTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/BooleanDataBufferTestBase.java new file mode 100644 index 00000000000..e1d522e689f --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/BooleanDataBufferTestBase.java @@ -0,0 +1,134 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.BitSet; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.impl.buffer.misc.MiscDataBufferFactory; + +public abstract class BooleanDataBufferTestBase extends DataBufferTestBase { + + @Override + protected abstract BooleanDataBuffer allocate(long size); + + @Override + protected Boolean valueOf(Long val) { + return val != 0; + } + + @Test + public void writeAndReadFromArray() { + BooleanDataBuffer buffer = allocate(10L); + boolean[] values = new boolean[] {true, false, false, true, false}; + + buffer.write(values); + assertTrue(buffer.getObject(0)); + assertFalse(buffer.getObject(1)); + + buffer.offset(5).write(values); + assertTrue(buffer.getObject(5)); + + boolean[] read = new boolean[5]; + buffer.read(read); + assertArrayEquals(values, read); + + buffer.write(values, 2, 3); + assertFalse(buffer.getObject(0)); + assertTrue(buffer.getObject(1)); + assertFalse(buffer.getObject(2)); + + Arrays.fill(read, false); + buffer.read(read, 1, 2); + assertFalse(read[0]); + assertFalse(read[1]); + assertTrue(read[2]); + assertFalse(read[3]); + } + + @Test + public void equalWithBitSetBuffer() { + BitSet bitSet1 = BitSet.valueOf(new byte[] {0x01, 0x01}); + BooleanDataBuffer bitSet1Buffer = MiscDataBufferFactory.create(bitSet1, 12, true); + + BitSet bitSet2 = BitSet.valueOf(new byte[] {0x11, 0x01}); + BooleanDataBuffer bitSet2Buffer = MiscDataBufferFactory.create(bitSet2, 12, true); + + BooleanDataBuffer buffer = allocate(12).setBoolean(true, 0).setBoolean(true, 8); + + assertTrue(bitSet1Buffer.equals(buffer)); + assertTrue(buffer.equals(bitSet1Buffer)); + assertEquals(bitSet1Buffer.hashCode(), buffer.hashCode()); + + assertFalse(bitSet2Buffer.equals(buffer)); + assertFalse(buffer.equals(bitSet2Buffer)); + assertNotEquals(bitSet2Buffer.hashCode(), buffer.hashCode()); + } + + @Test + public void equalWithBooleanArrayBuffer() { + boolean[] array1 = new boolean[] {false, false, false, true, true, false}; + BooleanDataBuffer array1Buffer = MiscDataBufferFactory.create(array1, true); + + boolean[] array2 = new boolean[] {false, false, false, true, true, true}; + BooleanDataBuffer array2Buffer = MiscDataBufferFactory.create(array2, true); + + BooleanDataBuffer buffer = allocate(6).setBoolean(true, 3).setBoolean(true, 4); + + assertTrue(array1Buffer.equals(buffer)); + assertTrue(buffer.equals(array1Buffer)); + assertEquals(array1Buffer.hashCode(), buffer.hashCode()); + + assertFalse(array2Buffer.equals(buffer)); + assertFalse(buffer.equals(array2Buffer)); + assertNotEquals(array2Buffer.hashCode(), buffer.hashCode()); + } + + @Test + public void equalWithBooleanObjectBuffer() { + Boolean[] array1 = new Boolean[] {false, false, false, true, true, false}; + DataBuffer array1Buffer = MiscDataBufferFactory.create(array1, true); + + boolean[] array2 = new boolean[] {false, false, false, true, true, true}; + DataBuffer array2Buffer = MiscDataBufferFactory.create(array2, true); + + BooleanDataBuffer buffer = allocate(6).setBoolean(true, 3).setBoolean(true, 4); + + assertTrue(array1Buffer.equals(buffer)); + assertTrue(buffer.equals(array1Buffer)); + assertEquals(array1Buffer.hashCode(), buffer.hashCode()); + + assertFalse(array2Buffer.equals(buffer)); + assertFalse(buffer.equals(array2Buffer)); + assertNotEquals(array2Buffer.hashCode(), buffer.hashCode()); + } + + @Test + public void notEqualWithOtherTypes() { + BooleanDataBuffer buffer = allocate(2).setBoolean(false, 0).setBoolean(true, 1); + ByteDataBuffer byteBuffer = DataBuffers.of((byte) 0, (byte) 1); + + assertFalse(buffer.equals(byteBuffer)); + assertFalse(byteBuffer.equals(buffer)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/ByteDataBufferTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/ByteDataBufferTestBase.java new file mode 100644 index 00000000000..59f27cabfae --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/ByteDataBufferTestBase.java @@ -0,0 +1,139 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.impl.buffer.misc.MiscDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.nio.NioDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; + +public abstract class ByteDataBufferTestBase extends DataBufferTestBase { + + @Override + protected abstract ByteDataBuffer allocate(long size); + + @Override + protected Byte valueOf(Long val) { + return val.byteValue(); + } + + @Test + public void writeAndReadFromArray() { + ByteDataBuffer buffer = allocate(10L); + byte[] oneToFive = new byte[] {1, 2, 3, 4, 5}; + + buffer.write(oneToFive); + assertEquals(2, buffer.getByte(1)); + + buffer.offset(5).write(oneToFive); + assertEquals(2, buffer.getByte(1)); + assertEquals(2, buffer.getByte(6)); + + byte[] read = new byte[5]; + buffer.read(read); + assertArrayEquals(oneToFive, read); + + buffer.write(oneToFive, 2, 2); + assertEquals(3, buffer.getByte(0)); + assertEquals(4, buffer.getByte(1)); + assertEquals(3, buffer.getByte(2)); + + Arrays.fill(read, valueOf(0L)); + buffer.read(read, 1, 2); + assertEquals(0, read[0]); + assertEquals(3, read[1]); + assertEquals(4, read[2]); + assertEquals(0, read[3]); + } + + @Test + public void equalWithByteNioBuffer() { + ByteDataBuffer nioBuffer1 = + NioDataBufferFactory.create(ByteBuffer.wrap(new byte[] {0x01, 0x10})); + ByteDataBuffer nioBuffer2 = + NioDataBufferFactory.create(ByteBuffer.wrap(new byte[] {0x01, 0x11})); + + ByteDataBuffer buffer = allocate(2).setByte((byte) 0x01, 0).setByte((byte) 0x10, 1); + + assertTrue(nioBuffer1.equals(buffer)); + assertTrue(buffer.equals(nioBuffer1)); + assertEquals(nioBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(nioBuffer2.equals(buffer)); + assertFalse(buffer.equals(nioBuffer2)); + assertNotEquals(nioBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void equalWithByteRawBuffer() { + ByteDataBuffer rawBuffer1 = RawDataBufferFactory.create(new byte[] {0x01, 0x10}, true); + ByteDataBuffer rawBuffer2 = RawDataBufferFactory.create(new byte[] {0x01, 0x11}, true); + + ByteDataBuffer buffer = allocate(2).setByte((byte) 0x01, 0).setByte((byte) 0x10, 1); + + assertTrue(rawBuffer1.equals(buffer)); + assertTrue(buffer.equals(rawBuffer1)); + assertEquals(rawBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(rawBuffer2.equals(buffer)); + assertFalse(buffer.equals(rawBuffer2)); + assertNotEquals(rawBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void equalWithByteObjectBuffer() { + DataBuffer objBuffer1 = MiscDataBufferFactory.create(new Byte[] {0x01, 0x10}, true); + DataBuffer objBuffer2 = MiscDataBufferFactory.create(new Byte[] {0x01, 0x11}, true); + + ByteDataBuffer buffer = allocate(2).setByte((byte) 0x01, 0).setByte((byte) 0x10, 1); + + assertTrue(objBuffer1.equals(buffer)); + assertTrue(buffer.equals(objBuffer1)); + assertEquals(objBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(objBuffer2.equals(buffer)); + assertFalse(buffer.equals(objBuffer2)); + assertNotEquals(objBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void notEqualWithOtherTypes() { + ByteDataBuffer buffer = allocate(2).setByte((byte) 1, 0).setByte((byte) 16, 1); + LongDataBuffer longBuffer = DataBuffers.of(1L, 16L); + + assertFalse(buffer.equals(longBuffer)); + assertFalse(longBuffer.equals(buffer)); + + try { + IntDataBuffer intBuffer = buffer.asInts(); + + assertFalse(buffer.equals(intBuffer)); + assertFalse(intBuffer.equals(buffer)); + + } catch (IllegalStateException e) { + // some byte buffers cannot be converted to ints, ignore the test in that case + } + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/DataBufferTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/DataBufferTestBase.java new file mode 100644 index 00000000000..46ec6520210 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/DataBufferTestBase.java @@ -0,0 +1,281 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.nio.BufferOverflowException; +import java.nio.BufferUnderflowException; +import org.junit.jupiter.api.Test; + +public abstract class DataBufferTestBase { + + protected final boolean enableLargeBufferTests = System.getProperty("testLargeBuffers") != null; + + protected long maxSize() { + return DataBuffers.MAX_32BITS; + } + + protected abstract DataBuffer allocate(long size); + + protected abstract T valueOf(Long val); + + @Test + public void bufferSize() { + DataBuffer buffer = allocate(10L); + assertEquals(10L, buffer.size()); + + buffer = allocate(0L); + assertEquals(0L, buffer.size()); + + if (enableLargeBufferTests) { + buffer = allocate(maxSize()); + assertEquals(maxSize(), buffer.size()); + } + } + + @Test + public void offsetNarrowAndSlice() { + DataBuffer buffer = allocate(10L).setObject(valueOf(1L), 5); // 0, 0, 0, 0, 0, 1, 0, 0, 0, 0 + assertEquals(10L, buffer.size()); + assertEquals(valueOf(1L), buffer.getObject(5)); + + DataBuffer subBuffer = buffer.slice(2, 6); // 0, 0, 0, 1, 0, 0 + assertEquals(6L, subBuffer.size()); + assertEquals(valueOf(1L), subBuffer.getObject(3)); + + subBuffer = subBuffer.offset(2L); // 0, 1, 0, 0 + assertEquals(4L, subBuffer.size()); + assertEquals(valueOf(1L), subBuffer.getObject(1)); + + subBuffer = subBuffer.narrow(2L); // 0, 1 + assertEquals(2L, subBuffer.size()); + assertEquals(valueOf(1L), subBuffer.getObject(1)); + try { + subBuffer.getObject(2); + fail(); + } catch (IndexOutOfBoundsException e) { + // as expected + } + try { + buffer.slice(2, 12); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + try { + buffer.slice(-1, 3); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + try { + buffer.slice(2, -1); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + try { + buffer.offset(-1L); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + try { + buffer.offset(11L); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + try { + buffer.narrow(-1L); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + try { + buffer.narrow(11L); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + } + + @Test + public void putAndGet() { + DataBuffer buffer = allocate(10L); + + buffer.setObject(valueOf(5L), 5L); + assertEquals(valueOf(5L), buffer.getObject(5L)); + try { + buffer.setObject(valueOf(10L), 10L); + fail(); + } catch (IndexOutOfBoundsException e) { + // as expected + } + try { + buffer.getObject(10L); + fail(); + } catch (IndexOutOfBoundsException e) { + // as expected + } + try { + buffer.setObject(valueOf(-1L), -1L); + fail(); + } catch (IndexOutOfBoundsException e) { + // as expected + } + try { + buffer.getObject(-1L); + fail(); + } catch (IndexOutOfBoundsException e) { + // as expected + } + } + + @Test + public void copyToBuffer() { + DataBuffer srcBuffer = allocate(25L); + srcBuffer.setObject(valueOf(5L), 5L); + srcBuffer.setObject(valueOf(10L), 10L); + srcBuffer.setObject(valueOf(15L), 15L); + srcBuffer.setObject(valueOf(20L), 20L); + try { + srcBuffer.copyTo(srcBuffer, srcBuffer.size()); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } + DataBuffer dstBuffer = allocate(30L); + srcBuffer.copyTo(dstBuffer, srcBuffer.size()); + assertEquals(valueOf(5L), dstBuffer.getObject(5L)); + try { + srcBuffer.copyTo(dstBuffer, dstBuffer.size()); + fail(); + } catch (BufferUnderflowException e) { + // as expected + } + try { + dstBuffer.copyTo(srcBuffer, dstBuffer.size()); + fail(); + } catch (BufferOverflowException e) { + // as expected + } + } + + @Test + public void createFromVarargs() { + DataBuffer buffer = DataBuffers.ofObjects(valueOf(1L), valueOf(2L), valueOf(3L)); + assertEquals(3, buffer.size()); + assertEquals(valueOf(1L), buffer.getObject(0)); + assertEquals(valueOf(2L), buffer.getObject(1)); + assertEquals(valueOf(3L), buffer.getObject(2)); + } + + @Test + public void equalWithObjectBuffer() { + DataBuffer buffer1 = allocate(2).setObject(valueOf(0L), 0).setObject(valueOf(1L), 1); + DataBuffer buffer2 = allocate(2).setObject(valueOf(0L), 0).setObject(valueOf(1L), 1); + DataBuffer buffer3 = allocate(2).setObject(valueOf(1L), 0).setObject(valueOf(0L), 1); + DataBuffer buffer4 = allocate(1).setObject(valueOf(0L), 0); + DataBuffer buffer5 = + allocate(3).setObject(valueOf(0L), 0).setObject(valueOf(1L), 1).setObject(valueOf(2L), 2); + + assertTrue(buffer1.equals(buffer2)); + assertTrue(buffer2.equals(buffer1)); + assertEquals(buffer1.hashCode(), buffer1.hashCode()); + assertEquals(buffer1.hashCode(), buffer2.hashCode()); + + assertFalse(buffer3.equals(buffer1)); + assertFalse(buffer1.equals(buffer3)); + assertNotEquals(buffer3.hashCode(), buffer1.hashCode()); + + assertFalse(buffer4.equals(buffer1)); + assertFalse(buffer1.equals(buffer4)); + assertNotEquals(buffer4.hashCode(), buffer1.hashCode()); + + assertFalse(buffer5.equals(buffer1)); + assertFalse(buffer1.equals(buffer5)); + assertNotEquals(buffer5.hashCode(), buffer1.hashCode()); + } + + @Test + public void bufferWindow() { + DataBuffer buffer = allocate(20); + DataBufferWindow> bufferWindow; + try { + bufferWindow = buffer.window(4); + } catch (UnsupportedOperationException e) { + return; // skip test if this buffer does not support windows + } + assertEquals(0, bufferWindow.offset()); + assertEquals(4, bufferWindow.size()); + assertEquals(4, bufferWindow.buffer().size()); + + for (long i = 0; i < buffer.size(); ++i) { + buffer.setObject(valueOf(i), i); + } + assertEquals(valueOf(2L), bufferWindow.buffer().getObject(2)); + DataBuffer windowBuffer = bufferWindow.buffer(); + + bufferWindow.slide(10); + assertEquals(10, bufferWindow.offset()); + assertEquals(4, bufferWindow.size()); + assertEquals(valueOf(12L), bufferWindow.buffer().getObject(2)); + assertSame(windowBuffer, bufferWindow.buffer()); + + bufferWindow.slide(-2); + assertEquals(8, bufferWindow.offset()); + assertEquals(4, bufferWindow.size()); + assertEquals(valueOf(10L), bufferWindow.buffer().getObject(2)); + + bufferWindow.slideTo(16); + assertEquals(16, bufferWindow.offset()); + assertEquals(4, bufferWindow.size()); + assertEquals(valueOf(18L), bufferWindow.buffer().getObject(2)); + + try { + bufferWindow.slide(1); + fail(); + } catch (IndexOutOfBoundsException e) { + // as expected + } + try { + bufferWindow.slide(-17); + fail(); + } catch (IndexOutOfBoundsException e) { + // as expected + } + try { + bufferWindow.slideTo(-1); + fail(); + } catch (IndexOutOfBoundsException e) { + // as expected + } + try { + bufferWindow.slideTo(17); + fail(); + } catch (IndexOutOfBoundsException e) { + // as expected + } + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/DoubleDataBufferTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/DoubleDataBufferTestBase.java new file mode 100644 index 00000000000..c09badfc415 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/DoubleDataBufferTestBase.java @@ -0,0 +1,129 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.DoubleBuffer; +import java.util.Arrays; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.impl.buffer.misc.MiscDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.nio.NioDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; + +public abstract class DoubleDataBufferTestBase extends DataBufferTestBase { + + @Override + protected abstract DoubleDataBuffer allocate(long size); + + @Override + protected Double valueOf(Long val) { + return val.doubleValue(); + } + + @Test + public void writeAndReadFromArray() { + DoubleDataBuffer buffer = allocate(10L); + double[] oneToFive = new double[] {1.0, 2.0, 3.0, 4.0, 5.0}; + + buffer.write(oneToFive); + assertEquals(2.0, buffer.getDouble(1), 0.0); + + buffer.offset(5).write(oneToFive); + assertEquals(2.0, buffer.getDouble(1), 0.0); + assertEquals(2.0, buffer.getDouble(6), 0.0); + + double[] read = new double[5]; + buffer.read(read); + assertArrayEquals(oneToFive, read, 0.0); + + buffer.write(oneToFive, 2, 2); + assertEquals(3.0, buffer.getDouble(0), 0.0); + assertEquals(4.0, buffer.getDouble(1), 0.0); + assertEquals(3.0, buffer.getDouble(2), 0.0); + + Arrays.fill(read, valueOf(0L)); + buffer.read(read, 1, 2); + assertEquals(0.0, read[0], 0.0); + assertEquals(3.0, read[1], 0.0); + assertEquals(4.0, read[2], 0.0); + assertEquals(0.0, read[3], 0.0); + } + + @Test + public void equalWithDoubleNioBuffer() { + DoubleDataBuffer nioBuffer1 = + NioDataBufferFactory.create(DoubleBuffer.wrap(new double[] {1.0, 16.0})); + DoubleDataBuffer nioBuffer2 = + NioDataBufferFactory.create(DoubleBuffer.wrap(new double[] {1.0, 25.0})); + + DoubleDataBuffer buffer = allocate(2).setDouble(1.0, 0).setDouble(16.0, 1); + + assertTrue(nioBuffer1.equals(buffer)); + assertTrue(buffer.equals(nioBuffer1)); + assertEquals(nioBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(nioBuffer2.equals(buffer)); + assertFalse(buffer.equals(nioBuffer2)); + assertNotEquals(nioBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void equalWithDoubleRawBuffer() { + DoubleDataBuffer rawBuffer1 = RawDataBufferFactory.create(new double[] {1.0, 16.0}, true); + DoubleDataBuffer rawBuffer2 = RawDataBufferFactory.create(new double[] {1.0, 25.0}, true); + + DoubleDataBuffer buffer = allocate(2).setDouble(1.0, 0).setDouble(16.0, 1); + + assertTrue(rawBuffer1.equals(buffer)); + assertTrue(buffer.equals(rawBuffer1)); + assertEquals(rawBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(rawBuffer2.equals(buffer)); + assertFalse(buffer.equals(rawBuffer2)); + assertNotEquals(rawBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void equalWithDoubleObjectBuffer() { + DataBuffer objBuffer1 = MiscDataBufferFactory.create(new Double[] {1.0, 16.0}, true); + DataBuffer objBuffer2 = MiscDataBufferFactory.create(new Double[] {1.0, 25.0}, true); + + DoubleDataBuffer buffer = allocate(2).setDouble(1.0, 0).setDouble(16.0, 1); + + assertTrue(objBuffer1.equals(buffer)); + assertTrue(buffer.equals(objBuffer1)); + assertEquals(objBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(objBuffer2.equals(buffer)); + assertFalse(buffer.equals(objBuffer2)); + assertNotEquals(objBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void notEqualWithOtherTypes() { + DoubleDataBuffer buffer = allocate(2).setDouble(1.0, 0).setDouble(16.0, 1); + FloatDataBuffer floatBuffer = DataBuffers.of(1.0f, 16.0f); + + assertFalse(buffer.equals(floatBuffer)); + assertFalse(floatBuffer.equals(buffer)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/FloatDataBufferTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/FloatDataBufferTestBase.java new file mode 100644 index 00000000000..7fca8363634 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/FloatDataBufferTestBase.java @@ -0,0 +1,129 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.FloatBuffer; +import java.util.Arrays; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.impl.buffer.misc.MiscDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.nio.NioDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; + +public abstract class FloatDataBufferTestBase extends DataBufferTestBase { + + @Override + protected abstract FloatDataBuffer allocate(long size); + + @Override + protected Float valueOf(Long val) { + return val.floatValue(); + } + + @Test + public void writeAndReadFromArray() { + FloatDataBuffer buffer = allocate(10L); + float[] oneToFive = new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + + buffer.write(oneToFive); + assertEquals(2.0f, buffer.getFloat(1), 0.0f); + + buffer.offset(5).write(oneToFive); + assertEquals(2.0f, buffer.getFloat(1), 0.0f); + assertEquals(2.0f, buffer.getFloat(6), 0.0f); + + float[] read = new float[5]; + buffer.read(read); + assertArrayEquals(oneToFive, read, 0.0f); + + buffer.write(oneToFive, 2, 2); + assertEquals(3.0f, buffer.getFloat(0), 0.0f); + assertEquals(4.0f, buffer.getFloat(1), 0.0f); + assertEquals(3.0f, buffer.getFloat(2), 0.0f); + + Arrays.fill(read, valueOf(0L)); + buffer.read(read, 1, 2); + assertEquals(0.0f, read[0], 0.0f); + assertEquals(3.0f, read[1], 0.0f); + assertEquals(4.0f, read[2], 0.0f); + assertEquals(0.0f, read[3], 0.0f); + } + + @Test + public void equalWithFloatNioBuffer() { + FloatDataBuffer nioBuffer1 = + NioDataBufferFactory.create(FloatBuffer.wrap(new float[] {1.0f, 16.0f})); + FloatDataBuffer nioBuffer2 = + NioDataBufferFactory.create(FloatBuffer.wrap(new float[] {1.0f, 25.0f})); + + FloatDataBuffer buffer = allocate(2).setFloat(1.0f, 0).setFloat(16.0f, 1); + + assertTrue(nioBuffer1.equals(buffer)); + assertTrue(buffer.equals(nioBuffer1)); + assertEquals(nioBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(nioBuffer2.equals(buffer)); + assertFalse(buffer.equals(nioBuffer2)); + assertNotEquals(nioBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void equalWithFloatRawBuffer() { + FloatDataBuffer rawBuffer1 = RawDataBufferFactory.create(new float[] {1.0f, 16.0f}, true); + FloatDataBuffer rawBuffer2 = RawDataBufferFactory.create(new float[] {1.0f, 25.0f}, true); + + FloatDataBuffer buffer = allocate(2).setFloat(1.0f, 0).setFloat(16.0f, 1); + + assertTrue(rawBuffer1.equals(buffer)); + assertTrue(buffer.equals(rawBuffer1)); + assertEquals(rawBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(rawBuffer2.equals(buffer)); + assertFalse(buffer.equals(rawBuffer2)); + assertNotEquals(rawBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void equalWithFloatObjectBuffer() { + DataBuffer objBuffer1 = MiscDataBufferFactory.create(new Float[] {1.0f, 16.0f}, true); + DataBuffer objBuffer2 = MiscDataBufferFactory.create(new Float[] {1.0f, 25.0f}, true); + + FloatDataBuffer buffer = allocate(2).setFloat(1.0f, 0).setFloat(16.0f, 1); + + assertTrue(objBuffer1.equals(buffer)); + assertTrue(buffer.equals(objBuffer1)); + assertEquals(objBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(objBuffer2.equals(buffer)); + assertFalse(buffer.equals(objBuffer2)); + assertNotEquals(objBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void notEqualWithOtherTypes() { + FloatDataBuffer buffer = allocate(2).setFloat(1.0f, 0).setFloat(16.0f, 1); + DoubleDataBuffer doubleBuffer = DataBuffers.of(1.0, 16.0); + + assertFalse(buffer.equals(doubleBuffer)); + assertFalse(doubleBuffer.equals(buffer)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/IntDataBufferTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/IntDataBufferTestBase.java new file mode 100644 index 00000000000..7593411a85a --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/IntDataBufferTestBase.java @@ -0,0 +1,127 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.IntBuffer; +import java.util.Arrays; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.impl.buffer.misc.MiscDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.nio.NioDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; + +public abstract class IntDataBufferTestBase extends DataBufferTestBase { + + @Override + protected abstract IntDataBuffer allocate(long size); + + @Override + protected Integer valueOf(Long val) { + return val.intValue(); + } + + @Test + public void writeAndReadFromArray() { + IntDataBuffer buffer = allocate(10L); + int[] oneToFive = new int[] {1, 2, 3, 4, 5}; + + buffer.write(oneToFive); + assertEquals(2, buffer.getInt(1)); + + buffer.offset(5).write(oneToFive); + assertEquals(2, buffer.getInt(1)); + assertEquals(2, buffer.getInt(6)); + + int[] read = new int[5]; + buffer.read(read); + assertArrayEquals(oneToFive, read); + + buffer.write(oneToFive, 2, 2); + assertEquals(3, buffer.getInt(0)); + assertEquals(4, buffer.getInt(1)); + assertEquals(3, buffer.getInt(2)); + + Arrays.fill(read, valueOf(0L)); + buffer.read(read, 1, 2); + assertEquals(0, read[0]); + assertEquals(3, read[1]); + assertEquals(4, read[2]); + assertEquals(0, read[3]); + } + + @Test + public void equalWithIntNioBuffer() { + IntDataBuffer nioBuffer1 = NioDataBufferFactory.create(IntBuffer.wrap(new int[] {1, 16})); + IntDataBuffer nioBuffer2 = NioDataBufferFactory.create(IntBuffer.wrap(new int[] {1, 25})); + + IntDataBuffer buffer = allocate(2).setInt(1, 0).setInt(16, 1); + + assertTrue(nioBuffer1.equals(buffer)); + assertTrue(buffer.equals(nioBuffer1)); + assertEquals(nioBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(nioBuffer2.equals(buffer)); + assertFalse(buffer.equals(nioBuffer2)); + assertNotEquals(nioBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void equalWithIntRawBuffer() { + IntDataBuffer rawBuffer1 = RawDataBufferFactory.create(new int[] {1, 16}, true); + IntDataBuffer rawBuffer2 = RawDataBufferFactory.create(new int[] {1, 25}, true); + + IntDataBuffer buffer = allocate(2).setInt(1, 0).setInt(16, 1); + + assertTrue(rawBuffer1.equals(buffer)); + assertTrue(buffer.equals(rawBuffer1)); + assertEquals(rawBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(rawBuffer2.equals(buffer)); + assertFalse(buffer.equals(rawBuffer2)); + assertNotEquals(rawBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void equalWithIntObjectBuffer() { + DataBuffer objBuffer1 = MiscDataBufferFactory.create(new Integer[] {1, 16}, true); + DataBuffer objBuffer2 = MiscDataBufferFactory.create(new Integer[] {1, 25}, true); + + IntDataBuffer buffer = allocate(2).setInt(1, 0).setInt(16, 1); + + assertTrue(objBuffer1.equals(buffer)); + assertTrue(buffer.equals(objBuffer1)); + assertEquals(objBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(objBuffer2.equals(buffer)); + assertFalse(buffer.equals(objBuffer2)); + assertNotEquals(objBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void notEqualWithOtherTypes() { + IntDataBuffer buffer = allocate(2).setInt(1, 0).setInt(16, 1); + LongDataBuffer longBuffer = DataBuffers.of(1L, 16L); + + assertFalse(buffer.equals(longBuffer)); + assertFalse(longBuffer.equals(buffer)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/LongDataBufferTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/LongDataBufferTestBase.java new file mode 100644 index 00000000000..a3bdb068113 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/LongDataBufferTestBase.java @@ -0,0 +1,127 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.LongBuffer; +import java.util.Arrays; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.impl.buffer.misc.MiscDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.nio.NioDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; + +public abstract class LongDataBufferTestBase extends DataBufferTestBase { + + @Override + protected abstract LongDataBuffer allocate(long size); + + @Override + protected Long valueOf(Long val) { + return val; + } + + @Test + public void writeAndReadFromArray() { + LongDataBuffer buffer = allocate(10L); + long[] oneToFive = new long[] {1L, 2L, 3L, 4L, 5L}; + + buffer.write(oneToFive); + assertEquals(2, buffer.getLong(1)); + + buffer.offset(5).write(oneToFive); + assertEquals(2L, buffer.getLong(1)); + assertEquals(2L, buffer.getLong(6)); + + long[] read = new long[5]; + buffer.read(read); + assertArrayEquals(oneToFive, read); + + buffer.write(oneToFive, 2, 2); + assertEquals(3L, buffer.getLong(0)); + assertEquals(4L, buffer.getLong(1)); + assertEquals(3L, buffer.getLong(2)); + + Arrays.fill(read, valueOf(0L)); + buffer.read(read, 1, 2); + assertEquals(0L, read[0]); + assertEquals(3L, read[1]); + assertEquals(4L, read[2]); + assertEquals(0L, read[3]); + } + + @Test + public void equalWithLongNioBuffer() { + LongDataBuffer nioBuffer1 = NioDataBufferFactory.create(LongBuffer.wrap(new long[] {1, 16})); + LongDataBuffer nioBuffer2 = NioDataBufferFactory.create(LongBuffer.wrap(new long[] {1, 25})); + + LongDataBuffer buffer = allocate(2).setLong(1, 0).setLong(16, 1); + + assertTrue(nioBuffer1.equals(buffer)); + assertTrue(buffer.equals(nioBuffer1)); + assertEquals(nioBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(nioBuffer2.equals(buffer)); + assertFalse(buffer.equals(nioBuffer2)); + assertNotEquals(nioBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void equalWithLongRawBuffer() { + LongDataBuffer rawBuffer1 = RawDataBufferFactory.create(new long[] {1, 16}, true); + LongDataBuffer rawBuffer2 = RawDataBufferFactory.create(new long[] {1, 25}, true); + + LongDataBuffer buffer = allocate(2).setLong(1, 0).setLong(16, 1); + + assertTrue(rawBuffer1.equals(buffer)); + assertTrue(buffer.equals(rawBuffer1)); + assertEquals(rawBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(rawBuffer2.equals(buffer)); + assertFalse(buffer.equals(rawBuffer2)); + assertNotEquals(rawBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void equalWithLongObjectBuffer() { + DataBuffer objBuffer1 = MiscDataBufferFactory.create(new Long[] {1L, 16L}, true); + DataBuffer objBuffer2 = MiscDataBufferFactory.create(new Long[] {1L, 25L}, true); + + LongDataBuffer buffer = allocate(2).setLong(1, 0).setLong(16, 1); + + assertTrue(objBuffer1.equals(buffer)); + assertTrue(buffer.equals(objBuffer1)); + assertEquals(objBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(objBuffer2.equals(buffer)); + assertFalse(buffer.equals(objBuffer2)); + assertNotEquals(objBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void notEqualWithOtherTypes() { + LongDataBuffer buffer = allocate(2).setLong(1L, 0).setLong(16L, 1); + IntDataBuffer intBuffer = DataBuffers.of(1, 16); + + assertFalse(buffer.equals(intBuffer)); + assertFalse(intBuffer.equals(buffer)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/ShortDataBufferTestBase.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/ShortDataBufferTestBase.java new file mode 100644 index 00000000000..40569842125 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/buffer/ShortDataBufferTestBase.java @@ -0,0 +1,127 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.buffer; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.ShortBuffer; +import java.util.Arrays; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.impl.buffer.misc.MiscDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.nio.NioDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; + +public abstract class ShortDataBufferTestBase extends DataBufferTestBase { + + @Override + protected abstract ShortDataBuffer allocate(long size); + + @Override + protected Short valueOf(Long val) { + return val.shortValue(); + } + + @Test + public void writeAndReadFromArray() { + ShortDataBuffer buffer = allocate(10L); + short[] oneToFive = new short[] {1, 2, 3, 4, 5}; + + buffer.write(oneToFive); + assertEquals(2, buffer.getShort(1)); + + buffer.offset(5).write(oneToFive); + assertEquals(2, buffer.getShort(1), 0); + assertEquals(2, buffer.getShort(6), 0); + + short[] read = new short[5]; + buffer.read(read); + assertArrayEquals(oneToFive, read); + + buffer.write(oneToFive, 2, 2); + assertEquals(3, buffer.getShort(0)); + assertEquals(4, buffer.getShort(1)); + assertEquals(3, buffer.getShort(2)); + + Arrays.fill(read, valueOf(0L)); + buffer.read(read, 1, 2); + assertEquals(0, read[0]); + assertEquals(3, read[1]); + assertEquals(4, read[2]); + assertEquals(0, read[3]); + } + + @Test + public void equalWithShortNioBuffer() { + ShortDataBuffer nioBuffer1 = NioDataBufferFactory.create(ShortBuffer.wrap(new short[] {1, 16})); + ShortDataBuffer nioBuffer2 = NioDataBufferFactory.create(ShortBuffer.wrap(new short[] {1, 25})); + + ShortDataBuffer buffer = allocate(2).setShort((short) 1, 0).setShort((short) 16, 1); + + assertTrue(nioBuffer1.equals(buffer)); + assertTrue(buffer.equals(nioBuffer1)); + assertEquals(nioBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(nioBuffer2.equals(buffer)); + assertFalse(buffer.equals(nioBuffer2)); + assertNotEquals(nioBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void equalWithShortRawBuffer() { + ShortDataBuffer rawBuffer1 = RawDataBufferFactory.create(new short[] {1, 16}, true); + ShortDataBuffer rawBuffer2 = RawDataBufferFactory.create(new short[] {1, 25}, true); + + ShortDataBuffer buffer = allocate(2).setShort((short) 1, 0).setShort((short) 16, 1); + + assertTrue(rawBuffer1.equals(buffer)); + assertTrue(buffer.equals(rawBuffer1)); + assertEquals(rawBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(rawBuffer2.equals(buffer)); + assertFalse(buffer.equals(rawBuffer2)); + assertNotEquals(rawBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void equalWithShortObjectBuffer() { + DataBuffer objBuffer1 = MiscDataBufferFactory.create(new Short[] {1, 16}, true); + DataBuffer objBuffer2 = MiscDataBufferFactory.create(new Short[] {1, 25}, true); + + ShortDataBuffer buffer = allocate(2).setShort((short) 1, 0).setShort((short) 16, 1); + + assertTrue(objBuffer1.equals(buffer)); + assertTrue(buffer.equals(objBuffer1)); + assertEquals(objBuffer1.hashCode(), buffer.hashCode()); + + assertFalse(objBuffer2.equals(buffer)); + assertFalse(buffer.equals(objBuffer2)); + assertNotEquals(objBuffer2.hashCode(), buffer.hashCode()); + } + + @Test + public void notEqualWithOtherTypes() { + ShortDataBuffer buffer = allocate(2).setShort((short) 1, 0).setShort((short) 16, 1); + LongDataBuffer longBuffer = DataBuffers.of(1L, 16L); + + assertFalse(buffer.equals(longBuffer)); + assertFalse(longBuffer.equals(buffer)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/BigIntegerDataBufferAdapterTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/BigIntegerDataBufferAdapterTest.java new file mode 100644 index 00000000000..f8109666b1f --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/BigIntegerDataBufferAdapterTest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import java.math.BigInteger; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBufferTestBase; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.layout.DataLayout; + +public class BigIntegerDataBufferAdapterTest extends DataBufferTestBase { + + @Override + protected DataBuffer allocate(long size) { + return LAYOUT.applyTo(DataBuffers.ofBytes(size * LAYOUT.scale())); + } + + @Override + protected long maxSize() { + return super.maxSize() / 3; + } + + @Override + protected BigInteger valueOf(Long val) { + return BigInteger.valueOf(val); + } + + private static DataLayout LAYOUT = + new DataLayout() { + + @Override + public void writeObject(ByteDataBuffer buffer, BigInteger value, long index) { + byte[] bytes = value.toByteArray(); + buffer.setByte(bytes.length > 2 ? bytes[2] : 0, index); + buffer.setByte(bytes.length > 1 ? bytes[1] : 0, index + 1); + buffer.setByte(bytes[0], index + 2); + } + + @Override + public BigInteger readObject(ByteDataBuffer buffer, long index) { + byte byte2 = buffer.getByte(index); + byte byte1 = buffer.getByte(index + 1); + byte byte0 = buffer.getByte(index + 2); + return new BigInteger(new byte[] {byte2, byte1, byte0}); + } + + @Override + public int scale() { + return 3; + } + }; +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/BooleanDataBufferAdapterTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/BooleanDataBufferAdapterTest.java new file mode 100644 index 00000000000..9507cef3456 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/BooleanDataBufferAdapterTest.java @@ -0,0 +1,46 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.BooleanDataBufferTestBase; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.layout.BooleanDataLayout; + +public class BooleanDataBufferAdapterTest extends BooleanDataBufferTestBase { + + @Override + protected BooleanDataBuffer allocate(long size) { + return LAYOUT.applyTo(DataBuffers.ofBytes(size * LAYOUT.scale())); + } + + private static BooleanDataLayout LAYOUT = + new BooleanDataLayout() { + + @Override + public void writeBoolean(ByteDataBuffer buffer, boolean value, long index) { + buffer.setByte((byte) (value ? 1 : 0), index); + } + + @Override + public boolean readBoolean(ByteDataBuffer buffer, long index) { + return buffer.getByte(index) > 0; + } + }; +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/ByteDataBufferAdapterTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/ByteDataBufferAdapterTest.java new file mode 100644 index 00000000000..59462ba436a --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/ByteDataBufferAdapterTest.java @@ -0,0 +1,28 @@ +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBufferTestBase; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.buffer.layout.ByteDataLayout; + +public class ByteDataBufferAdapterTest extends ByteDataBufferTestBase { + + public ByteDataBuffer allocate(long size) { + return LAYOUT.applyTo(DataBuffers.ofShorts(size * LAYOUT.scale())); + } + + private static ByteDataLayout LAYOUT = + new ByteDataLayout() { + + @Override + public void writeByte(ShortDataBuffer buffer, byte value, long index) { + buffer.setShort(value, index); + } + + @Override + public byte readByte(ShortDataBuffer buffer, long index) { + return (byte) buffer.getShort(index); + } + }; +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/DoubleDataBufferAdapterTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/DoubleDataBufferAdapterTest.java new file mode 100644 index 00000000000..898409f3541 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/DoubleDataBufferAdapterTest.java @@ -0,0 +1,63 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBufferTestBase; +import org.tensorflow.ndarray.buffer.layout.DoubleDataLayout; + +public class DoubleDataBufferAdapterTest extends DoubleDataBufferTestBase { + + @Override + protected DoubleDataBuffer allocate(long size) { + return LAYOUT.applyTo(DataBuffers.ofBytes(size * LAYOUT.scale())); + } + + @Override + protected long maxSize() { + return super.maxSize() / 3; + } + + private static DoubleDataLayout LAYOUT = + new DoubleDataLayout() { + + @Override + public void writeDouble(ByteDataBuffer buffer, double value, long index) { + long bits = Double.doubleToLongBits(value); + buffer.setByte((byte) ((bits >> 56) & 0xFF), index); + buffer.setByte((byte) ((bits >> 48) & 0xFF), index + 1); + buffer.setByte((byte) ((bits >> 40) & 0xFF), index + 2); + } + + @Override + public double readDouble(ByteDataBuffer buffer, long index) { + long byte7 = buffer.getByte(index); + long byte6 = buffer.getByte(index + 1); + long byte5 = buffer.getByte(index + 2); + return Double.longBitsToDouble( + ((byte7 & 0xFF) << 56) | ((byte6 & 0xFF) << 48) | ((byte5 & 0xFF) << 40)); + } + + @Override + public int scale() { + return 3; + } + }; +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/FloatDataBufferAdapterTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/FloatDataBufferAdapterTest.java new file mode 100644 index 00000000000..325ef9c05cf --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/FloatDataBufferAdapterTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBufferTestBase; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.buffer.layout.FloatDataLayout; + +public class FloatDataBufferAdapterTest extends FloatDataBufferTestBase { + + @Override + public FloatDataBuffer allocate(long size) { + return LAYOUT.applyTo(DataBuffers.ofShorts(size * LAYOUT.scale())); + } + + @Override + protected long maxSize() { + return super.maxSize() / 2; + } + + private static FloatDataLayout LAYOUT = + new FloatDataLayout() { + + @Override + public void writeFloat(ShortDataBuffer buffer, float value, long index) { + int bits = Float.floatToIntBits(value); + buffer.setShort((short) (bits >> 16), index); + } + + @Override + public float readFloat(ShortDataBuffer buffer, long index) { + int i = buffer.getShort(index); + return Float.intBitsToFloat(i << 16); + } + }; +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/IntDataBufferAdapterTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/IntDataBufferAdapterTest.java new file mode 100644 index 00000000000..ac045e24662 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/IntDataBufferAdapterTest.java @@ -0,0 +1,52 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBufferTestBase; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.buffer.layout.IntDataLayout; + +public class IntDataBufferAdapterTest extends IntDataBufferTestBase { + + @Override + protected IntDataBuffer allocate(long size) { + return LAYOUT.applyTo(DataBuffers.ofShorts(size * LAYOUT.scale())); + } + + @Override + protected long maxSize() { + return super.maxSize() / 2; + } + + private static IntDataLayout LAYOUT = + new IntDataLayout() { + + @Override + public void writeInt(ShortDataBuffer buffer, int value, long index) { + buffer.setShort((short) (((value & 0x80000000) >> 16) | (value & 0x7FFF)), index); + } + + @Override + public int readInt(ShortDataBuffer buffer, long index) { + int i = buffer.getShort(index); + return ((i & 0x8000) << 16) | ((i & 0x7FFF)); + } + }; +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/LongDataBufferAdapterTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/LongDataBufferAdapterTest.java new file mode 100644 index 00000000000..bdb17d50fed --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/LongDataBufferAdapterTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBufferTestBase; +import org.tensorflow.ndarray.buffer.layout.LongDataLayout; + +public class LongDataBufferAdapterTest extends LongDataBufferTestBase { + + @Override + protected LongDataBuffer allocate(long size) { + return LAYOUT.applyTo(DataBuffers.ofBytes(size * LAYOUT.scale())); + } + + @Override + protected long maxSize() { + return super.maxSize() / 3; + } + + private static LongDataLayout LAYOUT = + new LongDataLayout() { + + @Override + public void writeLong(ByteDataBuffer buffer, long value, long index) { + buffer.setByte((byte) (((value >> 56) & 0x80) | ((value >> 16) & 0x7F)), index); + buffer.setByte((byte) ((value >> 8) & 0xFF), index + 1); + buffer.setByte((byte) (value & 0xFF), index + 2); + } + + @Override + public long readLong(ByteDataBuffer buffer, long index) { + long msb = buffer.getByte(index); + long midb = buffer.getByte(index + 1); + long lsb = buffer.getByte(index + 2); + return ((msb & 0x80) << 56) | ((msb & 0x7F) << 16) | ((midb & 0xFF) << 8) | (lsb & 0xFF); + } + + @Override + public int scale() { + return 3; + } + }; +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/ShortDataBufferAdapterTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/ShortDataBufferAdapterTest.java new file mode 100644 index 00000000000..dd446028c60 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/adapter/ShortDataBufferAdapterTest.java @@ -0,0 +1,46 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.adapter; + +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBufferTestBase; +import org.tensorflow.ndarray.buffer.layout.ShortDataLayout; + +public class ShortDataBufferAdapterTest extends ShortDataBufferTestBase { + + public ShortDataBuffer allocate(long size) { + return LAYOUT.applyTo(DataBuffers.ofBytes(size * LAYOUT.scale())); + } + + private static ShortDataLayout LAYOUT = + new ShortDataLayout() { + + @Override + public void writeShort(ByteDataBuffer buffer, short value, long index) { + buffer.setByte((byte) (((value & 0x8000) >> 8) | (value & 0x7F)), index); + } + + @Override + public short readShort(ByteDataBuffer buffer, long index) { + int b = buffer.getByte(index); + return (short) (((b & 0x80) << 8) | (b & 0x7F)); + } + }; +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/layout/Bfloat16LayoutTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/layout/Bfloat16LayoutTest.java new file mode 100644 index 00000000000..30eff04bfac --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/layout/Bfloat16LayoutTest.java @@ -0,0 +1,84 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.layout; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public class Bfloat16LayoutTest { + + @Test + public void testFloat32to16() { + + // Zero and subnormals + assertEquals((short) 0x0000, Bfloat16Layout.float32to16(0.0f)); + assertEquals((short) 0x8000, Bfloat16Layout.float32to16(-0.0f)); + assertEquals((short) 0x0001, Bfloat16Layout.float32to16(1e-40f)); + assertEquals((short) 0xC000, Bfloat16Layout.float32to16(-2.0f)); + assertEquals((short) 0x0000, Bfloat16Layout.float32to16(4.59e-41f)); + + // Infinite and NaN + assertEquals((short) 0x7F80, Bfloat16Layout.float32to16(Float.POSITIVE_INFINITY)); + assertEquals((short) 0xFF80, Bfloat16Layout.float32to16(Float.NEGATIVE_INFINITY)); + assertEquals((short) 0x7FC0, Bfloat16Layout.float32to16(Float.NaN)); + assertEquals((short) 0x7FC0, Bfloat16Layout.float32to16(Float.intBitsToFloat(0xFFFFFFFF))); + + // Normalized + assertEquals((short) 0x3F80, Bfloat16Layout.float32to16(1.0f)); + assertEquals((short) 0xBF80, Bfloat16Layout.float32to16(-1.0f)); + assertEquals((short) 0x42C8, Bfloat16Layout.float32to16(100.0f)); + assertEquals((short) 0xC2CA, Bfloat16Layout.float32to16(-101.0f)); + assertEquals((short) 0x3F8F, Bfloat16Layout.float32to16(1.1171875f)); + assertEquals((short) 0x4800, Bfloat16Layout.float32to16(131072f)); + assertEquals((short) 0x7F7F, Bfloat16Layout.float32to16(3.3895314e38f)); + assertEquals((short) 0xFF7F, Bfloat16Layout.float32to16(-3.3895314e38f)); + + // Rounding up + assertEquals((short) 0x3FCF, Bfloat16Layout.float32to16(1.6191406f)); // 1.6171875 + assertEquals((short) 0x4780, Bfloat16Layout.float32to16(65600.0f)); // 65536.0 + } + + @Test + public void testFloat16to32() { + + // Zero and subnormals + assertEquals(0.0f, Bfloat16Layout.float16to32((short) 0x0000), 0); + assertEquals(-0.0f, Bfloat16Layout.float16to32((short) 0x8000), 0); + assertEquals(9.18355E-41f, Bfloat16Layout.float16to32((short) 0x0001), 1e-8f); + assertEquals(-9.403955E-38, Bfloat16Layout.float16to32((short) 0x8200), 1e-8f); + + // Infinite and NaN + assertEquals(Float.POSITIVE_INFINITY, Bfloat16Layout.float16to32((short) 0x7F80), 0); + assertEquals(Float.NEGATIVE_INFINITY, Bfloat16Layout.float16to32((short) 0xFF80), 0); + assertEquals(Float.NaN, Bfloat16Layout.float16to32((short) 0x7FC0), 0); + assertEquals(Float.intBitsToFloat(0xFFFFFFFF), Bfloat16Layout.float16to32((short) 0x7FC0), 0); + + // Normalized + assertEquals(1.0f, Bfloat16Layout.float16to32((short) 0x3F80), 0); + assertEquals(-1.0f, Bfloat16Layout.float16to32((short) 0xBF80), 0); + assertEquals(100.0f, Bfloat16Layout.float16to32((short) 0x42C8), 0); + assertEquals(-101.0f, Bfloat16Layout.float16to32((short) 0xC2CA), 0); + assertEquals(1.1171875f, Bfloat16Layout.float16to32((short) 0x3F8F), 0); + assertEquals(131072f, Bfloat16Layout.float16to32((short) 0x4800), 0); + assertEquals(3.3895314e38f, Bfloat16Layout.float16to32((short) 0x7F7F), 0); + assertEquals(-3.3895314e38f, Bfloat16Layout.float16to32((short) 0xFF7F), 0); + assertEquals(1.6171875f, Bfloat16Layout.float16to32((short) 0x3FCF), 0); + assertEquals(65536.0, Bfloat16Layout.float16to32((short) 0x4780), 0); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/layout/BoolLayoutTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/layout/BoolLayoutTest.java new file mode 100644 index 00000000000..7cdc010e478 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/layout/BoolLayoutTest.java @@ -0,0 +1,42 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.layout; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; + +public class BoolLayoutTest { + + @Test + public void booleanToByteTest() { + assertEquals((byte) 1, BoolLayout.booleanToByte(true)); + assertEquals((byte) 0, BoolLayout.booleanToByte(false)); + } + + @Test + public void byteToBooleanTest() { + assertTrue(BoolLayout.byteToBoolean((byte) 1)); + assertTrue(BoolLayout.byteToBoolean((byte) 127)); + assertTrue(BoolLayout.byteToBoolean((byte) -128)); + assertTrue(BoolLayout.byteToBoolean((byte) 255)); + assertFalse(BoolLayout.byteToBoolean((byte) 0)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/layout/Float16LayoutTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/layout/Float16LayoutTest.java new file mode 100644 index 00000000000..2c7c8c281a6 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/layout/Float16LayoutTest.java @@ -0,0 +1,90 @@ +/* + * Copyright 2020 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.buffer.layout; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public class Float16LayoutTest { + + @Test + public void testFloat32to16() { + + // Zero and subnormals + assertEquals((short) 0x0000, Float16Layout.float32to16(0.0f)); + assertEquals((short) 0x8000, Float16Layout.float32to16(-0.0f)); + assertEquals((short) 0x0001, Float16Layout.float32to16(6e-8f)); + assertEquals((short) 0x8200, Float16Layout.float32to16(-3.052e-5f)); + assertEquals((short) 0x0000, Float16Layout.float32to16(6e-9f)); + + // Infinite and NaN + assertEquals((short) 0x7C00, Float16Layout.float32to16(Float.POSITIVE_INFINITY)); + assertEquals((short) 0xFC00, Float16Layout.float32to16(Float.NEGATIVE_INFINITY)); + assertEquals((short) 0x7C00, Float16Layout.float32to16(65520.0f)); + assertEquals((short) 0x7C00, Float16Layout.float32to16(165536.0f)); + assertEquals((short) 0xFC00, Float16Layout.float32to16(-65520.0f)); + assertEquals((short) 0x7E00, Float16Layout.float32to16(Float.NaN)); + assertEquals((short) 0x7E00, Float16Layout.float32to16(Float.intBitsToFloat(0xFFFFFFFF))); + + // Normalized + assertEquals((short) 0x7BFF, Float16Layout.float32to16(65519.0f)); + assertEquals((short) 0x3C00, Float16Layout.float32to16(1.0f)); + assertEquals((short) 0xBC00, Float16Layout.float32to16(-1.0f)); + assertEquals((short) 0x5640, Float16Layout.float32to16(100.0f)); + assertEquals((short) 0xD650, Float16Layout.float32to16(-101.0f)); + assertEquals((short) 0x3C7E, Float16Layout.float32to16(1.123f)); + + // Rounding up + assertEquals((short) 0x3C7E, Float16Layout.float32to16(1.1235f)); // 1.123 + assertEquals((short) 0x3C7F, Float16Layout.float32to16(1.1236f)); // 1.124 + assertEquals((short) 0x4000, Float16Layout.float32to16(2.0009f)); // 2.0 + assertEquals((short) 0x4001, Float16Layout.float32to16(2.001f)); // 2.002 + assertEquals((short) 0x5C00, Float16Layout.float32to16(256.125f)); // 256.0 + assertEquals((short) 0x5C01, Float16Layout.float32to16(256.126f)); // 256.3 + assertEquals((short) 0x5C01, Float16Layout.float32to16(256.30f)); // 256.3 + assertEquals((short) 0x5C01, Float16Layout.float32to16(256.374f)); // 256.3 + assertEquals((short) 0x5C02, Float16Layout.float32to16(256.375f)); // 256.5 + assertEquals((short) 0x5C02, Float16Layout.float32to16(256.51f)); // 256.5 + } + + @Test + public void testFloat16to32() { + + // Zero and subnormals + assertEquals(0.0f, Float16Layout.float16to32((short) 0x0000), 0); + assertEquals(-0.0f, Float16Layout.float16to32((short) 0x8000), 0); + assertEquals(6e-8f, Float16Layout.float16to32((short) 0x0001), 1e-8f); + assertEquals(-3.052e-5f, Float16Layout.float16to32((short) 0x8200), 1e-8f); + + // Infinite and NaN + assertEquals(Float.POSITIVE_INFINITY, Float16Layout.float16to32((short) 0x7C00), 0); + assertEquals(Float.NEGATIVE_INFINITY, Float16Layout.float16to32((short) 0xFC00), 0); + assertEquals(Float.NaN, Float16Layout.float16to32((short) 0x7E00), 0); + assertEquals(Float.intBitsToFloat(0xFFFFFFFF), Float16Layout.float16to32((short) 0x7E00), 0); + + // Normalized + assertEquals(1.0f, Float16Layout.float16to32((short) 0x3C00), 1e-1f); + assertEquals(-1.0f, Float16Layout.float16to32((short) 0xBC00), 1e-1f); + assertEquals(100.0f, Float16Layout.float16to32((short) 0x5640), 1e-1f); + assertEquals(-101.0f, Float16Layout.float16to32((short) 0xD650), 1e-1f); + assertEquals(1.123f, Float16Layout.float16to32((short) 0x3C7E), 1e-3f); + assertEquals(1.123f, Float16Layout.float16to32((short) 0x3C7E), 1e-3f); + assertEquals(-62.34f, Float16Layout.float16to32((short) 0xD3CB), 1e-2f); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/misc/ArrayDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/misc/ArrayDataBufferTest.java new file mode 100644 index 00000000000..60ab337c8f2 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/misc/ArrayDataBufferTest.java @@ -0,0 +1,269 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.misc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.math.BigDecimal; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBufferTestBase; + +public class ArrayDataBufferTest extends DataBufferTestBase { + + @Override + protected DataBuffer allocate(long size) { + return new ArrayDataBuffer<>(new BigDecimal[(int) size], false); + } + + @Override + protected BigDecimal valueOf(Long val) { + return BigDecimal.valueOf(val); + } + + @Test + public void byteArrayBufferEquals() { + DataBuffer buffer1 = new ArrayDataBuffer<>(new byte[][] {{0x01}, {0x03}}, true); + DataBuffer buffer2 = new ArrayDataBuffer<>(new byte[][] {{0x01}, {0x03}}, true); + DataBuffer buffer3 = new ArrayDataBuffer<>(new byte[][] {{0x02}, {0x03}}, true); + DataBuffer buffer4 = new ArrayDataBuffer<>(new byte[][][] {{{0x01}}, {{0x03}}}, true); + DataBuffer buffer5 = new ArrayDataBuffer<>(new byte[][][] {{{0x01}}, {{0x03}}}, true); + + assertTrue(buffer1.equals(buffer2)); + assertTrue(buffer2.equals(buffer1)); + assertEquals(buffer1.hashCode(), buffer2.hashCode()); + + assertFalse(buffer1.equals(buffer3)); + assertFalse(buffer3.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer3.hashCode()); + + assertFalse(buffer1.equals(buffer4)); + assertFalse(buffer4.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer4.hashCode()); + + assertTrue(buffer4.equals(buffer5)); + assertTrue(buffer4.equals(buffer5)); + assertEquals(buffer4.hashCode(), buffer5.hashCode()); + } + + @Test + public void intArrayBufferEquals() { + DataBuffer buffer1 = new ArrayDataBuffer<>(new int[][] {{10}, {30}}, true); + DataBuffer buffer2 = new ArrayDataBuffer<>(new int[][] {{10}, {30}}, true); + DataBuffer buffer3 = new ArrayDataBuffer<>(new int[][] {{20}, {30}}, true); + DataBuffer buffer4 = new ArrayDataBuffer<>(new int[][][] {{{10}}, {{30}}}, true); + DataBuffer buffer5 = new ArrayDataBuffer<>(new int[][][] {{{10}}, {{30}}}, true); + + assertTrue(buffer1.equals(buffer2)); + assertTrue(buffer2.equals(buffer1)); + assertEquals(buffer1.hashCode(), buffer2.hashCode()); + + assertFalse(buffer1.equals(buffer3)); + assertFalse(buffer3.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer3.hashCode()); + + assertFalse(buffer1.equals(buffer4)); + assertFalse(buffer4.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer4.hashCode()); + + assertTrue(buffer4.equals(buffer5)); + assertTrue(buffer4.equals(buffer5)); + assertEquals(buffer4.hashCode(), buffer5.hashCode()); + } + + @Test + public void shortArrayBufferEquals() { + DataBuffer buffer1 = new ArrayDataBuffer<>(new short[][] {{10}, {30}}, true); + DataBuffer buffer2 = new ArrayDataBuffer<>(new short[][] {{10}, {30}}, true); + DataBuffer buffer3 = new ArrayDataBuffer<>(new short[][] {{20}, {30}}, true); + DataBuffer buffer4 = new ArrayDataBuffer<>(new short[][][] {{{10}}, {{30}}}, true); + DataBuffer buffer5 = new ArrayDataBuffer<>(new short[][][] {{{10}}, {{30}}}, true); + + assertTrue(buffer1.equals(buffer2)); + assertTrue(buffer2.equals(buffer1)); + assertEquals(buffer1.hashCode(), buffer2.hashCode()); + + assertFalse(buffer1.equals(buffer3)); + assertFalse(buffer3.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer3.hashCode()); + + assertFalse(buffer1.equals(buffer4)); + assertFalse(buffer4.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer4.hashCode()); + + assertTrue(buffer4.equals(buffer5)); + assertTrue(buffer4.equals(buffer5)); + assertEquals(buffer4.hashCode(), buffer5.hashCode()); + } + + @Test + public void longArrayBufferEquals() { + DataBuffer buffer1 = new ArrayDataBuffer<>(new long[][] {{10}, {30}}, true); + DataBuffer buffer2 = new ArrayDataBuffer<>(new long[][] {{10}, {30}}, true); + DataBuffer buffer3 = new ArrayDataBuffer<>(new long[][] {{20}, {30}}, true); + DataBuffer buffer4 = new ArrayDataBuffer<>(new long[][][] {{{10}}, {{30}}}, true); + DataBuffer buffer5 = new ArrayDataBuffer<>(new long[][][] {{{10}}, {{30}}}, true); + + assertTrue(buffer1.equals(buffer2)); + assertTrue(buffer2.equals(buffer1)); + assertEquals(buffer1.hashCode(), buffer2.hashCode()); + + assertFalse(buffer1.equals(buffer3)); + assertFalse(buffer3.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer3.hashCode()); + + assertFalse(buffer1.equals(buffer4)); + assertFalse(buffer4.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer4.hashCode()); + + assertTrue(buffer4.equals(buffer5)); + assertTrue(buffer4.equals(buffer5)); + assertEquals(buffer4.hashCode(), buffer5.hashCode()); + } + + @Test + public void floatArrayBufferEquals() { + DataBuffer buffer1 = new ArrayDataBuffer<>(new float[][] {{10}, {30}}, true); + DataBuffer buffer2 = new ArrayDataBuffer<>(new float[][] {{10}, {30}}, true); + DataBuffer buffer3 = new ArrayDataBuffer<>(new float[][] {{20}, {30}}, true); + DataBuffer buffer4 = new ArrayDataBuffer<>(new float[][][] {{{10}}, {{30}}}, true); + DataBuffer buffer5 = new ArrayDataBuffer<>(new float[][][] {{{10}}, {{30}}}, true); + + assertTrue(buffer1.equals(buffer2)); + assertTrue(buffer2.equals(buffer1)); + assertEquals(buffer1.hashCode(), buffer2.hashCode()); + + assertFalse(buffer1.equals(buffer3)); + assertFalse(buffer3.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer3.hashCode()); + + assertFalse(buffer1.equals(buffer4)); + assertFalse(buffer4.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer4.hashCode()); + + assertTrue(buffer4.equals(buffer5)); + assertTrue(buffer4.equals(buffer5)); + assertEquals(buffer4.hashCode(), buffer5.hashCode()); + } + + @Test + public void doubleArrayBufferEquals() { + DataBuffer buffer1 = new ArrayDataBuffer<>(new double[][] {{10}, {30}}, true); + DataBuffer buffer2 = new ArrayDataBuffer<>(new double[][] {{10}, {30}}, true); + DataBuffer buffer3 = new ArrayDataBuffer<>(new double[][] {{20}, {30}}, true); + DataBuffer buffer4 = new ArrayDataBuffer<>(new double[][][] {{{10}}, {{30}}}, true); + DataBuffer buffer5 = new ArrayDataBuffer<>(new double[][][] {{{10}}, {{30}}}, true); + + assertTrue(buffer1.equals(buffer2)); + assertTrue(buffer2.equals(buffer1)); + assertEquals(buffer1.hashCode(), buffer2.hashCode()); + + assertFalse(buffer1.equals(buffer3)); + assertFalse(buffer3.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer3.hashCode()); + + assertFalse(buffer1.equals(buffer4)); + assertFalse(buffer4.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer4.hashCode()); + + assertTrue(buffer4.equals(buffer5)); + assertTrue(buffer4.equals(buffer5)); + assertEquals(buffer4.hashCode(), buffer5.hashCode()); + } + + @Test + public void booleanArrayBufferEquals() { + DataBuffer buffer1 = new ArrayDataBuffer<>(new boolean[][] {{true}, {false}}, true); + DataBuffer buffer2 = new ArrayDataBuffer<>(new boolean[][] {{true}, {false}}, true); + DataBuffer buffer3 = new ArrayDataBuffer<>(new boolean[][] {{false}, {false}}, true); + DataBuffer buffer4 = + new ArrayDataBuffer<>(new boolean[][][] {{{true}}, {{false}}}, true); + DataBuffer buffer5 = + new ArrayDataBuffer<>(new boolean[][][] {{{true}}, {{false}}}, true); + + assertTrue(buffer1.equals(buffer2)); + assertTrue(buffer2.equals(buffer1)); + assertEquals(buffer1.hashCode(), buffer2.hashCode()); + + assertFalse(buffer1.equals(buffer3)); + assertFalse(buffer3.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer3.hashCode()); + + assertFalse(buffer1.equals(buffer4)); + assertFalse(buffer4.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer4.hashCode()); + + assertTrue(buffer4.equals(buffer5)); + assertTrue(buffer4.equals(buffer5)); + assertEquals(buffer4.hashCode(), buffer5.hashCode()); + } + + @Test + public void objectArrayBufferEquals() { + DataBuffer buffer1 = new ArrayDataBuffer<>(new String[][] {{"10"}, {"30"}}, true); + DataBuffer buffer2 = new ArrayDataBuffer<>(new String[][] {{"10"}, {"30"}}, true); + DataBuffer buffer3 = new ArrayDataBuffer<>(new String[][] {{"20"}, {"30"}}, true); + DataBuffer buffer4 = + new ArrayDataBuffer<>(new String[][][] {{{"10"}}, {{"30"}}}, true); + DataBuffer buffer5 = + new ArrayDataBuffer<>(new String[][][] {{{"10"}}, {{"30"}}}, true); + + assertTrue(buffer1.equals(buffer2)); + assertTrue(buffer2.equals(buffer1)); + assertEquals(buffer1.hashCode(), buffer2.hashCode()); + + assertFalse(buffer1.equals(buffer3)); + assertFalse(buffer3.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer3.hashCode()); + + assertFalse(buffer1.equals(buffer4)); + assertFalse(buffer4.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer4.hashCode()); + + assertTrue(buffer4.equals(buffer5)); + assertTrue(buffer4.equals(buffer5)); + assertEquals(buffer4.hashCode(), buffer5.hashCode()); + } + + @Test + public void nullableObjectArrayBufferEquals() { + DataBuffer buffer1 = new ArrayDataBuffer<>(new String[][] {null, {"30"}}, true); + DataBuffer buffer2 = new ArrayDataBuffer<>(new String[][] {null, {"30"}}, true); + DataBuffer buffer3 = new ArrayDataBuffer<>(new String[][] {{"20"}, {"30"}}, true); + DataBuffer buffer4 = new ArrayDataBuffer<>(new String[][][] {{{"10"}}, null}, true); + DataBuffer buffer5 = new ArrayDataBuffer<>(new String[][][] {{{"10"}}, null}, true); + + assertTrue(buffer1.equals(buffer2)); + assertTrue(buffer2.equals(buffer1)); + assertEquals(buffer1.hashCode(), buffer2.hashCode()); + + assertFalse(buffer1.equals(buffer3)); + assertFalse(buffer3.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer3.hashCode()); + + assertFalse(buffer1.equals(buffer4)); + assertFalse(buffer4.equals(buffer1)); + assertNotEquals(buffer1.hashCode(), buffer4.hashCode()); + + assertTrue(buffer4.equals(buffer5)); + assertTrue(buffer4.equals(buffer5)); + assertEquals(buffer4.hashCode(), buffer5.hashCode()); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/misc/BitSetDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/misc/BitSetDataBufferTest.java new file mode 100644 index 00000000000..2ebd7c492d3 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/misc/BitSetDataBufferTest.java @@ -0,0 +1,34 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.misc; + +import java.util.BitSet; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.BooleanDataBufferTestBase; + +public class BitSetDataBufferTest extends BooleanDataBufferTestBase { + + @Override + protected BooleanDataBuffer allocate(long size) { + return new BitSetDataBuffer(new BitSet((int) size), size, false); + } + + @Override + protected Boolean valueOf(Long val) { + return val != 0; + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/misc/StringArrayDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/misc/StringArrayDataBufferTest.java new file mode 100644 index 00000000000..e91f44bbb9e --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/misc/StringArrayDataBufferTest.java @@ -0,0 +1,33 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.misc; + +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBufferTestBase; + +public class StringArrayDataBufferTest extends DataBufferTestBase { + + @Override + protected DataBuffer allocate(long size) { + return new ArrayDataBuffer<>(new String[(int) size], false); + } + + @Override + protected String valueOf(Long val) { + return val.toString(); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/ByteNioDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/ByteNioDataBufferTest.java new file mode 100644 index 00000000000..8c80e1cbac5 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/ByteNioDataBufferTest.java @@ -0,0 +1,29 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.nio; + +import java.nio.ByteBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBufferTestBase; + +public class ByteNioDataBufferTest extends ByteDataBufferTestBase { + + @Override + protected ByteDataBuffer allocate(long size) { + return new ByteNioDataBuffer(ByteBuffer.allocate((int) size)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/DoubleNioDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/DoubleNioDataBufferTest.java new file mode 100644 index 00000000000..47b9562ec1e --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/DoubleNioDataBufferTest.java @@ -0,0 +1,29 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.nio; + +import java.nio.DoubleBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBufferTestBase; + +public class DoubleNioDataBufferTest extends DoubleDataBufferTestBase { + + @Override + protected DoubleDataBuffer allocate(long size) { + return new DoubleNioDataBuffer(DoubleBuffer.allocate((int) size)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/FloatNioDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/FloatNioDataBufferTest.java new file mode 100644 index 00000000000..2dfe3620556 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/FloatNioDataBufferTest.java @@ -0,0 +1,29 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.nio; + +import java.nio.FloatBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBufferTestBase; + +public class FloatNioDataBufferTest extends FloatDataBufferTestBase { + + @Override + protected FloatDataBuffer allocate(long size) { + return new FloatNioDataBuffer(FloatBuffer.allocate((int) size)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/IntNioDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/IntNioDataBufferTest.java new file mode 100644 index 00000000000..28e9525f4a0 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/IntNioDataBufferTest.java @@ -0,0 +1,29 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.nio; + +import java.nio.IntBuffer; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBufferTestBase; + +public class IntNioDataBufferTest extends IntDataBufferTestBase { + + @Override + protected IntDataBuffer allocate(long size) { + return new IntNioDataBuffer(IntBuffer.allocate((int) size)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/LongNioDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/LongNioDataBufferTest.java new file mode 100644 index 00000000000..57538c7d348 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/LongNioDataBufferTest.java @@ -0,0 +1,29 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.nio; + +import java.nio.LongBuffer; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBufferTestBase; + +public class LongNioDataBufferTest extends LongDataBufferTestBase { + + @Override + protected LongDataBuffer allocate(long size) { + return new LongNioDataBuffer(LongBuffer.allocate((int) size)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/ShortNioDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/ShortNioDataBufferTest.java new file mode 100644 index 00000000000..dc2d5f8aea6 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/nio/ShortNioDataBufferTest.java @@ -0,0 +1,29 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.nio; + +import java.nio.ShortBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBufferTestBase; + +public class ShortNioDataBufferTest extends ShortDataBufferTestBase { + + @Override + protected ShortDataBuffer allocate(long size) { + return new ShortNioDataBuffer(ShortBuffer.allocate((int) size)); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/BooleanRawDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/BooleanRawDataBufferTest.java new file mode 100644 index 00000000000..bd0f18d861c --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/BooleanRawDataBufferTest.java @@ -0,0 +1,29 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.raw; + +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.BooleanDataBufferTestBase; + +public class BooleanRawDataBufferTest extends BooleanDataBufferTestBase { + + @Override + protected BooleanDataBuffer allocate(long size) { + return new BooleanRawDataBuffer( + UnsafeMemoryHandle.fromArray(new boolean[(int) size], (int) size), false); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/ByteRawDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/ByteRawDataBufferTest.java new file mode 100644 index 00000000000..79d07e8644c --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/ByteRawDataBufferTest.java @@ -0,0 +1,29 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.raw; + +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.ByteDataBufferTestBase; + +public class ByteRawDataBufferTest extends ByteDataBufferTestBase { + + @Override + protected ByteDataBuffer allocate(long size) { + return new ByteRawDataBuffer( + UnsafeMemoryHandle.fromArray(new byte[(int) size], (int) size), false); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/DoubleRawDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/DoubleRawDataBufferTest.java new file mode 100644 index 00000000000..b2d82fc3d26 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/DoubleRawDataBufferTest.java @@ -0,0 +1,29 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.raw; + +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.buffer.DoubleDataBufferTestBase; + +public class DoubleRawDataBufferTest extends DoubleDataBufferTestBase { + + @Override + protected DoubleDataBuffer allocate(long size) { + return new DoubleRawDataBuffer( + UnsafeMemoryHandle.fromArray(new double[(int) size], (int) size), false); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/FloatRawDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/FloatRawDataBufferTest.java new file mode 100644 index 00000000000..ef4fbbce6cd --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/FloatRawDataBufferTest.java @@ -0,0 +1,29 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.raw; + +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.buffer.FloatDataBufferTestBase; + +public class FloatRawDataBufferTest extends FloatDataBufferTestBase { + + @Override + protected FloatDataBuffer allocate(long size) { + return new FloatRawDataBuffer( + UnsafeMemoryHandle.fromArray(new float[(int) size], (int) size), false); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/IntRawDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/IntRawDataBufferTest.java new file mode 100644 index 00000000000..f2efd0324cb --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/IntRawDataBufferTest.java @@ -0,0 +1,29 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.raw; + +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.buffer.IntDataBufferTestBase; + +public class IntRawDataBufferTest extends IntDataBufferTestBase { + + @Override + protected IntDataBuffer allocate(long size) { + return new IntRawDataBuffer( + UnsafeMemoryHandle.fromArray(new int[(int) size], (int) size), false); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/LongRawDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/LongRawDataBufferTest.java new file mode 100644 index 00000000000..e2cacf4a84d --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/LongRawDataBufferTest.java @@ -0,0 +1,29 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.raw; + +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.buffer.LongDataBufferTestBase; + +public class LongRawDataBufferTest extends LongDataBufferTestBase { + + @Override + protected LongDataBuffer allocate(long size) { + return new LongRawDataBuffer( + UnsafeMemoryHandle.fromArray(new long[(int) size], (int) size), false); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/ShortRawDataBufferTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/ShortRawDataBufferTest.java new file mode 100644 index 00000000000..887a3d747f7 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/buffer/raw/ShortRawDataBufferTest.java @@ -0,0 +1,29 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.buffer.raw; + +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.buffer.ShortDataBufferTestBase; + +public class ShortRawDataBufferTest extends ShortDataBufferTestBase { + + @Override + protected ShortDataBuffer allocate(long size) { + return new ShortRawDataBuffer( + UnsafeMemoryHandle.fromArray(new short[(int) size], (int) size), false); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArrayTest.java new file mode 100644 index 00000000000..35cbf07fab9 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/BooleanDenseNdArrayTest.java @@ -0,0 +1,47 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.BooleanNdArray; +import org.tensorflow.ndarray.BooleanNdArrayTestBase; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; + +public class BooleanDenseNdArrayTest extends BooleanNdArrayTestBase { + + @Override + protected BooleanNdArray allocate(Shape shape) { + return NdArrays.ofBooleans(shape); + } + + @Override + protected DataBuffer allocateBuffer(long size) { + return DataBuffers.ofBooleans(size); + } + + @Test + public void testToString() { + BooleanNdArray matrix3d = allocate(Shape.of(5, 4, 5)); + Assertions.assertEquals("BooleanDenseNdArray(shape=[5, 4, 5])", matrix3d.toString()); + BooleanNdArray scalar = allocate(Shape.of()); + Assertions.assertEquals("BooleanDenseNdArray(shape=[])", scalar.toString()); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArrayTest.java new file mode 100644 index 00000000000..848999025d9 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/ByteDenseNdArrayTest.java @@ -0,0 +1,37 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.ByteNdArray; +import org.tensorflow.ndarray.ByteNdArrayTestBase; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; + +public class ByteDenseNdArrayTest extends ByteNdArrayTestBase { + + @Override + protected ByteNdArray allocate(Shape shape) { + return NdArrays.ofBytes(shape); + } + + @Override + protected DataBuffer allocateBuffer(long size) { + return DataBuffers.ofBytes(size); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java new file mode 100644 index 00000000000..fb3a44ccb39 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java @@ -0,0 +1,56 @@ +package org.tensorflow.ndarray.impl.dense; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.IntNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.index.Indices; + +public class DenseNdArrayTest { + + @Test + public void arrayEquals() { + IntNdArray array = + NdArrays.ofInts(Shape.of(2, 2)) + .set(NdArrays.vectorOf(1, 2), 0) + .set(NdArrays.vectorOf(3, 4), 1); + + assertTrue(array.equals(StdArrays.ndCopyOf(new int[][] {{1, 2}, {3, 4}}))); + assertTrue(array.equals(StdArrays.ndCopyOf(new Integer[][] {{1, 2}, {3, 4}}))); + assertFalse(array.equals(NdArrays.vectorOf(1, 2, 3, 4))); + assertFalse(array.equals(StdArrays.ndCopyOf(new int[][] {{3, 4}, {1, 2}}))); + assertFalse(array.equals(StdArrays.ndCopyOf(new long[][] {{1L, 2L}, {3L, 4L}}))); + } + + @Test + public void equalsAndHashCodeOnSlices() { + IntNdArray vector1 = NdArrays.vectorOf(3, 4); + IntNdArray vector2 = NdArrays.vectorOf(1, 2, 3, 4); + IntNdArray matrix1 = StdArrays.ndCopyOf(new int[][] {{1, 2}, {3, 4}}); + IntNdArray matrix2 = StdArrays.ndCopyOf(new int[][] {{1, 0, 2, 0}, {3, 0, 4, 0}}); + IntNdArray matrix3d1 = + StdArrays.ndCopyOf( + new int[][][] { + {{1, 2}, {3, 4}}, + {{5, 6}, {7, 8}} + }); + IntNdArray matrix3d2 = + StdArrays.ndCopyOf( + new int[][][] { + {{1, 2}, {4, 5}}, + {{3, 4}, {6, 7}} + }); + + assertTrue(vector1.equals(vector2.slice(Indices.sliceFrom(2)))); + assertTrue(vector1.equals(matrix1.get(1))); + assertTrue(vector1.equals(matrix2.get(1).slice(Indices.even()))); + assertTrue(matrix1.equals(matrix2.slice(Indices.all(), Indices.even()))); + assertTrue(matrix3d1.get(0).equals(matrix1)); + assertFalse(matrix3d1.get(0).equals(vector2)); + assertTrue(matrix1.equals(matrix3d2.slice(Indices.all(), Indices.at(0)))); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArrayTest.java new file mode 100644 index 00000000000..1d5ad93bc27 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DoubleDenseNdArrayTest.java @@ -0,0 +1,49 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.DoubleNdArrayTestBase; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; + +public class DoubleDenseNdArrayTest extends DoubleNdArrayTestBase { + + @Override + protected DoubleNdArray allocate(Shape shape) { + return NdArrays.ofDoubles(shape); + } + + @Override + protected DataBuffer allocateBuffer(long size) { + return DataBuffers.ofDoubles(size); + } + + @Test + public void testToString() { + DoubleNdArray matrix3d = allocate(Shape.of(5, 4, 5)); + Assertions.assertEquals("DoubleDenseNdArray(shape=[5, 4, 5])", matrix3d.toString()); + DoubleNdArray vector = allocate(Shape.of(5)); + Assertions.assertEquals("DoubleDenseNdArray(shape=[5])", vector.toString()); + DoubleNdArray scalar = allocate(Shape.of()); + Assertions.assertEquals("DoubleDenseNdArray(shape=[])", scalar.toString()); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArrayTest.java new file mode 100644 index 00000000000..5023d832edd --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/FloatDenseNdArrayTest.java @@ -0,0 +1,76 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.FloatNdArrayTestBase; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.index.Indices; + +public class FloatDenseNdArrayTest extends FloatNdArrayTestBase { + + @Override + protected FloatNdArray allocate(Shape shape) { + return NdArrays.ofFloats(shape); + } + + @Override + protected DataBuffer allocateBuffer(long size) { + return DataBuffers.ofFloats(size); + } + + @Test + public void testSlice() { + Shape shape = Shape.of(3, 4); + Float[] values = { + 1f, 0f, 0f, 0f, + 0f, 0f, 2f, 0f, + 0f, 0f, 0f, 0f + }; + + float[] expected = {0, 0, 2, 0, 0, 0}; + + FloatDataBuffer buffer = (FloatDataBuffer) allocateBuffer(shape.size()); + buffer.write(values); + FloatNdArray instance = FloatDenseNdArray.create(buffer, shape); + + FloatNdArray sliceInstance = instance.slice(Indices.all(), Indices.sliceFrom(2)); + // check the values of the slice against the original array + AtomicInteger i = new AtomicInteger(); + sliceInstance + .scalars() + .forEachIndexed((idx, f) -> assertEquals(expected[i.getAndIncrement()], f.getFloat())); + + // check values from elements(0) of a slice against the original array + i.set(0); + sliceInstance + .elements(0) + .forEachIndexed( + (idx, l) -> + l.scalars() + .forEachIndexed( + (lidx, f) -> assertEquals(expected[i.getAndIncrement()], f.getFloat()))); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArrayTest.java new file mode 100644 index 00000000000..8a6496976ec --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/IntDenseNdArrayTest.java @@ -0,0 +1,37 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.IntNdArray; +import org.tensorflow.ndarray.IntNdArrayTestBase; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; + +public class IntDenseNdArrayTest extends IntNdArrayTestBase { + + @Override + protected IntNdArray allocate(Shape shape) { + return NdArrays.ofInts(shape); + } + + @Override + protected DataBuffer allocateBuffer(long size) { + return DataBuffers.ofInts(size); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArrayTest.java new file mode 100644 index 00000000000..a8affa58ef0 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/LongDenseNdArrayTest.java @@ -0,0 +1,37 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.LongNdArrayTestBase; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; + +public class LongDenseNdArrayTest extends LongNdArrayTestBase { + + @Override + protected LongNdArray allocate(Shape shape) { + return NdArrays.ofLongs(shape); + } + + @Override + protected DataBuffer allocateBuffer(long size) { + return DataBuffers.ofLongs(size); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArrayTest.java new file mode 100644 index 00000000000..0b41cb8a575 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/ShortDenseNdArrayTest.java @@ -0,0 +1,37 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.ShortNdArray; +import org.tensorflow.ndarray.ShortNdArrayTestBase; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; + +public class ShortDenseNdArrayTest extends ShortNdArrayTestBase { + + @Override + protected ShortNdArray allocate(Shape shape) { + return NdArrays.ofShorts(shape); + } + + @Override + protected DataBuffer allocateBuffer(long size) { + return DataBuffers.ofShorts(size); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/StringDenseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/StringDenseNdArrayTest.java new file mode 100644 index 00000000000..76168b7cc1c --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/StringDenseNdArrayTest.java @@ -0,0 +1,46 @@ +/* +Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +======================================================================= +*/ +package org.tensorflow.ndarray.impl.dense; + +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrayTestBase; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; + +public class StringDenseNdArrayTest extends NdArrayTestBase { + + @Override + protected NdArray allocate(Shape shape) { + return NdArrays.ofObjects(String.class, shape); + } + + @Override + protected DataBuffer allocateBuffer(long size) { + return DataBuffers.ofObjects(String.class, size); + } + + @Override + protected String valueOf(Long val) { + return val.toString(); + } + + protected String zeroOrNull() { + return null; + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java new file mode 100644 index 00000000000..87ebd4da4be --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sequence/ElementSequenceTest.java @@ -0,0 +1,149 @@ +/* + * Copyright 2019 The TensorFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ======================================================================= + */ + +package org.tensorflow.ndarray.impl.sequence; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.IntNdArray; +import org.tensorflow.ndarray.NdArraySequence; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBufferWindow; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.impl.AbstractNdArray; + +public class ElementSequenceTest { + + @Test + public void iterateVectorsWithIndex() { + IntNdArray array = NdArrays.ofInts(Shape.of(2, 3, 2)); + + NdArraySequence sequence = + new SlicingElementSequence((AbstractNdArray) array, 1); + List coords = new ArrayList<>((int) array.shape().size()); + sequence.forEachIndexed((c, e) -> coords.add(Arrays.copyOf(c, c.length))); + + assertEquals(6, coords.size()); + assertArrayEquals(new long[] {0, 0}, coords.get(0)); + assertArrayEquals(new long[] {0, 1}, coords.get(1)); + assertArrayEquals(new long[] {0, 2}, coords.get(2)); + assertArrayEquals(new long[] {1, 0}, coords.get(3)); + assertArrayEquals(new long[] {1, 1}, coords.get(4)); + assertArrayEquals(new long[] {1, 2}, coords.get(5)); + } + + @Test + public void iterateScalarsWithIndex() { + IntNdArray array = NdArrays.ofInts(Shape.of(2, 3, 2)); + + NdArraySequence cursor = + new SlicingElementSequence((AbstractNdArray) array, 2); + List coords = new ArrayList<>((int) array.shape().size()); + cursor.forEachIndexed((c, e) -> coords.add(Arrays.copyOf(c, c.length))); + + assertEquals(12, coords.size()); + assertArrayEquals(new long[] {0, 0, 0}, coords.get(0)); + assertArrayEquals(new long[] {0, 0, 1}, coords.get(1)); + assertArrayEquals(new long[] {0, 1, 0}, coords.get(2)); + assertArrayEquals(new long[] {0, 1, 1}, coords.get(3)); + assertArrayEquals(new long[] {0, 2, 0}, coords.get(4)); + assertArrayEquals(new long[] {0, 2, 1}, coords.get(5)); + assertArrayEquals(new long[] {1, 0, 0}, coords.get(6)); + assertArrayEquals(new long[] {1, 0, 1}, coords.get(7)); + assertArrayEquals(new long[] {1, 1, 0}, coords.get(8)); + assertArrayEquals(new long[] {1, 1, 1}, coords.get(9)); + assertArrayEquals(new long[] {1, 2, 0}, coords.get(10)); + assertArrayEquals(new long[] {1, 2, 1}, coords.get(11)); + } + + @Test + public void slicingElementSequenceReturnsUniqueInstances() { + IntNdArray array = NdArrays.ofInts(Shape.of(2, 3, 2)); + NdArraySequence sequence = + new SlicingElementSequence((AbstractNdArray) array, 1); + List elements = new ArrayList<>(); + sequence.forEach( + e -> { + elements.forEach( + tmp -> { + if (tmp == e) { + fail(); + } + }); + elements.add(e); + }); + } + + @Test + public void fastElementSequenceReturnsSameInstance() { + IntNdArray array = NdArrays.ofInts(Shape.of(2, 3, 2)); + IntNdArray element = array.get(0); + NdArraySequence sequence = + new FastElementSequence( + (AbstractNdArray) array, 1, element, mockDataBufferWindow(2)); + sequence.forEach( + e -> { + if (e != element) { + fail(); + } + }); + } + + private DataBufferWindow mockDataBufferWindow(long size) { + return new DataBufferWindow() { + + @Override + public long offset() { + return offset; + } + + @Override + public long size() { + return size; + } + + @Override + public DataBufferWindow slideTo(long index) { + offset = index; + return this; + } + + @Override + public DataBufferWindow slide(long step) { + offset += step; + return this; + } + + @Override + public IntDataBuffer buffer() { + return buffer; + } + + private long offset; + private final long size = 2; + private final IntDataBuffer buffer = DataBuffers.ofInts(2); + }; + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/BooleanSparseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/BooleanSparseNdArrayTest.java new file mode 100644 index 00000000000..32ea120e0e1 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/BooleanSparseNdArrayTest.java @@ -0,0 +1,313 @@ +package org.tensorflow.ndarray.impl.sparse; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.BooleanNdArray; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.BooleanDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.index.Indices; + +class BooleanSparseNdArrayTest { + long[][] indicesArray = {{0, 0}, {1, 2}}; + boolean[] valuesArray = {true, true}; + boolean[] valuesArrayDefaultValue = {false, false}; + boolean[] denseArray = { + true, false, false, false, + false, false, true, false, + false, false, false, false + }; + boolean[][] dense2DArray = { + {true, false, false, false}, {false, false, true, false}, {false, false, false, false} + }; + + Shape shape = Shape.of(3, 4); + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + BooleanNdArray values = StdArrays.ndCopyOf(valuesArray); + + @Test + public void testBasic() { + BooleanSparseNdArray instance = + new BooleanSparseNdArray(indices, values, DimensionalSpace.create(shape)); + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals(shape, instance.shape()); + } + + @Test + public void testCopyToBuffer() { + BooleanSparseNdArray instance = + new BooleanSparseNdArray(indices, values, DimensionalSpace.create(shape)); + BooleanDataBuffer dataBuffer = DataBuffers.ofBooleans(instance.shape().size()); + + instance.copyTo(dataBuffer); + + boolean[] array = new boolean[denseArray.length]; + dataBuffer.read(array); + assertArrayEquals(denseArray, array); + } + + @Test + public void testCopyFromBuffer() { + + BooleanDataBuffer dataBuffer = RawDataBufferFactory.create(denseArray, false); + // use a zero buffer + BooleanSparseNdArray instance = BooleanSparseNdArray.create(DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + } + + @Test + public void testWriteDefaultValue() { + // invert true/false + boolean[] denseArrayDefaultValue = new boolean[denseArray.length]; + for (int i = 0; i < denseArrayDefaultValue.length; i++) { + denseArrayDefaultValue[i] = !denseArray[i]; + } + + BooleanNdArray valuesDefault = StdArrays.ndCopyOf(new boolean[] {false, false}); + BooleanDataBuffer dataBuffer = RawDataBufferFactory.create(denseArrayDefaultValue, false); + // use a zero buffer + BooleanSparseNdArray instance = + BooleanSparseNdArray.create(true, DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(valuesDefault, instance.getValues()); + } + + @Test + public void testGetObject() { + BooleanNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + BooleanSparseNdArray instance = + new BooleanSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getObject(n, m), instance.getObject(n, m)); + } + } + } + + @Test + public void testGetBoolean() { + BooleanNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + BooleanSparseNdArray instance = + new BooleanSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getBoolean(n, m), instance.getBoolean(n, m)); + } + } + } + + @Test + public void testGetBooleanDefaultValue() { + // flip the truth table + BooleanNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + BooleanSparseNdArray instance = + new BooleanSparseNdArray( + indices, + NdArrays.vectorOf(valuesArrayDefaultValue), + true, + DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertNotEquals(ndArray.getBoolean(n, m), instance.getBoolean(n, m)); + } + } + } + + @Test + public void testGet() { + BooleanNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + BooleanSparseNdArray instance = + new BooleanSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + assertEquals(ndArray.get(n), instance.get(n)); + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.get(n, m), instance.get(n, m)); + } + } + } + + @Test + public void testSetObject() { + BooleanSparseNdArray instance = + new BooleanSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + assertThrows(java.nio.ReadOnlyBufferException.class, () -> instance.setObject(false, 0, 0)); + } + + @Test + public void testSet() { + BooleanSparseNdArray instance = + new BooleanSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + assertThrows( + java.nio.ReadOnlyBufferException.class, + () -> instance.set(instance.getDefaultArray(), 0, 0)); + } + + @Test + public void testSort() { + + long[][] indicesArray = {{0, 0}, {1, 2}, {0, 1}, {2, 3}, {1, 4}}; + long[][] sortedIndicesArray = {{0, 0}, {0, 1}, {1, 2}, {1, 4}, {2, 3}}; + boolean[] valuesArray = {true, true, false, true, false}; + boolean[] sortedValuesArray = {true, false, true, false, true}; + + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + LongNdArray sortedIndices = StdArrays.ndCopyOf(sortedIndicesArray); + BooleanNdArray values = StdArrays.ndCopyOf(valuesArray); + BooleanNdArray sortedValues = StdArrays.ndCopyOf(sortedValuesArray); + + BooleanSparseNdArray instance = + new BooleanSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + instance.sortIndicesAndValues(); + + // should be sorted in ascending row-wise coordinate order based on test values + assertEquals(sortedIndices, instance.getIndices()); + assertEquals(sortedValues, instance.getValues()); + } + + @Test + public void testElements() { + + BooleanSparseNdArray instance = + new BooleanSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + instance + .elements(0) + .forEachIndexed( + (idx, item) -> { + boolean[] slice = dense2DArray[(int) idx[0]]; + item.scalars() + .forEachIndexed((dx, f) -> assertEquals(slice[(int) dx[0]], f.getObject())); + }); + } + + @Test + public void testDense() { + + BooleanSparseNdArray instance = + new BooleanSparseNdArray(indices, values, DimensionalSpace.create(shape)); + BooleanNdArray denseInstance = instance.toDense(); + BooleanNdArray expectedDense = StdArrays.ndCopyOf(dense2DArray); + assertEquals(expectedDense, denseInstance); + } + + @Test + public void testFromDense() { + BooleanNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + BooleanSparseNdArray instance = + BooleanSparseNdArray.create(DimensionalSpace.create(ndArray.shape())); + instance.fromDense(ndArray); + assertNotNull(instance.getIndices()); + assertEquals(2, instance.getIndices().shape().get(0)); + assertNotNull(instance.getValues()); + assertEquals(2, instance.getValues().size()); + + assertEquals(ndArray.shape(), instance.shape()); + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getBoolean(n, m), instance.getBoolean(n, m)); + } + } + } + + @Test + public void testElements1() { + boolean[] expected = {true, false, false}; + + BooleanSparseNdArray instance = + new BooleanSparseNdArray(indices, values, DimensionalSpace.create(shape)); + instance + .elements(0) + .forEachIndexed((idx, l) -> assertEquals(expected[(int) idx[0]], l.getObject())); + instance + .elements(1) + .forEachIndexed( + (idx, l) -> assertEquals(dense2DArray[(int) idx[0]][(int) idx[1]], l.getObject())); + } + + @Test + public void testCopyTo() { + BooleanNdArray dst = NdArrays.ofBooleans(shape); + BooleanSparseNdArray instance = + new BooleanSparseNdArray(indices, values, DimensionalSpace.create(shape)); + instance.copyTo(dst); + for (int n = 0; n < instance.shape().get(0); n++) { + for (int m = 0; m < instance.shape().get(1); m++) { + assertEquals(instance.getBoolean(n, m), dst.getBoolean(n, m)); + } + } + } + + @Test + public void testCreate() { + + BooleanSparseNdArray instance = + new BooleanSparseNdArray(indices, values, DimensionalSpace.create(shape)); + BooleanSparseNdArray instanceA = + BooleanSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + assertEquals(instance, instanceA); + + BooleanDataBuffer dataBuffer = RawDataBufferFactory.create(denseArray, false); + // use a zero buffer + BooleanSparseNdArray instanceB = BooleanSparseNdArray.create(DimensionalSpace.create(shape)); + instanceB.copyFrom(dataBuffer); + assertEquals(instance, instanceB); + + BooleanSparseNdArray instanceC = + BooleanSparseNdArray.create(dataBuffer, DimensionalSpace.create(shape)); + assertEquals(instanceB, instanceC); + + BooleanSparseNdArray instanceD = BooleanSparseNdArray.create(dataBuffer, shape); + assertEquals(instanceB, instanceD); + + BooleanNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + BooleanSparseNdArray instanceE = BooleanSparseNdArray.create(ndArray); + assertEquals(instance, instanceE); + } + + @Test + public void testSlice() { + boolean[] expected = {false, false, true, false, false, false}; + BooleanSparseNdArray instance = + new BooleanSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + BooleanNdArray sliceInstance = instance.slice(Indices.all(), Indices.sliceFrom(2)); + // check the values of the slice against the original sparse array + AtomicInteger i = new AtomicInteger(); + sliceInstance + .scalars() + .forEachIndexed((idx, f) -> assertEquals(expected[i.getAndIncrement()], f.getBoolean())); + // check values from elements(0) of a slice against the original sparse array + i.set(0); + sliceInstance + .elements(0) + .forEachIndexed( + (idx, l) -> + l.scalars() + .forEachIndexed( + (lidx, f) -> assertEquals(expected[i.getAndIncrement()], f.getBoolean()))); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/ByteSparseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/ByteSparseNdArrayTest.java new file mode 100644 index 00000000000..b0504659055 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/ByteSparseNdArrayTest.java @@ -0,0 +1,304 @@ +package org.tensorflow.ndarray.impl.sparse; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.ByteNdArray; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.ByteDataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.index.Indices; + +class ByteSparseNdArrayTest { + long[][] indicesArray = {{0, 0}, {1, 2}}; + byte[] valuesArray = {1, 16}; + byte[] denseArray = { + 1, 0, 0, 0, + 0, 0, 16, 0, + 0, 0, 0, 0 + }; + byte[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 16, 0}, {0, 0, 0, 0}}; + + byte[][] dense2DArrayDefaultValue = {{1, -1, -1, -1}, {-1, -1, 16, -1}, {-1, -1, -1, -1}}; + + Shape shape = Shape.of(3, 4); + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + ByteNdArray values = StdArrays.ndCopyOf(valuesArray); + + @Test + public void testBasic() { + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, DimensionalSpace.create(shape)); + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals(shape, instance.shape()); + } + + @Test + public void testCopyToBuffer() { + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, DimensionalSpace.create(shape)); + ByteDataBuffer dataBuffer = DataBuffers.ofBytes(instance.shape().size()); + + instance.copyTo(dataBuffer); + + byte[] array = new byte[denseArray.length]; + dataBuffer.read(array); + assertArrayEquals(denseArray, array); + } + + @Test + public void testCopyFromBuffer() { + + ByteDataBuffer dataBuffer = RawDataBufferFactory.create(denseArray, false); + // use a zero buffer + ByteSparseNdArray instance = ByteSparseNdArray.create(DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + } + + @Test + public void testWriteDefaultValue() { + // change 0 to -1 + byte[] denseArrayDefaultValue = new byte[denseArray.length]; + for (int i = 0; i < denseArrayDefaultValue.length; i++) { + denseArrayDefaultValue[i] = denseArray[i] == 0 ? -1 : denseArray[i]; + } + ByteDataBuffer dataBuffer = RawDataBufferFactory.create(denseArrayDefaultValue, false); + // use a zero buffer + ByteSparseNdArray instance = + ByteSparseNdArray.create((byte) -1, DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals((byte) -1, instance.getByte(2, 0)); + } + + @Test + public void testGetObject() { + ByteNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getObject(n, m), instance.getObject(n, m)); + } + } + } + + @Test + public void testGetByte() { + ByteNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getByte(n, m), instance.getByte(n, m)); + } + } + } + + @Test + public void testGetByteDefaultValue() { + ByteNdArray ndArray = StdArrays.ndCopyOf(dense2DArrayDefaultValue); + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, (byte) -1, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getByte(n, m), instance.getByte(n, m)); + } + } + } + + @Test + public void testGet() { + ByteNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + assertEquals(ndArray.get(n), instance.get(n)); + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.get(n, m), instance.get(n, m)); + } + } + } + + @Test + public void testSetObject() { + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + assertThrows(java.nio.ReadOnlyBufferException.class, () -> instance.setObject((byte) 0, 0, 0)); + } + + @Test + public void testSet() { + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + assertThrows( + java.nio.ReadOnlyBufferException.class, + () -> instance.set(instance.getDefaultArray(), 0, 0)); + } + + @Test + public void testSort() { + + long[][] indicesArray = {{0, 0}, {1, 2}, {0, 1}, {2, 3}, {1, 4}}; + long[][] sortedIndicesArray = {{0, 0}, {0, 1}, {1, 2}, {1, 4}, {2, 3}}; + byte[] valuesArray = {1, 3, 2, 5, 4}; + byte[] sortedValuesArray = {1, 2, 3, 4, 5}; + + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + LongNdArray sortedIndices = StdArrays.ndCopyOf(sortedIndicesArray); + ByteNdArray values = StdArrays.ndCopyOf(valuesArray); + ByteNdArray sortedValues = StdArrays.ndCopyOf(sortedValuesArray); + + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + instance.sortIndicesAndValues(); + + // should be sorted in ascending row-wise coordinate order based on test values + assertEquals(sortedIndices, instance.getIndices()); + assertEquals(sortedValues, instance.getValues()); + } + + @Test + public void testElements() { + + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + instance + .elements(0) + .forEachIndexed( + (idx, item) -> { + byte[] slice = dense2DArray[(int) idx[0]]; + item.scalars() + .forEachIndexed((dx, f) -> assertEquals(slice[(int) dx[0]], f.getObject())); + }); + } + + @Test + public void testDense() { + + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, DimensionalSpace.create(shape)); + ByteNdArray denseInstance = instance.toDense(); + ByteNdArray expectedDense = StdArrays.ndCopyOf(dense2DArray); + assertEquals(expectedDense, denseInstance); + } + + @Test + public void testFromDense() { + ByteNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + ByteSparseNdArray instance = ByteSparseNdArray.create(DimensionalSpace.create(ndArray.shape())); + instance.fromDense(ndArray); + assertNotNull(instance.getIndices()); + assertEquals(2, instance.getIndices().shape().get(0)); + assertNotNull(instance.getValues()); + assertEquals(2, instance.getValues().size()); + + assertEquals(ndArray.shape(), instance.shape()); + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getByte(n, m), instance.getByte(n, m)); + } + } + } + + @Test + public void testElements1() { + byte[] expected = {1, 0, 0}; + + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, DimensionalSpace.create(shape)); + instance + .elements(0) + .forEachIndexed((idx, l) -> assertEquals(expected[(int) idx[0]], l.getObject())); + instance + .elements(1) + .forEachIndexed( + (idx, l) -> assertEquals(dense2DArray[(int) idx[0]][(int) idx[1]], l.getObject())); + } + + @Test + public void testCopyTo() { + ByteNdArray dst = NdArrays.ofBytes(shape); + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, DimensionalSpace.create(shape)); + instance.copyTo(dst); + for (int n = 0; n < instance.shape().get(0); n++) { + for (int m = 0; m < instance.shape().get(1); m++) { + assertEquals(instance.getByte(n, m), dst.getByte(n, m)); + } + } + } + + @Test + public void testCreate() { + + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, DimensionalSpace.create(shape)); + ByteSparseNdArray instanceA = + ByteSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + assertEquals(instance, instanceA); + + ByteDataBuffer dataBuffer = RawDataBufferFactory.create(denseArray, false); + // use a zero buffer + ByteSparseNdArray instanceB = ByteSparseNdArray.create(DimensionalSpace.create(shape)); + instanceB.copyFrom(dataBuffer); + assertEquals(instance, instanceB); + + ByteSparseNdArray instanceC = + ByteSparseNdArray.create(dataBuffer, DimensionalSpace.create(shape)); + assertEquals(instanceB, instanceC); + + ByteSparseNdArray instanceD = ByteSparseNdArray.create(dataBuffer, shape); + assertEquals(instanceB, instanceD); + + ByteNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + ByteSparseNdArray instanceE = ByteSparseNdArray.create(ndArray); + assertEquals(instance, instanceE); + } + + @Test + public void testSlice() { + byte[] expected = {0, 0, 16, 0, 0, 0}; + ByteSparseNdArray instance = + new ByteSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + ByteNdArray sliceInstance = instance.slice(Indices.all(), Indices.sliceFrom(2)); + // check the values of the slice against the original sparse array + AtomicInteger i = new AtomicInteger(); + sliceInstance + .scalars() + .forEachIndexed((idx, f) -> assertEquals(expected[i.getAndIncrement()], f.getByte())); + // check values from elements(0) of a slice against the original sparse array + i.set(0); + sliceInstance + .elements(0) + .forEachIndexed( + (idx, l) -> + l.scalars() + .forEachIndexed( + (lidx, f) -> assertEquals(expected[i.getAndIncrement()], f.getByte()))); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/DoubleSparseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/DoubleSparseNdArrayTest.java new file mode 100644 index 00000000000..e7209902d86 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/DoubleSparseNdArrayTest.java @@ -0,0 +1,318 @@ +package org.tensorflow.ndarray.impl.sparse; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.DoubleBuffer; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.DoubleDataBuffer; +import org.tensorflow.ndarray.impl.buffer.nio.NioDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.index.Indices; + +class DoubleSparseNdArrayTest { + long[][] indicesArray = {{0, 0}, {1, 2}}; + double[] valuesArray = {1, 256}; + double[] denseArray = { + 1, 0, 0, 0, + 0, 0, 256, 0, + 0, 0, 0, 0 + }; + double[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 256, 0}, {0, 0, 0, 0}}; + + double[][] dense2DArrayDefaultValue = {{1, -1, -1, -1}, {-1, -1, 256, -1}, {-1, -1, -1, -1}}; + + Shape shape = Shape.of(3, 4); + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + DoubleNdArray values = StdArrays.ndCopyOf(valuesArray); + + @Test + public void testBasic() { + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, DimensionalSpace.create(shape)); + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals(shape, instance.shape()); + } + + @Test + public void testCopyToBuffer() { + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, DimensionalSpace.create(shape)); + DoubleDataBuffer dataBuffer = DataBuffers.ofDoubles(instance.shape().size()); + + instance.copyTo(dataBuffer); + + double[] array = new double[denseArray.length]; + dataBuffer.read(array); + assertArrayEquals(denseArray, array); + } + + @Test + public void testCopyFromBuffer() { + + DoubleDataBuffer dataBuffer = NioDataBufferFactory.create(DoubleBuffer.wrap(denseArray)); + // use a zero buffer + DoubleSparseNdArray instance = DoubleSparseNdArray.create(DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + } + + @Test + public void testWriteDefaultValue() { + // change 0 to -1 + double[] denseArrayDefaultValue = Arrays.stream(denseArray).map(x -> x == 0 ? -1 : x).toArray(); + DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(denseArrayDefaultValue, false); + // use a zero buffer + DoubleSparseNdArray instance = DoubleSparseNdArray.create(-1d, DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals(-1d, instance.getDouble(2, 0)); + } + + @Test + public void testGetObject() { + + DoubleNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getObject(n, m), instance.getObject(n, m)); + } + } + } + + @Test + public void testGetDouble() { + DoubleNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getDouble(n, m), instance.getDouble(n, m)); + } + } + } + + @Test + public void testGetDoubleDefaultValue() { + DoubleNdArray ndArray = StdArrays.ndCopyOf(dense2DArrayDefaultValue); + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, -1d, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getDouble(n, m), instance.getDouble(n, m)); + } + } + } + + @Test + public void testGet() { + DoubleNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + assertEquals(ndArray.get(n), instance.get(n)); + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.get(n, m), instance.get(n, m)); + } + } + } + + @Test + public void testSetObject() { + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + assertThrows(java.nio.ReadOnlyBufferException.class, () -> instance.setObject(2d, 0, 0)); + } + + @Test + public void testSet() { + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + assertThrows( + java.nio.ReadOnlyBufferException.class, + () -> instance.set(instance.getDefaultArray(), 0, 0)); + } + + @Test + public void testSort() { + + long[][] indicesArray = {{0, 0}, {1, 2}, {0, 1}, {2, 3}, {1, 4}}; + long[][] sortedIndicesArray = {{0, 0}, {0, 1}, {1, 2}, {1, 4}, {2, 3}}; + double[] valuesArray = {1, 3, 2, 5, 4}; + double[] sortedValuesArray = {1, 2, 3, 4, 5}; + + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + LongNdArray sortedIndices = StdArrays.ndCopyOf(sortedIndicesArray); + DoubleNdArray values = StdArrays.ndCopyOf(valuesArray); + DoubleNdArray sortedValues = StdArrays.ndCopyOf(sortedValuesArray); + + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + instance.sortIndicesAndValues(); + + // should be sorted in ascending row-wise coordinate order based on test values + assertEquals(sortedIndices, instance.getIndices()); + assertEquals(sortedValues, instance.getValues()); + } + + @Test + public void testElements() { + + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + instance + .elements(0) + .forEachIndexed( + (idx, item) -> { + double[] slice = dense2DArray[(int) idx[0]]; + item.scalars() + .forEachIndexed((dx, f) -> assertEquals(slice[(int) dx[0]], f.getObject())); + }); + } + + @Test + public void testDense() { + + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, DimensionalSpace.create(shape)); + DoubleNdArray denseInstance = instance.toDense(); + DoubleNdArray expectedDense = StdArrays.ndCopyOf(dense2DArray); + assertEquals(expectedDense, denseInstance); + } + + @Test + public void testFromDense() { + DoubleNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + DoubleSparseNdArray instance = + DoubleSparseNdArray.create(DimensionalSpace.create(ndArray.shape())); + instance.fromDense(ndArray); + assertNotNull(instance.getIndices()); + assertEquals(2, instance.getIndices().shape().get(0)); + assertNotNull(instance.getValues()); + assertEquals(2, instance.getValues().size()); + + assertEquals(ndArray.shape(), instance.shape()); + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getDouble(n, m), instance.getDouble(n, m)); + } + } + } + + @Test + public void testElements1() { + double[] expected = {1, 0, 0}; + + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, DimensionalSpace.create(shape)); + instance + .elements(0) + .forEachIndexed((idx, l) -> assertEquals(expected[(int) idx[0]], l.getObject())); + instance + .elements(1) + .forEachIndexed( + (idx, l) -> assertEquals(dense2DArray[(int) idx[0]][(int) idx[1]], l.getObject())); + } + + @Test + public void testCopyTo() { + DoubleNdArray dst = NdArrays.ofDoubles(shape); + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, DimensionalSpace.create(shape)); + instance.copyTo(dst); + for (int n = 0; n < instance.shape().get(0); n++) { + for (int m = 0; m < instance.shape().get(1); m++) { + assertEquals(instance.getDouble(n, m), dst.getDouble(n, m)); + } + } + } + + @Test + public void testCreate() { + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, DimensionalSpace.create(shape)); + DoubleSparseNdArray instanceA = + DoubleSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + assertEquals(instance, instanceA); + + DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(denseArray, false); + // use a zero buffer + DoubleSparseNdArray instanceB = DoubleSparseNdArray.create(DimensionalSpace.create(shape)); + instanceB.copyFrom(dataBuffer); + assertEquals(instance, instanceB); + + DoubleSparseNdArray instanceC = + DoubleSparseNdArray.create(dataBuffer, DimensionalSpace.create(shape)); + assertEquals(instanceB, instanceC); + + DoubleSparseNdArray instanceD = DoubleSparseNdArray.create(dataBuffer, shape); + assertEquals(instanceB, instanceD); + + DoubleNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + DoubleSparseNdArray instanceE = DoubleSparseNdArray.create(ndArray); + assertEquals(instance, instanceE); + } + + @Test + public void testSlice() { + double[] expected = {0, 0, 256, 0, 0, 0}; + DoubleSparseNdArray instance = + new DoubleSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + DoubleNdArray sliceInstance = instance.slice(Indices.all(), Indices.sliceFrom(2)); + // check the values of the slice against the original sparse array + AtomicInteger i = new AtomicInteger(); + sliceInstance + .scalars() + .forEachIndexed((idx, f) -> assertEquals(expected[i.getAndIncrement()], f.getDouble())); + // check values from elements(0) of a slice against the original sparse array + i.set(0); + sliceInstance + .elements(0) + .forEachIndexed( + (idx, l) -> + l.scalars() + .forEachIndexed( + (lidx, f) -> assertEquals(expected[i.getAndIncrement()], f.getDouble()))); + } + + @Test + public void testToString() { + DoubleNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + DoubleSparseNdArray instance = + DoubleSparseNdArray.create(DimensionalSpace.create(ndArray.shape())); + instance.fromDense(ndArray); + Assertions.assertEquals( + "DoubleSparseNdArray(defaultValue=0.0, numElements=2, shape=[3, 4])", instance.toString()); + DoubleSparseNdArray empty = DoubleSparseNdArray.create(DimensionalSpace.create(Shape.of(5))); + Assertions.assertEquals( + "DoubleSparseNdArray(defaultValue=0.0, numElements=0, shape=[5])", empty.toString()); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/FloatSparseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/FloatSparseNdArrayTest.java new file mode 100644 index 00000000000..de5d3bbb634 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/FloatSparseNdArrayTest.java @@ -0,0 +1,313 @@ +package org.tensorflow.ndarray.impl.sparse; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.FloatBuffer; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.FloatDataBuffer; +import org.tensorflow.ndarray.impl.buffer.nio.NioDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.index.Indices; + +class FloatSparseNdArrayTest { + long[][] indicesArray = {{0, 0}, {1, 2}}; + float[] valuesArray = {1, 2}; + float[] denseArray = { + 1, 0, 0, 0, + 0, 0, 2, 0, + 0, 0, 0, 0 + }; + + float[][] dense2DArrayDefaultValue = {{1, -1, -1, -1}, {-1, -1, 2, -1}, {-1, -1, -1, -1}}; + + Shape shape = Shape.of(3, 4); + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + FloatNdArray values = StdArrays.ndCopyOf(valuesArray); + + @Test + public void testBasic() { + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, DimensionalSpace.create(shape)); + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals(shape, instance.shape()); + } + + @Test + public void testCopyToBuffer() { + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, DimensionalSpace.create(shape)); + FloatDataBuffer dataBuffer = DataBuffers.ofFloats(instance.shape().size()); + + instance.copyTo(dataBuffer); + + float[] array = new float[denseArray.length]; + dataBuffer.read(array); + assertArrayEquals(denseArray, array); + } + + @Test + public void testCopyFromBuffer() { + + FloatDataBuffer dataBuffer = NioDataBufferFactory.create(FloatBuffer.wrap(denseArray)); + // use a zero buffer + FloatSparseNdArray instance = FloatSparseNdArray.create(DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + } + + @Test + public void testWriteDefaultValue() { + // change 0 to -1 + float[] denseArrayDefaultValue = new float[denseArray.length]; + for (int i = 0; i < denseArrayDefaultValue.length; i++) { + denseArrayDefaultValue[i] = denseArray[i] == 0f ? -1f : denseArray[i]; + } + FloatDataBuffer dataBuffer = RawDataBufferFactory.create(denseArrayDefaultValue, false); + // use a zero buffer + FloatSparseNdArray instance = FloatSparseNdArray.create(-1f, DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals(-1f, instance.getFloat(2, 0)); + } + + @Test + public void testGetObject() { + float[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + FloatNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getObject(n, m), instance.getObject(n, m)); + } + } + } + + @Test + public void testGetFloat() { + float[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + FloatNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getFloat(n, m), instance.getFloat(n, m)); + } + } + } + + @Test + public void testGetFloatDefaultValue() { + FloatNdArray ndArray = StdArrays.ndCopyOf(dense2DArrayDefaultValue); + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, -1, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getFloat(n, m), instance.getFloat(n, m)); + } + } + } + + @Test + public void testGet() { + float[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + FloatNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + assertEquals(ndArray.get(n), instance.get(n)); + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.get(n, m), instance.get(n, m)); + } + } + } + + @Test + public void testSetObject() { + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + assertThrows(java.nio.ReadOnlyBufferException.class, () -> instance.setObject(2f, 0, 0)); + } + + @Test + public void testSet() { + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + assertThrows( + java.nio.ReadOnlyBufferException.class, + () -> instance.set(instance.getDefaultArray(), 0, 0)); + } + + @Test + public void testSort() { + + long[][] indicesArray = {{0, 0}, {1, 2}, {0, 1}, {2, 3}, {1, 4}}; + long[][] sortedIndicesArray = {{0, 0}, {0, 1}, {1, 2}, {1, 4}, {2, 3}}; + float[] valuesArray = {1, 3, 2, 5, 4}; + float[] sortedValuesArray = {1, 2, 3, 4, 5}; + + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + LongNdArray sortedIndices = StdArrays.ndCopyOf(sortedIndicesArray); + FloatNdArray values = StdArrays.ndCopyOf(valuesArray); + FloatNdArray sortedValues = StdArrays.ndCopyOf(sortedValuesArray); + + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + instance.sortIndicesAndValues(); + + // should be sorted in ascending row-wise coordinate order based on test values + assertEquals(sortedIndices, instance.getIndices()); + assertEquals(sortedValues, instance.getValues()); + } + + @Test + public void testElements() { + + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + float[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + instance + .elements(0) + .forEachIndexed( + (idx, item) -> { + float[] slice = dense2DArray[(int) idx[0]]; + item.scalars() + .forEachIndexed((dx, f) -> assertEquals(slice[(int) dx[0]], f.getObject())); + }); + } + + @Test + public void testDense() { + float[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, DimensionalSpace.create(shape)); + FloatNdArray denseInstance = instance.toDense(); + FloatNdArray expectedDense = StdArrays.ndCopyOf(dense2DArray); + assertEquals(expectedDense, denseInstance); + } + + @Test + public void testFromDense() { + float[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + FloatNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + FloatSparseNdArray instance = + FloatSparseNdArray.create(DimensionalSpace.create(ndArray.shape())); + instance.fromDense(ndArray); + assertNotNull(instance.getIndices()); + assertEquals(2, instance.getIndices().shape().get(0)); + assertNotNull(instance.getValues()); + assertEquals(2, instance.getValues().size()); + + assertEquals(ndArray.shape(), instance.shape()); + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getFloat(n, m), instance.getFloat(n, m)); + } + } + } + + @Test + public void testElements1() { + float[] expected = {1, 0, 0}; + float[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, DimensionalSpace.create(shape)); + instance + .elements(0) + .forEachIndexed((idx, l) -> assertEquals(expected[(int) idx[0]], l.getObject())); + instance + .elements(1) + .forEachIndexed( + (idx, l) -> assertEquals(dense2DArray[(int) idx[0]][(int) idx[1]], l.getObject())); + } + + @Test + public void testCopyTo() { + FloatNdArray dst = NdArrays.ofFloats(shape); + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, DimensionalSpace.create(shape)); + instance.copyTo(dst); + for (int n = 0; n < instance.shape().get(0); n++) { + for (int m = 0; m < instance.shape().get(1); m++) { + assertEquals(instance.getFloat(n, m), dst.getFloat(n, m)); + } + } + } + + @Test + public void testCreate() { + float[] denseArray = {1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0}; + float[][] dense2Array = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, DimensionalSpace.create(shape)); + FloatSparseNdArray instanceA = + FloatSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + assertEquals(instance, instanceA); + + FloatDataBuffer dataBuffer = RawDataBufferFactory.create(denseArray, false); + // use a zero buffer + FloatSparseNdArray instanceB = FloatSparseNdArray.create(DimensionalSpace.create(shape)); + instanceB.copyFrom(dataBuffer); + assertEquals(instance, instanceB); + + FloatSparseNdArray instanceC = + FloatSparseNdArray.create(dataBuffer, DimensionalSpace.create(shape)); + assertEquals(instanceB, instanceC); + + FloatSparseNdArray instanceD = FloatSparseNdArray.create(dataBuffer, shape); + assertEquals(instanceB, instanceD); + + FloatNdArray ndArray = StdArrays.ndCopyOf(dense2Array); + FloatSparseNdArray instanceE = FloatSparseNdArray.create(ndArray); + assertEquals(instance, instanceE); + } + + @Test + public void testSlice() { + float[] expected = {0, 0, 2, 0, 0, 0}; + FloatSparseNdArray instance = + new FloatSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + FloatNdArray sliceInstance = instance.slice(Indices.all(), Indices.sliceFrom(2)); + // check the values of the slice against the original sparse array + AtomicInteger i = new AtomicInteger(); + sliceInstance + .scalars() + .forEachIndexed((idx, f) -> assertEquals(expected[i.getAndIncrement()], f.getFloat())); + // check values from elements(0) of a slice against the original sparse array + i.set(0); + sliceInstance + .elements(0) + .forEachIndexed( + (idx, l) -> + l.scalars() + .forEachIndexed( + (lidx, f) -> assertEquals(expected[i.getAndIncrement()], f.getFloat()))); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/IntSparseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/IntSparseNdArrayTest.java new file mode 100644 index 00000000000..669cd8080e5 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/IntSparseNdArrayTest.java @@ -0,0 +1,311 @@ +package org.tensorflow.ndarray.impl.sparse; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.IntBuffer; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.IntNdArray; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.IntDataBuffer; +import org.tensorflow.ndarray.impl.buffer.nio.NioDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.index.Indices; + +class IntSparseNdArrayTest { + long[][] indicesArray = {{0, 0}, {1, 2}}; + int[] valuesArray = {1, 2}; + int[] denseArray = { + 1, 0, 0, 0, + 0, 0, 2, 0, + 0, 0, 0, 0 + }; + + int[][] dense2DArrayDefaultValue = {{1, -1, -1, -1}, {-1, -1, 2, -1}, {-1, -1, -1, -1}}; + + Shape shape = Shape.of(3, 4); + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + IntNdArray values = StdArrays.ndCopyOf(valuesArray); + + @Test + public void testBasic() { + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, DimensionalSpace.create(shape)); + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals(shape, instance.shape()); + } + + @Test + public void testCopyToBuffer() { + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, DimensionalSpace.create(shape)); + IntDataBuffer dataBuffer = DataBuffers.ofInts(instance.shape().size()); + + instance.copyTo(dataBuffer); + + int[] array = new int[denseArray.length]; + dataBuffer.read(array); + assertArrayEquals(denseArray, array); + } + + @Test + public void testCopyFromBufferBuffer() { + + IntDataBuffer dataBuffer = NioDataBufferFactory.create(IntBuffer.wrap(denseArray)); + // use a zero buffer + IntSparseNdArray instance = IntSparseNdArray.create(DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + } + + @Test + public void testWriteDefaultValue() { + // change 0 to -1 + int[] denseArrayDefaultValue = Arrays.stream(denseArray).map(x -> x == 0 ? -1 : x).toArray(); + + IntDataBuffer dataBuffer = RawDataBufferFactory.create(denseArrayDefaultValue, false); + // use a zero buffer + IntSparseNdArray instance = IntSparseNdArray.create(-1, DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals(-1, instance.getInt(2, 0)); + } + + @Test + public void testGetObject() { + int[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + IntNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getObject(n, m), instance.getObject(n, m)); + } + } + } + + @Test + public void testGetInt() { + int[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + IntNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getInt(n, m), instance.getInt(n, m)); + } + } + } + + @Test + public void testGetIntDefaultValue() { + IntNdArray ndArray = StdArrays.ndCopyOf(dense2DArrayDefaultValue); + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, -1, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getInt(n, m), instance.getInt(n, m)); + } + } + } + + @Test + public void testGet() { + int[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + IntNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + assertEquals(ndArray.get(n), instance.get(n)); + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.get(n, m), instance.get(n, m)); + } + } + } + + @Test + public void testSetObject() { + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + assertThrows(java.nio.ReadOnlyBufferException.class, () -> instance.setObject(2, 0, 0)); + } + + @Test + public void testSet() { + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + assertThrows( + java.nio.ReadOnlyBufferException.class, + () -> instance.set(instance.getDefaultArray(), 0, 0)); + } + + @Test + public void testSort() { + + long[][] indicesArray = {{0, 0}, {1, 2}, {0, 1}, {2, 3}, {1, 4}}; + long[][] sortedIndicesArray = {{0, 0}, {0, 1}, {1, 2}, {1, 4}, {2, 3}}; + int[] valuesArray = {1, 3, 2, 5, 4}; + int[] sortedValuesArray = {1, 2, 3, 4, 5}; + + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + LongNdArray sortedIndices = StdArrays.ndCopyOf(sortedIndicesArray); + IntNdArray values = StdArrays.ndCopyOf(valuesArray); + IntNdArray sortedValues = StdArrays.ndCopyOf(sortedValuesArray); + + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + instance.sortIndicesAndValues(); + + // should be sorted in ascending row-wise coordinate order based on test values + assertEquals(sortedIndices, instance.getIndices()); + assertEquals(sortedValues, instance.getValues()); + } + + @Test + public void testElements() { + + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + int[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + instance + .elements(0) + .forEachIndexed( + (idx, item) -> { + int[] slice = dense2DArray[(int) idx[0]]; + item.scalars() + .forEachIndexed((dx, f) -> assertEquals(slice[(int) dx[0]], f.getObject())); + }); + } + + @Test + public void testDense() { + int[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, DimensionalSpace.create(shape)); + IntNdArray denseInstance = instance.toDense(); + IntNdArray expectedDense = StdArrays.ndCopyOf(dense2DArray); + assertEquals(expectedDense, denseInstance); + } + + @Test + public void testFromDense() { + int[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + IntNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + IntSparseNdArray instance = IntSparseNdArray.create(DimensionalSpace.create(ndArray.shape())); + instance.fromDense(ndArray); + assertNotNull(instance.getIndices()); + assertEquals(2, instance.getIndices().shape().get(0)); + assertNotNull(instance.getValues()); + assertEquals(2, instance.getValues().size()); + + assertEquals(ndArray.shape(), instance.shape()); + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getInt(n, m), instance.getInt(n, m)); + } + } + } + + @Test + public void testElements1() { + int[] expected = {1, 0, 0}; + int[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, DimensionalSpace.create(shape)); + instance + .elements(0) + .forEachIndexed((idx, l) -> assertEquals(expected[(int) idx[0]], l.getObject())); + instance + .elements(1) + .forEachIndexed( + (idx, l) -> assertEquals(dense2DArray[(int) idx[0]][(int) idx[1]], l.getObject())); + } + + @Test + public void testCopyTo() { + IntNdArray dst = NdArrays.ofInts(shape); + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, DimensionalSpace.create(shape)); + instance.copyTo(dst); + for (int n = 0; n < instance.shape().get(0); n++) { + for (int m = 0; m < instance.shape().get(1); m++) { + assertEquals(instance.getInt(n, m), dst.getInt(n, m)); + } + } + } + + @Test + public void testCreate() { + int[] denseArray = {1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0}; + int[][] dense2Array = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, DimensionalSpace.create(shape)); + IntSparseNdArray instanceA = + IntSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + assertEquals(instance, instanceA); + + IntDataBuffer dataBuffer = RawDataBufferFactory.create(denseArray, false); + // use a zero buffer + IntSparseNdArray instanceB = IntSparseNdArray.create(DimensionalSpace.create(shape)); + instanceB.copyFrom(dataBuffer); + assertEquals(instance, instanceB); + + IntSparseNdArray instanceC = + IntSparseNdArray.create(dataBuffer, DimensionalSpace.create(shape)); + assertEquals(instanceB, instanceC); + + IntSparseNdArray instanceD = IntSparseNdArray.create(dataBuffer, shape); + assertEquals(instanceB, instanceD); + + IntNdArray ndArray = StdArrays.ndCopyOf(dense2Array); + IntSparseNdArray instanceE = IntSparseNdArray.create(ndArray); + assertEquals(instance, instanceE); + } + + @Test + public void testSlice() { + int[] expected = {0, 0, 2, 0, 0, 0}; + IntSparseNdArray instance = + new IntSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + IntNdArray sliceInstance = instance.slice(Indices.all(), Indices.sliceFrom(2)); + // check the values of the slice against the original sparse array + AtomicInteger i = new AtomicInteger(); + sliceInstance + .scalars() + .forEachIndexed((idx, f) -> assertEquals(expected[i.getAndIncrement()], f.getInt())); + // check values from elements(0) of a slice against the original sparse array + i.set(0); + sliceInstance + .elements(0) + .forEachIndexed( + (idx, l) -> + l.scalars() + .forEachIndexed( + (lidx, f) -> assertEquals(expected[i.getAndIncrement()], f.getInt()))); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/LongSparseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/LongSparseNdArrayTest.java new file mode 100644 index 00000000000..93864683650 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/LongSparseNdArrayTest.java @@ -0,0 +1,310 @@ +package org.tensorflow.ndarray.impl.sparse; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.LongBuffer; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.impl.buffer.nio.NioDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.index.Indices; + +class LongSparseNdArrayTest { + long[][] indicesArray = {{0, 0}, {1, 2}}; + long[] valuesArray = {1, 2}; + long[] denseArray = { + 1, 0, 0, 0, + 0, 0, 2, 0, + 0, 0, 0, 0 + }; + + long[][] dense2DArrayDefaultValue = {{1, -1, -1, -1}, {-1, -1, 2, -1}, {-1, -1, -1, -1}}; + + Shape shape = Shape.of(3, 4); + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + LongNdArray values = StdArrays.ndCopyOf(valuesArray); + + @Test + public void testBasic() { + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, DimensionalSpace.create(shape)); + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals(shape, instance.shape()); + } + + @Test + public void testCopyToBuffer() { + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, DimensionalSpace.create(shape)); + LongDataBuffer dataBuffer = DataBuffers.ofLongs(instance.shape().size()); + + instance.copyTo(dataBuffer); + + long[] array = new long[denseArray.length]; + dataBuffer.read(array); + assertArrayEquals(denseArray, array); + } + + @Test + public void testCopyFromBuffer() { + + LongDataBuffer dataBuffer = NioDataBufferFactory.create(LongBuffer.wrap(denseArray)); + // use a zero buffer + LongSparseNdArray instance = LongSparseNdArray.create(DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + } + + @Test + public void testWriteDefaultValue() { + // change 0 to -1 + long[] denseArrayDefaultValue = Arrays.stream(denseArray).map(x -> x == 0 ? -1 : x).toArray(); + + LongDataBuffer dataBuffer = RawDataBufferFactory.create(denseArrayDefaultValue, false); + // use a zero buffer + LongSparseNdArray instance = LongSparseNdArray.create(-1L, DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals(-1L, instance.getLong(2, 0)); + } + + @Test + public void testGetObject() { + long[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + LongNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getObject(n, m), instance.getObject(n, m)); + } + } + } + + @Test + public void testGetLong() { + long[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + LongNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getLong(n, m), instance.getLong(n, m)); + } + } + } + + @Test + public void testGetLongDefaultValue() { + LongNdArray ndArray = StdArrays.ndCopyOf(dense2DArrayDefaultValue); + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, -1L, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getLong(n, m), instance.getLong(n, m)); + } + } + } + + @Test + public void testGet() { + long[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + LongNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + assertEquals(ndArray.get(n), instance.get(n)); + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.get(n, m), instance.get(n, m)); + } + } + } + + @Test + public void testSetObject() { + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + assertThrows(java.nio.ReadOnlyBufferException.class, () -> instance.setObject(2L, 0, 0)); + } + + @Test + public void testSet() { + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + assertThrows( + java.nio.ReadOnlyBufferException.class, + () -> instance.set(instance.getDefaultArray(), 0, 0)); + } + + @Test + public void testSort() { + + long[][] indicesArray = {{0, 0}, {1, 2}, {0, 1}, {2, 3}, {1, 4}}; + long[][] sortedIndicesArray = {{0, 0}, {0, 1}, {1, 2}, {1, 4}, {2, 3}}; + long[] valuesArray = {1, 3, 2, 5, 4}; + long[] sortedValuesArray = {1, 2, 3, 4, 5}; + + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + LongNdArray sortedIndices = StdArrays.ndCopyOf(sortedIndicesArray); + LongNdArray values = StdArrays.ndCopyOf(valuesArray); + LongNdArray sortedValues = StdArrays.ndCopyOf(sortedValuesArray); + + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + instance.sortIndicesAndValues(); + + // should be sorted in ascending row-wise coordinate order based on test values + assertEquals(sortedIndices, instance.getIndices()); + assertEquals(sortedValues, instance.getValues()); + } + + @Test + public void testElements() { + + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + long[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + instance + .elements(0) + .forEachIndexed( + (idx, item) -> { + long[] slice = dense2DArray[(int) idx[0]]; + item.scalars() + .forEachIndexed((dx, f) -> assertEquals(slice[(int) dx[0]], f.getObject())); + }); + } + + @Test + public void testDense() { + long[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, DimensionalSpace.create(shape)); + LongNdArray denseInstance = instance.toDense(); + LongNdArray expectedDense = StdArrays.ndCopyOf(dense2DArray); + assertEquals(expectedDense, denseInstance); + } + + @Test + public void testFromDense() { + long[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + LongNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + LongSparseNdArray instance = LongSparseNdArray.create(DimensionalSpace.create(ndArray.shape())); + instance.fromDense(ndArray); + assertNotNull(instance.getIndices()); + assertEquals(2, instance.getIndices().shape().get(0)); + assertNotNull(instance.getValues()); + assertEquals(2, instance.getValues().size()); + + assertEquals(ndArray.shape(), instance.shape()); + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getLong(n, m), instance.getLong(n, m)); + } + } + } + + @Test + public void testElements1() { + long[] expected = {1, 0, 0}; + long[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, DimensionalSpace.create(shape)); + instance + .elements(0) + .forEachIndexed((idx, l) -> assertEquals(expected[(int) idx[0]], l.getObject())); + instance + .elements(1) + .forEachIndexed( + (idx, l) -> assertEquals(dense2DArray[(int) idx[0]][(int) idx[1]], l.getObject())); + } + + @Test + public void testCopyTo() { + LongNdArray dst = NdArrays.ofLongs(shape); + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, DimensionalSpace.create(shape)); + instance.copyTo(dst); + for (int n = 0; n < instance.shape().get(0); n++) { + for (int m = 0; m < instance.shape().get(1); m++) { + assertEquals(instance.getLong(n, m), dst.getLong(n, m)); + } + } + } + + @Test + public void testCreate() { + long[] denseArray = {1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0}; + long[][] dense2Array = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, DimensionalSpace.create(shape)); + LongSparseNdArray instanceA = + LongSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + assertEquals(instance, instanceA); + + LongDataBuffer dataBuffer = RawDataBufferFactory.create(denseArray, false); + // use a zero buffer + LongSparseNdArray instanceB = LongSparseNdArray.create(DimensionalSpace.create(shape)); + instanceB.copyFrom(dataBuffer); + assertEquals(instance, instanceB); + + LongSparseNdArray instanceC = + LongSparseNdArray.create(dataBuffer, DimensionalSpace.create(shape)); + assertEquals(instanceB, instanceC); + + LongSparseNdArray instanceD = LongSparseNdArray.create(dataBuffer, shape); + assertEquals(instanceB, instanceD); + + LongNdArray ndArray = StdArrays.ndCopyOf(dense2Array); + LongSparseNdArray instanceE = LongSparseNdArray.create(ndArray); + assertEquals(instance, instanceE); + } + + @Test + public void testSlice() { + long[] expected = {0, 0, 2, 0, 0, 0}; + LongSparseNdArray instance = + new LongSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + LongNdArray sliceInstance = instance.slice(Indices.all(), Indices.sliceFrom(2)); + // check the values of the slice against the original sparse array + AtomicInteger i = new AtomicInteger(); + sliceInstance + .scalars() + .forEachIndexed((idx, f) -> assertEquals(expected[i.getAndIncrement()], f.getLong())); + // check values from elements(0) of a slice against the original sparse array + i.set(0); + sliceInstance + .elements(0) + .forEachIndexed( + (idx, l) -> + l.scalars() + .forEachIndexed( + (lidx, f) -> assertEquals(expected[i.getAndIncrement()], f.getLong()))); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/ShortSparseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/ShortSparseNdArrayTest.java new file mode 100644 index 00000000000..ae5b8ffef44 --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/ShortSparseNdArrayTest.java @@ -0,0 +1,314 @@ +package org.tensorflow.ndarray.impl.sparse; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.ShortBuffer; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.ShortNdArray; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.buffer.ShortDataBuffer; +import org.tensorflow.ndarray.impl.buffer.nio.NioDataBufferFactory; +import org.tensorflow.ndarray.impl.buffer.raw.RawDataBufferFactory; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.index.Indices; + +class ShortSparseNdArrayTest { + long[][] indicesArray = {{0, 0}, {1, 2}}; + short[] valuesArray = {1, 2}; + short[] denseArray = { + 1, 0, 0, 0, + 0, 0, 2, 0, + 0, 0, 0, 0 + }; + short[][] dense2DArrayDefaultValue = {{1, -1, -1, -1}, {-1, -1, 2, -1}, {-1, -1, -1, -1}}; + + Shape shape = Shape.of(3, 4); + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + ShortNdArray values = StdArrays.ndCopyOf(valuesArray); + + @Test + public void testBasic() { + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, DimensionalSpace.create(shape)); + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals(shape, instance.shape()); + } + + @Test + public void testCopyToBuffer() { + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, DimensionalSpace.create(shape)); + ShortDataBuffer dataBuffer = DataBuffers.ofShorts(instance.shape().size()); + + instance.copyTo(dataBuffer); + + short[] array = new short[denseArray.length]; + dataBuffer.read(array); + assertArrayEquals(denseArray, array); + } + + @Test + public void testCopyFromBuffer() { + + ShortDataBuffer dataBuffer = NioDataBufferFactory.create(ShortBuffer.wrap(denseArray)); + // use a zero buffer + ShortSparseNdArray instance = ShortSparseNdArray.create(DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + } + + @Test + public void testWriteDefaultValue() { + // change 0 to -1 + short[] denseArrayDefaultValue = new short[denseArray.length]; + for (int i = 0; i < denseArrayDefaultValue.length; i++) { + denseArrayDefaultValue[i] = denseArray[i] == 0 ? (short) -1 : denseArray[i]; + } + + ShortDataBuffer dataBuffer = RawDataBufferFactory.create(denseArrayDefaultValue, false); + // use a zero buffer + ShortSparseNdArray instance = + ShortSparseNdArray.create((short) -1, DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals((short) -1, instance.getShort(2, 0)); + } + + @Test + public void testGetObject() { + short[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + ShortNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getObject(n, m), instance.getObject(n, m)); + } + } + } + + @Test + public void testGetShort() { + short[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + ShortNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getShort(n, m), instance.getShort(n, m)); + } + } + } + + @Test + public void testGetShortDefaultValue() { + ShortNdArray ndArray = StdArrays.ndCopyOf(dense2DArrayDefaultValue); + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, (short) -1, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getShort(n, m), instance.getShort(n, m)); + } + } + } + + @Test + public void testGet() { + short[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + ShortNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + assertEquals(ndArray.get(n), instance.get(n)); + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.get(n, m), instance.get(n, m)); + } + } + } + + @Test + public void testSetObject() { + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + assertThrows(java.nio.ReadOnlyBufferException.class, () -> instance.setObject((short) 2, 0, 0)); + } + + @Test + public void testSet() { + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + assertThrows( + java.nio.ReadOnlyBufferException.class, + () -> instance.set(instance.getDefaultArray(), 0, 0)); + } + + @Test + public void testSort() { + + long[][] indicesArray = {{0, 0}, {1, 2}, {0, 1}, {2, 3}, {1, 4}}; + long[][] sortedIndicesArray = {{0, 0}, {0, 1}, {1, 2}, {1, 4}, {2, 3}}; + short[] valuesArray = {1, 3, 2, 5, 4}; + short[] sortedValuesArray = {1, 2, 3, 4, 5}; + + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + LongNdArray sortedIndices = StdArrays.ndCopyOf(sortedIndicesArray); + ShortNdArray values = StdArrays.ndCopyOf(valuesArray); + ShortNdArray sortedValues = StdArrays.ndCopyOf(sortedValuesArray); + + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + instance.sortIndicesAndValues(); + + // should be sorted in ascending row-wise coordinate order based on test values + assertEquals(sortedIndices, instance.getIndices()); + assertEquals(sortedValues, instance.getValues()); + } + + @Test + public void testElements() { + + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + short[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + instance + .elements(0) + .forEachIndexed( + (idx, item) -> { + short[] slice = dense2DArray[(int) idx[0]]; + item.scalars() + .forEachIndexed((dx, f) -> assertEquals(slice[(int) dx[0]], f.getObject())); + }); + } + + @Test + public void testDense() { + short[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, DimensionalSpace.create(shape)); + ShortNdArray denseInstance = instance.toDense(); + ShortNdArray expectedDense = StdArrays.ndCopyOf(dense2DArray); + assertEquals(expectedDense, denseInstance); + } + + @Test + public void testFromDense() { + short[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + ShortNdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + ShortSparseNdArray instance = + ShortSparseNdArray.create(DimensionalSpace.create(ndArray.shape())); + instance.fromDense(ndArray); + assertNotNull(instance.getIndices()); + assertEquals(2, instance.getIndices().shape().get(0)); + assertNotNull(instance.getValues()); + assertEquals(2, instance.getValues().size()); + + assertEquals(ndArray.shape(), instance.shape()); + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getShort(n, m), instance.getShort(n, m)); + } + } + } + + @Test + public void testElements1() { + short[] expected = {1, 0, 0}; + short[][] dense2DArray = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, DimensionalSpace.create(shape)); + instance + .elements(0) + .forEachIndexed((idx, l) -> assertEquals(expected[(int) idx[0]], l.getObject())); + instance + .elements(1) + .forEachIndexed( + (idx, l) -> assertEquals(dense2DArray[(int) idx[0]][(int) idx[1]], l.getObject())); + } + + @Test + public void testCopyTo() { + ShortNdArray dst = NdArrays.ofShorts(shape); + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, DimensionalSpace.create(shape)); + instance.copyTo(dst); + for (int n = 0; n < instance.shape().get(0); n++) { + for (int m = 0; m < instance.shape().get(1); m++) { + assertEquals(instance.getShort(n, m), dst.getShort(n, m)); + } + } + } + + @Test + public void testCreate() { + short[] denseArray = {1, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0}; + short[][] dense2Array = {{1, 0, 0, 0}, {0, 0, 2, 0}, {0, 0, 0, 0}}; + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, DimensionalSpace.create(shape)); + ShortSparseNdArray instanceA = + ShortSparseNdArray.create(indices, values, DimensionalSpace.create(shape)); + assertEquals(instance, instanceA); + + ShortDataBuffer dataBuffer = RawDataBufferFactory.create(denseArray, false); + // use a zero buffer + ShortSparseNdArray instanceB = ShortSparseNdArray.create(DimensionalSpace.create(shape)); + instanceB.copyFrom(dataBuffer); + assertEquals(instance, instanceB); + + ShortSparseNdArray instanceC = + ShortSparseNdArray.create(dataBuffer, DimensionalSpace.create(shape)); + assertEquals(instanceB, instanceC); + + ShortSparseNdArray instanceD = ShortSparseNdArray.create(dataBuffer, shape); + assertEquals(instanceB, instanceD); + + ShortNdArray ndArray = StdArrays.ndCopyOf(dense2Array); + ShortSparseNdArray instanceE = ShortSparseNdArray.create(ndArray); + assertEquals(instance, instanceE); + } + + @Test + public void testSlice() { + short[] expected = {0, 0, 2, 0, 0, 0}; + ShortSparseNdArray instance = + new ShortSparseNdArray(indices, values, DimensionalSpace.create(shape)); + + ShortNdArray sliceInstance = instance.slice(Indices.all(), Indices.sliceFrom(2)); + // check the values of the slice against the original sparse array + AtomicInteger i = new AtomicInteger(); + sliceInstance + .scalars() + .forEachIndexed((idx, f) -> assertEquals(expected[i.getAndIncrement()], f.getShort())); + // check values from elements(0) of a slice against the original sparse array + i.set(0); + sliceInstance + .elements(0) + .forEachIndexed( + (idx, l) -> + l.scalars() + .forEachIndexed( + (lidx, f) -> assertEquals(expected[i.getAndIncrement()], f.getShort()))); + } +} diff --git a/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/StringSparseNdArrayTest.java b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/StringSparseNdArrayTest.java new file mode 100644 index 00000000000..32b83ef702f --- /dev/null +++ b/tensorflow-ndarray/src/test/java/org/tensorflow/ndarray/impl/sparse/StringSparseNdArrayTest.java @@ -0,0 +1,356 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.ndarray.impl.sparse; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.ndarray.buffer.DataBuffer; +import org.tensorflow.ndarray.buffer.DataBuffers; +import org.tensorflow.ndarray.impl.dimension.DimensionalSpace; +import org.tensorflow.ndarray.index.Indices; + +public class StringSparseNdArrayTest { + long[][] indicesArray = {{0, 0}, {1, 2}}; + String[] valuesArray = {"alpha", "omega"}; + String[] denseArray = { + "alpha", null, null, null, null, null, "omega", null, null, null, null, null + }; + String[][] dense2DArray = { + {"alpha", null, null, null}, {null, null, "omega", null}, {null, null, null, null} + }; + + String[][] dense2DArrayDefault = { + {"alpha", "default", "default", "default"}, + {"default", "default", "omega", "default"}, + {"default", "default", "default", "default"} + }; + + Shape shape = Shape.of(3, 4); + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + NdArray values = StdArrays.ndCopyOf(valuesArray); + + @Test + public void testBasic() { + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals(shape, instance.shape()); + } + + @Test + public void testCopyToBuffer() { + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + DataBuffer dataBuffer = DataBuffers.ofObjects(String.class, instance.shape().size()); + + instance.copyTo(dataBuffer); + + String[] array = new String[denseArray.length]; + dataBuffer.read(array); + assertArrayEquals(denseArray, array); + } + + @Test + public void testCopyFromBuffer() { + + DataBuffer dataBuffer = DataBuffers.ofObjects(denseArray); + // use a zero buffer + SparseNdArray> instance = + SparseNdArray.create(String.class, DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals(indices.shape().get(0), values.size()); + assertEquals(2, values.size()); + } + + @Test + public void testWriteDefaultValue() { + String defaultValue = "default"; + String[] denseArrayDefaultValue = new String[denseArray.length]; + for (int i = 0; i < denseArrayDefaultValue.length; i++) { + denseArrayDefaultValue[i] = denseArray[i] == null ? defaultValue : denseArray[i]; + } + + DataBuffer dataBuffer = DataBuffers.ofObjects(denseArrayDefaultValue); + // use a zero buffer + SparseNdArray> instance = + SparseNdArray.create(String.class, defaultValue, DimensionalSpace.create(shape)); + instance.copyFrom(dataBuffer); + + assertEquals(indices, instance.getIndices()); + assertEquals(values, instance.getValues()); + assertEquals(2, values.size()); + assertEquals(indices.shape().get(0), values.size()); + } + + @Test + public void testGetObject() { + NdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getObject(n, m), instance.getObject(n, m)); + } + } + } + + @Test + public void testGetObjectDefaultValue() { + String defaultValue = "default"; + + NdArray ndArray = StdArrays.ndCopyOf(dense2DArrayDefault); + SparseNdArray> instance = + new SparseNdArray<>( + String.class, indices, values, defaultValue, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getObject(n, m), instance.getObject(n, m)); + } + } + } + + @Test + public void testGet() { + NdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + + for (int n = 0; n < ndArray.shape().get(0); n++) { + assertEquals(ndArray.get(n), instance.get(n)); + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.get(n, m), instance.get(n, m)); + } + } + } + + @Test + public void testSetObject() { + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + + assertThrows(java.nio.ReadOnlyBufferException.class, () -> instance.setObject(null, 0, 0)); + } + + @Test + public void testSet() { + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + + assertThrows( + java.nio.ReadOnlyBufferException.class, + () -> instance.set(instance.getDefaultArray(), 0, 0)); + } + + @Test + public void testSort() { + + long[][] indicesArray = {{0, 0}, {1, 2}, {0, 1}, {2, 3}, {1, 4}}; + long[][] sortedIndicesArray = {{0, 0}, {0, 1}, {1, 2}, {1, 4}, {2, 3}}; + String[] valuesArray = {"b", "d", "a", null, "c"}; + String[] sortedValuesArray = {"b", "a", "d", "c", null}; + + LongNdArray indices = StdArrays.ndCopyOf(indicesArray); + LongNdArray sortedIndices = StdArrays.ndCopyOf(sortedIndicesArray); + NdArray values = StdArrays.ndCopyOf(valuesArray); + NdArray sortedValues = StdArrays.ndCopyOf(sortedValuesArray); + + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + + instance.sortIndicesAndValues(); + + // should be sorted in ascending row-wise coordinate order based on test values + assertEquals(sortedIndices, instance.getIndices()); + assertEquals(sortedValues, instance.getValues()); + } + + @Test + public void testElements() { + + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + + instance + .elements(0) + .forEachIndexed( + (idx, item) -> { + String[] slice = dense2DArray[(int) idx[0]]; + item.scalars() + .forEachIndexed((dx, f) -> assertEquals(slice[(int) dx[0]], f.getObject())); + }); + } + + @Test + public void testDense() { + + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + NdArray denseInstance = instance.toDense(); + NdArray expectedDense = StdArrays.ndCopyOf(dense2DArray); + assertEquals(expectedDense, denseInstance); + } + + @Test + public void testFromDense() { + NdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + SparseNdArray> instance = + SparseNdArray.create(String.class, DimensionalSpace.create(ndArray.shape())); + instance.fromDense(ndArray); + assertNotNull(instance.getIndices()); + assertEquals(2, instance.getIndices().shape().get(0)); + assertNotNull(instance.getValues()); + assertEquals(2, instance.getValues().size()); + + assertEquals(ndArray.shape(), instance.shape()); + for (int n = 0; n < ndArray.shape().get(0); n++) { + for (int m = 0; m < ndArray.shape().get(1); m++) { + assertEquals(ndArray.getObject(n, m), instance.getObject(n, m)); + } + } + } + + @Test + public void testElements1() { + String[] expected = {"alpha", null, null}; + + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + instance + .elements(0) + .forEachIndexed((idx, l) -> assertEquals(expected[(int) idx[0]], l.getObject())); + instance + .elements(1) + .forEachIndexed( + (idx, l) -> assertEquals(dense2DArray[(int) idx[0]][(int) idx[1]], l.getObject())); + } + + @Test + public void testCopyTo() { + NdArray dst = NdArrays.ofObjects(String.class, shape); + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + instance.copyTo(dst); + for (int n = 0; n < instance.shape().get(0); n++) { + for (int m = 0; m < instance.shape().get(1); m++) { + assertEquals(instance.getObject(n, m), dst.getObject(n, m)); + } + } + } + + @Test + public void testCreate() { + + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + SparseNdArray> instanceA = + SparseNdArray.create(String.class, indices, values, DimensionalSpace.create(shape)); + assertEquals(instance, instanceA); + + DataBuffer dataBuffer = DataBuffers.ofObjects(denseArray); + + // use a zero buffer + SparseNdArray> instanceB = + SparseNdArray.create(String.class, DimensionalSpace.create(shape)); + instanceB.copyFrom(dataBuffer); + assertEquals(instance, instanceB); + + SparseNdArray> instanceC = + SparseNdArray.create(String.class, dataBuffer, DimensionalSpace.create(shape)); + assertEquals(instanceB, instanceC); + + SparseNdArray> instanceD = + SparseNdArray.create(String.class, dataBuffer, shape); + assertEquals(instanceB, instanceD); + + NdArray ndArray = StdArrays.ndCopyOf(dense2DArray); + SparseNdArray> instanceE = SparseNdArray.create(String.class, ndArray); + assertEquals(instance, instanceE); + } + + @Test + public void testSlice() { + String[] expected = {null, null, "omega", null, null, null}; + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + + NdArray sliceInstance = instance.slice(Indices.all(), Indices.sliceFrom(2)); + // check the values of the slice against the original sparse array + AtomicInteger i = new AtomicInteger(); + sliceInstance + .scalars() + .forEachIndexed((idx, f) -> assertEquals(expected[i.getAndIncrement()], f.getObject())); + // check values from elements(0) of a slice against the original sparse array + i.set(0); + sliceInstance + .elements(0) + .forEachIndexed( + (idx, l) -> + l.scalars() + .forEachIndexed( + (lidx, f) -> assertEquals(expected[i.getAndIncrement()], f.getObject()))); + } + + @Test + public void testNullDefault() { + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + + NdArray dArray = instance.getDefaultArray(); + assertEquals(1L, dArray.size()); + assertNull(dArray.getObject()); + + instance = + new SparseNdArray<>( + String.class, indices, values, "a default", DimensionalSpace.create(shape)); + + dArray = instance.getDefaultArray(); + assertEquals(1L, dArray.size()); + assertNotNull(dArray.getObject()); + assertEquals("a default", dArray.getObject()); + } + + @Test + public void testToString() { + SparseNdArray> instance = + new SparseNdArray<>(String.class, indices, values, DimensionalSpace.create(shape)); + Assertions.assertEquals( + "SparseNdArray(type=String, defaultValue=, numElements=2, shape=[3, 4])", + instance.toString()); + instance = + new SparseNdArray<>( + String.class, indices, values, "a default", DimensionalSpace.create(shape)); + Assertions.assertEquals( + "SparseNdArray(type=String, defaultValue='a default', numElements=2, shape=[3, 4])", + instance.toString()); + } +} diff --git a/tensorflow-ndarray/src/test/resources/COPYRIGHT.txt b/tensorflow-ndarray/src/test/resources/COPYRIGHT.txt new file mode 100644 index 00000000000..5e7bd50bb48 --- /dev/null +++ b/tensorflow-ndarray/src/test/resources/COPYRIGHT.txt @@ -0,0 +1 @@ +All images in this folder and its subfolders are free of any copyright. \ No newline at end of file diff --git a/tensorflow-ndarray/src/test/resources/castle.jpg b/tensorflow-ndarray/src/test/resources/castle.jpg new file mode 100644 index 00000000000..c5b07b4bc2a Binary files /dev/null and b/tensorflow-ndarray/src/test/resources/castle.jpg differ