SPCoast
Railroading on the Southern Pacific Coast

NBD Mina

From SPCoast

Jump to: navigation, search

Full maven project source code: NBD implemented with Mina (gzip'd tar file) Apache Open Source License.

(See also: NBD implemented with Netty)

Starting at the top, the Server simply parses command line arguments, instantiates the protocol stack and binds it to a port. I used the params Map to pass the command line arguments to the session handler.

        // ... parse the command line...
	params.put("port", (Integer) port);
	params.put("file", (String) fileName);
	params.put("blocksize", (Integer) 1024);
 
	// Create an Acceptor
	NioSocketAcceptor NBDacceptor = new NioSocketAcceptor();
	        
	// Add encode and decode filters
	NBDacceptor.getFilterChain().addLast("NBDcodec", 
	        new ProtocolCodecFilter(
	        	new NBDCodecFactory(false)));  // false = server, true = client
	    	    	
	NBDacceptor.setHandler(new NBDServerHandler(params));
	NBDacceptor.bind(new InetSocketAddress(port));
	        
	System.out.println("Server now listening on port " + port);    

Next comes the business logic in the Handler - get READ and WRITE commands from the Linux Kernel via a TCP/IP socket and actually read and write blocks from a file. The "file" could be a real disk drive or it could be a file created just for that reason: % dd if=/dev/zero of=somefile bs=1M count=1000

    @Override
    public void sessionOpened(IoSession session) throws Exception {
          ... initialize stats, open file, ...
          // Trigger a negotiation packet... 
          session.setAttribute(NEGOTIATE_KEY, true);
    	 NBDResponse response = new NBDResponse(); 
         session.write(response);
    }
    
    @Override
    public void sessionClosed(IoSession session) throws Exception {
    	// cleanup cache and other stuff
    	logger.info("NBD session closed");
    	RandomAccessFile fileHandle = (RandomAccessFile) session.getAttribute(FILEHANDLE_KEY);
    	fileHandle.getChannel().force(true);
    	fileHandle.getFD().sync();
    	fileHandle.close();
        ........
    }
    
    // @Override
    public void sessionIdle(IoSession session) throws Exception {
    	// flush buffers...
        RandomAccessFile fileHandle = (RandomAccessFile) session.getAttribute(FILEHANDLE_KEY);
    	fileHandle.getChannel().force(true);
    	fileHandle.getFD().sync();
    }

    @Override
    public void messageReceived(IoSession session, Object message) throws Exception {
    	// logger.info("Got Message...");

        NBDRequest request = (NBDRequest) message;
        RandomAccessFile fileHandle = (RandomAccessFile) session.getAttribute(FILEHANDLE_KEY);
    	NBDResponse response = new NBDResponse(request.getHandle(), 0);
    	
    ....
		if (request.getType() == NBDRequest.NBD_CMD_WRITE) {  //  1
			try {				
				if (params.containsKey("readonly")) {
			    	logger.info("READ ONLY during WRITE...");
					response.setError(13); // EACCES
				} else {
					IoBuffer data = request.getData();
					if (data == null) {
			        	        throw new Exception("Write packet with no data");
					}
					fileHandle.seek(from);
					while(len > 0) {
						int currentLength = (int)Math.min(len, blockSize);
						byte [] buffer = new byte[currentLength];
						data.get(buffer);
						fileHandle.write(buffer);
						len -= currentLength;
					}
					data.clear();
				}
			} catch (Exception ex) {
				response.setError(5); // EIO
			}
			session.write(response);
		} else if (request.getType() == NBDRequest.NBD_CMD_READ) {  // 0
		    try {
				IoBuffer data = IoBuffer.allocate(len);
				data.clear();
				fileHandle.seek(from);
				while(len > 0) {
					int currentLength = (int)Math.min(len, blockSize);
					byte [] buffer = new byte[currentLength];
					if (fileHandle.read(buffer) < currentLength) {
						throw new Exception("short read on file");
					}
					data.put(buffer);
					len -= currentLength;
				}
				data.flip();
				response.setData(data);
			} catch (Exception ex) {
				response.setError(5);  // EIO
			}
			session.write(response);
		} else if (request.getType() == NBDRequest.NBD_CMD_DISC) {  // 2
			logger.info("End of session - DISCONNECTing");
		} else {
			throw new Exception("Unknown message type: " 
			    		+ request.getType());
		}
        }
}

The Request and Response objects are POJOs - plain old Java Objects that mimic the Linux "C" structs...

/**
 * NBD packet request header
 */

/**
 * @author plocher
 * nbd_request {
 *	__be32 magic;
 *	__be32 type;	
 *	char handle[8];
 *	__be64 from;
 *	__be32 len;
 *	__be32 flags;
 * } == 32 btyes
 */
public class NBDRequest {
    ....
    // magic:
    public final static int 	NBD_REQUEST_MAGIC = 0x25609513;
    
    // type:
    public final static int 	NBD_CMD_READ  = 0;
    public final static int 	NBD_CMD_WRITE = 1;
    public final static int 	NBD_CMD_DISC  = 2;
    
    // flags:
    public final static int 	NBD_READ_ONLY   = 0x0001;
    public final static int 	NBD_WRITE_NOCHK = 0x0002;

    public NBDRequest(
    	     int type,
    	     long handle,
    	     long from,
    	     int len,
    	     IoBuffer data) {
             ....
    }
}

