Discussion threadhttps://lists.apache.org/thread/880cxnw9ygsjl3x5w0kxrh4x5fw6mp4x
Vote threadhttps://lists.apache.org/thread/by9mpnmk22nws1sk54z7t69802ws2z55
JIRA

FLINK-37777 - Getting issue details... STATUS

Release<Flink Version>
Current StatusAccepted

Please keep the discussion on the mailing list rather than commenting on the wiki (wiki discussions get unwieldy fast).

Motivation

FLIP-437: Support ML Models in Flink SQL introduced ML_PREDICT  and ML_EVALUATE  functions, this FLIP adds details on how they can be implemented in both SQL and Table API as well as the new interfaces to introduce.

FLIP-437 proposed syntax:

SELECT f1, f2, label FROM ML_PREDICT(`my_data`, `classifier_model`, DESCRIPTOR(f1, f2))
SELECT * FROM ML_EVALUATE(`eval_data`, `classifier_model`, DESCRIPTOR(f1, f2))


after researching the code, we decide to still use TABLE  and MODEL  keyword to indicate table and model identifier. So the syntax we will implement will be

Prediction

SELECT f1, f2, label FROM ML_PREDICT(TABLE `my_data`, MODEL `classifier_model`, DESCRIPTOR(f1, f2))

Evaluation

For ml_evaluate  function, this proposal separates the label and feature Descriptor  to make it more clear. 

SELECT * FROM ML_EVALUATE(TABLE `eval_data`, MODEL `classifier_model`, DESCRIPTOR(label), DESCRIPTOR(f1, f2))
SELECT * FROM ML_EVALUATE(input => TABLE `eval_data`, model_input => MODEL `classifier_model`, label => DESCRIPTOR(label), features => DESCRIPTOR(f1, f2), task => 'classification')

The task  parameter can explicitly override the type of model to be evaluated.

Additionally, each predict and evaluate function can take an optional MAP  runtime config which can configure how the function can be run such as async  or sync .

Public Interfaces

Model runtime

Below is a class diagram overview to introduce provider and runtime for model prediction:

ML_PREDICT function can have different implementations based on Model provider similar to Table connector. For example:

CREATE MODEL mymodel 
INPUT (text STRING)
OUTPUT (result STRING)
WITH (
    'provider' = 'openai',
    'task' = 'text_generation',
    'endpoint' = 'https://api.openai.com/v1/llm/v1/chat',
    'apikey' = 'abcdefg',
    'system_prompt' = 'translate text to english'
)

With provider being openai  , this model can use runtime provided by openai. provider  option is mandatory to construct ModelProviderFactory  through Java SPI.

ModelProviderFactory and ModelProvider are introduced as below to choose runtime based on provider:

/**
 * Base interface for configuring a model provider.
 */
@PublicEvolving
public interface ModelProviderFactory extends Factory {

    /**
     * Create ModelRuntime based on connector.
     */
    ModelProvider createModelProvider(Context context);

    /** Provides catalog and session information describing the model to be accessed. */
    @PublicEvolving
    interface Context {
        /**
         * Returns the identifier of the model in the {@link Catalog}.
         *
         * <p>This identifier describes the relationship between the model instance and the
         * associated {@link Catalog} (if any).
         */
        ObjectIdentifier getObjectIdentifier();

        /**
         * Returns the resolved model information received from the {@link Catalog} or persisted
         * plan.
         *
         * <p>The {@link ResolvedCatalogModel} forwards the metadata from the catalog but offers a
         * validated {@link ResolvedSchema}. The original metadata object is available via {@link
         * ResolvedCatalogModel#getOrigin()}.
         *
         * <p>In most cases, a factory is interested in the following characteristics:
         *
         * <pre>{@code
         * // get the physical input and output data type to initialize the connector
         * context.getCatalogModel().getResolvedInputSchema().toPhysicalRowDataType()
         * context.getCatalogModel().getResolvedInputSchema().toPhysicalRowDataType()
         *
         * // get configuration options
         * context.getCatalogModel().getOptions()
         * }</pre>
         *
         * <p>During a plan restore, usually the model information persisted in the plan is used to
         * reconstruct the catalog model.
         */
        ResolvedCatalogModel getCatalogModel();

