Comprehensive Analysis of Spark, PMLS, and TensorFlow Distributed Machine Learning Platforms

This article investigates the design methods used in distributed machine learning platforms and proposes future research directions. I collaborated with my students Kuo Zhang and Salem Alqahtani to complete this in the fall of 2016, and it will be presented at the ICCCN’17 conference in Vancouver.

Machine learning (especially deep learning) has now achieved transformative success in areas such as speech recognition, image recognition, natural language processing, and recommendation/search engines. These technologies have promising applications in autonomous vehicles, digital healthcare systems, CRM, advertising, and the Internet of Things. Of course, capital has accelerated the development of machine learning, and many machine learning platforms have emerged recently.

Since the training process involves models of huge datasets, machine learning platforms are generally distributed, often using dozens or hundreds of parallel workers to train models. It is estimated that in the near future, the vast majority of tasks running in data centers will be machine learning tasks.

Having a research background in distributed systems, we decided to study these machine learning platforms from the perspective of distributed systems and analyze their communication and control limitations. We also investigated the fault tolerance and programming difficulty of these platforms.

We categorized these distributed machine learning platforms into three fundamental design methods:

  1. Basic Dataflow

  2. Parameter-Server Model

  3. Advanced Dataflow

We briefly introduced these three methods and provided examples: the basic dataflow method uses Apache Spark, the parameter-server model uses PMLS (Petuum), and the advanced dataflow model uses TensorFlow and MXNet. We provided several comparative performance evaluation results (more evaluation results are available in the paper). Unfortunately, we were unable to conduct large-scale evaluations.

1. Spark: In Spark, computation is modeled as a Directed Acyclic Graph (DAG), where each vertex represents a Resilient Distributed Dataset (RDD), and each edge represents an operation on the RDD. RDD is a collection of objects partitioned into different logical partitions, which are stored and processed in-memory, with shuffling/overflow to disk.

In a DAG, edge E from vertex A to vertex B indicates that RDD B is the result of performing operation E on RDD A. There are two types of operations: transformations and actions. Transformations (e.g., map, filter, join) refer to performing an operation on an RDD to generate a new RDD. Comprehensive Analysis of Spark, PMLS, and TensorFlow Distributed Machine Learning Platforms

Spark users need to model computations as DAGs to perform transformations or actions on RDDs. The DAG needs to be compiled into stages. Each stage is executed as a series of parallel tasks (each partition executes one task). Simple narrow dependencies favor efficient execution, while wide dependencies introduce bottlenecks as they disrupt flow and require communication-intensive shuffle operations. Comprehensive Analysis of Spark, PMLS, and TensorFlow Distributed Machine Learning Platforms

Distributed execution in Spark is achieved by splitting this DAG stage across different machines. This diagram clearly illustrates this master-worker architecture. The driver contains the tasks and two scheduler components—the DAG scheduler and the task scheduler; it also maps tasks to workers.

Spark is designed for general data processing and is not specific to machine learning. However, using MLlib for Spark, machine learning can also be performed on Spark. In the basic setup, Spark stores model parameters on the driver node, and workers communicate with the driver to update these parameters after each iteration. For large-scale deployments, these model parameters may not fit on the driver and will be maintained and updated as an RDD. This incurs a significant overhead since a new RDD must be created to store the updated model parameters for each iteration. Updating the model involves reshuffling data across machines/disks, which limits Spark’s scalability. This is a drawback of Spark’s basic dataflow model (DAG). Spark does not support the iterations required for machine learning well.

2. PMLS: PMLS is specifically designed for machine learning, without any other cluttered history. It introduces the abstract concept of a parameter server (PS) to support the intensive iterative training process of machine learning.

Comprehensive Analysis of Spark, PMLS, and TensorFlow Distributed Machine Learning Platforms

In this, the PS (the green box in the diagram) serves as a distributed in-memory key-value store. It is replicated and shared: each node acts as a primary node for one shard of the model (parameter space) and as secondary nodes/copies for other shards. Therefore, the PS can scale well in terms of the number of nodes. PS nodes store and update model parameters and respond to requests from workers. Workers request the latest model parameters from their local PS replicas and perform computations on their allocated dataset portions.

PMLS also adopts the Stale Synchronous Parallelism (SSP) model, which is more relaxed than the Bulk Synchronous Parallelism (BSP) model—where workers synchronize at the end of each iteration. SSP reduces the hassle of synchronization for workers, ensuring that the fastest worker cannot exceed the slowest worker by s iterations. A relaxed consistency model can still be used for machine learning training because the process has some noise tolerance. I discussed this issue in my April 2016 article: https://muratbuffalo.blogspot.com/2016/04/petuum-new-platform-for-distributed.html

