Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

Status

Current state[Accepted]

...

Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

Motivation

The goal of this FLIP is to enhance the scalability and the ease of use of Flink ML. In machine learning, there are mainly two types of people. As shown in the following figure, the first type is MLlib developer. They need a set of standard/well abstracted core ML APIs to implement the algorithms. Every ML algorithm is a certain concrete implementation on top of these APIs. The second type is MLlib users who utilize the existing/packaged MLlib to train or server a model.  It is pretty common that the entire training or inference is constructed by a sequence of transformation or algorithms. It is essential to provide a workflow/pipeline API for MLlib users such that they can easily combine multiple algorithms to describe the ML workflow/pipeline.

...

  • Provide a new set of ML core interface (on top of Flink TableAPI)
  • Provide a ML pipeline interface (on top of Flink TableAPI)
  • Provide the interfaces for parameters management and pipeline persistence
  • All the above interfaces should facilitate any new ML algorithm. We will gradually add various standard ML algorithms on top of these new proposed interfaces to ensure their feasibility and scalability.

Proposed Changes

Major Concepts:

This section introduces the key interfaces proposed in this FLIP. Most of them are inspired by Scikit-learn project.

...

Persistable: This interface is provided to save and restore Pipeline and PipelineStages.

ML core interface:

PipelineStage:

/**

 * Base class for a stage in a pipeline. The interface is only a concept, and does not have any

 * actual functionality. Its subclasses must be either Estimator or Transformer. No other classes

 * should inherit this interface directly.

 *

 * <p>Each pipeline stage is with parameters and meanwhile persistable.

 *

 * @param <T> The class type of the PipelineStage implementation itself, used by {@link

 *            org.apache.flink.table.ml.api.helper.param.WithParams}

 * @see WithParams

 */

interface PipelineStage<T extends PipelineStage<T>> extends WithParams<T>, Persistable, Serializable {

}


Transformer:

/**

* A transformer is a {@link PipelineStage} that transforms an input {@link Table} to a result {@link Table}.

*

* @param <T> The class type of the Transformer implementation itself, used by

*            {@link org.apache.flink.table.ml.api.helper.param.WithParams}

*/

public interface Transformer<T extends Transformer<T>> extends PipelineStage<T> {

   /**

    * Applies the transformer on the input table, and returns the result table.

    *

    * @param input the table to be transformed

    * @return the transformed table

    */

   Table transform(Table input);

}


Model:

/**

* A model is an ordinary {@link Transformer} except how it is created. While ordinary

* transformers are defined by specifying the parameters directly, a model is usually

* generated by an {@link Estimator} when {@link Estimator#fit(Table)} is invoked.

*

* <p>We separate Model from {@link Transformer} in order to support potential

* model specific logic such as linking a Model to the {@link Estimator} from which

* the model was generated.

*

* @param <M> The class type of the Model implementation itself,

*           used by {@link org.apache.flink.table.ml.api.helper.param.WithParams}

*/

public interface Model<M extends Model<M>> extends Transformer<M> {

}



Estimator:

/**

* Estimators are {@link PipelineStage}s responsible for training and generating machine learning models.

*

* <p>The implementations are expected to take an input table as training samples and generate a {@link Model}

* which fits these samples.

*

* @param <E> class type of the Estimator implementation itself,

*           used by {@link org.apache.flink.table.ml.api.helper.param.WithParams}.

* @param <M> class type of the {@link Model} this Estimator produces.

*/

public interface Estimator<E extends Estimator<E, M>, M extends Model<M>> extends PipelineStage<E> {

   /**

    * Train and produce a {@link Model} which fits the records in the given {@link Table}.

    *

    * @param input the table with records to train the Model.

    * @return a model trained to fit on the given Table.

    */

   M fit(Table input);

}


ML pipeline:

ML Pipeline is a linear workflow. It consists of a sequence of PipelineStages. Each stage is either a Transformer(Model) or an Estimator. The input Table is updated as it passes through each stage. In Transformer stages, the transform() method is called on the Table. In Estimator stages, the fit() method is called to produce a Model. The transform() method of the returned Model is called on the new input Table during the inference. If a pipeline contains an Estimator/Model, we name it as a Estimator/Model pipeline respectively. Otherwise, it is a Transformer pipeline.

...

