Discussion thread
Vote thread
JIRA

FLINK-16187 - Getting issue details... STATUS

Release1.11

Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

Motivation

FLIP-39 rebuilds the Flink ML pipeline on top of TableAPI and introduces a new set of Java APIs. As Python is widely used in ML areas, providing Python ML Pipeline APIs for Flink can not only make it easier to write ML jobs for Python users but also broaden the adoption of Flink ML. This design document discusses the design of Python ML Pipeline API.

Goals

  • Add Python pipeline API according to Java pipeline API(we will adapt the Python pipeline API if Java pipeline API changes).
  • Support native Python Transformer/Estimator/Model, i.e., users can write not only Python Transformer/Estimator/Model wrappers for calling Java ones but also can write native Python Transformer/Estimator/Models. 
  • Ease of use. Support keyword arguments when defining parameters.


Proposed Changes

Java Pipeline API and the related Python API

To give it a whole picture, the table below lists all key Java interfaces and the corresponding Python interfaces to be added. 

Interface type

Java Interface Name

Python Interface Name

Description

ML core interface

PipelineStage

PipelineStage

The base node of Pipeline

Transformer

Transformer

Native Python interface

JavaTransformer

Python wrappers for calling Java interface

Estimator

Estimator

Native Python interface

JavaEstimator

Python wrappers for calling Java interface

Model

Model

Native Python interface

JavaModel

Python wrappers for calling Java interface

ML Pipeline

Pipeline

Pipeline

Describes a ML workflow

ML environment

MLEnvironment

MLEnvironment

Stores the necessary context in Flink

MLEnvironmentFactory

MLEnvironmentFactory

Factory to get the MLEnvironment

Help interface

Params

Params

A container of parameters

ParamInfo

ParamInfo

Definition of a parameter

WithParams

WithParams

common interface to interact with classes with parameters

Support native Python Transformer/Estimator/Model

Transformers, Estimator and Model can be treated as an algorithm. With different Transformers, Estimators and Models, users can write different machine learning jobs.

There are ways to provide Python Transformers, Estimators and Models. On one hand, we can provide Python Transformer/Estimator/Model that wraps existing Java ones to leverage the power of Java APIs.  While on the other hand, there are scenarios that Python users want to write a native Python Transformer/Estimator/Model. And this should also be supported.

To support both these two cases, we provided two kinds of interfaces for Transformer/Estimator/Model. One for Java wrappers, the other for native Python APIs. 

Java Interface Name

Python Interface Name

Transformer

Transformer

JavaTransformer

Estimator

Estimator

JavaEstimator

Model

Model

JavaModel

Below, we will take Transformer as an example to show you what the interfaces would be like and how to implement these two kinds of interfaces.

class Transformer(PipelineStage):

    """

    A transformer is a PipelineStage that transforms an input Table to a result Table.

    """

    __metaclass__ = ABCMeta

    @abstractmethod

    def transform(self, table_env, table):

        raise NotImplementedError()


class JavaTransformer(Transformer):

    """

    Base class for :py:class:`Transformer`s that wrap Java implementations.

    Subclasses should ensure they have the transformer Java object available as j_obj.

    """

    def __init__(self, j_obj):

        super().__init__()

        self._j_obj = j_obj

    def transform(self, table_env, table):

        self._convert_params_to_java(self._j_obj)

        return Table(self._j_obj.transform(table_env._j_tenv, table._j_table))

To write native Python Transformers, you can extend the Transformer class and implement the transform() method. For Java wrappers, you can extend the JavaTransformer class and pass the java object to the constructor. In this case, you don’t have to implement the transform() method it has already been implemented.

Support keyword arguments

Parameters are widely used in machine learning algorithms. For Java APIs, builder pattern is used to set params:

// Java

KMeans kMeans = new KMeans()

  .setVectorCol("features")

  .setK(2)

  .setPredictionCol("prediction_result");

To align the APIs, the builder pattern will also be supported in Python. 

// Python, builder pattern

kmeans = KMeans() \

    .set_vector_col("features") \

    .set_k(2) \

    .set_prediction_col("prediction_result")

But, builder pattern is not commonly used in Python because the language takes keyword arguments. To make it easy for Python users, we propose to also support keyword arguments for the constructor:

