import java.io.*;

public class ImageOutputStream extends OutputStream {
    private EasyBufferedImage buffer;
    private int column, row, band, channel, numberOfChannels;

    public EasyBufferedImage getBuffer() {
	return buffer;
    }

    public ImageOutputStream(EasyBufferedImage buf, int channels) {
	numberOfChannels = channels;
	buffer = buf;
	column = 0;
	row = 0;
	band = 0;
	channel = 0;
    }

    public boolean hasCapacity() {
	return row < buffer.getHeight();
    }

   // channel is 0 through 7
    // 0 is LSB
    // 7 is MSB
    private int setBit(int value, int theBit, int channel) {
	theBit = theBit & 0x01;
	if(theBit == 0) return setOff(value, channel);
	else return setOn(value, channel);
    }

    private int setOn(int value, int channel) {
	return value | (0x01 << channel);
    }

    private int setOff(int value, int channel) {
	return value & (~(0x01 << channel));
    }

    private void advanceIndices() {
	channel = (channel + 1) % numberOfChannels;
	if(channel == 0) {
	    band = (band + 1) % buffer.getRaster().getNumBands();
	    if(band == 0) {
		column = (column + 1) % buffer.getWidth();
		if(column == 0) row++;
	    }
	}
    }

    public void writeBit(int theBit) throws IOException {
	if(!hasCapacity()) throw new IOException();

	int newValue = buffer.getRaster().getSample(column, row, band);
	newValue = setBit(newValue, theBit, channel);
	buffer.getRaster().setSample(column, row, band, newValue);

	advanceIndices();
    }

    public void writeByte(int theByte) throws IOException {
	for(int i=0; i<8; i++) { 
	    writeBit(theByte);
	    theByte = theByte >> 1;
	}
    }

    public void writeInt(int theInt) throws IOException {	
	for(int i=0; i<32; i++) { 
	    writeBit(theInt);
	    theInt = theInt >> 1;
	}
    }

    public void write(int b) throws IOException {
	writeByte(b);
    }
}
