Discussion thread | https://lists.apache.org/thread/m006xymqto1kp0s80sr8jqbbmw6fq9sg |
---|---|
Vote thread | https://lists.apache.org/thread/3t892xv83vvoc01nqp2tsk48kqbrlbsg |
JIRA | |
Release | ml-2.2.0 |
Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast)
[This FLIP proposal is a joint work by Dong Lin, Fan Hong, Jiang Xin, Zhipeng Zhang ]
Motivation and Objectives
With the existing Estimator/AlgoOperator/Transformer API, Flink ML can perform efficient near-line data transformation with Flink runtime. However, the existing APIs can not meet the requirements of online inference. Online inference aims to produce machine learning predictions in real-time and thus requires the inference framework to be lightweight and fast.
In this FLIP, we propose a set of APIs that support low-latency data transformation without relying on Flink runtime. We call the online inference subject `Servable`. A Servable takes one DataFrame as input and produces one DataFrame as the output.
We would like to address the following objectives with the changes proposed in this FLIP.
1) Add classes to support space-efficient representations of vector and matrix (e.g. DenseVector, SparseVector).
2) Provide API to express online serving logic that transforms one DataFrame into another DataFrame.
3) Provide API for users to periodically and dynamically update the model data used by a Servable without restarting it.
For example, a Flink ML training process generates new model data continuously and sinks the model data to Kafka. The inference framework can read the model data from Kafka and update the model in memory for better prediction results. Therefore we need to add the API to help continuously update the model data.
4) Provide utility tool to make it easy for users to chain and compose a sequence of Servable instances into one Servable instance.
5) Provide API for users to instantiate a Servable from the file (i.e., meta data and model data) saved by a Transformer.
6) Provide API for serving infrastructure to tell whether the given Transformer instances specified by users (could be from a drag-and-drop front-end) can be chained and turned into Servables, without actually running any Flink job. Note that not all Transfomer can be turned into Servable.
Public Interfaces
We propose to make the following API changes to support the use-cases described above.
1) Add the following classes to express the input/output data of Flink ML API: Vector, DenseVector, SparseVector, Vectors, Matrix, and DenseMatrix
These changes address the 1st objective described above.
/** A vector of double values. */ @PublicEvolving public interface Vector extends Serializable { /** Gets the size of the vector. */ int size(); /** Gets the value of the ith element. */ double get(int i); /** Sets the value of the ith element. */ void set(int i, double value); /** Converts the instance to a double array. */ double[] toArray(); /** Converts the instance to a dense vector. */ DenseVector toDense(); /** Converts the instance to a sparse vector. */ SparseVector toSparse(); /** Makes a deep copy of the vector. */ Vector clone(); } /** A dense vector of double values. */ @PublicEvolving public class DenseVector implements Vector { public final double[] values; public DenseVector(double[] values) {...} public DenseVector(int size) {...} } /** A sparse vector of double values. */ @PublicEvolving public class SparseVector implements Vector { public final int n; public int[] indices; public double[] values; public SparseVector(int n, int[] indices, double[] values) {...} } /** Utility methods for instantiating Vector. */ @PublicEvolving public class Vectors { /** Creates a dense vector from its values. */ public static Vector dense(double... values) {...} /** Creates a sparse vector from its values. */ public static Vector sparse(int size, int[] indices, double[] values) {...} } /** A matrix of double values. */ @PublicEvolving public interface Matrix extends Serializable { /** Gets number of rows. */ int numRows(); /** Gets number of columns. */ int numCols(); /** Gets value of the (i,j) element. */ double get(int i, int j); /** Converts the instance to a dense matrix. */ DenseMatrix toDense(); } /** * Column-major dense matrix. The entry values are stored in a single array of doubles with columns * listed in sequence. */ @PublicEvolving public class DenseMatrix implements Matrix { /** * Array for internal storage of elements. * * <p>The matrix data is stored in column major format internally. */ public final double[] values; /** * Constructs an m-by-n matrix of zeros. * * @param numRows Number of rows. * @param numCols Number of columns. */ public DenseMatrix(int numRows, int numCols) {...} /** * Constructs a matrix from a 1-D array. The data in the array should be organized in column * major. * * @param numRows Number of rows. * @param numCols Number of cols. * @param values One-dimensional array of doubles. */ public DenseMatrix(int numRows, int numCols, double[] values) {...} }
2) Add the following classes: BasicType, DataType, ScalarType, VectorType, and MatrixType.
These classes are needed to represent the type of data contained in DataFrame (see below) so that that users can process DataFrame properly.
/** This enum class lists primitive types such as boolean, int, long, etc. */ @PublicEvolving public enum BasicType { BOOLEAN, BYTE, SHORT, INT, LONG, FLOAT, DOUBLE, STRING, BYTE_STRING; } /** This class describes the data type of a value. */ @PublicEvolving public abstract class DataType {} /** A DataType representing a single element of the given BasicType. */ @PublicEvolving public final class ScalarType extends DataType { public ScalarType(BasicType elementType) {...} public BasicType getElementType() {...} } /** A DataType representing a Vector. */ @PublicEvolving public final class VectorType extends DataType { public VectorType(BasicType elementType) {...} public BasicType getElementType() {...} } /** A DataType representing a Matrix. */ @PublicEvolving public final class MatrixType extends DataType { public MatrixType(BasicType elementType) {...} public BasicType getElementType() {...} }
3) Add the TransformerServable interface, DataFrame class and Row class.
This change addresses the 2nd objective described above.
/** Represents an ordered list of values. */ @PublicEvolving public class Row { public Row(List<Object> values) {...} /** Returns the value at the given index. */ public Object get(int index) {...} /** Returns the value at the given index as the given type. */ @SuppressWarnings("unchecked") public <T> T getAs(int index) {...} /** Adds the value to the end of this row and returns this row. */ public Row add(Object value) {...} /** Returns the number of values in this row. */ public int size() {...} }
/** * A DataFrame consists of some number of rows, each of which has the same list of column names and * data types. * * <p>All values in a column must have the same data type: integer, float, string etc. */ @PublicEvolving public class DataFrame { public DataFrame(List<String> columnNames, List<DataType> dataTypes, List<Row> rows) {...} /** Returns a list of the names of all the columns in this DataFrame. */ public List<String> getColumnNames() {...} /** * Returns the index of the column with the given name. * * @throws IllegalArgumentException if the column is not present in this table */ public int getIndex(String name) {...} /** * Returns the data type of the column with the given name. * * @throws IllegalArgumentException if the column is not present in this table */ public DataType getDataType(String name) {...} /** * Adds to this DataFrame a column with the given name, data type, and values. * * @throws IllegalArgumentException if the number of values is different from the number of rows. */ public DataFrame addColumn(String columnName, DataType dataType, List<Object> values) {...} /** Returns all rows of this table. */ public List<Row> collect() {...} }
/** * A TransformerServable takes a DataFrame as input and produces a DataFrame as the result. It can * be used to encode online inference computation logic. * * <p>NOTE: Every TransformerServable subclass should have a no-arg constructor. * * <p>NOTE: Every TransformerServable subclass should implement a static method with signature * {@code static T load(String path) throws IOException;}, where {@code T} refers to the concrete * subclass. This static method should instantiate a new TransformerServable instance based on the * data read from the given path. * * @param <T> The class type of the TransformerServable implementation itself. */ @PublicEvolving public interface TransformerServable<T extends TransformerServable<T>> extends WithParams<T> { /** * Applies the TransformerServable on the given input DataFrame and returns the result DataFrame. * * @param input the input data * @return the result data */ DataFrame transform(DataFrame input); }
4) Add the ModelServable interface.
This interface addresses the 3rd objective described above.
/** * A ModelServable is a TransformerServable with the extra API to set model data. * * @param <T> The class type of the ModelServable implementation itself. */ @PublicEvolving public interface ModelServable<T extends ModelServable<T>> extends TransformerServable<T> { /** Sets model data using the serialized model data from the given input stream. */ default T setModelData(InputStream... modelData) throws IOException { throw new UnsupportedOperationException("this operation is not supported"); } }
5) Add the PipelineModelServable class
This class addresses the 4th objective described above.
/** * A PipelineModelServable acts as a ModelServable. It consists of an ordered list of * TransformerServable, each of which could be a TransformerServable or ModelServable. */ @PublicEvolving public final class PipelineModelServable implements ModelServable<PipelineModelServable> { public PipelineModelServable(List<TransformerServable<?>> servables) {...} @Override public DataFrame transform(DataFrame input) {...} @Override public PipelineModelServable setModelData(InputStream... modelData) throws IOException {...} @Override public Map<Param<?>, Object> getParamMap() {...} public static PipelineModelServable load(String path) throws IOException {...} }
6) Update the Java doc of the Transformer interface to state that if a Transformer has the corresponding TransformerServable, it should implement the static method `loadServable(...)`.
This change addresses the 5th objective described above.
The `loadServable(...)` method tells users whether the Transformer can be turned into a Servable and what the corresponding servable class is. It also simplifies the loading of the servable instance.
/** * ... * * <p>NOTE: If a Transformer has a corresponding {@link TransformerServable}, it should implement * a static method with signature {@code static T loadServable(String path)}, where {@code T} refers to * the concrete subclass of {@link TransformerServable}. This static method should instantiate a * new {@link TransformerServable} instance based on the data read from the given path. * * ... */ @PublicEvolving public interface Transformer<T extends Transformer<T>> extends AlgoOperator<T> {...}
7) Add a util method in PipelineModel and GraphModel to check whether all Transformers in the pipeline/graph can be turned into TransformerServable.
This change addresses the 6th objective described above.
@PublicEvolving public final class PipelineModel implements Model<PipelineModel> { /** * Whether all stages in the pipeline have corresponding {@link TransformerServable} so that * the PipelineModel can be turned into a TransformerServable and used in an online inference program. * * @return true if all stages have corresponding TransformerServable, false if not. */ public boolean supportServable() {...} } @PublicEvolving public final class GraphModel implements Model<GraphModel> { /** * Whether all stages in the graph have corresponding {@link TransformerServable} so that * the GraphModel can be turned into a TransformerServable and used in an online inference program. * * @return true if all stages have corresponding TransformerServable, false if not. */ public boolean supportServable() {...} }
Example Usage
In this section, we provide example code snippets to demonstrate how we can use the APIs proposed in this FLIP to address the use-cases in the motivation section.
Here is an online inference scenario:
We have an unbounded stream of labeled data that can be used for training.
We have one Flink ML Estimator that is trained using this unbounded stream of data. And we assume the accuracy of this Estimator increases with the increasing amount of training data it has seen.
We would like to train this Estimator using the unbounded data stream on the Flink cluster. And use these algorithms with the update-to-date model data to do inference on 10 different web servers.
In order to address this use-case, we can write the training and inference logic with the following API behaviors:
Estimator::fit(...) generates an instance of Model.
Model::save(...) saves the metadata of this model to filesystem.
ModelServableA::load(...) takes the file path as input and loads the file into a ModelServableA instance.
ModelServable::setModelData(...) takes an InputStream as input. Its implementation deserializes the bytes from this input stream and updates its model data accordingly.
ModelServable::transform(...) takes a DataFrame as input and returns a DataFrame. The returned DataFrame represents the inference results.
Here are the code snippet that addresses this use-case by using the proposed APIs.
First, run the following code on the Flink cluster to generate the model and continuously write the latest model data to Kafka:
void runTrainingOnFlinkCluster(...) { Table trainingStream = ...; Estimator estimatorA = new EstimatorA(...); Model modelA = estimatorA.fit(trainingStream); PipelineModel pipelineModel = new PipelineModel(Arrays.asList(modelA)); if (!pipelineModel.supportServable()) { throw new RuntimeException("some stages can not be turned into servable ") } Table modelData = modelA.getModelData()[0]; // This method writes the data from the given Table to a Kafka topic. writeToKafka(modelData, "topicA"); // Saves model's state/metadata to a remote path. modelA.save(remotePath); // Executes the operators generated by the Estimator::fit(...), which reads from trainingStream and writes to modelData. env.execute(); }
Then run the following code on each web server to load the model data into a ModelServable for online inference. (Note: in production situations, synchronization techniques, like locks, should be used to coordinate the model update thread and transformation thread.)
void runOnlineInferenceOnWebServer(...) { // Loads model's state/metadata to generate a Servable instance. ModelServable servableA = ModelServableA::load(remotePath); // Creates a thread to update the latest model data that generated // by the above code snippet. ((Runnable) (servableA) -> { InputStream modelData = readFromKafka("topicA"); servableA.setModelData(modelData); }).run(); while (True) { HttpRequest request = getNextRequest(); Json input_json = request.getBody(); // Constructs the input DataFrame from the request data DataFrame input_df = new DataFrame( Arrays.asList("input"), Arrays.asList(new ScalarType(BasicType.LONG)), Arrays.asList(convertJsonToRow(input_json)) ); // Transforms with ModelServableA DataFrame output_df = servableA.transform(input_df); // Constructs the output Json object from the DataFrame Json output_json = convertRowToJson( output_df.getDataType("output"), output_df.collect().get(0) ); request.setResponseBody(output_json); } }
Compatibility, Deprecation, and Migration Plan
This FLIP proposes a new feature in Flink ML. It is fully backward compatible.
Test Plan
We will provide unit tests to validate the proposed changes.
Rejected Alternatives
1) Let the ModelServable interface provide the API "<T> T setModelData(M modelData)" where M is the concrete ModelData class.
We choose not to use this approach because it requires users to have additional code to de-serialize bytes from e.g. Kafka to obtain the concrete model data instance. And it also makes API signature more complicated.