// Python, keyword arguments

kmeans = KMeans(vector_col="features", k=2, prediction_col="prediction_result")

Modules

Two more python packages will be added under the current pyflink, i.e., `ml` and `mllib`. Package `ml` is used to place classes that align with java flink-ml-api module and package `mllib` is used to place classes that align with java flink-ml-lib module.

flink-python(maven module)

  • pyflink(python package)
    • ml 
    • mllib

As for components in FLINK JIRA, we can reuse the current `API/Python` component tag.

Public interfaces

Below the detailed design for all Python interfaces are listed.

ML core interface

PipelineStage

class PipelineStage(WithParams):

    """

    Base class for a stage in a pipeline.

    """

    def __init__(self, params=None):

        if params is None:

            self._params = Params()

        else:

            self._params = params

    def get_params(self):

        return self._params


Transformer

class Transformer(PipelineStage):

    """

    A transformer is a PipelineStage that transforms an input Table to a result Table.

    """

    __metaclass__ = ABCMeta

    @abstractmethod

    def transform(self, table_env, table):

        """

        Applies the transformer on the input table, and returns the result table.

        :param table_env: the table environment to which the input table is bound.

        :param table: the table to be transformed

        :returns: the transformed table

        """

        raise NotImplementedError()


JavaTransformer

class JavaTransformer(Transformer):

    """

    Base class for :py:class:`Transformer`s that wrap Java implementations.

    Subclasses should ensure they have the transformer Java object available as j_obj.

    """

    def __init__(self, j_obj):

        super().__init__()

        self._j_obj = j_obj

    def transform(self, table_env, table):

        """

        Applies the transformer on the input table, and returns the result table.

        :param table_env: the table environment to which the input table is bound.

        :param table: the table to be transformed

        :returns: the transformed table

        """

        self._convert_params_to_java(self._j_obj)

        return Table(self._j_obj.transform(table_env._j_tenv, table._j_table))


Estimator

class Estimator(PipelineStage):

    """

    Estimators are PipelineStages responsible for training and generating machine learning models.

    The implementations are expected to take an input table as training samples and generate a

    Model which fits these samples.

    """

    __metaclass__ = ABCMeta

    def fit(self, table_env, table):

        """

        Train and produce a Model which fits the records in the given Table.

        :param table_env: the table environment to which the input table is bound.

        :param table: the table with records to train the Model.

        :returns: a model trained to fit on the given Table.

        """

        raise NotImplementedError()


JavaEstimator

class JavaEstimator(Estimator):

    """

    Base class for :py:class:`Estimator`s that wrap Java implementations.

    Subclasses should ensure they have the estimator Java object available as j_obj.

    """

    def __init__(self, j_obj):

        super().__init__()

        self._j_obj = j_obj

    def fit(self, table_env, table):

        """

        Train and produce a Model which fits the records in the given Table.

        :param table_env: the table environment to which the input table is bound.

        :param table: the table with records to train the Model.

        :returns: a model trained to fit on the given Table.

        """

        self._convert_params_to_java(self._j_obj)

        return JavaModel(self._j_obj.fit(table_env._j_tenv, table._j_table))


Model

class Model(Transformer):

    """

    Abstract class for models that are fitted by estimators.

    A model is an ordinary Transformer except how it is created. While ordinary transformers 

    are defined by specifying the parameters directly, a model is usually generated by an Estimator

    when Estimator.fit(table_env, table) is invoked.

    """

    __metaclass__ = ABCMeta


JavaModel

class JavaModel(JavaTransformer, Model):

    """

    Base class for :py:class:`JavaTransformer`s that wrap Java implementations.

    Subclasses should ensure they have the model Java object available as j_obj.

    """

    def __init__(self, j_obj):

        super().__init__(j_obj)

ML Pipeline