The above figure shows a pipeline with two stages. The first one is a Transformer, and the second is an Estimator. The entire pipeline is an Estimator Pipeline (because the pipeline is ended with an Estimator stage). During the training step, the Pipeline.fit() method is called on the original input table (input1). In the Transformer stage, transform() method converts input table (input1) into a new output table (output1). In the Estimator stage, fit() method is called to produce a Model (Model is a special Transformer whose params are trained by an Estimator). After Estimator Pipeline’s fit() method is executed, it returns a Model pipeline, which has the same number of stages as the Estimator Pipeline, but Estimator has become Model. This Model pipeline will be then used in the inference step. When Model pipeline’s transform() method is called on a test input table (input2), the data are passed through the entire Model pipeline. In each stage, transform() method is called to convert the table and then pass it to the next stage. Finally, Model pipeline returns a result table after all Transformers and Models have executed their transform() method.

Pipeline:


/**

 * A pipeline is a linear workflow which chains {@link Estimator}s and {@link Transformer}s to execute an algorithm.

 *

 * <p>A pipeline itself can either act as an Estimator or a Transformer, depending on the stages it includes. More

 * specifically:

 * <ul>

 *  <li>

 *      If a Pipeline has an {@link Estimator}, one needs to call {@link Pipeline#fit(Table)} before use the pipeline

 *      as a {@link Transformer}. In this case the Pipeline is an {@link Estimator} and can produce a Pipeline as a

 *      {@link Model}.

 *  </li>

 *  <li>

 *      If a Pipeline has no {@link Estimator}, it is a {@link Transformer} and can be applied to a Table directly.

 *      In this case, {@link Pipeline#fit(Table)} will simply return the pipeline itself.

 *  </li>

 * </ul>

 *

 * <p>In addition, a pipeline can also be used as a {@link PipelineStage} in another pipeline, just like an

 * ordinary {@link Estimator} or {@link Transformer} as described above.

 */

public final class Pipeline implements Estimator<Pipeline, Pipeline>, Transformer<Pipeline>, Model<Pipeline> {

    ….(implementation details are skipped)....

}


Help interface:

Params:

/**

* The map-like container class for parameter. This class is provided to unify the interaction with parameters.

*/

public class Params implements Serializable {

   /**

    * Returns the value of the specific parameter, or default value defined in the {@code info} if this Params

    * doesn't contain the param.

    *

    * @param info the info of the specific parameter, usually with default value

    * @param <V>  the type of the specific parameter

    * @return the value of the specific parameter, or default value defined in the {@code info} if this Params doesn't

    *         contain the parameter

    * @throws RuntimeException if the Params doesn't contains the specific parameter, while the param is not optional

    *                          but has no default value in the {@code info}

    */

   public <V> V get(ParamInfo<V> info) {......}

   /**

    * Set the value of the specific parameter.

    *

    * @param info  the info of the specific parameter to set.

    * @param value the value to be set to the specific parameter.

    * @param <V>   the type of the specific parameter.

    * @return the previous value of the specific parameter, or null if this Params didn't contain the parameter before

    * @throws RuntimeException if the {@code info} has a validator and the {@code value} is evaluated as illegal

    *                          by the validator

    */

   public <V> Params set(ParamInfo<V> info, V value) {......}

   /**

    * Removes the specific parameter from this Params.

    *

    * @param info the info of the specific parameter to remove

    * @param <V>  the type of the specific parameter

    */

   public <V> void remove(ParamInfo<V> info) {......}

   /**

    * Creates and returns a deep clone of this Params.

    *

    * @return a deep clone of this Params

    */

   public Params clone() {.....}

   /**

    * Returns a json containing all parameters in this Params. The json should be human-readable if possible.

    *

    * @return a json containing all parameters in this Params

    */

   public String toJson() {......}

   /**

    * Restores the parameters from the given json. The parameters should be exactly the same with the one who

    * was serialized to the input json after the restoration. The class mapping of the parameters in the json is

    * required because it is hard to directly restore a param of a user defined type.

    * Params will be treated as String if it doesn't exist in the {@code classMap}.

    *

    * @param json     the json String to restore from

    * @param classMap the classes of the parameters contained in the json

    */

   public void fromJson(String json, Map<String, Class<?>> classMap) {......}

}

ParamInfo:


/**

* Definition of a parameter, including name, type, default value, validator and so on.

...

   boolean validate(V value);

}

WithParams:


/**

* Parameters are widely used in machine learning algorithms. This class defines a common interface to interact

* with classes with parameters.

*

* @param <T> the actual type of this WithParams, as the return type of setter

*/

public interface WithParams<T> {

   /**

    * Returns the all the parameters.

    *

    * @return all the parameters.

    */

   Params getParams();