and


public class NBDResponse {
    ....
    // magic:
    final public static int NBD_REPLY_MAGIC = 0x67446698;

    public NBDResponse() {
    	this.handle = 0;
    	this.error  = 0;
    	this.data = null;
    }
}


After all this, the actual encoder and decoder are pretty simple:

public class NBDRequestDecoder extends CumulativeProtocolDecoder {
	@Override
	protected boolean doDecode(IoSession session, IoBuffer in, ProtocolDecoderOutput out) 
	throws Exception {
	    int start = in.position();   // remember our starting point...
	    // logbuffer(in, "Decoding request");
	    if (in.remaining() >= 28) {
	    	IoBuffer data;
	        int magic     = in.getInt();								//  0 + 4 =  4
	        if (magic != NBDRequest.NBD_REQUEST_MAGIC) {
	        	throw (new Exception("Bad Magic Request: " + magic));
	        }
	        int type      = in.getInt();								//  4 + 4  =  8
	        long handle   = in.getLong();							//  8 + 8  = 16
	        long from     = in.getLong();							// 16 + 8  = 24
	        int len       = in.getInt();								// 24 + 4  = 28
	        if (type == NBDRequest.NBD_CMD_WRITE) {
	        	if (in.remaining() >= len) {
	        		int limit = in.limit();
        			int end = Math.min(len, in.remaining()) + in.position();
        			in.limit(end);
        			data = in.slice();
        			in.position(end);
        			in.limit(limit);
	        	} else {
	                        in.position(start);
	                        return false;
	        	}
	        } else {
	        	data = null;
	        }
                NBDRequest request = new NBDRequest(type, handle,from,len, data);
                out.write(request);
                return true;
            } else {
                return false;
            }
      }
}

and

public class NBDResponseEncoder extends ProtocolEncoderAdapter {
	private byte[] ints2Bytes(int [] bytes) throws Exception
	{
		byte [] out = new byte[bytes.length];
		for(int index=0; index<bytes.length; index++)
			out[index] = (byte)bytes[index];
		return out;
	}

	
	@Override
	public void encode(IoSession session, Object message, ProtocolEncoderOutput out)
			throws Exception {
		
                NBDResponse nbdResponse = (NBDResponse) message;
                Boolean needNegotiation = (Boolean) session.getAttribute(NBDServerHandler.NEGOTIATE_KEY);
        
                if (needNegotiation) {
                        // NBD negotiation Packet
                        // NBD Password
            		int nbdPassword[] = { 'N', 'B', 'D', 'M', 'A', 'G', 'I', 'C' };
    	        	// cliserv_magic
    	        	int nbdMagic[] = { 0x00, 0x00, 0x42, 0x02, 0x81, 0x86, 0x12, 0x53 };
    	        	// File Size in bytes
    	        	Long fileSize = (Long) session.getAttribute(NBDServerHandler.FILESIZE_KEY);
    	        	int nbdSize[] = {
    						(int)((fileSize >> 56) & 0xFF),
    						(int)((fileSize >> 48) & 0xFF),
    						(int)((fileSize >> 40) & 0xFF),
    						(int)((fileSize >> 32) & 0xFF),
    						(int)((fileSize >> 24) & 0xFF),
    						(int)((fileSize >> 16) & 0xFF),
    						(int)((fileSize >>  8) & 0xFF),
    						(int)((fileSize      ) & 0xFF) };
    		        // flags - 4 bytes, 
    		        //		Flags used between the client and server
    		        //		#define NBD_FLAG_HAS_FLAGS (1 << 0) /* Flags are there */
    		        //		#define NBD_FLAG_READ_ONLY (1 << 1) /* read-only */
    		        int nbdFlags[] = { 0x00, 0x00, 0x00, 0x00 };
    		        if ((Boolean) session.getAttribute(NBDServerHandler.READONLY_KEY)) {
    		        	nbdFlags[3] = 0x03;	// 0b0000_0011
    		        }
    	        	// padding - 124 bytes
    	        	int nbdPadding[] = new int[124];
    		        try {
    			IoBuffer buffer = IoBuffer.allocate(152, false);
    			buffer.put(ints2Bytes(nbdPassword));
    			buffer.put(ints2Bytes(nbdMagic));
    			buffer.put(ints2Bytes(nbdSize));
    			buffer.put(ints2Bytes(nbdFlags));
    			buffer.put(ints2Bytes(nbdPadding));
    		
    	                buffer.flip();
    	                out.write(buffer);
                        session.setAttribute(NBDServerHandler.NEGOTIATE_KEY, false);
    		        } catch (Exception ex) {
            	                ex.printStackTrace();
            	                session.close();
    	                }
                } else {
	                long handle   = nbdResponse.getHandle();
	                int  error    = nbdResponse.getError();
	                IoBuffer data  = nbdResponse.getData();
	                int  buflen = 16;
	                if (data != null) { buflen += data.remaining(); }
	                try {
		                IoBuffer buffer = IoBuffer.allocate(buflen, false);
		                buffer.putInt(NBDResponse.NBD_REPLY_MAGIC);
		                buffer.putInt(error);
		                buffer.putLong(handle);
		                if (data != null) {
		                	buffer.put(data);
		                }
		                buffer.flip();
		                out.write(buffer);
	                } catch (Exception ex) {
            	                ex.printStackTrace();
            	                session.close();
	                }
                }
	}
}