class Pipeline(Estimator, Model):

    """

    A pipeline is a linear workflow which chains Estimators and Transformers to

    execute an algorithm.

    """

    def __init__(self, stages=None):

        super().__init__()

        self.stages = []

        if stages is not None:

            self.stages = stages

        self.last_estimator_index = -1

    def _need_fit(self):

        return self.last_estimator_index >= 0

    @staticmethod

    def _is_stage_need_fit(stage):

        return (isinstance(stage, Pipeline) and stage._need_fit()) or \

               ((not isinstance(stage, Pipeline)) and isinstance(stage, Estimator))

    def append_stage(self, stage):

        if self._is_stage_need_fit(stage):

            self.last_estimator_index = len(self.stages)

        elif not isinstance(stage, Transformer):

            raise RuntimeError("All PipelineStages should be Estimator or Transformer!")

        self.stages.append(stage)

        return self

    def fit(self, t_env, input):

        """

        Train the pipeline to fit on the records in the given Table.

        :param t_env: the table environment to which the input table is bound.

        :param input: the table with records to train the Pipeline.

        :returns: a pipeline with same stages as this Pipeline except all Estimators

        replaced with their corresponding Models.

        """

        transform_stages = []

        for i in range(0, len(self.stages)):

            s = self.stages[i]

            if i <= self.last_estimator_index:

                need_fit = self._is_stage_need_fit(s)

                if need_fit:

                    t = s.fit(t_env, input)

                else:

                    t = s

                transform_stages.append(t)

                input = t.transform(t_env, input)

            else:

                transform_stages.append(s)

        return Pipeline(transform_stages)

    def transform(self, t_env, input):

        """

        Generate a result table by applying all the stages in this pipeline to the input table in order.

        :param t_env: the table environment to which the input table is bound.

        :param input: the table to be transformed.

        :returns: a result table with all the stages applied to the input tables in order.

        """

        if self._need_fit():

            raise RuntimeError("Pipeline contains Estimator, need to fit first.")

        for s in self.stages:

            input = s.transform(t_env, input)

        return input

ML environment 

MLEnvironmentFactory

class MLEnvironmentFactory:

    """

    Factory to get the MLEnvironment using a MLEnvironmentId.

    """

    _lock = threading.RLock()

    _default_ml_environment_id = 0

    _next_id = 1

    _map = {}

    gateway = get_gateway()

    j_ml_env = gateway.jvm.MLEnvironmentFactory.getDefault()

    _default_ml_env = MLEnvironment(

        ExecutionEnvironment(j_ml_env.getExecutionEnvironment()),

        StreamExecutionEnvironment(j_ml_env.getStreamExecutionEnvironment()),

        BatchTableEnvironment(j_ml_env.getBatchTableEnvironment()),

        StreamTableEnvironment(j_ml_env.getStreamTableEnvironment()))

    _map[_default_ml_environment_id] = _default_ml_env

    @staticmethod

    def get(ml_env_id):

        """

        Get the MLEnvironment using a MLEnvironmentId.

        :param ml_env_id: the MLEnvironmentId

        :return: the MLEnvironment

        """

        with MLEnvironmentFactory._lock:

            if ml_env_id not in MLEnvironmentFactory._map:

                raise ValueError(

                    "Cannot find MLEnvironment for MLEnvironmentId %s. "

                    "Did you get the MLEnvironmentId by calling "

                    "get_new_ml_environment_id?" % ml_env_id)

            return MLEnvironmentFactory._map[ml_env_id]

    @staticmethod

    def get_default():

        """

        Get the MLEnvironment use the default MLEnvironmentId.

        :return: the default MLEnvironment.

        """

        with MLEnvironmentFactory._lock:

            return MLEnvironmentFactory._map[MLEnvironmentFactory._default_ml_environment_id]

    @staticmethod

    def get_new_ml_environment_id():

        """

        Create a unique MLEnvironment id and register a new MLEnvironment in the factory.

        :return: the MLEnvironment id.

        """

        with MLEnvironmentFactory._lock:

            return MLEnvironmentFactory.register_ml_environment(MLEnvironment())

    @staticmethod

    def register_ml_environment(ml_environment):

        """

        Register a new MLEnvironment to the factory and return a new MLEnvironment id.

        :param ml_environment: the MLEnvironment that will be stored in the factory.

        :return: the MLEnvironment id.

        """

        with MLEnvironmentFactory._lock:

            MLEnvironmentFactory._map[MLEnvironmentFactory._next_id] = ml_environment

            MLEnvironmentFactory._next_id += 1

            return MLEnvironmentFactory._next_id - 1

    @staticmethod

    def remove(ml_env_id):

        """

        Remove the MLEnvironment using the MLEnvironmentId.

        :param ml_env_id: the id.

        :return: the removed MLEnvironment

        """

        with MLEnvironmentFactory._lock:

            if ml_env_id is None:

                raise ValueError("The environment id cannot be null.")

            # Never remove the default MLEnvironment. Just return the default environment.

            if MLEnvironmentFactory._default_ml_env == ml_env_id:

                return MLEnvironmentFactory.get_default()

            else:

                return MLEnvironmentFactory._map.pop(ml_env_id)


