A network architecture defines the way in which a deep learning model is structured and more importantly what it’s designed to do. The architecture will determine:
Before choosing a network architecture, it’s important to understand what kind of use case you have and the common networks available to you.
You’re likely to encounter these common architectures when getting started with deep learning:
• Convolutional neural network (CNN): CNNs are commonly associated with images as input data, but they can also be used for other input data, and I’ll get into those details in question 1.
• Recurrent neural network (RNN): RNNs have connections that keep track of previous information to make future predictions. Unlike CNNs, where each input is assumed to be an independent event, RNNs can process sequences of data that might affect each other. One example is in natural language processing, where previous words influence the likelihood of what comes next.
• Long-short term memory (LSTM): LSTM networks are a commonly used RNN for sequence and signal data. I'll go into more detail in question 3.
Excellent question. The short answer is you probably want a CNN to classify images.
Here’s why.
Let’s start with what CNN and LSTM networks are, and how they are commonly used.
When talking about a convolutional neural network, some people say “ConvNet” but whenever I try dropping that in conversation, I always feel like I’m trying to be cool when I’m actually not.
CNNs consist of many layers but follow some semblance of a pattern of convolution | ReLU | pooling, which repeats again (and again and again). They are often useful for image classification because they are very good at local spatial pattern matching, and they also generally outperform other methods for image feature extraction. Keep in mind that at the heart of CNNs is convolution. Convolving the input image with a series of filters highlights the features in the image without losing spatial interaction between adjacent pixels.
There are many variations on this theme, but a few common configurations of CNNs are:
An example of AlexNet.
Series layers follow a straight line.
An example of GoogLeNet.
Multiple lines and connections
are a classic sign of DAG.
Long short-term memory networks are primarily associated with time-series and sequence data. LSTM networks remember some portion of the data prior to making decisions now—they see data in context, which helps make better associations.
This diagram illustrates a simple LSTM network for classification:
My knee-jerk reaction is to always suggest LSTM networks! However, there are many techniques available for a reason, and some are going to work better in certain scenarios. I can't answer this question specifically without more information, so instead let's walk through a few possible scenarios.
Time-Series Regression Scenario #1: My input is low-complexity time-series data. I have a series of data points that I want to use to forecast future events.
In this situation, you might be best off using machine learning. Here's a short (3:43) video looking at forecasting electrical load using machine learning.
Time-Series Regression Scenario #2: I have data from multiple sensors and want to predict remaining useful life (the amount of time a machine has before it needs repair or replacement).
My colleagues and I see this question with our customers in industrial automation who need to identify problems before they become dangerous or expensive. This time, you might want to use an LSTM network over machine learning regression. This approach reduces the need to identify features manually, which would be a significant task given multiple sensors.
Time-Series Regression Scenario #3: I have audio data I want to denoise.
Here you could use a CNN. The important thing about this method is to convert signals into images prior to passing them into the network. This means the signal becomes an image representation through Fourier transform or other time-frequency manipulation. Using images provides a way to see features you might not be able to visualize in the original signal. The network used can be a pretrained network designed for images, since a Fourier transform is essentially an image.
Now, once again, you can do what you want here. It’s very possible that you could also use an LSTM network in Scenario 1, or a CNN in Scenario 2. These scenarios are just meant to give you a starting point.
Network architecture and pretrained networks go hand-in-hand. A pretrained model is a neural network that has already undergone training. The weights and biases of the network are tuned to the input data, and the network can be retrained quicker for a new task. This process, called transfer learning, can sometimes require fewer images and work with a smaller dataset. An additional avenue to explore is to “create” more data through simulation or augmentation.
For now, I will say that you should use whichever network you would like regardless of the dataset size, but consider using a pretrained network to require less input data, or consider methods to augment your dataset. My next column will cover pretrained networks and models, so keep an eye out for more on this topic.
Find all the columns in one place.
Take advantage of all MATLAB has to offer for deep learning.
See how to prepare a network for transfer learning interactively.