Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

-


[This FLIP proposal is a joint work between Yun Gao Dong Lin and Zhipeng Zhang]

Motivation

Iteration is a basic building block for a ML library. It is required for training ML models for both offline and online Training. In general, two types of iterations is required:

  1. Bounded Iteration: Usually used in the offline case. In this case the algorithm usually train on a bounded dataset, it updates the parameters for multiple rounds until convergence.
  2. Unbounded Iteration: Usually used in the online case, in this case the algorithm usually train on an unbounded dataset. It accumulates a mini-batch of data and then do one update to the parameters. 

Previously Flink supported bounded iteration with DataSet API and supported the unbounded iteration with DataStream API. However, since Flink aims to deprecate the DataSet API and the iteration in the DataStream API is rather incomplete, thus we would require to re-implement a new iteration library in the Flink-ml repository to support the algorithms. 

And there is another performance issue with the DataSet::iterate(...) API: in order to replay the user-provided data streams multiple times, it requires the runtime to always dump the user-provided data streams to disk. This introduces storage and disk I/O overhead even if user's algorithm may prefer to cache those values in-memory and in possibly a more compact format.

In order to address all the issues described above, and make Flink ML available for more iteration use-case in the long run, this FLIP proposes to add a couple APIs in the flink-ml repository to achieve the following goals:

  • Provide solution for all the iteration use-cases (see the use-case section below for more detail) supported by the existing APIs, without having the issues described above.
  • Provide solution for a few use-cases (e.g. bounded streams + async mode + per-round variable update) not supported by the existing APIs.
  • Decouple the iteration-related APIs from core Flink core runtime (by moving them to the flink-ml repo) so that we can keep the Flink core runtime as simple and maintainable as possible.
  • Provide iteration API that does not enforce the disk I/O overhead described above, so that users can optimize an iterative algorithm for best possible performance.

Target Use-cases

In general a ML algorithm would update the model according to the data in iteration until the model is converged. The target algorithms could be classified w.r.t. three properties: boundedness of input datasets, amount of data required for each variable update and the synchronization policy

1) Algorithms have different needs for whether the input data streams should be bounded or unbounded. We classify those algorithms into online algorithm and offline algorithms as below.

  • For online training algorithms, the training samples will be unbounded streams of data. The corresponding iteration body should ingest these unbounded streams of data, read each value in each stream once, and update machine learning model repeatedly in near real-time. The iteration will never terminate in this case. The algorithm should be executed as a streaming job.
  • For offline training algorithms, the training samples will be bounded streams of data. The corresponding iteration body should read these bounded data streams for arbitrary number of rounds and update machine learning model repeatedly until a termination criteria is met (e.g. a given number of rounds is reached or the model has converged). The algorithm should be executed as a batch job.

2) An algorithm may have additional requirements in how much data should be consumed each time before a subtask can update variables. There are two categories of choices here:

  • Per-batch variable update: The algorithm wants to update variables every time an arbitrary subset of the user-provided data streams (either bounded or unbounded) is processed.
  • Per-round variable update: The algorithm wants to update variables every time all data of the user-provided bounded data streams is processed.

In the machine learning domain, some algorithms allow users to configure a batch size and the model will be updated every time each subtask processes a batch of data. Those algorithms fits into the first category. And such an algorithm can be either online or offline.

Other algorithms only update variables every time the entire data is consumed for one round. Those algorithms fit into the second category. And such an algorithm must be offline because, by this definition, the user-provided dataset must be bounded.

3) Algorithms (either online or offline) have different needs of how their parallel subtasks.

  • In the sync mode, parallel subtasks, which execute the iteration body, update the model variables in a coordinated manner. There exists global epochs, such that all subtasks fetch the shared model variable values at the beginning of an epoch, calculate model variable updates based on the fetched variable values, and updates the model variable values at the end of this epoch. In other words, all subtasks read and update model variables in global lock steps.
  • In the async mode, each parallel subtask, which execute the iteration body, could read/update the shared model variables without waiting for variable updates from other subtask. For example, a subtask could have updated model variables 10 times when another subtask has updated model variables only 3 times.

The sync mode is useful when an algorithm should be executed in a deterministic way to achieve best possible accuracy, and the straggler issue (i.e. there is subtask which is considerably slower than others) does not cause slow down the algorithm execution too much. In comparison, the async mode is useful for algorithms which want to be parallelized and executed across many subtasks as fast as possible, without worrying about performance issue caused by stragglers, at the possible cost of reduced accuracy.

Based on the above dimensions, the algorithms could be classified into the following types:

TypeBounded / UnboundedData GranularitySynchronization Pattern Support in the existing APIsSupport in the proposed APIExamples
Non-SGD-basedBoundedEpochMostly SynchronousYesYesOffline K-Means, Apriori, Decision Tree

SGD-Based Synchronous Offline algorithm

BoundedBatch → Epoch*SynchronousYesYesLinear Regression, Logistic Regression, Deep Learning algorithms
SGD-Based Asynchronous Offline algorithmBoundedBatch → Epoch*AsynchronousNoYesSame to the above
SGD-Based Synchronous Online algorithmUnboundedBatchSynchronousYesYesOnline version of the above algorithm
SGD-Based Asynchronous Online algorithmUnboundedBatchAsynchronousNoYesOnline version of the above algorithm