3. TensorFlow: Google has a distributed machine learning platform based on the parameter server model called DistBelief. Refer to my commentary on the DistBelief paper: https://muratbuffalo.blogspot.com/2017/01/google-distbelief-paper-large-scale.html. In my opinion, the main drawback of DistBelief is that writing machine learning applications requires manipulating low-level code. Google wants all its employees to be able to write machine learning code without needing to be proficient in distributed execution—for the same reason, Google developed the MapReduce framework for big data processing.

So to achieve this goal, Google designed TensorFlow. TensorFlow adopts a data flow paradigm but is a more advanced version—where the computation graph does not need to be a DAG and can include loops and support mutable state. I believe the Naiad design may have influenced TensorFlow’s design.

TensorFlow uses a directed graph of nodes and edges to represent computations. Nodes represent computations with mutable state. The edges represent multi-dimensional data arrays (tensors) that are transmitted between nodes. TensorFlow requires users to statically declare this symbolic computation graph and uses rewriting and partitioning to allocate it to machines for distributed execution. (MXNet, especially DyNet, uses dynamic declaration of graphs, improving programming difficulty and flexibility.)

Comprehensive Analysis of Spark, PMLS, and TensorFlow Distributed Machine Learning Platforms

Distributed machine learning training in TensorFlow uses the parameter server method as shown in the diagram. When you use the PS abstraction in TensorFlow, you utilize the parameter server and data parallelism. TensorFlow also allows for more complex tasks, but that requires writing custom code and venturing into new territories.

Some Evaluation Results

Our evaluation used Amazon EC2 m4.xlarge instances. Each instance contains 4 vCPUs powered by Intel Xeon E5-2676 v3 and 16 GiB RAM. The EBS bandwidth is 750Mbps. We evaluated using two common machine learning tasks: binary classification logistic regression and image classification using multi-layer neural networks. I have only provided a few graphs here; more experimental details can be found in our paper. However, our experiments have some limitations: we used a small number of machines and could not conduct large-scale testing. We also limited CPU computation and did not test GPU.

Comprehensive Analysis of Spark, PMLS, and TensorFlow Distributed Machine Learning Platforms

This graph shows the execution speed of logistic regression on various platforms. Spark performs well but lags behind PMLS and MXNet.

Comprehensive Analysis of Spark, PMLS, and TensorFlow Distributed Machine Learning Platforms

This graph shows the execution speed of deep neural networks (DNN) on various platforms. Compared to single-layer logistic regression, Spark experiences a greater performance loss on two-layer neural networks. This is because two-layer networks require more iterative computations. In Spark, we store parameters in the driver, allowing them to fit; if we store parameters in an RDD and update them after each iteration, the situation worsens.

Comprehensive Analysis of Spark, PMLS, and TensorFlow Distributed Machine Learning Platforms

This graph shows the CPU utilization of various platforms. Spark applications seem to have significantly high CPU utilization, mainly due to the overhead of serialization. Our earlier work has pointed out this issue: https://muratbuffalo.blogspot.com/2017/05/paper-summary-making-sense-of.html

Conclusion and Future Directions

Parallel processing of machine learning/deep learning applications is challenging and is not particularly interesting from the perspective of concurrent algorithms. It can be said with some certainty that the parameter server method performs better for training on distributed machine learning platforms. In terms of limitations, the network remains a bottleneck for distributed machine learning applications. Providing better data/model stratification is more useful than more advanced general data flow platforms; data/model should be prioritized.

In Spark, CPU overhead will become a bottleneck before network limitations. The programming language used by Spark, Scala/JVM, significantly affects its performance. Therefore, distributed machine learning platforms particularly require better monitoring and/or performance prediction tools. Recently, some tools have been proposed to address issues in Spark data processing applications, such as Ernest and CherryPick.

There are still many unresolved issues regarding distributed system support for machine learning runtime, such as resource scheduling and runtime performance enhancement. By using runtime monitoring/performance analysis, the next generation of distributed machine learning platforms should provide detailed runtime elastic configuration/scheduling of computational, memory, and network resources for task execution. There are also some unresolved issues in programming and software engineering support. What kind of (distributed) programming abstractions are suitable for machine learning applications? Moreover, more research is needed on the verification and validation of distributed machine learning applications (especially testing DNN with problematic inputs).

Leave a Comment