        /** Gives read-only access to the configuration of the current session. */
        ReadableConfig getConfiguration();

        /**
         * Returns the class loader of the current session.
         *
         * <p>The class loader is in particular useful for discovering further (nested) factories.
         */
        ClassLoader getClassLoader();

        /** Whether the model is temporary. */
        boolean isTemporary();
    }
}
/**
 * Model Provider base interface.
 */
interface ModelProvider {

	/**
	 * Creates a copy of this instance during planning. The copy should be a deep copy of all
     * mutable members.
     */
    ModelProvider copy();

	/**
 	 * Context for creating runtime providers.
	 */
    interface Context {
	    
		/** Resolved catalog model. */
        ResolvedCatalogModel getCatalogModel();

		/** Runtime config provided to provider */
        ReadableConfig runtimeConfig();
    }
}
/**
 * Provider to provide synchronous model runtime for predict function.
 */
interface PredictRuntimeProvider extends ModelProvider {

    /**
     * Returns a ml_predict function.
     */
    PredictFunction createPredictFunction(Context context);
}
/**
 * Provider to provide asynchronous model runtime for predict function.
 */
interface AsyncPredictRuntimeProvider extends ModelProvider {

    /**
     * Returns a async ml_predict function.
     */
    AsyncPredictFunction createAsyncPredictFunction(Context context);
}
/**
 * A wrapper class of {@link TableFunction} for synchronously prediction.
 *
 * <p>The output type of this table function is fixed as {@link Row}.
 */
@PublicEvolving
public abstract class PredictFunction extends TableFunction<RowData> {

    /**
     * Synchronously predict result based on input row.
     *
     * @param inputRow - A {@link RowData} that wraps input for predict function.
	 *
     * @return A collection of predicted results.
     */
    public abstract Collection<RowData> predict(RowData inputRow);

    /** Invoke {@link #predict} and handle exceptions. */
    public final void eval(Object... args) {
        try {
            GenericRowData argsData = GenericRowData.of(args);              
            Collection<RowData> results = predict(argsData);
            if (results == null) {
                return;
            }
            results.forEach(this::collect);
        } catch (Exception e) {
            throw new FlinkRuntimeException("Prediction error", e);
        }
    }
}
/**
 * A wrapper class of {@link AsyncTableFunction} for asynchronously prediction.
 *
 * <p>The output type of this table function is fixed as {@link RowData}.
 */
@PublicEvolving
public abstract class AsyncPredictFunction extends AsyncTableFunction<RowData> {

    /**
     * Asynchronously predict result based on input row.
     *
     * @param inputRow - A {@link RowData} that wraps input for predict function.
	 * @param completableFuture - Future to collect prediction results.
     */
    public abstract CompletableFuture<Collection<RowData>> asyncPredict(RowData inputRow);

    public void eval(CompletableFuture<Collection<Row>> completableFuture, Object... args) {
        try {
             GenericRowData argsData = GenericRowData.of(args);
             asyncPredict(keyRow)
                .whenComplete(
                        (result, exception) -> {
                            if (exception != null) {
                                future.completeExceptionally(
                                        new TableException(
                                                "Failed to execute asynchronously prediction.",
                                                exception));
                                return;
                            }
                            completableFuture.complete(result);
                        });
        }
    }
}
interface Module {
	default Optional<ModelProviderFactory> getModelProviderFactory() {
		return Optional.empty();
	}
}

Async Predict Function Usage and Config

AsyncPredictFuntion  can be used if a ModelProvider  implements AsyncPredictRuntimeProvider . To enable async function, user needs to pass async:true  in runtime config:

SELECT * FROM ML_PREDICT(TABLE input, MODEL mdl, descriptor(f1, f2), MAP['async', 'true']);

If the runtime config map is not provided, sync function will be used by default.

Additional configs for async function:


Name 

Meaning