*Although SGD-based algorithms are also batch-based, they could be implemented with an Epoch-based method if intermediate state is allowed: the subtasks could sample a batch from all the records from the position of the last batch. 

Based on the above classification and the replacement implementation for SGD-based algorithms with bounded dataset, we mainly need to support

  1. The synchronous / asynchronous epoch-based algorithms on the bounded dataset.
  2. The synchronous / asynchronous batch-based algorithms on the unbounded dataset. 

Overview of the Iteration Paradigm

Based on the types of algorithms, we explain the iteration paradigm that has motivated our choices of the proposed APIs.

An iterative algorithm has the following behavior pattern:

  • The iterative algorithm has an iteration body that is repeatedly invoked until some termination criteria is reached (e.g. after a user-specified number of epochs has been reached). An iteration body is a subgraph of operators that implements the computation logic of e.g. an iterative machine learning algorithm, whose outputs might be be fed back as the inputs of this subgraph. 
  • In each invocation, the iteration body updates the model parameters based on the user-provided data as well as the most recent model parameters.
  • The iterative algorithm takes as inputs the user-provided data and the initial model parameters.
  • The iterative algorithm could output arbitrary user-defined information, such as the loss after each epoch, or the final model parameters. 

Therefore, the behavior of an iterative algorithm could be characterized with the following iteration paradigm (w.r.t. Flink concepts):

  • An iteration-body is a Flink subgraph with the following inputs and outputs:
    • Inputs: model-variables (as a list of data streams) and user-provided-data (as another list of data streams)
    • Outputs: feedback-model-variables (as a list of data streams) and user-observed-outputs (as a list of data streams)
  • A termination-condition that specifies when the iterative execution of the iteration body should terminate.
  • In order to execute an iteration body, a user needs to execute an iteration body the following inputs, and gets the following outputs.
    • Inputs: initial-model-variables (as a list of bounded data streams) and user-provided-data (as a list of data streams)
    • Outputs: the user-observed-output emitted by the iteration body.

It is important to note that the users typically should not invoke the IterationBody::process directly because the model-variables expected by the iteration body is not the same as the initial-model-variables provided by the user. Instead, model-variables are computed as the union of the feedback-model-variables (emitted by the iteration body) and the initial-model-variables (provided by the caller of the iteration body). To relieve user from creating this union operator, we have added utility class (see Iterations) to run an iteration-body with the user-provided inputs.

The figure below summarizes the iteration paradigm described above. The streams in the red color are inputs provided by the user to the iteration body, as well as outputs emitted by the iteration body to the user.

Public Interfaces

We propose to make the following API changes to support the iteration paradigm described above. 


1) Add the IterationBody interface.

This interface corresponds to the iteration-body with the inputs and outputs described in the iteration paradigm. This interface should be implemented by the developer of the algorithm.

Note that the IterationBody also outputs the termination criteria which corresponds to the termination-condition described in the iteration paradigm. This allows the algorithm developer to use a stream created inside the IterationBody as the terminationCriteria.

package org.apache.flink.iteration;

@PublicEvolving
public interface IterationBody {
    /**
     * This method creates the graph for the iteration body.
     *
     * See Utils::iterate, Utils::iterateBoundedStreams and Utils::iterateAndReplayBoundedStreams for how the iteration
     * body can be executed and when execution of the corresponding graph should terminate.
     *
     * Required: the number of feedback variable streams returned by this method must equal the number of variable
     * streams given to this method.
     *
     * @param variableStreams the variable streams.
     * @param dataStreams the data streams.
     * @return a IterationBodyResult.
     */
    IterationBodyResult process(DataStreamList variableStreams, DataStreamList dataStreams);
}


2) Add IterationBodyResult class.

This is a helper class which contains the objects returned by the IterationBody::process(...).

package org.apache.flink.iteration;

/**
 * A helper class that contains the streams returned by the iteration body.
 */
class IterationBodyResult {
    /**
     * A list of feedback variable streams. These streams will only be used during the iteration execution and will
     * not be returned to the caller of the iteration body. It is assumed that the method which executes the
     * iteration body will feed the records of the feedback variable streams back to the corresponding input variable
     * streams.
     */
    DataStreamList feedbackVariableStreams;

    /**
     * A list of output streams. These streams will be returned to the caller of the methods that execute the
     * iteration body.
     */
    DataStreamList outputStreams;

    /**
     * An optional termination criteria stream. If this stream is not null, it will be used together with the
     * feedback variable streams to determine when the iteration should terminate.
     */
    @Nullable DataStream<?> terminationCriteria;
}


3) Add the IterationListener interface.

If an UDF (a.k.a user-defined function) used inside the IterationBody implements this interface, the callbacks on this interface will be invoked when corresponding events happen.

This interface allows users to achieve the following goals:
- Run an algorithm in sync mode, i.e. each subtask will wait for model parameters updates from all other subtasks before reading the aggregated model parameters and starting the next epoch of execution. As long as we could find a cut in the subgraph of the iteration body that all the operators in the cut only emit records in onEpochWatermarkIncremented,  the algorithm would be synchronous. The detailed proof could be found in the appendix. 
- Emit final output after the iteration terminates.


package org.apache.flink.iteration