   /**

    * Set the value of a specific parameter.

    *

    * @param info  the info of the specific param to set

    * @param value the value to be set to the specific param

    * @param <V>   the type of the specific param

    * @return the WithParams itself

    */

   default <V> T set(ParamInfo<V> info, V value) {

       getParams().set(info, value);

       return (T) this;

   }

   /**

    * Returns the value of the specific param.

    *

    * @param info the info of the specific param, usually with default value

    * @param <V>  the type of the specific param

    * @return the value of the specific param, or default value defined in the {@code info} if the inner Params doesn't contains this param

    */

   default <V> V get(ParamInfo<V> info) {

       return getParams().get(info);

   }

}


Persistable

/**

* An interface to allow PipelineStage persistence and reload. As of now, we are using JSON as format.

...

   void loadJson(String json);

}

Examples:

In this section, we illustrate how ML pipeline works with a simple example. The figure above shows the usage of a pipeline in training step as well as the inference step. In this case, the pipeline has three stages. The first two stages (Bucketize and Connect) are Transformers, and the third stage (Linear Regression) is an Estimator. For the training step, since Linear Regression is an Estimator, it is a estimator pipeline. The pipeline will call fit() method to produce a LinearRegression Model (a special Transformer whose params are trained by an Estimator) and therefore the resulting pipeline becomes a fitted model pipeline. This model pipeline can be persistent and used for inference. During the inference step, when the transform() method is called on the new input table (servingTable), the input table will be passed through the entire model pipeline. The transform() method of each stage will update the table and passes the resulting table to the next stage. The corresponding test code for this example is shown below:

val inputFields = Array("gender", "age", "income", "label")

val inputTypes = Array[DataType](DoubleType.INSTANCE, DoubleType.INSTANCE, DoubleType.INSTANCE, DoubleType.INSTANCE)

//prepare the data for training

val trainingTable = createInputTable(tEnv, generateDataList(), inputFields, inputTypes)

//create a bucketize transformer

val bucketize = new Bucketize().setInputCol("income").setBoundaries(Array[Double](1, 8, 20)).setOutputCol("income_rank")

//create a connect transformer, which connects all features into a double array as the input of lr

val connect = new Connect().setDim(3).setInputCols(Array("gender", "age", "income")).setOutputCol("data")

val lr = new LinearRegression().setFeatureCol("data").setLabelCol("label").setPredictionCol("pred").setDim(3).setMaxIter(1000).setInitLearningRate(0.001)

//initialize pipeline

val pipeline = new Pipeline

pipeline.appendStage(bucketize).appendStage(connect).appendStage(lr)

//train the pipeline and return the model pipeline

val model = pipeline.fit(trainingTable)

//persistent model pipeline

saveStage(modelPath, model)

//prepare the data for serving

val servingTable = createInputTable(tEnv, generateDataList(), inputFields, inputTypes)

//serving the new generated data with model pipeline

val result1 = model.transform(servingTable)

//alternatively, model pipeline can be reloaded from persistent storage

val restoredPipeline = loadStage[Pipeline](modelPath)

val result2 = restoredPipeline.transform(servingTable)

Compatibility, Deprecation, and Migration Plan

The new proposed ML pipeline and ML lib interfaces are completely independent from the legacy flink-ml package which is designed on top of DataSet API. The proposed changes of this FLIP will be implemented in another package (flink-table-ml) in flink-libraries. Therefore there’s no compatibility problems.

The new TableAPI based Flink ML package will completely cover the functionality of the legacy Flink ML. The legacy flink-ml will be deprecated when we deprecate DataSet API in the future (https://flink.apache.org/roadmap.html). Users who currently use the legacy flink-ml should consider switching to build their applications on top of the new flink-table-ml. Since the ml libs in flink-table-ml will be the superset of the ones in flink-ml, the cost of this migration should be very low, almost inevitable.

Rejected Alternatives

  • WRT Pipeline, instead of making it as a concrete class, alternatively, we can define it as an interface and providing an implementation. But we think the abstract and functionality of Pipeline is pretty general. It is not necessary to allow users overriding it with different implementations.

Implementation plan

Modules and Packages:

flink-ml (module under flink root)

Flink1.9.0

  • complete flink-ml-api module
  • implement several algorithms in flink-ml-lib module
  • complete the integration test framework in flink-ml-test module
  • keep adding ml libs depending the contribution progress and the cut-off date of 1.9.0
  • try to collaborate & integrate the FLIP23 (model serving)

Flink1.10.0

  • Keep adding ml libs