max-concurrent-operations

The max number of async i/o operation that the async predict can trigger

timeout

The total time which can pass before the invocation (including retries) is considered timed out and task execution is failed

retry-strategy

FIXED_DELAY is for a retry after a fixed amount of time

retry-delay

The time to wait between retries for the FIXED_DELAY strategy.  Could be the base delay time for a (not yet proposed) exponential backoff.

max-attempts

The maximum number of attempts while retrying.

output-mode

Output mode for async operation. Can be ORDERED  (default) or ALLOW_UNORDERED .

async

Whether run async predict function or not. Default to false.

The defaults for these configs are

table.exec.async-ml-predict.max-concurrent-operations: 10
table.exec.async-ml-predict.timeout: 30s
table.exec.async-ml-predict.retry-strategy: FIXED_DELAY
table.exec.async-ml-predict.fixed-delay: 10s
table.exec.async-ml-predict.max-attempts: 3


ML_EVALUATE function

For ML_EVALUATE , we don't provide an interface for user to implement it because user can easily use predicted value to calculate evaluation themselves. Instead, Flink provides an default implementation for ML_EVALUATE  based on task  option. The output type of evaluation function will be a MAP and the keys can be different metrics according to task type.

Output metrics based on task type

TaskOutput metrics
classification

Accuracy

Precision

Recall

f1 score

clusteringMean Davies-Bouldin Index
embeddingMean Cosine Similarity
regression

Mean Absolute Error

Mean Squared Error

Root Mean Squared Error

Mean Absolute Percentage Error

text_generation

BLEU Score

ROUGE Score

Semantic Similarity

Model Sql Table Functions

SqlMLPredictionTableFunction  and SqlMLEvaluateTableFunction  will be registered as builtin SqlFunction which provides type inference and validation. We can create a base class SqlMLTableFunction  for them to indicate ML table functions in the planning phase.

public abstract class SqlMlFunctionTableFunction extends SqlFunction implements SqlTableFunction {
    @Override
    public SqlReturnTypeInference getRowTypeInference() {
        return this::inferRowType;
    }
    protected abstract RelDataType inferRowType(SqlOperatorBinding opBinding);
}
public class SqlMlPredictTableFunction extends SqlMlFunctionTableFunction {
    // Input/output validation and type inference
}
public class SqlMlEvaluateTableFunction extends SqlMlFunctionTableFunction {
    // Input/output validation and type inference
}

Proposed Changes

Sql syntax and planner changes for sql execution

  • MODEL keyword will be introduced in sql parser to parse model identifier to SqlModelCall which has the model identifier. When validating the call node, we can get and store the model for type inference etc. This is slightly different from the FLIP where there’s no MODEL or TABLE keyword in the function parameters. The result is it’s difficult to create Model and Table node from parser without these keywords.

  • We need to extend FlinkCalciteCatalogReader to pass in CatalogManager and provide getModel function.

  • We need to translate SqlModelCall to RexModelCall node in SqlToRelConverter during which process CatalogModel will be translated to ModelProvider instance which contains ResolvedCatalogModel

  • We need to introduce physical planner rule StreamPhysicalModelTableFunctionRule, BatchPhysicalModelTableFunctionRule to translate model table function call to ExecNode

  • During the translation, we will look at the functions (ML_PREDICT/ML_EVALUATE or some other function later on) and create Corresponding ExecNode
  • In ExecNode, we need to create FunctionProvider based on Context (Model options, runtime config whether it's async or not etc)

Compatibility, Deprecation, and Migration Plan

  • There’s compatibility issue since these are new features.

  • If we plan to support ML_PREDICT/ML_EVALUATE in PTF later, compiled plan won’t be affected as long as ExecNodes are there. We can change the planner to not support old syntax (such as drop Model keyword and model sql and rex node.

Future work

More generic support of Model argument in PTF and PTF Table API so user can define how Model can be used flexibly in their own UDF.

Test Plan

  • All existing tests pass

  • Add new unit tests and integration tests for any new code changes

Rejected Alternatives

None