/**
 * The callbacks defined below will be invoked only if the operator instance which implements this interface is used
 * within an iteration body.
 */
@PublicEvolving
public interface IterationListener<T> {
    /**
     * This callback is invoked every time the epoch watermark of this operator increments. The initial epoch watermark
     * of an operator is 0.
     *
     * The epochWatermark is the maximum integer that meets this requirement: every record that arrives at the operator
     * going forward should have an epoch larger than the epochWatermark. See Java docs in Iterations for how epoch
     * is determined for records ingested into the iteration body and for records emitted by operators within the
     * iteration body.
     *
     * If all inputs are bounded, the maximum epoch of all records ingested into this operator is used as the
     * epochWatermark parameter for the last invocation of this callback.
     *
     * @param epochWatermark The incremented epoch watermark.
     * @param context A context that allows emitting side output. The context is only valid during the invocation of
     *                this method.
     * @param collector The collector for returning result values.
     */
    void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector);

    /**
     * This callback is invoked after the execution of the iteration body has terminated.
     *
     * See Java doc of methods in Iterations for the termination conditions.
     *
     * @param context A context that allows emitting side output. The context is only valid during the invocation of
     *                this method.
     * @param collector The collector for returning result values.
     */
    void onIterationTermination(Context context, Collector<T> collector);

    /**
     * Information available in an invocation of the callbacks defined in the IterationProgressListener.
     */
    interface Context {
        /**
         * Emits a record to the side output identified by the {@link OutputTag}.
         *
         * @param outputTag the {@code OutputTag} that identifies the side output to emit to.
         * @param value The record to emit.
         */
        <X> void output(OutputTag<X> outputTag, X value);
    }
}


4) Add the Iterations class.

This class provides APIs to execute an iteration body with the user-provided inputs. This class provides three APIs to run an iteration body, each with different input types (e.g. bounded data streams vs. unbounded data streams) and data replay semantics (i.e. whether to replay the user-provided data streams).

Each of these three APIs provide the functionality as described in the iteration paradigm: Union the feedback variables streams (returned by the iteration body) with the initial variable streams (provided by the user) and use the merged streams as inputs to invoke IterationBody::process(...).


package org.apache.flink.iteration;


/**
 * A helper class to create iterations. To construct an iteration, Users are required to provide
 *
 * <ul>
 *   <li>initVariableStreams: the initial values of the variable data streams which would be updated
 *       in each round.
 *   <li>dataStreams: the other data streams used inside the iteration, but would not be updated.
 *   <li>iterationBody: specifies the subgraph to update the variable streams and the outputs.
 * </ul>
 *
 * <p>The iteration body will be invoked with two parameters: The first parameter is a list of input
 * variable streams, which are created as the union of the initial variable streams and the
 * corresponding feedback variable streams (returned by the iteration body); The second parameter is
 * the data streams given to this method.
 *
 * <p>During the execution of iteration body, each of the records involved in the iteration has an
 * epoch attached, which is mark the progress of the iteration. The epoch is computed as:
 *
 * <ul>
 *   <li>All records in the initial variable streams and initial data streams has epoch = 0.
 *   <li>For any record emitted by this operator into a non-feedback stream, the epoch of this
 *       emitted record = the epoch of the input record that triggers this emission. If this record
 *       is emitted by onEpochWatermarkIncremented(), then the epoch of this record =
 *       epochWatermark.
 *   <li>For any record emitted by this operator into a feedback variable stream, the epoch of the
 *       emitted record = the epoch of the input record that triggers this emission + 1.
 * </ul>
 *
 * <p>The framework would given the notification at the end of each epoch for operators and UDFs
 * that implements {@link IterationListener}.
 *
 * <p>The limitation of constructing the subgraph inside the iteration body could be refer in {@link
 * IterationBody}.
 *
 * <p>An example of the iteration is like:
 *
 * <pre>{@code
 * DataStreamList result = Iterations.iterateUnboundedStreams(
 *  DataStreamList.of(first, second),
 *  DataStreamList.of(third),
 *  (variableStreams, dataStreams) -> {
 *      ...
 *      return new IterationBodyResult(
 *          DataStreamList.of(firstFeedback, secondFeedback),
 *          DataStreamList.of(output));
 *  }
 *  result.<Integer>get(0).addSink(...);
 * }</pre>
 */
public class Iterations {

    /**
     * This method uses an iteration body to process records in possibly unbounded data streams. The
     * iteration would not terminate if at least one of its inputs is unbounded. Otherwise it will
     * terminated after all the inputs are terminated and no more records are iterating.
     *
     * @param initVariableStreams The initial variable streams, which is merged with the feedback
     *     variable streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The non-variable streams also refered in the {@code body}.
     * @param body The computation logic which takes variable/data streams and returns
     *     feedback/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    public static DataStreamList iterateUnboundedStreams(
            DataStreamList initVariableStreams, DataStreamList dataStreams, IterationBody body) {
        return null;
    }

    /**
     * This method uses an iteration body to process records in some bounded data streams
     * iteratively until a termination criteria is reached (e.g. the given number of rounds is
     * completed or no further variable update is needed). Because this method does not replay
     * records in the data streams, the iteration body needs to cache those records in order to
     * visit those records repeatedly.
     *
     * @param initVariableStreams The initial variable streams, which is merged with the feedback
     *     variable streams before being used as the 1st parameter to invoke the iteration body.
     * @param dataStreams The non-variable streams also refered in the {@code body}.
     * @param config The config for the iteration, like whether to re-create the operator on each
     *     round.
     * @param body The computation logic which takes variable/data streams and returns
     *     feedback/output streams.
     * @return The list of output streams returned by the iteration boy.
     */
    public static DataStreamList iterateBoundedStreamsUntilTermination(
            DataStreamList initVariableStreams,
            ReplayableDataStreamList dataStreams,
            IterationConfig config,
            IterationBody body) {
        return null;
    }
}  