MLEnvironment

class MLEnvironment(object):

    """

    The MLEnvironment stores the necessary context in Flink. Each MLEnvironment 

    will be associated with a unique ID. The operations associated with the same 

    MLEnvironment ID will share the same Flink job context. Both MLEnvironment 

    ID and MLEnvironment can only be retrieved from MLEnvironmentFactory.

    """

    def __init__(self, exe_env=None, stream_exe_env=None, batch_tab_env=None, stream_tab_env=None):

        self._exe_env = exe_env

        self._stream_exe_env = stream_exe_env

        self._batch_tab_env = batch_tab_env

        self._stream_tab_env = stream_tab_env

    def get_execution_environment(self):

        if self._exe_env is None:

            self._exe_env = ExecutionEnvironment.get_execution_environment()

        return self._exe_env

    def get_stream_execution_environment(self):

        if self._stream_exe_env is None:

            self._stream_exe_env = StreamExecutionEnvironment.get_execution_environment()

        return self._stream_exe_env

    def get_batch_table_environment(self):

        if self._batch_tab_env is None:

            self._batch_tab_env = BatchTableEnvironment.create(ExecutionEnvironment.get_execution_environment())

        return self._batch_tab_env

    def get_stream_table_environment(self):

        if self._stream_tab_env is None:

            self._stream_tab_env = StreamTableEnvironment.create(StreamExecutionEnvironment.get_execution_environment())

        return self._stream_tab_env


Params interface

Params

class Params(object):

    """

    The map-like container class for parameter. This class is provided to unify the interaction with

    parameters.

    """

    def __init__(self):

        self._paramMap = {}

    def set(self, k, v):

        self._paramMap[k] = v

    def get(self, k):

        return self._paramMap[k]


ParamInfo

class ParamInfo(object):

    """

    Definition of a parameter, including name, description, type_converter and so on.

    """

    def __init__(self, name, description, type_converter=None):

        self.name = str(name)

        self.description = str(description)

        self.type_converter = TypeConverters.identity if type_converter is None else type_converter


WithParams

class WithParams(object):

    """

    Parameters are widely used in machine learning realm. This class defines a common interface to

    interact with classes with parameters.

    """

    def get_params(self):

        pass

    def set(self, k, v):

        self.get_params().set(k, v)

        return self

    def get(self, k):

        return self.get_params().get(k)

    def _set(self, **kwargs):

        """

        Sets user-supplied params.

        """

        for param, value in kwargs.items():

            p = getattr(self, param)

            if value is not None:

                try:

                    value = p.type_converter(value)

                except TypeError as e:

                    raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e))

            self.get_params().set(p, value)

        return self

Example

trainTable = t_env.from_path('traningSource')

servingTable = t_env.from_path('servingSource')

# transformer

va = VectorAssembler(selected_cols=["a", "b"], output_col="features")

# estimator

kmeans = KMeans()\

    .set_vector_col("features")\

    .set_k(2)\

    .set_reserved_cols(["a", "b"])\

    .set_prediction_col("prediction_result")\

    .set_max_iter(100)

# pipeline

pipeline = Pipeline().append_stage(va).append_stage(kmeans)

pipeline\

    .fit(t_env, trainTable)\

    .transform(t_env, servingTable)\

    .insert_into('mySink')

t_env.execute('KmeansTest') 


Implementation Plan

  • Align interface for MLEnvironment and MLEnvironmentFactory
  • Add support for Python Translator/Estimator/Model
  • Add support for Translator/Estimator/Model Java Wrappers.