summaryrefslogtreecommitdiff
path: root/RhSolutions.ML.Builder/Program.cs
blob: 898e872cce1c2d25f8b10d696f9a560b9f7fdaf0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
using Microsoft.ML;

namespace RhSolutions.ML.Builder
{
    public class Program
    {
        private static string _appPath = Path.GetDirectoryName(Environment.GetCommandLineArgs()[0]) ?? ".";
        private static MLContext _mlContext = new MLContext(seed: 0);
        public static void Main()
        {
            var _trainDataView = _mlContext.Data.LoadFromTextFile<Product>(
                Path.Combine(_appPath, "..", "..", "..", "Data", "train.tsv"), hasHeader: true);
            var pipeline = ProcessData();
            BuildAndTrainModel(_trainDataView, pipeline, out ITransformer trainedModel);
            SaveModelAsFile(_mlContext, _trainDataView.Schema, trainedModel);
        }

        private static IEstimator<ITransformer> ProcessData()
        {
            var pipeline = _mlContext.Transforms.Conversion.MapValueToKey(inputColumnName: "Type", outputColumnName: "Label")
                .Append(_mlContext.Transforms.Text.FeaturizeText(inputColumnName: "Name", outputColumnName: "NameFeaturized"))
                .Append(_mlContext.Transforms.Concatenate("Features", "NameFeaturized"))
                .AppendCacheCheckpoint(_mlContext);
            return pipeline;
        }

        private static IEstimator<ITransformer> BuildAndTrainModel(IDataView trainingDataView, IEstimator<ITransformer> pipeline, out ITransformer trainedModel)
        {
            var trainingPipeline = pipeline.Append(_mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy("Label", "Features"))
                .Append(_mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));

            trainedModel = trainingPipeline.Fit(trainingDataView);
            return trainingPipeline;
        }

        private static void SaveModelAsFile(MLContext mlContext, DataViewSchema trainingDataViewSchema, ITransformer model)
        {
            mlContext.Model.Save(model, trainingDataViewSchema,
                Path.Combine(_appPath, "..", "..", "..", "Models", "model.zip"));
        }
    }
}