5) Introduce the forEachRound utility method.

forEachRound allows the users to specify a sub-graph that executes as per-round mode, namely all the operators would be re-created for each round. 

public interface IterationBody {
    
    ....


    /**
     * @param inputs The inputs of the subgraph.
     * @param perRoundSubBody The computational logic that want to be executed as per-round.
     * @return The output of the subgraph.
     */
    static DataStreamList forEachRound(
        DataStreamList inputs, PerRoundSubBody perRoundSubBody) {
        return null;
    }


    /** The sub-graph inside the iteration body that should be executed as per-round. */
    interface PerRoundSubBody {

        DataStreamList process(DataStreamList input);
    }
}


6) Add the DataStreamList and ReplayableDataStreamList class.

DataStreamList is a helper class that contains a list of data streams with possibly different elements types and ReplayableDataStreamList is a helper class that contains a list of data streams and whether they need replay for each round.

package org.apache.flink.iteration;

public class DataStreamList {
	public static DataStreamList of(DataStream<?>... streams);

    // Returns the number of data streams in this list.
    public int size() {...}

    // Returns the data stream at the given index in this list.
    public <T> DataStream<T> get(int index) {...}
}  

public class ReplayableDataStreamList {

    private final List<DataStream<?>> replayedDataStreams;

    private final List<DataStream<?>> nonReplayedStreams;

    private ReplayableDataStreamList(
            List<DataStream<?>> replayedDataStreams, List<DataStream<?>> nonReplayedStreams) {
        this.replayedDataStreams = replayedDataStreams;
        this.nonReplayedStreams = nonReplayedStreams;
    }

    public static ReplayedDataStreamList replay(DataStream<?>... dataStreams) {
        return new ReplayedDataStreamList(Arrays.asList(dataStreams));
    }

    public static NonReplayedDataStreamList notReplay(DataStream<?>... dataStreams) {
        return new NonReplayedDataStreamList(Arrays.asList(dataStreams));
    }

    List<DataStream<?>> getReplayedDataStreams() {
        return Collections.unmodifiableList(replayedDataStreams);
    }

    List<DataStream<?>> getNonReplayedStreams() {
        return Collections.unmodifiableList(nonReplayedStreams);
    }

    private static class ReplayedDataStreamList extends ReplayableDataStreamList {

        public ReplayedDataStreamList(List<DataStream<?>> replayedDataStreams) {
            super(replayedDataStreams, Collections.emptyList());
        }

        public ReplayableDataStreamList andNotReplay(DataStream<?>... nonReplayedStreams) {
            return new ReplayableDataStreamList(
                    getReplayedDataStreams(), Arrays.asList(nonReplayedStreams));
        }
    }

    private static class NonReplayedDataStreamList extends ReplayableDataStreamList {

        public NonReplayedDataStreamList(List<DataStream<?>> nonReplayedDataStreams) {
            super(Collections.emptyList(), nonReplayedDataStreams);
        }

        public ReplayableDataStreamList andReplay(DataStream<?>... replayedStreams) {
            return new ReplayableDataStreamList(
                    Arrays.asList(replayedStreams), getNonReplayedStreams());
        }
    }
}

7) Introduce the IterationConfig

IterationConfig allows users to specify the config for each iteration. For now users could only specify the default operator lifecycle inside the iteration, but it ensures the forward compatibility if we have more options in the future. 

package org.apache.flink.iteration;

/** The config for an iteration. */
public class IterationConfig {

    private final OperatorLifeCycle operatorRoundMode;

    private IterationConfig(OperatorLifeCycle operatorRoundMode) {
        this.operatorRoundMode = operatorRoundMode;
    }

    public static IterationConfigBuilder newBuilder() {

        return new IterationConfigBuilder();
    }

    public static class IterationConfigBuilder {

        private OperatorLifeCycle operatorRoundMode = OperatorLifeCycle.ALL_ROUND;

        private IterationConfigBuilder() {}

        public IterationConfigBuilder setOperatorRoundMode(OperatorLifeCycle operatorRoundMode) {
            this.operatorRoundMode = operatorRoundMode;
            return this;
        }

        public IterationConfig build() {
            return new IterationConfig(operatorRoundMode);
        }
    }

    public enum OperatorLifeCycle {
        ALL_ROUND,

        PER_ROUND
    }
}


8) Deprecate the existing DataStream::iterate() and the DataStream::iterate(long maxWaitTimeMillis) methods.

We plan to remove both methods after the APIs added in this doc is ready for production use. This change is needed to decouple the iteration-related APIs from core Flink core runtime  so that we can keep the Flink core runtime as simple and maintainable as possible.

Example Usage

This sections shows how general used ML algorithms could be implemented with the iteration API. 

Offline Algorithms

We would like to first show the usage of the bounded iteration with the linear regression case: the model is Y = XA, and we would like to acquire the best estimation of A with the SGD algorithm. To simplify we assume the parameters could be held in the memory of one task.

The job graph of the algorithm could be shown in the Figure 3: in each round, we use the latest parameters to calculate the update to the parameters: ΔA = ∑(Y - XA)X. To achieve this, the Parameters vertex would broadcast the latest parameters to the Train vertex. Each subtask of the Train vertex holds a part of dataset. Follow the sprite of SGD, it would sample a small batch of training records, and calculate the update with the above equation. Then the Reducer would merge ΔA from all the train tasks and emit the reduced value to the Parameters node to update the parameters. The Reducer is required since we required the feedback streams have the same parallelism with the corresponding initialized streams. 


Figure 2. The JobGraph for the offline training of the linear regression case.


We will start with the synchronous training. The synchronous training requires the updates from all the Train vertex subtask is merged before the next round of training. It could be done by only emit the next round of parameters on the end of round. The code is shown as follows:

public class SynchronousBoundedLinearRegression {
    private static final int N_DIM = 50;
    private static final int N_EPOCH = 5;
    private static final int N_BATCH_PER_EPOCH = 10;
    private static final OutputTag<double[]> FINAL_MODEL_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        DataStreamList resultStreams = 
            Iterations.iterateBoundedStreamsUntilTermination(
				DataStreamList.of(initParameters), 
				ReplayableDataStreamList.notReplay(dataset), 
				IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build();
				(variableStreams, dataStreams) -> {
                DataStream<double[]> parameterUpdates = variableStreams.get(0);
                	DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

                	SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction());
                	DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                   		.broadcast()
	                    .connect(dataset)
    	                .coProcess(new TrainFunction())
        	            .setParallelism(10)
						.process(new ReduceFunction())
						.setParallelism(1)
	
                	return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
            	});
        
        DataStream<double[]> finalModel = resultStreams.get("final_model");
        finalModel.print();
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]>
        implements IterationListener<double[]> {  
        
        private final double[] parameters = new double[N_DIM];

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);
        }

        void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
            if (epochWatermark < N_EPOCH * N_BATCH_PER_EPOCH) {
                collector.collect(parameters);
            }
        }

        public void onIterationEnd(int[] round, Context context) {
            context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
        }
    }

    public static class TrainFunction extends CoProcessFunction<double[], Tuple2<double[], Double>, double[]> implements IterationListener<double[]> {

        private final List<Tuple2<double[], Double>> dataset = new ArrayList<>();
        private double[] firstRoundCachedParameter;

        private Supplier<int[]> recordRoundQuerier;

        public void setCurrentRecordRoundsQuerier(Supplier<int[]> querier) {
            this.recordRoundQuerier = querier;
        } 

        public void processElement1(double[] parameter, Context context, Collector<O> output) {
            int[] round = recordRoundQuerier.get();
            if (round[0] == 0) {
                firstRoundCachedParameter = parameter;
                return;
            }

            calculateModelUpdate(parameter, output);
        }

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
            dataset.add(trainSample)
        }

        void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
            if (epochWatermark == 0) {
                calculateModelUpdate(firstRoundCachedParameter, output);
                firstRoundCachedParameter = null;
            }
        }

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
            List<Tuple2<double[], Double>> samples = sample(dataset);

            double[] modelUpdate = new double[N_DIM];
            for (Tuple2<double[], Double> record : samples) {
                double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
            }

            output.collect(modelUpdate);
        }
    }
	
	public static class ReduceFunction {
		private double[] mergedValue = ArrayUtils.newArray(N_DIM);

	 	public void processElement(double[] parameter, Context context, Collector<O> output) {
            mergedValue = ArrayUtils.add(mergedValue, parameter);
        }

	 	void onEpochWatermarkIncremented(int epochWatermark, Context context, Collector<T> collector) {
            collector.collect(mergedValue);
			mergedValue = ArrayUtils.newArray(N_DIM);
        }
	}
}

Here we use a specialized reduce function that handles the boundary of the round. However, outside of the iteration we might already be able to do the reduce with windowAllI().reduce(), to reuse the operators outside of the iteration, we could use the forEachRound utility method. 

 			DataStreamList resultStreams = 
            Iterations.iterateBoundedStreamsUntilTermination(
				DataStreamList.of(initParameters), 
				ReplayableDataStreamList.notReplay(dataset), 
				IterationConfig.newBuilder().setOperatorRoundMode(ALL_ROUND).build();
				(variableStreams, dataStreams) -> {
                DataStream<double[]> parameterUpdates = variableStreams.get(0);
                	DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

                	SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction());
                	DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                   		.broadcast()
	                    .connect(dataset)
    	                .coProcess(new TrainFunction())
        	            .setParallelism(10)
					
			     	DataStream<double[]> reduced = forEachRound(DataStreamList.of(modelUpdate), streams -> {
						return streams.<double[]>get(0).windowAll().reduce((x, y) -> ArrayUtils.add(x, y));
					}).<double[]>get(0);
	
                	return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
            	});

This would be very helpful for complex scenarios that we could reuse the ability of the the current datastream and table operators. 



If instead we want to do asynchronous training, we would need to do the following change:

  1. The Parameters vertex would not wait till round end to ensure received all the updates from the iteration. Instead, it would immediately output the current parameters values once it received the model update from one train subtask.
  2. To label the source of the update, we would like to change the input type to be Tuple2<Integer, double[]>. The Parameters would only output the new parameters values to the Train task that send the update.
  3. The Reducer would not merge the received updates any more, instead, it would directly redirect the received updates. 

We omit the change to the graph building code since the change is trivial (change the output type and the partitioner to be customized one). The change to the Parameters vertex is the follows:

public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>>
    implements BoundedIterationProgressListener<double[]> {  
    
    private final double[] parameters = new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) {
        // Suppose we have a util to add the second array to the first.
        ArrayUtils.addWith(parameters, update);
        output.collect(new Tuple2<>(update.f0, parameters))
    }

    public void onIterationEnd(int[] round, Context context) {
        context.output(FINAL_MODEL_OUTPUT_TAG, parameters);
    }
}

public class AsynchronousBoundedLinearRegression {
    
    ...

    DataStreamList resultStreams = 
        Iterations.iterateBoundedStreamsUntilTermination(DataStreamList.of(initParameters), DataStreamList.of(dataset), (variableStreams, dataStreams) -> {
            DataStream<double> parameterUpdates = variableStreams.get(0);
            DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

            SingleOutputStreamOperator<double[]> parameters = parameterUpdates.process(new ParametersCacheFunction());
            DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                .partitionCustom((key, numPartitions) -> key % numPartitions, update -> update.f0)
                .connect(dataset)
                .coProcess(new TrainFunction())
                .setParallelism(10)

            return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
        });

    ...
}

Online Algorithms

Suppose now we would change the algorithm to unbounded iteration, compared to the offline, the differences is that

  1. The dataset is unbounded. The Train operator could not cache all the data in the first round.
  2. The training algorithm might be changed to others like FTRL. But we keep using SGD in this example since it does not affect showing the usage of the iteration.

We also start with the synchronous case. for online training, the Train vertex usually do one update after accumulating one mini-batch. This is to ensure the distribution of the samples is similar to the global statistics. In this example we omit the complex data re-sample process and just fetch the next several records as one mini-batch. 

The JobGraph for online training is still shown in Figure 1, with the training dataset become unbounded. Similar to the bounded cases, for the synchronous training, the process would be expected like

  1. The Parameters broadcast the initialized values on received the input values.
  2. All the Train task read the next mini-batch of records, Calculating an update and emit to the Parameters vertex. Then it would wait till received update parameters from the Parameters Vertex before it head to process the next mini-batch.
  3. The Parameter vertex would wait received the updates from all the Train tasks before it broadcast the updated parameters. 

Since in the unbounded case there is not the concept of round, and we do update per-mini-batch, thus we could instead use the InputSelectable functionality to implement the algorithm:

public class SynchronousUnboundedLinearRegression {
    private static final N_DIM = 50;
    private static final OutputTag<double[]> MODEL_UPDATE_OUTPUT_TAG = new OutputTag<double[]>{};

    public static void main(String[] args) {
        DataStream<double[]> initParameters = loadParameters().setParallelism(1);
        DataStream<Tuple2<double[], Double>> dataset = loadDataSet().setParallelism(1);

        DataStreamList resultStreams = Iterations.iterateUnboundedStreams(DataStreamList.of(initParameters), DataStreamList.of(dataset), (variableStreams, dataStreams) -> {
            DataStream<double> parameterUpdates = variableStreams.get(0);
            DataStream<Tuple2<double[], Double>> dataset = dataStreams.get(0);

            SingleOutputStreamOperator<double[]> parameters = model.process(new ParametersCacheFunction(10));
                DataStream<double[]> modelUpdate = parameters.setParallelism(1)
                    .broadcast()
                    .connect(dataset)
                    .transform(
                                "operator",
                                BasicTypeInfo.INT_TYPE_INFO,
                                new TrainOperators(50));
                    .setParallelism(10);
            return new IterationBodyResult(DataStreamList.of(modelUpdate), DataStreamList.of(parameters.getSideOut(FINAL_MODEL_OUTPUT_TAG)));
        })
        
        DataStream<double[]> finalModel = resultStreams.get("model_update");
        finalModel.addSink(...)
    }

    public static class ParametersCacheFunction extends ProcessFunction<double[], double[]> {  
        
        private final int numOfTrainTasks;

        private final int numOfUpdatesReceived = 0;
        private final double[] parameters = new double[N_DIM];

        public ParametersCacheFunction(int numOfTrainTasks) {
            this.numOfTrainTasks = numOfTrainTasks;
        }

        public void processElement(double[] update, Context ctx, Collector<O> output) {
            // Suppose we have a util to add the second array to the first.
            ArrayUtils.addWith(parameters, update);
            numOfUpdatesReceived++;

            if (numOfUpdatesReceived == numOfTrainTasks) {
                output.collect(parameters);
                numOfUpdatesReceived = 0;
            }
        }
    }

    public static class TrainOperators extends AbstractStreamOperator<double[]> implements TwoInputStreamOperator<double[], Tuple2<double[], Double>, double[]>, InputSelectable {

        private final int miniBatchSize;

        private final List<Tuple2<double[], Double>> miniBatch = new ArrayList<>();
        private double[] firstRoundCachedParameter;

        public TrainOperators(int miniBatchSize) {
            this.miniBatchSize = miniBatchSize;
        }

        public void processElement1(double[] parameter, Context context, Collector<O> output) {
            calculateModelUpdate(parameter, output);
			miniBatchSize.clear();
        }

        public void processElement2(Tuple2<double[], Double> trainSample, Context context, Collector<O> output) {
            dataset.add(trainSample);
        }

        public InputSelection nextSelection() {
            if (miniBatch.size() < miniBatchSize) {
                return InputSelection.SECOND;
            } else {
                return InputSelection.FIRST;
            }
        }

        private void calculateModelUpdate(double[] parameters, Collector<O> output) {
            double[] modelUpdate = new double[N_DIM];
            for (Tuple2<double[], Double> record : miniBatchSize) {
                double diff = (ArrayUtils.muladd(record.f0, parameters) - record.f1);
                ArrayUtils.addWith(modelUpdate, ArrayUtils.multiply(record.f0, diff));
            }

            output.collect(modelUpdate);
        }
    }
}

Also similar to the bounded case, for the asynchronous training the Parameters vertex would not wait for received updates from all the Train tasks. Instead, it would directly response to the task sending update:

public static class ParametersCacheFunction extends ProcessFunction<Tuple2<Integer, double[]>, Tuple2<Integer, double[]>> {  
    
    private final double[] parameters = new double[N_DIM];

    public void processElement(Tuple2<Integer, double[]> update, Context ctx, Collector<Tuple2<Integer, double[]>> output) {
        ArrayUtils.addWith(parameters, update);
                
        if (update.f0 < 0) {
            // Received the initialized parameter values, broadcast to all the downstream tasks
            for (int i = 0; i < 10; ++i) {
                output.collect(new Tuple2<>(i, parameters))        
            }
        } else {
            output.collect(new Tuple2<>(update.f0, parameters))
        }
    }
}

Proposed Changes

In this section, we discuss a few design choices related to the implementation and usage of the proposed APIs.

Support the Feedback Edges

1) How the feedback edge is supported.

The Flink core runtime can only execute a DAG of operators that does not involve cycles. Thus extra work needs to be done to support feedback edges (which effectively introduces cycles in the data flow).

Similar to the existing iterative API, this FLIP plans to implement the feedback edge using the following approach:

  • Automatically insert the HEAD and the TAIL operators as the first and the last operators in the iteration body.
  • Co-locate the HEAD and the TAIL operators on the same task manager.
  • Have the HEAD and the TAIL operators transmit the records of the feedback edges using an in-memory queue data structure. 

Since all the forward edges have limited buffers, to avoid the deadlocks the feedback queues must have unlimited size. To avoid unlimited memory footprint, when the queued size exceeds a threshold the records would be spilled to the disk. To avoid such spilling the HEAD operators would read the data in a "feedback-first" manner, namely it would always process the feedback records first if there are records from both initial input and feedback edges.


2) How the termination of the iteration execution is determined.

Having the feedback edges also complicate the termination detection of the job. Since the feedback edges is not visible in the JobGraph, the HEAD operators, as the first operators in the DAG of the iteration body, would decide when the whole iteration body could be terminated and initiate the termination process. The termination happens when

  1. All the inputs to the iteration body have been finished. 
  2. AND 
    1. If users have specify a reference stream, the number of records in each epoch would be counted after the epoch is done, and if the count is 0, HEAD would start to terminate. 
    2. OR If users have not specify a reference stream, then when no records are still being processed inside the iteration body. This is detected by have a special event travel through the whole iteration body. 

The Execution Mode

Currently Flink provides two execution mode for the bounded job, name blocking mode and stream mode. The two modes differ in the edge types, scheduler policy and failover policy. For the jobs with iteration, to reduce the implementation complexity of coordination and failover, we expect the whole iteration body would be always run together, thus at least initially we would not support blocking mode. 

Users may have jobs that requires preprocessing and post-process. For such jobs, currently we expected users to split the whole pipeline into multiple jobs so that users could use blocking modes for the other jobs. 

Dealing With Failure

Since currently we only support stream mode, we could rely on the checkpoint mechanism of Flink directly. The existing checkpoint mechanism does not support feedback edges, we would enhance them by also snapshotting the feedback records with the checkpoint. 

Compatibility, Deprecation, and Migration Plan

Deprecation plan

The following APIs will be deprecated and removed in the future Flink release:

  • The entire DataSet class. See FLIP-131 for its motivation and the migration plan. The deprecation of DataSet::iterate(...) proposed by this FLIP is covered by FLIP-131.
  • The DataStream::iterate(...) and DataStream::iterate(long).

Compatibility

  • This FLIP introduces backward in-compatible changes by proposing to remove DataStream::iterate(...) in the future.
  • We expect the APIs proposed in this FLIP can address most of the use-cases supported by the DataStream::iterate(...) and DataSet::iterate(...). The only use-cases that we have dropped support are those that require nested iteration on bounded data streams in sync mode. We have made this choice because we are not aware of any reasonable use-cases that require nested iteration. This support can be added if any user provide a good use-case for nested iteration.

Migration plan

  • Users will need to re-write their application code in order to migrate from the existing iterative APIs to the proposed APIs.

Rejected Alternatives

Naiad has proposed a unified model for watermark mechanism (namely progress tracking outside of the iteration) and the progress tracking inside the iteration. It extends the event time and watermark to be a vector (long timestamp, int[] rounds) and implements a vectorized alignment algorithm. Although Naiad provides an elegant model, the direct implementation on Flink would requires a large amount of modification to the flink runtime, which would cause a lot of complexity and maintenance overhead.  Thus we would choose to implement a simplified version on top of FLINK, as a part of the flink-ml library.

For the iteration DAG build graph, it would be more simpler if we could directly refer to the data stream variables outside of the closure of iteration body. However, since we need to make the iteration DAG creation first happen in the mock execution environment, we could not use these variables directly, otherwise we would directly modify the real environment and won't have chance to add wrappers to the operators. 

The existing FLIP does not explicitly support nested iterations. This is because we have not seen clear use-case that require nested iteration. We would prefer to only introduce additional complexity that are required by some reasonable use-cases, In the future if we decide to support nested iterations, we will need extra APIs to run the iteration body, with the semantics that operators inside the iteration body will be re-created for every round of iteration.

Appendix

1) In the following, we prove that the proposed solution can be used to implement an iterative algorithm in the sync mode.

Refer to the "Proposed Changes" section for the definition of sync mode and the description of the solution. In the following, we prove that the solution does work as expected.

Proof

In the following, we will prove that the solution described above could enforce the sync-mode execution. Note that the calculation of the record's epoch and the semantics of onEpochWatermarkIncremented(...) are described in the Java doc of the corresponding APIs.

Lemma-1: For any operator OpB defined in the IterationBody, at the time its Nth invocation of onEpochWatermarkIncremented(...) starts, it is guaranteed that:

  • If an input edge is a non-feedback edge from OpA, then OpA's Nth invocation of onEpochWatermarkIncremented(...) has been completed.
  • If an input edge is a feedback edge from OpA, then OpA's (N-1)th invocation of onEpochWatermarkIncremented(...) has been completed.

Let's prove the lemma-1 by contradiction:

  • At the time the OpB's Nth invocation starts, its epoch watermark has incremented to N, which means OpB will no longer receive any record with epoch <= N.
  • Suppose there is a non-feedback edge from OpA AND OpA's Nth invocation has not been completed. Then when OpA's Nth invocation completes, OpA can generate a record with epoch=N and send it to OpB via this non-feedback edge, which contradicts the guarantee described above.
  • Suppose there is a feedback edge from OpA AND OpA's (N-1)th invocation has not been completed. Then when OpA's (N-1)th invocation completes, OpA can generate a record with epoch=N and send it to OpB via this feedback edge, which contradicts the guarantee described above.

Lemma-2: For any operator OpB defined in the IterationBody, at the time its Nth invocation of onEpochWatermarkIncremented(...) starts, it is guaranteed that:

  • If an edge is a non-feedback input edge from OpA and this edge is part of a feedback loop, then OpA's (N+1)th invocation of onEpochWatermarkIncremented(...) has not started.
  • If an edge is a feedback input edge from OpA and this edge is part of a feedback loop, then OpA's Nth invocation of onEpochWatermarkIncremented(...) has not started.

Let's prove this lemma by contradiction:

  • Suppose there is a non-feedback edge from OpA, this edge is part of a feedback loop, and OpA's (N+1)th invocation has started. Since this non-feedback edge is part of a feedback loop, there is a backward path from OpA to OpB with exactly 1 feedback edge on this path. By applying the lemma-1 recursively for operators on this path, we can tell that OpB's Nth invocation has been completed. This contradicts the assumption that OpB's Nth invocation just started.
  • Suppose there is a feedback edge from OpA, this edge is part of a feedback loop, and OpA's Nth invocation has started. Since this feedback edge is part of feedback loop, there is a backward path from OpA to OpB with no feedback edge on this path. By applying lemma-1 recursively for operators on this path, we can tell that OpB's Nth invocation has been completed. This contradicts the assumption that OpB's Nth invocation just started.

Let's now prove that the sync-mode is achieved:

  • For any operator in the IterationBody, we define its output for the Nth epoch as the output emitted by the Nth invocation of onEpochWatermarkIncremented(). This definition is well-defined because operators only emit records in onEpochWatermarkIncremented().
  • At the time an operator OpB computes its output for the Nth epoch, this operator must have received exactly the following records from its input edges:
    • Suppose an edge is a non-feedback input edge from OpA and this edge is part of a feedback loop. It follows that OpA has emitted records for its Nth epoch (by lemma-1) and has not started to emit records for its (N+1)th epoch (by lemma-2).
    • Suppose an edge is a feedback input edge from OpA and this edge is part of a feedback loop. It follows that OpA has emitted records for its (N-1)th epoch (by lemma-1) and has not started to emit records for its Nth epoch (by lemma-2).

3 Comments

  1. null
    Is probably outdated because we return an Optional<DataStream<?>> .
    1. Hi Till Rohrmann ~ sorry I could not found the referred position of this comment, do you mean we return `Optional<DataStream<?>>` within the IterationBodyResult ~? 

      1. Yes, the return type of `IterationBodyResult.terminationCriteria`.